diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index b80c5dd..5b5d6b2 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -23,7 +23,9 @@ import socket import struct +import logging import common +import eventloop _request_count = 1 @@ -66,18 +68,7 @@ QTYPE_CNAME = 5 QCLASS_IN = 1 -def parse_ip(addrtype, data, length, offset): - if addrtype == QTYPE_A: - return socket.inet_ntop(socket.AF_INET, data[offset:offset + length]) - elif addrtype == QTYPE_AAAA: - return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length]) - elif addrtype == QTYPE_CNAME: - return parse_name(data, offset, length)[1] - else: - return data - - -def pack_address(address): +def build_address(address): address = address.strip('.') labels = address.split('.') results = [] @@ -91,17 +82,28 @@ def pack_address(address): return ''.join(results) -def pack_request(address): +def build_request(address, qtype): global _request_count header = struct.pack('!HBBHHHH', _request_count, 1, 0, 1, 0, 0, 0) - addr = pack_address(address) - qtype_qclass = struct.pack('!HH', QTYPE_ANY, QCLASS_IN) + addr = build_address(address) + qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN) _request_count += 1 if _request_count > 65535: _request_count = 1 return header + addr + qtype_qclass +def parse_ip(addrtype, data, length, offset): + if addrtype == QTYPE_A: + return socket.inet_ntop(socket.AF_INET, data[offset:offset + length]) + elif addrtype == QTYPE_AAAA: + return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length]) + elif addrtype == QTYPE_CNAME: + return parse_name(data, offset, length)[1] + else: + return data + + def parse_name(data, offset, length=512): p = offset if (ord(data[offset]) & (128 + 64)) == (128 + 64): @@ -110,7 +112,7 @@ def parse_name(data, offset, length=512): pointer = pointer & 0x3FFF if pointer == offset: return (0, None) - return (2, parse_name(data, pointer)) + return (2, parse_name(data, pointer)[1]) else: labels = [] l = ord(data[p]) @@ -173,7 +175,7 @@ def parse_record(data, offset, question=False): return len + 4, (name, None, record_type, record_class, None, None) -def unpack_response(data): +def parse_response(data): try: if len(data) >= 12: header = struct.unpack('!HBBHHHH', data[:12]) @@ -214,39 +216,185 @@ def unpack_response(data): offset += l if r: ars.append(r) - - return ans - + response = DNSResponse() + if qds: + response.hostname = qds[0][0] + for an in ans: + response.answers.append((an[1], an[2], an[3])) + return response except Exception as e: import traceback traceback.print_exc() return None -def resolve(address, callback): - # TODO async - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP) - req = pack_request(address) - if req is None: - # TODO - return - sock.sendto(req, ('8.8.8.8', 53)) - res, addr = sock.recvfrom(1024) - parsed_res = unpack_response(res) - callback(parsed_res) +def is_ip(address): + for family in (socket.AF_INET, socket.AF_INET6): + try: + socket.inet_pton(family, address) + return True + except (OSError, IOError): + pass + return False + + +class DNSResponse(object): + def __init__(self): + self.hostname = None + self.answers = [] # each: (addr, type, class) + + def __str__(self): + return '%s: %s' % (self.hostname, str(self.answers)) + + +STATUS_IPV4 = 0 +STATUS_IPV6 = 1 + + +class DNSResolver(object): + + def __init__(self): + self._loop = None + self._hostname_status = {} + self._hostname_to_cb = {} + self._cb_to_hostname = {} + # TODO add caching + # TODO try ipv4 and ipv6 sequencely + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + self._parse_config() + + def _parse_config(self): + try: + with open('/etc/resolv.conf', 'rb') as f: + servers = [] + content = f.readlines() + for line in content: + line = line.strip() + if line: + if line.startswith('nameserver'): + parts = line.split(' ') + if len(parts) >= 2: + server = parts[1] + if is_ip(server): + servers.append(server) + # TODO support more servers + if servers: + self._dns_server = (servers[0], 53) + return + except IOError: + pass + self._dns_server = ('8.8.8.8', 53) + + def add_to_loop(self, loop): + self._loop = loop + loop.add(self._sock, eventloop.POLL_IN) + loop.add_handler(self.handle_events) + + def _handle_data(self, data): + response = parse_response(data) + if response and response.hostname: + hostname = response.hostname + callbacks = self._hostname_to_cb.get(hostname, []) + ip = None + for answer in response.answers: + if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ + answer[2] == QCLASS_IN: + ip = answer[0] + break + if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ + == STATUS_IPV4: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) + return + for callback in callbacks: + if self._cb_to_hostname.__contains__(callback): + del self._cb_to_hostname[callback] + callback((hostname, ip), None) + if self._hostname_to_cb.__contains__(hostname): + del self._hostname_to_cb[hostname] + + def handle_events(self, events): + for sock, fd, event in events: + if sock != self._sock: + continue + if event & eventloop.POLL_ERR: + logging.error('dns socket err') + self._loop.remove(self._sock) + self._sock.close() + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + self._loop.add(self._sock, eventloop.POLL_IN) + else: + data, addr = sock.recvfrom(1024) + if addr != self._dns_server: + logging.warn('received a packet other than our dns') + break + self._handle_data(data) + break + + def remove_callback(self, callback): + hostname = self._cb_to_hostname.get(callback) + if hostname: + del self._cb_to_hostname[callback] + arr = self._hostname_to_cb.get(hostname, None) + if arr: + arr.remove(callback) + if not arr: + del self._hostname_to_cb[hostname] + + def _send_req(self, hostname, qtype): + logging.debug('resolving %s with type %d using server %s', hostname, + qtype, self._dns_server) + req = build_request(hostname, qtype) + self._sock.sendto(req, self._dns_server) + + def resolve(self, hostname, callback): + if not hostname: + callback(None, Exception('empty hostname')) + elif is_ip(hostname): + callback(hostname, None) + else: + arr = self._hostname_to_cb.get(hostname, None) + if not arr: + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) + self._hostname_to_cb[hostname] = [callback] + self._cb_to_hostname[callback] = hostname + else: + arr.append(callback) def test(): - def _callback(address): - print address + logging.getLogger('').handlers = [] + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', filemode='a+') - resolve('www.twitter.com', _callback) - resolve('www.google.com', _callback) - resolve('ipv6.google.com', _callback) - resolve('ipv6.l.google.com', _callback) - resolve('www.baidu.com', _callback) - resolve('www.a.shifen.com', _callback) - resolve('m.baidu.jp', _callback) + def _callback(address, error): + print error, address + + loop = eventloop.EventLoop() + resolver = DNSResolver() + resolver.add_to_loop(loop) + + resolver.resolve('8.8.8.8', _callback) + resolver.resolve('www.twitter.com', _callback) + resolver.resolve('www.google.com', _callback) + resolver.resolve('ipv6.google.com', _callback) + resolver.resolve('ipv6.l.google.com', _callback) + resolver.resolve('www.gmail.com', _callback) + resolver.resolve('r4---sn-3qqp-ioql.googlevideo.com', _callback) + resolver.resolve('www.baidu.com', _callback) + resolver.resolve('www.a.shifen.com', _callback) + resolver.resolve('m.baidu.jp', _callback) + resolver.resolve('www.youku.com', _callback) + resolver.resolve('www.twitter.com', _callback) + resolver.resolve('ipv6.google.com', _callback) + + loop.run() if __name__ == '__main__':