From 5cd9f04948c6a46bf3af75577a2225bff9ebc9f1 Mon Sep 17 00:00:00 2001 From: ahxxm Date: Fri, 7 Oct 2016 12:30:17 +0800 Subject: [PATCH] Refactor (#615) * make tcprelay.py less nested * import traceback module at the top * make loops in DNSResolver less nested * make manager.py less nested * introduce exception_handle decorator make try/except block more clean * apply exception_handle decorator to tcprelay * quote condition judgement * pep8 fix --- shadowsocks/asyncdns.py | 39 +++--- shadowsocks/eventloop.py | 2 +- shadowsocks/local.py | 44 +++--- shadowsocks/manager.py | 26 ++-- shadowsocks/shell.py | 47 +++++++ shadowsocks/tcprelay.py | 293 +++++++++++++++++++-------------------- 6 files changed, 248 insertions(+), 203 deletions(-) diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 44528d7..91601ea 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -276,15 +276,18 @@ class DNSResolver(object): content = f.readlines() for line in content: line = line.strip() - if line: - if line.startswith(b'nameserver'): - parts = line.split() - if len(parts) >= 2: - server = parts[1] - if common.is_ip(server) == socket.AF_INET: - if type(server) != str: - server = server.decode('utf8') - self._servers.append(server) + if not (line and line.startswith(b'nameserver')): + continue + + parts = line.split() + if len(parts) < 2: + continue + + server = parts[1] + if common.is_ip(server) == socket.AF_INET: + if type(server) != str: + server = server.decode('utf8') + self._servers.append(server) except IOError: pass if not self._servers: @@ -299,13 +302,17 @@ class DNSResolver(object): for line in f.readlines(): line = line.strip() parts = line.split() - if len(parts) >= 2: - ip = parts[0] - if common.is_ip(ip): - for i in range(1, len(parts)): - hostname = parts[i] - if hostname: - self._hosts[hostname] = ip + if len(parts) < 2: + continue + + ip = parts[0] + if not common.is_ip(ip): + continue + + for i in range(1, len(parts)): + hostname = parts[i] + if hostname: + self._hosts[hostname] = ip except IOError: self._hosts['localhost'] = '127.0.0.1' diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py index 78b532c..ce5da37 100644 --- a/shadowsocks/eventloop.py +++ b/shadowsocks/eventloop.py @@ -25,6 +25,7 @@ import os import time import socket import select +import traceback import errno import logging from collections import defaultdict @@ -204,7 +205,6 @@ class EventLoop(object): logging.debug('poll:%s', e) else: logging.error('poll:%s', e) - import traceback traceback.print_exc() continue diff --git a/shadowsocks/local.py b/shadowsocks/local.py index 4255a2e..dfc8032 100755 --- a/shadowsocks/local.py +++ b/shadowsocks/local.py @@ -27,6 +27,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns +@shell.exception_handle(self_=False, exit_code=1) def main(): shell.check_python() @@ -37,36 +38,31 @@ def main(): os.chdir(p) config = shell.get_config(True) - daemon.daemon_exec(config) - try: - logging.info("starting local at %s:%d" % - (config['local_address'], config['local_port'])) + logging.info("starting local at %s:%d" % + (config['local_address'], config['local_port'])) - dns_resolver = asyncdns.DNSResolver() - tcp_server = tcprelay.TCPRelay(config, dns_resolver, True) - udp_server = udprelay.UDPRelay(config, dns_resolver, True) - loop = eventloop.EventLoop() - dns_resolver.add_to_loop(loop) - tcp_server.add_to_loop(loop) - udp_server.add_to_loop(loop) + dns_resolver = asyncdns.DNSResolver() + tcp_server = tcprelay.TCPRelay(config, dns_resolver, True) + udp_server = udprelay.UDPRelay(config, dns_resolver, True) + loop = eventloop.EventLoop() + dns_resolver.add_to_loop(loop) + tcp_server.add_to_loop(loop) + udp_server.add_to_loop(loop) - def handler(signum, _): - logging.warn('received SIGQUIT, doing graceful shutting down..') - tcp_server.close(next_tick=True) - udp_server.close(next_tick=True) - signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler) + def handler(signum, _): + logging.warn('received SIGQUIT, doing graceful shutting down..') + tcp_server.close(next_tick=True) + udp_server.close(next_tick=True) + signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler) - def int_handler(signum, _): - sys.exit(1) - signal.signal(signal.SIGINT, int_handler) - - daemon.set_user(config.get('user', None)) - loop.run() - except Exception as e: - shell.print_exception(e) + def int_handler(signum, _): sys.exit(1) + signal.signal(signal.SIGINT, int_handler) + + daemon.set_user(config.get('user', None)) + loop.run() if __name__ == '__main__': main() diff --git a/shadowsocks/manager.py b/shadowsocks/manager.py index b42ffa9..bba542e 100644 --- a/shadowsocks/manager.py +++ b/shadowsocks/manager.py @@ -173,18 +173,20 @@ class Manager(object): self._statistics.clear() def _send_control_data(self, data): - if self._control_client_addr: - try: - self._control_socket.sendto(data, self._control_client_addr) - except (socket.error, OSError, IOError) as e: - error_no = eventloop.errno_from_exception(e) - if error_no in (errno.EAGAIN, errno.EINPROGRESS, - errno.EWOULDBLOCK): - return - else: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() + if not self._control_client_addr: + return + + try: + self._control_socket.sendto(data, self._control_client_addr) + except (socket.error, OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() def run(self): self._loop.run() diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py index 6cca837..041086a 100644 --- a/shadowsocks/shell.py +++ b/shadowsocks/shell.py @@ -23,6 +23,10 @@ import json import sys import getopt import logging +import traceback + +from functools import wraps + from shadowsocks.common import to_bytes, to_str, IPNetwork from shadowsocks import encrypt @@ -53,6 +57,49 @@ def print_exception(e): traceback.print_exc() +def exception_handle(self_, err_msg=None, exit_code=None, + destroy=False, conn_err=False): + # self_: if function passes self as first arg + + def process_exception(e, self=None): + print_exception(e) + if err_msg: + logging.error(err_msg) + if exit_code: + sys.exit(1) + + if not self_: + return + + if conn_err: + addr, port = self._client_address[0], self._client_address[1] + logging.error('%s when handling connection from %s:%d' % + (e, addr, port)) + if self._config['verbose']: + traceback.print_exc() + if destroy: + self.destroy() + + def decorator(func): + if self_: + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + func(self, *args, **kwargs) + except Exception as e: + process_exception(e, self) + else: + @wraps(func) + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except Exception as e: + process_exception(e) + + return wrapper + return decorator + + def print_shadowsocks(): version = '' try: diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 810e713..2e4772d 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -190,21 +190,23 @@ class TCPRelayHandler(object): if self._upstream_status != status: self._upstream_status = status dirty = True - if dirty: - if self._local_sock: - event = eventloop.POLL_ERR - if self._downstream_status & WAIT_STATUS_WRITING: - event |= eventloop.POLL_OUT - if self._upstream_status & WAIT_STATUS_READING: - event |= eventloop.POLL_IN - self._loop.modify(self._local_sock, event) - if self._remote_sock: - event = eventloop.POLL_ERR - if self._downstream_status & WAIT_STATUS_READING: - event |= eventloop.POLL_IN - if self._upstream_status & WAIT_STATUS_WRITING: - event |= eventloop.POLL_OUT - self._loop.modify(self._remote_sock, event) + if not dirty: + return + + if self._local_sock: + event = eventloop.POLL_ERR + if self._downstream_status & WAIT_STATUS_WRITING: + event |= eventloop.POLL_OUT + if self._upstream_status & WAIT_STATUS_READING: + event |= eventloop.POLL_IN + self._loop.modify(self._local_sock, event) + if self._remote_sock: + event = eventloop.POLL_ERR + if self._downstream_status & WAIT_STATUS_READING: + event |= eventloop.POLL_IN + if self._upstream_status & WAIT_STATUS_WRITING: + event |= eventloop.POLL_OUT + self._loop.modify(self._remote_sock, event) def _write_to_sock(self, data, sock): # write data to sock @@ -247,19 +249,20 @@ class TCPRelayHandler(object): return True def _handle_stage_connecting(self, data): - if self._is_local: - if self._ota_enable_session: - data = self._ota_chunk_data_gen(data) - data = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data) - else: + if not self._is_local: if self._ota_enable_session: self._ota_chunk_data(data, self._data_to_write_to_remote.append) else: self._data_to_write_to_remote.append(data) - if self._is_local and not self._fastopen_connected and \ - self._config['fast_open']: + return + + if self._ota_enable_session: + data = self._ota_chunk_data_gen(data) + data = self._encryptor.encrypt(data) + self._data_to_write_to_remote.append(data) + + if self._config['fast_open'] and not self._fastopen_connected: # for sslocal and fastopen, we basically wait for data and use # sendto to connect try: @@ -293,93 +296,88 @@ class TCPRelayHandler(object): traceback.print_exc() self.destroy() + @shell.exception_handle(self_=True, destroy=True, conn_err=True) def _handle_stage_addr(self, data): - try: - if self._is_local: - cmd = common.ord(data[1]) - if cmd == CMD_UDP_ASSOCIATE: - logging.debug('UDP associate') - if self._local_sock.family == socket.AF_INET6: - header = b'\x05\x00\x00\x04' - else: - header = b'\x05\x00\x00\x01' - addr, port = self._local_sock.getsockname()[:2] - addr_to_send = socket.inet_pton(self._local_sock.family, - addr) - port_to_send = struct.pack('>H', port) - self._write_to_sock(header + addr_to_send + port_to_send, - self._local_sock) - self._stage = STAGE_UDP_ASSOC - # just wait for the client to disconnect - return - elif cmd == CMD_CONNECT: - # just trim VER CMD RSV - data = data[3:] + if self._is_local: + cmd = common.ord(data[1]) + if cmd == CMD_UDP_ASSOCIATE: + logging.debug('UDP associate') + if self._local_sock.family == socket.AF_INET6: + header = b'\x05\x00\x00\x04' else: - logging.error('unknown command %d', cmd) + header = b'\x05\x00\x00\x01' + addr, port = self._local_sock.getsockname()[:2] + addr_to_send = socket.inet_pton(self._local_sock.family, + addr) + port_to_send = struct.pack('>H', port) + self._write_to_sock(header + addr_to_send + port_to_send, + self._local_sock) + self._stage = STAGE_UDP_ASSOC + # just wait for the client to disconnect + return + elif cmd == CMD_CONNECT: + # just trim VER CMD RSV + data = data[3:] + else: + logging.error('unknown command %d', cmd) + self.destroy() + return + header_result = parse_header(data) + if header_result is None: + raise Exception('can not parse header') + addrtype, remote_addr, remote_port, header_length = header_result + logging.info('connecting %s:%d from %s:%d' % + (common.to_str(remote_addr), remote_port, + self._client_address[0], self._client_address[1])) + if self._is_local is False: + # spec https://shadowsocks.org/en/spec/one-time-auth.html + self._ota_enable_session = addrtype & ADDRTYPE_AUTH + if self._ota_enable and not self._ota_enable_session: + logging.warn('client one time auth is required') + return + if self._ota_enable_session: + if len(data) < header_length + ONETIMEAUTH_BYTES: + logging.warn('one time auth header is too short') + return None + offset = header_length + ONETIMEAUTH_BYTES + _hash = data[header_length: offset] + _data = data[:header_length] + key = self._encryptor.decipher_iv + self._encryptor.key + if onetimeauth_verify(_hash, _data, key) is False: + logging.warn('one time auth fail') self.destroy() return - header_result = parse_header(data) - if header_result is None: - raise Exception('can not parse header') - addrtype, remote_addr, remote_port, header_length = header_result - logging.info('connecting %s:%d from %s:%d' % - (common.to_str(remote_addr), remote_port, - self._client_address[0], self._client_address[1])) - if self._is_local is False: - # spec https://shadowsocks.org/en/spec/one-time-auth.html - self._ota_enable_session = addrtype & ADDRTYPE_AUTH - if self._ota_enable and not self._ota_enable_session: - logging.warn('client one time auth is required') - return - if self._ota_enable_session: - if len(data) < header_length + ONETIMEAUTH_BYTES: - logging.warn('one time auth header is too short') - return None - offset = header_length + ONETIMEAUTH_BYTES - _hash = data[header_length: offset] - _data = data[:header_length] - key = self._encryptor.decipher_iv + self._encryptor.key - if onetimeauth_verify(_hash, _data, key) is False: - logging.warn('one time auth fail') - self.destroy() - return - header_length += ONETIMEAUTH_BYTES - self._remote_address = (common.to_str(remote_addr), remote_port) - # pause reading - self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) - self._stage = STAGE_DNS - if self._is_local: - # forward address to remote - self._write_to_sock((b'\x05\x00\x00\x01' - b'\x00\x00\x00\x00\x10\x10'), - self._local_sock) - # spec https://shadowsocks.org/en/spec/one-time-auth.html - # ATYP & 0x10 == 1, then OTA is enabled. - if self._ota_enable_session: - data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:] - key = self._encryptor.cipher_iv + self._encryptor.key - data += onetimeauth_gen(data, key) - data_to_send = self._encryptor.encrypt(data) - self._data_to_write_to_remote.append(data_to_send) - # notice here may go into _handle_dns_resolved directly - self._dns_resolver.resolve(self._chosen_server[0], - self._handle_dns_resolved) - else: - if self._ota_enable_session: - data = data[header_length:] - self._ota_chunk_data(data, - self._data_to_write_to_remote.append) - elif len(data) > header_length: - self._data_to_write_to_remote.append(data[header_length:]) - # notice here may go into _handle_dns_resolved directly - self._dns_resolver.resolve(remote_addr, - self._handle_dns_resolved) - except Exception as e: - self._log_error(e) - if self._config['verbose']: - traceback.print_exc() - self.destroy() + header_length += ONETIMEAUTH_BYTES + self._remote_address = (common.to_str(remote_addr), remote_port) + # pause reading + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) + self._stage = STAGE_DNS + if self._is_local: + # forward address to remote + self._write_to_sock((b'\x05\x00\x00\x01' + b'\x00\x00\x00\x00\x10\x10'), + self._local_sock) + # spec https://shadowsocks.org/en/spec/one-time-auth.html + # ATYP & 0x10 == 1, then OTA is enabled. + if self._ota_enable_session: + data = common.chr(addrtype | ADDRTYPE_AUTH) + data[1:] + key = self._encryptor.cipher_iv + self._encryptor.key + data += onetimeauth_gen(data, key) + data_to_send = self._encryptor.encrypt(data) + self._data_to_write_to_remote.append(data_to_send) + # notice here may go into _handle_dns_resolved directly + self._dns_resolver.resolve(self._chosen_server[0], + self._handle_dns_resolved) + else: + if self._ota_enable_session: + data = data[header_length:] + self._ota_chunk_data(data, + self._data_to_write_to_remote.append) + elif len(data) > header_length: + self._data_to_write_to_remote.append(data[header_length:]) + # notice here may go into _handle_dns_resolved directly + self._dns_resolver.resolve(remote_addr, + self._handle_dns_resolved) def _create_remote_socket(self, ip, port): addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, @@ -398,51 +396,50 @@ class TCPRelayHandler(object): remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) return remote_sock + @shell.exception_handle(self_=True) def _handle_dns_resolved(self, result, error): if error: - self._log_error(error) + addr, port = self._client_address[0], self._client_address[1] + logging.error('%s when handling connection from %s:%d' % + (error, addr, port)) + self.destroy() + return + if not (result and result[1]): self.destroy() return - if result and result[1]: - ip = result[1] - try: - self._stage = STAGE_CONNECTING - remote_addr = ip - if self._is_local: - remote_port = self._chosen_server[1] - else: - remote_port = self._remote_address[1] - if self._is_local and self._config['fast_open']: - # for fastopen: - # wait for more data arrive and send them in one SYN - self._stage = STAGE_CONNECTING - # we don't have to wait for remote since it's not - # created - self._update_stream(STREAM_UP, WAIT_STATUS_READING) - # TODO when there is already data in this packet - else: - # else do connect - remote_sock = self._create_remote_socket(remote_addr, - remote_port) - try: - remote_sock.connect((remote_addr, remote_port)) - except (OSError, IOError) as e: - if eventloop.errno_from_exception(e) == \ - errno.EINPROGRESS: - pass - self._loop.add(remote_sock, - eventloop.POLL_ERR | eventloop.POLL_OUT, - self._server) - self._stage = STAGE_CONNECTING - self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) - self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) - return - except Exception as e: - shell.print_exception(e) - if self._config['verbose']: - traceback.print_exc() - self.destroy() + ip = result[1] + self._stage = STAGE_CONNECTING + remote_addr = ip + if self._is_local: + remote_port = self._chosen_server[1] + else: + remote_port = self._remote_address[1] + + if self._is_local and self._config['fast_open']: + # for fastopen: + # wait for more data arrive and send them in one SYN + self._stage = STAGE_CONNECTING + # we don't have to wait for remote since it's not + # created + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + # TODO when there is already data in this packet + else: + # else do connect + remote_sock = self._create_remote_socket(remote_addr, + remote_port) + try: + remote_sock.connect((remote_addr, remote_port)) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == \ + errno.EINPROGRESS: + pass + self._loop.add(remote_sock, + eventloop.POLL_ERR | eventloop.POLL_OUT, + self._server) + self._stage = STAGE_CONNECTING + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) def _write_to_sock_remote(self, data): self._write_to_sock(data, self._remote_sock) @@ -661,10 +658,6 @@ class TCPRelayHandler(object): else: logging.warn('unknown socket') - def _log_error(self, e): - logging.error('%s when handling connection from %s:%d' % - (e, self._client_address[0], self._client_address[1])) - def destroy(self): # destroy the handler and release any resources # promises: