diff --git a/.coveragerc b/.coveragerc index 3fbea75..37b481a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,6 @@ [run] branch = True -source = saltyrtc +source = saltyrtc.server [report] exclude_lines = diff --git a/.travis-install-libsodium.sh b/.travis-install-libsodium.sh deleted file mode 100755 index f6dba9f..0000000 --- a/.travis-install-libsodium.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash -set -ev - -if [ -d "${HOME}/libsodium/lib" ]; then - exit 0; -fi - -cd ${HOME} -git clone --depth 1 -b stable https://github.com/jedisct1/libsodium.git libsodium-git -cd libsodium-git -./autogen.sh -./configure --prefix=${HOME}/libsodium && make && make install - -exit 0; diff --git a/.travis.yml b/.travis.yml index 33bf829..4af63a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,22 @@ # Use container system sudo: false -# Cache pip & libsodium +# Install APT dependencies +addons: + apt: + sources: + - sourceline: "ppa:chris-lea/libsodium" + packages: + - libsodium-dev + +# Cache pip cache: directories: - - $HOME/libsodium - $HOME/.cache/pip # Clean up pip log before_cache: - - rm -f $HOME/.cache/pip/log/debug.log +- rm -f $HOME/.cache/pip/log/debug.log # Build matrix language: python @@ -17,57 +24,72 @@ matrix: include: - python: "3.4" env: - - EVENT_LOOP=asyncio - - TIMEOUT=2.0 + - EVENT_LOOP=asyncio + - TIMEOUT=2.0 + - IGNORE=tests/test_cli.py - python: "3.5" env: - - EVENT_LOOP=asyncio - - TIMEOUT=2.0 + - EVENT_LOOP=asyncio + - TIMEOUT=2.0 - python: "3.5" env: - - EVENT_LOOP=uvloop - - TIMEOUT=2.0 + - EVENT_LOOP=uvloop + - TIMEOUT=2.0 before_script: "pip install .[uvloop]" - python: "3.6" env: - - EVENT_LOOP=asyncio - - TIMEOUT=2.0 + - EVENT_LOOP=asyncio + - TIMEOUT=2.0 - python: "3.6" env: - - EVENT_LOOP=uvloop - - TIMEOUT=2.0 + - EVENT_LOOP=uvloop + - TIMEOUT=2.0 before_script: "pip install .[uvloop]" +# TODO: Enable once 3.7 support has been added +# - python: "3.7" +# env: +# - EVENT_LOOP=asyncio +# - TIMEOUT=2.0 +# - python: "3.7" +# env: +# - EVENT_LOOP=uvloop +# - TIMEOUT=2.0 +# before_script: "pip install .[uvloop]" - python: "pypy3" # 2017-08-05: It's pypy3-5.8.0-beta env: - - EVENT_LOOP=asyncio - - TIMEOUT=16.0 + - EVENT_LOOP=asyncio + - TIMEOUT=16.0 # TODO: Re-enable once pypy3 is able to compile uvloop # - python: "pypy3" # 2017-08-05: It's pypy3-5.8.0-beta # env: -# - EVENT_LOOP=uvloop -# - TIMEOUT=16.0 +# - EVENT_LOOP=uvloop +# - TIMEOUT=16.0 # before_script: "pip install .[uvloop]" # Install dependencies -before_install: - - ./.travis-install-libsodium.sh - - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/libsodium/lib install: - - pip install -U setuptools pip - - "pip install .[dev]" - - pip install codecov +- pip install -U setuptools pip +- "pip install .[dev]" +- pip install codecov # Run flake8, isort, check docs & tests +# TODO: Re-enable isort once #85 has been resolved script: - - > - if [[ "$TRAVIS_PYTHON_VERSION" != "pypy3" ]]; then - flake8 . || travis_terminate 1; - isort -rc -c . || (isort -rc -df . && return 1) || travis_terminate 1; - fi - - python setup.py checkdocs - - py.test --cov-config .coveragerc --cov=saltyrtc.server --loop=$EVENT_LOOP --timeout=$TIMEOUT +- > + if [[ "$TRAVIS_PYTHON_VERSION" != "pypy3" ]]; then + flake8 . || travis_terminate 1; + #isort -rc -c . || (isort -rc -df . && return 1) || travis_terminate 1; + fi +- python setup.py checkdocs +- > + py.test \ + --ignore=$IGNORE \ + --cov-config=.coveragerc \ + --cov=saltyrtc.server \ + --loop=$EVENT_LOOP \ + --timeout=$TIMEOUT # After success after_success: - - codecov +- codecov diff --git a/saltyrtc/server/bin.py b/saltyrtc/server/bin.py index c24fcca..0ec7612 100755 --- a/saltyrtc/server/bin.py +++ b/saltyrtc/server/bin.py @@ -272,3 +272,7 @@ def main(): finally: if obj['logging_handler'] is not None: obj['logging_handler'].pop_application() + + +if __name__ == '__main__': + main() diff --git a/saltyrtc/server/common.py b/saltyrtc/server/common.py index 2edb4f6..a9e6585 100644 --- a/saltyrtc/server/common.py +++ b/saltyrtc/server/common.py @@ -21,6 +21,7 @@ 'OverflowSentinel', 'SubProtocol', 'CloseCode', + 'ClientState', 'AddressType', 'MessageType', 'available_slot_range', @@ -82,12 +83,44 @@ class CloseCode(enum.IntEnum): initiator_could_not_decrypt = 3005 no_shared_tasks = 3006 invalid_key = 3007 + timeout = 3008 @property def is_valid_drop_reason(self): return self.value in _drop_reasons +@enum.unique +class ClientState(enum.IntEnum): + """ + The state of a :class:`PathClient`. + + .. important:: States MUST follow the exact order as enumerated + below. A client cannot go back a state or skip + a state in between. For example, a *dropped* client + MUST have been formerly *authenticated*. + """ + # The client is connected but is not allowed to communicate + # with another client. + restricted = 1 + + # The client has been authenticated and may communicate with + # other clients (of different type). + authenticated = 2 + + # The client has been dropped by another client. + dropped = 3 + + @property + def next(self): + """ + Return the subsequent state. + + Raises :exc:`ValueError` in case there is no subsequent state. + """ + return ClientState(self + 1) + + @enum.unique class AddressType(enum.IntEnum): server = 0x00 diff --git a/saltyrtc/server/message.py b/saltyrtc/server/message.py index 1805e9b..18d977f 100644 --- a/saltyrtc/server/message.py +++ b/saltyrtc/server/message.py @@ -14,6 +14,7 @@ NONCE_FORMATTER, NONCE_LENGTH, AddressType, + ClientState, CloseCode, MessageType, OverflowSentinel, @@ -192,8 +193,8 @@ def pack(self, client): # Encrypt payload if required if self.encrypted: - if not client.authenticated: - raise MessageFlowError('Cannot encrypt payload, no box available') + if client.state != ClientState.authenticated: + raise MessageFlowError('Cannot encrypt payload, not authenticated') payload = self._encrypt_payload(client, nonce, payload) # Append payload and return as bytes @@ -224,7 +225,7 @@ def unpack(cls, client, data): expect_type = None if destination_type == AddressType.server: data = data[NONCE_LENGTH:] - if not client.authenticated and client.type is None: + if client.state == ClientState.restricted and client.type is None: payload = None # Try client-auth (encrypted) diff --git a/saltyrtc/server/protocol.py b/saltyrtc/server/protocol.py index af08f15..9dcf53e 100644 --- a/saltyrtc/server/protocol.py +++ b/saltyrtc/server/protocol.py @@ -1,4 +1,5 @@ import asyncio +import enum import os import struct @@ -14,6 +15,7 @@ KEEP_ALIVE_TIMEOUT, KEY_LENGTH, AddressType, + ClientState, OverflowSentinel, available_slot_range, is_initiator_id, @@ -24,18 +26,25 @@ InternalError, MessageError, MessageFlowError, - SignalingError, SlotsFullError, ) from .message import unpack __all__ = ( 'Path', + 'PathClientTasks', 'PathClient', 'Protocol', ) +@enum.unique +class _TaskQueueState(enum.IntEnum): + open = 1 + closed = 2 + cancelled = 3 + + class Path: __slots__ = ('_slots', 'log', 'initiator_key', 'number', 'attached') @@ -49,26 +58,35 @@ def __init__(self, initiator_key, number, attached=True): @property def empty(self): """ - Return `True` in case the path is empty. A call to this property - will also remove clients from the path whose connections are - closed but have not been removed from the path. (However, in - case that the path is not empty, this property does not ensure - that all disconnected clients will be removed.) - """ - for client in self._slots.values(): - if client is not None: - if client.connection_closed.done(): - self.remove_client(client) - self.log.notice('Removed dead client {}', client) - else: - return False - return True + Return whether the path is empty. + """ + return all((client is None for client in self._slots.values())) + + def has_client(self, client): + """ + Return whether a client's :class:`PathClient` instance is still + available on the path. + + Arguments: + - `client`: The :class:`PathClient` instance to look for. + + Raises :exc:`KeyError` in case the client has not been assigned + an ID yet. + """ + # Note: No need to check for an unassigned ID since the server's ID will never + # be available in the slots. + return self._slots[client.id] == client def get_initiator(self): """ - Return the initiator's :class:`PathClient` instance or `None`. + Return the initiator's :class:`PathClient` instance. + + Raises :exc:`KeyError` if there is no initiator. """ - return self._slots.get(AddressType.initiator) + client = self._slots[AddressType.initiator] + if client is None: + raise KeyError('No initiator found') + return client def set_initiator(self, initiator): """ @@ -77,6 +95,9 @@ def set_initiator(self, initiator): Arguments: - `initiator`: A :class:`PathClient` instance. + Raises :exc:`ValueError` in case of a state violation on the + :class:`PathClient`. + Return the previously set initiator or `None`. """ previous_initiator = self._slots.get(AddressType.initiator) @@ -85,31 +106,36 @@ def set_initiator(self, initiator): # Update initiator's log name initiator.update_log_name(AddressType.initiator) # Authenticated, assign id - initiator.authenticated = True - initiator.id = AddressType.initiator + initiator.authenticate(AddressType.initiator) # Return previous initiator return previous_initiator def get_responder(self, id_): """ - Return a responder's :class:`PathClient` instance or `None`. + Return a responder's :class:`PathClient` instance. Arguments: - `id_`: The identifier of the responder. - Raises :exc:`ValueError` if `id_` is not a valid responder - identifier. + Raises: + - :exc:`ValueError`: If `id_` is not a valid responder + identifier. + - :exc:`KeyError`: If `id_` cannot be associated to a + :class:`PathClient` instance. """ if not is_responder_id(id_): raise ValueError('Invalid responder identifier') - return self._slots.get(id_) + client = self._slots[id_] + if client is None: + raise KeyError('No responder found') + return client def get_responder_ids(self): """ - Return a list of responder's identifiers (slots). + Return an iterable of responder's identifiers (slots). """ - return [id_ for id_, responder in self._slots.items() - if is_responder_id(id_) and responder is not None] + return (id_ for id_, responder in self._slots.items() + if is_responder_id(id_) and responder is not None) def add_responder(self, responder): """ @@ -118,7 +144,10 @@ def add_responder(self, responder): Arguments: - `client`: A :class:`PathClient` instance. - Raises :exc:`SlotsFullError` if no free slot exists on the path. + Raises: + - :exc:`SlotsFullError` if no free slot exists on the path. + - :exc:`ValueError` in case of a state violation on the + :class:`PathClient`. Return the assigned slot identifier. """ @@ -129,8 +158,7 @@ def add_responder(self, responder): # Update responder's log name responder.update_log_name(id_) # Authenticated, set and return assigned slot id - responder.authenticated = True - responder.id = id_ + responder.authenticate(id_) return id_ raise SlotsFullError('No free slots on path') @@ -139,26 +167,23 @@ def remove_client(self, client): Remove a client (initiator or responder) from the :class:`Path`. + .. important:: Shall only be called from the client's + own :class:`Protocol` instance or from another client's + :class.`Protocol` instance in case it is dropping a client. + Arguments: - `client`: The :class:`PathClient` instance. - Raises :exc:`ValueError` in case the client provided an + Raises :exc:`KeyError` in case the client provided an invalid slot identifier. """ - - if not client.authenticated: - # Client has not been authenticated. Nothing to do. + if client.state == ClientState.restricted: + # Client has never been authenticated. Nothing to do. return id_ = client.id - # Get client instance - try: - slot_client = self._slots[id_] - except KeyError: - raise ValueError('Invalid slot identifier: {}'.format(id_)) - # Compare client instances - if client != slot_client: + if client != self._slots[id_]: # Note: This is absolutely fine and happens when another initiator # takes the place of a previous initiator. return @@ -168,10 +193,66 @@ def remove_client(self, client): self.log.debug('Removed {}', 'initiator' if is_initiator_id(id_) else 'responder') +# TODO: We should be able to use a NamedTuple for this once we drop Python 3.4 support +class PathClientTasks: + __slots__ = ( + 'task_loop', + 'receive_loop', + 'keep_alive_loop', + ) + + def __init__( + self, + task_loop=None, receive_loop=None, keep_alive_loop=None, + loop=None + ): + if loop is None: + asyncio.get_event_loop() + self.task_loop = self._ensure_future_or_none(task_loop, loop) + self.receive_loop = self._ensure_future_or_none(receive_loop, loop=loop) + self.keep_alive_loop = self._ensure_future_or_none(keep_alive_loop, loop=loop) + + @property + def tasks(self): + """ + Return all tasks (including those who are set to `None`) as a + tuple. + """ + return ( + self.task_loop, + self.receive_loop, + self.keep_alive_loop, + ) + + @property + def valid(self): + """ + Return all valid tasks (i.e. those who are not set to `None`) + as an iterable. + """ + return (task for task in self.tasks if task is not None) + + def cancel_all_but_task_loop(self): + """ + Cancel all valid tasks but the task queue. + """ + for task in self.valid: + if task != self.task_loop: + task.cancel() + + @staticmethod + def _ensure_future_or_none(coroutine_or_task, loop): + if coroutine_or_task is None: + return None + return asyncio.ensure_future(coroutine_or_task, loop=loop) + + class PathClient: __slots__ = ( '_loop', + '_state', '_connection', + '_connection_closed_future', '_client_key', '_server_permanent_key', '_server_session_key', @@ -187,11 +268,11 @@ class PathClient: '_keep_alive_interval', 'log', 'type', - 'authenticated', 'keep_alive_timeout', 'keep_alive_pings', + 'tasks', '_task_queue', - '_task_queue_closed', + '_task_queue_state', ) def __init__( @@ -199,7 +280,10 @@ def __init__( server_session_key=None, loop=None ): self._loop = asyncio.get_event_loop() if loop is None else loop + self._state = ClientState.restricted self._connection = connection + connection_closed_future = asyncio.Future(loop=self._loop) + self._connection_closed_future = connection_closed_future self._client_key = initiator_key self._server_permanent_key = None self._server_session_key = server_session_key @@ -213,13 +297,18 @@ def __init__( self._keep_alive_interval = KEEP_ALIVE_INTERVAL_DEFAULT self.log = util.get_logger('path.{}.client.{:x}'.format(path_number, id(self))) self.type = None - self.authenticated = False self.keep_alive_timeout = KEEP_ALIVE_TIMEOUT self.keep_alive_pings = 0 + self.tasks = None + + # Schedule connection closed future + def _connection_closed(_): + connection_closed_future.set_result(connection.close_code) + self._connection.connection_lost_waiter.add_done_callback(_connection_closed) # Queue for tasks to be run on the client (relay messages, closing, ...) self._task_queue = asyncio.Queue(loop=self._loop) - self._task_queue_closed = False + self._task_queue_state = _TaskQueueState.open def __str__(self): type_ = self.type @@ -229,20 +318,33 @@ def __str__(self): type_, self._id, hex(id(self))) @property - def connection_closed(self): + def state(self): """ - Return the 'connection_closed' future of the underlying - WebSocket connection. + Return the current :class:`ClientState` of the client. """ - return self._connection.connection_closed + return self._state + + @state.setter + def state(self, state): + """ + Update the :class:`ClientState` of the client. + + Raises :exc:`ValueError` in case the state is not following + the strict state order as defined by :class`ClientState`. + """ + if state != self.state.next: + raise ValueError('State {} cannot be updated to {}'.format(self.state, state)) + self.log.debug('State {} -> {}', self._state.name, state.name) + self._state = state @property - def tasks_complete(self): + def connection_closed_future(self): """ - Return whether the underlying task queue is complete (empty). - Will also return `True` in case the task queue has been closed. + Resolves once the connection has been closed. + + Return the close code. """ - return self._task_queue.empty() or self._task_queue_closed + return self._connection_closed_future @property def id(self): @@ -251,14 +353,6 @@ def id(self): """ return self._id - @id.setter - def id(self, id_): - """ - Assign the id. Only :class:`Path` may set the id! - """ - self._id = id_ - self.log.debug('Assigned id: {}', id_) - @property def keep_alive_interval(self): """ @@ -418,6 +512,19 @@ def set_client_key(self, public_key): self._box = libnacl.public.Box(self.server_key, public_key) self.log.debug('Client key updated') + def authenticate(self, id_): + """ + Authenticate the client and assign it an id. + + .. important:: Only :class:`Path` may call this! + + Raises :exc:`ValueError` in case the previous state was not + :attr:`ClientState.restricted`. + """ + self.state = ClientState.authenticated + self._id = id_ + self.log.debug('Assigned id: {}', id_) + def update_log_name(self, slot_id): """ Update the logger's name by the assigned slot identifier. @@ -486,23 +593,38 @@ def p2p_allowed(self, destination_type): Return `True` if :class:`RawMessage` instances are allowed and can be sent to the requested :class:`AddressType`. """ - return self.authenticated and self.type != destination_type + return self.state == ClientState.authenticated and self.type != destination_type @asyncio.coroutine - def enqueue_task(self, coroutine_or_task): + def enqueue_task(self, coroutine_or_task, ignore_closed=False): """ Enqueue a coroutine or task into the task queue of the client. + .. important:: Only the following tasks shall be enqueued: + - Messages from the server towards this client. + - Messages from other clients **towards** this + client (i.e. relayed messages). + - Delayed close operations towards this client. + + .. note:: Coroutines will be closed and :class:`asyncio.Task`s + will be cancelled when the task queue has been closed + (unless `ignore_closed` has been set to `True`) or + cancelled. The coroutine or task must be prepared for + that. + Arguments: - `coroutine_or_task`: A coroutine or a :class:`asyncio.Task`. - - Raises :exc:`SignalingError` in case the task queue is closed. + - `ignore_closed`: Whether the coroutine or + :class:`asyncio.Task` should be enqueued even if the task + queue has been closed. """ - if self._task_queue_closed: - raise SignalingError('Task queue is already closed') - yield from self._task_queue.put(coroutine_or_task) + if (self._task_queue_state == _TaskQueueState.open + or (ignore_closed and self._task_queue_state == _TaskQueueState.closed)): + yield from self._task_queue.put(coroutine_or_task) + else: + self._cancel_coroutine_or_task(coroutine_or_task, mark_as_done=False) @asyncio.coroutine def dequeue_task(self): @@ -510,25 +632,109 @@ def dequeue_task(self): Dequeue and return a coroutine or task from the task queue of the client. - Shall only be called from the client's :class:`Protocol` - instance. - - Raises :exc:`SignalingError` in case the task queue is closed. + .. warning:: Shall only be called from the client's + :class:`Protocol` instance. """ - if self._task_queue_closed: - raise SignalingError('Task queue is already closed') return (yield from self._task_queue.get()) + def task_done(self, task): + """ + Mark a previously dequeued task as processed. + + Raises :exc:`InternalError` if called more times than there + were tasks placed in the queue. + """ + self.log.debug('Done task {}', task) + try: + self._task_queue.task_done() + except ValueError: + raise InternalError('More tasks marked as done as were enqueued') + def close_task_queue(self): """ - Close the task queue, so no further tasks can be enqueued. + Close the task queue to prevent further enqueues. Will do + nothing in case the task queue has already been closed or + cancelled. - Raises :exc:`SignalingError` in case the task queue was already - closed. + .. note:: Unlike :func:`~PathClient.cancel_task_queue`, this does + not cancel any pending tasks. """ - if self._task_queue_closed: - raise SignalingError('Task queue is already closed') - self._task_queue_closed = True + # Ignore if already closed or cancelled + if self._task_queue_state >= _TaskQueueState.closed: + return + + # Update state + self._task_queue_state = _TaskQueueState.closed + self.log.debug('Closed task queue') + + def cancel_task_queue(self): + """ + Cancel all pending tasks of the task queue and prevent further + enqueues. Will do nothing in case the task queue has already + been cancelled. + """ + # Ignore if already cancelled + if self._task_queue_state >= _TaskQueueState.cancelled: + return + + # Cancel all pending tasks + # + # Add a 'done' callback to each task in order to mark the task queue as 'closed' + # after all functions, which may want to handle the cancellation, have handled + # that cancellation. + # + # This for example prevents a 'disconnect' message from being sent before a + # 'send-error' message has been sent, see: + # https://github.com/saltyrtc/saltyrtc-server-python/issues/77 + self._task_queue_state = _TaskQueueState.cancelled + self.log.debug('Cancelling {} queued tasks', self._task_queue.qsize()) + while True: + try: + coroutine_or_task = self._task_queue.get_nowait() + except asyncio.QueueEmpty: + break + self._cancel_coroutine_or_task(coroutine_or_task, mark_as_done=True) + + def _cancel_coroutine_or_task(self, coroutine_or_task, mark_as_done=False): + """ + Cancel a coroutine or a :class:`asyncio.Task`. + + Arguments: + - `coroutine_or_task`: The coroutine or + :class:`asyncio.Task` to be cancelled. + - `mark_as_done`: Whether to mark the task as *processed* + on the task queue. Defaults to `False`. + """ + if asyncio.iscoroutine(coroutine_or_task): + self.log.debug('Closing queued coroutine {}', coroutine_or_task) + coroutine_or_task.close() + if mark_as_done: + self.task_done(coroutine_or_task) + else: + if mark_as_done: + coroutine_or_task.add_done_callback(self.task_done) + # Note: We need to check for .cancelled first since a task is also marked + # .done when it is cancelled. + if coroutine_or_task.cancelled(): + self.log.debug('Already cancelled task {}', coroutine_or_task) + elif coroutine_or_task.done(): + exc = coroutine_or_task.exception() + if exc is not None: + message = 'Ignoring exception of queued task {}: {}' + self.log.debug(message, coroutine_or_task, repr(exc)) + else: + message = 'Ignoring completion of queued task {}' + self.log.debug(message, coroutine_or_task) + else: + self.log.debug('Cancelling queued task {}', coroutine_or_task) + coroutine_or_task.cancel() + + @asyncio.coroutine + def join_task_queue(self): + """ + Block until all tasks of the task queue have been processed. + """ + yield from self._task_queue.join() @asyncio.coroutine def send(self, message): @@ -548,6 +754,7 @@ def send(self, message): yield from self._connection.send(data) except websockets.ConnectionClosed as exc: self.log.debug('Connection closed while sending') + self.close_task_queue() raise Disconnected(exc.code) from exc @asyncio.coroutine @@ -555,11 +762,17 @@ def receive(self): """ Disconnected """ + # Safeguard + # Note: This should never happen since the receive queue will + # be stopped when a client is being dropped. + assert self.state < ClientState.dropped + # Receive data try: data = yield from self._connection.recv() except websockets.ConnectionClosed as exc: self.log.debug('Connection closed while receiving') + self.close_task_queue() raise Disconnected(exc.code) from exc self.log.debug('Received message') @@ -576,16 +789,82 @@ def ping(self): """ self.log.debug('Sending ping') try: - return (yield from self._connection.ping()) + pong_future = yield from self._connection.ping() except websockets.ConnectionClosed as exc: self.log.debug('Connection closed while pinging') + self.close_task_queue() + raise Disconnected(exc.code) from exc + return self._wait_pong(pong_future) + + @asyncio.coroutine + def _wait_pong(self, pong_future): + """ + Disconnected + """ + try: + yield from pong_future + except websockets.ConnectionClosed as exc: + self.log.debug('Connection closed while waiting for pong') + self.close_task_queue() raise Disconnected(exc.code) from exc @asyncio.coroutine def close(self, code=1000): + """ + Initiate the closing procedure and wait for the connection to + become closed. + + Arguments: + - `close`: The close code. + """ + # Close the task queue to ensure no further tasks can be + # enqueued while the client is in the closing process. + self.close_task_queue() + # Note: We are not sending a reason for security reasons. yield from self._connection.close(code=code) + def drop(self, code): + """ + Drop this client. Will enqueue the closing procedure and cancel + the receive loop as well as the keep alive loop of the client. + + Return the enqueue operation in form of a + :class:`asyncio.Task`. + + .. important:: This should only be called by clients dropping + another client or when the server is closing. + + Arguments: + - `close`: The close code. + """ + # Enqueue the close procedure on our own task queue. + # Note: The closing procedure would interrupt further send operations, thus we + # MUST enqueue it as a coroutine and NOT wrap in a Future. That way, it + # will not initiate the closing procedure before this client has executed + # all other pending tasks. + self.log.debug('Scheduling delayed closing procedure', code) + close_coroutine = self.close(code=code) + enqueue_task = asyncio.ensure_future( + self.enqueue_task(close_coroutine, ignore_closed=True), loop=self._loop) + + # Close the task queue to ensure no further tasks can be + # enqueued while the client is in the closing process. + self.close_task_queue() + + # Cancel all loops for the client but the task queue. + # Note: This will ensure that all messages forwarded towards the client to be + # dropped will still be forwarded. But the to be dropped client will not be + # able to send any more messages towards the server or relay messages + # towards other clients. + self.log.debug('Cancelling all running tasks but the task loop') + self.tasks.cancel_all_but_task_loop() + + # Mark as dropped + self.state = ClientState.dropped + self.log.debug('Client dropped, close code: {}', code) + return enqueue_task + class Protocol: PATH_LENGTH = KEY_LENGTH * 2 diff --git a/saltyrtc/server/server.py b/saltyrtc/server/server.py index 8e10a3a..d9bc824 100644 --- a/saltyrtc/server/server.py +++ b/saltyrtc/server/server.py @@ -1,6 +1,5 @@ import asyncio import binascii -import inspect from collections import OrderedDict from typing import ( Dict, @@ -15,6 +14,7 @@ NONCE_LENGTH, RELAY_TIMEOUT, AddressType, + ClientState, CloseCode, MessageType, SubProtocol, @@ -26,6 +26,7 @@ from .exception import ( Disconnected, DowngradeError, + InternalError, MessageError, MessageFlowError, PathError, @@ -46,6 +47,7 @@ from .protocol import ( Path, PathClient, + PathClientTasks, Protocol, ) @@ -60,16 +62,16 @@ 'ServerProtocol', 'Paths', 'Server', - 'TASK_LOOP_TIMEOUT', ) -TASK_LOOP_TIMEOUT = 600.0 +_TASK_QUEUE_JOIN_TIMEOUT = 10.0 @asyncio.coroutine def serve( ssl_context, keys, paths=None, host=None, port=8765, loop=None, - event_callbacks: Dict[Event, List[Coroutine]] = None, server_class=None + event_callbacks: Dict[Event, List[Coroutine]] = None, server_class=None, + ws_kwargs=None, ): """ Start serving SaltyRTC Signalling Clients. @@ -94,6 +96,15 @@ def serve( occurs. - `server_class`: An optional :class:`Server` class to create an instance from. + - `ws_kwargs`: Additional keyword arguments passed to + :func:`websockets.server.serve`. Note that the fields `ssl`, + `host`, `port`, `loop`, `subprotocols` and `ping_interval` + will be overridden. + + If the `compression` field is not explicitly set, + compression will be disabled (since the data to be compressed + is already encrypted, compression will have little to no + positive effect). Raises :exc:`ServerKeyError` in case one or more keys have been repeated. """ @@ -115,16 +126,20 @@ def serve( for callback in callbacks: server.register_event_callback(event, callback) - # Start server - ws_server = yield from websockets.serve( - server.handler, - ssl=ssl_context, - host=host, - port=port, - subprotocols=server.subprotocols - ) + # Prepare arguments for the WS server + if ws_kwargs is None: + ws_kwargs = {} + ws_kwargs['ssl'] = ssl_context + ws_kwargs['host'] = host + ws_kwargs['port'] = port + ws_kwargs.setdefault('compression', None) + ws_kwargs['ping_interval'] = None # Disable the keep-alive of the transport library + ws_kwargs['subprotocols'] = server.subprotocols + + # Start WS server + ws_server = yield from websockets.serve(server.handler, **ws_kwargs) - # Set server instance + # Set WS server instance server.server = ws_server # Return server @@ -142,7 +157,7 @@ class ServerProtocol(Protocol): 'handler_task' ) - def __init__(self, server, subprotocol, loop=None): + def __init__(self, server, subprotocol, connection, ws_path, loop=None): self._log = util.get_logger('server.protocol') self._loop = asyncio.get_event_loop() if loop is None else loop @@ -153,39 +168,6 @@ def __init__(self, server, subprotocol, loop=None): # Path and client instance self.path = None self.client = None - - # Handler task that is set after 'connection_made' has been called - self.handler_task = None - - # Determine subprotocol selection function - # Might be a static method, might be a normal method, see - # https://github.com/aaugustin/websockets/pull/132 - protocol = websockets.WebSocketServerProtocol - select_subprotocol = inspect.getattr_static(protocol, 'select_subprotocol') - if isinstance(select_subprotocol, staticmethod): - self._select_subprotocol = protocol.select_subprotocol - else: - def _select_subprotocol(client_subprotocols, server_subprotocols): - # noinspection PyTypeChecker - return protocol.select_subprotocol( - None, client_subprotocols, server_subprotocols) - self._select_subprotocol = _select_subprotocol - - def connection_made(self, connection, ws_path): - self.handler_task = asyncio.ensure_future( - self.handler(connection, ws_path), loop=self._loop) - - @asyncio.coroutine - def close(self, code=1000): - # Note: The client will be set as early as possible without any yielding. - # Thus, self.client is either set and can be closed or the connection - # is already closing (see the corresponding lines in 'handler' and - # 'get_path_client') - if self.client is not None: - yield from self.client.close(code=code) - - @asyncio.coroutine - def handler(self, connection, ws_path): self._log.debug('New connection on WS path {}', ws_path) # Get path and client instance as early as possible @@ -193,80 +175,96 @@ def handler(self, connection, ws_path): path, client = self.get_path_client(connection, ws_path) except PathError as exc: self._log.notice('Closing due to path error: {}', exc) - yield from connection.close(code=CloseCode.protocol_error.value) - self._server.raise_event( - Event.disconnected, None, CloseCode.protocol_error.value) - return - client.log.info('Connection established') - client.log.debug('Worker started') - # Store path and client - self.path = path - self.client = client - self._server.register(self) + @asyncio.coroutine + def close_with_protocol_error(): + yield from connection.close(code=CloseCode.protocol_error.value) + self._server.raise_event( + Event.disconnected, None, CloseCode.protocol_error.value) + handler_coroutine = close_with_protocol_error() + else: + handler_coroutine = self.handler() + client.log.info('Connection established') + client.log.debug('Worker started') - # Start task queue - client.log.debug('Starting to poll for enqueued tasks') - task_loop_task = asyncio.ensure_future(self.task_loop(), loop=self._loop) + # Store path and client + self.path = path + self.client = client + self._server.register(self) + + # Start handler task + self.handler_task = asyncio.ensure_future(handler_coroutine, loop=self._loop) + + @asyncio.coroutine + def handler(self): + client, path = self.client, self.path # Handle client until disconnected or an exception occurred hex_path = binascii.hexlify(self.path.initiator_key).decode('ascii') close_future = asyncio.Future(loop=self._loop) + try: - yield from self.handle_client(task_loop_task) + yield from self.handle_client() except Disconnected as exc: client.log.info('Connection closed (code: {})', exc.reason) close_future.set_result(None) self._server.raise_event(Event.disconnected, hex_path, exc.reason) + except PingTimeoutError: + client.log.info('Closing because of a ping timeout') + close_future = client.close(CloseCode.timeout) + self._server.raise_event( + Event.disconnected, hex_path, CloseCode.timeout) except SlotsFullError as exc: client.log.notice('Closing because all path slots are full: {}', exc) close_future = client.close(code=CloseCode.path_full_error.value) self._server.raise_event( - Event.disconnected, hex_path, CloseCode.path_full_error.value) + Event.disconnected, hex_path, CloseCode.path_full_error.value) except ServerKeyError as exc: client.log.notice('Closing due to server key error: {}', exc) close_future = client.close(code=CloseCode.invalid_key.value) self._server.raise_event( - Event.disconnected, hex_path, CloseCode.invalid_key.value) + Event.disconnected, hex_path, CloseCode.invalid_key.value) + except InternalError as exc: + client.log.exception('Closing due to an internal error:', exc) + close_future = client.close(code=CloseCode.internal_error.value) + self._server.raise_event( + Event.disconnected, hex_path, CloseCode.internal_error.value) except SignalingError as exc: client.log.notice('Closing due to protocol error: {}', exc) close_future = client.close(code=CloseCode.protocol_error.value) self._server.raise_event( - Event.disconnected, hex_path, CloseCode.protocol_error.value) + Event.disconnected, hex_path, CloseCode.protocol_error.value) except Exception as exc: client.log.exception('Closing due to exception:', exc) close_future = client.close(code=CloseCode.internal_error.value) self._server.raise_event( - Event.disconnected, hex_path, CloseCode.internal_error.value) + Event.disconnected, hex_path, CloseCode.internal_error.value) else: # Note: This should not ever happen since 'handle_client' - # contains an inifinite loop that only stops due to an exception. + # contains an infinite loop that only stops due to an exception. client.log.error('Client closed without exception') close_future.set_result(None) - # Remove client from path - path.remove_client(client) - self._server.paths.clean(path) - - # Wait for the task loop to complete - # Note: This will ensure all relay messages are cancelled and a 'send-error' is - # sent before the 'disconnected' message is being created. - if client.tasks_complete: - task_loop_task.cancel() - client.log.debug('Waiting for the task loop to finish') + # Schedule closing of the client + # Note: This ensures the client is closed soon even if the task queue is holding + # us up. + close_future = asyncio.ensure_future(close_future, loop=self._loop) + + # Wait until all queued tasks have been processed + # Note: This ensure that a send-error message (and potentially other messages) + # are enqueued towards other clients before the disconnect message. + client.log.debug('Joining task queue') try: - yield from asyncio.wait_for(task_loop_task, TASK_LOOP_TIMEOUT, - loop=self._loop) + yield from asyncio.wait_for( + client.join_task_queue(), _TASK_QUEUE_JOIN_TIMEOUT, loop=self._loop) except asyncio.TimeoutError: - client.log.error('Task loop has been killed after {} seconds!', - TASK_LOOP_TIMEOUT) - except Exception as exc: - client.log.exception('Task loop returned with an exception: {}', exc) - client.close_task_queue() - client.log.debug('Task queue finished and closed') + client.log.error( + 'Task queue did not close after {} seconds', _TASK_QUEUE_JOIN_TIMEOUT) + else: + client.log.debug('Task queue closed') # Send disconnected message if client was authenticated - if client.authenticated: + if client.state == ClientState.authenticated: # Initiator: Send to all responders if client.type == AddressType.initiator: responder_ids = path.get_responder_ids() @@ -279,19 +277,35 @@ def handler(self, connection, ws_path): AddressType.server, responder_id, client.id) responder.log.debug('Enqueueing disconnected message') coroutines.append(responder.enqueue_task(responder.send(message))) - yield from asyncio.gather(*coroutines, loop=self._loop) + try: + yield from asyncio.gather(*coroutines, loop=self._loop) + except Exception as exc: + description = 'Error while dispatching disconnected messages to ' \ + 'responders:' + client.log.exception(description, exc) # Responder: Send to initiator (if present) elif client.type == AddressType.responder: - initiator = path.get_initiator() - initiator_connected = initiator is not None - if initiator_connected: - # Create message and add send coroutine to task queue of the initiator + try: + initiator = path.get_initiator() + except KeyError: + pass # No initiator present + else: + # Create message and add send coroutine to task queue of the + # initiator message = DisconnectedMessage.create( AddressType.server, initiator.id, client.id) initiator.log.debug('Enqueueing disconnected message') - yield from initiator.enqueue_task(initiator.send(message)) + try: + yield from initiator.enqueue_task(initiator.send(message)) + except Exception as exc: + description = 'Error while dispatching disconnected message' \ + 'to initiator:' + client.log.exception(description, exc) else: - client.log.error('Invalid address type: {}'.format(client.type)) + client.log.error('Invalid address type: {}', client.type) + else: + client.log.debug( + 'Skipping disconnected message due to {} state', client.state.name) # Wait for the connection to be closed yield from close_future @@ -301,6 +315,22 @@ def handler(self, connection, ws_path): self._server.unregister(self) client.log.debug('Worker stopped') + @asyncio.coroutine + def close(self, code): + """ + Close the underlying connection and stop the protocol. + + Arguments: + - `code`: The close code. + """ + # Note: The client will be set as early as possible without any yielding. + # Thus, self.client is either set and can be closed or the connection + # is already closing (see the constructor and 'get_path_client') + if self.client is not None: + # We need to use 'drop' in order to prevent the server from sending a + # 'disconnect' message for each client. + yield from self._drop_client(self.client, code) + def get_path_client(self, connection, ws_path): # Extract public key from path initiator_key = ws_path[1:] @@ -324,7 +354,7 @@ def get_path_client(self, connection, ws_path): return path, client @asyncio.coroutine - def handle_client(self, task_loop_task): + def handle_client(self): """ SignalingError PathError @@ -334,65 +364,128 @@ def handle_client(self, task_loop_task): SlotsFullError DowngradeError ServerKeyError + InternalError """ - client = self.client + path, client = self.path, self.client # Do handshake client.log.debug('Starting handshake') yield from self.handshake() client.log.info('Handshake completed') - # Prepare tasks - coroutines = [] + # Task: Execute enqueued tasks + client.log.debug('Starting to poll for enqueued tasks') + task_loop = self.task_loop() + + # Check if the client is still connected to the path or has already been dropped. + # Note: This can happen when the client is being picked up and dropped by another + # client while running the handshake. To prevent other race conditions, we + # have to add the client instance to the path early during the handshake. + is_connected = path.has_client(client) # Task: Poll for messages - hex_path = binascii.hexlify(self.path.initiator_key).decode('ascii') + hex_path = binascii.hexlify(path.initiator_key).decode('ascii') + receive_loop = None if client.type == AddressType.initiator: - client.log.debug('Starting runner for initiator') self._server.raise_event(Event.initiator_connected, hex_path) - coroutines.append(self.initiator_receive_loop()) + if is_connected: + client.log.debug('Starting runner for initiator') + receive_loop = self.initiator_receive_loop() elif client.type == AddressType.responder: - client.log.debug('Starting runner for responder') self._server.raise_event(Event.responder_connected, hex_path) - coroutines.append(self.responder_receive_loop()) + if is_connected: + client.log.debug('Starting runner for responder') + receive_loop = self.responder_receive_loop() else: raise ValueError('Invalid address type: {}'.format(client.type)) # Task: Keep alive - client.log.debug('Starting keep-alive task') - coroutines.append(self.keep_alive_loop()) + if is_connected: + client.log.debug('Starting keep-alive task') + keep_alive_loop = self.keep_alive_loop() + else: + keep_alive_loop = None + + # Move the tasks into a context and store it on the path + client.tasks = PathClientTasks( + task_loop=task_loop, + receive_loop=receive_loop, + keep_alive_loop=keep_alive_loop, + loop=self._loop + ) # Wait until complete - # Note: We also add the task loop into this list to catch any - # errors that bubble up in tasks of this client. - tasks = [task_loop_task] - tasks += [asyncio.ensure_future(coroutine, loop=self._loop) - for coroutine in coroutines] + # + # Note: We also add the task loop into this list to catch any errors that bubble + # up in tasks of this client. + # + # Warning: This is probably the most complicated piece of code in the server. + # Avoid touching this! + tasks = set(client.tasks.valid) while True: done, pending = yield from asyncio.wait( tasks, loop=self._loop, return_when=asyncio.FIRST_COMPLETED) + is_connected = path.has_client(client) + exc = None for task in done: - client.log.debug('Task done {}', done) - exc = task.exception() - - # Cancel pending tasks - # Note: Be careful not to cancel the task loop - for pending_task in pending: - if pending_task != task_loop_task: - client.log.debug('Cancelling task {}', pending_task) - pending_task.cancel() + client.log.debug('Task done {}, connected={}', task, is_connected) + + # Determine the exception to be raised + # Note: The first task will set the exception that will be raised. + if task.cancelled(): + if task != client.tasks.task_loop and not is_connected: + # If the client has been dropped, we need to wait for the task + # loop to return. So, remove the task from the list and continue. + tasks.remove(task) + break + if exc is None: + exc = InternalError('A vital task has been cancelled') + client.log.error('Task {} has been cancelled', task) + continue + + task_exc = task.exception() + if task_exc is None: + connection_closed_future = client.connection_closed_future + if not connection_closed_future.done(): + client.log.error('Task {} returned unexpectedly', task) + task_exc = InternalError('A task returned unexpectedly') + else: + # Note: This can happen in case a task returned due to the + # connection becoming closed. Since this doesn't raise an + # exception, we need to do it ourselves. + task_exc = Disconnected(connection_closed_future.result()) - # Raise (or re-raise) if exc is None: - if task == task_loop_task: - # Task loop may return early and it's okay - client.log.debug('Task loop returned early') - tasks.remove(task_loop_task) - else: - client.log.error('Task {} returned unexpectedly', task) - raise SignalingError('A task returned unexpectedly') - else: - raise exc + exc = task_exc + + # Continue if we have no exception + # Note: This may only happen in case the client has been dropped and we need + # to wait for the task loop to return. + if exc is None: + continue + + # Cancel pending tasks + for pending_task in pending: + client.log.debug('Cancelling task {}', pending_task) + pending_task.cancel() + + # Cancel the task queue and remove client from path + # Note: Removing the client needs to be done here since the re-raise hands + # the task back into the event loop allowing other tasks to get the + # client's path instance from the path while it is already effectively + # disconnected. + client.cancel_task_queue() + path = self.path + try: + path.remove_client(client) + except KeyError: + # We can safely ignore this since clients will be removed immediately + # from the path in case they are being dropped by another client. + pass + self._server.paths.clean(path) + + # Finally, raise the exception + raise exc @asyncio.coroutine def handshake(self): @@ -446,11 +539,10 @@ def handshake_initiator(self, message): # Authenticated previous_initiator = path.set_initiator(initiator) if previous_initiator is not None: - # Drop previous initiator using the task queue of the previous initiator + # Drop previous initiator using its task queue path.log.debug('Dropping previous initiator {}', previous_initiator) previous_initiator.log.debug('Dropping (another initiator connected)') - coroutine = previous_initiator.close(code=CloseCode.drop_by_initiator.value) - yield from previous_initiator.enqueue_task(coroutine) + self._drop_client(previous_initiator, CloseCode.drop_by_initiator) # Send new-initiator message if any responder is present responder_ids = path.get_responder_ids() @@ -465,7 +557,7 @@ def handshake_initiator(self, message): yield from asyncio.gather(*coroutines, loop=self._loop) # Send server-auth - responder_ids = path.get_responder_ids() + responder_ids = list(path.get_responder_ids()) message = ServerAuthMessage.create( AddressType.server, initiator.id, initiator.cookie_in, sign_keys=len(self._server.keys) > 0, responder_ids=responder_ids) @@ -500,9 +592,11 @@ def handshake_responder(self, message): id_ = path.add_responder(responder) # Send new-responder message if initiator is present - initiator = path.get_initiator() - initiator_connected = initiator is not None - if initiator_connected: + try: + initiator = path.get_initiator() + except KeyError: + initiator = None + else: # Create message and add send coroutine to task queue of the initiator message = NewResponderMessage.create(AddressType.server, initiator.id, id_) initiator.log.debug('Enqueueing new-responder message') @@ -512,58 +606,65 @@ def handshake_responder(self, message): message = ServerAuthMessage.create( AddressType.server, responder.id, responder.cookie_in, sign_keys=len(self._server.keys) > 0, - initiator_connected=initiator_connected) + initiator_connected=initiator is not None) responder.log.debug('Sending server-auth without responder ids') yield from responder.send(message) @asyncio.coroutine def task_loop(self): client = self.client - while not (client.connection_closed.done() and client.tasks_complete): - try: - # Get a task from the queue - task = yield from client.dequeue_task() - except asyncio.CancelledError: - break + while not client.connection_closed_future.done(): + # Get a task from the queue + task = yield from client.dequeue_task() - # Wait and catch exceptions, ignore cancelled tasks + # Wait and handle exceptions client.log.debug('Waiting for task to complete {}', task) try: yield from task - except asyncio.CancelledError: - client.log.debug('Task cancelled {}', task) - - client.log.debug('Stopped polling for tasks') + except Exception as exc: + if isinstance(exc, asyncio.CancelledError): + client.log.debug('Cancelling active task {}', task) + else: + client.log.debug('Stopping active task {}, ', task) + if asyncio.iscoroutine(task): + task.close() + client.task_done(task) + else: + task.add_done_callback(client.task_done) + raise + client.task_done(task) @asyncio.coroutine def initiator_receive_loop(self): path, initiator = self.path, self.client - while not initiator.connection_closed.done(): + while not initiator.connection_closed_future.done(): # Receive relay message or drop-responder message = yield from initiator.receive() # Relay if isinstance(message, RawMessage): # Lookup responder - responder = path.get_responder(message.destination) + try: + responder = path.get_responder(message.destination) + except KeyError: + responder = None # Send to responder yield from self.relay_message(responder, message.destination, message) # Drop-responder elif message.type == MessageType.drop_responder: # Lookup responder - responder = path.get_responder(message.responder_id) - if responder is not None: + try: + responder = path.get_responder(message.responder_id) + except KeyError: + log_message = 'Responder {} already dropped, nothing to do' + path.log.debug(log_message, message.responder_id) + else: # Drop responder using its task queue path.log.debug( 'Dropping responder {}, reason: {}', responder, message.reason) responder.log.debug( 'Dropping (requested by initiator), reason: {}', message.reason) - # TODO: Mark responder as dropped and don't send relay messages - coroutine = responder.close(code=message.reason.value) - yield from responder.enqueue_task(coroutine) - else: - log_message = 'Responder {} already dropped, nothing to do' - path.log.debug(log_message, message.responder_id) + self._drop_client(responder, message.reason.value) else: error = "Expected relay message or 'drop-responder', got '{}'" raise MessageFlowError(error.format(message.type)) @@ -571,14 +672,17 @@ def initiator_receive_loop(self): @asyncio.coroutine def responder_receive_loop(self): path, responder = self.path, self.client - while not responder.connection_closed.done(): + while not responder.connection_closed_future.done(): # Receive relay message message = yield from responder.receive() # Relay if isinstance(message, RawMessage): # Lookup initiator - initiator = path.get_initiator() + try: + initiator = path.get_initiator() + except KeyError: + initiator = None # Send to initiator yield from self.relay_message(initiator, AddressType.initiator, message) else: @@ -642,14 +746,14 @@ def keep_alive_loop(self): PingTimeoutError """ client = self.client - while not client.connection_closed.done(): + while not client.connection_closed_future.done(): # Wait yield from asyncio.sleep(client.keep_alive_interval, loop=self._loop) # Send ping and wait for pong client.log.debug('Ping') + pong_future = yield from client.ping() try: - pong_future = yield from client.ping() yield from asyncio.wait_for( pong_future, client.keep_alive_timeout, loop=self._loop) except asyncio.TimeoutError: @@ -711,11 +815,33 @@ def _validate_subprotocol(self, client_subprotocols): self.client.log.debug( 'Checking for subprotocol downgrade, client: {}, server: {}', client_subprotocols, self._server.subprotocols) - chosen = self._select_subprotocol( + chosen = websockets.WebSocketServerProtocol.select_subprotocol( client_subprotocols, self._server.subprotocols) if chosen != self.subprotocol.value: raise DowngradeError('Subprotocol downgrade detected') + def _drop_client(self, client, code): + """ + Mark the client as closed, schedule the closing procedure on + the client's task queue, remove it from the path and return the + drop operation in form of a :class:`asyncio.Task`. + + .. important:: This should only be called by clients dropping + another client or when the server is closing. + + Arguments: + - `client`: The client to be dropped. + - `close`: The close code. + """ + # Drop the client + drop_task = client.drop(code) + + # Remove the client from the path + path = self.path + path.remove_client(client) + + return drop_task + class Paths: __slots__ = ('_log', 'number', 'paths') @@ -752,6 +878,9 @@ def __init__(self, keys, paths, loop=None): self._log = util.get_logger('server') self._loop = asyncio.get_event_loop() if loop is None else loop + # Protocol class + self._protocol_class = ServerProtocol + # WebSocket server instance self._server = None @@ -765,8 +894,9 @@ def __init__(self, keys, paths, loop=None): # Store paths self.paths = paths - # Store server protocols + # Store server protocols and closing task self.protocols = set() + self._close_task = None # Event Registry self._events = EventRegistry() @@ -782,6 +912,11 @@ def server(self, server): @asyncio.coroutine def handler(self, connection, ws_path): + # Closing? Drop immediately + if self._close_task is not None: + yield from connection.close(CloseCode.going_away.value) + return + # Convert sub-protocol try: subprotocol = SubProtocol(connection.subprotocol) @@ -796,8 +931,8 @@ def handler(self, connection, ws_path): yield from connection.close(code=CloseCode.subprotocol_error.value) self.raise_event(Event.disconnected, None, CloseCode.subprotocol_error.value) else: - protocol = ServerProtocol(self, subprotocol, loop=self._loop) - protocol.connection_made(connection, ws_path) + protocol = self._protocol_class( + self, subprotocol, connection, ws_path, loop=self._loop) yield from protocol.handler_task def register(self, protocol): @@ -825,37 +960,36 @@ def close(self): """ Close open connections and the server. """ - asyncio.ensure_future(self._close_after_all_protocols_closed(), loop=self._loop) + if self._close_task is None: + self._close_task = asyncio.ensure_future( + self._close_after_all_protocols_closed(), loop=self._loop) @asyncio.coroutine def wait_closed(self): """ - Wait until all connections and the server itself is closed. + Wait until all connections and the server itself has been + closed. """ - yield from self._wait_connections_closed() yield from self.server.wait_closed() - @asyncio.coroutine - def _wait_connections_closed(self): - """ - Wait until all connections to the server have been closed. - """ - if len(self.protocols) > 0: - tasks = [protocol.handler_task for protocol in self.protocols] - yield from asyncio.gather(*tasks, loop=self._loop) - @asyncio.coroutine def _close_after_all_protocols_closed(self, timeout=None): # Schedule closing all protocols - self._log.debug('Closing protocols') + self._log.info('Closing protocols') if len(self.protocols) > 0: - tasks = [protocol.close(code=CloseCode.going_away.value) - for protocol in self.protocols] + @asyncio.coroutine + def _close_and_wait(): + # Wait until all connections have been scheduled to be closed + tasks = [protocol.close(CloseCode.going_away.value) + for protocol in self.protocols] + yield from asyncio.gather(*tasks, loop=self._loop) + + # Wait until all protocols have returned + tasks = [protocol.handler_task for protocol in self.protocols] + yield from asyncio.gather(*tasks, loop=self._loop) - # Wait until all protocols are closed (we need the server to be active for the - # WebSocket close protocol) - yield from asyncio.wait(tasks, loop=self._loop, timeout=timeout) + yield from asyncio.wait_for(_close_and_wait(), timeout, loop=self._loop) # Now we can close the server - self._log.debug('Closing server') + self._log.info('Closing server') self.server.close() diff --git a/setup.cfg b/setup.cfg index 61345e2..f537393 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bdist_wheel] -python-tag = py34.py35.py36 +python-tag = py34.py35.py36.py37 [flake8] max-line-length = 90 diff --git a/setup.py b/setup.py index 92fbfcf..999541a 100644 --- a/setup.py +++ b/setup.py @@ -40,14 +40,15 @@ def read(file): # Note: These are just tools that aren't required, so a version range # is not necessary here. tests_require = [ - 'pytest>=3.0.7', - 'pytest-asyncio>=0.5.0', - 'pytest-cov>=2.4.0', - 'flake8>=3.3.0', - 'isort>=4.2.5', + 'pytest>=3.7.3', + 'pytest-asyncio>=0.9.0', + 'pytest-cov>=2.5.1', + 'pytest-mock>=1.10.0', + 'flake8>=3.5.0', + 'isort>=4.3.4', 'collective.checkdocs>=0.2', 'Pygments>=2.2.0', # required by checkdocs - 'ordered-set>=3.0.0', # required by TestServer class + 'ordered-set>=3.0.1', # required by TestServer class ] + logging_require setup( @@ -57,7 +58,7 @@ def read(file): install_requires=[ 'libnacl>=1.5.0,<2', 'click>=6.7', # doesn't seem to follow semantic versioning (see #57) - 'websockets>=3.2,<4', + 'websockets>=7.0,<8', 'u-msgpack-python>=2.3,<3', ], tests_require=tests_require, @@ -99,6 +100,7 @@ def read(file): 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Communications', 'Topic :: Internet', 'Topic :: Security', diff --git a/tests/cert.pem b/tests/cert.pem index 71a9833..d21d1a5 100644 --- a/tests/cert.pem +++ b/tests/cert.pem @@ -1,18 +1,12 @@ ------BEGIN EC PARAMETERS----- -BggqhkjOPQMBBw== ------END EC PARAMETERS----- ------BEGIN EC PRIVATE KEY----- -MHcCAQEEIIzGxguRVwC6nIW2L0aSF1WOb7f1Gg1YrkJvYYnGyeCwoAoGCCqGSM49 -AwEHoUQDQgAEnePy9uKM8yWCPp68xWknNL5FiUpRzmdOuKCxor36Ofk4haOcpsTJ -eCTHF1wrAK9/c+JJaJw3Ka5pv7VJBqi4HQ== ------END EC PRIVATE KEY----- -----BEGIN CERTIFICATE----- -MIIBbjCCARWgAwIBAgIJALwoZ5qw/ta4MAoGCCqGSM49BAMCMBQxEjAQBgNVBAMM -CTEyNy4wLjAuMTAeFw0xNzAzMjMxNjUyMDFaFw0yNzAzMjExNjUyMDFaMBQxEjAQ -BgNVBAMMCTEyNy4wLjAuMTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABJ3j8vbi +MIIBsDCCAVWgAwIBAgIJAJxiRGRw5YGAMAoGCCqGSM49BAMCMBQxEjAQBgNVBAMM +CWxvY2FsaG9zdDAeFw0xODA4MjcxNTM0NTlaFw0yODA4MjQxNTM0NTlaMBQxEjAQ +BgNVBAMMCWxvY2FsaG9zdDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABJ3j8vbi jPMlgj6evMVpJzS+RYlKUc5nTrigsaK9+jn5OIWjnKbEyXgkxxdcKwCvf3PiSWic -Nymuab+1SQaouB2jUDBOMB0GA1UdDgQWBBQXZoZAscS/dwnH3J9NeEq/BxmtjDAf -BgNVHSMEGDAWgBQXZoZAscS/dwnH3J9NeEq/BxmtjDAMBgNVHRMEBTADAQH/MAoG -CCqGSM49BAMCA0cAMEQCIAnSwYX5xX94PsyQeCXcnnbvYnrlOAqhbMYnr+3sHRIG -AiBd0N6ORu7jZzVm/hFuVrVwchYzkmYPbNpRDHTYwbiPtQ== +Nymuab+1SQaouB2jgY8wgYwwHQYDVR0OBBYEFBdmhkCxxL93Ccfcn014Sr8HGa2M +MB8GA1UdIwQYMBaAFBdmhkCxxL93Ccfcn014Sr8HGa2MMA8GA1UdEwEB/wQFMAMB +Af8wCwYDVR0PBAQDAgGmMCwGA1UdEQQlMCOCCWxvY2FsaG9zdIcEfwAAAYcQAAAA +AAAAAAAAAAAAAAAAATAKBggqhkjOPQQDAgNJADBGAiEA0/23gmC4lEfzHl46ArMl +ss694KFRZ6128ngxawPdECgCIQDyPFh3PdWxHKMgAQHazxczthAswWnrGAkt7C+w +oT5WpQ== -----END CERTIFICATE----- diff --git a/tests/conftest.py b/tests/conftest.py index fead690..3ebc66f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -61,9 +61,11 @@ def pytest_namespace(): 'have_uvloop': pytest.mark.skipif(not have_uvloop, reason='requires uvloop'), 'no_uvloop': pytest.mark.skipif( have_uvloop, reason='requires uvloop to be not installed'), - 'ip': '127.0.0.1', + 'host': 'localhost', 'port': 8766, 'cli_path': os.path.join(sys.exec_prefix, 'bin', 'saltyrtc-server'), + 'key': os.path.normpath( + os.path.join(os.path.abspath(__file__), os.pardir, 'key.pem')), 'cert': os.path.normpath( os.path.join(os.path.abspath(__file__), os.pardir, 'cert.pem')), 'dh_params': os.path.normpath( @@ -101,7 +103,7 @@ def unused_tcp_port(): Find an unused localhost TCP port from 1024-65535 and return it. """ with closing(socket.socket()) as sock: - sock.bind((pytest.saltyrtc.ip, 0)) + sock.bind((pytest.saltyrtc.host, 0)) return sock.getsockname()[1] @@ -152,17 +154,6 @@ def _get_timeout(timeout=None, request=None, config=None): return timeout -def _sleep(**kwargs): - """ - Sleep *timeout* seconds. - """ - @asyncio.coroutine - def __sleep(timeout=None): - yield from asyncio.sleep(_get_timeout(timeout=timeout, **kwargs)) - - return __sleep - - @pytest.fixture(scope='module') def event_loop(request): """ @@ -222,11 +213,16 @@ def responder_key(): @pytest.fixture(scope='module') -def sleep(request): +def sleep(event_loop): """ Sleep *timeout* seconds. """ - return _sleep(request=request) + @asyncio.coroutine + def _sleep(delay, **kwargs): + kwargs.setdefault('loop', event_loop) + yield from asyncio.sleep(delay, **kwargs) + + return _sleep @pytest.fixture(scope='module') @@ -267,26 +263,33 @@ def __init__(self, *args, timeout=None, **kwargs): def raise_event(self, event: Event, *data): super().raise_event(event, *data) if event == Event.disconnected: - self._most_recent_connection_closed_future.set_result(None) + self._most_recent_connection_closed_future.set_result(data) self._most_recent_connection_closed_future = asyncio.Future(loop=self._loop) @asyncio.coroutine def wait_connections_closed(self): self._log.debug('#protocols remaining: {}', len(self.protocols)) + + @asyncio.coroutine + def _wait_connections_closed(): + if len(self.protocols) > 0: + tasks = [protocol.handler_task for protocol in self.protocols] + yield from asyncio.gather(*tasks, loop=self._loop) + yield from asyncio.wait_for( - self._wait_connections_closed(), timeout=self.timeout, loop=self._loop) + _wait_connections_closed(), timeout=self.timeout, loop=self._loop) @asyncio.coroutine def wait_most_recent_connection_closed(self, connection_closed_future=None): # If there is no future, we simply wait for the 'disconnected' event if connection_closed_future is None: connection_closed_future = self._most_recent_connection_closed_future - yield from asyncio.wait_for( - connection_closed_future, timeout=self.timeout, loop=self._loop) + return (yield from asyncio.wait_for( + connection_closed_future, timeout=self.timeout, loop=self._loop)) def wait_connection_closed_marker(self): protocol = self.protocols[-1] - connection_closed_future = protocol.client.connection_closed + connection_closed_future = protocol.client.connection_closed_future return functools.partial( self.wait_most_recent_connection_closed, connection_closed_future=connection_closed_future) @@ -301,13 +304,13 @@ def server_factory(request, event_loop, server_permanent_keys): os.environ['PYTHONASYNCIODEBUG'] = '1' # Enable logging - util.enable_logging(level=logbook.TRACE, redirect_loggers={ + util.enable_logging(level=logbook.DEBUG, redirect_loggers={ 'asyncio': logbook.WARNING, 'websockets': logbook.WARNING, }) - # Push handler - logging_handler = logbook.StderrHandler() + # Push handlers + logging_handler = logbook.StderrHandler(bubble=True) logging_handler.push_application() _server_instances = [] @@ -320,9 +323,10 @@ def _server_factory(permanent_keys=None): port = unused_tcp_port() coroutine = serve( util.create_ssl_context( - pytest.saltyrtc.cert, dh_params_file=pytest.saltyrtc.dh_params), + pytest.saltyrtc.cert, keyfile=pytest.saltyrtc.key, + dh_params_file=pytest.saltyrtc.dh_params), permanent_keys, - host=pytest.saltyrtc.ip, + host=pytest.saltyrtc.host, port=port, loop=event_loop, server_class=TestServer, @@ -330,7 +334,7 @@ def _server_factory(permanent_keys=None): server_ = event_loop.run_until_complete(coroutine) # Inject timeout and address (little bit of a hack but meh...) server_.timeout = _get_timeout(request=request) - server_.address = (pytest.saltyrtc.ip, port) + server_.address = (pytest.saltyrtc.host, port) _server_instances.append(server_) @@ -363,6 +367,46 @@ def server_no_key(server_factory): return server_factory(permanent_keys=[]) +@pytest.fixture +def log_handler(request): + """ + Return a :class:`logbook.TestHandler` instance where log records + can be accessed. + """ + log_handler = logbook.TestHandler(level=logbook.DEBUG, bubble=True) + log_handler._ignore_filter = lambda _: False + log_handler._error_level = logbook.ERROR + log_handler.push_application() + + def fin(): + log_handler.pop_application() + request.addfinalizer(fin) + + return log_handler + + +@pytest.fixture +def evaluate_log(log_handler): + """ + Ensure that no test is logging (handled) errors. + """ + yield + errors = [record for record in log_handler.records + if (record.level >= log_handler._error_level + and not log_handler._ignore_filter(record))] + assert(len(errors) == 0) + + +@pytest.fixture +def log_ignore_filter(log_handler): + """ + Ignore specific log entries with a filter callback. + """ + def _set_filter(callback): + log_handler._ignore_filter = callback + return _set_filter + + class _DefaultBox: pass @@ -376,6 +420,7 @@ def __init__(self, ws_client, pack_message, unpack_message, request, timeout=Non self.session_key = None self.box = None + @asyncio.coroutine def send(self, nonce, message, box=_DefaultBox, timeout=None, pack=True): if timeout is None: timeout = self.timeout @@ -384,6 +429,7 @@ def send(self, nonce, message, box=_DefaultBox, timeout=None, pack=True): box=self.box if box == _DefaultBox else box, timeout=timeout, pack=pack )) + @asyncio.coroutine def recv(self, box=_DefaultBox, timeout=None): if timeout is None: timeout = self.timeout @@ -397,7 +443,17 @@ def close(self): @pytest.fixture(scope='module') -def ws_client_factory(initiator_key, event_loop, server): +def client_kwargs(event_loop): + return { + 'compression': None, + 'subprotocols': pytest.saltyrtc.subprotocols, + 'ping_interval': None, + 'loop': event_loop, + } + + +@pytest.fixture(scope='module') +def ws_client_factory(initiator_key, event_loop, client_kwargs, server): """ Return a simplified :class:`websockets.client.connect` wrapper where no parameters are required. @@ -415,20 +471,16 @@ def _ws_client_factory(server=None, path=None, **kwargs): server = server_ if path is None: path = '{}/{}'.format(url(*server.address), key_path(initiator_key)) - _kwargs = { - 'subprotocols': pytest.saltyrtc.subprotocols, - 'ssl': ssl_context, - 'loop': event_loop, - } + _kwargs = client_kwargs.copy() _kwargs.update(kwargs) - return websockets.connect(path, **_kwargs) + return websockets.connect(path, ssl=ssl_context, **_kwargs) return _ws_client_factory @pytest.fixture(scope='module') def client_factory( - request, initiator_key, event_loop, server, server_permanent_keys, responder_key, - pack_nonce, pack_message, unpack_message + request, initiator_key, event_loop, client_kwargs, server, server_permanent_keys, + responder_key, pack_nonce, pack_message, unpack_message ): """ Return a simplified :class:`websockets.client.connect` wrapper @@ -456,16 +508,12 @@ def _client_factory( cookie = random_cookie() if permanent_key is None: permanent_key = server_permanent_keys[0].pk - _kwargs = { - 'subprotocols': pytest.saltyrtc.subprotocols, - 'ssl': ssl_context, - 'loop': event_loop, - } + _kwargs = client_kwargs.copy() _kwargs.update(kwargs) if ws_client is None: ws_client = yield from websockets.connect( '{}/{}'.format(url(*server.address), key_path(path)), - **_kwargs + ssl=ssl_context, **_kwargs ) client = Client( ws_client, pack_message, unpack_message, diff --git a/tests/key.pem b/tests/key.pem new file mode 100644 index 0000000..c6f9a8e --- /dev/null +++ b/tests/key.pem @@ -0,0 +1,8 @@ +-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIIzGxguRVwC6nIW2L0aSF1WOb7f1Gg1YrkJvYYnGyeCwoAoGCCqGSM49 +AwEHoUQDQgAEnePy9uKM8yWCPp68xWknNL5FiUpRzmdOuKCxor36Ofk4haOcpsTJ +eCTHF1wrAK9/c+JJaJw3Ka5pv7VJBqi4HQ== +-----END EC PRIVATE KEY----- diff --git a/tests/openssl.cnf b/tests/openssl.cnf new file mode 100644 index 0000000..f812338 --- /dev/null +++ b/tests/openssl.cnf @@ -0,0 +1,36 @@ +[req] +distinguished_name = req_distinguished_name +x509_extensions = v3_ca +string_mask = utf8only + +[req_distinguished_name] +countryName = Country Name (2 letter code) +countryName_default = +countryName_min = 2 +countryName_max = 2 +stateOrProvinceName = State or Province Name (full name) +stateOrProvinceName_default = +localityName = Locality Name (eg, city) +localityName_default = +0.organizationName = Organization Name (eg, company) +0.organizationName_default = +organizationalUnitName = Organizational Unit Name (eg, section) +organizationalUnitName_default = +commonName = Common Name (localhost) +commonName_default = localhost +commonName_max = 64 +emailAddress = Email Address +emailAddress_default = +emailAddress_max = 40 + +[v3_ca] +subjectKeyIdentifier = hash +authorityKeyIdentifier = keyid:always,issuer +basicConstraints = critical, CA:true +keyUsage = cRLSign, keyCertSign, digitalSignature, keyEncipherment +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = 127.0.0.1 +IP.2 = ::1 diff --git a/tests/test_base.py b/tests/test_base.py index 6c5d025..43ad291 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -6,6 +6,7 @@ import websockets +@pytest.mark.usefixtures('evaluate_log') class TestPrerequisities: @pytest.mark.asyncio def test_server_handshake(self, ws_client_factory): diff --git a/tests/test_cli.py b/tests/test_cli.py index 588e19e..36a9439 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,6 +13,7 @@ ) +@pytest.mark.usefixtures('evaluate_log') class TestCLI: @pytest.mark.asyncio def test_invalid_command(self, cli): @@ -84,6 +85,7 @@ def test_serve_cert_missing(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-p', '8443', ) assert 'It is REQUIRED' in exc_info.value.output @@ -109,6 +111,7 @@ def test_serve_invalid_key_file(self, cli, tmpdir): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', str(keyfile), '-p', '8443', ) @@ -122,6 +125,7 @@ def test_serve_invalid_dh_params_file(self, cli, tmpdir): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-dhp', str(dh_params_file), '-p', '8443', @@ -135,6 +139,7 @@ def test_serve_invalid_hex_encoded_key(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', key, '-p', '8443', ) @@ -146,6 +151,7 @@ def test_serve_invalid_host(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-h', 'meow', '-p', '8443', @@ -158,6 +164,7 @@ def test_serve_invalid_port(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', 'meow', ) @@ -169,6 +176,7 @@ def test_serve_invalid_loop(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', '-l', 'meow', @@ -182,6 +190,7 @@ def test_serve_uvloop_unavailable(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', '-l', 'uvloop', @@ -193,6 +202,7 @@ def test_serve_asyncio(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', signal=signal.SIGINT, @@ -204,6 +214,7 @@ def test_serve_asyncio_dh_params(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-dhp', pytest.saltyrtc.dh_params, '-p', '8443', @@ -216,6 +227,7 @@ def test_serve_asyncio_hex_encoded_key(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', open(pytest.saltyrtc.permanent_key_primary, 'r').read(), '-p', '8443', signal=signal.SIGINT, @@ -228,6 +240,7 @@ def test_serve_asyncio_plus_logging(self, cli): '-v', '7', 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', signal=signal.SIGINT, @@ -241,6 +254,7 @@ def test_serve_uvloop(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', '-l', 'uvloop', @@ -254,6 +268,7 @@ def test_serve_uvloop_dh_params(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-dhp', pytest.saltyrtc.dh_params, '-p', '8443', @@ -269,6 +284,7 @@ def test_serve_uvloop_plus_logging(self, cli): '-v', '7', 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', '-l', 'uvloop', @@ -282,6 +298,7 @@ def test_serve_asyncio_restart(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', signal=[signal.SIGHUP, signal.SIGINT], @@ -331,6 +348,7 @@ def test_serve_repeated_key(self, cli): yield from cli(*[ 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-p', '8443', ] + key_arguments) assert 'key has been supplied more than once' in exc_info.value.output @@ -343,6 +361,7 @@ def test_serve_invalid_2nd_key_file(self, cli, tmpdir): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-k', str(keyfile), '-p', '8443', @@ -356,6 +375,7 @@ def test_serve_invalid_2nd_hex_encoded_key(self, cli): yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-k', key, '-p', '8443', @@ -374,6 +394,7 @@ def test_serve_asyncio_2nd_key(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_primary, '-k', pytest.saltyrtc.permanent_key_secondary, '-p', '8443', @@ -395,6 +416,7 @@ def test_serve_asyncio_2nd_key_reversed(self, cli): output = yield from cli( 'serve', '-sc', pytest.saltyrtc.cert, + '-sk', pytest.saltyrtc.key, '-k', pytest.saltyrtc.permanent_key_secondary, '-k', pytest.saltyrtc.permanent_key_primary, '-p', '8443', diff --git a/tests/test_protocol.py b/tests/test_protocol.py index bd3ee47..ffb648a 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -3,28 +3,35 @@ compliant to the SaltyRTC protocol. """ import asyncio -import collections import libnacl.public import pytest import websockets +from saltyrtc.server import ServerProtocol from saltyrtc.server.common import ( SIGNED_KEYS_CIPHERTEXT_LENGTH, + ClientState, CloseCode, ) -from saltyrtc.server.events import Event class _FakePathClient: def __init__(self): - self.connection_closed = asyncio.Future() - self.connection_closed.set_result(None) + self.connection_closed_future = asyncio.Future() + self.connection_closed_future.set_result(None) + self.state = ClientState.restricted + self.id = None def update_log_name(self, id_): pass + def authenticate(self, id_): + self.id = id_ + self.state = ClientState.authenticated + +@pytest.mark.usefixtures('evaluate_log') class TestProtocol: @pytest.mark.asyncio def test_no_subprotocols(self, server, ws_client_factory): @@ -532,7 +539,7 @@ def test_client_factory_handshake( yield from server.wait_connections_closed() @pytest.mark.asyncio - def test_keep_alive_pings_initiator(self, server, client_factory): + def test_keep_alive_pings_initiator(self, sleep, server, client_factory): """ Check that the server sends ping messages in the requested interval. @@ -544,7 +551,7 @@ def test_keep_alive_pings_initiator(self, server, client_factory): ) # Wait for two pings (including pongs) - yield from asyncio.sleep(2.1) + yield from sleep(2.1) # Check ping counter assert len(server.protocols) == 1 @@ -556,7 +563,7 @@ def test_keep_alive_pings_initiator(self, server, client_factory): yield from server.wait_connections_closed() @pytest.mark.asyncio - def test_keep_alive_pings_responder(self, server, client_factory): + def test_keep_alive_pings_responder(self, sleep, server, client_factory): """ Check that the server sends ping messages in the requested interval. @@ -568,7 +575,7 @@ def test_keep_alive_pings_responder(self, server, client_factory): ) # Wait for two pings (including pongs) - yield from asyncio.sleep(1.1) + yield from sleep(1.1) # Check ping counter assert len(server.protocols) == 1 @@ -580,7 +587,7 @@ def test_keep_alive_pings_responder(self, server, client_factory): yield from server.wait_connections_closed() @pytest.mark.asyncio - def test_keep_alive_ignore_invalid(self, server, client_factory): + def test_keep_alive_ignore_invalid(self, sleep, server, client_factory): """ Check that the server ignores invalid keep alive intervals. """ @@ -591,7 +598,7 @@ def test_keep_alive_ignore_invalid(self, server, client_factory): ) # Wait for a second - yield from asyncio.sleep(1.1) + yield from sleep(1.1) # Check ping counter assert len(server.protocols) == 1 @@ -607,9 +614,8 @@ def test_keep_alive_timeout( self, ws_client_factory, server, client_factory ): """ - Monkey-patch the the server's keep alive interval and timeout - and check that the server sends us a ping and waits for a - pong. + Monkey-patch the server's keep alive interval and timeout and + check that the server sends a ping and waits for a pong. """ # Create client and patch it to not answer pings ws_client = yield from ws_client_factory() @@ -622,13 +628,12 @@ def test_keep_alive_timeout( protocol.client.keep_alive_timeout = 0.001 # Initiator handshake - client, i = yield from client_factory( - ws_client=ws_client, initiator_handshake=True) + yield from client_factory(ws_client=ws_client, initiator_handshake=True) # Expect protocol error yield from server.wait_connections_closed() - assert not client.ws_client.open - assert client.ws_client.close_code == CloseCode.protocol_error + assert not ws_client.open + assert ws_client.close_code == CloseCode.timeout @pytest.mark.asyncio def test_initiator_invalid_source_after_handshake( @@ -933,7 +938,7 @@ def test_drop_responder_with_reason( 'reason': CloseCode.internal_error.value, }) - # Responder: Expect reason 'handover' + # Responder: Expect reason 'internal error' yield from connection_closed_future() assert not responder.ws_client.open assert responder.ws_client.close_code == CloseCode.internal_error @@ -1203,9 +1208,221 @@ def test_relay_receiver_offline( yield from initiator.close() yield from server.wait_connections_closed() + @pytest.mark.asyncio + def test_relay_send_and_close( + self, pack_nonce, cookie_factory, server, client_factory + ): + """ + Ensure a relay messages are being dispatched in case the client + closes after having sent a couple of relay messages. + """ + # Initiator handshake + initiator, i = yield from client_factory(initiator_handshake=True) + i['rccsn'] = 98798981 + i['rcck'] = cookie_factory() + + # Responder handshake + responder, r = yield from client_factory(responder_handshake=True) + r['iccsn'] = 2 ** 23 + r['icck'] = cookie_factory() + + # new-responder + yield from initiator.recv() + + # Send 3 relay messages: initiator --> responder + expected_data = b'\xfe' * 2**16 # 64 KiB + for _ in range(3): + nonce = pack_nonce(i['rcck'], i['id'], r['id'], i['rccsn']) + yield from initiator.send(nonce, expected_data, box=None) + i['rccsn'] += 1 + + # Close initiator + yield from initiator.close() + + # Receive 3 relay messages: initiator --> responder + for _ in range(3): + actual_data, *_ = yield from responder.recv(box=None) + assert actual_data == expected_data + + # Bye + yield from responder.close() + yield from server.wait_connections_closed() + + @pytest.mark.asyncio + def test_relay_send_before_close_responder( + self, pack_nonce, cookie_factory, server, client_factory + ): + """ + Ensure relay messages are being dispatched in case the receiver + is being closed (drop responder) after the sender has sent the + relay messages. + """ + # Initiator handshake + initiator, i = yield from client_factory(initiator_handshake=True) + i['rccsn'] = 98798981 + i['rcck'] = cookie_factory() + + # Responder handshake + responder, r = yield from client_factory(responder_handshake=True) + responder_closed_future = server.wait_connection_closed_marker() + r['iccsn'] = 2 ** 23 + r['icck'] = cookie_factory() + + # new-responder + yield from initiator.recv() + + # Send 6 relay messages: initiator --> responder + expected_data = b'\xfe' * 2**15 # 32 KiB + for _ in range(6): + nonce = pack_nonce(i['rcck'], i['id'], r['id'], i['rccsn']) + yield from initiator.send(nonce, expected_data, box=None) + i['rccsn'] += 1 + + # Drop responder + yield from initiator.send(pack_nonce(i['cck'], 0x01, 0x00, i['ccsn']), { + 'type': 'drop-responder', + 'id': r['id'], + }) + i['ccsn'] += 1 + + # Receive 6 relay messages: initiator --> responder + for _ in range(6): + actual_data, *_ = yield from responder.recv(box=None) + assert actual_data == expected_data + + # Responder: Expect drop by initiator + yield from responder_closed_future() + assert not responder.ws_client.open + assert responder.ws_client.close_code == CloseCode.drop_by_initiator + + # Bye + yield from initiator.close() + yield from server.wait_connections_closed() + + @pytest.mark.asyncio + def test_relay_send_before_close_initiator( + self, pack_nonce, cookie_factory, server, client_factory + ): + """ + Ensure relay messages are being dispatched in case the receiver + is being closed (drop initiator) after the sender has sent the + relay messages. + """ + # Initiator handshake + first_initiator, i = yield from client_factory(initiator_handshake=True) + connection_closed_future = server.wait_connection_closed_marker() + i['rccsn'] = 98798981 + i['rcck'] = cookie_factory() + + # Responder handshake + responder, r = yield from client_factory(responder_handshake=True) + r['iccsn'] = 2 ** 23 + r['icck'] = cookie_factory() + + # new-responder + yield from first_initiator.recv() + + # Send 6 relay messages: initiator <-- responder + expected_data = b'\xfe' * 2**15 # 32 KiB + for _ in range(6): + nonce = pack_nonce(r['icck'], r['id'], i['id'], r['iccsn']) + yield from responder.send(nonce, expected_data, box=None) + r['iccsn'] += 1 + + # Second initiator handshake + second_initiator, i = yield from client_factory(initiator_handshake=True) + # Responder is connected + assert i['responders'] == [r['id']] + + # new-initiator + yield from responder.recv() + + # Receive 6 relay messages: initiator <-- responder + for _ in range(6): + actual_data, *_ = yield from first_initiator.recv(box=None) + assert actual_data == expected_data + + # First initiator: Expect drop by initiator + yield from connection_closed_future() + assert not first_initiator.ws_client.open + assert first_initiator.ws_client.close_code == CloseCode.drop_by_initiator + + # Bye + yield from responder.close() + yield from second_initiator.close() + yield from server.wait_connections_closed() + + @pytest.mark.asyncio + def test_relay_send_after_close( + self, mocker, event_loop, pack_nonce, cookie_factory, server, client_factory, + initiator_key + ): + """ + When the responder is being dropped by the initiator, the + responder's task loop may await a long-blocking task before it + is being closed. Ensure that the initiator is not able to + enqueue further messages to the responder at that point. + """ + # Mock the protocol to release the 'done_future' once the closing procedure has + # been initiated + class _MockProtocol(ServerProtocol): + def _drop_client(self, *args, **kwargs): + result = super()._drop_client(*args, **kwargs) + done_future.set_result(None) + return result + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Initiator handshake + initiator, i = yield from client_factory(initiator_handshake=True) + i['rccsn'] = 98798981 + i['rcck'] = cookie_factory() + + # Responder handshake + responder, r = yield from client_factory(responder_handshake=True) + r['iccsn'] = 2 ** 23 + r['icck'] = cookie_factory() + + # new-responder + yield from initiator.recv() + + # Get responder's PathClient instance + path = server.paths.get(initiator_key.pk) + path_client = path.get_responder(r['id']) + done_future = asyncio.Future(loop=event_loop) + + # Create long-blocking task + @asyncio.coroutine + def blocking_task(): + yield from done_future + + # Enqueue long-blocking task + yield from path_client.enqueue_task(blocking_task()) + + # Drop responder + yield from initiator.send(pack_nonce(i['cck'], 0x01, 0x00, i['ccsn']), { + 'type': 'drop-responder', + 'id': r['id'], + }) + i['ccsn'] += 1 + + # Send relay message: initiator --> responder + nonce = pack_nonce(i['rcck'], i['id'], r['id'], i['rccsn']) + yield from initiator.send(nonce, b'\xfe' * 2**15, box=None) + i['rccsn'] += 1 + + # Responder: Expect drop by initiator + with pytest.raises(websockets.ConnectionClosed): + yield from responder.recv(box=None) + assert responder.ws_client.close_code == CloseCode.drop_by_initiator + + # Bye + yield from initiator.close() + yield from server.wait_connections_closed() + @pytest.mark.asyncio def test_relay_receiver_connection_lost( - self, event_loop, ws_client_factory, initiator_key, pack_nonce, + self, mocker, event_loop, ws_client_factory, initiator_key, pack_nonce, cookie_factory, server, client_factory ): """ @@ -1242,12 +1459,10 @@ def test_relay_receiver_connection_lost( path_client = path.get_responder(0x02) # Mock responder instance: Block sending and let the next ping time out - forever_blocking_future = asyncio.Future(loop=event_loop) - @asyncio.coroutine def _mock_send(*_): path_client.log.notice('... NOT') - yield from forever_blocking_future + yield from asyncio.Future(loop=event_loop) @asyncio.coroutine def _mock_ping(*_): @@ -1255,11 +1470,11 @@ def _mock_ping(*_): # Dunno why asyncio treats this as a non-coroutine but it does, # so this workaround is required future = asyncio.Future(loop=event_loop) - future.set_result(forever_blocking_future) + future.set_result(asyncio.Future(loop=event_loop)) return future - path_client._connection.send = _mock_send - path_client._connection.ping = _mock_ping + mocker.patch.object(path_client._connection, 'send', _mock_send) + mocker.patch.object(path_client._connection, 'ping', _mock_ping) # Send relay message: initiator --> responder (mocked) nonce = pack_nonce(i['rcck'], i['id'], 0x02, i['rccsn']) @@ -1269,7 +1484,7 @@ def _mock_ping(*_): }, box=None) i['rccsn'] += 1 - # Receive send-error message: initiator <-- initiator + # Receive send-error message: initiator <-- initiator (mocked) message, _, sck, s, d, scsn = yield from initiator.recv(timeout=10.0) assert s == 0x00 assert d == i['id'] @@ -1497,6 +1712,8 @@ def test_path_full(self, event_loop, server, client_factory): and check that the correct error code (Path Full) is being returned. """ + assert len(server.protocols) == 0 + tasks = [client_factory(responder_handshake=True, timeout=20.0) for _ in range(0x02, 0x100)] clients = yield from asyncio.gather(*tasks, loop=event_loop) @@ -1514,59 +1731,6 @@ def test_path_full(self, event_loop, server, client_factory): yield from asyncio.gather(*tasks, loop=event_loop) yield from server.wait_connections_closed() - @pytest.mark.asyncio - def test_event_emitted( - self, initiator_key, responder_key, cookie_factory, server, client_factory - ): - # Dictionary where fired events are added - events_fired = collections.defaultdict(list) - - @asyncio.coroutine - def callback(event: Event, *data): - events_fired[event].append(data) - - # Register event callback for all events - for event in Event: - server.register_event_callback(event, callback) - - # Initiator handshake - initiator, i = yield from client_factory(initiator_handshake=True) - i['rccsn'] = 456987 - i['rcck'] = cookie_factory() - i['rbox'] = libnacl.public.Box(sk=initiator_key, pk=responder_key.pk) - - # Responder handshake - responder, r = yield from client_factory(responder_handshake=True) - r['iccsn'] = 2 ** 24 - r['icck'] = cookie_factory() - r['ibox'] = libnacl.public.Box(sk=responder_key, pk=initiator_key.pk) - - yield from initiator.recv() - assert set(events_fired.keys()) == { - Event.initiator_connected, - Event.responder_connected, - } - assert events_fired[Event.initiator_connected] == [ - (initiator_key.hex_pk().decode('ascii'),) - ] - assert events_fired[Event.responder_connected] == [ - (initiator_key.hex_pk().decode('ascii'),) - ] - - yield from initiator.close() - yield from responder.close() - yield from server.wait_connections_closed() - - assert set(events_fired.keys()) == { - Event.initiator_connected, - Event.responder_connected, - Event.disconnected, - } - assert events_fired[Event.disconnected] == [ - (initiator_key.hex_pk().decode('ascii'), 1000), - (initiator_key.hex_pk().decode('ascii'), 1000), - ] - @pytest.mark.asyncio def test_explicit_permanent_key_unavailable( self, server_no_key, server, client_factory @@ -1632,18 +1796,21 @@ def test_explicit_permanent_key( yield from server.wait_connections_closed() @pytest.mark.asyncio - def test_initiator_disconnected( - self, server, client_factory, - ): + def test_initiator_disconnected(self, server, client_factory): + """ + Check that the server sends a 'disconnected' message to all + responders of the associated path when the initiator + disconnects. + """ # Client handshakes initiator, i = yield from client_factory(initiator_handshake=True) - responder1, r = yield from client_factory(responder_handshake=True) - responder2, r = yield from client_factory(responder_handshake=True) + responder1, _ = yield from client_factory(responder_handshake=True) + responder2, _ = yield from client_factory(responder_handshake=True) # Disconnect initiator yield from initiator.close() - # Expect 'disconnected' msgs sent to all responders + # Expect 'disconnected' messages sent to all responders msg1, *_ = yield from responder1.recv() msg2, *_ = yield from responder2.recv() assert msg1 == msg2 == {'type': 'disconnected', 'id': i['id']} @@ -1653,9 +1820,11 @@ def test_initiator_disconnected( yield from server.wait_connections_closed() @pytest.mark.asyncio - def test_responder_disconnected( - self, server, client_factory, - ): + def test_responder_disconnected(self, server, client_factory): + """ + Check that the server sends 'disconnected' message to the + initiator when a responder disconnects. + """ # Client handshakes responder, r = yield from client_factory(responder_handshake=True) initiator, i = yield from client_factory(initiator_handshake=True) @@ -1663,9 +1832,39 @@ def test_responder_disconnected( # Disconnect initiator yield from responder.close() - # Expect 'disconnected' msg sent to initiator + # Expect 'disconnected' message sent to initiator msg, *_ = yield from initiator.recv() assert msg == {'type': 'disconnected', 'id': r['id']} yield from initiator.close() yield from server.wait_connections_closed() + + @pytest.mark.asyncio + def test_drop_responder_no_disconnect( + self, pack_nonce, server, client_factory + ): + """ + Ensure that dropping a responder explicitly does not trigger a + 'disconnected' message being sent to the initiator. + """ + # Client handshakes + initiator, i = yield from client_factory(initiator_handshake=True) + responder, r = yield from client_factory(responder_handshake=True) + + # Ignore 'new-responder' message + message, *_ = yield from initiator.recv() + assert message['type'] == 'new-responder' + + # Drop responder + yield from initiator.send(pack_nonce(i['cck'], 0x01, 0x00, i['ccsn']), { + 'type': 'drop-responder', + 'id': r['id'], + }) + + # Ensure no further message is being received + with pytest.raises(asyncio.TimeoutError): + yield from initiator.recv(timeout=1.0) + + # Bye + yield from initiator.close() + yield from server.wait_connections_closed() diff --git a/tests/test_server.py b/tests/test_server.py index 6231264..05d6c3a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,12 +2,23 @@ The tests provided in this module make sure that the server instance behaves as expected. """ +import asyncio +import collections import pytest -from saltyrtc import server +from saltyrtc.server import ( + AddressType, + CloseCode, + RawMessage, + ServerProtocol, + exception, + serve, +) +from saltyrtc.server.events import Event +@pytest.mark.usefixtures('evaluate_log') class TestServer: @pytest.mark.asyncio def test_repeated_permanent_keys(self, server_permanent_keys): @@ -15,6 +26,553 @@ def test_repeated_permanent_keys(self, server_permanent_keys): Ensure the server does not accept repeated keys. """ keys = server_permanent_keys + [server_permanent_keys[1]] - with pytest.raises(server.ServerKeyError) as exc_info: - yield from server.serve(None, keys) + with pytest.raises(exception.ServerKeyError) as exc_info: + yield from serve(None, keys) assert 'Repeated permanent keys' in str(exc_info.value) + + @pytest.mark.asyncio + def test_task_returned_connection_open( + self, mocker, log_ignore_filter, log_handler, sleep, server, client_factory, + ): + """ + Ensure the server handles a task returning early while the + connection is still running. + """ + def _filter(record): + return 'returned unexpectedly' in record.message \ + or (record.exception_message is not None + and 'returned unexpectedly' in record.exception_message) + log_ignore_filter(_filter) + + # Mock the initiator receive loop to return after a brief timeout + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def initiator_receive_loop(self): + # ZZzzzZZzz + yield from sleep(0.1) + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + + # Expect internal error + yield from server.wait_connections_closed() + assert not initiator.ws_client.open + assert initiator.ws_client.close_code == CloseCode.internal_error + assert len([record for record in log_handler.records if _filter(record)]) == 2 + + @pytest.mark.asyncio + def test_task_cancelled_connection_open( + self, mocker, log_ignore_filter, log_handler, sleep, server, client_factory + ): + """ + Ensure the server handles a task being cancelled early while + the connection is still running. + """ + def _filter(record): + return 'has been cancelled' in record.message \ + or (record.exception_message is not None + and 'has been cancelled' in record.exception_message) + log_ignore_filter(_filter) + + # Mock the initiator receive loop and cancel itself after a brief timeout + class _MockProtocol(ServerProtocol): + def initiator_receive_loop(self): + receive_loop = asyncio.ensure_future( + super().initiator_receive_loop(), loop=self._loop) + + @asyncio.coroutine + def _cancel_loop(): + yield from sleep(0.1) + receive_loop.cancel() + + asyncio.ensure_future(_cancel_loop(), loop=self._loop) + return receive_loop + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + + # Expect internal error + yield from server.wait_connections_closed() + assert not initiator.ws_client.open + assert initiator.ws_client.close_code == CloseCode.internal_error + assert len([record for record in log_handler.records if _filter(record)]) == 2 + + @pytest.mark.asyncio + def test_task_returned_connection_closed( + self, mocker, event_loop, log_handler, sleep, server, client_factory + ): + """ + Ensure the server does gracefully handle a task returning when + the connection is already closed. + """ + # Mock the initiator receive loop to be able to notify when it returns + receive_loop_closed_future = asyncio.Future(loop=event_loop) + + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def initiator_receive_loop(self): + connection_closed_future = self.client._connection_closed_future + self.client._connection_closed_future = asyncio.Future(loop=self._loop) + + # ZZzzzZZzz + yield from sleep(0.1) + + # Replace the future with the previous one to prevent an exception + @asyncio.coroutine + def _revert_future(): + yield from sleep(0.05) + self.client._connection_closed_future = connection_closed_future + asyncio.ensure_future(_revert_future(), loop=self._loop) + + # Resolve the connection closed future and the loop future + self.client._connection_closed_future.set_result(1337) + receive_loop_closed_future.set_result(sleep(0.1)) + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + + # Wait for the receive loop to return (and the waiter it returns) + yield from (yield from receive_loop_closed_future) + + # Bye + yield from initiator.ws_client.close() + yield from server.wait_connections_closed() + partials = ('Task done', 'result=None') + assert len([record for record in log_handler.records + if all(partial in record.message for partial in partials)]) == 1 + + @pytest.mark.asyncio + def test_disconnect_during_receive( + self, mocker, log_handler, sleep, server, client_factory + ): + """ + Check that the server handles a disconnect correctly when the + receive loop returns. + """ + # Mock the initiator keep alive loop to stay quiet + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def keep_alive_loop(self): + yield from sleep(60.0) + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Initiator handshake & disconnect immediately + initiator, _ = yield from client_factory(initiator_handshake=True) + yield from initiator.ws_client.close() + + # Expect disconnect during receive in the log + yield from server.wait_connections_closed() + assert len([record for record in log_handler.records + if 'closed while receiving' in record.message]) == 1 + + @pytest.mark.asyncio + def test_disconnect_during_send( + self, mocker, event_loop, log_handler, ws_client_factory, server + ): + """ + Check that the server handles a disconnect correctly when the + server tries to send something while the client is already gone. + """ + close_future = asyncio.Future(loop=event_loop) + + # Mock the handshake to wait until the client has been closed + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def handshake(self): + yield from close_future + return (yield from super().handshake()) + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Connect & disconnect immediately + ws_client = yield from ws_client_factory() + yield from ws_client.close() + close_future.set_result(None) + + # Expect disconnect during send in the log + yield from server.wait_connections_closed() + assert len([record for record in log_handler.records + if 'closed while sending' in record.message]) == 1 + + @pytest.mark.asyncio + def test_disconnect_during_task( + self, mocker, event_loop, log_handler, sleep, server, client_factory + ): + """ + Check that the server handles a disconnect correctly when a + task (that awaits a send operation) is awaited. + """ + close_future = asyncio.Future(loop=event_loop) + + # Mock the loops to stay quiet and enqueue a relay task + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def initiator_receive_loop(self): + yield from close_future + + @asyncio.coroutine + def _send_task(): + message = RawMessage(AddressType.server, self.client.id, b'\x00') + message._nonce = b'\x00' * 24 + yield from self.client.send(message) + + yield from self.client.enqueue_task(_send_task()) + yield from sleep(60.0) + + @asyncio.coroutine + def keep_alive_loop(self): + yield from sleep(60.0) + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Initiator handshake & disconnect immediately + initiator, _ = yield from client_factory(initiator_handshake=True) + yield from initiator.ws_client.close() + close_future.set_result(None) + + # Expect disconnect during send in the log + yield from server.wait_connections_closed() + partials = ['closed while sending', 'Stopping active task', 'Task done'] + assert len([record for record in log_handler.records + if any(partial in record.message for partial in partials)]) == 3 + + @pytest.mark.asyncio + def test_disconnect_keep_alive_ping( + self, mocker, event_loop, log_handler, sleep, ws_client_factory, + initiator_key, server, client_factory + ): + """ + Check that the server handles a disconnect correctly when + sending a ping. + """ + # Mock the initiator receive loop to return after a brief timeout + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def initiator_receive_loop(self): + # Wait until closed (and a little further) + yield from self.client.connection_closed_future + yield from sleep(0.1) + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Connect client to server + ws_client = yield from ws_client_factory() + + # Patch server's keep alive interval and timeout + assert len(server.protocols) == 1 + protocol = next(iter(server.protocols)) + protocol.client._keep_alive_interval = 0.1 + + # Initiator handshake + yield from client_factory(ws_client=ws_client, initiator_handshake=True) + connection_closed_future = server.wait_connection_closed_marker() + + # Get path instance of server and initiator's PathClient instance + path = server.paths.get(initiator_key.pk) + path_client = path.get_initiator() + + # Delay sending a ping + ping = path_client._connection.ping + ready_future = asyncio.Future(loop=event_loop) + + @asyncio.coroutine + def _mock_ping(*args): + yield from ready_future + return (yield from ping(*args)) + + mocker.patch.object(path_client._connection, 'ping', _mock_ping) + + # Let the server know we're ready once the connection has been closed. + # The server will now try to send a ping. + yield from ws_client.close() + ready_future.set_result(None) + + # Expect a normal closure (seen on the server side) + close_code = yield from connection_closed_future() + assert close_code == 1000 + yield from server.wait_connections_closed() + assert len([record for record in log_handler.records + if 'closed while pinging' in record.message]) == 1 + + @pytest.mark.asyncio + def test_disconnect_keep_alive_pong( + self, mocker, sleep, log_handler, ws_client_factory, server, client_factory + ): + """ + Check that the server handles a disconnect correctly when + waiting for a pong. + """ + # Mock the initiator receive loop to return after a brief timeout + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def initiator_receive_loop(self): + # Wait until closed + yield from self.client.connection_closed_future + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Create client and patch it to not answer pings + ws_client = yield from ws_client_factory() + ws_client.pong = asyncio.coroutine(lambda *args, **kwargs: None) + + # Patch server's keep alive interval and timeout + assert len(server.protocols) == 1 + protocol = next(iter(server.protocols)) + protocol.client._keep_alive_interval = 0.1 + protocol.client.keep_alive_timeout = 60.0 + + # Initiator handshake + yield from client_factory(ws_client=ws_client, initiator_handshake=True) + connection_closed_future = server.wait_connection_closed_marker() + + # Ensure the server can send a ping before closing + yield from sleep(0.25) + yield from ws_client.close() + + # Expect a normal closure (seen on the server side) + close_code = yield from connection_closed_future() + assert close_code == 1000 + yield from server.wait_connections_closed() + assert len([record for record in log_handler.records + if 'closed while waiting for pong' in record.message]) == 1 + + @pytest.mark.asyncio + def test_misbehaving_coroutine( + self, mocker, event_loop, sleep, log_ignore_filter, log_handler, + initiator_key, server, client_factory + ): + """ + Check that the server handles a misbehaving coroutine + correctly. + """ + log_ignore_filter(lambda record: 'queue did not close' in record.message) + + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + connection_closed_future = server.wait_connection_closed_marker() + + # Mock the task queue join timeout + mocker.patch('saltyrtc.server.server._TASK_QUEUE_JOIN_TIMEOUT', 0.1) + + # Get path instance of server and initiator's PathClient instance + path = server.paths.get(initiator_key.pk) + path_client = path.get_initiator() + + @asyncio.coroutine + def bad_coroutine(cancelled_future): + try: + yield from sleep(60.0) + except asyncio.CancelledError: + cancelled_future.set_result(None) + yield from sleep(60.0) + raise + + @asyncio.coroutine + def enqueue_bad_coroutine(): + cancelled_future = asyncio.Future(loop=event_loop) + yield from path_client.enqueue_task(bad_coroutine(cancelled_future)) + return cancelled_future + + # Enqueue misbehaving coroutine + # Note: We need to add two of these since one of them will be dequeued + # immediately and waited for which runs in a different code + # section. + active_coroutine_cancelled_future = yield from enqueue_bad_coroutine() + queued_coroutine_cancelled_future = yield from enqueue_bad_coroutine() + + # Close and wait + yield from initiator.ws_client.close() + + # Expect a normal closure (seen on the server side) + close_code = yield from connection_closed_future() + assert close_code == 1000 + yield from server.wait_connections_closed() + + # The active coroutine was activated and thus will be cancelled + assert active_coroutine_cancelled_future.result() is None + # Since the active coroutine does not re-raise the cancellation, it should + # never be marked as cancelled by the task loop. + assert len([record for record in log_handler.records + if 'Cancelling active task' in record.message]) == 0 + + # The queued coroutine was never waited for and it has not been added as a task + # to the event loop either. Thus, it will not be cancelled. + assert not queued_coroutine_cancelled_future.done() + # The queued task will be cancelled. + assert len([record for record in log_handler.records + if 'Cancelling 1 queued tasks' in record.message]) == 1 + # Ensure it has been picked up as a coroutine + assert len([record for record in log_handler.records + if 'Closing queued coroutine' in record.message]) == 1 + + # Check log messages + assert len([record for record in log_handler.records + if 'queue did not close' in record.message]) == 1 + + @pytest.mark.asyncio + def test_misbehaving_task( + self, mocker, event_loop, sleep, log_ignore_filter, log_handler, + initiator_key, server, client_factory + ): + """ + Check that the server handles a misbehaving task correctly. + """ + log_ignore_filter(lambda record: 'queue did not close' in record.message) + + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + connection_closed_future = server.wait_connection_closed_marker() + + # Mock the task queue join timeout + mocker.patch('saltyrtc.server.server._TASK_QUEUE_JOIN_TIMEOUT', 0.1) + + # Get path instance of server and initiator's PathClient instance + path = server.paths.get(initiator_key.pk) + path_client = path.get_initiator() + + @asyncio.coroutine + def bad_coroutine(cancelled_future): + try: + yield from sleep(60.0) + except asyncio.CancelledError: + cancelled_future.set_result(None) + yield from sleep(60.0) + raise + + @asyncio.coroutine + def enqueue_bad_task(): + cancelled_future = asyncio.Future(loop=event_loop) + yield from path_client.enqueue_task( + asyncio.ensure_future(bad_coroutine(cancelled_future))) + return cancelled_future + + # Enqueue misbehaving task + # Note: We need to add two of these since one of them will be dequeued + # immediately and waited for which runs in a different code + # section. + active_task_cancelled_future = yield from enqueue_bad_task() + queued_task_cancelled_future = yield from enqueue_bad_task() + + # Close and wait + yield from initiator.ws_client.close() + + # Expect a normal closure (seen on the server side) + close_code = yield from connection_closed_future() + assert close_code == 1000 + yield from server.wait_connections_closed() + + # The active task will be implicitly cancelled by cancellation of the task loop + assert active_task_cancelled_future.result() is None + # Since the active task does not re-raise the cancellation, it should never be + # marked as cancelled by the task loop. + assert len([record for record in log_handler.records + if 'Cancelling active task' in record.message]) == 0 + + # The queued task has been scheduled on the event loop and thus will be + # cancelled by the task queue cancellation. + assert queued_task_cancelled_future.result() is None + # The queued task will be cancelled. + assert len([record for record in log_handler.records + if 'Cancelling 1 queued tasks' in record.message]) == 1 + # Ensure it has been picked up as a task + assert len([record for record in log_handler.records + if 'Cancelling queued task' in record.message]) == 1 + + # Check log messages + assert len([record for record in log_handler.records + if 'queue did not close' in record.message]) == 1 + + @pytest.mark.asyncio + def test_event_emitted(self, initiator_key, server, client_factory): + """ + Ensure the server does emit events as expected. + """ + # Dictionary where fired events are added + events_fired = collections.defaultdict(list) + + @asyncio.coroutine + def callback(event: Event, *data): + events_fired[event].append(data) + + # Register event callback for all events + for event in Event: + server.register_event_callback(event, callback) + + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + + # Responder handshake + responder, _ = yield from client_factory(responder_handshake=True) + + yield from initiator.recv() + assert set(events_fired.keys()) == { + Event.initiator_connected, + Event.responder_connected, + } + assert events_fired[Event.initiator_connected] == [ + (initiator_key.hex_pk().decode('ascii'),) + ] + assert events_fired[Event.responder_connected] == [ + (initiator_key.hex_pk().decode('ascii'),) + ] + + yield from initiator.close() + yield from responder.close() + yield from server.wait_connections_closed() + + assert set(events_fired.keys()) == { + Event.initiator_connected, + Event.responder_connected, + Event.disconnected, + } + assert events_fired[Event.disconnected] == [ + (initiator_key.hex_pk().decode('ascii'), 1000), + (initiator_key.hex_pk().decode('ascii'), 1000), + ] + + @pytest.mark.asyncio + def test_error_after_disconnect( + self, mocker, server, client_factory + ): + """ + Ensure the server does not error after the client's disconnect + procedure has been started. + + This test exists to prevent a regression. Previously it was + possible to enqueue tasks on a client whose task queue has + already been closed. + """ + # Initiator handshake + initiator, _ = yield from client_factory(initiator_handshake=True) + connection_closed_future = server.wait_connection_closed_marker() + + # Mock the responder's client handler to wait before raising an exception + class _MockProtocol(ServerProtocol): + @asyncio.coroutine + def handle_client(self): + try: + yield from super().handle_client() + except Exception: + # Hold back the exception until the initiator has closed its + # connection to provoke a race condition + yield from connection_closed_future() + raise + + mocker.patch.object(server, '_protocol_class', _MockProtocol) + + # Responder handshake + responder, _ = yield from client_factory(responder_handshake=True) + + # Disconnect the responder first, then the initiator. + # The initiator may trigger some behaviour on the responder resulting in an + # exception being logged. Thus, we don't have to assert anything here. + yield from responder.close() + yield from initiator.close() + yield from server.wait_connections_closed()