diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..90669ef --- /dev/null +++ b/Dockerfile @@ -0,0 +1,45 @@ +# Dockerfile for ShadowsocksR based alpine +# Reference URL: https://github.com/shadowsocksrr/shadowsocksr shadowsocksr-3.2.2 +# repositories URL: https://github.com/V7hinc/ssr_server_client + +FROM python:3.7-alpine +MAINTAINER V7hinc + +ENV INSTALL_PATH=/usr/local/shadowsocksr +ENV ssr_config_path=$INSTALL_PATH/conf +ENV autostart_shell=/autostart.sh + +# 安装必要插件 +RUN set -ex;\ +apk add --no-cache jq; + +COPY conf/ssr_client_config_sample.json $ssr_config_path/ssr_client_config.json +COPY conf/ssr_server_config_sample.json $ssr_config_path/ssr_server_config.json +COPY shadowsocks $INSTALL_PATH/shadowsocks + + +# 编写开机启动脚本 +RUN set -x;\ +echo "#!/bin/sh" >> $autostart_shell;\ +# 开机选择开启server模式还是client模式 +echo "ssr_client() {" >> $autostart_shell;\ +echo "python $INSTALL_PATH/shadowsocks/local.py -d start -c $ssr_config_path/ssr_client_config.json --pid-file=$INSTALL_PATH/ssr.pid --log-file=$INSTALL_PATH/ssr.log" >> $autostart_shell;\ +echo "}" >> $autostart_shell;\ +echo "ssr_server() {" >> $autostart_shell;\ +echo "python $INSTALL_PATH/shadowsocks/server.py -c $ssr_config_path/ssr_server_config.json" >> $autostart_shell;\ +echo "}" >> $autostart_shell;\ +echo "case \$1 in" >> $autostart_shell;\ +echo "client) ssr_client ;;" >> $autostart_shell;\ +echo "server) ssr_server ;;" >> $autostart_shell;\ +echo "esac" >> $autostart_shell;\ +# 保持前台 +echo "/bin/sh;" >> $autostart_shell;\ +chmod 755 $autostart_shell; + +WORKDIR $ssr_config_path +VOLUME $ssr_config_path + +EXPOSE 1080 + +ENTRYPOINT ["/autostart.sh"] +CMD ["server"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..616edb7 --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +# ssr_server_client + +> shadowsocks源码来源于https://github.com/shadowsocksrr/shadowsocksr shadowsocksr-3.2.2 + +使用步骤 +1. 拉取代码 +```shell script +git clone https://github.com/V7hinc/ssr_server_client.git +``` +2.想要使用服务端(server)功能 +> 修改[conf/ssr_server_config_sample.json](conf/ssr_server_config_sample.json)配置文件 +``` +"port_password":{ + "9999":"xxx", + "8989":"xxx", + "8080":"xxx" +}, +# 键值对是 端口:密码 +# 想开放多个端口就写多条 +# 其他关于加密混淆方式的修改可以自行了解修改 +``` +3.想要使用客户端(client)功能 +> 修改[conf/ssr_client_config_sample.json](conf/ssr_client_config_sample.json)配置文件 +``` +{ + "server": "你的ssr服务器ip", # 填写你的ssr服务器ip + "server_ipv6": "::", + "local_address": "0.0.0.0", + "local_port": 1080, # 本地使用的端口,可不修改 + "server_port": 8989, # 填写ssr服务端开放的端口 + "password": "xxx", # 填写ssr服务端开放的端口对应的密码 +# 需设置与ssr服务端相匹配的配置 +# 加密方式需与服务端相匹配,否则将不能用 +# 其他关于加密混淆方式的修改可以自行了解修改 +``` +4. 构建docker镜像 +```shell script +docker build -t V7hinc/ssr_server_client . +# docker仓库https://hub.docker.com/r/v7hinc/ssr_server_client +``` +4. 启动server端 +```shell script +# 这里映射的端口需跟配置文件配置的端口相对应 +docker run -p 9999:9999 -p 8989:8989 -p 8080:8080 --name ssr_server --restart=always -v /usr/local/shadowsocksr/conf -dit v7hinc/ssr_server_client server +``` + +5. 启动client端 +```shell script +# 这里映射的端口需跟配置文件配置的端口相对应 +docker run -p 1080:1080 --name ssr_client --restart=always -v /usr/local/shadowsocksr/conf -dit v7hinc/ssr_server_client client +``` + +6. 客户端测试 +```shell script +curl ifconfig.me --socks 127.0.0.1:1080 +# 如成功将显示服务端出口ip +``` \ No newline at end of file diff --git a/conf/ssr_client_config_sample.json b/conf/ssr_client_config_sample.json new file mode 100644 index 0000000..bb45f6e --- /dev/null +++ b/conf/ssr_client_config_sample.json @@ -0,0 +1,25 @@ +{ + "server": "你的ssr服务器ip", + "server_ipv6": "::", + "local_address": "0.0.0.0", + "local_port": 1080, + "server_port": 8989, + "password": "xxx", + + + + "method": "aes-256-cfb", + "protocol": "auth_sha1_v4", + "protocol_param": "5", + "obfs": "http_simple", + "obfs_param": "", + "speed_limit_per_con": 0, + "speed_limit_per_user": 0, + "additional_ports": {}, + "timeout": 120, + "udp_timeout": 60, + "dns_ipv6": false, + "connect_verbose_info": 0, + "redirect": "", + "fast_open": false +} \ No newline at end of file diff --git a/conf/ssr_server_config_sample.json b/conf/ssr_server_config_sample.json new file mode 100644 index 0000000..610c950 --- /dev/null +++ b/conf/ssr_server_config_sample.json @@ -0,0 +1,25 @@ +{ + "server": "0.0.0.0", + "server_ipv6": "::", + "local_address": "127.0.0.1", + "local_port": 1080, + "port_password": { + "9999": "xxx", + "8989": "xxx", + "8080": "xxx" + }, + "method": "aes-256-cfb", + "protocol": "auth_sha1_v4", + "protocol_param": "5", + "obfs": "http_simple", + "obfs_param": "", + "speed_limit_per_con": 0, + "speed_limit_per_user": 0, + "additional_ports": {}, + "timeout": 120, + "udp_timeout": 60, + "dns_ipv6": false, + "connect_verbose_info": 0, + "redirect": "", + "fast_open": false +} \ No newline at end of file diff --git a/shadowsocks/__init__.py b/shadowsocks/__init__.py new file mode 100644 index 0000000..dc3abd4 --- /dev/null +++ b/shadowsocks/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/python +# +# Copyright 2012-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/shadowsocks/__pycache__/__init__.cpython-37.pyc b/shadowsocks/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..712bcfc Binary files /dev/null and b/shadowsocks/__pycache__/__init__.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/asyncdns.cpython-37.pyc b/shadowsocks/__pycache__/asyncdns.cpython-37.pyc new file mode 100644 index 0000000..fd2f9e8 Binary files /dev/null and b/shadowsocks/__pycache__/asyncdns.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/common.cpython-37.pyc b/shadowsocks/__pycache__/common.cpython-37.pyc new file mode 100644 index 0000000..274aa49 Binary files /dev/null and b/shadowsocks/__pycache__/common.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/daemon.cpython-37.pyc b/shadowsocks/__pycache__/daemon.cpython-37.pyc new file mode 100644 index 0000000..ce7dfbc Binary files /dev/null and b/shadowsocks/__pycache__/daemon.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/encrypt.cpython-37.pyc b/shadowsocks/__pycache__/encrypt.cpython-37.pyc new file mode 100644 index 0000000..a414c2f Binary files /dev/null and b/shadowsocks/__pycache__/encrypt.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/eventloop.cpython-37.pyc b/shadowsocks/__pycache__/eventloop.cpython-37.pyc new file mode 100644 index 0000000..d51995d Binary files /dev/null and b/shadowsocks/__pycache__/eventloop.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/lru_cache.cpython-37.pyc b/shadowsocks/__pycache__/lru_cache.cpython-37.pyc new file mode 100644 index 0000000..76285ab Binary files /dev/null and b/shadowsocks/__pycache__/lru_cache.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/obfs.cpython-37.pyc b/shadowsocks/__pycache__/obfs.cpython-37.pyc new file mode 100644 index 0000000..83b9846 Binary files /dev/null and b/shadowsocks/__pycache__/obfs.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/shell.cpython-37.pyc b/shadowsocks/__pycache__/shell.cpython-37.pyc new file mode 100644 index 0000000..729f832 Binary files /dev/null and b/shadowsocks/__pycache__/shell.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/tcprelay.cpython-37.pyc b/shadowsocks/__pycache__/tcprelay.cpython-37.pyc new file mode 100644 index 0000000..5e1a6ba Binary files /dev/null and b/shadowsocks/__pycache__/tcprelay.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/udprelay.cpython-37.pyc b/shadowsocks/__pycache__/udprelay.cpython-37.pyc new file mode 100644 index 0000000..7185dc4 Binary files /dev/null and b/shadowsocks/__pycache__/udprelay.cpython-37.pyc differ diff --git a/shadowsocks/__pycache__/version.cpython-37.pyc b/shadowsocks/__pycache__/version.cpython-37.pyc new file mode 100644 index 0000000..a35ef3f Binary files /dev/null and b/shadowsocks/__pycache__/version.cpython-37.pyc differ diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py new file mode 100644 index 0000000..868ea61 --- /dev/null +++ b/shadowsocks/asyncdns.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2014-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import socket +import struct +import re +import logging + +if __name__ == '__main__': + import sys + import inspect + + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + sys.path.insert(0, os.path.join(file_path, '../')) + +from shadowsocks import common, lru_cache, eventloop, shell + +CACHE_SWEEP_INTERVAL = 30 + +VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(? 63: + return None + results.append(common.chr(l)) + results.append(label) + results.append(b'\0') + return b''.join(results) + + +def build_request(address, qtype): + request_id = os.urandom(2) + header = struct.pack('!BBHHHH', 1, 0, 1, 0, 0, 0) + addr = build_address(address) + qtype_qclass = struct.pack('!HH', qtype, QCLASS_IN) + return request_id + header + addr + qtype_qclass + + +def parse_ip(addrtype, data, length, offset): + if addrtype == QTYPE_A: + return socket.inet_ntop(socket.AF_INET, data[offset:offset + length]) + elif addrtype == QTYPE_AAAA: + return socket.inet_ntop(socket.AF_INET6, data[offset:offset + length]) + elif addrtype in [QTYPE_CNAME, QTYPE_NS]: + return parse_name(data, offset)[1] + else: + return data[offset:offset + length] + + +def parse_name(data, offset): + p = offset + labels = [] + l = common.ord(data[p]) + while l > 0: + if (l & (128 + 64)) == (128 + 64): + # pointer + pointer = struct.unpack('!H', data[p:p + 2])[0] + pointer &= 0x3FFF + r = parse_name(data, pointer) + labels.append(r[1]) + p += 2 + # pointer is the end + return p - offset, b'.'.join(labels) + else: + labels.append(data[p + 1:p + 1 + l]) + p += 1 + l + l = common.ord(data[p]) + return p - offset + 1, b'.'.join(labels) + + +# rfc1035 +# record +# 1 1 1 1 1 1 +# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | | +# / / +# / NAME / +# | | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | TYPE | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | CLASS | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | TTL | +# | | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +# | RDLENGTH | +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| +# / RDATA / +# / / +# +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +def parse_record(data, offset, question=False): + nlen, name = parse_name(data, offset) + if not question: + record_type, record_class, record_ttl, record_rdlength = struct.unpack( + '!HHiH', data[offset + nlen:offset + nlen + 10] + ) + ip = parse_ip(record_type, data, record_rdlength, offset + nlen + 10) + return nlen + 10 + record_rdlength, \ + (name, ip, record_type, record_class, record_ttl) + else: + record_type, record_class = struct.unpack( + '!HH', data[offset + nlen:offset + nlen + 4] + ) + return nlen + 4, (name, None, record_type, record_class, None, None) + + +def parse_header(data): + if len(data) >= 12: + header = struct.unpack('!HBBHHHH', data[:12]) + res_id = header[0] + res_qr = header[1] & 128 + res_tc = header[1] & 2 + res_ra = header[2] & 128 + res_rcode = header[2] & 15 + # assert res_tc == 0 + # assert res_rcode in [0, 3] + res_qdcount = header[3] + res_ancount = header[4] + res_nscount = header[5] + res_arcount = header[6] + return (res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, + res_ancount, res_nscount, res_arcount) + return None + + +def parse_response(data): + try: + if len(data) >= 12: + header = parse_header(data) + if not header: + return None + res_id, res_qr, res_tc, res_ra, res_rcode, res_qdcount, \ + res_ancount, res_nscount, res_arcount = header + + qds = [] + ans = [] + offset = 12 + for i in range(0, res_qdcount): + l, r = parse_record(data, offset, True) + offset += l + if r: + qds.append(r) + for i in range(0, res_ancount): + l, r = parse_record(data, offset) + offset += l + if r: + ans.append(r) + for i in range(0, res_nscount): + l, r = parse_record(data, offset) + offset += l + for i in range(0, res_arcount): + l, r = parse_record(data, offset) + offset += l + response = DNSResponse() + if qds: + response.hostname = qds[0][0] + for an in qds: + response.questions.append((an[1], an[2], an[3])) + for an in ans: + response.answers.append((an[1], an[2], an[3])) + return response + except Exception as e: + shell.print_exception(e) + return None + + +def is_valid_hostname(hostname): + if len(hostname) > 255: + return False + if hostname[-1] == b'.': + hostname = hostname[:-1] + return all(VALID_HOSTNAME.match(x) for x in hostname.split(b'.')) + + +class DNSResponse(object): + def __init__(self): + self.hostname = None + self.questions = [] # each: (addr, type, class) + self.answers = [] # each: (addr, type, class) + + def __str__(self): + return '%s: %s' % (self.hostname, str(self.answers)) + + +STATUS_IPV4 = 0 +STATUS_IPV6 = 1 + + +class DNSResolver(object): + def __init__(self, black_hostname_list=None): + self._loop = None + self._hosts = {} + self._hostname_status = {} + self._hostname_to_cb = {} + self._cb_to_hostname = {} + self._cache = lru_cache.LRUCache(timeout=300) + # read black_hostname_list from config + if type(black_hostname_list) != list: + self._black_hostname_list = [] + else: + self._black_hostname_list = list(map( + (lambda t: t if type(t) == bytes else t.encode('utf8')), + black_hostname_list + )) + logging.info('black_hostname_list init as : ' + str(self._black_hostname_list)) + self._sock = None + self._servers = None + self._parse_resolv() + self._parse_hosts() + # TODO monitor hosts change and reload hosts + # TODO parse /etc/gai.conf and follow its rules + + def _parse_resolv(self): + self._servers = [] + try: + with open('dns.conf', 'rb') as f: + content = f.readlines() + for line in content: + line = line.strip() + if line: + parts = line.split(b' ', 1) + if len(parts) >= 2: + server = parts[0] + port = int(parts[1]) + else: + server = parts[0] + port = 53 + if common.is_ip(server) == socket.AF_INET: + if type(server) != str: + server = server.decode('utf8') + self._servers.append((server, port)) + except IOError: + pass + if not self._servers: + try: + with open('/etc/resolv.conf', 'rb') as f: + content = f.readlines() + for line in content: + line = line.strip() + if line: + if line.startswith(b'nameserver'): + parts = line.split() + if len(parts) >= 2: + server = parts[1] + if common.is_ip(server) == socket.AF_INET: + if type(server) != str: + server = server.decode('utf8') + self._servers.append((server, 53)) + except IOError: + pass + if not self._servers: + self._servers = [('8.8.4.4', 53), ('8.8.8.8', 53)] + logging.info('dns server: %s' % (self._servers,)) + + def _parse_hosts(self): + etc_path = '/etc/hosts' + if 'WINDIR' in os.environ: + etc_path = os.environ['WINDIR'] + '/system32/drivers/etc/hosts' + try: + with open(etc_path, 'rb') as f: + for line in f.readlines(): + line = line.strip() + if b"#" in line: + line = line[:line.find(b'#')] + parts = line.split() + if len(parts) >= 2: + ip = parts[0] + if common.is_ip(ip): + for i in range(1, len(parts)): + hostname = parts[i] + if hostname: + self._hosts[hostname] = ip + except IOError: + self._hosts['localhost'] = '127.0.0.1' + + def add_to_loop(self, loop): + if self._loop: + raise Exception('already add to loop') + self._loop = loop + # TODO when dns server is IPv6 + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + loop.add(self._sock, eventloop.POLL_IN, self) + loop.add_periodic(self.handle_periodic) + + def _call_callback(self, hostname, ip, error=None): + callbacks = self._hostname_to_cb.get(hostname, []) + for callback in callbacks: + if callback in self._cb_to_hostname: + del self._cb_to_hostname[callback] + if ip or error: + callback((hostname, ip), error) + else: + callback((hostname, None), + Exception('unable to parse hostname %s' % hostname)) + if hostname in self._hostname_to_cb: + del self._hostname_to_cb[hostname] + if hostname in self._hostname_status: + del self._hostname_status[hostname] + + def _handle_data(self, data): + response = parse_response(data) + if response and response.hostname: + hostname = response.hostname + ip = None + for answer in response.answers: + if answer[1] in (QTYPE_A, QTYPE_AAAA) and \ + answer[2] == QCLASS_IN: + ip = answer[0] + break + if IPV6_CONNECTION_SUPPORT: + if not ip and self._hostname_status.get(hostname, STATUS_IPV4) \ + == STATUS_IPV6: + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) + else: + if ip: + self._cache[hostname] = ip + self._call_callback(hostname, ip) + elif self._hostname_status.get(hostname, None) == STATUS_IPV4: + for question in response.questions: + if question[1] == QTYPE_A: + self._call_callback(hostname, None) + break + else: + if not ip and self._hostname_status.get(hostname, STATUS_IPV6) \ + == STATUS_IPV4: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) + else: + if ip: + self._cache[hostname] = ip + self._call_callback(hostname, ip) + elif self._hostname_status.get(hostname, None) == STATUS_IPV6: + for question in response.questions: + if question[1] == QTYPE_AAAA: + self._call_callback(hostname, None) + break + + def handle_event(self, sock, fd, event): + if sock != self._sock: + return + if event & eventloop.POLL_ERR: + logging.error('dns socket err') + self._loop.remove(self._sock) + self._sock.close() + # TODO when dns server is IPv6 + self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, + socket.SOL_UDP) + self._sock.setblocking(False) + self._loop.add(self._sock, eventloop.POLL_IN, self) + else: + data, addr = sock.recvfrom(1024) + if addr not in self._servers: + logging.warn('received a packet other than our dns') + return + self._handle_data(data) + + def handle_periodic(self): + self._cache.sweep() + + def remove_callback(self, callback): + hostname = self._cb_to_hostname.get(callback) + if hostname: + del self._cb_to_hostname[callback] + arr = self._hostname_to_cb.get(hostname, None) + if arr: + arr.remove(callback) + if not arr: + del self._hostname_to_cb[hostname] + if hostname in self._hostname_status: + del self._hostname_status[hostname] + + def _send_req(self, hostname, qtype): + req = build_request(hostname, qtype) + for server in self._servers: + logging.debug('resolving %s with type %d using server %s', + hostname, qtype, server) + self._sock.sendto(req, server) + + def resolve(self, hostname, callback): + if type(hostname) != bytes: + hostname = hostname.encode('utf8') + if not hostname: + callback(None, Exception('empty hostname')) + elif common.is_ip(hostname): + callback((hostname, hostname), None) + elif hostname in self._hosts: + logging.debug('hit hosts: %s', hostname) + ip = self._hosts[hostname] + callback((hostname, ip), None) + elif hostname in self._cache: + logging.debug('hit cache: %s ==>> %s', hostname, self._cache[hostname]) + ip = self._cache[hostname] + callback((hostname, ip), None) + elif any(hostname.endswith(t) for t in self._black_hostname_list): + callback(None, Exception('hostname <%s> is block by the black hostname list' % hostname)) + return + else: + if not is_valid_hostname(hostname): + callback(None, Exception('invalid hostname: %s' % hostname)) + return + if False: + addrs = socket.getaddrinfo(hostname, 0, 0, + socket.SOCK_DGRAM, socket.SOL_UDP) + if addrs: + af, socktype, proto, canonname, sa = addrs[0] + logging.debug('DNS resolve %s %s' % (hostname, sa[0])) + self._cache[hostname] = sa[0] + callback((hostname, sa[0]), None) + return + arr = self._hostname_to_cb.get(hostname, None) + if not arr: + if IPV6_CONNECTION_SUPPORT: + self._hostname_status[hostname] = STATUS_IPV6 + self._send_req(hostname, QTYPE_AAAA) + else: + self._hostname_status[hostname] = STATUS_IPV4 + self._send_req(hostname, QTYPE_A) + self._hostname_to_cb[hostname] = [callback] + self._cb_to_hostname[callback] = hostname + else: + arr.append(callback) + # TODO send again only if waited too long + if IPV6_CONNECTION_SUPPORT: + self._send_req(hostname, QTYPE_AAAA) + else: + self._send_req(hostname, QTYPE_A) + + def close(self): + if self._sock: + if self._loop: + self._loop.remove_periodic(self.handle_periodic) + self._loop.remove(self._sock) + self._sock.close() + self._sock = None + + +def test(): + black_hostname_list = [ + 'baidu.com', + 'yahoo.com', + ] + dns_resolver = DNSResolver(black_hostname_list=black_hostname_list) + loop = eventloop.EventLoop() + dns_resolver.add_to_loop(loop) + + global counter + counter = 0 + + def make_callback(): + global counter + + def callback(result, error): + global counter + # TODO: what can we assert? + print(result, error) + counter += 1 + if counter == 12: + dns_resolver.close() + loop.stop() + + a_callback = callback + return a_callback + + assert (make_callback() != make_callback()) + + dns_resolver.resolve(b'google.com', make_callback()) + dns_resolver.resolve('google.com', make_callback()) + dns_resolver.resolve('baidu.com', make_callback()) + dns_resolver.resolve('map.baidu.com', make_callback()) + dns_resolver.resolve('yahoo.com', make_callback()) + dns_resolver.resolve('example.com', make_callback()) + dns_resolver.resolve('ipv6.google.com', make_callback()) + dns_resolver.resolve('www.facebook.com', make_callback()) + dns_resolver.resolve('ns2.google.com', make_callback()) + dns_resolver.resolve('invalid.@!#$%^&$@.hostname', make_callback()) + dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'long.hostname', make_callback()) + dns_resolver.resolve('toooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' + 'long.hostname', make_callback()) + loop.run() + # test black_hostname_list + dns_resolver = DNSResolver(black_hostname_list=[]) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=123) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=None) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver() + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() + + +if __name__ == '__main__': + test() diff --git a/shadowsocks/common.py b/shadowsocks/common.py new file mode 100644 index 0000000..d022f3a --- /dev/null +++ b/shadowsocks/common.py @@ -0,0 +1,448 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2013-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import socket +import struct +import logging +import binascii +import re + +from shadowsocks import lru_cache + +def compat_ord(s): + if type(s) == int: + return s + return _ord(s) + + +def compat_chr(d): + if bytes == str: + return _chr(d) + return bytes([d]) + + +_ord = ord +_chr = chr +ord = compat_ord +chr = compat_chr + +connect_log = logging.debug + +def to_bytes(s): + if bytes != str: + if type(s) == str: + return s.encode('utf-8') + return s + + +def to_str(s): + if bytes != str: + if type(s) == bytes: + return s.decode('utf-8') + return s + +def int32(x): + if x > 0xFFFFFFFF or x < 0: + x &= 0xFFFFFFFF + if x > 0x7FFFFFFF: + x = int(0x100000000 - x) + if x < 0x80000000: + return -x + else: + return -2147483648 + return x + +def inet_ntop(family, ipstr): + if family == socket.AF_INET: + return to_bytes(socket.inet_ntoa(ipstr)) + elif family == socket.AF_INET6: + import re + v6addr = ':'.join(('%02X%02X' % (ord(i), ord(j))).lstrip('0') + for i, j in zip(ipstr[::2], ipstr[1::2])) + v6addr = re.sub('::+', '::', v6addr, count=1) + return to_bytes(v6addr) + + +def inet_pton(family, addr): + addr = to_str(addr) + if family == socket.AF_INET: + return socket.inet_aton(addr) + elif family == socket.AF_INET6: + if '.' in addr: # a v4 addr + v4addr = addr[addr.rindex(':') + 1:] + v4addr = socket.inet_aton(v4addr) + v4addr = ['%02X' % ord(x) for x in v4addr] + v4addr.insert(2, ':') + newaddr = addr[:addr.rindex(':') + 1] + ''.join(v4addr) + return inet_pton(family, newaddr) + dbyts = [0] * 8 # 8 groups + grps = addr.split(':') + for i, v in enumerate(grps): + if v: + dbyts[i] = int(v, 16) + else: + for j, w in enumerate(grps[::-1]): + if w: + dbyts[7 - j] = int(w, 16) + else: + break + break + return b''.join((chr(i // 256) + chr(i % 256)) for i in dbyts) + else: + raise RuntimeError("What family?") + + +def is_ip(address): + for family in (socket.AF_INET, socket.AF_INET6): + try: + if type(address) != str: + address = address.decode('utf8') + inet_pton(family, address) + return family + except (TypeError, ValueError, OSError, IOError): + pass + return False + + +def sync_str_bytes(obj, target_example): + """sync (obj)'s type to (target_example)'s type""" + if type(obj) != type(target_example): + if type(target_example) == str: + obj = to_str(obj) + if type(target_example) == bytes: + obj = to_bytes(obj) + return obj + + +def match_regex(regex, text): + # avoid 'cannot use a string pattern on a bytes-like object' + regex = sync_str_bytes(regex, text) + regex = re.compile(regex) + for item in regex.findall(text): + return True + return False + + +def patch_socket(): + if not hasattr(socket, 'inet_pton'): + socket.inet_pton = inet_pton + + if not hasattr(socket, 'inet_ntop'): + socket.inet_ntop = inet_ntop + + +patch_socket() + + +ADDRTYPE_IPV4 = 1 +ADDRTYPE_IPV6 = 4 +ADDRTYPE_HOST = 3 + + +def pack_addr(address): + address_str = to_str(address) + for family in (socket.AF_INET, socket.AF_INET6): + try: + r = socket.inet_pton(family, address_str) + if family == socket.AF_INET6: + return b'\x04' + r + else: + return b'\x01' + r + except (TypeError, ValueError, OSError, IOError): + pass + if len(address) > 255: + address = address[:255] # TODO + return b'\x03' + chr(len(address)) + address + +def pre_parse_header(data): + if not data: + return None + datatype = ord(data[0]) + if datatype == 0x80: + if len(data) <= 2: + return None + rand_data_size = ord(data[1]) + if rand_data_size + 2 >= len(data): + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + data = data[rand_data_size + 2:] + elif datatype == 0x81: + data = data[1:] + elif datatype == 0x82: + if len(data) <= 3: + return None + rand_data_size = struct.unpack('>H', data[1:3])[0] + if rand_data_size + 3 >= len(data): + logging.warn('header too short, maybe wrong password or ' + 'encryption method') + return None + data = data[rand_data_size + 3:] + elif datatype == 0x88 or (~datatype & 0xff) == 0x88: + if len(data) <= 7 + 7: + return None + data_size = struct.unpack('>H', data[1:3])[0] + ogn_data = data + data = data[:data_size] + crc = binascii.crc32(data) & 0xffffffff + if crc != 0xffffffff: + logging.warn('uncorrect CRC32, maybe wrong password or ' + 'encryption method') + return None + start_pos = 3 + ord(data[3]) + data = data[start_pos:-4] + if data_size < len(ogn_data): + data += ogn_data[data_size:] + return data + +def parse_header(data): + addrtype = ord(data[0]) + dest_addr = None + dest_port = None + header_length = 0 + connecttype = (addrtype & 0x8) and 1 or 0 + addrtype &= ~0x8 + if addrtype == ADDRTYPE_IPV4: + if len(data) >= 7: + dest_addr = socket.inet_ntoa(data[1:5]) + dest_port = struct.unpack('>H', data[5:7])[0] + header_length = 7 + else: + logging.warn('header is too short') + elif addrtype == ADDRTYPE_HOST: + if len(data) > 2: + addrlen = ord(data[1]) + if len(data) >= 4 + addrlen: + dest_addr = data[2:2 + addrlen] + dest_port = struct.unpack('>H', data[2 + addrlen:4 + + addrlen])[0] + header_length = 4 + addrlen + else: + logging.warn('header is too short') + else: + logging.warn('header is too short') + elif addrtype == ADDRTYPE_IPV6: + if len(data) >= 19: + dest_addr = socket.inet_ntop(socket.AF_INET6, data[1:17]) + dest_port = struct.unpack('>H', data[17:19])[0] + header_length = 19 + else: + logging.warn('header is too short') + else: + logging.warn('unsupported addrtype %d, maybe wrong password or ' + 'encryption method' % addrtype) + if dest_addr is None: + return None + return connecttype, addrtype, to_bytes(dest_addr), dest_port, header_length + + +class IPNetwork(object): + ADDRLENGTH = {socket.AF_INET: 32, socket.AF_INET6: 128, False: 0} + + def __init__(self, addrs): + self.addrs_str = addrs + self._network_list_v4 = [] + self._network_list_v6 = [] + if type(addrs) == str: + addrs = addrs.split(',') + list(map(self.add_network, addrs)) + + def add_network(self, addr): + if addr is "": + return + block = addr.split('/') + addr_family = is_ip(block[0]) + addr_len = IPNetwork.ADDRLENGTH[addr_family] + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(block[0])) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, block[0])) + ip = (hi << 64) | lo + else: + raise Exception("Not a valid CIDR notation: %s" % addr) + if len(block) is 1: + prefix_size = 0 + while (ip & 1) == 0 and ip is not 0: + ip >>= 1 + prefix_size += 1 + logging.warn("You did't specify CIDR routing prefix size for %s, " + "implicit treated as %s/%d" % (addr, addr, addr_len)) + elif block[1].isdigit() and int(block[1]) <= addr_len: + prefix_size = addr_len - int(block[1]) + ip >>= prefix_size + else: + raise Exception("Not a valid CIDR notation: %s" % addr) + if addr_family is socket.AF_INET: + self._network_list_v4.append((ip, prefix_size)) + else: + self._network_list_v6.append((ip, prefix_size)) + + def __contains__(self, addr): + addr_family = is_ip(addr) + if addr_family is socket.AF_INET: + ip, = struct.unpack("!I", socket.inet_aton(addr)) + return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1], + self._network_list_v4)) + elif addr_family is socket.AF_INET6: + hi, lo = struct.unpack("!QQ", inet_pton(addr_family, addr)) + ip = (hi << 64) | lo + return any(map(lambda n_ps: n_ps[0] == ip >> n_ps[1], + self._network_list_v6)) + else: + return False + + def __cmp__(self, other): + return cmp(self.addrs_str, other.addrs_str) + + def __eq__(self, other): + return self.addrs_str == other.addrs_str + + def __ne__(self, other): + return self.addrs_str != other.addrs_str + +class PortRange(object): + def __init__(self, range_str): + self.range_str = to_str(range_str) + self.range = set() + range_str = to_str(range_str).split(',') + for item in range_str: + try: + int_range = item.split('-') + if len(int_range) == 1: + if item: + self.range.add(int(item)) + elif len(int_range) == 2: + int_range[0] = int(int_range[0]) + int_range[1] = int(int_range[1]) + if int_range[0] < 0: + int_range[0] = 0 + if int_range[1] > 65535: + int_range[1] = 65535 + i = int_range[0] + while i <= int_range[1]: + self.range.add(i) + i += 1 + except Exception as e: + logging.error(e) + + def __contains__(self, val): + return val in self.range + + def __cmp__(self, other): + return cmp(self.range_str, other.range_str) + + def __eq__(self, other): + return self.range_str == other.range_str + + def __ne__(self, other): + return self.range_str != other.range_str + +class UDPAsyncDNSHandler(object): + dns_cache = lru_cache.LRUCache(timeout=1800) + def __init__(self, params): + self.params = params + self.remote_addr = None + self.call_back = None + + def resolve(self, dns_resolver, remote_addr, call_back): + if remote_addr in UDPAsyncDNSHandler.dns_cache: + if call_back: + call_back("", remote_addr, UDPAsyncDNSHandler.dns_cache[remote_addr], self.params) + else: + self.call_back = call_back + self.remote_addr = remote_addr + dns_resolver.resolve(remote_addr[0], self._handle_dns_resolved) + UDPAsyncDNSHandler.dns_cache.sweep() + + def _handle_dns_resolved(self, result, error): + if error: + logging.error("%s when resolve DNS" % (error,)) #drop + return self.call_back(error, self.remote_addr, None, self.params) + if result: + ip = result[1] + if ip: + return self.call_back("", self.remote_addr, ip, self.params) + logging.warning("can't resolve %s" % (self.remote_addr,)) + return self.call_back("fail to resolve", self.remote_addr, None, self.params) + +def test_inet_conv(): + ipv4 = b'8.8.4.4' + b = inet_pton(socket.AF_INET, ipv4) + assert inet_ntop(socket.AF_INET, b) == ipv4 + ipv6 = b'2404:6800:4005:805::1011' + b = inet_pton(socket.AF_INET6, ipv6) + assert inet_ntop(socket.AF_INET6, b) == ipv6 + + +def test_parse_header(): + assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \ + (0, ADDRTYPE_HOST, b'www.google.com', 80, 18) + assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \ + (0, ADDRTYPE_IPV4, b'8.8.8.8', 53, 7) + assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00' + b'\x00\x10\x11\x00\x50')) == \ + (0, ADDRTYPE_IPV6, b'2404:6800:4005:805::1011', 80, 19) + + +def test_pack_header(): + assert pack_addr(b'8.8.8.8') == b'\x01\x08\x08\x08\x08' + assert pack_addr(b'2404:6800:4005:805::1011') == \ + b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00\x00\x10\x11' + assert pack_addr(b'www.google.com') == b'\x03\x0ewww.google.com' + + +def test_ip_network(): + ip_network = IPNetwork('127.0.0.0/24,::ff:1/112,::1,192.168.1.1,192.0.2.0') + assert '127.0.0.1' in ip_network + assert '127.0.1.1' not in ip_network + assert ':ff:ffff' in ip_network + assert '::ffff:1' not in ip_network + assert '::1' in ip_network + assert '::2' not in ip_network + assert '192.168.1.1' in ip_network + assert '192.168.1.2' not in ip_network + assert '192.0.2.1' in ip_network + assert '192.0.3.1' in ip_network # 192.0.2.0 is treated as 192.0.2.0/23 + assert 'www.google.com' not in ip_network + + +def test_sync_str_bytes(): + assert sync_str_bytes(b'a\.b', b'a\.b') == b'a\.b' + assert sync_str_bytes('a\.b', b'a\.b') == b'a\.b' + assert sync_str_bytes(b'a\.b', 'a\.b') == 'a\.b' + assert sync_str_bytes('a\.b', 'a\.b') == 'a\.b' + pass + + +def test_match_regex(): + assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(r'\bgoogle\.com\b', b' google.com ') + pass + +if __name__ == '__main__': + test_sync_str_bytes() + test_match_regex() + test_inet_conv() + test_parse_header() + test_pack_header() + test_ip_network() diff --git a/shadowsocks/crypto/__init__.py b/shadowsocks/crypto/__init__.py new file mode 100644 index 0000000..401c7b7 --- /dev/null +++ b/shadowsocks/crypto/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/shadowsocks/crypto/__pycache__/__init__.cpython-37.pyc b/shadowsocks/crypto/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..bf18386 Binary files /dev/null and b/shadowsocks/crypto/__pycache__/__init__.cpython-37.pyc differ diff --git a/shadowsocks/crypto/__pycache__/openssl.cpython-37.pyc b/shadowsocks/crypto/__pycache__/openssl.cpython-37.pyc new file mode 100644 index 0000000..c23deb8 Binary files /dev/null and b/shadowsocks/crypto/__pycache__/openssl.cpython-37.pyc differ diff --git a/shadowsocks/crypto/__pycache__/rc4_md5.cpython-37.pyc b/shadowsocks/crypto/__pycache__/rc4_md5.cpython-37.pyc new file mode 100644 index 0000000..08f9a01 Binary files /dev/null and b/shadowsocks/crypto/__pycache__/rc4_md5.cpython-37.pyc differ diff --git a/shadowsocks/crypto/__pycache__/sodium.cpython-37.pyc b/shadowsocks/crypto/__pycache__/sodium.cpython-37.pyc new file mode 100644 index 0000000..57bb94a Binary files /dev/null and b/shadowsocks/crypto/__pycache__/sodium.cpython-37.pyc differ diff --git a/shadowsocks/crypto/__pycache__/table.cpython-37.pyc b/shadowsocks/crypto/__pycache__/table.cpython-37.pyc new file mode 100644 index 0000000..8a42010 Binary files /dev/null and b/shadowsocks/crypto/__pycache__/table.cpython-37.pyc differ diff --git a/shadowsocks/crypto/__pycache__/util.cpython-37.pyc b/shadowsocks/crypto/__pycache__/util.cpython-37.pyc new file mode 100644 index 0000000..42a08b0 Binary files /dev/null and b/shadowsocks/crypto/__pycache__/util.cpython-37.pyc differ diff --git a/shadowsocks/crypto/ctypes_libsodium.py b/shadowsocks/crypto/ctypes_libsodium.py new file mode 100644 index 0000000..efecfd4 --- /dev/null +++ b/shadowsocks/crypto/ctypes_libsodium.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 clowwindy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +from ctypes import CDLL, c_char_p, c_int, c_ulonglong, byref, \ + create_string_buffer, c_void_p + +__all__ = ['ciphers'] + +libsodium = None +loaded = False + +buf_size = 2048 + +# for salsa20 and chacha20 +BLOCK_SIZE = 64 + + +def load_libsodium(): + global loaded, libsodium, buf + + from ctypes.util import find_library + for p in ('sodium',): + libsodium_path = find_library(p) + if libsodium_path: + break + else: + raise Exception('libsodium not found') + logging.info('loading libsodium from %s', libsodium_path) + libsodium = CDLL(libsodium_path) + libsodium.sodium_init.restype = c_int + libsodium.crypto_stream_salsa20_xor_ic.restype = c_int + libsodium.crypto_stream_salsa20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + libsodium.crypto_stream_chacha20_xor_ic.restype = c_int + libsodium.crypto_stream_chacha20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + + libsodium.sodium_init() + + buf = create_string_buffer(buf_size) + loaded = True + + +class Salsa20Crypto(object): + def __init__(self, cipher_name, key, iv, op): + if not loaded: + load_libsodium() + self.key = key + self.iv = iv + self.key_ptr = c_char_p(key) + self.iv_ptr = c_char_p(iv) + if cipher_name == b'salsa20': + self.cipher = libsodium.crypto_stream_salsa20_xor_ic + elif cipher_name == b'chacha20': + self.cipher = libsodium.crypto_stream_chacha20_xor_ic + else: + raise Exception('Unknown cipher') + # byte counter, not block counter + self.counter = 0 + + def update(self, data): + global buf_size, buf + l = len(data) + + # we can only prepend some padding to make the encryption align to + # blocks + padding = self.counter % BLOCK_SIZE + if buf_size < padding + l: + buf_size = (padding + l) * 2 + buf = create_string_buffer(buf_size) + + if padding: + data = (b'\0' * padding) + data + self.cipher(byref(buf), c_char_p(data), padding + l, + self.iv_ptr, int(self.counter / BLOCK_SIZE), self.key_ptr) + self.counter += l + # buf is copied to a str object when we access buf.raw + # strip off the padding + return buf.raw[padding:padding + l] + + +ciphers = { + b'salsa20': (32, 8, Salsa20Crypto), + b'chacha20': (32, 8, Salsa20Crypto), +} + + +def test_salsa20(): + from shadowsocks.crypto import util + + cipher = Salsa20Crypto(b'salsa20', b'k' * 32, b'i' * 16, 1) + decipher = Salsa20Crypto(b'salsa20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_chacha20(): + from shadowsocks.crypto import util + + cipher = Salsa20Crypto(b'chacha20', b'k' * 32, b'i' * 16, 1) + decipher = Salsa20Crypto(b'chacha20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +if __name__ == '__main__': + test_chacha20() + test_salsa20() diff --git a/shadowsocks/crypto/ctypes_openssl.py b/shadowsocks/crypto/ctypes_openssl.py new file mode 100644 index 0000000..0ef8ce0 --- /dev/null +++ b/shadowsocks/crypto/ctypes_openssl.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 clowwindy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import logging +from ctypes import CDLL, c_char_p, c_int, c_long, byref,\ + create_string_buffer, c_void_p + +__all__ = ['ciphers'] + +libcrypto = None +loaded = False + +buf_size = 2048 + + +def load_openssl(): + global loaded, libcrypto, buf + + from ctypes.util import find_library + for p in ('crypto', 'eay32', 'libeay32'): + libcrypto_path = find_library(p) + if libcrypto_path: + break + else: + raise Exception('libcrypto(OpenSSL) not found') + logging.info('loading libcrypto from %s', libcrypto_path) + libcrypto = CDLL(libcrypto_path) + libcrypto.EVP_get_cipherbyname.restype = c_void_p + libcrypto.EVP_CIPHER_CTX_new.restype = c_void_p + + libcrypto.EVP_CipherInit_ex.argtypes = (c_void_p, c_void_p, c_char_p, + c_char_p, c_char_p, c_int) + + libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p, + c_char_p, c_int) + + libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) + libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,) + if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'): + libcrypto.OpenSSL_add_all_ciphers() + + buf = create_string_buffer(buf_size) + loaded = True + + +def load_cipher(cipher_name): + func_name = b'EVP_' + cipher_name.replace(b'-', b'_') + if bytes != str: + func_name = str(func_name, 'utf-8') + cipher = getattr(libcrypto, func_name, None) + if cipher: + cipher.restype = c_void_p + return cipher() + return None + + +class CtypesCrypto(object): + def __init__(self, cipher_name, key, iv, op): + if not loaded: + load_openssl() + self._ctx = None + cipher = libcrypto.EVP_get_cipherbyname(cipher_name) + if not cipher: + cipher = load_cipher(cipher_name) + if not cipher: + raise Exception('cipher %s not found in libcrypto' % cipher_name) + key_ptr = c_char_p(key) + iv_ptr = c_char_p(iv) + self._ctx = libcrypto.EVP_CIPHER_CTX_new() + if not self._ctx: + raise Exception('can not create cipher context') + r = libcrypto.EVP_CipherInit_ex(self._ctx, cipher, None, + key_ptr, iv_ptr, c_int(op)) + if not r: + self.clean() + raise Exception('can not initialize cipher context') + + def update(self, data): + global buf_size, buf + cipher_out_len = c_long(0) + l = len(data) + if buf_size < l: + buf_size = l * 2 + buf = create_string_buffer(buf_size) + libcrypto.EVP_CipherUpdate(self._ctx, byref(buf), + byref(cipher_out_len), c_char_p(data), l) + # buf is copied to a str object when we access buf.raw + return buf.raw[:cipher_out_len.value] + + def __del__(self): + self.clean() + + def clean(self): + if self._ctx: + libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx) + libcrypto.EVP_CIPHER_CTX_free(self._ctx) + + +ciphers = { + b'aes-128-cfb': (16, 16, CtypesCrypto), + b'aes-192-cfb': (24, 16, CtypesCrypto), + b'aes-256-cfb': (32, 16, CtypesCrypto), + b'aes-128-ofb': (16, 16, CtypesCrypto), + b'aes-192-ofb': (24, 16, CtypesCrypto), + b'aes-256-ofb': (32, 16, CtypesCrypto), + b'aes-128-ctr': (16, 16, CtypesCrypto), + b'aes-192-ctr': (24, 16, CtypesCrypto), + b'aes-256-ctr': (32, 16, CtypesCrypto), + b'aes-128-cfb8': (16, 16, CtypesCrypto), + b'aes-192-cfb8': (24, 16, CtypesCrypto), + b'aes-256-cfb8': (32, 16, CtypesCrypto), + b'aes-128-cfb1': (16, 16, CtypesCrypto), + b'aes-192-cfb1': (24, 16, CtypesCrypto), + b'aes-256-cfb1': (32, 16, CtypesCrypto), + b'bf-cfb': (16, 8, CtypesCrypto), + b'camellia-128-cfb': (16, 16, CtypesCrypto), + b'camellia-192-cfb': (24, 16, CtypesCrypto), + b'camellia-256-cfb': (32, 16, CtypesCrypto), + b'cast5-cfb': (16, 8, CtypesCrypto), + b'des-cfb': (8, 8, CtypesCrypto), + b'idea-cfb': (16, 8, CtypesCrypto), + b'rc2-cfb': (16, 8, CtypesCrypto), + b'rc4': (16, 0, CtypesCrypto), + b'seed-cfb': (16, 16, CtypesCrypto), +} + + +def run_method(method): + from shadowsocks.crypto import util + + cipher = CtypesCrypto(method, b'k' * 32, b'i' * 16, 1) + decipher = CtypesCrypto(method, b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_aes_128_cfb(): + run_method(b'aes-128-cfb') + + +def test_aes_256_cfb(): + run_method(b'aes-256-cfb') + + +def test_aes_128_cfb8(): + run_method(b'aes-128-cfb8') + + +def test_aes_256_ofb(): + run_method(b'aes-256-ofb') + + +def test_aes_256_ctr(): + run_method(b'aes-256-ctr') + + +def test_bf_cfb(): + run_method(b'bf-cfb') + + +def test_rc4(): + run_method(b'rc4') + + +if __name__ == '__main__': + test_aes_128_cfb() diff --git a/shadowsocks/crypto/openssl.py b/shadowsocks/crypto/openssl.py new file mode 100644 index 0000000..e4980a4 --- /dev/null +++ b/shadowsocks/crypto/openssl.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +from ctypes import c_char_p, c_int, c_long, byref, \ + create_string_buffer, c_void_p + +from shadowsocks import common +from shadowsocks.crypto import util + +__all__ = ['ciphers'] + +libcrypto = None +loaded = False + +buf_size = 2048 + +ctx_cleanup = None + + +def load_openssl(): + global loaded, libcrypto, buf, ctx_cleanup + + libcrypto = util.find_library(('crypto', 'eay32'), + 'EVP_get_cipherbyname', + 'libcrypto') + if libcrypto is None: + raise Exception('libcrypto(OpenSSL) not found') + + libcrypto.EVP_get_cipherbyname.restype = c_void_p + libcrypto.EVP_CIPHER_CTX_new.restype = c_void_p + + libcrypto.EVP_CipherInit_ex.argtypes = (c_void_p, c_void_p, c_char_p, + c_char_p, c_char_p, c_int) + + libcrypto.EVP_CipherUpdate.argtypes = (c_void_p, c_void_p, c_void_p, + c_char_p, c_int) + + if hasattr(libcrypto, "EVP_CIPHER_CTX_cleanup"): + libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) + ctx_cleanup = libcrypto.EVP_CIPHER_CTX_cleanup + else: + libcrypto.EVP_CIPHER_CTX_reset.argtypes = (c_void_p,) + ctx_cleanup = libcrypto.EVP_CIPHER_CTX_reset + libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,) + + libcrypto.RAND_bytes.restype = c_int + libcrypto.RAND_bytes.argtypes = (c_void_p, c_int) + + if hasattr(libcrypto, 'OpenSSL_add_all_ciphers'): + libcrypto.OpenSSL_add_all_ciphers() + + buf = create_string_buffer(buf_size) + loaded = True + + +def load_cipher(cipher_name): + func_name = 'EVP_' + cipher_name.replace('-', '_') + cipher = getattr(libcrypto, func_name, None) + if cipher: + cipher.restype = c_void_p + return cipher() + return None + + +def rand_bytes(length): + if not loaded: + load_openssl() + buf = create_string_buffer(length) + r = libcrypto.RAND_bytes(buf, length) + if r <= 0: + raise Exception('RAND_bytes return error') + return buf.raw + + +class OpenSSLCrypto(object): + def __init__(self, cipher_name, key, iv, op): + self._ctx = None + if not loaded: + load_openssl() + cipher = libcrypto.EVP_get_cipherbyname(common.to_bytes(cipher_name)) + if not cipher: + cipher = load_cipher(cipher_name) + if not cipher: + raise Exception('cipher %s not found in libcrypto' % cipher_name) + key_ptr = c_char_p(key) + iv_ptr = c_char_p(iv) + self._ctx = libcrypto.EVP_CIPHER_CTX_new() + if not self._ctx: + raise Exception('can not create cipher context') + r = libcrypto.EVP_CipherInit_ex(self._ctx, cipher, None, + key_ptr, iv_ptr, c_int(op)) + if not r: + self.clean() + raise Exception('can not initialize cipher context') + + def update(self, data): + global buf_size, buf + cipher_out_len = c_long(0) + l = len(data) + if buf_size < l: + buf_size = l * 2 + buf = create_string_buffer(buf_size) + libcrypto.EVP_CipherUpdate(self._ctx, byref(buf), + byref(cipher_out_len), c_char_p(data), l) + # buf is copied to a str object when we access buf.raw + return buf.raw[:cipher_out_len.value] + + def __del__(self): + self.clean() + + def clean(self): + if self._ctx: + ctx_cleanup(self._ctx) + libcrypto.EVP_CIPHER_CTX_free(self._ctx) + self._ctx = None + + +ciphers = { + # CBC mode need a special use way that different from other. + # CBC mode encrypt message with 16n length, and need 16n+1 length space to decrypt it , otherwise don't decrypt it + 'aes-128-cbc': (16, 16, OpenSSLCrypto), + 'aes-192-cbc': (24, 16, OpenSSLCrypto), + 'aes-256-cbc': (32, 16, OpenSSLCrypto), + 'aes-128-gcm': (16, 16, OpenSSLCrypto), + 'aes-192-gcm': (24, 16, OpenSSLCrypto), + 'aes-256-gcm': (32, 16, OpenSSLCrypto), + 'aes-128-cfb': (16, 16, OpenSSLCrypto), + 'aes-192-cfb': (24, 16, OpenSSLCrypto), + 'aes-256-cfb': (32, 16, OpenSSLCrypto), + 'aes-128-ofb': (16, 16, OpenSSLCrypto), + 'aes-192-ofb': (24, 16, OpenSSLCrypto), + 'aes-256-ofb': (32, 16, OpenSSLCrypto), + 'aes-128-ctr': (16, 16, OpenSSLCrypto), + 'aes-192-ctr': (24, 16, OpenSSLCrypto), + 'aes-256-ctr': (32, 16, OpenSSLCrypto), + 'aes-128-cfb8': (16, 16, OpenSSLCrypto), + 'aes-192-cfb8': (24, 16, OpenSSLCrypto), + 'aes-256-cfb8': (32, 16, OpenSSLCrypto), + 'aes-128-cfb1': (16, 16, OpenSSLCrypto), + 'aes-192-cfb1': (24, 16, OpenSSLCrypto), + 'aes-256-cfb1': (32, 16, OpenSSLCrypto), + 'bf-cfb': (16, 8, OpenSSLCrypto), + 'camellia-128-cfb': (16, 16, OpenSSLCrypto), + 'camellia-192-cfb': (24, 16, OpenSSLCrypto), + 'camellia-256-cfb': (32, 16, OpenSSLCrypto), + 'cast5-cfb': (16, 8, OpenSSLCrypto), + 'des-cfb': (8, 8, OpenSSLCrypto), + 'idea-cfb': (16, 8, OpenSSLCrypto), + 'rc2-cfb': (16, 8, OpenSSLCrypto), + 'rc4': (16, 0, OpenSSLCrypto), + 'seed-cfb': (16, 16, OpenSSLCrypto), +} + + +def run_method(method): + cipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 1) + decipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_aes_128_cfb(): + run_method('aes-128-cfb') + + +def test_aes_256_cfb(): + run_method('aes-256-cfb') + + +def test_aes_128_cfb8(): + run_method('aes-128-cfb8') + + +def test_aes_256_ofb(): + run_method('aes-256-ofb') + + +def test_aes_256_ctr(): + run_method('aes-256-ctr') + + +def test_bf_cfb(): + run_method('bf-cfb') + + +def test_rc4(): + run_method('rc4') + + +def test_all(): + for k, v in ciphers.items(): + print(k) + try: + run_method(k) + except AssertionError as e: + eprint("AssertionError===========" + k) + eprint(e) + + +def eprint(*args, **kwargs): + import sys + print(*args, file=sys.stderr, **kwargs) + + +if __name__ == '__main__': + test_all() diff --git a/shadowsocks/crypto/rc4_md5.py b/shadowsocks/crypto/rc4_md5.py new file mode 100644 index 0000000..b0105ec --- /dev/null +++ b/shadowsocks/crypto/rc4_md5.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import hashlib + +from shadowsocks.crypto import openssl + +__all__ = ['ciphers'] + + +def create_cipher(alg, key, iv, op, key_as_bytes=0, d=None, salt=None, + i=1, padding=1): + md5 = hashlib.md5() + md5.update(key) + md5.update(iv) + rc4_key = md5.digest() + return openssl.OpenSSLCrypto(b'rc4', rc4_key, b'', op) + + +ciphers = { + 'rc4-md5': (16, 16, create_cipher), + 'rc4-md5-6': (16, 6, create_cipher), +} + + +def test(): + from shadowsocks.crypto import util + + cipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 1) + decipher = create_cipher('rc4-md5', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +if __name__ == '__main__': + test() diff --git a/shadowsocks/crypto/sodium.py b/shadowsocks/crypto/sodium.py new file mode 100644 index 0000000..390941c --- /dev/null +++ b/shadowsocks/crypto/sodium.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +from ctypes import c_char_p, c_int, c_ulong, c_ulonglong, byref, \ + create_string_buffer, c_void_p + +import logging + +from shadowsocks.crypto import util + +__all__ = ['ciphers'] + +libsodium = None +loaded = False + +buf_size = 2048 + +# for salsa20 and chacha20 and chacha20-ietf +BLOCK_SIZE = 64 + + +def load_libsodium(): + global loaded, libsodium, buf + + libsodium = util.find_library('sodium', 'crypto_stream_salsa20_xor_ic', + 'libsodium') + if libsodium is None: + raise Exception('libsodium not found') + + libsodium.crypto_stream_salsa20_xor_ic.restype = c_int + libsodium.crypto_stream_salsa20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + libsodium.crypto_stream_chacha20_xor_ic.restype = c_int + libsodium.crypto_stream_chacha20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + + try: + libsodium.crypto_stream_chacha20_ietf_xor_ic.restype = c_int + libsodium.crypto_stream_chacha20_ietf_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulong, + c_char_p) + except: + logging.info("ChaCha20 IETF not support.") + pass + + try: + libsodium.crypto_stream_xsalsa20_xor_ic.restype = c_int + libsodium.crypto_stream_xsalsa20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + except: + logging.info("XSalsa20 not support.") + pass + + try: + libsodium.crypto_stream_xchacha20_xor_ic.restype = c_int + libsodium.crypto_stream_xchacha20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + except: + logging.info("XChaCha20 not support. XChaCha20 only support since libsodium v1.0.12") + pass + + buf = create_string_buffer(buf_size) + loaded = True + + +class SodiumCrypto(object): + def __init__(self, cipher_name, key, iv, op): + if not loaded: + load_libsodium() + self.key = key + self.iv = iv + self.key_ptr = c_char_p(key) + self.iv_ptr = c_char_p(iv) + if cipher_name == 'salsa20': + self.cipher = libsodium.crypto_stream_salsa20_xor_ic + elif cipher_name == 'chacha20': + self.cipher = libsodium.crypto_stream_chacha20_xor_ic + elif cipher_name == 'chacha20-ietf': + self.cipher = libsodium.crypto_stream_chacha20_ietf_xor_ic + elif cipher_name == 'xchacha20': + self.cipher = libsodium.crypto_stream_xchacha20_xor_ic + elif cipher_name == 'xsalsa20': + self.cipher = libsodium.crypto_stream_xsalsa20_xor_ic + else: + raise Exception('Unknown cipher') + # byte counter, not block counter + self.counter = 0 + + def update(self, data): + global buf_size, buf + l = len(data) + + # we can only prepend some padding to make the encryption align to + # blocks + padding = self.counter % BLOCK_SIZE + if buf_size < padding + l: + buf_size = (padding + l) * 2 + buf = create_string_buffer(buf_size) + + if padding: + data = (b'\0' * padding) + data + self.cipher(byref(buf), c_char_p(data), padding + l, + self.iv_ptr, int(self.counter / BLOCK_SIZE), self.key_ptr) + self.counter += l + # buf is copied to a str object when we access buf.raw + # strip off the padding + return buf.raw[padding:padding + l] + + def clean(self): + pass + + +ciphers = { + 'salsa20': (32, 8, SodiumCrypto), + 'chacha20': (32, 8, SodiumCrypto), + 'chacha20-ietf': (32, 12, SodiumCrypto), + 'xchacha20': (32, 24, SodiumCrypto), + 'xsalsa20': (32, 24, SodiumCrypto), +} + + +def test_salsa20(): + cipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 1) + decipher = SodiumCrypto('salsa20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_chacha20(): + cipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 1) + decipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_chacha20_ietf(): + cipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 1) + decipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 0) + + util.run_cipher(cipher, decipher) + + +def test_xchacha20(): + cipher = SodiumCrypto('xchacha20', b'k' * 32, b'i' * 24, 1) + decipher = SodiumCrypto('xchacha20', b'k' * 32, b'i' * 24, 0) + + util.run_cipher(cipher, decipher) + + +def test_xsalsa20(): + cipher = SodiumCrypto('xsalsa20', b'k' * 32, b'i' * 24, 1) + decipher = SodiumCrypto('xsalsa20', b'k' * 32, b'i' * 24, 0) + + util.run_cipher(cipher, decipher) + + +if __name__ == '__main__': + test_chacha20_ietf() + test_chacha20() + test_salsa20() + test_xchacha20() + test_xsalsa20() diff --git a/shadowsocks/crypto/table.py b/shadowsocks/crypto/table.py new file mode 100644 index 0000000..c7a6e84 --- /dev/null +++ b/shadowsocks/crypto/table.py @@ -0,0 +1,187 @@ +# !/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import string +import struct +import hashlib + + +__all__ = ['ciphers'] + +cached_tables = {} + +if hasattr(string, 'maketrans'): + maketrans = string.maketrans + translate = string.translate +else: + maketrans = bytes.maketrans + translate = bytes.translate + + +def get_table(key): + m = hashlib.md5() + m.update(key) + s = m.digest() + a, b = struct.unpack(' list + """ + find lib in windows in all the directory in path env + + :param name: can end with `.dll` or not + :return: lib results list + """ + # modified from ctypes.util + # ctypes.util.find_library just returns first result he found + # but we want to try them all + # because on Windows, users may have both 32bit and 64bit version installed + results = [] + for directory in os.environ['PATH'].split(os.pathsep): + fname = os.path.join(directory, name) + if os.path.isfile(fname): + results.append(fname) + if fname.lower().endswith(".dll"): + continue + fname = fname + ".dll" + if os.path.isfile(fname): + results.append(fname) + return results + + +def find_library(possible_lib_names, search_symbol, library_name): + import ctypes.util + from ctypes import CDLL + + paths = [] + + if type(possible_lib_names) not in (list, tuple): + possible_lib_names = [possible_lib_names] + + lib_names = [] + for lib_name in possible_lib_names: + lib_names.append(lib_name) + lib_names.append('lib' + lib_name) + + for name in lib_names: + if os.name == "nt": + paths.extend(find_library_nt(name)) + else: + path = ctypes.util.find_library(name) + if path: + paths.append(path) + + # always find lib on extend path that to avoid ```CDLL()``` failed on some strange linux environment + # in that case ```ctypes.util.find_library()``` have different find path from ```CDLL()``` + if True: + # We may get here when find_library fails because, for example, + # the user does not have sufficient privileges to access those + # tools underlying find_library on linux. + import glob + + for name in lib_names: + patterns = [ + '/usr/local/lib*/lib%s.*' % name, + '/usr/lib*/lib%s.*' % name, + 'lib%s.*' % name, + '%s.dll' % name] + + for pat in patterns: + files = glob.glob(pat) + if files: + paths.extend(files) + for path in paths: + try: + lib = CDLL(path) + if hasattr(lib, search_symbol): + logging.info('loading %s from %s', library_name, path) + return lib + else: + logging.warn('can\'t find symbol %s in %s', search_symbol, + path) + except Exception: + if path == paths[-1]: + raise + return None + + +def run_cipher(cipher, decipher): + from os import urandom + import random + import time + + BLOCK_SIZE = 16384 + rounds = 1 * 1024 + plain = urandom(BLOCK_SIZE * rounds) + + results = [] + pos = 0 + print('test start') + start = time.time() + while pos < len(plain): + l = random.randint(100, 32768) + c = cipher.update(plain[pos:pos + l]) + results.append(c) + pos += l + pos = 0 + c = b''.join(results) + results = [] + while pos < len(plain): + l = random.randint(100, 32768) + results.append(decipher.update(c[pos:pos + l])) + pos += l + end = time.time() + print('speed: %d bytes/s' % (BLOCK_SIZE * rounds / (end - start))) + assert b''.join(results) == plain + + +def test_find_library(): + assert find_library('c', 'strcpy', 'libc') is not None + assert find_library(['c'], 'strcpy', 'libc') is not None + assert find_library(('c',), 'strcpy', 'libc') is not None + assert find_library(('crypto', 'eay32'), 'EVP_CipherUpdate', + 'libcrypto') is not None + assert find_library('notexist', 'strcpy', 'libnotexist') is None + assert find_library('c', 'symbol_not_exist', 'c') is None + assert find_library(('notexist', 'c', 'crypto', 'eay32'), + 'EVP_CipherUpdate', 'libc') is not None + + +if __name__ == '__main__': + test_find_library() diff --git a/shadowsocks/daemon.py b/shadowsocks/daemon.py new file mode 100644 index 0000000..8dc5608 --- /dev/null +++ b/shadowsocks/daemon.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2014-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import logging +import signal +import time +from shadowsocks import common, shell + +# this module is ported from ShadowVPN daemon.c + + +def daemon_exec(config): + if 'daemon' in config: + if os.name != 'posix': + raise Exception('daemon mode is only supported on Unix') + command = config['daemon'] + if not command: + command = 'start' + pid_file = config['pid-file'] + log_file = config['log-file'] + if command == 'start': + daemon_start(pid_file, log_file) + elif command == 'stop': + daemon_stop(pid_file) + # always exit after daemon_stop + sys.exit(0) + elif command == 'restart': + daemon_stop(pid_file) + daemon_start(pid_file, log_file) + else: + raise Exception('unsupported daemon command %s' % command) + + +def write_pid_file(pid_file, pid): + import fcntl + import stat + + try: + fd = os.open(pid_file, os.O_RDWR | os.O_CREAT, + stat.S_IRUSR | stat.S_IWUSR) + except OSError as e: + shell.print_exception(e) + return -1 + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + assert flags != -1 + flags |= fcntl.FD_CLOEXEC + r = fcntl.fcntl(fd, fcntl.F_SETFD, flags) + assert r != -1 + # There is no platform independent way to implement fcntl(fd, F_SETLK, &fl) + # via fcntl.fcntl. So use lockf instead + try: + fcntl.lockf(fd, fcntl.LOCK_EX | fcntl.LOCK_NB, 0, 0, os.SEEK_SET) + except IOError: + r = os.read(fd, 32) + if r: + logging.error('already started at pid %s' % common.to_str(r)) + else: + logging.error('already started') + os.close(fd) + return -1 + os.ftruncate(fd, 0) + os.write(fd, common.to_bytes(str(pid))) + return 0 + + +def freopen(f, mode, stream): + oldf = open(f, mode) + oldfd = oldf.fileno() + newfd = stream.fileno() + os.close(newfd) + os.dup2(oldfd, newfd) + + +def daemon_start(pid_file, log_file): + + def handle_exit(signum, _): + if signum == signal.SIGTERM: + sys.exit(0) + sys.exit(1) + + signal.signal(signal.SIGINT, handle_exit) + signal.signal(signal.SIGTERM, handle_exit) + + # fork only once because we are sure parent will exit + pid = os.fork() + assert pid != -1 + + if pid > 0: + # parent waits for its child + time.sleep(5) + sys.exit(0) + + # child signals its parent to exit + ppid = os.getppid() + pid = os.getpid() + if write_pid_file(pid_file, pid) != 0: + os.kill(ppid, signal.SIGINT) + sys.exit(1) + + os.setsid() + signal.signal(signal.SIG_IGN, signal.SIGHUP) + + print('started') + os.kill(ppid, signal.SIGTERM) + + sys.stdin.close() + try: + freopen(log_file, 'a', sys.stdout) + freopen(log_file, 'a', sys.stderr) + except IOError as e: + shell.print_exception(e) + sys.exit(1) + + +def daemon_stop(pid_file): + import errno + try: + with open(pid_file) as f: + buf = f.read() + pid = common.to_str(buf) + if not buf: + logging.error('not running') + except IOError as e: + shell.print_exception(e) + if e.errno == errno.ENOENT: + # always exit 0 if we are sure daemon is not running + logging.error('not running') + return + sys.exit(1) + pid = int(pid) + if pid > 0: + try: + os.kill(pid, signal.SIGTERM) + except OSError as e: + if e.errno == errno.ESRCH: + logging.error('not running') + # always exit 0 if we are sure daemon is not running + return + shell.print_exception(e) + sys.exit(1) + else: + logging.error('pid is not positive: %d', pid) + + # sleep for maximum 10s + for i in range(0, 200): + try: + # query for the pid + os.kill(pid, 0) + except OSError as e: + if e.errno == errno.ESRCH: + break + time.sleep(0.05) + else: + logging.error('timed out when stopping pid %d', pid) + sys.exit(1) + print('stopped') + os.unlink(pid_file) + + +def set_user(username): + if username is None: + return + + import pwd + import grp + + try: + pwrec = pwd.getpwnam(username) + except KeyError: + logging.error('user not found: %s' % username) + raise + user = pwrec[0] + uid = pwrec[2] + gid = pwrec[3] + + cur_uid = os.getuid() + if uid == cur_uid: + return + if cur_uid != 0: + logging.error('can not set user as nonroot user') + # will raise later + + # inspired by supervisor + if hasattr(os, 'setgroups'): + groups = [grprec[2] for grprec in grp.getgrall() if user in grprec[3]] + groups.insert(0, gid) + os.setgroups(groups) + os.setgid(gid) + os.setuid(uid) diff --git a/shadowsocks/encrypt.py b/shadowsocks/encrypt.py new file mode 100644 index 0000000..a9fd02a --- /dev/null +++ b/shadowsocks/encrypt.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python +# +# Copyright 2012-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging + +from shadowsocks import common, lru_cache +from shadowsocks.crypto import rc4_md5, openssl, sodium, table + + +method_supported = {} +method_supported.update(rc4_md5.ciphers) +method_supported.update(openssl.ciphers) +method_supported.update(sodium.ciphers) +method_supported.update(table.ciphers) + + +def random_string(length): + try: + return os.urandom(length) + except NotImplementedError as e: + return openssl.rand_bytes(length) + +cached_keys = lru_cache.LRUCache(timeout=180) + + +def try_cipher(key, method=None): + Encryptor(key, method) + + +def EVP_BytesToKey(password, key_len, iv_len, cache): + # equivalent to OpenSSL's EVP_BytesToKey() with count 1 + # so that we make the same key and iv as nodejs version + cached_key = '%s-%d-%d' % (password, key_len, iv_len) + r = cached_keys.get(cached_key, None) + if r: + return r + m = [] + i = 0 + while len(b''.join(m)) < (key_len + iv_len): + md5 = hashlib.md5() + data = password + if i > 0: + data = m[i - 1] + password + md5.update(data) + m.append(md5.digest()) + i += 1 + ms = b''.join(m) + key = ms[:key_len] + iv = ms[key_len:key_len + iv_len] + if cache: + cached_keys[cached_key] = (key, iv) + cached_keys.sweep() + return key, iv + + +class Encryptor(object): + def __init__(self, key, method, iv = None, cache = False): + self.key = key + self.method = method + self.iv = None + self.iv_sent = False + self.cipher_iv = b'' + self.iv_buf = b'' + self.cipher_key = b'' + self.decipher = None + self.cache = cache + method = method.lower() + self._method_info = self.get_method_info(method) + if self._method_info: + if iv is None or len(iv) != self._method_info[1]: + self.cipher = self.get_cipher(key, method, 1, + random_string(self._method_info[1])) + else: + self.cipher = self.get_cipher(key, method, 1, iv) + else: + logging.error('method %s not supported' % method) + sys.exit(1) + + def get_method_info(self, method): + method = method.lower() + m = method_supported.get(method) + return m + + def iv_len(self): + return len(self.cipher_iv) + + def get_cipher(self, password, method, op, iv): + password = common.to_bytes(password) + m = self._method_info + if m[0] > 0: + key, iv_ = EVP_BytesToKey(password, m[0], m[1], self.cache) + else: + # key_length == 0 indicates we should use the key directly + key, iv = password, b'' + + iv = iv[:m[1]] + if op == 1: + # this iv is for cipher not decipher + self.cipher_iv = iv[:m[1]] + self.cipher_key = key + return m[2](method, key, iv, op) + + def encrypt(self, buf): + if len(buf) == 0: + if not self.iv_sent: + self.iv_sent = True + return self.cipher_iv + return buf + if self.iv_sent: + return self.cipher.update(buf) + else: + self.iv_sent = True + return self.cipher_iv + self.cipher.update(buf) + + def decrypt(self, buf): + if len(buf) == 0: + return buf + if self.decipher is not None: #optimize + return self.decipher.update(buf) + + decipher_iv_len = self._method_info[1] + if len(self.iv_buf) <= decipher_iv_len: + self.iv_buf += buf + if len(self.iv_buf) > decipher_iv_len: + decipher_iv = self.iv_buf[:decipher_iv_len] + self.decipher = self.get_cipher(self.key, self.method, 0, + iv=decipher_iv) + buf = self.iv_buf[decipher_iv_len:] + del self.iv_buf + return self.decipher.update(buf) + else: + return b'' + + def dispose(self): + if self.decipher is not None: + self.decipher.clean() + self.decipher = None + +def encrypt_all(password, method, op, data): + result = [] + method = method.lower() + (key_len, iv_len, m) = method_supported[method] + if key_len > 0: + key, _ = EVP_BytesToKey(password, key_len, iv_len, True) + else: + key = password + if op: + iv = random_string(iv_len) + result.append(iv) + else: + iv = data[:iv_len] + data = data[iv_len:] + cipher = m(method, key, iv, op) + result.append(cipher.update(data)) + return b''.join(result) + +def encrypt_key(password, method): + method = method.lower() + (key_len, iv_len, m) = method_supported[method] + if key_len > 0: + key, _ = EVP_BytesToKey(password, key_len, iv_len, True) + else: + key = password + return key + +def encrypt_iv_len(method): + method = method.lower() + (key_len, iv_len, m) = method_supported[method] + return iv_len + +def encrypt_new_iv(method): + method = method.lower() + (key_len, iv_len, m) = method_supported[method] + return random_string(iv_len) + +def encrypt_all_iv(key, method, op, data, ref_iv): + result = [] + method = method.lower() + (key_len, iv_len, m) = method_supported[method] + if op: + iv = ref_iv[0] + result.append(iv) + else: + iv = data[:iv_len] + data = data[iv_len:] + ref_iv[0] = iv + cipher = m(method, key, iv, op) + result.append(cipher.update(data)) + return b''.join(result) + + +CIPHERS_TO_TEST = [ + 'aes-128-cfb', + 'aes-256-cfb', + 'rc4-md5', + 'salsa20', + 'chacha20', + 'table', +] + + +def test_encryptor(): + from os import urandom + plain = urandom(10240) + for method in CIPHERS_TO_TEST: + logging.warn(method) + encryptor = Encryptor(b'key', method) + decryptor = Encryptor(b'key', method) + cipher = encryptor.encrypt(plain) + plain2 = decryptor.decrypt(cipher) + assert plain == plain2 + + +def test_encrypt_all(): + from os import urandom + plain = urandom(10240) + for method in CIPHERS_TO_TEST: + logging.warn(method) + cipher = encrypt_all(b'key', method, 1, plain) + plain2 = encrypt_all(b'key', method, 0, cipher) + assert plain == plain2 + + +if __name__ == '__main__': + test_encrypt_all() + test_encryptor() diff --git a/shadowsocks/encrypt_test.py b/shadowsocks/encrypt_test.py new file mode 100644 index 0000000..d5e5077 --- /dev/null +++ b/shadowsocks/encrypt_test.py @@ -0,0 +1,51 @@ +from __future__ import absolute_import, division, print_function, \ + with_statement + +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../')) + + +from shadowsocks.crypto import rc4_md5 +from shadowsocks.crypto import openssl +from shadowsocks.crypto import sodium +from shadowsocks.crypto import table + +def run(func): + try: + func() + except: + pass + +def run_n(func, name): + try: + func(name) + except: + pass + +def main(): + print("\n""rc4_md5") + rc4_md5.test() + print("\n""aes-256-cfb") + openssl.test_aes_256_cfb() + print("\n""aes-128-cfb") + openssl.test_aes_128_cfb() + print("\n""bf-cfb") + run(openssl.test_bf_cfb) + print("\n""camellia-128-cfb") + run_n(openssl.run_method, "camellia-128-cfb") + print("\n""cast5-cfb") + run_n(openssl.run_method, "cast5-cfb") + print("\n""idea-cfb") + run_n(openssl.run_method, "idea-cfb") + print("\n""seed-cfb") + run_n(openssl.run_method, "seed-cfb") + print("\n""salsa20") + run(sodium.test_salsa20) + print("\n""chacha20") + run(sodium.test_chacha20) + +if __name__ == '__main__': + main() + diff --git a/shadowsocks/eventloop.py b/shadowsocks/eventloop.py new file mode 100644 index 0000000..341620e --- /dev/null +++ b/shadowsocks/eventloop.py @@ -0,0 +1,258 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2013-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# from ssloop +# https://github.com/clowwindy/ssloop + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import time +import socket +import select +import errno +import logging +from collections import defaultdict + +from shadowsocks import shell + + +__all__ = ['EventLoop', 'POLL_NULL', 'POLL_IN', 'POLL_OUT', 'POLL_ERR', + 'POLL_HUP', 'POLL_NVAL', 'EVENT_NAMES'] + +POLL_NULL = 0x00 +POLL_IN = 0x01 +POLL_OUT = 0x04 +POLL_ERR = 0x08 +POLL_HUP = 0x10 +POLL_NVAL = 0x20 + + +EVENT_NAMES = { + POLL_NULL: 'POLL_NULL', + POLL_IN: 'POLL_IN', + POLL_OUT: 'POLL_OUT', + POLL_ERR: 'POLL_ERR', + POLL_HUP: 'POLL_HUP', + POLL_NVAL: 'POLL_NVAL', +} + +# we check timeouts every TIMEOUT_PRECISION seconds +TIMEOUT_PRECISION = 2 + + +class KqueueLoop(object): + + MAX_EVENTS = 1024 + + def __init__(self): + self._kqueue = select.kqueue() + self._fds = {} + + def _control(self, fd, mode, flags): + events = [] + if mode & POLL_IN: + events.append(select.kevent(fd, select.KQ_FILTER_READ, flags)) + if mode & POLL_OUT: + events.append(select.kevent(fd, select.KQ_FILTER_WRITE, flags)) + for e in events: + self._kqueue.control([e], 0) + + def poll(self, timeout): + if timeout < 0: + timeout = None # kqueue behaviour + events = self._kqueue.control(None, KqueueLoop.MAX_EVENTS, timeout) + results = defaultdict(lambda: POLL_NULL) + for e in events: + fd = e.ident + if e.filter == select.KQ_FILTER_READ: + results[fd] |= POLL_IN + elif e.filter == select.KQ_FILTER_WRITE: + results[fd] |= POLL_OUT + return results.items() + + def register(self, fd, mode): + self._fds[fd] = mode + self._control(fd, mode, select.KQ_EV_ADD) + + def unregister(self, fd): + self._control(fd, self._fds[fd], select.KQ_EV_DELETE) + del self._fds[fd] + + def modify(self, fd, mode): + self.unregister(fd) + self.register(fd, mode) + + def close(self): + self._kqueue.close() + + +class SelectLoop(object): + + def __init__(self): + self._r_list = set() + self._w_list = set() + self._x_list = set() + + def poll(self, timeout): + r, w, x = select.select(self._r_list, self._w_list, self._x_list, + timeout) + results = defaultdict(lambda: POLL_NULL) + for p in [(r, POLL_IN), (w, POLL_OUT), (x, POLL_ERR)]: + for fd in p[0]: + results[fd] |= p[1] + return results.items() + + def register(self, fd, mode): + if mode & POLL_IN: + self._r_list.add(fd) + if mode & POLL_OUT: + self._w_list.add(fd) + if mode & POLL_ERR: + self._x_list.add(fd) + + def unregister(self, fd): + if fd in self._r_list: + self._r_list.remove(fd) + if fd in self._w_list: + self._w_list.remove(fd) + if fd in self._x_list: + self._x_list.remove(fd) + + def modify(self, fd, mode): + self.unregister(fd) + self.register(fd, mode) + + def close(self): + pass + + +class EventLoop(object): + def __init__(self): + if hasattr(select, 'epoll'): + self._impl = select.epoll() + model = 'epoll' + elif hasattr(select, 'kqueue'): + self._impl = KqueueLoop() + model = 'kqueue' + elif hasattr(select, 'select'): + self._impl = SelectLoop() + model = 'select' + else: + raise Exception('can not find any available functions in select ' + 'package') + self._fdmap = {} # (f, handler) + self._last_time = time.time() + self._periodic_callbacks = [] + self._stopping = False + logging.debug('using event model: %s', model) + + def poll(self, timeout=None): + events = self._impl.poll(timeout) + return [(self._fdmap[fd][0], fd, event) for fd, event in events] + + def add(self, f, mode, handler): + fd = f.fileno() + self._fdmap[fd] = (f, handler) + self._impl.register(fd, mode) + + def remove(self, f): + fd = f.fileno() + del self._fdmap[fd] + self._impl.unregister(fd) + + def removefd(self, fd): + del self._fdmap[fd] + self._impl.unregister(fd) + + def add_periodic(self, callback): + self._periodic_callbacks.append(callback) + + def remove_periodic(self, callback): + self._periodic_callbacks.remove(callback) + + def modify(self, f, mode): + fd = f.fileno() + self._impl.modify(fd, mode) + + def stop(self): + self._stopping = True + + def run(self): + events = [] + while not self._stopping: + asap = False + try: + events = self.poll(TIMEOUT_PRECISION) + except (OSError, IOError) as e: + if errno_from_exception(e) in (errno.EPIPE, errno.EINTR): + # EPIPE: Happens when the client closes the connection + # EINTR: Happens when received a signal + # handles them as soon as possible + asap = True + logging.debug('poll:%s', e) + else: + logging.error('poll:%s', e) + import traceback + traceback.print_exc() + continue + + handle = False + for sock, fd, event in events: + handler = self._fdmap.get(fd, None) + if handler is not None: + handler = handler[1] + try: + handle = handler.handle_event(sock, fd, event) or handle + except (OSError, IOError) as e: + shell.print_exception(e) + now = time.time() + if asap or now - self._last_time >= TIMEOUT_PRECISION: + for callback in self._periodic_callbacks: + callback() + self._last_time = now + if events and not handle: + time.sleep(0.001) + + def __del__(self): + self._impl.close() + + +# from tornado +def errno_from_exception(e): + """Provides the errno from an Exception object. + + There are cases that the errno attribute was not set so we pull + the errno out of the args but if someone instatiates an Exception + without any args you will get a tuple error. So this function + abstracts all that behavior to give you a safe way to get the + errno. + """ + + if hasattr(e, 'errno'): + return e.errno + elif e.args: + return e.args[0] + else: + return None + + +# from tornado +def get_sock_error(sock): + error_number = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + return socket.error(error_number, os.strerror(error_number)) diff --git a/shadowsocks/local.py b/shadowsocks/local.py new file mode 100644 index 0000000..9f54d93 --- /dev/null +++ b/shadowsocks/local.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import sys +import os +import logging +import signal + +if __name__ == '__main__': + import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + sys.path.insert(0, os.path.join(file_path, '../')) + +from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, asyncdns + + +def main(): + shell.check_python() + + # fix py2exe + if hasattr(sys, "frozen") and sys.frozen in \ + ("windows_exe", "console_exe"): + p = os.path.dirname(os.path.abspath(sys.executable)) + os.chdir(p) + + config = shell.get_config(True) + + if not config.get('dns_ipv6', False): + asyncdns.IPV6_CONNECTION_SUPPORT = False + + daemon.daemon_exec(config) + logging.info("local start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" % + (config['protocol'], config['password'], config['method'], config['obfs'], config['obfs_param'])) + + try: + logging.info("starting local at %s:%d" % + (config['local_address'], config['local_port'])) + + dns_resolver = asyncdns.DNSResolver() + tcp_server = tcprelay.TCPRelay(config, dns_resolver, True) + udp_server = udprelay.UDPRelay(config, dns_resolver, True) + loop = eventloop.EventLoop() + dns_resolver.add_to_loop(loop) + tcp_server.add_to_loop(loop) + udp_server.add_to_loop(loop) + + def handler(signum, _): + logging.warn('received SIGQUIT, doing graceful shutting down..') + tcp_server.close(next_tick=True) + udp_server.close(next_tick=True) + signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), handler) + + def int_handler(signum, _): + sys.exit(1) + signal.signal(signal.SIGINT, int_handler) + + daemon.set_user(config.get('user', None)) + loop.run() + except Exception as e: + shell.print_exception(e) + sys.exit(1) + +if __name__ == '__main__': + main() diff --git a/shadowsocks/logrun.sh b/shadowsocks/logrun.sh new file mode 100644 index 0000000..fc081e1 --- /dev/null +++ b/shadowsocks/logrun.sh @@ -0,0 +1,8 @@ +#!/bin/bash +cd `dirname $0` +python_ver=$(ls /usr/bin|grep -e "^python[23]\.[1-9]\+$"|tail -1) +eval $(ps -ef | grep "[0-9] ${python_ver} server\\.py a" | awk '{print "kill "$2}') +ulimit -n 512000 +nohup ${python_ver} server.py a>> ssserver.log 2>&1 & + + diff --git a/shadowsocks/lru_cache.py b/shadowsocks/lru_cache.py new file mode 100644 index 0000000..ab0d210 --- /dev/null +++ b/shadowsocks/lru_cache.py @@ -0,0 +1,179 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import collections +import logging +import time + +if __name__ == '__main__': + import os, sys, inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + sys.path.insert(0, os.path.join(file_path, '../')) + +try: + from collections import OrderedDict +except: + from shadowsocks.ordereddict import OrderedDict + +# this LRUCache is optimized for concurrency, not QPS +# n: concurrency, keys stored in the cache +# m: visits not timed out, proportional to QPS * timeout +# get & set is O(1), not O(n). thus we can support very large n +# sweep is O((n - m)) or O(1024) at most, +# no metter how large the cache or timeout value is + +SWEEP_MAX_ITEMS = 1024 + +class LRUCache(collections.MutableMapping): + """This class is not thread safe""" + + def __init__(self, timeout=60, close_callback=None, *args, **kwargs): + self.timeout = timeout + self.close_callback = close_callback + self._store = {} + self._keys_to_last_time = OrderedDict() + self.update(dict(*args, **kwargs)) # use the free update to set keys + + def __getitem__(self, key): + # O(1) + t = time.time() + last_t = self._keys_to_last_time[key] + del self._keys_to_last_time[key] + self._keys_to_last_time[key] = t + return self._store[key] + + def __setitem__(self, key, value): + # O(1) + t = time.time() + if key in self._keys_to_last_time: + del self._keys_to_last_time[key] + self._keys_to_last_time[key] = t + self._store[key] = value + + def __delitem__(self, key): + # O(1) + last_t = self._keys_to_last_time[key] + del self._store[key] + del self._keys_to_last_time[key] + + def __contains__(self, key): + return key in self._store + + def __iter__(self): + return iter(self._store) + + def __len__(self): + return len(self._store) + + def first(self): + if len(self._keys_to_last_time) > 0: + for key in self._keys_to_last_time: + return key + + def sweep(self, sweep_item_cnt = SWEEP_MAX_ITEMS): + # O(n - m) + now = time.time() + c = 0 + while c < sweep_item_cnt: + if len(self._keys_to_last_time) == 0: + break + for key in self._keys_to_last_time: + break + last_t = self._keys_to_last_time[key] + if now - last_t <= self.timeout: + break + value = self._store[key] + del self._store[key] + del self._keys_to_last_time[key] + if self.close_callback is not None: + self.close_callback(value) + c += 1 + if c: + logging.debug('%d keys swept' % c) + return c < SWEEP_MAX_ITEMS + + def clear(self, keep): + now = time.time() + c = 0 + while len(self._keys_to_last_time) > keep: + if len(self._keys_to_last_time) == 0: + break + for key in self._keys_to_last_time: + break + last_t = self._keys_to_last_time[key] + value = self._store[key] + if self.close_callback is not None: + self.close_callback(value) + del self._store[key] + del self._keys_to_last_time[key] + c += 1 + if c: + logging.debug('%d keys swept' % c) + return c < SWEEP_MAX_ITEMS + +def test(): + c = LRUCache(timeout=0.3) + + c['a'] = 1 + assert c['a'] == 1 + c['a'] = 1 + + time.sleep(0.5) + c.sweep() + assert 'a' not in c + + c['a'] = 2 + c['b'] = 3 + time.sleep(0.2) + c.sweep() + assert c['a'] == 2 + assert c['b'] == 3 + + time.sleep(0.2) + c.sweep() + c['b'] + time.sleep(0.2) + c.sweep() + assert 'a' not in c + assert c['b'] == 3 + + time.sleep(0.5) + c.sweep() + assert 'a' not in c + assert 'b' not in c + + global close_cb_called + close_cb_called = False + + def close_cb(t): + global close_cb_called + assert not close_cb_called + close_cb_called = True + + c = LRUCache(timeout=0.1, close_callback=close_cb) + c['s'] = 1 + c['s'] + time.sleep(0.1) + c['s'] + time.sleep(0.3) + c.sweep() + +if __name__ == '__main__': + test() diff --git a/shadowsocks/manager.py b/shadowsocks/manager.py new file mode 100644 index 0000000..80d0a32 --- /dev/null +++ b/shadowsocks/manager.py @@ -0,0 +1,291 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import errno +import traceback +import socket +import logging +import json +import collections + +from shadowsocks import common, eventloop, tcprelay, udprelay, asyncdns, shell + + +BUF_SIZE = 1506 +STAT_SEND_LIMIT = 50 + + +class Manager(object): + + def __init__(self, config): + self._config = config + self._relays = {} # (tcprelay, udprelay) + self._loop = eventloop.EventLoop() + self._dns_resolver = asyncdns.DNSResolver() + self._dns_resolver.add_to_loop(self._loop) + + self._statistics = collections.defaultdict(int) + self._control_client_addr = None + try: + manager_address = common.to_str(config['manager_address']) + if ':' in manager_address: + addr = manager_address.rsplit(':', 1) + addr = addr[0], int(addr[1]) + addrs = socket.getaddrinfo(addr[0], addr[1]) + if addrs: + family = addrs[0][0] + else: + logging.error('invalid address: %s', manager_address) + exit(1) + else: + addr = manager_address + family = socket.AF_UNIX + self._control_socket = socket.socket(family, + socket.SOCK_DGRAM) + self._control_socket.bind(addr) + self._control_socket.setblocking(False) + except (OSError, IOError) as e: + logging.error(e) + logging.error('can not bind to manager address') + exit(1) + self._loop.add(self._control_socket, + eventloop.POLL_IN, self) + self._loop.add_periodic(self.handle_periodic) + + port_password = config['port_password'] + del config['port_password'] + for port, password in port_password.items(): + a_config = config.copy() + a_config['server_port'] = int(port) + a_config['password'] = password + self.add_port(a_config) + + def add_port(self, config): + port = int(config['server_port']) + servers = self._relays.get(port, None) + if servers: + logging.error("server already exists at %s:%d" % (config['server'], + port)) + return + logging.info("adding server at %s:%d" % (config['server'], port)) + t = tcprelay.TCPRelay(config, self._dns_resolver, False, + stat_callback=self.stat_callback) + u = udprelay.UDPRelay(config, self._dns_resolver, False, + stat_callback=self.stat_callback) + t.add_to_loop(self._loop) + u.add_to_loop(self._loop) + self._relays[port] = (t, u) + + def remove_port(self, config): + port = int(config['server_port']) + servers = self._relays.get(port, None) + if servers: + logging.info("removing server at %s:%d" % (config['server'], port)) + t, u = servers + t.close(next_tick=False) + u.close(next_tick=False) + del self._relays[port] + else: + logging.error("server not exist at %s:%d" % (config['server'], + port)) + + def handle_event(self, sock, fd, event): + if sock == self._control_socket and event == eventloop.POLL_IN: + data, self._control_client_addr = sock.recvfrom(BUF_SIZE) + parsed = self._parse_command(data) + if parsed: + command, config = parsed + a_config = self._config.copy() + if config: + # let the command override the configuration file + a_config.update(config) + if 'server_port' not in a_config: + logging.error('can not find server_port in config') + else: + if command == 'add': + self.add_port(a_config) + self._send_control_data(b'ok') + elif command == 'remove': + self.remove_port(a_config) + self._send_control_data(b'ok') + elif command == 'ping': + self._send_control_data(b'pong') + else: + logging.error('unknown command %s', command) + + def _parse_command(self, data): + # commands: + # add: {"server_port": 8000, "password": "foobar"} + # remove: {"server_port": 8000"} + data = common.to_str(data) + parts = data.split(':', 1) + if len(parts) < 2: + return data, None + command, config_json = parts + try: + config = shell.parse_json_in_str(config_json) + return command, config + except Exception as e: + logging.error(e) + return None + + def stat_callback(self, port, data_len): + self._statistics[port] += data_len + + def handle_periodic(self): + r = {} + i = 0 + + def send_data(data_dict): + if data_dict: + # use compact JSON format (without space) + data = common.to_bytes(json.dumps(data_dict, + separators=(',', ':'))) + self._send_control_data(b'stat: ' + data) + + for k, v in self._statistics.items(): + r[k] = v + i += 1 + # split the data into segments that fit in UDP packets + if i >= STAT_SEND_LIMIT: + send_data(r) + r.clear() + i = 0 + if len(r) > 0 : + send_data(r) + self._statistics.clear() + + def _send_control_data(self, data): + if self._control_client_addr: + try: + self._control_socket.sendto(data, self._control_client_addr) + except (socket.error, OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + + def run(self): + self._loop.run() + + +def run(config): + Manager(config).run() + + +def test(): + import time + import threading + import struct + from shadowsocks import encrypt + + logging.basicConfig(level=5, + format='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + enc = [] + eventloop.TIMEOUT_PRECISION = 1 + + def run_server(): + config = shell.get_config(True) + config = config.copy() + a_config = { + 'server': '127.0.0.1', + 'local_port': 1081, + 'port_password': { + '8381': 'foobar1', + '8382': 'foobar2' + }, + 'method': 'aes-256-cfb', + 'manager_address': '127.0.0.1:6001', + 'timeout': 60, + 'fast_open': False, + 'verbose': 2 + } + config.update(a_config) + manager = Manager(config) + enc.append(manager) + manager.run() + + t = threading.Thread(target=run_server) + t.start() + time.sleep(1) + manager = enc[0] + cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + cli.connect(('127.0.0.1', 6001)) + + # test add and remove + time.sleep(1) + cli.send(b'add: {"server_port":7001, "password":"asdfadsfasdf"}') + time.sleep(1) + assert 7001 in manager._relays + data, addr = cli.recvfrom(1506) + assert b'ok' in data + + cli.send(b'remove: {"server_port":8381}') + time.sleep(1) + assert 8381 not in manager._relays + data, addr = cli.recvfrom(1506) + assert b'ok' in data + logging.info('add and remove test passed') + + # test statistics for TCP + header = common.pack_addr(b'google.com') + struct.pack('>H', 80) + data = encrypt.encrypt_all(b'asdfadsfasdf', 'aes-256-cfb', 1, + header + b'GET /\r\n\r\n') + tcp_cli = socket.socket() + tcp_cli.connect(('127.0.0.1', 7001)) + tcp_cli.send(data) + tcp_cli.recv(4096) + tcp_cli.close() + + data, addr = cli.recvfrom(1506) + data = common.to_str(data) + assert data.startswith('stat: ') + data = data.split('stat:')[1] + stats = shell.parse_json_in_str(data) + assert '7001' in stats + logging.info('TCP statistics test passed') + + # test statistics for UDP + header = common.pack_addr(b'127.0.0.1') + struct.pack('>H', 80) + data = encrypt.encrypt_all(b'foobar2', 'aes-256-cfb', 1, + header + b'test') + udp_cli = socket.socket(type=socket.SOCK_DGRAM) + udp_cli.sendto(data, ('127.0.0.1', 8382)) + tcp_cli.close() + + data, addr = cli.recvfrom(1506) + data = common.to_str(data) + assert data.startswith('stat: ') + data = data.split('stat:')[1] + stats = json.loads(data) + assert '8382' in stats + logging.info('UDP statistics test passed') + + manager._loop.stop() + t.join() + + +if __name__ == '__main__': + test() diff --git a/shadowsocks/obfs.py b/shadowsocks/obfs.py new file mode 100644 index 0000000..499ca1b --- /dev/null +++ b/shadowsocks/obfs.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging + +from shadowsocks import common +from shadowsocks.obfsplugin import plain, http_simple, obfs_tls, verify, auth, auth_chain + + +method_supported = {} +method_supported.update(plain.obfs_map) +method_supported.update(http_simple.obfs_map) +method_supported.update(obfs_tls.obfs_map) +method_supported.update(verify.obfs_map) +method_supported.update(auth.obfs_map) +method_supported.update(auth_chain.obfs_map) + +def mu_protocol(): + return ["auth_aes128_md5", "auth_aes128_sha1", "auth_chain_a", "auth_chain_b", "auth_chain_c", "auth_chain_d", "auth_chain_e"] + +class server_info(object): + def __init__(self, data): + self.data = data + +class obfs(object): + def __init__(self, method): + method = common.to_str(method) + self.method = method + self._method_info = self.get_method_info(method) + if self._method_info: + self.obfs = self.get_obfs(method) + else: + raise Exception('obfs plugin [%s] not supported' % method) + + def init_data(self): + return self.obfs.init_data() + + def set_server_info(self, server_info): + return self.obfs.set_server_info(server_info) + + def get_server_info(self): + return self.obfs.get_server_info() + + def get_method_info(self, method): + method = method.lower() + m = method_supported.get(method) + return m + + def get_obfs(self, method): + m = self._method_info + return m[0](method) + + def get_overhead(self, direction): + return self.obfs.get_overhead(direction) + + def client_pre_encrypt(self, buf): + return self.obfs.client_pre_encrypt(buf) + + def client_encode(self, buf): + return self.obfs.client_encode(buf) + + def client_decode(self, buf): + return self.obfs.client_decode(buf) + + def client_post_decrypt(self, buf): + return self.obfs.client_post_decrypt(buf) + + def server_pre_encrypt(self, buf): + return self.obfs.server_pre_encrypt(buf) + + def server_encode(self, buf): + return self.obfs.server_encode(buf) + + def server_decode(self, buf): + return self.obfs.server_decode(buf) + + def server_post_decrypt(self, buf): + return self.obfs.server_post_decrypt(buf) + + def client_udp_pre_encrypt(self, buf): + return self.obfs.client_udp_pre_encrypt(buf) + + def client_udp_post_decrypt(self, buf): + return self.obfs.client_udp_post_decrypt(buf) + + def server_udp_pre_encrypt(self, buf, uid): + return self.obfs.server_udp_pre_encrypt(buf, uid) + + def server_udp_post_decrypt(self, buf): + return self.obfs.server_udp_post_decrypt(buf) + + def dispose(self): + self.obfs.dispose() + del self.obfs + diff --git a/shadowsocks/obfsplugin/__init__.py b/shadowsocks/obfsplugin/__init__.py new file mode 100644 index 0000000..401c7b7 --- /dev/null +++ b/shadowsocks/obfsplugin/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/shadowsocks/obfsplugin/__pycache__/__init__.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..bfd0baa Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/__init__.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/__pycache__/auth.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/auth.cpython-37.pyc new file mode 100644 index 0000000..c7dbe08 Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/auth.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/__pycache__/auth_chain.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/auth_chain.cpython-37.pyc new file mode 100644 index 0000000..0aa1dbd Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/auth_chain.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/__pycache__/http_simple.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/http_simple.cpython-37.pyc new file mode 100644 index 0000000..4d2716f Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/http_simple.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/__pycache__/obfs_tls.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/obfs_tls.cpython-37.pyc new file mode 100644 index 0000000..03b3ade Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/obfs_tls.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/__pycache__/plain.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/plain.cpython-37.pyc new file mode 100644 index 0000000..44536db Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/plain.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/__pycache__/verify.cpython-37.pyc b/shadowsocks/obfsplugin/__pycache__/verify.cpython-37.pyc new file mode 100644 index 0000000..c027333 Binary files /dev/null and b/shadowsocks/obfsplugin/__pycache__/verify.cpython-37.pyc differ diff --git a/shadowsocks/obfsplugin/auth.py b/shadowsocks/obfsplugin/auth.py new file mode 100644 index 0000000..a745e09 --- /dev/null +++ b/shadowsocks/obfsplugin/auth.py @@ -0,0 +1,787 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import base64 +import time +import datetime +import random +import math +import struct +import zlib +import hmac +import hashlib + +import shadowsocks +from shadowsocks import common, lru_cache, encrypt +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord, chr + +def create_auth_sha1_v4(method): + return auth_sha1_v4(method) + +def create_auth_aes128_md5(method): + return auth_aes128_sha1(method, hashlib.md5) + +def create_auth_aes128_sha1(method): + return auth_aes128_sha1(method, hashlib.sha1) + +obfs_map = { + 'auth_sha1_v4': (create_auth_sha1_v4,), + 'auth_sha1_v4_compatible': (create_auth_sha1_v4,), + 'auth_aes128_md5': (create_auth_aes128_md5,), + 'auth_aes128_sha1': (create_auth_aes128_sha1,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class auth_base(plain.plain): + def __init__(self, method): + super(auth_base, self).__init__(method) + self.method = method + self.no_compatible_method = '' + self.overhead = 7 + + def init_data(self): + return '' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + + def not_match_return(self, buf): + self.raw_trans = True + self.overhead = 0 + if self.method == self.no_compatible_method: + return (b'E'*2048, False) + return (buf, False) + +class client_queue(object): + def __init__(self, begin_id): + self.front = begin_id - 64 + self.back = begin_id + 1 + self.alloc = {} + self.enable = True + self.last_update = time.time() + + def update(self): + self.last_update = time.time() + + def is_active(self): + return time.time() - self.last_update < 60 * 3 + + def re_enable(self, connection_id): + self.enable = True + self.front = connection_id - 64 + self.back = connection_id + 1 + self.alloc = {} + + def insert(self, connection_id): + if not self.enable: + logging.warn('obfs auth: not enable') + return False + if not self.is_active(): + self.re_enable(connection_id) + self.update() + if connection_id < self.front: + logging.warn('obfs auth: deprecated id, someone replay attack') + return False + if connection_id > self.front + 0x4000: + logging.warn('obfs auth: wrong id') + return False + if connection_id in self.alloc: + logging.warn('obfs auth: duplicate id, someone replay attack') + return False + if self.back <= connection_id: + self.back = connection_id + 1 + self.alloc[connection_id] = 1 + while (self.front in self.alloc) or self.front + 0x1000 < self.back: + if self.front in self.alloc: + del self.alloc[self.front] + self.front += 1 + return True + +class obfs_auth_v2_data(object): + def __init__(self): + self.client_id = lru_cache.LRUCache() + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, client_id, connection_id): + if client_id in self.client_id: + self.client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, client_id, connection_id): + if self.client_id.get(client_id, None) is None or not self.client_id[client_id].enable: + if self.client_id.first() is None or len(self.client_id) < self.max_client: + if client_id not in self.client_id: + #TODO: check + self.client_id[client_id] = client_queue(connection_id) + else: + self.client_id[client_id].re_enable(connection_id) + return self.client_id[client_id].insert(connection_id) + + if not self.client_id[self.client_id.first()].is_active(): + del self.client_id[self.client_id.first()] + if client_id not in self.client_id: + #TODO: check + self.client_id[client_id] = client_queue(connection_id) + else: + self.client_id[client_id].re_enable(connection_id) + return self.client_id[client_id].insert(connection_id) + + logging.warn('auth_sha1_v2: no inactive client') + return False + else: + return self.client_id[client_id].insert(connection_id) + +class auth_sha1_v4(auth_base): + def __init__(self, method): + super(auth_sha1_v4, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 8100 + self.decrypt_packet_num = 0 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = b"auth_sha1_v4" + self.no_compatible_method = 'auth_sha1_v4' + + def init_data(self): + return obfs_auth_v2_data() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def rnd_data(self, buf_size): + if buf_size > 1200: + return b'\x01' + + if buf_size > 400: + rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 256) + else: + rnd_data = os.urandom(struct.unpack('>H', os.urandom(2))[0] % 512) + + if len(rnd_data) < 128: + return common.chr(len(rnd_data) + 1) + rnd_data + else: + return common.chr(255) + struct.pack('>H', len(rnd_data) + 3) + rnd_data + + def pack_data(self, buf): + data = self.rnd_data(len(buf)) + buf + data_len = len(data) + 8 + crc = binascii.crc32(struct.pack('>H', data_len)) & 0xFFFF + data = struct.pack('H', data_len) + data + adler32 = zlib.adler32(data) & 0xFFFFFFFF + data += struct.pack('H', data_len) + self.salt + self.server_info.key) & 0xFFFFFFFF + data = struct.pack('H', data_len) + data + data += hmac.new(self.server_info.iv + self.server_info.key, data, hashlib.sha1).digest()[:10] + return data + + def auth_data(self): + utc_time = int(time.time()) & 0xFFFFFFFF + if self.server_info.data.connection_id > 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = os.urandom(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + crc = struct.pack('H', self.recv_buf[:2])[0] + if length >= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + if length > len(self.recv_buf): + break + + if struct.pack('H', self.recv_buf[5:7])[0] + 4 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + if self.raw_trans: + return buf + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) <= 6: + return (b'', False) + crc = struct.pack('H', self.recv_buf[:2])[0] + if length > len(self.recv_buf): + return (b'', False) + sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, self.recv_buf[:length - 10], hashlib.sha1).digest()[:10] + if sha1data != self.recv_buf[length - 10:length]: + logging.error('auth_sha1_v4 data uncorrect auth HMAC-SHA1') + return self.not_match_return(self.recv_buf) + pos = common.ord(self.recv_buf[6]) + if pos < 255: + pos += 6 + else: + pos = struct.unpack('>H', self.recv_buf[7:9])[0] + 6 + out_buf = self.recv_buf[pos:length - 10] + if len(out_buf) < 12: + logging.info('auth_sha1_v4: too short, data %s' % (binascii.hexlify(self.recv_buf),)) + return self.not_match_return(self.recv_buf) + utc_time = struct.unpack(' self.max_time_dif: + logging.info('auth_sha1_v4: wrong timestamp, time_dif %d, data %s' % (time_dif, binascii.hexlify(out_buf),)) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(client_id, connection_id): + self.has_recv_header = True + out_buf = out_buf[12:] + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('auth_sha1_v4: auth fail, data %s' % (binascii.hexlify(out_buf),)) + return self.not_match_return(self.recv_buf) + self.recv_buf = self.recv_buf[length:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + crc = struct.pack('H', self.recv_buf[:2])[0] + if length >= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + logging.info('auth_sha1_v4: over size') + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if struct.pack('H', self.recv_buf[5:7])[0] + 4 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + if pos == length - 4: + sendback = True + + if out_buf: + self.server_info.data.update(self.client_id, self.connection_id) + self.decrypt_packet_num += 1 + return (out_buf, sendback) + +class obfs_auth_mu_data(object): + def __init__(self): + self.user_id = {} + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if client_id in local_client_id: + local_client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: + if local_client_id.first() is None or len(local_client_id) < self.max_client: + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + if not local_client_id[local_client_id.first()].is_active(): + del local_client_id[local_client_id.first()] + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + logging.warn('auth_aes128: no inactive client') + return False + else: + return local_client_id[client_id].insert(connection_id) + +class auth_aes128_sha1(auth_base): + def __init__(self, method, hashfunc): + super(auth_aes128_sha1, self).__init__(method) + self.hashfunc = hashfunc + self.recv_buf = b'' + self.unit_len = 8100 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = hashfunc == hashlib.md5 and b"auth_aes128_md5" or b"auth_aes128_sha1" + self.no_compatible_method = hashfunc == hashlib.md5 and "auth_aes128_md5" or 'auth_aes128_sha1' + self.extra_wait_size = struct.unpack('>H', os.urandom(2))[0] % 1024 + self.pack_id = 1 + self.recv_id = 1 + self.user_id = None + self.user_key = None + self.last_rnd_len = 0 + self.overhead = 9 + + def init_data(self): + return obfs_auth_mu_data() + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def trapezoid_random_float(self, d): + if d == 0: + return random.random() + s = random.random() + a = 1 - d + return (math.sqrt(a * a + 4 * d * s) - a) / (2 * d) + + def trapezoid_random_int(self, max_val, d): + v = self.trapezoid_random_float(d) + return int(v * max_val) + + def rnd_data_len(self, buf_size, full_buf_size): + if full_buf_size >= self.server_info.buffer_size: + return 0 + tcp_mss = self.server_info.tcp_mss + rev_len = tcp_mss - buf_size - 9 + if rev_len == 0: + return 0 + if rev_len < 0: + if rev_len > -tcp_mss: + return self.trapezoid_random_int(rev_len + tcp_mss, -0.3) + return common.ord(os.urandom(1)[0]) % 32 + if buf_size > 900: + return struct.unpack('>H', os.urandom(2))[0] % rev_len + return self.trapezoid_random_int(rev_len, -0.3) + + def rnd_data(self, buf_size, full_buf_size): + data_len = self.rnd_data_len(buf_size, full_buf_size) + + if data_len < 128: + return common.chr(data_len + 1) + os.urandom(data_len) + + return common.chr(255) + struct.pack(' 400: + rnd_len = struct.unpack(' 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = os.urandom(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len], ogn_data_len) + buf = buf[self.unit_len:] + ret += self.pack_data(buf, ogn_data_len) + self.last_rnd_len = ogn_data_len + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + if length > len(self.recv_buf): + break + + if hmac.new(mac_key, self.recv_buf[:length - 4], self.hashfunc).digest()[:4] != self.recv_buf[length - 4:length]: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = common.ord(self.recv_buf[4]) + if pos < 255: + pos += 4 + else: + pos = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len], ogn_data_len) + buf = buf[self.unit_len:] + ret += self.pack_data(buf, ogn_data_len) + self.last_rnd_len = ogn_data_len + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) >= 7 or len(self.recv_buf) in [2, 3]: + recv_len = min(len(self.recv_buf), 7) + mac_key = self.server_info.recv_iv + self.server_info.key + sha1data = hmac.new(mac_key, self.recv_buf[:1], self.hashfunc).digest()[:recv_len - 1] + if sha1data != self.recv_buf[1:recv_len]: + return self.not_match_return(self.recv_buf) + + if len(self.recv_buf) < 31: + return (b'', False) + sha1data = hmac.new(mac_key, self.recv_buf[7:27], self.hashfunc).digest()[:4] + if sha1data != self.recv_buf[27:31]: + logging.error('%s data uncorrect auth HMAC-SHA1 from %s:%d, data %s' % (self.no_compatible_method, self.server_info.client, self.server_info.client_port, binascii.hexlify(self.recv_buf))) + if len(self.recv_buf) < 31 + self.extra_wait_size: + return (b'', False) + return self.not_match_return(self.recv_buf) + + uid = self.recv_buf[7:11] + if uid in self.server_info.users: + self.user_id = uid + self.user_key = self.hashfunc(self.server_info.users[uid]).digest() + self.server_info.update_user_func(uid) + else: + if not self.server_info.users: + self.user_key = self.server_info.key + else: + self.user_key = self.server_info.recv_iv + encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc') + head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[11:27] + b'\x00') # need an extra byte or recv empty + length = struct.unpack(' self.max_time_dif: + logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(self.user_id, client_id, connection_id): + self.has_recv_header = True + out_buf = self.recv_buf[31 + rnd_len:length - 4] + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) + return self.not_match_return(self.recv_buf) + self.recv_buf = self.recv_buf[length:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 0: + logging.info(self.no_compatible_method + ': over size') + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if hmac.new(mac_key, self.recv_buf[:length - 4], self.hashfunc).digest()[:4] != self.recv_buf[length - 4:length]: + logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 0: + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = common.ord(self.recv_buf[4]) + if pos < 255: + pos += 4 + else: + pos = struct.unpack('> 17) ^ (y >> 26)) + self.v1 = x + return (x + y) & xorshift128plus.max_int + + def init_from_bin(self, bin): + if len(bin) < 16: + bin += b'\0' * 16 + self.v0 = struct.unpack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + + +class auth_base(plain.plain): + def __init__(self, method): + super(auth_base, self).__init__(method) + self.method = method + self.no_compatible_method = '' + self.overhead = 4 + + def init_data(self): + return '' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + + def not_match_return(self, buf): + self.raw_trans = True + self.overhead = 0 + if self.method == self.no_compatible_method: + return (b'E' * 2048, False) + return (buf, False) + + +class client_queue(object): + def __init__(self, begin_id): + self.front = begin_id - 64 + self.back = begin_id + 1 + self.alloc = {} + self.enable = True + self.last_update = time.time() + self.ref = 0 + + def update(self): + self.last_update = time.time() + + def addref(self): + self.ref += 1 + + def delref(self): + if self.ref > 0: + self.ref -= 1 + + def is_active(self): + return (self.ref > 0) and (time.time() - self.last_update < 60 * 10) + + def re_enable(self, connection_id): + self.enable = True + self.front = connection_id - 64 + self.back = connection_id + 1 + self.alloc = {} + + def insert(self, connection_id): + if not self.enable: + logging.warn('obfs auth: not enable') + return False + if not self.is_active(): + self.re_enable(connection_id) + self.update() + if connection_id < self.front: + logging.warn('obfs auth: deprecated id, someone replay attack') + return False + if connection_id > self.front + 0x4000: + logging.warn('obfs auth: wrong id') + return False + if connection_id in self.alloc: + logging.warn('obfs auth: duplicate id, someone replay attack') + return False + if self.back <= connection_id: + self.back = connection_id + 1 + self.alloc[connection_id] = 1 + while (self.front in self.alloc) or self.front + 0x1000 < self.back: + if self.front in self.alloc: + del self.alloc[self.front] + self.front += 1 + self.addref() + return True + + +class obfs_auth_chain_data(object): + def __init__(self, name): + self.name = name + self.user_id = {} + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if client_id in local_client_id: + local_client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: + if local_client_id.first() is None or len(local_client_id) < self.max_client: + if client_id not in local_client_id: + # TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + if not local_client_id[local_client_id.first()].is_active(): + del local_client_id[local_client_id.first()] + if client_id not in local_client_id: + # TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + logging.warn(self.name + ': no inactive client') + return False + else: + return local_client_id[client_id].insert(connection_id) + + def remove(self, user_id, client_id): + if user_id in self.user_id: + local_client_id = self.user_id[user_id] + if client_id in local_client_id: + local_client_id[client_id].delref() + + +class auth_chain_a(auth_base): + def __init__(self, method): + super(auth_chain_a, self).__init__(method) + self.hashfunc = hashlib.md5 + self.recv_buf = b'' + self.unit_len = 2800 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = b"auth_chain_a" + self.no_compatible_method = 'auth_chain_a' + self.pack_id = 1 + self.recv_id = 1 + self.user_id = None + self.user_id_num = 0 + self.user_key = None + self.overhead = 4 + self.client_over_head = 4 + self.last_client_hash = b'' + self.last_server_hash = b'' + self.random_client = xorshift128plus() + self.random_server = xorshift128plus() + self.encryptor = None + + def init_data(self): + return obfs_auth_chain_data(self.method) + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def trapezoid_random_float(self, d): + if d == 0: + return random.random() + s = random.random() + a = 1 - d + return (math.sqrt(a * a + 4 * d * s) - a) / (2 * d) + + def trapezoid_random_int(self, max_val, d): + v = self.trapezoid_random_float(d) + return int(v * max_val) + + def rnd_data_len(self, buf_size, last_hash, random): + if buf_size > 1440: + return 0 + random.init_from_bin_len(last_hash, buf_size) + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % 1021 + + def udp_rnd_data_len(self, last_hash, random): + random.init_from_bin(last_hash) + return random.next() % 127 + + def rnd_start_pos(self, rand_len, random): + if rand_len > 0: + return random.next() % 8589934609 % rand_len + return 0 + + def rnd_data(self, buf_size, buf, last_hash, random): + rand_len = self.rnd_data_len(buf_size, last_hash, random) + + rnd_data_buf = rand_bytes(rand_len) + + if buf_size == 0: + return rnd_data_buf + else: + if rand_len > 0: + start_pos = self.rnd_start_pos(rand_len, random) + return rnd_data_buf[:start_pos] + buf + rnd_data_buf[start_pos:] + else: + return buf + + def pack_client_data(self, buf): + buf = self.encryptor.encrypt(buf) + data = self.rnd_data(len(buf), buf, self.last_client_hash, self.random_client) + mac_key = self.user_key + struct.pack(' 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = rand_bytes(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_client_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_client_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 4096: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + + if length + 4 > len(self.recv_buf): + break + + server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() + if server_hash[:2] != self.recv_buf[length + 2: length + 4]: + logging.info('%s: checksum error, data %s' + % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data uncorrect checksum') + + pos = 2 + if data_len > 0 and rand_len > 0: + pos = 2 + self.rnd_start_pos(rand_len, self.random_server) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) + self.last_server_hash = server_hash + if self.recv_id == 1: + self.server_info.tcp_mss = struct.unpack(' self.unit_len: + ret += self.pack_server_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_server_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) >= 12 or len(self.recv_buf) in [7, 8]: + recv_len = min(len(self.recv_buf), 12) + mac_key = self.server_info.recv_iv + self.server_info.key + md5data = hmac.new(mac_key, self.recv_buf[:4], self.hashfunc).digest() + if md5data[:recv_len - 4] != self.recv_buf[4:recv_len]: + return self.not_match_return(self.recv_buf) + + if len(self.recv_buf) < 12 + 24: + return (b'', False) + + self.last_client_hash = md5data + uid = struct.unpack(' self.max_time_dif: + logging.info('%s: wrong timestamp, time_dif %d, data %s' % ( + self.no_compatible_method, time_dif, binascii.hexlify(head) + )) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(self.user_id, client_id, connection_id): + self.has_recv_header = True + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) + return self.not_match_return(self.recv_buf) + + self.on_recv_auth_data(utc_time) + self.encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') + self.recv_buf = self.recv_buf[36:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 4096: + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 1: + logging.info(self.no_compatible_method + ': over size') + return (b'E' * 2048, False) + else: + raise Exception('server_post_decrype data error') + + if length + 4 > len(self.recv_buf): + break + + client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() + if client_hash[:2] != self.recv_buf[length + 2: length + 4]: + logging.info('%s: checksum error, data %s' % ( + self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]) + )) + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 1: + return (b'E' * 2048, False) + else: + raise Exception('server_post_decrype data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = 2 + if data_len > 0 and rand_len > 0: + pos = 2 + self.rnd_start_pos(rand_len, self.random_client) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) + self.last_client_hash = client_hash + self.recv_buf = self.recv_buf[length + 4:] + if data_len == 0: + sendback = True + + if out_buf: + self.server_info.data.update(self.user_id, self.client_id, self.connection_id) + return (out_buf, sendback) + + def client_udp_pre_encrypt(self, buf): + if self.user_key is None: + if b':' in to_bytes(self.server_info.protocol_param): + try: + items = to_bytes(self.server_info.protocol_param).split(':') + self.user_key = self.hashfunc(items[1]).digest() + self.user_id = struct.pack('= 1440: + return 0 + random.init_from_bin_len(last_hash, buf_size) + pos = bisect.bisect_left(self.data_size_list, buf_size + self.server_info.overhead) + final_pos = pos + random.next() % (len(self.data_size_list)) + # 假设random均匀分布,则越长的原始数据长度越容易if false + if final_pos < len(self.data_size_list): + return self.data_size_list[final_pos] - buf_size - self.server_info.overhead + + # 上面if false后选择2号补全数组,此处有更精细的长度分段 + pos = bisect.bisect_left(self.data_size_list2, buf_size + self.server_info.overhead) + final_pos = pos + random.next() % (len(self.data_size_list2)) + if final_pos < len(self.data_size_list2): + return self.data_size_list2[final_pos] - buf_size - self.server_info.overhead + # final_pos 总是分布在pos~(data_size_list2.len-1)之间 + if final_pos < pos + len(self.data_size_list2) - 1: + return 0 + # 有1/len(self.data_size_list2)的概率不满足上一个if + + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % 1021 + + +class auth_chain_c(auth_chain_b): + def __init__(self, method): + super(auth_chain_c, self).__init__(method) + self.salt = b"auth_chain_c" + self.no_compatible_method = 'auth_chain_c' + self.data_size_list0 = [] + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + random.init_from_bin(key) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + self.init_data_size(self.server_info.key) + + def rnd_data_len(self, buf_size, last_hash, random): + other_data_size = buf_size + self.server_info.overhead + # 一定要在random使用前初始化,以保证服务器与客户端同步,保证包大小验证结果正确 + random.init_from_bin_len(last_hash, buf_size) + # final_pos 总是分布在pos~(data_size_list0.len-1)之间 + # 除非data_size_list0中的任何值均过小使其全部都无法容纳buf + if other_data_size >= self.data_size_list0[-1]: + if other_data_size >= 1440: + return 0 + if other_data_size > 1300: + return random.next() % 31 + if other_data_size > 900: + return random.next() % 127 + if other_data_size > 400: + return random.next() % 521 + return random.next() % 1021 + + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + # random select a size in the leftover data_size_list0 + final_pos = pos + random.next() % (len(self.data_size_list0) - pos) + return self.data_size_list0[final_pos] - other_data_size + + +class auth_chain_d(auth_chain_b): + def __init__(self, method): + super(auth_chain_d, self).__init__(method) + self.salt = b"auth_chain_d" + self.no_compatible_method = 'auth_chain_d' + self.data_size_list0 = [] + + def check_and_patch_data_size(self, random): + # append new item + # when the biggest item(first time) or the last append item(other time) are not big enough. + # but set a limit size (64) to avoid stack overflow. + if self.data_size_list0[-1] < 1300 and len(self.data_size_list0) < 64: + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.check_and_patch_data_size(random) + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + random.init_from_bin(key) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + old_len = len(self.data_size_list0) + self.check_and_patch_data_size(random) + # if check_and_patch_data_size are work, re-sort again. + if old_len != len(self.data_size_list0): + self.data_size_list0.sort() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + self.init_data_size(self.server_info.key) + + def rnd_data_len(self, buf_size, last_hash, random): + other_data_size = buf_size + self.server_info.overhead + # if other_data_size > the bigest item in data_size_list0, not padding any data + if other_data_size >= self.data_size_list0[-1]: + return 0 + + random.init_from_bin_len(last_hash, buf_size) + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + # random select a size in the leftover data_size_list0 + final_pos = pos + random.next() % (len(self.data_size_list0) - pos) + return self.data_size_list0[final_pos] - other_data_size + + +class auth_chain_e(auth_chain_d): + def __init__(self, method): + super(auth_chain_e, self).__init__(method) + self.salt = b"auth_chain_e" + self.no_compatible_method = 'auth_chain_e' + + def rnd_data_len(self, buf_size, last_hash, random): + random.init_from_bin_len(last_hash, buf_size) + other_data_size = buf_size + self.server_info.overhead + # if other_data_size > the bigest item in data_size_list0, not padding any data + if other_data_size >= self.data_size_list0[-1]: + return 0 + + # use the mini size in the data_size_list0 + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + return self.data_size_list0[pos] - other_data_size + + +# auth_chain_f +# when every connect create, generate size_list will different when every day or every custom time interval which set in the config +class auth_chain_f(auth_chain_e): + def __init__(self, method): + super(auth_chain_f, self).__init__(method) + self.salt = b"auth_chain_f" + self.no_compatible_method = 'auth_chain_f' + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + try: + self.key_change_interval = int(server_info.protocol_param.split('#')[1]) # config are in second + except: + self.key_change_interval = 60 * 60 * 24 # a day by second + + def on_recv_auth_data(self, utc_time): + self.key_change_datetime_key = int(utc_time / self.key_change_interval) + self.key_change_datetime_key_bytes = [] # big bit first list + for i in range(7, -1, -1): # big-ending compare to c + self.key_change_datetime_key_bytes.append((self.key_change_datetime_key >> (8 * i)) & 0xFF) + self.init_data_size(self.server_info.key) + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + # key xor with key_change_datetime_key + new_key = bytearray(key) + new_key_str = '' + for i in range(0, 8): + new_key[i] ^= self.key_change_datetime_key_bytes[i] + new_key_str += chr(new_key[i]) + for i in range(8, len(new_key)): + new_key_str += chr(new_key[i]) + random.init_from_bin(to_bytes(new_key_str)) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append(int(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + old_len = len(self.data_size_list0) + self.check_and_patch_data_size(random) + # if check_and_patch_data_size are work, re-sort again. + if old_len != len(self.data_size_list0): + self.data_size_list0.sort() diff --git a/shadowsocks/obfsplugin/http_simple.py b/shadowsocks/obfsplugin/http_simple.py new file mode 100644 index 0000000..ff3c5fd --- /dev/null +++ b/shadowsocks/obfsplugin/http_simple.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import struct +import base64 +import datetime +import random + +from shadowsocks import common +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord, chr + +def create_http_simple_obfs(method): + return http_simple(method) + +def create_http_post_obfs(method): + return http_post(method) + +def create_random_head_obfs(method): + return random_head(method) + +obfs_map = { + 'http_simple': (create_http_simple_obfs,), + 'http_simple_compatible': (create_http_simple_obfs,), + 'http_post': (create_http_post_obfs,), + 'http_post_compatible': (create_http_post_obfs,), + 'random_head': (create_random_head_obfs,), + 'random_head_compatible': (create_random_head_obfs,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class http_simple(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.host = None + self.port = 0 + self.recv_buffer = b'' + # TODO user config user_agent + self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0", + b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0", + b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", + b"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.11 (KHTML, like Gecko) Ubuntu/11.10 Chromium/27.0.1453.93 Chrome/27.0.1453.93 Safari/537.36", + b"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:35.0) Gecko/20100101 Firefox/35.0", + b"Mozilla/5.0 (compatible; WOW64; MSIE 10.0; Windows NT 6.2)", + b"Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27", + b"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.3; Trident/7.0; .NET4.0E; .NET4.0C)", + b"Mozilla/5.0 (Windows NT 6.3; Trident/7.0; rv:11.0) like Gecko", + b"Mozilla/5.0 (Linux; Android 4.4; Nexus 5 Build/BuildID) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/30.0.0.0 Mobile Safari/537.36", + b"Mozilla/5.0 (iPad; CPU OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3", + b"Mozilla/5.0 (iPhone; CPU iPhone OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3"] + + def encode_head(self, buf): + hexstr = binascii.hexlify(buf) + chs = [] + for i in range(0, len(hexstr), 2): + chs.append(b"%" + hexstr[i:i+2]) + return b''.join(chs) + + def client_encode(self, buf): + if self.has_sent_header: + return buf + head_size = len(self.server_info.iv) + self.server_info.head_len + if len(buf) - head_size > 64: + headlen = head_size + random.randint(0, 64) + else: + headlen = len(buf) + headdata = buf[:headlen] + buf = buf[headlen:] + port = b'' + if self.server_info.port != 80: + port = b':' + to_bytes(str(self.server_info.port)) + body = None + hosts = (self.server_info.obfs_param or self.server_info.host) + pos = hosts.find("#") + if pos >= 0: + body = hosts[pos + 1:].replace("\n", "\r\n") + body = body.replace("\\n", "\r\n") + hosts = hosts[:pos] + hosts = hosts.split(',') + host = random.choice(hosts) + http_head = b"GET /" + self.encode_head(headdata) + b" HTTP/1.1\r\n" + http_head += b"Host: " + to_bytes(host) + port + b"\r\n" + if body: + http_head += body + "\r\n\r\n" + else: + http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n" + http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nDNT: 1\r\nConnection: keep-alive\r\n\r\n" + self.has_sent_header = True + return http_head + buf + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + pos = buf.find(b'\r\n\r\n') + if pos >= 0: + self.has_recv_header = True + return (buf[pos + 4:], False) + else: + return (b'', False) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + + header = b'HTTP/1.1 200 OK\r\nConnection: keep-alive\r\nContent-Encoding: gzip\r\nContent-Type: text/html\r\nDate: ' + header += to_bytes(datetime.datetime.now().strftime('%a, %d %b %Y %H:%M:%S GMT')) + header += b'\r\nServer: nginx\r\nVary: Accept-Encoding\r\n\r\n' + self.has_sent_header = True + return header + buf + + def get_data_from_http_header(self, buf): + ret_buf = b'' + lines = buf.split(b'\r\n') + if lines and len(lines) > 1: + hex_items = lines[0].split(b'%') + if hex_items and len(hex_items) > 1: + for index in range(1, len(hex_items)): + if len(hex_items[index]) < 2: + ret_buf += binascii.unhexlify('0' + hex_items[index]) + break + elif len(hex_items[index]) > 2: + ret_buf += binascii.unhexlify(hex_items[index][:2]) + break + else: + ret_buf += binascii.unhexlify(hex_items[index]) + return ret_buf + return b'' + + def get_host_from_http_header(self, buf): + ret_buf = b'' + lines = buf.split(b'\r\n') + if lines and len(lines) > 1: + for line in lines: + if match_begin(line, b"Host: "): + return common.to_str(line[6:]) + + def not_match_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + if self.method == 'http_simple': + return (b'E'*2048, False, False) + return (buf, True, False) + + def error_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + return (b'E'*2048, False, False) + + def server_decode(self, buf): + if self.has_recv_header: + return (buf, True, False) + + self.recv_buffer += buf + buf = self.recv_buffer + if len(buf) > 10: + if match_begin(buf, b'GET ') or match_begin(buf, b'POST '): + if len(buf) > 65536: + self.recv_buffer = None + logging.warn('http_simple: over size') + return self.not_match_return(buf) + else: #not http header, run on original protocol + self.recv_buffer = None + logging.debug('http_simple: not match begin') + return self.not_match_return(buf) + else: + return (b'', True, False) + + if b'\r\n\r\n' in buf: + datas = buf.split(b'\r\n\r\n', 1) + ret_buf = self.get_data_from_http_header(buf) + host = self.get_host_from_http_header(buf) + if host and self.server_info.obfs_param: + pos = host.find(":") + if pos >= 0: + host = host[:pos] + hosts = self.server_info.obfs_param.split(',') + if host not in hosts: + return self.not_match_return(buf) + if len(ret_buf) < 4: + return self.error_return(buf) + if len(datas) > 1: + ret_buf += datas[1] + if len(ret_buf) >= 13: + self.has_recv_header = True + return (ret_buf, True, False) + return self.not_match_return(buf) + else: + return (b'', True, False) + +class http_post(http_simple): + def __init__(self, method): + super(http_post, self).__init__(method) + + def boundary(self): + return to_bytes(''.join([random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(32)])) + + def client_encode(self, buf): + if self.has_sent_header: + return buf + head_size = len(self.server_info.iv) + self.server_info.head_len + if len(buf) - head_size > 64: + headlen = head_size + random.randint(0, 64) + else: + headlen = len(buf) + headdata = buf[:headlen] + buf = buf[headlen:] + port = b'' + if self.server_info.port != 80: + port = b':' + to_bytes(str(self.server_info.port)) + body = None + hosts = (self.server_info.obfs_param or self.server_info.host) + pos = hosts.find("#") + if pos >= 0: + body = hosts[pos + 1:].replace("\\n", "\r\n") + hosts = hosts[:pos] + hosts = hosts.split(',') + host = random.choice(hosts) + http_head = b"POST /" + self.encode_head(headdata) + b" HTTP/1.1\r\n" + http_head += b"Host: " + to_bytes(host) + port + b"\r\n" + if body: + http_head += body + "\r\n\r\n" + else: + http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n" + http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\n" + http_head += b"Content-Type: multipart/form-data; boundary=" + self.boundary() + b"\r\nDNT: 1\r\n" + http_head += b"Connection: keep-alive\r\n\r\n" + self.has_sent_header = True + return http_head + buf + + def not_match_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + if self.method == 'http_post': + return (b'E'*2048, False, False) + return (buf, True, False) + +class random_head(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + self.raw_trans_recv = False + self.send_buffer = b'' + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + self.has_sent_header = True + data = os.urandom(common.ord(os.urandom(1)[0]) % 96 + 4) + crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff + return data + struct.pack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class obfs_auth_data(object): + def __init__(self): + self.client_data = lru_cache.LRUCache(60 * 5) + self.client_id = os.urandom(32) + self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF + self.ticket_buf = {} + +class tls_ticket_auth(plain.plain): + def __init__(self, method): + self.method = method + self.handshake_status = 0 + self.send_buffer = b'' + self.recv_buffer = b'' + self.client_id = b'' + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.tls_version = b'\x03\x03' + self.overhead = 5 + + def init_data(self): + return obfs_auth_data() + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def sni(self, url): + url = common.to_bytes(url) + data = b"\x00" + struct.pack('>H', len(url)) + url + data = b"\x00\x00" + struct.pack('>H', len(data) + 2) + struct.pack('>H', len(data)) + data + return data + + def pack_auth_data(self, client_id): + utc_time = int(time.time()) & 0xFFFFFFFF + data = struct.pack('>I', utc_time) + os.urandom(18) + data += hmac.new(self.server_info.key + client_id, data, hashlib.sha1).digest()[:10] + return data + + def client_encode(self, buf): + if self.handshake_status == -1: + return buf + if self.handshake_status == 8: + ret = b'' + while len(buf) > 2048: + size = min(struct.unpack('>H', os.urandom(2))[0] % 4096 + 100, len(buf)) + ret += b"\x17" + self.tls_version + struct.pack('>H', size) + buf[:size] + buf = buf[size:] + if len(buf) > 0: + ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + return ret + if len(buf) > 0: + self.send_buffer += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + if self.handshake_status == 0: + self.handshake_status = 1 + data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100") + ext = binascii.unhexlify(b"ff01000100") + host = self.server_info.obfs_param or self.server_info.host + if host and host[-1] in string.digits: + host = '' + hosts = host.split(',') + host = random.choice(hosts) + ext += self.sni(host) + ext += b"\x00\x17\x00\x00" + if host not in self.server_info.data.ticket_buf: + self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 17 + 8) * 16) + ext += b"\x00\x23" + struct.pack('>H', len(self.server_info.data.ticket_buf[host])) + self.server_info.data.ticket_buf[host] + ext += binascii.unhexlify(b"000d001600140601060305010503040104030301030302010203") + ext += binascii.unhexlify(b"000500050100000000") + ext += binascii.unhexlify(b"00120000") + ext += binascii.unhexlify(b"75500000") + ext += binascii.unhexlify(b"000b00020100") + ext += binascii.unhexlify(b"000a0006000400170018") + data += struct.pack('>H', len(ext)) + ext + data = b"\x01\x00" + struct.pack('>H', len(data)) + data + data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data + return data + elif self.handshake_status == 1 and len(buf) == 0: + data = b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec + data += b"\x16" + self.tls_version + b"\x00\x20" + os.urandom(22) #Finished + data += hmac.new(self.server_info.key + self.server_info.data.client_id, data, hashlib.sha1).digest()[:10] + ret = data + self.send_buffer + self.send_buffer = b'' + self.handshake_status = 8 + return ret + return b'' + + def client_decode(self, buf): + if self.handshake_status == -1: + return (buf, False) + + if self.handshake_status == 8: + ret = b'' + self.recv_buffer += buf + while len(self.recv_buffer) > 5: + if ord(self.recv_buffer[0]) != 0x17: + logging.info("data = %s" % (binascii.hexlify(self.recv_buffer))) + raise Exception('server_decode appdata error') + size = struct.unpack('>H', self.recv_buffer[3:5])[0] + if len(self.recv_buffer) < size + 5: + break + buf = self.recv_buffer[5:size+5] + ret += buf + self.recv_buffer = self.recv_buffer[size+5:] + return (ret, False) + + if len(buf) < 11 + 32 + 1 + 32: + raise Exception('client_decode data error') + verify = buf[11:33] + if hmac.new(self.server_info.key + self.server_info.data.client_id, verify, hashlib.sha1).digest()[:10] != buf[33:43]: + raise Exception('client_decode data error') + if hmac.new(self.server_info.key + self.server_info.data.client_id, buf[:-10], hashlib.sha1).digest()[:10] != buf[-10:]: + raise Exception('client_decode data error') + return (b'', True) + + def server_encode(self, buf): + if self.handshake_status == -1: + return buf + if (self.handshake_status & 8) == 8: + ret = b'' + while len(buf) > 2048: + size = min(struct.unpack('>H', os.urandom(2))[0] % 4096 + 100, len(buf)) + ret += b"\x17" + self.tls_version + struct.pack('>H', size) + buf[:size] + buf = buf[size:] + if len(buf) > 0: + ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + return ret + self.handshake_status |= 8 + data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"c02f000005ff01000100") + data = b"\x02\x00" + struct.pack('>H', len(data)) + data #server hello + data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + if random.randint(0, 8) < 1: + ticket = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 164) * 2 + 64) + ticket = struct.pack('>H', len(ticket) + 4) + b"\x04\x00" + struct.pack('>H', len(ticket)) + ticket + data += b"\x16" + self.tls_version + ticket #New session ticket + data += b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec + finish_len = random.choice([32, 40]) + data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) #Finished + data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10] + if buf: + data += self.server_encode(buf) + return data + + def decode_error_return(self, buf): + self.handshake_status = -1 + if self.overhead > 0: + self.server_info.overhead -= self.overhead + self.overhead = 0 + if self.method in ['tls1.2_ticket_auth', 'tls1.2_ticket_fastauth']: + return (b'E'*2048, False, False) + return (buf, True, False) + + def server_decode(self, buf): + if self.handshake_status == -1: + return (buf, True, False) + + if (self.handshake_status & 4) == 4: + ret = b'' + self.recv_buffer += buf + while len(self.recv_buffer) > 5: + if ord(self.recv_buffer[0]) != 0x17 or ord(self.recv_buffer[1]) != 0x3 or ord(self.recv_buffer[2]) != 0x3: + logging.info("data = %s" % (binascii.hexlify(self.recv_buffer))) + raise Exception('server_decode appdata error') + size = struct.unpack('>H', self.recv_buffer[3:5])[0] + if len(self.recv_buffer) < size + 5: + break + ret += self.recv_buffer[5:size+5] + self.recv_buffer = self.recv_buffer[size+5:] + return (ret, True, False) + + if (self.handshake_status & 1) == 1: + self.recv_buffer += buf + buf = self.recv_buffer + verify = buf + if len(buf) < 11: + raise Exception('server_decode data error') + if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): #ChangeCipherSpec + raise Exception('server_decode data error') + buf = buf[6:] + if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): #Finished + raise Exception('server_decode data error') + verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10 + if len(verify) < verify_len + 10: + return (b'', False, False) + if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]: + raise Exception('server_decode data error') + self.recv_buffer = verify[verify_len + 10:] + status = self.handshake_status + self.handshake_status |= 4 + ret = self.server_decode(b'') + return ret; + + #raise Exception("handshake data = %s" % (binascii.hexlify(buf))) + self.recv_buffer += buf + buf = self.recv_buffer + ogn_buf = buf + if len(buf) < 3: + return (b'', False, False) + if not match_begin(buf, b'\x16\x03\x01'): + return self.decode_error_return(ogn_buf) + buf = buf[3:] + header_len = struct.unpack('>H', buf[:2])[0] + if header_len > len(buf) - 2: + return (b'', False, False) + + self.recv_buffer = self.recv_buffer[header_len + 5:] + self.handshake_status = 1 + buf = buf[2:header_len + 2] + if not match_begin(buf, b'\x01\x00'): #client hello + logging.info("tls_auth not client hello message") + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: + logging.info("tls_auth wrong message size") + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if not match_begin(buf, self.tls_version): + logging.info("tls_auth wrong tls version") + return self.decode_error_return(ogn_buf) + buf = buf[2:] + verifyid = buf[:32] + buf = buf[32:] + sessionid_len = ord(buf[0]) + if sessionid_len < 32: + logging.info("tls_auth wrong sessionid_len") + return self.decode_error_return(ogn_buf) + sessionid = buf[1:sessionid_len + 1] + buf = buf[sessionid_len+1:] + self.client_id = sessionid + sha1 = hmac.new(self.server_info.key + sessionid, verifyid[:22], hashlib.sha1).digest()[:10] + utc_time = struct.unpack('>I', verifyid[:4])[0] + time_dif = common.int32((int(time.time()) & 0xffffffff) - utc_time) + if self.server_info.obfs_param: + try: + self.max_time_dif = int(self.server_info.obfs_param) + except: + pass + if self.max_time_dif > 0 and (time_dif < -self.max_time_dif or time_dif > self.max_time_dif \ + or common.int32(utc_time - self.server_info.data.startup_time) < -self.max_time_dif / 2): + logging.info("tls_auth wrong time") + return self.decode_error_return(ogn_buf) + if sha1 != verifyid[22:]: + logging.info("tls_auth wrong sha1") + return self.decode_error_return(ogn_buf) + if self.server_info.data.client_data.get(verifyid[:22]): + logging.info("replay attack detect, id = %s" % (binascii.hexlify(verifyid))) + return self.decode_error_return(ogn_buf) + self.server_info.data.client_data.sweep() + self.server_info.data.client_data[verifyid[:22]] = sessionid + if len(self.recv_buffer) >= 11: + ret = self.server_decode(b'') + return (ret[0], True, True) + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (b'', False, True) + diff --git a/shadowsocks/obfsplugin/plain.py b/shadowsocks/obfsplugin/plain.py new file mode 100644 index 0000000..8c6355c --- /dev/null +++ b/shadowsocks/obfsplugin/plain.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging + +from shadowsocks.common import ord + +def create_obfs(method): + return plain(method) + +obfs_map = { + 'plain': (create_obfs,), + 'origin': (create_obfs,), +} + +class plain(object): + def __init__(self, method): + self.method = method + self.server_info = None + + def init_data(self): + return b'' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return 0 + + def get_server_info(self): + return self.server_info + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_pre_encrypt(self, buf): + return buf + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + # (buffer_to_recv, is_need_to_encode_and_send_back) + return (buf, False) + + def client_post_decrypt(self, buf): + return buf + + def server_pre_encrypt(self, buf): + return buf + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (buf, True, False) + + def server_post_decrypt(self, buf): + return (buf, False) + + def client_udp_pre_encrypt(self, buf): + return buf + + def client_udp_post_decrypt(self, buf): + return buf + + def server_udp_pre_encrypt(self, buf, uid): + return buf + + def server_udp_post_decrypt(self, buf): + return (buf, None) + + def dispose(self): + pass + + def get_head_size(self, buf, def_value): + if len(buf) < 2: + return def_value + head_type = ord(buf[0]) & 0x7 + if head_type == 1: + return 7 + if head_type == 4: + return 19 + if head_type == 3: + return 4 + ord(buf[1]) + return def_value + diff --git a/shadowsocks/obfsplugin/verify.py b/shadowsocks/obfsplugin/verify.py new file mode 100644 index 0000000..0dc0ca6 --- /dev/null +++ b/shadowsocks/obfsplugin/verify.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import base64 +import time +import datetime +import random +import struct +import zlib +import hmac +import hashlib + +import shadowsocks +from shadowsocks import common +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord, chr + +def create_verify_deflate(method): + return verify_deflate(method) + +obfs_map = { + 'verify_deflate': (create_verify_deflate,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class obfs_verify_data(object): + def __init__(self): + pass + +class verify_base(plain.plain): + def __init__(self, method): + super(verify_base, self).__init__(method) + self.method = method + + def init_data(self): + return obfs_verify_data() + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + +class verify_deflate(verify_base): + def __init__(self, method): + super(verify_deflate, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 32700 + self.decrypt_packet_num = 0 + self.raw_trans = False + + def pack_data(self, buf): + if len(buf) == 0: + return b'' + data = zlib.compress(buf) + data = struct.pack('>H', len(data)) + data[2:] + return data + + def client_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 32768 or length < 6: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + if length > len(self.recv_buf): + break + + out_buf += zlib.decompress(b'x\x9c' + self.recv_buf[2:length]) + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 32768 or length < 6: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + out_buf += zlib.decompress(b'\x78\x9c' + self.recv_buf[2:length]) + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return (out_buf, False) + diff --git a/shadowsocks/ordereddict.py b/shadowsocks/ordereddict.py new file mode 100644 index 0000000..e1918f5 --- /dev/null +++ b/shadowsocks/ordereddict.py @@ -0,0 +1,214 @@ +import collections + +################################################################################ +### OrderedDict +################################################################################ + +class OrderedDict(dict): + 'Dictionary that remembers insertion order' + # An inherited dict maps keys to values. + # The inherited dict provides __getitem__, __len__, __contains__, and get. + # The remaining methods are order-aware. + # Big-O running times for all methods are the same as regular dictionaries. + + # The internal self.__map dict maps keys to links in a doubly linked list. + # The circular doubly linked list starts and ends with a sentinel element. + # The sentinel element never gets deleted (this simplifies the algorithm). + # Each link is stored as a list of length three: [PREV, NEXT, KEY]. + + def __init__(*args, **kwds): + '''Initialize an ordered dictionary. The signature is the same as + regular dictionaries, but keyword arguments are not recommended because + their insertion order is arbitrary. + + ''' + if not args: + raise TypeError("descriptor '__init__' of 'OrderedDict' object " + "needs an argument") + self = args[0] + args = args[1:] + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: + self.__root + except AttributeError: + self.__root = root = [] # sentinel node + root[:] = [root, root, None] + self.__map = {} + self.__update(*args, **kwds) + + def __setitem__(self, key, value, dict_setitem=dict.__setitem__): + 'od.__setitem__(i, y) <==> od[i]=y' + # Setting a new item creates a new link at the end of the linked list, + # and the inherited dictionary is updated with the new key/value pair. + if key not in self: + root = self.__root + last = root[0] + last[1] = root[0] = self.__map[key] = [last, root, key] + return dict_setitem(self, key, value) + + def __delitem__(self, key, dict_delitem=dict.__delitem__): + 'od.__delitem__(y) <==> del od[y]' + # Deleting an existing item uses self.__map to find the link which gets + # removed by updating the links in the predecessor and successor nodes. + dict_delitem(self, key) + link_prev, link_next, _ = self.__map.pop(key) + link_prev[1] = link_next # update link_prev[NEXT] + link_next[0] = link_prev # update link_next[PREV] + + def __iter__(self): + 'od.__iter__() <==> iter(od)' + # Traverse the linked list in order. + root = self.__root + curr = root[1] # start at the first node + while curr is not root: + yield curr[2] # yield the curr[KEY] + curr = curr[1] # move to next node + + def __reversed__(self): + 'od.__reversed__() <==> reversed(od)' + # Traverse the linked list in reverse order. + root = self.__root + curr = root[0] # start at the last node + while curr is not root: + yield curr[2] # yield the curr[KEY] + curr = curr[0] # move to previous node + + def clear(self): + 'od.clear() -> None. Remove all items from od.' + root = self.__root + root[:] = [root, root, None] + self.__map.clear() + dict.clear(self) + + # -- the following methods do not depend on the internal structure -- + + def keys(self): + 'od.keys() -> list of keys in od' + return list(self) + + def values(self): + 'od.values() -> list of values in od' + return [self[key] for key in self] + + def items(self): + 'od.items() -> list of (key, value) pairs in od' + return [(key, self[key]) for key in self] + + def iterkeys(self): + 'od.iterkeys() -> an iterator over the keys in od' + return iter(self) + + def itervalues(self): + 'od.itervalues -> an iterator over the values in od' + for k in self: + yield self[k] + + def iteritems(self): + 'od.iteritems -> an iterator over the (key, value) pairs in od' + for k in self: + yield (k, self[k]) + + update = collections.MutableMapping.update + + __update = update # let subclasses override update without breaking __init__ + + __marker = object() + + def pop(self, key, default=__marker): + '''od.pop(k[,d]) -> v, remove specified key and return the corresponding + value. If key is not found, d is returned if given, otherwise KeyError + is raised. + + ''' + if key in self: + result = self[key] + del self[key] + return result + if default is self.__marker: + raise KeyError(key) + return default + + def setdefault(self, key, default=None): + 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od' + if key in self: + return self[key] + self[key] = default + return default + + def popitem(self, last=True): + '''od.popitem() -> (k, v), return and remove a (key, value) pair. + Pairs are returned in LIFO order if last is true or FIFO order if false. + + ''' + if not self: + raise KeyError('dictionary is empty') + key = next(reversed(self) if last else iter(self)) + value = self.pop(key) + return key, value + + def __repr__(self, _repr_running={}): + 'od.__repr__() <==> repr(od)' + call_key = id(self), _get_ident() + if call_key in _repr_running: + return '...' + _repr_running[call_key] = 1 + try: + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + finally: + del _repr_running[call_key] + + def __reduce__(self): + 'Return state information for pickling' + items = [[k, self[k]] for k in self] + inst_dict = vars(self).copy() + for k in vars(OrderedDict()): + inst_dict.pop(k, None) + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def copy(self): + 'od.copy() -> a shallow copy of od' + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S. + If not specified, the value defaults to None. + + ''' + self = cls() + for key in iterable: + self[key] = value + return self + + def __eq__(self, other): + '''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive + while comparison to a regular mapping is order-insensitive. + + ''' + if isinstance(other, OrderedDict): + return dict.__eq__(self, other) and all(_imap(_eq, self, other)) + return dict.__eq__(self, other) + + def __ne__(self, other): + 'od.__ne__(y) <==> od!=y' + return not self == other + + # -- the following methods support python 3.x style dictionary views -- + + def viewkeys(self): + "od.viewkeys() -> a set-like object providing a view on od's keys" + return KeysView(self) + + def viewvalues(self): + "od.viewvalues() -> an object providing a view on od's values" + return ValuesView(self) + + def viewitems(self): + "od.viewitems() -> a set-like object providing a view on od's items" + return ItemsView(self) + diff --git a/shadowsocks/run.sh b/shadowsocks/run.sh new file mode 100644 index 0000000..4372030 --- /dev/null +++ b/shadowsocks/run.sh @@ -0,0 +1,7 @@ +#!/bin/bash +cd `dirname $0` +python_ver=$(ls /usr/bin|grep -e "^python[23]\.[1-9]\+$"|tail -1) +eval $(ps -ef | grep "[0-9] ${python_ver} server\\.py a" | awk '{print "kill "$2}') +ulimit -n 512000 +nohup ${python_ver} server.py a>> /dev/null 2>&1 & + diff --git a/shadowsocks/server.py b/shadowsocks/server.py new file mode 100644 index 0000000..dced0a3 --- /dev/null +++ b/shadowsocks/server.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import sys +import os +import logging +import signal + +if __name__ == '__main__': + import inspect + + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) + sys.path.insert(0, os.path.join(file_path, '../')) + +from shadowsocks import shell, daemon, eventloop, tcprelay, udprelay, \ + asyncdns, manager, common + + +def main(): + shell.check_python() + + config = shell.get_config(False) + + shell.log_shadowsocks_version() + + daemon.daemon_exec(config) + + try: + import resource + logging.info( + 'current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) + except ImportError: + pass + + if config['port_password']: + pass + else: + config['port_password'] = {} + server_port = config['server_port'] + if type(server_port) == list: + for a_server_port in server_port: + config['port_password'][a_server_port] = config['password'] + else: + config['port_password'][str(server_port)] = config['password'] + + if not config.get('dns_ipv6', False): + asyncdns.IPV6_CONNECTION_SUPPORT = False + + if config.get('manager_address', 0): + logging.info('entering manager mode') + manager.run(config) + return + + tcp_servers = [] + udp_servers = [] + dns_resolver = asyncdns.DNSResolver(config['black_hostname_list']) + if int(config['workers']) > 1: + stat_counter_dict = None + else: + stat_counter_dict = {} + port_password = config['port_password'] + config_password = config.get('password', 'm') + del config['port_password'] + for port, password_obfs in port_password.items(): + method = config["method"] + protocol = config.get("protocol", 'origin') + protocol_param = config.get("protocol_param", '') + obfs = config.get("obfs", 'plain') + obfs_param = config.get("obfs_param", '') + bind = config.get("out_bind", '') + bindv6 = config.get("out_bindv6", '') + if type(password_obfs) == list: + password = password_obfs[0] + obfs = common.to_str(password_obfs[1]) + if len(password_obfs) > 2: + protocol = common.to_str(password_obfs[2]) + elif type(password_obfs) == dict: + password = password_obfs.get('password', config_password) + method = common.to_str(password_obfs.get('method', method)) + protocol = common.to_str(password_obfs.get('protocol', protocol)) + protocol_param = common.to_str(password_obfs.get('protocol_param', protocol_param)) + obfs = common.to_str(password_obfs.get('obfs', obfs)) + obfs_param = common.to_str(password_obfs.get('obfs_param', obfs_param)) + bind = password_obfs.get('out_bind', bind) + bindv6 = password_obfs.get('out_bindv6', bindv6) + else: + password = password_obfs + a_config = config.copy() + ipv6_ok = False + logging.info("server start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" % + (protocol, password, method, obfs, obfs_param)) + if 'server_ipv6' in a_config: + try: + if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][ + -1] == b"]": + a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] + a_config['server_port'] = int(port) + a_config['password'] = password + a_config['method'] = method + a_config['protocol'] = protocol + a_config['protocol_param'] = protocol_param + a_config['obfs'] = obfs + a_config['obfs_param'] = obfs_param + a_config['out_bind'] = bind + a_config['out_bindv6'] = bindv6 + a_config['server'] = common.to_str(a_config['server_ipv6']) + logging.info("starting server at [%s]:%d" % + (a_config['server'], int(port))) + tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict)) + udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict)) + if a_config['server_ipv6'] == b"::": + ipv6_ok = True + except Exception as e: + shell.print_exception(e) + + try: + a_config = config.copy() + a_config['server_port'] = int(port) + a_config['password'] = password + a_config['method'] = method + a_config['protocol'] = protocol + a_config['protocol_param'] = protocol_param + a_config['obfs'] = obfs + a_config['obfs_param'] = obfs_param + a_config['out_bind'] = bind + a_config['out_bindv6'] = bindv6 + logging.info("starting server at %s:%d" % + (a_config['server'], int(port))) + tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict)) + udp_servers.append(udprelay.UDPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict)) + except Exception as e: + if not ipv6_ok: + shell.print_exception(e) + + def run_server(): + def child_handler(signum, _): + logging.warn('received SIGQUIT, doing graceful shutting down..') + list(map(lambda s: s.close(next_tick=True), + tcp_servers + udp_servers)) + + signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), + child_handler) + + def int_handler(signum, _): + sys.exit(1) + + signal.signal(signal.SIGINT, int_handler) + + try: + loop = eventloop.EventLoop() + dns_resolver.add_to_loop(loop) + list(map(lambda s: s.add_to_loop(loop), tcp_servers + udp_servers)) + + daemon.set_user(config.get('user', None)) + loop.run() + except Exception as e: + shell.print_exception(e) + sys.exit(1) + + if int(config['workers']) > 1: + if os.name == 'posix': + children = [] + is_child = False + for i in range(0, int(config['workers'])): + r = os.fork() + if r == 0: + logging.info('worker started') + is_child = True + run_server() + break + else: + children.append(r) + if not is_child: + def handler(signum, _): + for pid in children: + try: + os.kill(pid, signum) + os.waitpid(pid, 0) + except OSError: # child may already exited + pass + sys.exit() + + signal.signal(signal.SIGTERM, handler) + signal.signal(signal.SIGQUIT, handler) + signal.signal(signal.SIGINT, handler) + + # master + for a_tcp_server in tcp_servers: + a_tcp_server.close() + for a_udp_server in udp_servers: + a_udp_server.close() + dns_resolver.close() + + for child in children: + os.waitpid(child, 0) + else: + logging.warn('worker is only available on Unix/Linux') + run_server() + else: + run_server() + + +if __name__ == '__main__': + main() diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py new file mode 100644 index 0000000..a1547d0 --- /dev/null +++ b/shadowsocks/shell.py @@ -0,0 +1,451 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import json +import sys +import getopt +import logging +from shadowsocks.common import to_bytes, to_str, IPNetwork, PortRange +from shadowsocks import encrypt + +VERBOSE_LEVEL = 5 + +verbose = 0 + + +def check_python(): + info = sys.version_info + if info[0] == 2 and not info[1] >= 6: + print('Python 2.6+ required') + sys.exit(1) + elif info[0] == 3 and not info[1] >= 3: + print('Python 3.3+ required') + sys.exit(1) + elif info[0] not in [2, 3]: + print('Python version not supported') + sys.exit(1) + + +def print_exception(e): + global verbose + logging.error(e) + if verbose > 0: + import traceback + traceback.print_exc() + + +def __version(): + version_str = '' + try: + import pkg_resources + version_str = pkg_resources.get_distribution('shadowsocks').version + except Exception: + try: + from shadowsocks import version + version_str = version.version() + except Exception: + pass + return version_str + + +def print_shadowsocks(): + print('ShadowsocksR %s' % __version()) + + +def log_shadowsocks_version(): + logging.info('ShadowsocksR %s' % __version()) + + +def find_config(): + user_config_path = 'user-config.json' + config_path = 'config.json' + + def sub_find(file_name): + if os.path.exists(file_name): + return file_name + file_name = os.path.join(os.path.abspath('..'), file_name) + return file_name if os.path.exists(file_name) else None + + return sub_find(user_config_path) or sub_find(config_path) + + +def check_config(config, is_local): + if config.get('daemon', None) == 'stop': + # no need to specify configuration for daemon stop + return + + if is_local and not config.get('password', None): + logging.error('password not specified') + print_help(is_local) + sys.exit(2) + + if not is_local and not config.get('password', None) \ + and not config.get('port_password', None): + logging.error('password or port_password not specified') + print_help(is_local) + sys.exit(2) + + if 'local_port' in config: + config['local_port'] = int(config['local_port']) + + if 'server_port' in config and type(config['server_port']) != list: + config['server_port'] = int(config['server_port']) + + if config.get('local_address', '') in [b'0.0.0.0']: + logging.warning('warning: local set to listen on 0.0.0.0, it\'s not safe') + if config.get('server', '') in ['127.0.0.1', 'localhost']: + logging.warning('warning: server set to listen on %s:%s, are you sure?' % + (to_str(config['server']), config['server_port'])) + if config.get('timeout', 300) < 100: + logging.warning('warning: your timeout %d seems too short' % + int(config.get('timeout'))) + if config.get('timeout', 300) > 600: + logging.warning('warning: your timeout %d seems too long' % + int(config.get('timeout'))) + if config.get('password') in [b'mypassword']: + logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your ' + 'config.json!') + sys.exit(1) + if config.get('user', None) is not None: + if os.name != 'posix': + logging.error('user can be used only on Unix') + sys.exit(1) + + encrypt.try_cipher(config['password'], config['method']) + + +def get_config(is_local): + global verbose + config = {} + config_path = None + logging.basicConfig(level=logging.INFO, + format='%(levelname)-s: %(message)s') + if is_local: + shortopts = 'hd:s:b:p:k:l:m:O:o:G:g:c:t:vq' + longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'user=', + 'version'] + else: + shortopts = 'hd:s:p:k:m:O:o:G:g:c:t:vq' + longopts = ['help', 'fast-open', 'pid-file=', 'log-file=', 'workers=', + 'forbidden-ip=', 'user=', 'manager-address=', 'version'] + try: + optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts) + for key, value in optlist: + if key == '-c': + config_path = value + elif key in ('-h', '--help'): + print_help(is_local) + sys.exit(0) + elif key == '--version': + print_shadowsocks() + sys.exit(0) + else: + continue + + if config_path is None: + config_path = find_config() + + if config_path: + logging.debug('loading config from %s' % config_path) + with open(config_path, 'rb') as f: + try: + config = parse_json_in_str(remove_comment(f.read().decode('utf8'))) + except ValueError as e: + logging.error('found an error in config.json: %s', str(e)) + sys.exit(1) + + v_count = 0 + for key, value in optlist: + if key == '-p': + config['server_port'] = int(value) + elif key == '-k': + config['password'] = to_bytes(value) + elif key == '-l': + config['local_port'] = int(value) + elif key == '-s': + config['server'] = to_str(value) + elif key == '-m': + config['method'] = to_str(value) + elif key == '-O': + config['protocol'] = to_str(value) + elif key == '-o': + config['obfs'] = to_str(value) + elif key == '-G': + config['protocol_param'] = to_str(value) + elif key == '-g': + config['obfs_param'] = to_str(value) + elif key == '-b': + config['local_address'] = to_str(value) + elif key == '-v': + v_count += 1 + # '-vv' turns on more verbose mode + config['verbose'] = v_count + elif key == '-t': + config['timeout'] = int(value) + elif key == '--fast-open': + config['fast_open'] = True + elif key == '--workers': + config['workers'] = int(value) + elif key == '--manager-address': + config['manager_address'] = value + elif key == '--user': + config['user'] = to_str(value) + elif key == '--forbidden-ip': + config['forbidden_ip'] = to_str(value) + + elif key == '-d': + config['daemon'] = to_str(value) + elif key == '--pid-file': + config['pid-file'] = to_str(value) + elif key == '--log-file': + config['log-file'] = to_str(value) + elif key == '-q': + v_count -= 1 + config['verbose'] = v_count + else: + continue + except getopt.GetoptError as e: + print(e, file=sys.stderr) + print_help(is_local) + sys.exit(2) + + if not config: + logging.error('config not specified') + print_help(is_local) + sys.exit(2) + + config['password'] = to_bytes(config.get('password', b'')) + config['method'] = to_str(config.get('method', 'aes-256-cfb')) + config['protocol'] = to_str(config.get('protocol', 'origin')) + config['protocol_param'] = to_str(config.get('protocol_param', '')) + config['obfs'] = to_str(config.get('obfs', 'plain')) + config['obfs_param'] = to_str(config.get('obfs_param', '')) + config['port_password'] = config.get('port_password', None) + config['additional_ports'] = config.get('additional_ports', {}) + config['additional_ports_only'] = config.get('additional_ports_only', False) + config['timeout'] = int(config.get('timeout', 300)) + config['udp_timeout'] = int(config.get('udp_timeout', 120)) + config['udp_cache'] = int(config.get('udp_cache', 64)) + config['fast_open'] = config.get('fast_open', False) + config['workers'] = config.get('workers', 1) + config['pid-file'] = config.get('pid-file', '/var/run/shadowsocksr.pid') + config['log-file'] = config.get('log-file', '/var/log/shadowsocksr.log') + config['verbose'] = config.get('verbose', False) + config['connect_verbose_info'] = config.get('connect_verbose_info', 0) + config['local_address'] = to_str(config.get('local_address', '127.0.0.1')) + config['local_port'] = config.get('local_port', 1080) + if is_local: + if config.get('server', None) is None: + logging.error('server addr not specified') + print_local_help() + sys.exit(2) + else: + config['server'] = to_str(config['server']) + else: + config['server'] = to_str(config.get('server', '0.0.0.0')) + config['black_hostname_list'] = to_str(config.get('black_hostname_list', '')).split(',') + if len(config['black_hostname_list']) == 1 and config['black_hostname_list'][0] == '': + config['black_hostname_list'] = [] + try: + config['forbidden_ip'] = \ + IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) + except Exception as e: + logging.error(e) + sys.exit(2) + try: + config['forbidden_port'] = PortRange(config.get('forbidden_port', '')) + except Exception as e: + logging.error(e) + sys.exit(2) + try: + config['ignore_bind'] = \ + IPNetwork(config.get('ignore_bind', '127.0.0.0/8,::1/128,10.0.0.0/8,192.168.0.0/16')) + except Exception as e: + logging.error(e) + sys.exit(2) + config['server_port'] = config.get('server_port', 8388) + + logging.getLogger('').handlers = [] + logging.addLevelName(VERBOSE_LEVEL, 'VERBOSE') + if config['verbose'] >= 2: + level = VERBOSE_LEVEL + elif config['verbose'] == 1: + level = logging.DEBUG + elif config['verbose'] == -1: + level = logging.WARN + elif config['verbose'] <= -2: + level = logging.ERROR + else: + level = logging.INFO + verbose = config['verbose'] + logging.basicConfig(level=level, + format='%(asctime)s %(levelname)-8s %(filename)s:%(lineno)s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + + check_config(config, is_local) + + return config + + +def print_help(is_local): + if is_local: + print_local_help() + else: + print_server_help() + + +def print_local_help(): + print('''usage: sslocal [OPTION]... +A fast tunnel proxy that helps you bypass firewalls. + +You can supply configurations via either config file or command line arguments. + +Proxy options: + -c CONFIG path to config file + -s SERVER_ADDR server address + -p SERVER_PORT server port, default: 8388 + -b LOCAL_ADDR local binding address, default: 127.0.0.1 + -l LOCAL_PORT local port, default: 1080 + -k PASSWORD password + -m METHOD encryption method, default: aes-256-cfb + -o OBFS obfsplugin, default: http_simple + -t TIMEOUT timeout in seconds, default: 300 + --fast-open use TCP_FASTOPEN, requires Linux 3.7+ + +General options: + -h, --help show this help message and exit + -d start/stop/restart daemon mode + --pid-file PID_FILE pid file for daemon mode + --log-file LOG_FILE log file for daemon mode + --user USER username to run as + -v, -vv verbose mode + -q, -qq quiet mode, only show warnings/errors + --version show version information + +Online help: +''') + + +def print_server_help(): + print('''usage: ssserver [OPTION]... +A fast tunnel proxy that helps you bypass firewalls. + +You can supply configurations via either config file or command line arguments. + +Proxy options: + -c CONFIG path to config file + -s SERVER_ADDR server address, default: 0.0.0.0 + -p SERVER_PORT server port, default: 8388 + -k PASSWORD password + -m METHOD encryption method, default: aes-256-cfb + -o OBFS obfsplugin, default: http_simple + -t TIMEOUT timeout in seconds, default: 300 + --fast-open use TCP_FASTOPEN, requires Linux 3.7+ + --workers WORKERS number of workers, available on Unix/Linux + --forbidden-ip IPLIST comma seperated IP list forbidden to connect + --manager-address ADDR optional server manager UDP address, see wiki + +General options: + -h, --help show this help message and exit + -d start/stop/restart daemon mode + --pid-file PID_FILE pid file for daemon mode + --log-file LOG_FILE log file for daemon mode + --user USER username to run as + -v, -vv verbose mode + -q, -qq quiet mode, only show warnings/errors + --version show version information + +Online help: +''') + + +def _decode_list(data): + rv = [] + for item in data: + if hasattr(item, 'encode'): + item = item.encode('utf-8') + elif isinstance(item, list): + item = _decode_list(item) + elif isinstance(item, dict): + item = _decode_dict(item) + rv.append(item) + return rv + + +def _decode_dict(data): + rv = {} + for key, value in data.items(): + if hasattr(value, 'encode'): + value = value.encode('utf-8') + elif isinstance(value, list): + value = _decode_list(value) + elif isinstance(value, dict): + value = _decode_dict(value) + rv[key] = value + return rv + + +class JSFormat: + def __init__(self): + self.state = 0 + + def push(self, ch): + ch = ord(ch) + if self.state == 0: + if ch == ord('"'): + self.state = 1 + return to_str(chr(ch)) + elif ch == ord('/'): + self.state = 3 + else: + return to_str(chr(ch)) + elif self.state == 1: + if ch == ord('"'): + self.state = 0 + return to_str(chr(ch)) + elif ch == ord('\\'): + self.state = 2 + return to_str(chr(ch)) + elif self.state == 2: + self.state = 1 + if ch == ord('"'): + return to_str(chr(ch)) + return "\\" + to_str(chr(ch)) + elif self.state == 3: + if ch == ord('/'): + self.state = 4 + else: + return "/" + to_str(chr(ch)) + elif self.state == 4: + if ch == ord('\n'): + self.state = 0 + return "\n" + return "" + + +def remove_comment(json): + fmt = JSFormat() + return "".join([fmt.push(c) for c in json]) + + +def parse_json_in_str(data): + # parse json and convert everything from unicode to str + return json.loads(data, object_hook=_decode_dict) diff --git a/shadowsocks/stop.sh b/shadowsocks/stop.sh new file mode 100644 index 0000000..d7d2958 --- /dev/null +++ b/shadowsocks/stop.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +python_ver=$(ls /usr/bin|grep -e "^python[23]\.[1-9]\+$"|tail -1) +eval $(ps -ef | grep "[0-9] ${python_ver} server\\.py a" | awk '{print "kill "$2}') + diff --git a/shadowsocks/tail.sh b/shadowsocks/tail.sh new file mode 100644 index 0000000..aa37139 --- /dev/null +++ b/shadowsocks/tail.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +tail -f ssserver.log diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py new file mode 100644 index 0000000..45a650e --- /dev/null +++ b/shadowsocks/tcprelay.py @@ -0,0 +1,1483 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import time +import socket +import errno +import struct +import logging +import binascii +import traceback +import random +import platform +import threading + +from shadowsocks import encrypt, obfs, eventloop, shell, common, lru_cache, version +from shadowsocks.common import pre_parse_header, parse_header + +# we clear at most TIMEOUTS_CLEAN_SIZE timeouts each time +TIMEOUTS_CLEAN_SIZE = 512 + +MSG_FASTOPEN = 0x20000000 + +# SOCKS command definition +CMD_CONNECT = 1 +CMD_BIND = 2 +CMD_UDP_ASSOCIATE = 3 + +# for each opening port, we have a TCP Relay + +# for each connection, we have a TCP Relay Handler to handle the connection + +# for each handler, we have 2 sockets: +# local: connected to the client +# remote: connected to remote server + +# for each handler, it could be at one of several stages: + +# as sslocal: +# stage 0 SOCKS hello received from local, send hello to local +# stage 1 addr received from local, query DNS for remote +# stage 2 UDP assoc +# stage 3 DNS resolved, connect to remote +# stage 4 still connecting, more data from local received +# stage 5 remote connected, piping local and remote + +# as ssserver: +# stage 0 just jump to stage 1 +# stage 1 addr received from local, query DNS for remote +# stage 3 DNS resolved, connect to remote +# stage 4 still connecting, more data from local received +# stage 5 remote connected, piping local and remote + +STAGE_INIT = 0 +STAGE_ADDR = 1 +STAGE_UDP_ASSOC = 2 +STAGE_DNS = 3 +STAGE_CONNECTING = 4 +STAGE_STREAM = 5 +STAGE_DESTROYED = -1 + +# for each handler, we have 2 stream directions: +# upstream: from client to server direction +# read local and write to remote +# downstream: from server to client direction +# read remote and write to local + +STREAM_UP = 0 +STREAM_DOWN = 1 + +# for each stream, it's waiting for reading, or writing, or both +WAIT_STATUS_INIT = 0 +WAIT_STATUS_READING = 1 +WAIT_STATUS_WRITING = 2 +WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING + +NETWORK_MTU = 1500 +TCP_MSS = NETWORK_MTU - 40 +BUF_SIZE = 32 * 1024 +UDP_MAX_BUF_SIZE = 65536 + +class SpeedTester(object): + def __init__(self, max_speed = 0): + self.max_speed = max_speed * 1024 + self.last_time = time.time() + self.sum_len = 0 + + def update_limit(self, max_speed): + self.max_speed = max_speed * 1024 + + def add(self, data_len): + if self.max_speed > 0: + cut_t = time.time() + self.sum_len -= (cut_t - self.last_time) * self.max_speed + if self.sum_len < 0: + self.sum_len = 0 + self.last_time = cut_t + self.sum_len += data_len + + def isExceed(self): + if self.max_speed > 0: + cut_t = time.time() + self.sum_len -= (cut_t - self.last_time) * self.max_speed + if self.sum_len < 0: + self.sum_len = 0 + self.last_time = cut_t + return self.sum_len >= self.max_speed + return False + +class TCPRelayHandler(object): + def __init__(self, server, fd_to_handlers, loop, local_sock, config, + dns_resolver, is_local): + self._server = server + self._fd_to_handlers = fd_to_handlers + self._loop = loop + self._local_sock = local_sock + self._remote_sock = None + self._remote_sock_v6 = None + self._local_sock_fd = None + self._remote_sock_fd = None + self._remotev6_sock_fd = None + self._remote_udp = False + self._config = config + self._dns_resolver = dns_resolver + self._add_ref = 0 + if not self._create_encryptor(config): + return + + self._client_address = local_sock.getpeername()[:2] + self._accept_address = local_sock.getsockname()[:2] + self._user = None + self._user_id = server._listen_port + self._update_tcp_mss(local_sock) + + # TCP Relay works as either sslocal or ssserver + # if is_local, this is sslocal + self._is_local = is_local + self._encrypt_correct = True + self._obfs = obfs.obfs(config['obfs']) + self._protocol = obfs.obfs(config['protocol']) + self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local) + self._recv_buffer_size = BUF_SIZE - self._overhead + + server_info = obfs.server_info(server.obfs_data) + server_info.host = config['server'] + server_info.port = server._listen_port + #server_info.users = server.server_users + #server_info.update_user_func = self._update_user + server_info.client = self._client_address[0] + server_info.client_port = self._client_address[1] + server_info.protocol_param = '' + server_info.obfs_param = config['obfs_param'] + server_info.iv = self._encryptor.cipher_iv + server_info.recv_iv = b'' + server_info.key_str = common.to_bytes(config['password']) + server_info.key = self._encryptor.cipher_key + server_info.head_len = 30 + server_info.tcp_mss = self._tcp_mss + server_info.buffer_size = self._recv_buffer_size + server_info.overhead = self._overhead + self._obfs.set_server_info(server_info) + + server_info = obfs.server_info(server.protocol_data) + server_info.host = config['server'] + server_info.port = server._listen_port + server_info.users = server.server_users + server_info.update_user_func = self._update_user + server_info.client = self._client_address[0] + server_info.client_port = self._client_address[1] + server_info.protocol_param = config['protocol_param'] + server_info.obfs_param = '' + server_info.iv = self._encryptor.cipher_iv + server_info.recv_iv = b'' + server_info.key_str = common.to_bytes(config['password']) + server_info.key = self._encryptor.cipher_key + server_info.head_len = 30 + server_info.tcp_mss = self._tcp_mss + server_info.buffer_size = self._recv_buffer_size + server_info.overhead = self._overhead + self._protocol.set_server_info(server_info) + + self._redir_list = config.get('redirect', ["*#0.0.0.0:0"]) + self._is_redirect = False + self._bind = config.get('out_bind', '') + self._bindv6 = config.get('out_bindv6', '') + self._ignore_bind_list = config.get('ignore_bind', []) + + self._fastopen_connected = False + self._data_to_write_to_local = [] + self._data_to_write_to_remote = [] + self._udp_data_send_buffer = b'' + self._upstream_status = WAIT_STATUS_READING + self._downstream_status = WAIT_STATUS_INIT + self._remote_address = None + + self._forbidden_iplist = config.get('forbidden_ip', None) + self._forbidden_portset = config.get('forbidden_port', None) + if is_local: + self._chosen_server = self._get_a_server() + + self.last_activity = 0 + self._update_activity() + self._server.add_connection(1) + self._server.stat_add(self._client_address[0], 1) + self._add_ref = 1 + self.speed_tester_u = SpeedTester(config.get("speed_limit_per_con", 0)) + self.speed_tester_d = SpeedTester(config.get("speed_limit_per_con", 0)) + self._recv_u_max_size = BUF_SIZE + self._recv_d_max_size = BUF_SIZE + self._recv_pack_id = 0 + self._udp_send_pack_id = 0 + self._udpv6_send_pack_id = 0 + + local_sock.setblocking(False) + local_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + self._local_sock_fd = local_sock.fileno() + fd_to_handlers[self._local_sock_fd] = self + loop.add(local_sock, eventloop.POLL_IN | eventloop.POLL_ERR, self._server) + self._stage = STAGE_INIT + + def __hash__(self): + # default __hash__ is id / 16 + # we want to eliminate collisions + return id(self) + + @property + def remote_address(self): + return self._remote_address + + def _get_a_server(self): + server = self._config['server'] + server_port = self._config['server_port'] + if type(server_port) == list: + server_port = random.choice(server_port) + if type(server) == list: + server = random.choice(server) + logging.debug('chosen server: %s:%d', server, server_port) + return server, server_port + + def _update_tcp_mss(self, local_sock): + self._tcp_mss = TCP_MSS + try: + tcp_mss = local_sock.getsockopt(socket.SOL_TCP, socket.TCP_MAXSEG) + if tcp_mss > 500 and tcp_mss <= 1500: + self._tcp_mss = tcp_mss + logging.debug("TCP MSS = %d" % (self._tcp_mss,)) + except: + pass + + def _create_encryptor(self, config): + try: + self._encryptor = encrypt.Encryptor(config['password'], + config['method'], None, True) + return True + except Exception: + self._stage = STAGE_DESTROYED + logging.error('create encryptor fail at port %d', self._server._listen_port) + + def _update_user(self, user): + self._user = user + self._user_id = struct.unpack(' 6: + length = struct.unpack('>H', self._udp_data_send_buffer[:2])[0] + + if length > len(self._udp_data_send_buffer): + break + + data = self._udp_data_send_buffer[:length] + self._udp_data_send_buffer = self._udp_data_send_buffer[length:] + + frag = common.ord(data[2]) + if frag != 0: + logging.warn('drop a message since frag is %d' % (frag,)) + continue + else: + data = data[3:] + header_result = parse_header(data) + if header_result is None: + continue + connecttype, addrtype, dest_addr, dest_port, header_length = header_result + if (addrtype & 7) == 3: + af = common.is_ip(dest_addr) + if af == False: + handler = common.UDPAsyncDNSHandler(data[header_length:]) + handler.resolve(self._dns_resolver, (dest_addr, dest_port), self._handle_server_dns_resolved) + else: + return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:]) + else: + return self._handle_server_dns_resolved("", (dest_addr, dest_port), dest_addr, data[header_length:]) + + except Exception as e: + #trace = traceback.format_exc() + #logging.error(trace) + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return False + return True + else: + try: + if self._encrypt_correct: + if sock == self._remote_sock: + self._server.add_transfer_u(self._user, len(data)) + self._update_activity(len(data)) + if data: + l = len(data) + s = sock.send(data) + if s < l: + data = data[s:] + uncomplete = True + else: + return + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + uncomplete = True + else: + #traceback.print_exc() + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return False + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return False + if uncomplete: + if sock == self._local_sock: + self._data_to_write_to_local.append(data) + self._update_stream(STREAM_DOWN, WAIT_STATUS_WRITING) + elif sock == self._remote_sock: + self._data_to_write_to_remote.append(data) + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) + else: + logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) + else: + if sock == self._local_sock: + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + elif sock == self._remote_sock: + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + else: + logging.error('write_all_to_sock:unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) + return True + + def _handle_server_dns_resolved(self, error, remote_addr, server_addr, data): + if error: + return + try: + addrs = socket.getaddrinfo(server_addr, remote_addr[1], 0, socket.SOCK_DGRAM, socket.SOL_UDP) + if not addrs: # drop + return + af, socktype, proto, canonname, sa = addrs[0] + if af == socket.AF_INET6: + self._remote_sock_v6.sendto(data, (server_addr, remote_addr[1])) + if self._udpv6_send_pack_id == 0: + addr, port = self._remote_sock_v6.getsockname()[:2] + common.connect_log('UDPv6 sendto %s(%s):%d from %s:%d by user %d' % + (common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id)) + self._udpv6_send_pack_id += 1 + else: + self._remote_sock.sendto(data, (server_addr, remote_addr[1])) + if self._udp_send_pack_id == 0: + addr, port = self._remote_sock.getsockname()[:2] + common.connect_log('UDP sendto %s(%s):%d from %s:%d by user %d' % + (common.to_str(remote_addr[0]), common.to_str(server_addr), remote_addr[1], addr, port, self._user_id)) + self._udp_send_pack_id += 1 + return True + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + + def _get_redirect_host(self, client_address, ogn_data): + host_list = self._redir_list or ["*#0.0.0.0:0"] + + if type(host_list) != list: + host_list = [host_list] + + items_sum = common.to_str(host_list[0]).rsplit('#', 1) + if len(items_sum) < 2: + hash_code = binascii.crc32(ogn_data) + addrs = socket.getaddrinfo(client_address[0], client_address[1], 0, socket.SOCK_STREAM, socket.SOL_TCP) + af, socktype, proto, canonname, sa = addrs[0] + address_bytes = common.inet_pton(af, sa[0]) + if af == socket.AF_INET6: + addr = struct.unpack('>Q', address_bytes[8:])[0] + elif af == socket.AF_INET: + addr = struct.unpack('>I', address_bytes)[0] + else: + addr = 0 + + host_port = [] + match_port = False + for host in host_list: + items = common.to_str(host).rsplit(':', 1) + if len(items) > 1: + try: + port = int(items[1]) + if port == self._server._listen_port: + match_port = True + host_port.append((items[0], port)) + except: + pass + else: + host_port.append((host, 80)) + + if match_port: + last_host_port = host_port + host_port = [] + for host in last_host_port: + if host[1] == self._server._listen_port: + host_port.append(host) + + return host_port[((hash_code & 0xffffffff) + addr) % len(host_port)] + + else: + host_port = [] + for host in host_list: + items_sum = common.to_str(host).rsplit('#', 1) + items_match = common.to_str(items_sum[0]).rsplit(':', 1) + items = common.to_str(items_sum[1]).rsplit(':', 1) + if len(items_match) > 1: + if items_match[1] != "*": + try: + if self._server._listen_port != int(items_match[1]) and int(items_match[1]) != 0: + continue + except: + pass + + if items_match[0] != "*" and common.match_regex( + items_match[0], ogn_data) == False: + continue + if len(items) > 1: + try: + port = int(items[1]) + return (items[0], port) + except: + pass + else: + return (items[0], 80) + + return ("0.0.0.0", 0) + + def _handel_protocol_error(self, client_address, ogn_data): + logging.warn("Protocol ERROR, TCP ogn data %s from %s:%d via port %d by UID %d" % (binascii.hexlify(ogn_data), client_address[0], client_address[1], self._server._listen_port, self._user_id)) + self._encrypt_correct = False + #create redirect or disconnect by hash code + host, port = self._get_redirect_host(client_address, ogn_data) + if port == 0: + raise Exception('can not parse header') + data = b"\x03" + common.to_bytes(common.chr(len(host))) + common.to_bytes(host) + struct.pack('>H', port) + self._is_redirect = True + logging.warn("TCP data redir %s:%d %s" % (host, port, binascii.hexlify(data))) + return data + ogn_data + + def _handle_stage_connecting(self, data): + if self._is_local: + if self._encryptor is not None: + data = self._protocol.client_pre_encrypt(data) + data = self._encryptor.encrypt(data) + data = self._obfs.client_encode(data) + if data: + self._data_to_write_to_remote.append(data) + if self._is_local and not self._fastopen_connected and \ + self._config['fast_open']: + # for sslocal and fastopen, we basically wait for data and use + # sendto to connect + try: + # only connect once + self._fastopen_connected = True + remote_sock = \ + self._create_remote_socket(self._chosen_server[0], + self._chosen_server[1]) + self._loop.add(remote_sock, eventloop.POLL_ERR, self._server) + data = b''.join(self._data_to_write_to_remote) + l = len(data) + s = remote_sock.sendto(data, MSG_FASTOPEN, self._chosen_server) + if s < l: + data = data[s:] + self._data_to_write_to_remote = [data] + else: + self._data_to_write_to_remote = [] + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) == errno.EINPROGRESS: + # in this case data is not sent at all + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + elif eventloop.errno_from_exception(e) == errno.ENOTCONN: + logging.error('fast open not supported on this OS') + self._config['fast_open'] = False + self.destroy() + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + + def _get_head_size(self, buf, def_value): + if len(buf) < 2: + return def_value + head_type = common.ord(buf[0]) & 0xF + if head_type == 1: + return 7 + if head_type == 4: + return 19 + if head_type == 3: + return 4 + common.ord(buf[1]) + return def_value + + def _handle_stage_addr(self, ogn_data, data): + try: + if self._is_local: + cmd = common.ord(data[1]) + if cmd == CMD_UDP_ASSOCIATE: + logging.debug('UDP associate') + if self._local_sock.family == socket.AF_INET6: + header = b'\x05\x00\x00\x04' + else: + header = b'\x05\x00\x00\x01' + addr, port = self._local_sock.getsockname()[:2] + addr_to_send = socket.inet_pton(self._local_sock.family, + addr) + port_to_send = struct.pack('>H', port) + self._write_to_sock(header + addr_to_send + port_to_send, + self._local_sock) + self._stage = STAGE_UDP_ASSOC + # just wait for the client to disconnect + return + elif cmd == CMD_CONNECT: + # just trim VER CMD RSV + data = data[3:] + else: + logging.error('invalid command %d', cmd) + self.destroy() + return + + before_parse_data = data + if self._is_local: + header_result = parse_header(data) + else: + data = pre_parse_header(data) + if data is None: + data = self._handel_protocol_error(self._client_address, ogn_data) + header_result = parse_header(data) + if header_result is not None: + try: + common.to_str(header_result[2]) + except Exception as e: + header_result = None + if header_result is None: + data = self._handel_protocol_error(self._client_address, ogn_data) + header_result = parse_header(data) + self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local) + self._recv_buffer_size = BUF_SIZE - self._overhead + server_info = self._obfs.get_server_info() + server_info.buffer_size = self._recv_buffer_size + server_info = self._protocol.get_server_info() + server_info.buffer_size = self._recv_buffer_size + connecttype, addrtype, remote_addr, remote_port, header_length = header_result + if connecttype != 0: + pass + #common.connect_log('UDP over TCP by user %d' % + # (self._user_id, )) + else: + common.connect_log('TCP request %s:%d by user %d' % + (common.to_str(remote_addr), remote_port, self._user_id)) + self._remote_address = (common.to_str(remote_addr), remote_port) + self._remote_udp = (connecttype != 0) + # pause reading + self._update_stream(STREAM_UP, WAIT_STATUS_WRITING) + self._stage = STAGE_DNS + if self._is_local: + # forward address to remote + self._write_to_sock((b'\x05\x00\x00\x01' + b'\x00\x00\x00\x00\x10\x10'), + self._local_sock) + head_len = self._get_head_size(data, 30) + self._obfs.obfs.server_info.head_len = head_len + self._protocol.obfs.server_info.head_len = head_len + if self._encryptor is not None: + data = self._protocol.client_pre_encrypt(data) + data_to_send = self._encryptor.encrypt(data) + data_to_send = self._obfs.client_encode(data_to_send) + if data_to_send: + self._data_to_write_to_remote.append(data_to_send) + # notice here may go into _handle_dns_resolved directly + self._dns_resolver.resolve(self._chosen_server[0], + self._handle_dns_resolved) + else: + if len(data) > header_length: + self._data_to_write_to_remote.append(data[header_length:]) + # notice here may go into _handle_dns_resolved directly + self._dns_resolver.resolve(remote_addr, + self._handle_dns_resolved) + except Exception as e: + self._log_error(e) + if self._config['verbose']: + traceback.print_exc() + self.destroy() + + def _socket_bind_addr(self, sock, af): + bind_addr = '' + if self._bind and af == socket.AF_INET: + bind_addr = self._bind + elif self._bindv6 and af == socket.AF_INET6: + bind_addr = self._bindv6 + else: + bind_addr = self._accept_address[0] + + bind_addr = bind_addr.replace("::ffff:", "") + if bind_addr in self._ignore_bind_list: + bind_addr = None + if bind_addr: + local_addrs = socket.getaddrinfo(bind_addr, 0, 0, socket.SOCK_STREAM, socket.SOL_TCP) + if local_addrs[0][0] == af: + logging.debug("bind %s" % (bind_addr,)) + try: + sock.bind((bind_addr, 0)) + except Exception as e: + logging.warn("bind %s fail" % (bind_addr,)) + + def _create_remote_socket(self, ip, port): + if self._remote_udp: + addrs_v6 = socket.getaddrinfo("::", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) + addrs = socket.getaddrinfo("0.0.0.0", 0, 0, socket.SOCK_DGRAM, socket.SOL_UDP) + else: + addrs = socket.getaddrinfo(ip, port, 0, socket.SOCK_STREAM, socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("getaddrinfo failed for %s:%d" % (ip, port)) + af, socktype, proto, canonname, sa = addrs[0] + if not self._remote_udp and not self._is_redirect: + if self._forbidden_iplist: + if common.to_str(sa[0]) in self._forbidden_iplist: + if self._remote_address: + raise Exception('IP %s is in forbidden list, when connect to %s:%d via port %d by UID %d' % + (common.to_str(sa[0]), self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id)) + raise Exception('IP %s is in forbidden list, reject' % + common.to_str(sa[0])) + if self._forbidden_portset: + if sa[1] in self._forbidden_portset: + if self._remote_address: + raise Exception('Port %d is in forbidden list, when connect to %s:%d via port %d by UID %d' % + (sa[1], self._remote_address[0], self._remote_address[1], self._server._listen_port, self._user_id)) + raise Exception('Port %d is in forbidden list, reject' % sa[1]) + remote_sock = socket.socket(af, socktype, proto) + self._remote_sock = remote_sock + self._remote_sock_fd = remote_sock.fileno() + self._fd_to_handlers[self._remote_sock_fd] = self + + if self._remote_udp: + af, socktype, proto, canonname, sa = addrs_v6[0] + remote_sock_v6 = socket.socket(af, socktype, proto) + self._remote_sock_v6 = remote_sock_v6 + self._remotev6_sock_fd = remote_sock_v6.fileno() + self._fd_to_handlers[self._remotev6_sock_fd] = self + + remote_sock.setblocking(False) + if self._remote_udp: + remote_sock_v6.setblocking(False) + + if not self._is_local: + self._socket_bind_addr(remote_sock, af) + self._socket_bind_addr(remote_sock_v6, af) + else: + remote_sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + if not self._is_local: + self._socket_bind_addr(remote_sock, af) + return remote_sock + + def _handle_dns_resolved(self, result, error): + if error: + self._log_error(error) + self.destroy() + return + if result: + ip = result[1] + if ip: + try: + self._stage = STAGE_CONNECTING + remote_addr = ip + if self._is_local: + remote_port = self._chosen_server[1] + else: + remote_port = self._remote_address[1] + + if self._is_local and self._config['fast_open']: + # for fastopen: + # wait for more data to arrive and send them in one SYN + self._stage = STAGE_CONNECTING + # we don't have to wait for remote since it's not + # created + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + # TODO when there is already data in this packet + else: + # else do connect + remote_sock = self._create_remote_socket(remote_addr, + remote_port) + if self._remote_udp: + self._loop.add(remote_sock, + eventloop.POLL_IN, + self._server) + if self._remote_sock_v6: + self._loop.add(self._remote_sock_v6, + eventloop.POLL_IN, + self._server) + else: + try: + remote_sock.connect((remote_addr, remote_port)) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in (errno.EINPROGRESS, + errno.EWOULDBLOCK): + pass # always goto here + else: + raise e + addr, port = self._remote_sock.getsockname()[:2] + common.connect_log('TCP connecting %s(%s):%d from %s:%d by user %d' % + (common.to_str(self._remote_address[0]), common.to_str(remote_addr), remote_port, addr, port, self._user_id)) + + self._loop.add(remote_sock, + eventloop.POLL_ERR | eventloop.POLL_OUT, + self._server) + self._stage = STAGE_CONNECTING + self._update_stream(STREAM_UP, WAIT_STATUS_READWRITING) + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + if self._remote_udp: + while self._data_to_write_to_remote: + data = self._data_to_write_to_remote[0] + del self._data_to_write_to_remote[0] + self._write_to_sock(data, self._remote_sock) + return + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + + def _get_read_size(self, sock, recv_buffer_size, up): + if self._overhead == 0: + return recv_buffer_size + buffer_size = len(sock.recv(recv_buffer_size, socket.MSG_PEEK)) + frame_size = self._tcp_mss - self._overhead + if up: + buffer_size = min(buffer_size, self._recv_u_max_size) + self._recv_u_max_size = min(self._recv_u_max_size + frame_size, BUF_SIZE) + else: + buffer_size = min(buffer_size, self._recv_d_max_size) + self._recv_d_max_size = min(self._recv_d_max_size + frame_size, BUF_SIZE) + if buffer_size == recv_buffer_size: + return buffer_size + if buffer_size > frame_size: + buffer_size = int(buffer_size / frame_size) * frame_size + return buffer_size + + def _on_local_read(self): + # handle all local read events and dispatch them to methods for + # each stage + if not self._local_sock: + return + is_local = self._is_local + if is_local: + recv_buffer_size = self._get_read_size(self._local_sock, self._recv_buffer_size, True) + else: + recv_buffer_size = BUF_SIZE + data = None + try: + data = self._local_sock.recv(recv_buffer_size) + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK): + return + if not data: + self.destroy() + return + + self.speed_tester_u.add(len(data)) + self._server.speed_tester_u(self._user_id).add(len(data)) + ogn_data = data + if not is_local: + if self._encryptor is not None: + if self._encrypt_correct: + try: + obfs_decode = self._obfs.server_decode(data) + if self._stage == STAGE_INIT: + self._overhead = self._obfs.get_overhead(self._is_local) + self._protocol.get_overhead(self._is_local) + server_info = self._protocol.get_server_info() + server_info.overhead = self._overhead + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return + if obfs_decode[2]: + data = self._obfs.server_encode(b'') + try: + self._write_to_sock(data, self._local_sock) + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return + if obfs_decode[1]: + if not self._protocol.obfs.server_info.recv_iv: + iv_len = len(self._protocol.obfs.server_info.iv) + self._protocol.obfs.server_info.recv_iv = obfs_decode[0][:iv_len] + data = self._encryptor.decrypt(obfs_decode[0]) + else: + data = obfs_decode[0] + try: + data, sendback = self._protocol.server_post_decrypt(data) + if sendback: + backdata = self._protocol.server_pre_encrypt(b'') + backdata = self._encryptor.encrypt(backdata) + backdata = self._obfs.server_encode(backdata) + try: + self._write_to_sock(backdata, self._local_sock) + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return + else: + return + if not data: + return + if self._stage == STAGE_STREAM: + if self._is_local: + if self._encryptor is not None: + data = self._protocol.client_pre_encrypt(data) + data = self._encryptor.encrypt(data) + data = self._obfs.client_encode(data) + self._write_to_sock(data, self._remote_sock) + elif is_local and self._stage == STAGE_INIT: + # TODO check auth method + self._write_to_sock(b'\x05\00', self._local_sock) + self._stage = STAGE_ADDR + elif self._stage == STAGE_CONNECTING: + self._handle_stage_connecting(data) + elif (is_local and self._stage == STAGE_ADDR) or \ + (not is_local and self._stage == STAGE_INIT): + self._handle_stage_addr(ogn_data, data) + + def _on_remote_read(self, is_remote_sock): + # handle all remote read events + data = None + try: + if self._remote_udp: + if is_remote_sock: + data, addr = self._remote_sock.recvfrom(UDP_MAX_BUF_SIZE) + else: + data, addr = self._remote_sock_v6.recvfrom(UDP_MAX_BUF_SIZE) + port = struct.pack('>H', addr[1]) + try: + ip = socket.inet_aton(addr[0]) + data = b'\x00\x01' + ip + port + data + except Exception as e: + ip = socket.inet_pton(socket.AF_INET6, addr[0]) + data = b'\x00\x04' + ip + port + data + size = len(data) + 2 + data = struct.pack('>H', size) + data + #logging.info('UDP over TCP recvfrom %s:%d %d bytes to %s:%d' % (addr[0], addr[1], len(data), self._client_address[0], self._client_address[1])) + else: + if self._is_local: + recv_buffer_size = BUF_SIZE + else: + recv_buffer_size = self._get_read_size(self._remote_sock, self._recv_buffer_size, False) + data = self._remote_sock.recv(recv_buffer_size) + self._recv_pack_id += 1 + except (OSError, IOError) as e: + if eventloop.errno_from_exception(e) in \ + (errno.ETIMEDOUT, errno.EAGAIN, errno.EWOULDBLOCK, 10035): #errno.WSAEWOULDBLOCK + return + if not data: + self.destroy() + return + + self.speed_tester_d.add(len(data)) + self._server.speed_tester_d(self._user_id).add(len(data)) + if self._encryptor is not None: + if self._is_local: + try: + obfs_decode = self._obfs.client_decode(data) + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return + if obfs_decode[1]: + send_back = self._obfs.client_encode(b'') + self._write_to_sock(send_back, self._remote_sock) + if not self._protocol.obfs.server_info.recv_iv: + iv_len = len(self._protocol.obfs.server_info.iv) + self._protocol.obfs.server_info.recv_iv = obfs_decode[0][:iv_len] + data = self._encryptor.decrypt(obfs_decode[0]) + try: + data = self._protocol.client_post_decrypt(data) + if self._recv_pack_id == 1: + self._tcp_mss = self._protocol.get_server_info().tcp_mss + except Exception as e: + shell.print_exception(e) + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + return + else: + if self._encrypt_correct: + data = self._protocol.server_pre_encrypt(data) + data = self._encryptor.encrypt(data) + data = self._obfs.server_encode(data) + self._server.add_transfer_d(self._user, len(data)) + self._update_activity(len(data)) + else: + return + try: + self._write_to_sock(data, self._local_sock) + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + logging.error("exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + + def _on_local_write(self): + # handle local writable event + if self._data_to_write_to_local: + data = b''.join(self._data_to_write_to_local) + self._data_to_write_to_local = [] + self._write_to_sock(data, self._local_sock) + else: + self._update_stream(STREAM_DOWN, WAIT_STATUS_READING) + + def _on_remote_write(self): + # handle remote writable event + self._stage = STAGE_STREAM + if self._data_to_write_to_remote: + data = b''.join(self._data_to_write_to_remote) + self._data_to_write_to_remote = [] + self._write_to_sock(data, self._remote_sock) + else: + self._update_stream(STREAM_UP, WAIT_STATUS_READING) + + def _on_local_error(self): + if self._local_sock: + err = eventloop.get_sock_error(self._local_sock) + if err.errno not in [errno.ECONNRESET, errno.EPIPE]: + logging.error(err) + logging.error("local error, exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + + def _on_remote_error(self): + if self._remote_sock: + err = eventloop.get_sock_error(self._remote_sock) + if err.errno not in [errno.ECONNRESET]: + logging.error(err) + if self._remote_address: + logging.error("remote error, when connect to %s:%d" % (self._remote_address[0], self._remote_address[1])) + else: + logging.error("remote error, exception from %s:%d" % (self._client_address[0], self._client_address[1])) + self.destroy() + + def handle_event(self, sock, fd, event): + # handle all events in this handler and dispatch them to methods + handle = False + if self._stage == STAGE_DESTROYED: + logging.debug('ignore handle_event: destroyed') + return True + if self._user is not None and self._user not in self._server.server_users: + self.destroy() + return True + if fd == self._remote_sock_fd or fd == self._remotev6_sock_fd: + if event & eventloop.POLL_ERR: + handle = True + self._on_remote_error() + elif event & (eventloop.POLL_IN | eventloop.POLL_HUP): + if not self.speed_tester_d.isExceed() and not self._server.speed_tester_d(self._user_id).isExceed(): + handle = True + self._on_remote_read(sock == self._remote_sock) + else: + self._recv_d_max_size = self._tcp_mss - self._overhead + elif event & eventloop.POLL_OUT: + handle = True + self._on_remote_write() + elif fd == self._local_sock_fd: + if event & eventloop.POLL_ERR: + handle = True + self._on_local_error() + elif event & (eventloop.POLL_IN | eventloop.POLL_HUP): + if not self.speed_tester_u.isExceed() and not self._server.speed_tester_u(self._user_id).isExceed(): + handle = True + self._on_local_read() + else: + self._recv_u_max_size = self._tcp_mss - self._overhead + elif event & eventloop.POLL_OUT: + handle = True + self._on_local_write() + else: + logging.warn('unknown socket from %s:%d' % (self._client_address[0], self._client_address[1])) + try: + self._loop.removefd(fd) + except Exception as e: + shell.print_exception(e) + try: + del self._fd_to_handlers[fd] + except Exception as e: + shell.print_exception(e) + sock.close() + + return handle + + def _log_error(self, e): + logging.error('%s when handling connection from %s:%d' % + (e, self._client_address[0], self._client_address[1])) + + def stage(self): + return self._stage + + def destroy(self): + # destroy the handler and release any resources + # promises: + # 1. destroy won't make another destroy() call inside + # 2. destroy releases resources so it prevents future call to destroy + # 3. destroy won't raise any exceptions + # if any of the promises are broken, it indicates a bug has been + # introduced! mostly likely memory leaks, etc + if self._stage == STAGE_DESTROYED: + # this couldn't happen + logging.debug('already destroyed') + return + self._stage = STAGE_DESTROYED + if self._remote_address: + logging.debug('destroy: %s:%d' % + self._remote_address) + else: + logging.debug('destroy') + if self._remote_sock: + logging.debug('destroying remote') + try: + self._loop.removefd(self._remote_sock_fd) + except Exception as e: + shell.print_exception(e) + try: + if self._remote_sock_fd is not None: + del self._fd_to_handlers[self._remote_sock_fd] + except Exception as e: + shell.print_exception(e) + self._remote_sock.close() + self._remote_sock = None + if self._remote_sock_v6: + logging.debug('destroying remote_v6') + try: + self._loop.removefd(self._remotev6_sock_fd) + except Exception as e: + shell.print_exception(e) + try: + if self._remotev6_sock_fd is not None: + del self._fd_to_handlers[self._remotev6_sock_fd] + except Exception as e: + shell.print_exception(e) + self._remote_sock_v6.close() + self._remote_sock_v6 = None + if self._local_sock: + logging.debug('destroying local') + try: + self._loop.removefd(self._local_sock_fd) + except Exception as e: + shell.print_exception(e) + try: + if self._local_sock_fd is not None: + del self._fd_to_handlers[self._local_sock_fd] + except Exception as e: + shell.print_exception(e) + self._local_sock.close() + self._local_sock = None + if self._obfs: + self._obfs.dispose() + self._obfs = None + if self._protocol: + self._protocol.dispose() + self._protocol = None + + if self._encryptor: + self._encryptor.dispose() + self._encryptor = None + self._dns_resolver.remove_callback(self._handle_dns_resolved) + self._server.remove_handler(self) + if self._add_ref > 0: + self._server.add_connection(-1) + self._server.stat_add(self._client_address[0], -1) + + #import gc + #gc.collect() + #logging.debug("gc %s" % (gc.garbage,)) + +class TCPRelay(object): + def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): + self._config = config + self._is_local = is_local + self._dns_resolver = dns_resolver + self._closed = False + self._eventloop = None + self._fd_to_handlers = {} + self.server_transfer_ul = 0 + self.server_transfer_dl = 0 + self.server_users = {} + self.server_users_cfg = {} + self.server_user_transfer_ul = {} + self.server_user_transfer_dl = {} + self.mu = False + self._speed_tester_u = {} + self._speed_tester_d = {} + self.server_connections = 0 + self.protocol_data = obfs.obfs(config['protocol']).init_data() + self.obfs_data = obfs.obfs(config['obfs']).init_data() + + if config.get('connect_verbose_info', 0) > 0: + common.connect_log = logging.info + + self._timeout = config['timeout'] + self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, + close_callback=self._close_tcp_client) + + if is_local: + listen_addr = config['local_address'] + listen_port = config['local_port'] + else: + listen_addr = config['server'] + listen_port = config['server_port'] + self._listen_port = listen_port + + if common.to_str(config['protocol']) in obfs.mu_protocol(): + self._update_users(None, None) + + addrs = socket.getaddrinfo(listen_addr, listen_port, 0, + socket.SOCK_STREAM, socket.SOL_TCP) + if len(addrs) == 0: + raise Exception("can't get addrinfo for %s:%d" % + (listen_addr, listen_port)) + af, socktype, proto, canonname, sa = addrs[0] + server_socket = socket.socket(af, socktype, proto) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind(sa) + server_socket.setblocking(False) + if config['fast_open']: + try: + server_socket.setsockopt(socket.SOL_TCP, 23, 5) + except socket.error: + logging.error('warning: fast open is not available') + self._config['fast_open'] = False + server_socket.listen(config.get('max_connect', 1024)) + self._server_socket = server_socket + self._server_socket_fd = server_socket.fileno() + self._stat_counter = stat_counter + self._stat_callback = stat_callback + + def add_to_loop(self, loop): + if self._eventloop: + raise Exception('already add to loop') + if self._closed: + raise Exception('already closed') + self._eventloop = loop + self._eventloop.add(self._server_socket, + eventloop.POLL_IN | eventloop.POLL_ERR, self) + self._eventloop.add_periodic(self.handle_periodic) + + def remove_handler(self, client): + if hash(client) in self._timeout_cache: + del self._timeout_cache[hash(client)] + + def add_connection(self, val): + self.server_connections += val + logging.debug('server port %5d connections = %d' % (self._listen_port, self.server_connections,)) + + def get_ud(self): + return (self.server_transfer_ul, self.server_transfer_dl) + + def get_users_ud(self): + return (self.server_user_transfer_ul.copy(), self.server_user_transfer_dl.copy()) + + def _update_users(self, protocol_param, acl): + if protocol_param is None: + protocol_param = self._config['protocol_param'] + param = common.to_bytes(protocol_param).split(b'#') + if len(param) == 2: + self.mu = True + user_list = param[1].split(b',') + if user_list: + for user in user_list: + items = user.split(b':') + if len(items) == 2: + user_int_id = int(items[0]) + uid = struct.pack('= stat_dict.get(-1, 0) + connections_step: + logging.info('port %d connections up to %d' % (port, newval)) + stat_dict[-1] = stat_dict.get(-1, 0) + connections_step + elif newval <= stat_dict.get(-1, 0) - connections_step: + logging.info('port %d connections down to %d' % (port, newval)) + stat_dict[-1] = stat_dict.get(-1, 0) - connections_step + + def stat_add(self, local_addr, val): + if self._stat_counter is not None: + if self._listen_port not in self._stat_counter: + self._stat_counter[self._listen_port] = {} + newval = self._stat_counter[self._listen_port].get(local_addr, 0) + val + logging.debug('port %d addr %s connections %d' % (self._listen_port, local_addr, newval)) + self._stat_counter[self._listen_port][local_addr] = newval + self.update_stat(self._listen_port, self._stat_counter[self._listen_port], val) + if newval <= 0: + if local_addr in self._stat_counter[self._listen_port]: + del self._stat_counter[self._listen_port][local_addr] + + newval = self._stat_counter.get(0, 0) + val + self._stat_counter[0] = newval + logging.debug('Total connections %d' % newval) + + connections_step = 50 + if newval >= self._stat_counter.get(-1, 0) + connections_step: + logging.info('Total connections up to %d' % newval) + self._stat_counter[-1] = self._stat_counter.get(-1, 0) + connections_step + elif newval <= self._stat_counter.get(-1, 0) - connections_step: + logging.info('Total connections down to %d' % newval) + self._stat_counter[-1] = self._stat_counter.get(-1, 0) - connections_step + + def update_activity(self, client, data_len): + if data_len and self._stat_callback: + self._stat_callback(self._listen_port, data_len) + + self._timeout_cache[hash(client)] = client + + def _sweep_timeout(self): + self._timeout_cache.sweep() + + def _close_tcp_client(self, client): + if client.remote_address: + logging.debug('timed out: %s:%d' % + client.remote_address) + else: + logging.debug('timed out') + client.destroy() + + def handle_event(self, sock, fd, event): + # handle events and dispatch to handlers + handle = False + if sock: + logging.log(shell.VERBOSE_LEVEL, 'fd %d %s', fd, + eventloop.EVENT_NAMES.get(event, event)) + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + # TODO + raise Exception('server_socket error') + handler = None + handle = True + try: + logging.debug('accept') + conn = self._server_socket.accept() + handler = TCPRelayHandler(self, self._fd_to_handlers, + self._eventloop, conn[0], self._config, + self._dns_resolver, self._is_local) + if handler.stage() == STAGE_DESTROYED: + conn[0].close() + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + if error_no in (errno.EAGAIN, errno.EINPROGRESS, + errno.EWOULDBLOCK): + return + else: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + if handler: + handler.destroy() + else: + if sock: + handler = self._fd_to_handlers.get(fd, None) + if handler: + handle = handler.handle_event(sock, fd, event) + else: + logging.warn('unknown fd') + handle = True + try: + self._eventloop.removefd(fd) + except Exception as e: + shell.print_exception(e) + sock.close() + else: + logging.warn('poll removed fd') + handle = True + if fd in self._fd_to_handlers: + try: + del self._fd_to_handlers[fd] + except Exception as e: + shell.print_exception(e) + return handle + + def handle_periodic(self): + if self._closed: + if self._server_socket: + self._eventloop.removefd(self._server_socket_fd) + self._server_socket.close() + self._server_socket = None + logging.info('closed TCP port %d', self._listen_port) + for handler in list(self._fd_to_handlers.values()): + handler.destroy() + self._sweep_timeout() + + def close(self, next_tick=False): + logging.debug('TCP close') + self._closed = True + if not next_tick: + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.removefd(self._server_socket_fd) + self._server_socket.close() + for handler in list(self._fd_to_handlers.values()): + handler.destroy() diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py new file mode 100644 index 0000000..cebe974 --- /dev/null +++ b/shadowsocks/udprelay.py @@ -0,0 +1,658 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# SOCKS5 UDP Request +# +----+------+------+----------+----------+----------+ +# |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | +# +----+------+------+----------+----------+----------+ +# | 2 | 1 | 1 | Variable | 2 | Variable | +# +----+------+------+----------+----------+----------+ + +# SOCKS5 UDP Response +# +----+------+------+----------+----------+----------+ +# |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | +# +----+------+------+----------+----------+----------+ +# | 2 | 1 | 1 | Variable | 2 | Variable | +# +----+------+------+----------+----------+----------+ + +# shadowsocks UDP Request (before encrypted) +# +------+----------+----------+----------+ +# | ATYP | DST.ADDR | DST.PORT | DATA | +# +------+----------+----------+----------+ +# | 1 | Variable | 2 | Variable | +# +------+----------+----------+----------+ + +# shadowsocks UDP Response (before encrypted) +# +------+----------+----------+----------+ +# | ATYP | DST.ADDR | DST.PORT | DATA | +# +------+----------+----------+----------+ +# | 1 | Variable | 2 | Variable | +# +------+----------+----------+----------+ + +# shadowsocks UDP Request and Response (after encrypted) +# +-------+--------------+ +# | IV | PAYLOAD | +# +-------+--------------+ +# | Fixed | Variable | +# +-------+--------------+ + +# HOW TO NAME THINGS +# ------------------ +# `dest` means destination server, which is from DST fields in the SOCKS5 +# request +# `local` means local server of shadowsocks +# `remote` means remote server of shadowsocks +# `client` means UDP clients that connects to other servers +# `server` means the UDP server that handles user requests + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import time +import socket +import logging +import struct +import errno +import random +import binascii +import traceback +import threading + +from shadowsocks import encrypt, obfs, eventloop, lru_cache, common, shell +from shadowsocks.common import pre_parse_header, parse_header, pack_addr + +# for each handler, we have 2 stream directions: +# upstream: from client to server direction +# read local and write to remote +# downstream: from server to client direction +# read remote and write to local + +STREAM_UP = 0 +STREAM_DOWN = 1 + +# for each stream, it's waiting for reading, or writing, or both +WAIT_STATUS_INIT = 0 +WAIT_STATUS_READING = 1 +WAIT_STATUS_WRITING = 2 +WAIT_STATUS_READWRITING = WAIT_STATUS_READING | WAIT_STATUS_WRITING + +BUF_SIZE = 65536 +DOUBLE_SEND_BEG_IDS = 16 +POST_MTU_MIN = 500 +POST_MTU_MAX = 1400 +SENDING_WINDOW_SIZE = 8192 + +STAGE_INIT = 0 +STAGE_RSP_ID = 1 +STAGE_DNS = 2 +STAGE_CONNECTING = 3 +STAGE_STREAM = 4 +STAGE_DESTROYED = -1 + +CMD_CONNECT = 0 +CMD_RSP_CONNECT = 1 +CMD_CONNECT_REMOTE = 2 +CMD_RSP_CONNECT_REMOTE = 3 +CMD_POST = 4 +CMD_SYN_STATUS = 5 +CMD_POST_64 = 6 +CMD_SYN_STATUS_64 = 7 +CMD_DISCONNECT = 8 + +CMD_VER_STR = b"\x08" + +RSP_STATE_EMPTY = b"" +RSP_STATE_REJECT = b"\x00" +RSP_STATE_CONNECTED = b"\x01" +RSP_STATE_CONNECTEDREMOTE = b"\x02" +RSP_STATE_ERROR = b"\x03" +RSP_STATE_DISCONNECT = b"\x04" +RSP_STATE_REDIRECT = b"\x05" + +def client_key(source_addr, server_af): + # notice this is server af, not dest af + return '%s:%s:%d' % (source_addr[0], source_addr[1], server_af) + +class UDPRelay(object): + def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): + self._config = config + if config.get('connect_verbose_info', 0) > 0: + common.connect_log = logging.info + if is_local: + self._listen_addr = config['local_address'] + self._listen_port = config['local_port'] + self._remote_addr = config['server'] + self._remote_port = config['server_port'] + else: + self._listen_addr = config['server'] + self._listen_port = config['server_port'] + self._remote_addr = None + self._remote_port = None + self._dns_resolver = dns_resolver + self._password = common.to_bytes(config['password']) + self._method = config['method'] + self._timeout = config['timeout'] + self._is_local = is_local + self._udp_cache_size = config['udp_cache'] + self._cache = lru_cache.LRUCache(timeout=config['udp_timeout'], + close_callback=self._close_client_pair) + self._cache_dns_client = lru_cache.LRUCache(timeout=10, + close_callback=self._close_client_pair) + self._client_fd_to_server_addr = {} + #self._dns_cache = lru_cache.LRUCache(timeout=1800) + self._eventloop = None + self._closed = False + self.server_transfer_ul = 0 + self.server_transfer_dl = 0 + self.server_users = {} + self.server_user_transfer_ul = {} + self.server_user_transfer_dl = {} + + if common.to_bytes(config['protocol']) in obfs.mu_protocol(): + self._update_users(None, None) + + self.protocol_data = obfs.obfs(config['protocol']).init_data() + self._protocol = obfs.obfs(config['protocol']) + server_info = obfs.server_info(self.protocol_data) + server_info.host = self._listen_addr + server_info.port = self._listen_port + server_info.users = self.server_users + server_info.protocol_param = config['protocol_param'] + server_info.obfs_param = '' + server_info.iv = b'' + server_info.recv_iv = b'' + server_info.key_str = common.to_bytes(config['password']) + server_info.key = encrypt.encrypt_key(self._password, self._method) + server_info.head_len = 30 + server_info.tcp_mss = 1452 + server_info.buffer_size = BUF_SIZE + server_info.overhead = 0 + self._protocol.set_server_info(server_info) + + self._sockets = set() + self._fd_to_handlers = {} + self._reqid_to_hd = {} + self._data_to_write_to_server_socket = [] + + self._timeout_cache = lru_cache.LRUCache(timeout=self._timeout, + close_callback=self._close_tcp_client) + + self._bind = config.get('out_bind', '') + self._bindv6 = config.get('out_bindv6', '') + self._ignore_bind_list = config.get('ignore_bind', []) + + if 'forbidden_ip' in config: + self._forbidden_iplist = config['forbidden_ip'] + else: + self._forbidden_iplist = None + if 'forbidden_port' in config: + self._forbidden_portset = config['forbidden_port'] + else: + self._forbidden_portset = None + + addrs = socket.getaddrinfo(self._listen_addr, self._listen_port, 0, + socket.SOCK_DGRAM, socket.SOL_UDP) + if len(addrs) == 0: + raise Exception("can't get addrinfo for %s:%d" % + (self._listen_addr, self._listen_port)) + af, socktype, proto, canonname, sa = addrs[0] + server_socket = socket.socket(af, socktype, proto) + server_socket.bind((self._listen_addr, self._listen_port)) + server_socket.setblocking(False) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) + self._server_socket = server_socket + self._stat_callback = stat_callback + + def _get_a_server(self): + server = self._config['server'] + server_port = self._config['server_port'] + if type(server_port) == list: + server_port = random.choice(server_port) + if type(server) == list: + server = random.choice(server) + logging.debug('chosen server: %s:%d', server, server_port) + return server, server_port + + def get_ud(self): + return (self.server_transfer_ul, self.server_transfer_dl) + + def get_users_ud(self): + ret = (self.server_user_transfer_ul.copy(), self.server_user_transfer_dl.copy()) + return ret + + def _update_users(self, protocol_param, acl): + if protocol_param is None: + protocol_param = self._config['protocol_param'] + param = common.to_bytes(protocol_param).split(b'#') + if len(param) == 2: + user_list = param[1].split(b',') + if user_list: + for user in user_list: + items = user.split(b':') + if len(items) == 2: + user_int_id = int(items[0]) + uid = struct.pack(' header_length + 13 and data[header_length + 4 : header_length + 12] == b"\x00\x01\x00\x00\x00\x00\x00\x00": + is_dns = True + else: + pass + if sa[1] == 53 and is_dns: #DNS + logging.debug("DNS query %s from %s:%d" % (common.to_str(sa[0]), r_addr[0], r_addr[1])) + self._cache_dns_client[key] = (client, uid) + else: + self._cache[key] = (client, uid) + self._client_fd_to_server_addr[client.fileno()] = (r_addr, af) + + self._sockets.add(client.fileno()) + self._eventloop.add(client, eventloop.POLL_IN, self) + + logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) + + if uid is not None: + user_id = struct.unpack(' 255: + # drop + return + data = pack_addr(r_addr[0]) + struct.pack('>H', r_addr[1]) + data + ref_iv = [encrypt.encrypt_new_iv(self._method)] + self._protocol.obfs.server_info.iv = ref_iv[0] + data = self._protocol.server_udp_pre_encrypt(data, client_uid) + response = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 1, + data, ref_iv) + if not response: + return + else: + ref_iv = [0] + data = encrypt.encrypt_all_iv(self._protocol.obfs.server_info.key, self._method, 0, + data, ref_iv) + if not data: + return + self._protocol.obfs.server_info.recv_iv = ref_iv[0] + data = self._protocol.client_udp_post_decrypt(data) + header_result = parse_header(data) + if header_result is None: + return + #connecttype, dest_addr, dest_port, header_length = header_result + #logging.debug('UDP handle_client %s:%d to %s:%d' % (common.to_str(r_addr[0]), r_addr[1], dest_addr, dest_port)) + + response = b'\x00\x00\x00' + data + + if client_addr: + if client_uid: + self.add_transfer_d(client_uid, len(response)) + else: + self.server_transfer_dl += len(response) + self.write_to_server_socket(response, client_addr[0]) + if client_dns_pair: + logging.debug("remove dns client %s:%d" % (client_addr[0][0], client_addr[0][1])) + del self._cache_dns_client[key] + self._close_client(client_dns_pair[0]) + else: + # this packet is from somewhere else we know + # simply drop that packet + pass + + def write_to_server_socket(self, data, addr): + uncomplete = False + retry = 0 + try: + self._server_socket.sendto(data, addr) + data = None + while self._data_to_write_to_server_socket: + data_buf = self._data_to_write_to_server_socket[0] + retry = data_buf[1] + 1 + del self._data_to_write_to_server_socket[0] + data, addr = data_buf[0] + self._server_socket.sendto(data, addr) + except (OSError, IOError) as e: + error_no = eventloop.errno_from_exception(e) + uncomplete = True + if error_no in (errno.EWOULDBLOCK,): + pass + else: + shell.print_exception(e) + return False + #if uncomplete and data is not None and retry < 3: + # self._data_to_write_to_server_socket.append([(data, addr), retry]) + #''' + + def add_to_loop(self, loop): + if self._eventloop: + raise Exception('already add to loop') + if self._closed: + raise Exception('already closed') + self._eventloop = loop + + server_socket = self._server_socket + self._eventloop.add(server_socket, + eventloop.POLL_IN | eventloop.POLL_ERR, self) + loop.add_periodic(self.handle_periodic) + + def remove_handler(self, client): + if hash(client) in self._timeout_cache: + del self._timeout_cache[hash(client)] + + def update_activity(self, client): + self._timeout_cache[hash(client)] = client + + def _sweep_timeout(self): + self._timeout_cache.sweep() + + def _close_tcp_client(self, client): + if client.remote_address: + logging.debug('timed out: %s:%d' % + client.remote_address) + else: + logging.debug('timed out') + client.destroy() + client.destroy_local() + + def handle_event(self, sock, fd, event): + if sock == self._server_socket: + if event & eventloop.POLL_ERR: + logging.error('UDP server_socket err') + try: + self._handle_server() + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + elif sock and (fd in self._sockets): + if event & eventloop.POLL_ERR: + logging.error('UDP client_socket err') + try: + self._handle_client(sock) + except Exception as e: + shell.print_exception(e) + if self._config['verbose']: + traceback.print_exc() + else: + if sock: + handler = self._fd_to_handlers.get(fd, None) + if handler: + handler.handle_event(sock, event) + else: + logging.warn('poll removed fd') + + def handle_periodic(self): + if self._closed: + self._cache.clear(0) + self._cache_dns_client.clear(0) + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.remove(self._server_socket) + if self._server_socket: + self._server_socket.close() + self._server_socket = None + logging.info('closed UDP port %d', self._listen_port) + else: + before_sweep_size = len(self._sockets) + self._cache.sweep() + self._cache_dns_client.sweep() + if before_sweep_size != len(self._sockets): + logging.debug('UDP port %5d sockets %d' % (self._listen_port, len(self._sockets))) + self._sweep_timeout() + + def close(self, next_tick=False): + logging.debug('UDP close') + self._closed = True + if not next_tick: + if self._eventloop: + self._eventloop.remove_periodic(self.handle_periodic) + self._eventloop.remove(self._server_socket) + self._server_socket.close() + self._cache.clear(0) + self._cache_dns_client.clear(0) diff --git a/shadowsocks/version.py b/shadowsocks/version.py new file mode 100644 index 0000000..7454f3f --- /dev/null +++ b/shadowsocks/version.py @@ -0,0 +1,20 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright 2017 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +def version(): + return 'SSRR 3.2.2 2018-05-22' +