Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/so-saf/websockify
Browse files Browse the repository at this point in the history
  • Loading branch information
CendioOssman committed Aug 29, 2024
2 parents 245fd08 + 0af3404 commit 417210f
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 16 deletions.
124 changes: 115 additions & 9 deletions tests/test_token_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,29 @@
from unittest.mock import patch, mock_open, MagicMock
from jwcrypto import jwt, jwk

from websockify.token_plugins import ReadOnlyTokenFile, JWTTokenApi, TokenRedis
from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis

class ParseSourceArgumentsTestCase(unittest.TestCase):
def test_parameterized(self):
params = [
('', ['']),
(':', ['', '']),
('::', ['', '', '']),
('"', ['"']),
('""', ['""']),
('"""', ['"""']),
('"localhost"', ['localhost']),
('"localhost":', ['localhost', '']),
('"localhost"::', ['localhost', '', '']),
('"local:host"', ['local:host']),
('"local:host:"pass"', ['"local', 'host', "pass"]),
('"local":"host"', ['local', 'host']),
('"local":host"', ['local', 'host"']),
('localhost:6379:1:pass"word:"my-app-namespace:dev"',
['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']),
]
for src, args in params:
self.assertEqual(args, parse_source_args(src))

class ReadOnlyTokenFileTestCase(unittest.TestCase):
patch('os.path.isdir', MagicMock(return_value=False))
Expand Down Expand Up @@ -267,13 +289,50 @@ def test_invalid_token(self, mock_redis):
instance.get.assert_called_once_with('testhost')
self.assertIsNone(result)

@patch('redis.Redis')
def test_token_without_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234')
token = 'testhost'

def mock_redis_get(key):
self.assertEqual(key, token)
return b'remote_host:remote_port'

instance = mock_redis.return_value
instance.get = mock_redis_get

result = plugin.lookup(token)

self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')

@patch('redis.Redis')
def test_token_with_namespace(self, mock_redis):
plugin = TokenRedis('127.0.0.1:1234:::namespace')
token = 'testhost'

def mock_redis_get(key):
self.assertEqual(key, "namespace:" + token)
return b'remote_host:remote_port'

instance = mock_redis.return_value
instance.get = mock_redis_get

result = plugin.lookup(token)

self.assertIsNotNone(result)
self.assertEqual(result[0], 'remote_host')
self.assertEqual(result[1], 'remote_port')

def test_src_only_host(self):
plugin = TokenRedis('127.0.0.1')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_port(self):
plugin = TokenRedis('127.0.0.1:1234')
Expand All @@ -282,6 +341,7 @@ def test_src_with_host_port(self):
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_port_db(self):
plugin = TokenRedis('127.0.0.1:1234:2')
Expand All @@ -290,6 +350,7 @@ def test_src_with_host_port_db(self):
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_port_db_pass(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret')
Expand All @@ -298,67 +359,112 @@ def test_src_with_host_port_db_pass(self):
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_empty_db_pass(self):
def test_src_with_host_port_db_pass_namespace(self):
plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "namespace:")

def test_src_with_host_empty_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::verysecret')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self):
plugin = TokenRedis('127.0.0.1::::')

def test_src_with_host_empty_port_empty_db_empty_pass(self):
self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:::')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_empty_db_no_pass(self):
def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_no_db_no_pass(self):
def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self):
plugin = TokenRedis('127.0.0.1::::namespace')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "namespace:")

def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self):
plugin = TokenRedis('127.0.0.1::::"ns1:ns2"')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "ns1:ns2:")

def test_src_with_host_empty_port_db_no_pass(self):
def test_src_with_host_empty_port_db_no_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")

def test_src_with_host_port_empty_db_pass(self):
def test_src_with_host_port_empty_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1:1234::verysecret')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 1234)
self.assertEqual(plugin._db, 0)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_db_pass(self):
def test_src_with_host_empty_port_db_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:verysecret')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, 'verysecret')
self.assertEqual(plugin._namespace, "")

def test_src_with_host_empty_port_db_empty_pass(self):
def test_src_with_host_empty_port_db_empty_pass_no_namespace(self):
plugin = TokenRedis('127.0.0.1::2:')

self.assertEqual(plugin._server, '127.0.0.1')
self.assertEqual(plugin._port, 6379)
self.assertEqual(plugin._db, 2)
self.assertEqual(plugin._password, None)
self.assertEqual(plugin._namespace, "")
55 changes: 48 additions & 7 deletions websockify/token_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@

logger = logging.getLogger(__name__)

_SOURCE_SPLIT_REGEX = re.compile(
r'(?<=^)"([^"]+)"(?=:|$)'
r'|(?<=:)"([^"]+)"(?=:|$)'
r'|(?<=^)([^:]*)(?=:|$)'
r'|(?<=:)([^:]*)(?=:|$)',
)


def parse_source_args(src):
"""It works like src.split(":") but with the ability to use a colon
if you wrap the word in quotation marks.
a:b:c:d -> ['a', 'b', 'c', 'd'
a:"b:c":c -> ['a', 'b:c', 'd']
"""
matches = _SOURCE_SPLIT_REGEX.findall(src)
return [m[0] or m[1] or m[2] or m[3] for m in matches]


class BasePlugin():
def __init__(self, src):
Expand Down Expand Up @@ -178,9 +196,9 @@ class TokenRedis(BasePlugin):
The token source is in the format:
host[:port[:db[:password]]]
host[:port[:db[:password[:namespace]]]]
where port, db and password are optional. If port or db are left empty
where port, db, password and namespace are optional. If port or db are left empty
they will take its default value, ie. 6379 and 0 respectively.
If your redis server is using the default port (6379) then you can use:
Expand All @@ -192,9 +210,18 @@ class TokenRedis(BasePlugin):
my-redis-host:::verysecretpass
You can also specify a namespace. In this case, the tokens
will be stored in the format '{namespace}:{token}'
my-redis-host::::my-app-namespace
Or if your namespace is nested, you can wrap it in quotes:
my-redis-host::::"first-ns:second-ns"
In the more general case you will use:
my-redis-host:6380:1:verysecretpass
my-redis-host:6380:1:verysecretpass:my-app-namespace
The TokenRedis plugin expects the format of the target in one of these two
formats:
Expand Down Expand Up @@ -234,8 +261,9 @@ def __init__(self, src):
self._port = 6379
self._db = 0
self._password = None
self._namespace = ""
try:
fields = src.split(":")
fields = parse_source_args(src)
if len(fields) == 1:
self._server = fields[0]
elif len(fields) == 2:
Expand All @@ -256,15 +284,28 @@ def __init__(self, src):
self._db = 0
if not self._password:
self._password = None
elif len(fields) == 5:
self._server, self._port, self._db, self._password, self._namespace = fields
if not self._port:
self._port = 6379
if not self._db:
self._db = 0
if not self._password:
self._password = None
if not self._namespace:
self._namespace = ""
else:
raise ValueError
self._port = int(self._port)
self._db = int(self._db)
logger.info("TokenRedis backend initilized (%s:%s)" %
if self._namespace:
self._namespace += ":"

logger.info("TokenRedis backend initialized (%s:%s)" %
(self._server, self._port))
except ValueError:
logger.error("The provided --token-source='%s' is not in the "
"expected format <host>[:<port>[:<db>[:<password>]]]" %
"expected format <host>[:<port>[:<db>[:<password>[:<namespace>]]]]" %
src)
sys.exit()

Expand All @@ -278,7 +319,7 @@ def lookup(self, token):
logger.info("resolving token '%s'" % token)
client = redis.Redis(host=self._server, port=self._port,
db=self._db, password=self._password)
stuff = client.get(token)
stuff = client.get(self._namespace + token)
if stuff is None:
return None
else:
Expand Down

0 comments on commit 417210f

Please sign in to comment.