diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 6896cf4d4a41..e5b16009e563 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -15,71 +15,18 @@ """Flower client interceptor.""" -import base64 -import collections -from collections.abc import Sequence -from logging import WARNING -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import grpc from cryptography.hazmat.primitives.asymmetric import ec +from google.protobuf.message import Message as GrpcMessage -from flwr.common.logger import log +from flwr.common import now +from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_public_key, - compute_hmac, - generate_shared_key, public_key_to_bytes, + sign_message, ) -from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611 -from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 - CreateNodeRequest, - DeleteNodeRequest, - PingRequest, - PullMessagesRequest, - PullTaskInsRequest, - PushMessagesRequest, - PushTaskResRequest, -) -from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 - -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - PullMessagesRequest, - PushMessagesRequest, -] - - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() - - return value - - -class _ClientCallDetails( - collections.namedtuple( - "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") - ), - grpc.ClientCallDetails, # type: ignore -): - """Details for each client call. - - The class will be passed on as the first argument in continuation function. - In our case, `AuthenticateClientInterceptor` adds new metadata to the construct. - """ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore @@ -91,86 +38,33 @@ def __init__( public_key: ec.EllipticCurvePublicKey, ): self.private_key = private_key - self.public_key = public_key - self.shared_secret: Optional[bytes] = None - self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None - self.encoded_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self.public_key) - ) + self.public_key_bytes = public_key_to_bytes(public_key) def intercept_unary_unary( self, continuation: Callable[[Any, Any], Any], client_call_details: grpc.ClientCallDetails, - request: Request, + request: GrpcMessage, ) -> grpc.Call: """Flower client interceptor. Intercept unary call from client and add necessary authentication header in the RPC metadata. """ - metadata = [] - postprocess = False - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - - # Always add the public key header - metadata.append( - ( - _PUBLIC_KEY_HEADER, - self.encoded_public_key, - ) - ) - - if isinstance(request, CreateNodeRequest): - postprocess = True - elif isinstance( - request, - ( - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - PullMessagesRequest, - PushMessagesRequest, - ), - ): - if self.shared_secret is None: - raise RuntimeError("Failure to compute hmac") - - message_bytes = request.SerializeToString(deterministic=True) - metadata.append( - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode( - compute_hmac(self.shared_secret, message_bytes) - ), - ) - ) + metadata = list(client_call_details.metadata or []) - client_call_details = _ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - ) + # Add the public key + metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes)) - response = continuation(client_call_details, request) - if postprocess: - server_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata()) - ) + # Add timestamp + timestamp = now().isoformat() + metadata.append((TIMESTAMP_HEADER, timestamp)) - if server_public_key_bytes != b"": - self.server_public_key = bytes_to_public_key(server_public_key_bytes) - else: - log(WARNING, "Can't get server public key, SuperLink may be offline") + # Sign and add the signature + signature = sign_message(self.private_key, timestamp.encode("ascii")) + metadata.append((SIGNATURE_HEADER, signature)) - if self.server_public_key is not None: - self.shared_secret = generate_shared_key( - self.private_key, self.server_public_key - ) + # Overwrite the metadata + details = client_call_details._replace(metadata=metadata) - return response + return continuation(details, request) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index a029b926423f..34a0ae6bd91f 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -15,28 +15,28 @@ """Flower client interceptor tests.""" -import base64 -import inspect import threading import unittest from collections.abc import Sequence from concurrent import futures from logging import DEBUG, INFO, WARN -from typing import Optional, Union, get_args +from typing import Any, Callable, Optional, Union import grpc +from google.protobuf.message import Message as GrpcMessage +from parameterized import parameterized from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH, serde +from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.record import RecordSet from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, generate_key_pairs, - generate_shared_key, public_key_to_bytes, + verify_signature, ) from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -48,13 +48,10 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.proto.fleet_pb2_grpc import FleetServicer from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 -from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request - class _MockServicer: """Mock Servicer for Flower clients.""" @@ -65,35 +62,24 @@ def __init__(self) -> None: self._received_client_metadata: Optional[ Sequence[tuple[str, Union[str, bytes]]] ] = None - self.server_private_key, self.server_public_key = generate_key_pairs() self._received_message_bytes: bytes = b"" def unary_unary( - self, request: Request, context: grpc.ServicerContext - ) -> Union[ - CreateNodeResponse, DeleteNodeResponse, PushTaskResResponse, PullTaskInsResponse - ]: + self, request: GrpcMessage, context: grpc.ServicerContext + ) -> GrpcMessage: """Handle unary call.""" with self._lock: self._received_client_metadata = context.invocation_metadata() self._received_message_bytes = request.SerializeToString(deterministic=True) if isinstance(request, CreateNodeRequest): - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self.server_public_key) - ), - ), - ) - ) return CreateNodeResponse(node=Node(node_id=123)) if isinstance(request, DeleteNodeRequest): return DeleteNodeResponse() if isinstance(request, PushTaskResRequest): return PushTaskResResponse() + if isinstance(request, GetRunRequest): + return GetRunResponse() return PullTaskInsResponse( task_ins_list=[ @@ -153,16 +139,6 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: server.add_generic_rpc_handlers((generic_handler,)) -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() - - return value - - def _init_retry_invoker() -> RetryInvoker: return RetryInvoker( wait_gen_factory=exponential, @@ -201,6 +177,36 @@ def _init_retry_invoker() -> RetryInvoker: ) +def _create_node(conn: Any) -> None: + create_node = conn[2] + create_node() + + +def _delete_node(conn: Any) -> None: + _, _, create_node, delete_node, _, _ = conn + create_node() + delete_node() + + +def _receive(conn: Any) -> None: + receive, _, create_node, _, _, _ = conn + create_node() + receive() + + +def _send(conn: Any) -> None: + receive, send, create_node, _, _, _ = conn + create_node() + receive() + send(Message(Metadata(0, "", 123, 0, "", "", 0, ""), RecordSet())) + + +def _get_run(conn: Any) -> None: + _, _, create_node, _, get_run, _ = conn + create_node() + get_run(0) + + class TestAuthenticateClientInterceptor(unittest.TestCase): """Test for client interceptor client authentication.""" @@ -219,7 +225,10 @@ def setUp(self) -> None: self._connection = grpc_request_response self._address = f"localhost:{port}" - def test_client_auth_create_node(self) -> None: + @parameterized.expand( + [(_create_node,), (_delete_node,), (_receive,), (_send,), (_get_run,)] + ) # type: ignore + def test_client_auth_rpc(self, grpc_call: Callable[[Any], None]) -> None: """Test client authentication during create node.""" # Prepare retry_invoker = _init_retry_invoker() @@ -233,190 +242,25 @@ def test_client_auth_create_node(self) -> None: None, (self._client_private_key, self._client_public_key), ) as conn: - _, _, create_node, _, _, _ = conn - assert create_node is not None - create_node() + grpc_call(conn) received_metadata = self._servicer.received_client_metadata() assert received_metadata is not None - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) + metadata_dict = dict(received_metadata) + actual_public_key = metadata_dict[PUBLIC_KEY_HEADER] + signature = metadata_dict[SIGNATURE_HEADER] + timestamp = metadata_dict[TIMESTAMP_HEADER] - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) + expected_public_key = public_key_to_bytes(self._client_public_key) # Assert + assert isinstance(signature, bytes) + assert isinstance(timestamp, str) assert actual_public_key == expected_public_key - - def test_client_auth_delete_node(self) -> None: - """Test client authentication during delete node.""" - # Prepare - retry_invoker = _init_retry_invoker() - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - _, _, create_node, delete_node, _, _ = conn - assert create_node is not None - create_node() - assert delete_node is not None - delete_node() - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) + assert verify_signature( + self._client_public_key, timestamp.encode("ascii"), signature ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac - - def test_client_auth_receive(self) -> None: - """Test client authentication during receive node.""" - # Prepare - retry_invoker = _init_retry_invoker() - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - receive, _, create_node, _, _, _ = conn - assert create_node is not None - create_node() - assert receive is not None - receive() - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) - ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac - - def test_client_auth_send(self) -> None: - """Test client authentication during send node.""" - # Prepare - retry_invoker = _init_retry_invoker() - message = Message(Metadata(0, "", 123, 0, "", "", 0, ""), RecordSet()) - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - receive, send, create_node, _, _, _ = conn - assert create_node is not None - create_node() - assert receive is not None - receive() - assert send is not None - send(message) - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) - ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac - - def test_client_auth_get_run(self) -> None: - """Test client authentication during send node.""" - # Prepare - retry_invoker = _init_retry_invoker() - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - _, _, create_node, _, get_run, _ = conn - assert create_node is not None - create_node() - assert get_run is not None - get_run(0) - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) - ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac def test_without_servicer(self) -> None: """Test client authentication without servicer.""" @@ -439,20 +283,6 @@ def test_without_servicer(self) -> None: assert self._servicer.received_client_metadata() is None - def test_fleet_requests_included(self) -> None: - """Test if all Fleet requests are included in the authentication mode.""" - # Prepare - requests = get_args(Request) - rpc_names = {req.__qualname__.removesuffix("Request") for req in requests} - expected_rpc_names = { - name - for name, ref in inspect.getmembers(FleetServicer) - if inspect.isfunction(ref) - } - - # Assert - assert expected_rpc_names == rpc_names - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index e0d31da2bddb..4a5554edfbbf 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -112,6 +112,12 @@ ACCESS_TOKEN_KEY = "access_token" REFRESH_TOKEN_KEY = "refresh_token" +# Constants for node authentication +PUBLIC_KEY_HEADER = "public-key-bin" # Must end with "-bin" for binary data +SIGNATURE_HEADER = "signature-bin" # Must end with "-bin" for binary data +TIMESTAMP_HEADER = "timestamp" +TIMESTAMP_TOLERANCE = 10 # Tolerance for timestamp verification + class MessageType: """Message type.""" diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 167507e54f8d..2197ee266ac9 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -15,91 +15,54 @@ """Flower server interceptor.""" -import base64 -from collections.abc import Sequence -from logging import INFO, WARNING -from typing import Any, Callable, Optional, Union +import datetime +from typing import Any, Callable, Optional, cast import grpc -from cryptography.hazmat.primitives.asymmetric import ec - -from flwr.common.logger import log +from google.protobuf.message import Message as GrpcMessage + +from flwr.common import now +from flwr.common.constant import ( + PUBLIC_KEY_HEADER, + SIGNATURE_HEADER, + TIMESTAMP_HEADER, + TIMESTAMP_TOLERANCE, +) from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, bytes_to_public_key, - generate_shared_key, - verify_hmac, + verify_signature, ) -from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, - DeleteNodeRequest, - DeleteNodeResponse, - PingRequest, - PingResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, ) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.linkstate import LinkStateFactory -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, -] - -Response = Union[ - CreateNodeResponse, - DeleteNodeResponse, - PullTaskInsResponse, - PushTaskResResponse, - GetRunResponse, - PingResponse, - GetFabResponse, -] - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() +def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler: + def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: + context.abort(grpc.StatusCode.UNAUTHENTICATED, message) + raise RuntimeError("Should not reach this point") # Make mypy happy - return value + return grpc.unary_unary_rpc_method_handler(terminate) class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore - """Server interceptor for node authentication.""" - - def __init__(self, state_factory: LinkStateFactory): + """Server interceptor for node authentication. + + Parameters + ---------- + state_factory : LinkStateFactory + A factory for creating new instances of LinkState. + auto_auth : bool (default: False) + If True, nodes are authenticated without requiring their public keys to be + pre-stored in the LinkState. If False, only nodes with pre-stored public keys + can be authenticated. + """ + + def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False): self.state_factory = state_factory - state = self.state_factory.state() - - self.node_public_keys = state.get_node_public_keys() - if len(self.node_public_keys) == 0: - log(WARNING, "Authentication enabled, but no known public keys configured") - - private_key = state.get_server_private_key() - public_key = state.get_server_public_key() - - if private_key is None or public_key is None: - raise ValueError("Error loading authentication keys") - - self.server_private_key = bytes_to_private_key(private_key) - self.encoded_server_public_key = base64.urlsafe_b64encode(public_key) + self.auto_auth = auto_auth def intercept_service( self, @@ -112,117 +75,80 @@ def intercept_service( metadata sent by the node. Continue RPC call if node is authenticated, else, terminate RPC call by setting context to abort. """ + state = self.state_factory.state() + metadata_dict = dict(handler_call_details.invocation_metadata) + + # Retrieve info from the metadata + try: + node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER]) + timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER]) + signature = cast(bytes, metadata_dict[SIGNATURE_HEADER]) + except KeyError: + return _unary_unary_rpc_terminator("Missing authentication metadata") + + if not self.auto_auth: + # Abort the RPC call if the node public key is not found + if node_pk_bytes not in state.get_node_public_keys(): + return _unary_unary_rpc_terminator("Public key not recognized") + + # Verify the signature + node_pk = bytes_to_public_key(node_pk_bytes) + if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature): + return _unary_unary_rpc_terminator("Invalid signature") + + # Verify the timestamp + current = now() + time_diff = current - datetime.datetime.fromisoformat(timestamp_iso) + # Abort the RPC call if the timestamp is too old or in the future + if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE: + return _unary_unary_rpc_terminator("Invalid timestamp") + + # Continue the RPC call + expected_node_id = state.get_node_id(node_pk_bytes) + if not handler_call_details.method.endswith("CreateNode"): + if expected_node_id is None: + return _unary_unary_rpc_terminator("Invalid node ID") # One of the method handlers in # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) - return self._generic_auth_unary_method_handler(method_handler) + return self._wrap_method_handler( + method_handler, expected_node_id, node_pk_bytes + ) - def _generic_auth_unary_method_handler( - self, method_handler: grpc.RpcMethodHandler + def _wrap_method_handler( + self, + method_handler: grpc.RpcMethodHandler, + expected_node_id: Optional[int], + node_public_key: bytes, ) -> grpc.RpcMethodHandler: def _generic_method_handler( - request: Request, + request: GrpcMessage, context: grpc.ServicerContext, - ) -> Response: - node_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - ) - ) - if node_public_key_bytes not in self.node_public_keys: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - if isinstance(request, CreateNodeRequest): - response = self._create_authenticated_node( - node_public_key_bytes, request, context - ) - log( - INFO, - "AuthenticateServerInterceptor: Created node_id=%s", - response.node.node_id, - ) - return response - - # Verify hmac value - hmac_value = base64.urlsafe_b64decode( - _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - ) - ) - public_key = bytes_to_public_key(node_public_key_bytes) - - if not self._verify_hmac(public_key, request, hmac_value): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - # Verify node_id - node_id = self.state_factory.state().get_node_id(node_public_key_bytes) - - if not self._verify_node_id(node_id, request): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - return method_handler.unary_unary(request, context) # type: ignore + ) -> GrpcMessage: + # Verify the node ID + if not isinstance(request, CreateNodeRequest): + try: + if request.node.node_id != expected_node_id: # type: ignore + raise ValueError + except (AttributeError, ValueError): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID") + + response: GrpcMessage = method_handler.unary_unary(request, context) + + # Set the public key after a successful CreateNode request + if isinstance(response, CreateNodeResponse): + state = self.state_factory.state() + try: + state.set_node_public_key(response.node.node_id, node_public_key) + except ValueError as e: + # Remove newly created node if setting the public key fails + state.delete_node(response.node.node_id) + context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e)) + + return response return grpc.unary_unary_rpc_method_handler( _generic_method_handler, request_deserializer=method_handler.request_deserializer, response_serializer=method_handler.response_serializer, ) - - def _verify_node_id( - self, - node_id: Optional[int], - request: Union[ - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - ], - ) -> bool: - if node_id is None: - return False - if isinstance(request, PushTaskResRequest): - if len(request.task_res_list) == 0: - return False - return request.task_res_list[0].task.producer.node_id == node_id - if isinstance(request, GetRunRequest): - return node_id in self.state_factory.state().get_nodes(request.run_id) - return request.node.node_id == node_id - - def _verify_hmac( - self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes - ) -> bool: - shared_secret = generate_shared_key(self.server_private_key, public_key) - message_bytes = request.SerializeToString(deterministic=True) - return verify_hmac(shared_secret, message_bytes, hmac_value) - - def _create_authenticated_node( - self, - public_key_bytes: bytes, - request: CreateNodeRequest, - context: grpc.ServicerContext, - ) -> CreateNodeResponse: - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - self.encoded_server_public_key, - ), - ) - ) - state = self.state_factory.state() - node_id = state.get_node_id(public_key_bytes) - - # Handle `CreateNode` here instead of calling the default method handler - # Return previously assigned `node_id` for the provided `public_key` - if node_id is not None: - state.acknowledge_ping(node_id, request.ping_interval) - return CreateNodeResponse(node=Node(node_id=node_id)) - - # No `node_id` exists for the provided `public_key` - # Handle `CreateNode` here instead of calling the default method handler - # Note: the innermost `CreateNode` method will never be called - node_id = state.create_node(request.ping_interval) - state.set_node_public_key(node_id, public_key_bytes) - return CreateNodeResponse(node=Node(node_id=node_id)) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 9984b93f3e84..6861d0235c31 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -15,21 +15,28 @@ """Flower server interceptor tests.""" -import base64 +import datetime import unittest +from typing import Any, Callable import grpc - -from flwr.common import ConfigsRecord -from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, Status +from parameterized import parameterized + +from flwr.common import ConfigsRecord, now +from flwr.common.constant import ( + FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, + PUBLIC_KEY_HEADER, + SIGNATURE_HEADER, + TIMESTAMP_HEADER, + Status, +) from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, generate_key_pairs, - generate_shared_key, - private_key_to_bytes, public_key_to_bytes, + sign_message, ) from flwr.common.typing import RunStatus +from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -49,11 +56,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory -from .server_interceptor import ( - _AUTH_TOKEN_HEADER, - _PUBLIC_KEY_HEADER, - AuthenticateServerInterceptor, -) +from .server_interceptor import AuthenticateServerInterceptor class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 @@ -61,18 +64,13 @@ class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 def setUp(self) -> None: """Initialize mock stub and server interceptor.""" - self._node_private_key, self._node_public_key = generate_key_pairs() - self._server_private_key, self._server_public_key = generate_key_pairs() + self.node_sk, self.node_pk = generate_key_pairs() state_factory = LinkStateFactory(":flwr-in-memory-state:") self.state = state_factory.state() ffs_factory = FfsFactory(".") self.ffs = ffs_factory.ffs() - self.state.store_server_private_public_key( - private_key_to_bytes(self._server_private_key), - public_key_to_bytes(self._server_public_key), - ) - self.state.store_node_public_keys({public_key_to_bytes(self._node_public_key)}) + self.state.store_node_public_keys({public_key_to_bytes(self.node_pk)}) self._server_interceptor = AuthenticateServerInterceptor(state_factory) self._server: grpc.Server = _run_fleet_api_grpc_rere( @@ -114,332 +112,206 @@ def setUp(self) -> None: request_serializer=PingRequest.SerializeToString, response_deserializer=PingResponse.FromString, ) + self._get_fab = self._channel.unary_unary( + "/flwr.proto.Fleet/GetFab", + request_serializer=GetFabRequest.SerializeToString, + response_deserializer=GetFabResponse.FromString, + ) def tearDown(self) -> None: """Clean up grpc server.""" self._server.stop(None) - def test_successful_create_node_with_metadata(self) -> None: - """Test server interceptor for creating node.""" - # Prepare - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._create_node.with_call( + def _make_metadata(self) -> list[Any]: + """Create metadata with signature and timestamp.""" + timestamp = now().isoformat() + signature = sign_message(self.node_sk, timestamp.encode("ascii")) + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(self.node_pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _make_metadata_with_invalid_signature(self) -> list[Any]: + """Create metadata with invalid signature.""" + timestamp = now().isoformat() + sk, _ = generate_key_pairs() + signature = sign_message(sk, timestamp.encode("ascii")) + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(self.node_pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _make_metadata_with_invalid_public_key(self) -> list[Any]: + """Create metadata with invalid public key.""" + timestamp = now().isoformat() + signature = sign_message(self.node_sk, timestamp.encode("ascii")) + _, pk = generate_key_pairs() + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _make_metadata_with_invalid_timestamp(self) -> list[Any]: + """Create metadata with invalid timestamp.""" + timestamp = (now() - datetime.timedelta(seconds=99)).isoformat() + signature = sign_message(self.node_sk, timestamp.encode("ascii")) + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(self.node_pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _test_create_node(self, metadata: list[Any]) -> Any: + """Test CreateNode.""" + return self._create_node.with_call( request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - # Assert - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - - def test_unsuccessful_create_node_with_metadata(self) -> None: - """Test server interceptor for creating node unsuccessfully.""" - # Prepare - _, node_public_key = generate_key_pairs() - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(node_public_key) + metadata=metadata, ) - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - def test_successful_delete_node_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" - # Prepare + def _test_delete_node(self, metadata: list[Any]) -> Any: + """Test DeleteNode.""" node_id = self._create_node_and_set_public_key() - request = DeleteNodeRequest(node=Node(node_id=node_id)) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, DeleteNodeResponse) - assert grpc.StatusCode.OK == call.code() + req = DeleteNodeRequest(node=Node(node_id=node_id)) + return self._delete_node.with_call(request=req, metadata=metadata) - def test_unsuccessful_delete_node_with_metadata(self) -> None: - """Test server interceptor for deleting node unsuccessfully.""" - # Prepare + def _test_pull_task_ins(self, metadata: list[Any]) -> Any: + """Test PullTaskIns.""" node_id = self._create_node_and_set_public_key() - request = DeleteNodeRequest(node=Node(node_id=node_id)) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) + req = PullTaskInsRequest(node=Node(node_id=node_id)) + return self._pull_task_ins.with_call(request=req, metadata=metadata) - def test_successful_pull_task_ins_with_metadata(self) -> None: - """Test server interceptor for pull task ins.""" - # Prepare - node_id = self._create_node_and_set_public_key() - request = PullTaskInsRequest(node=Node(node_id=node_id)) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._pull_task_ins.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, PullTaskInsResponse) - assert grpc.StatusCode.OK == call.code() - - def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: - """Test server interceptor for pull task ins unsuccessfully.""" - # Prepare - node_id = self._create_node_and_set_public_key() - request = PullTaskInsRequest(node=Node(node_id=node_id)) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._pull_task_ins.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - def test_successful_push_task_res_with_metadata(self) -> None: - """Test server interceptor for push task res.""" - # Prepare + def _test_push_task_res(self, metadata: list[Any]) -> Any: + """Test PushTaskRes.""" node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. - _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - request = PushTaskResRequest( + self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + req = PushTaskResRequest( + node=Node(node_id=node_id), task_res_list=[ TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id) - ] - ) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) + ], ) + return self._push_task_res.with_call(request=req, metadata=metadata) - # Execute - response, call = self._push_task_res.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, PushTaskResResponse) - assert grpc.StatusCode.OK == call.code() - - def test_unsuccessful_push_task_res_with_metadata(self) -> None: - """Test server interceptor for push task res unsuccessfully.""" - # Prepare + def _test_get_run(self, metadata: list[Any]) -> Any: + """Test GetRun.""" node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - # Transition status to running. PushTaskRes is only allowed in running status. - _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - request = PushTaskResRequest( - task_res_list=[TaskRes(task=Task(producer=Node(node_id=node_id)))] - ) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError) as e: - self._push_task_res.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - assert e.exception.code() == grpc.StatusCode.UNAUTHENTICATED - - def test_successful_get_run_with_metadata(self) -> None: - """Test server interceptor for get run.""" - # Prepare - self._create_node_and_set_public_key() - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. GetRun is only allowed in running status. - _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - request = GetRunRequest(run_id=run_id) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._get_run.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, GetRunResponse) - assert grpc.StatusCode.OK == call.code() + self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + req = GetRunRequest(node=Node(node_id=node_id), run_id=run_id) + return self._get_run.with_call(request=req, metadata=metadata) - def test_unsuccessful_get_run_with_metadata(self) -> None: - """Test server interceptor for get run unsuccessfully.""" - # Prepare - self._create_node_and_set_public_key() - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - request = GetRunRequest(run_id=run_id) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._get_run.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - def test_successful_ping_with_metadata(self) -> None: - """Test server interceptor for ping.""" - # Prepare + def _test_ping(self, metadata: list[Any]) -> Any: + """Test Ping.""" node_id = self._create_node_and_set_public_key() - request = PingRequest(node=Node(node_id=node_id)) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) + req = PingRequest(node=Node(node_id=node_id)) + return self._ping.with_call(request=req, metadata=metadata) - # Execute - response, call = self._ping.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, PingResponse) - assert grpc.StatusCode.OK == call.code() - - def test_unsuccessful_ping_with_metadata(self) -> None: - """Test server interceptor for ping unsuccessfully.""" - # Prepare + def _test_get_fab(self, metadata: list[Any]) -> Any: + """Test GetFab.""" + fab_hash = self.ffs.put(b"mock fab content", {}) node_id = self._create_node_and_set_public_key() - request = PingRequest(node=Node(node_id=node_id)) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + # Transition status to running. PushTaskRes is only allowed in running status. + self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + req = GetFabRequest( + node=Node(node_id=node_id), + run_id=run_id, + hash_str=fab_hash, ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._ping.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) + return self._get_fab.with_call(request=req, metadata=metadata) def _create_node_and_set_public_key(self) -> int: node_id = self.state.create_node(ping_interval=30) - pk_bytes = public_key_to_bytes(self._node_public_key) + pk_bytes = public_key_to_bytes(self.node_pk) self.state.set_node_public_key(node_id, pk_bytes) return node_id + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_successful_rpc_with_metadata( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC.""" + # Execute + _, call = rpc(self, self._make_metadata()) + + # Assert + assert call.code() == grpc.StatusCode.OK + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_unsuccessful_rpc_with_invalid_signature( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC unsuccessfully.""" + # Execute & Assert + with self.assertRaises(grpc.RpcError) as cm: + rpc(self, self._make_metadata_with_invalid_signature()) + assert cm.exception.code() == grpc.StatusCode.UNAUTHENTICATED + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_unsuccessful_rpc_with_invalid_public_key( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC unsuccessfully.""" + # Execute & Assert + with self.assertRaises(grpc.RpcError) as cm: + rpc(self, self._make_metadata_with_invalid_public_key()) + assert cm.exception.code() == grpc.StatusCode.UNAUTHENTICATED + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_unsuccessful_rpc_with_invalid_timestamp( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC unsuccessfully.""" + # Execute & Assert + with self.assertRaises(grpc.RpcError) as cm: + rpc(self, self._make_metadata_with_invalid_timestamp()) + assert cm.exception.code() == grpc.StatusCode.UNAUTHENTICATED