This commit is contained in:
clowwindy
2014-06-18 15:50:05 +08:00
parent a0e1a9f1b0
commit 743d3cddb5
5 changed files with 33 additions and 90 deletions

View File

@ -27,6 +27,7 @@ import errno
import struct
import logging
import traceback
import random
import encrypt
import eventloop
from common import parse_header
@ -96,6 +97,8 @@ class TCPRelayHandler(object):
self._upstream_status = WAIT_STATUS_READING
self._downstream_status = WAIT_STATUS_INIT
self._remote_address = None
if is_local:
self._chosen_server = self._get_a_server()
fd_to_handlers[local_sock.fileno()] = self
local_sock.setblocking(False)
local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
@ -112,6 +115,15 @@ class TCPRelayHandler(object):
def remote_address(self):
return self._remote_address
def _get_a_server(self):
server = self._config['server']
server_port = self._config['server_port']
if type(server_port) == list:
server_port = random.choice(server_port)
logging.debug('chosen server: %s:%d', server, server_port)
# TODO support multiple server IP
return server, server_port
def _update_activity(self):
self._server.update_activity(self)
@ -190,8 +202,7 @@ class TCPRelayHandler(object):
data = ''.join(self._data_to_write_to_local)
l = len(data)
s = self._remote_sock.sendto(data, MSG_FASTOPEN,
(self._config['server'],
self._config['server_port']))
self._chosen_server)
if s < l:
data = data[s:]
self._data_to_write_to_local = [data]
@ -255,7 +266,7 @@ class TCPRelayHandler(object):
data_to_send = self._encryptor.encrypt(data)
self._data_to_write_to_remote.append(data_to_send)
# notice here may go into _handle_dns_resolved directly
self._dns_resolver.resolve(self._config['server'],
self._dns_resolver.resolve(self._chosen_server[0],
self._handle_dns_resolved)
else:
if len(data) > header_length:
@ -283,8 +294,7 @@ class TCPRelayHandler(object):
remote_addr = self._remote_address[0]
remote_port = self._remote_address[1]
if self._is_local:
remote_addr = self._config['server']
remote_port = self._config['server_port']
remote_addr, remote_port = self._chosen_server
addrs = socket.getaddrinfo(ip, remote_port, 0,
socket.SOCK_STREAM,
socket.SOL_TCP)