From b1776438dbadc893750bbe45a7cb66044669f8e6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sat, 4 Jan 2025 13:18:06 +0000 Subject: [PATCH 01/10] init --- .../server/superlink/linkstate/in_memory_linkstate.py | 6 +++++- src/py/flwr/server/superlink/linkstate/linkstate.py | 8 +++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index d22072b41621..c4cbf6aaf754 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -307,7 +307,7 @@ def num_task_res(self) -> int: return len(self.task_res_store) def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None + self, ping_interval: float ) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random int64 as node_id @@ -366,6 +366,10 @@ def get_nodes(self, run_id: int) -> set[int]: if online_until > current_time } + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Store `public_key` for the specified `node_id`.""" + self.public_key_to_node_id[public_key] = node_id + def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" return self.public_key_to_node_id.get(node_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 4f3c16a5460a..1458ee4e4a3a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -154,9 +154,7 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: """Get all TaskIns IDs for the given run_id.""" @abc.abstractmethod - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" @abc.abstractmethod @@ -173,6 +171,10 @@ def get_nodes(self, run_id: int) -> set[int]: an empty `Set` MUST be returned. """ + @abc.abstractmethod + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Store `public_key` for the specified `node_id`.""" + @abc.abstractmethod def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" From 77352c3762cd90d45e18948610d06cf1d8b824bb Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sun, 5 Jan 2025 16:26:23 +0000 Subject: [PATCH 02/10] add and implement two new methods --- .../fleet/grpc_rere/server_interceptor.py | 3 +- .../grpc_rere/server_interceptor_test.py | 107 +++--------------- .../linkstate/in_memory_linkstate.py | 48 ++++---- .../server/superlink/linkstate/linkstate.py | 8 +- .../superlink/linkstate/linkstate_test.py | 44 +++---- .../superlink/linkstate/sqlite_linkstate.py | 62 ++++++---- 6 files changed, 106 insertions(+), 166 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index c07ee0788493..6cafaaa21459 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -223,5 +223,6 @@ def _create_authenticated_node( # No `node_id` exists for the provided `public_key` # Handle `CreateNode` here instead of calling the default method handler # Note: the innermost `CreateNode` method will never be called - node_id = state.create_node(request.ping_interval, public_key_bytes) + node_id = state.create_node(request.ping_interval) + state.set_node_public_key(node_id, public_key_bytes) return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index a0ff7a77304a..9984b93f3e84 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -161,9 +161,7 @@ def test_unsuccessful_create_node_with_metadata(self) -> None: def test_successful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = DeleteNodeRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -191,9 +189,7 @@ def test_successful_delete_node_with_metadata(self) -> None: def test_unsuccessful_delete_node_with_metadata(self) -> None: """Test server interceptor for deleting node unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = DeleteNodeRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -217,9 +213,7 @@ def test_unsuccessful_delete_node_with_metadata(self) -> None: def test_successful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for pull task ins.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PullTaskInsRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -247,9 +241,7 @@ def test_successful_pull_task_ins_with_metadata(self) -> None: def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: """Test server interceptor for pull task ins unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PullTaskInsRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -273,9 +265,7 @@ def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: def test_successful_push_task_res_with_metadata(self) -> None: """Test server interceptor for push task res.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -311,9 +301,7 @@ def test_successful_push_task_res_with_metadata(self) -> None: def test_unsuccessful_push_task_res_with_metadata(self) -> None: """Test server interceptor for push task res unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -344,9 +332,7 @@ def test_unsuccessful_push_task_res_with_metadata(self) -> None: def test_successful_get_run_with_metadata(self) -> None: """Test server interceptor for get run.""" # Prepare - self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. GetRun is only allowed in running status. _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) @@ -378,9 +364,7 @@ def test_successful_get_run_with_metadata(self) -> None: def test_unsuccessful_get_run_with_metadata(self) -> None: """Test server interceptor for get run unsuccessfully.""" # Prepare - self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) node_private_key, _ = generate_key_pairs() @@ -405,9 +389,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: def test_successful_ping_with_metadata(self) -> None: """Test server interceptor for ping.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PingRequest(node=Node(node_id=node_id)) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -435,9 +417,7 @@ def test_successful_ping_with_metadata(self) -> None: def test_unsuccessful_ping_with_metadata(self) -> None: """Test server interceptor for ping unsuccessfully.""" # Prepare - node_id = self.state.create_node( - ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) - ) + node_id = self._create_node_and_set_public_key() request = PingRequest(node=Node(node_id=node_id)) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) @@ -458,65 +438,8 @@ def test_unsuccessful_ping_with_metadata(self) -> None: ), ) - def test_successful_restore_node(self) -> None: - """Test server interceptor for restoring node.""" - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - node = response.node - node_node_id = node.node_id - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - - request = DeleteNodeRequest(node=node) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - assert isinstance(response, DeleteNodeResponse) - assert grpc.StatusCode.OK == call.code() - - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - response, call = self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - assert response.node.node_id == node_node_id + def _create_node_and_set_public_key(self) -> int: + node_id = self.state.create_node(ping_interval=30) + pk_bytes = public_key_to_bytes(self._node_public_key) + self.state.set_node_public_key(node_id, pk_bytes) + return node_id diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index c4cbf6aaf754..ccce6cdd6e05 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -62,6 +62,7 @@ def __init__(self) -> None: # Map node_id to (online_until, ping_interval) self.node_ids: dict[int, tuple[float, float]] = {} self.public_key_to_node_id: dict[bytes, int] = {} + self.node_id_to_public_key: dict[int, bytes] = {} # Map run_id to RunRecord self.run_ids: dict[int, RunRecord] = {} @@ -306,9 +307,7 @@ def num_task_res(self) -> int: """ return len(self.task_res_store) - def create_node( - self, ping_interval: float - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random int64 as node_id node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -318,33 +317,18 @@ def create_node( log(ERROR, "Unexpected node registration failure.") return 0 - if public_key is not None: - if ( - public_key in self.public_key_to_node_id - or node_id in self.public_key_to_node_id.values() - ): - log(ERROR, "Unexpected node registration failure.") - return 0 - - self.public_key_to_node_id[public_key] = node_id - self.node_ids[node_id] = (time.time() + ping_interval, ping_interval) return node_id - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Delete a node.""" with self.lock: if node_id not in self.node_ids: raise ValueError(f"Node {node_id} not found") - if public_key is not None: - if ( - public_key not in self.public_key_to_node_id - or node_id not in self.public_key_to_node_id.values() - ): - raise ValueError("Public key or node_id not found") - - del self.public_key_to_node_id[public_key] + # Remove node ID <> public key mappings + if pk := self.node_id_to_public_key.pop(node_id, None): + del self.public_key_to_node_id[pk] del self.node_ids[node_id] @@ -367,8 +351,24 @@ def get_nodes(self, run_id: int) -> set[int]: } def set_node_public_key(self, node_id: int, public_key: bytes) -> None: - """Store `public_key` for the specified `node_id`.""" - self.public_key_to_node_id[public_key] = node_id + """Set `public_key` for the specified `node_id`.""" + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + + if public_key in self.public_key_to_node_id: + raise ValueError("Public key already in use") + + self.public_key_to_node_id[public_key] = node_id + self.node_id_to_public_key[node_id] = public_key + + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + with self.lock: + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + + return self.node_id_to_public_key.get(node_id) def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index 1458ee4e4a3a..e1eccf2b8b2f 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -158,7 +158,7 @@ def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" @abc.abstractmethod - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Remove `node_id` from the link state.""" @abc.abstractmethod @@ -173,7 +173,11 @@ def get_nodes(self, run_id: int) -> set[int]: @abc.abstractmethod def set_node_public_key(self, node_id: int, public_key: bytes) -> None: - """Store `public_key` for the specified `node_id`.""" + """Set `public_key` for the specified `node_id`.""" + + @abc.abstractmethod + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" @abc.abstractmethod def get_node_id(self, node_public_key: bytes) -> Optional[int]: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 3edaf72ec20c..d3e391c5b62a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -588,7 +588,8 @@ def test_create_node_public_key(self) -> None: run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -602,15 +603,21 @@ def test_create_node_public_key_twice(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) # Execute - new_node_id = state.create_node(ping_interval=10, public_key=public_key) + new_node_id = state.create_node(ping_interval=10) + try: + state.set_node_public_key(new_node_id, public_key) + except ValueError: + state.delete_node(new_node_id) + else: + raise AssertionError("Should have raised ValueError") retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) # Assert - assert new_node_id == 0 assert len(retrieved_node_ids) == 1 assert retrieved_node_id == node_id @@ -639,10 +646,11 @@ def test_delete_node_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) # Execute - state.delete_node(node_id, public_key=public_key) + state.delete_node(node_id) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -660,7 +668,7 @@ def test_delete_node_public_key_none(self) -> None: # Execute & Assert with self.assertRaises(ValueError): - state.delete_node(node_id, public_key=public_key) + state.delete_node(node_id) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(public_key) @@ -668,25 +676,6 @@ def test_delete_node_public_key_none(self) -> None: assert len(retrieved_node_ids) == 0 assert retrieved_node_id is None - def test_delete_node_wrong_public_key(self) -> None: - """Test deleting a client node with wrong public key.""" - # Prepare - state: LinkState = self.state_factory() - public_key = b"mock" - wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = state.create_node(ping_interval=10, public_key=public_key) - - # Execute & Assert - with self.assertRaises(ValueError): - state.delete_node(node_id, public_key=wrong_public_key) - - retrieved_node_ids = state.get_nodes(run_id) - retrieved_node_id = state.get_node_id(public_key) - - assert len(retrieved_node_ids) == 1 - assert retrieved_node_id == node_id - def test_get_node_id_wrong_public_key(self) -> None: """Test retrieving a client node with wrong public key.""" # Prepare @@ -696,7 +685,8 @@ def test_get_node_id_wrong_public_key(self) -> None: run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute - state.create_node(ping_interval=10, public_key=public_key) + node_id = state.create_node(ping_interval=10) + state.set_node_public_key(node_id, public_key) retrieved_node_ids = state.get_nodes(run_id) retrieved_node_id = state.get_node_id(wrong_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index e8311dfaac5e..cc773f7b93de 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -72,14 +72,14 @@ SQL_CREATE_TABLE_CREDENTIAL = """ CREATE TABLE IF NOT EXISTS credential( - private_key BLOB PRIMARY KEY, - public_key BLOB + private_key BLOB PRIMARY KEY, + public_key BLOB ); """ SQL_CREATE_TABLE_PUBLIC_KEY = """ CREATE TABLE IF NOT EXISTS public_key( - public_key BLOB UNIQUE + public_key BLOB PRIMARY KEY ); """ @@ -635,9 +635,7 @@ def get_task_ids_from_run_id(self, run_id: int) -> set[UUID]: return {UUID(row["task_id"]) for row in rows} - def create_node( - self, ping_interval: float, public_key: Optional[bytes] = None - ) -> int: + def create_node(self, ping_interval: float) -> int: """Create, store in the link state, and return `node_id`.""" # Sample a random uint64 as node_id uint64_node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) @@ -645,13 +643,6 @@ def create_node( # Convert the uint64 value to sint64 for SQLite sint64_node_id = convert_uint64_to_sint64(uint64_node_id) - query = "SELECT node_id FROM node WHERE public_key = :public_key;" - row = self.query(query, {"public_key": public_key}) - - if len(row) > 0: - log(ERROR, "Unexpected node registration failure.") - return 0 - query = ( "INSERT INTO node " "(node_id, online_until, ping_interval, public_key) " @@ -665,7 +656,7 @@ def create_node( sint64_node_id, time.time() + ping_interval, ping_interval, - public_key, + b"", # Initialize with an empty public key ), ) except sqlite3.IntegrityError: @@ -675,7 +666,7 @@ def create_node( # Note: we need to return the uint64 value of the node_id return uint64_node_id - def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: + def delete_node(self, node_id: int) -> None: """Delete a node.""" # Convert the uint64 value to sint64 for SQLite sint64_node_id = convert_uint64_to_sint64(node_id) @@ -683,10 +674,6 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: query = "DELETE FROM node WHERE node_id = ?" params = (sint64_node_id,) - if public_key is not None: - query += " AND public_key = ?" - params += (public_key,) # type: ignore - if self.conn is None: raise AttributeError("LinkState is not initialized.") @@ -694,7 +681,7 @@ def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: with self.conn: rows = self.conn.execute(query, params) if rows.rowcount < 1: - raise ValueError("Public key or node_id not found") + raise ValueError(f"Node {node_id} not found") except KeyError as exc: log(ERROR, {"query": query, "data": params, "exception": exc}) @@ -722,6 +709,41 @@ def get_nodes(self, run_id: int) -> set[int]: result: set[int] = {convert_sint64_to_uint64(row["node_id"]) for row in rows} return result + def set_node_public_key(self, node_id: int, public_key: bytes) -> None: + """Set `public_key` for the specified `node_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = convert_uint64_to_sint64(node_id) + + # Check if the node exists in the `node` table + query = "SELECT 1 FROM node WHERE node_id = ?" + if not self.query(query, (sint64_node_id,)): + raise ValueError(f"Node {node_id} not found") + + # Check if the public key is already in use in the `node` table + query = "SELECT 1 FROM node WHERE public_key = ?" + if self.query(query, (public_key,)): + raise ValueError("Public key already in use") + + # Update the `node` table to set the public key for the given node ID + query = "UPDATE node SET public_key = ? WHERE node_id = ?" + self.query(query, (public_key, sint64_node_id)) + + def get_node_public_key(self, node_id: int) -> Optional[bytes]: + """Get `public_key` for the specified `node_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = convert_uint64_to_sint64(node_id) + + # Query the public key for the given node_id + query = "SELECT public_key FROM node WHERE node_id = ?" + rows = self.query(query, (sint64_node_id,)) + + # If no result is found, return None + if not rows: + raise ValueError(f"Node {node_id} not found") + + # Return the public key if it is not empty, otherwise return None + return rows[0]["public_key"] or None + def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" query = "SELECT node_id FROM node WHERE public_key = :public_key;" From dc260850f3f358fa3c3aaeafcb7f4e3eccf6b1bf Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Sun, 5 Jan 2025 18:06:40 +0000 Subject: [PATCH 03/10] rm unnecessary test --- .../superlink/linkstate/linkstate_test.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index d3e391c5b62a..fd1051e1cbfc 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -658,24 +658,6 @@ def test_delete_node_public_key(self) -> None: assert len(retrieved_node_ids) == 0 assert retrieved_node_id is None - def test_delete_node_public_key_none(self) -> None: - """Test deleting a client node with public key.""" - # Prepare - state: LinkState = self.state_factory() - public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) - node_id = 0 - - # Execute & Assert - with self.assertRaises(ValueError): - state.delete_node(node_id) - - retrieved_node_ids = state.get_nodes(run_id) - retrieved_node_id = state.get_node_id(public_key) - - assert len(retrieved_node_ids) == 0 - assert retrieved_node_id is None - def test_get_node_id_wrong_public_key(self) -> None: """Test retrieving a client node with wrong public key.""" # Prepare From 0732c43f33b22986888affc8ca712cbf9df91ca3 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 7 Jan 2025 19:22:52 +0000 Subject: [PATCH 04/10] init --- .../grpc_rere_client/client_interceptor.py | 138 ++------- src/py/flwr/common/constant.py | 6 + .../fleet/grpc_rere/server_interceptor.py | 264 +++++++----------- 3 files changed, 120 insertions(+), 288 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 041860957db7..e5b16009e563 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -15,67 +15,18 @@ """Flower client interceptor.""" -import base64 -import collections -from collections.abc import Sequence -from logging import WARNING -from typing import Any, Callable, Optional, Union +from typing import Any, Callable import grpc from cryptography.hazmat.primitives.asymmetric import ec +from google.protobuf.message import Message as GrpcMessage -from flwr.common.logger import log +from flwr.common import now +from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_public_key, - compute_hmac, - generate_shared_key, public_key_to_bytes, + sign_message, ) -from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611 -from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 - CreateNodeRequest, - DeleteNodeRequest, - PingRequest, - PullTaskInsRequest, - PushTaskResRequest, -) -from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611 - -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, -] - - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() - - return value - - -class _ClientCallDetails( - collections.namedtuple( - "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") - ), - grpc.ClientCallDetails, # type: ignore -): - """Details for each client call. - - The class will be passed on as the first argument in continuation function. - In our case, `AuthenticateClientInterceptor` adds new metadata to the construct. - """ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore @@ -87,84 +38,33 @@ def __init__( public_key: ec.EllipticCurvePublicKey, ): self.private_key = private_key - self.public_key = public_key - self.shared_secret: Optional[bytes] = None - self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None - self.encoded_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self.public_key) - ) + self.public_key_bytes = public_key_to_bytes(public_key) def intercept_unary_unary( self, continuation: Callable[[Any, Any], Any], client_call_details: grpc.ClientCallDetails, - request: Request, + request: GrpcMessage, ) -> grpc.Call: """Flower client interceptor. Intercept unary call from client and add necessary authentication header in the RPC metadata. """ - metadata = [] - postprocess = False - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - - # Always add the public key header - metadata.append( - ( - _PUBLIC_KEY_HEADER, - self.encoded_public_key, - ) - ) - - if isinstance(request, CreateNodeRequest): - postprocess = True - elif isinstance( - request, - ( - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - ), - ): - if self.shared_secret is None: - raise RuntimeError("Failure to compute hmac") - - message_bytes = request.SerializeToString(deterministic=True) - metadata.append( - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode( - compute_hmac(self.shared_secret, message_bytes) - ), - ) - ) + metadata = list(client_call_details.metadata or []) - client_call_details = _ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - ) + # Add the public key + metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes)) - response = continuation(client_call_details, request) - if postprocess: - server_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata()) - ) + # Add timestamp + timestamp = now().isoformat() + metadata.append((TIMESTAMP_HEADER, timestamp)) - if server_public_key_bytes != b"": - self.server_public_key = bytes_to_public_key(server_public_key_bytes) - else: - log(WARNING, "Can't get server public key, SuperLink may be offline") + # Sign and add the signature + signature = sign_message(self.private_key, timestamp.encode("ascii")) + metadata.append((SIGNATURE_HEADER, signature)) - if self.server_public_key is not None: - self.shared_secret = generate_shared_key( - self.private_key, self.server_public_key - ) + # Overwrite the metadata + details = client_call_details._replace(metadata=metadata) - return response + return continuation(details, request) diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 9ea23e78c009..1017cf5dc154 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -115,6 +115,12 @@ CREDENTIALS_DIR = ".credentials" AUTH_TYPE = "auth_type" +# Constants for node authentication +PUBLIC_KEY_HEADER = "public-key-bin" # Must end with "-bin" for binary data +SIGNATURE_HEADER = "signature-bin" # Must end with "-bin" for binary data +TIMESTAMP_HEADER = "timestamp" +TIMESTAMP_TOLERANCE = 10 # Tolerance for timestamp verification + class MessageType: """Message type.""" diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 6cafaaa21459..cb0eb38597c7 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -15,91 +15,54 @@ """Flower server interceptor.""" -import base64 -from collections.abc import Sequence -from logging import INFO, WARNING -from typing import Any, Callable, Optional, Union +import datetime +from typing import Any, Callable, Optional, cast import grpc -from cryptography.hazmat.primitives.asymmetric import ec - -from flwr.common.logger import log +from google.protobuf.message import Message as GrpcMessage + +from flwr.common import now +from flwr.common.constant import ( + PUBLIC_KEY_HEADER, + SIGNATURE_HEADER, + TIMESTAMP_HEADER, + TIMESTAMP_TOLERANCE, +) from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, bytes_to_public_key, - generate_shared_key, - verify_hmac, + verify_signature, ) -from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, - DeleteNodeRequest, - DeleteNodeResponse, - PingRequest, - PingResponse, - PullTaskInsRequest, - PullTaskInsResponse, - PushTaskResRequest, - PushTaskResResponse, ) -from flwr.proto.node_pb2 import Node # pylint: disable=E0611 -from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.server.superlink.linkstate import LinkStateFactory -_PUBLIC_KEY_HEADER = "public-key" -_AUTH_TOKEN_HEADER = "auth-token" - -Request = Union[ - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, -] - -Response = Union[ - CreateNodeResponse, - DeleteNodeResponse, - PullTaskInsResponse, - PushTaskResResponse, - GetRunResponse, - PingResponse, - GetFabResponse, -] - -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() +def _unary_unary_rpc_terminator(message: str) -> grpc.RpcMethodHandler: + def terminate(_request: GrpcMessage, context: grpc.ServicerContext) -> GrpcMessage: + context.abort(grpc.StatusCode.UNAUTHENTICATED, message) + raise RuntimeError("Should not reach this point") # Make mypy happy - return value + return grpc.unary_unary_rpc_method_handler(terminate) class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore - """Server interceptor for node authentication.""" - - def __init__(self, state_factory: LinkStateFactory): + """Server interceptor for node authentication. + + Parameters + ---------- + state_factory : LinkStateFactory + A factory for creating new instances of LinkState. + auto_auth : bool + If True, automatically authenticates nodes without verifying their public keys. + If False, only nodes with pre-stored public keys in the LinkState can be + authenticated. + """ + + def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False): self.state_factory = state_factory - state = self.state_factory.state() - - self.node_public_keys = state.get_node_public_keys() - if len(self.node_public_keys) == 0: - log(WARNING, "Authentication enabled, but no known public keys configured") - - private_key = state.get_server_private_key() - public_key = state.get_server_public_key() - - if private_key is None or public_key is None: - raise ValueError("Error loading authentication keys") - - self.server_private_key = bytes_to_private_key(private_key) - self.encoded_server_public_key = base64.urlsafe_b64encode(public_key) + self.auto_auth = auto_auth def intercept_service( self, @@ -112,117 +75,80 @@ def intercept_service( metadata sent by the node. Continue RPC call if node is authenticated, else, terminate RPC call by setting context to abort. """ + state = self.state_factory.state() + metadata_dict = dict(handler_call_details.invocation_metadata) + + # Retrieve info from the metadata + try: + node_pk_bytes = cast(bytes, metadata_dict[PUBLIC_KEY_HEADER]) + timestamp_iso = cast(str, metadata_dict[TIMESTAMP_HEADER]) + signature = cast(bytes, metadata_dict[SIGNATURE_HEADER]) + except KeyError: + return _unary_unary_rpc_terminator("Missing authentication metadata") + + if not self.auto_auth: + # Abort the RPC call if the node public key is not found + if node_pk_bytes not in state.get_node_public_keys(): + return _unary_unary_rpc_terminator("Public key not recognized") + + # Verify the signature + node_pk = bytes_to_public_key(node_pk_bytes) + if not verify_signature(node_pk, timestamp_iso.encode("ascii"), signature): + return _unary_unary_rpc_terminator("Invalid signature") + + # Verify the timestamp + current = now() + time_diff = current - datetime.datetime.fromisoformat(timestamp_iso) + # Abort the RPC call if the timestamp is too old or in the future + if not 0 <= time_diff.total_seconds() < TIMESTAMP_TOLERANCE: + return _unary_unary_rpc_terminator("Invalid timestamp") + + # Continue the RPC call + expected_node_id = state.get_node_id(node_pk_bytes) + if not handler_call_details.method.endswith("CreateNode"): + if expected_node_id is None: + return _unary_unary_rpc_terminator("Invalid node ID") # One of the method handlers in # `flwr.server.superlink.fleet.grpc_rere.fleet_server.FleetServicer` method_handler: grpc.RpcMethodHandler = continuation(handler_call_details) - return self._generic_auth_unary_method_handler(method_handler) + return self._wrap_method_handler( + method_handler, expected_node_id, node_pk_bytes + ) - def _generic_auth_unary_method_handler( - self, method_handler: grpc.RpcMethodHandler + def _wrap_method_handler( + self, + method_handler: grpc.RpcMethodHandler, + expected_node_id: Optional[int], + node_public_key: bytes, ) -> grpc.RpcMethodHandler: def _generic_method_handler( - request: Request, + request: GrpcMessage, context: grpc.ServicerContext, - ) -> Response: - node_public_key_bytes = base64.urlsafe_b64decode( - _get_value_from_tuples( - _PUBLIC_KEY_HEADER, context.invocation_metadata() - ) - ) - if node_public_key_bytes not in self.node_public_keys: - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - if isinstance(request, CreateNodeRequest): - response = self._create_authenticated_node( - node_public_key_bytes, request, context - ) - log( - INFO, - "AuthenticateServerInterceptor: Created node_id=%s", - response.node.node_id, - ) - return response - - # Verify hmac value - hmac_value = base64.urlsafe_b64decode( - _get_value_from_tuples( - _AUTH_TOKEN_HEADER, context.invocation_metadata() - ) - ) - public_key = bytes_to_public_key(node_public_key_bytes) - - if not self._verify_hmac(public_key, request, hmac_value): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - # Verify node_id - node_id = self.state_factory.state().get_node_id(node_public_key_bytes) - - if not self._verify_node_id(node_id, request): - context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied") - - return method_handler.unary_unary(request, context) # type: ignore + ) -> GrpcMessage: + # Verify the node ID + if not isinstance(request, CreateNodeRequest): + try: + if request.node.node_id != expected_node_id: # type: ignore + raise ValueError + except (AttributeError, ValueError): + context.abort(grpc.StatusCode.UNAUTHENTICATED, "Invalid node ID") + + response: GrpcMessage = method_handler.unary_unary(request, context) + + # Set the public key after a successful CreateNode request + if isinstance(response, CreateNodeResponse): + state = self.state_factory.state() + try: + state.set_node_public_key(response.node.node_id, node_public_key) + except ValueError: + context.abort( + grpc.StatusCode.UNAUTHENTICATED, "Public key already in use" + ) + + return response return grpc.unary_unary_rpc_method_handler( _generic_method_handler, request_deserializer=method_handler.request_deserializer, response_serializer=method_handler.response_serializer, ) - - def _verify_node_id( - self, - node_id: Optional[int], - request: Union[ - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, - GetRunRequest, - PingRequest, - GetFabRequest, - ], - ) -> bool: - if node_id is None: - return False - if isinstance(request, PushTaskResRequest): - if len(request.task_res_list) == 0: - return False - return request.task_res_list[0].task.producer.node_id == node_id - if isinstance(request, GetRunRequest): - return node_id in self.state_factory.state().get_nodes(request.run_id) - return request.node.node_id == node_id - - def _verify_hmac( - self, public_key: ec.EllipticCurvePublicKey, request: Request, hmac_value: bytes - ) -> bool: - shared_secret = generate_shared_key(self.server_private_key, public_key) - message_bytes = request.SerializeToString(deterministic=True) - return verify_hmac(shared_secret, message_bytes, hmac_value) - - def _create_authenticated_node( - self, - public_key_bytes: bytes, - request: CreateNodeRequest, - context: grpc.ServicerContext, - ) -> CreateNodeResponse: - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - self.encoded_server_public_key, - ), - ) - ) - state = self.state_factory.state() - node_id = state.get_node_id(public_key_bytes) - - # Handle `CreateNode` here instead of calling the default method handler - # Return previously assigned `node_id` for the provided `public_key` - if node_id is not None: - state.acknowledge_ping(node_id, request.ping_interval) - return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) - - # No `node_id` exists for the provided `public_key` - # Handle `CreateNode` here instead of calling the default method handler - # Note: the innermost `CreateNode` method will never be called - node_id = state.create_node(request.ping_interval) - state.set_node_public_key(node_id, public_key_bytes) - return CreateNodeResponse(node=Node(node_id=node_id, anonymous=False)) From 28b0f9ac3a63383dee90afd8a07ceb517ba3f84c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 9 Jan 2025 17:18:10 +0000 Subject: [PATCH 05/10] Update src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py Co-authored-by: Javier --- .../flwr/server/superlink/fleet/grpc_rere/server_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index cb0eb38597c7..db389ac79adc 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -100,7 +100,7 @@ def intercept_service( current = now() time_diff = current - datetime.datetime.fromisoformat(timestamp_iso) # Abort the RPC call if the timestamp is too old or in the future - if not 0 <= time_diff.total_seconds() < TIMESTAMP_TOLERANCE: + if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE: return _unary_unary_rpc_terminator("Invalid timestamp") # Continue the RPC call From ec6a74384fd7e60c044eb25bae1b9a270d5b279c Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Thu, 9 Jan 2025 17:23:47 +0000 Subject: [PATCH 06/10] use error message --- .../server/superlink/fleet/grpc_rere/server_interceptor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index db389ac79adc..38ef0f829dc0 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -140,10 +140,8 @@ def _generic_method_handler( state = self.state_factory.state() try: state.set_node_public_key(response.node.node_id, node_public_key) - except ValueError: - context.abort( - grpc.StatusCode.UNAUTHENTICATED, "Public key already in use" - ) + except ValueError as e: + context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e)) return response From e49d1f0f87a2117bcbac1418d13f8d6e4d24bb30 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 21 Jan 2025 18:58:43 +0000 Subject: [PATCH 07/10] feat(framework) Implement signature-based authentication (#4782) --- .../client_interceptor_test.py | 276 ++-------- .../grpc_rere/server_interceptor_test.py | 504 +++++++----------- 2 files changed, 241 insertions(+), 539 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index a029b926423f..34a0ae6bd91f 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -15,28 +15,28 @@ """Flower client interceptor tests.""" -import base64 -import inspect import threading import unittest from collections.abc import Sequence from concurrent import futures from logging import DEBUG, INFO, WARN -from typing import Optional, Union, get_args +from typing import Any, Callable, Optional, Union import grpc +from google.protobuf.message import Message as GrpcMessage +from parameterized import parameterized from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH, serde +from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER from flwr.common.logger import log from flwr.common.message import Message, Metadata from flwr.common.record import RecordSet from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, generate_key_pairs, - generate_shared_key, public_key_to_bytes, + verify_signature, ) from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -48,13 +48,10 @@ PushTaskResRequest, PushTaskResResponse, ) -from flwr.proto.fleet_pb2_grpc import FleetServicer from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611 from flwr.proto.task_pb2 import Task, TaskIns # pylint: disable=E0611 -from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request - class _MockServicer: """Mock Servicer for Flower clients.""" @@ -65,35 +62,24 @@ def __init__(self) -> None: self._received_client_metadata: Optional[ Sequence[tuple[str, Union[str, bytes]]] ] = None - self.server_private_key, self.server_public_key = generate_key_pairs() self._received_message_bytes: bytes = b"" def unary_unary( - self, request: Request, context: grpc.ServicerContext - ) -> Union[ - CreateNodeResponse, DeleteNodeResponse, PushTaskResResponse, PullTaskInsResponse - ]: + self, request: GrpcMessage, context: grpc.ServicerContext + ) -> GrpcMessage: """Handle unary call.""" with self._lock: self._received_client_metadata = context.invocation_metadata() self._received_message_bytes = request.SerializeToString(deterministic=True) if isinstance(request, CreateNodeRequest): - context.send_initial_metadata( - ( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self.server_public_key) - ), - ), - ) - ) return CreateNodeResponse(node=Node(node_id=123)) if isinstance(request, DeleteNodeRequest): return DeleteNodeResponse() if isinstance(request, PushTaskResRequest): return PushTaskResResponse() + if isinstance(request, GetRunRequest): + return GetRunResponse() return PullTaskInsResponse( task_ins_list=[ @@ -153,16 +139,6 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: server.add_generic_rpc_handlers((generic_handler,)) -def _get_value_from_tuples( - key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]] -) -> bytes: - value = next((value for key, value in tuples if key == key_string), "") - if isinstance(value, str): - return value.encode() - - return value - - def _init_retry_invoker() -> RetryInvoker: return RetryInvoker( wait_gen_factory=exponential, @@ -201,6 +177,36 @@ def _init_retry_invoker() -> RetryInvoker: ) +def _create_node(conn: Any) -> None: + create_node = conn[2] + create_node() + + +def _delete_node(conn: Any) -> None: + _, _, create_node, delete_node, _, _ = conn + create_node() + delete_node() + + +def _receive(conn: Any) -> None: + receive, _, create_node, _, _, _ = conn + create_node() + receive() + + +def _send(conn: Any) -> None: + receive, send, create_node, _, _, _ = conn + create_node() + receive() + send(Message(Metadata(0, "", 123, 0, "", "", 0, ""), RecordSet())) + + +def _get_run(conn: Any) -> None: + _, _, create_node, _, get_run, _ = conn + create_node() + get_run(0) + + class TestAuthenticateClientInterceptor(unittest.TestCase): """Test for client interceptor client authentication.""" @@ -219,7 +225,10 @@ def setUp(self) -> None: self._connection = grpc_request_response self._address = f"localhost:{port}" - def test_client_auth_create_node(self) -> None: + @parameterized.expand( + [(_create_node,), (_delete_node,), (_receive,), (_send,), (_get_run,)] + ) # type: ignore + def test_client_auth_rpc(self, grpc_call: Callable[[Any], None]) -> None: """Test client authentication during create node.""" # Prepare retry_invoker = _init_retry_invoker() @@ -233,190 +242,25 @@ def test_client_auth_create_node(self) -> None: None, (self._client_private_key, self._client_public_key), ) as conn: - _, _, create_node, _, _, _ = conn - assert create_node is not None - create_node() + grpc_call(conn) received_metadata = self._servicer.received_client_metadata() assert received_metadata is not None - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) + metadata_dict = dict(received_metadata) + actual_public_key = metadata_dict[PUBLIC_KEY_HEADER] + signature = metadata_dict[SIGNATURE_HEADER] + timestamp = metadata_dict[TIMESTAMP_HEADER] - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) + expected_public_key = public_key_to_bytes(self._client_public_key) # Assert + assert isinstance(signature, bytes) + assert isinstance(timestamp, str) assert actual_public_key == expected_public_key - - def test_client_auth_delete_node(self) -> None: - """Test client authentication during delete node.""" - # Prepare - retry_invoker = _init_retry_invoker() - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - _, _, create_node, delete_node, _, _ = conn - assert create_node is not None - create_node() - assert delete_node is not None - delete_node() - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) + assert verify_signature( + self._client_public_key, timestamp.encode("ascii"), signature ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac - - def test_client_auth_receive(self) -> None: - """Test client authentication during receive node.""" - # Prepare - retry_invoker = _init_retry_invoker() - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - receive, _, create_node, _, _, _ = conn - assert create_node is not None - create_node() - assert receive is not None - receive() - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) - ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac - - def test_client_auth_send(self) -> None: - """Test client authentication during send node.""" - # Prepare - retry_invoker = _init_retry_invoker() - message = Message(Metadata(0, "", 123, 0, "", "", 0, ""), RecordSet()) - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - receive, send, create_node, _, _, _ = conn - assert create_node is not None - create_node() - assert receive is not None - receive() - assert send is not None - send(message) - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) - ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac - - def test_client_auth_get_run(self) -> None: - """Test client authentication during send node.""" - # Prepare - retry_invoker = _init_retry_invoker() - - # Execute - with self._connection( - self._address, - True, - retry_invoker, - GRPC_MAX_MESSAGE_LENGTH, - None, - (self._client_private_key, self._client_public_key), - ) as conn: - _, _, create_node, _, get_run, _ = conn - assert create_node is not None - create_node() - assert get_run is not None - get_run(0) - - received_metadata = self._servicer.received_client_metadata() - assert received_metadata is not None - - shared_secret = generate_shared_key( - self._servicer.server_private_key, self._client_public_key - ) - expected_hmac = base64.urlsafe_b64encode( - compute_hmac(shared_secret, self._servicer.received_message_bytes()) - ) - actual_public_key = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, received_metadata - ) - actual_hmac = _get_value_from_tuples(_AUTH_TOKEN_HEADER, received_metadata) - expected_public_key = base64.urlsafe_b64encode( - public_key_to_bytes(self._client_public_key) - ) - - # Assert - assert actual_public_key == expected_public_key - assert actual_hmac == expected_hmac def test_without_servicer(self) -> None: """Test client authentication without servicer.""" @@ -439,20 +283,6 @@ def test_without_servicer(self) -> None: assert self._servicer.received_client_metadata() is None - def test_fleet_requests_included(self) -> None: - """Test if all Fleet requests are included in the authentication mode.""" - # Prepare - requests = get_args(Request) - rpc_names = {req.__qualname__.removesuffix("Request") for req in requests} - expected_rpc_names = { - name - for name, ref in inspect.getmembers(FleetServicer) - if inspect.isfunction(ref) - } - - # Assert - assert expected_rpc_names == rpc_names - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index 9984b93f3e84..6861d0235c31 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -15,21 +15,28 @@ """Flower server interceptor tests.""" -import base64 +import datetime import unittest +from typing import Any, Callable import grpc - -from flwr.common import ConfigsRecord -from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, Status +from parameterized import parameterized + +from flwr.common import ConfigsRecord, now +from flwr.common.constant import ( + FLEET_API_GRPC_RERE_DEFAULT_ADDRESS, + PUBLIC_KEY_HEADER, + SIGNATURE_HEADER, + TIMESTAMP_HEADER, + Status, +) from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, generate_key_pairs, - generate_shared_key, - private_key_to_bytes, public_key_to_bytes, + sign_message, ) from flwr.common.typing import RunStatus +from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611 from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, CreateNodeResponse, @@ -49,11 +56,7 @@ from flwr.server.superlink.ffs.ffs_factory import FfsFactory from flwr.server.superlink.linkstate.linkstate_factory import LinkStateFactory -from .server_interceptor import ( - _AUTH_TOKEN_HEADER, - _PUBLIC_KEY_HEADER, - AuthenticateServerInterceptor, -) +from .server_interceptor import AuthenticateServerInterceptor class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 @@ -61,18 +64,13 @@ class TestServerInterceptor(unittest.TestCase): # pylint: disable=R0902 def setUp(self) -> None: """Initialize mock stub and server interceptor.""" - self._node_private_key, self._node_public_key = generate_key_pairs() - self._server_private_key, self._server_public_key = generate_key_pairs() + self.node_sk, self.node_pk = generate_key_pairs() state_factory = LinkStateFactory(":flwr-in-memory-state:") self.state = state_factory.state() ffs_factory = FfsFactory(".") self.ffs = ffs_factory.ffs() - self.state.store_server_private_public_key( - private_key_to_bytes(self._server_private_key), - public_key_to_bytes(self._server_public_key), - ) - self.state.store_node_public_keys({public_key_to_bytes(self._node_public_key)}) + self.state.store_node_public_keys({public_key_to_bytes(self.node_pk)}) self._server_interceptor = AuthenticateServerInterceptor(state_factory) self._server: grpc.Server = _run_fleet_api_grpc_rere( @@ -114,332 +112,206 @@ def setUp(self) -> None: request_serializer=PingRequest.SerializeToString, response_deserializer=PingResponse.FromString, ) + self._get_fab = self._channel.unary_unary( + "/flwr.proto.Fleet/GetFab", + request_serializer=GetFabRequest.SerializeToString, + response_deserializer=GetFabResponse.FromString, + ) def tearDown(self) -> None: """Clean up grpc server.""" self._server.stop(None) - def test_successful_create_node_with_metadata(self) -> None: - """Test server interceptor for creating node.""" - # Prepare - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._create_node.with_call( + def _make_metadata(self) -> list[Any]: + """Create metadata with signature and timestamp.""" + timestamp = now().isoformat() + signature = sign_message(self.node_sk, timestamp.encode("ascii")) + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(self.node_pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _make_metadata_with_invalid_signature(self) -> list[Any]: + """Create metadata with invalid signature.""" + timestamp = now().isoformat() + sk, _ = generate_key_pairs() + signature = sign_message(sk, timestamp.encode("ascii")) + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(self.node_pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _make_metadata_with_invalid_public_key(self) -> list[Any]: + """Create metadata with invalid public key.""" + timestamp = now().isoformat() + signature = sign_message(self.node_sk, timestamp.encode("ascii")) + _, pk = generate_key_pairs() + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _make_metadata_with_invalid_timestamp(self) -> list[Any]: + """Create metadata with invalid timestamp.""" + timestamp = (now() - datetime.timedelta(seconds=99)).isoformat() + signature = sign_message(self.node_sk, timestamp.encode("ascii")) + return [ + (PUBLIC_KEY_HEADER, public_key_to_bytes(self.node_pk)), + (SIGNATURE_HEADER, signature), + (TIMESTAMP_HEADER, timestamp), + ] + + def _test_create_node(self, metadata: list[Any]) -> Any: + """Test CreateNode.""" + return self._create_node.with_call( request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - expected_metadata = ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode( - public_key_to_bytes(self._server_public_key) - ).decode(), - ) - - # Assert - assert call.initial_metadata()[0] == expected_metadata - assert isinstance(response, CreateNodeResponse) - - def test_unsuccessful_create_node_with_metadata(self) -> None: - """Test server interceptor for creating node unsuccessfully.""" - # Prepare - _, node_public_key = generate_key_pairs() - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(node_public_key) + metadata=metadata, ) - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._create_node.with_call( - request=CreateNodeRequest(), - metadata=((_PUBLIC_KEY_HEADER, public_key_bytes),), - ) - - def test_successful_delete_node_with_metadata(self) -> None: - """Test server interceptor for deleting node.""" - # Prepare + def _test_delete_node(self, metadata: list[Any]) -> Any: + """Test DeleteNode.""" node_id = self._create_node_and_set_public_key() - request = DeleteNodeRequest(node=Node(node_id=node_id)) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, DeleteNodeResponse) - assert grpc.StatusCode.OK == call.code() + req = DeleteNodeRequest(node=Node(node_id=node_id)) + return self._delete_node.with_call(request=req, metadata=metadata) - def test_unsuccessful_delete_node_with_metadata(self) -> None: - """Test server interceptor for deleting node unsuccessfully.""" - # Prepare + def _test_pull_task_ins(self, metadata: list[Any]) -> Any: + """Test PullTaskIns.""" node_id = self._create_node_and_set_public_key() - request = DeleteNodeRequest(node=Node(node_id=node_id)) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._delete_node.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) + req = PullTaskInsRequest(node=Node(node_id=node_id)) + return self._pull_task_ins.with_call(request=req, metadata=metadata) - def test_successful_pull_task_ins_with_metadata(self) -> None: - """Test server interceptor for pull task ins.""" - # Prepare - node_id = self._create_node_and_set_public_key() - request = PullTaskInsRequest(node=Node(node_id=node_id)) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._pull_task_ins.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, PullTaskInsResponse) - assert grpc.StatusCode.OK == call.code() - - def test_unsuccessful_pull_task_ins_with_metadata(self) -> None: - """Test server interceptor for pull task ins unsuccessfully.""" - # Prepare - node_id = self._create_node_and_set_public_key() - request = PullTaskInsRequest(node=Node(node_id=node_id)) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._pull_task_ins.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - def test_successful_push_task_res_with_metadata(self) -> None: - """Test server interceptor for push task res.""" - # Prepare + def _test_push_task_res(self, metadata: list[Any]) -> Any: + """Test PushTaskRes.""" node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. PushTaskRes is only allowed in running status. - _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - request = PushTaskResRequest( + self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + req = PushTaskResRequest( + node=Node(node_id=node_id), task_res_list=[ TaskRes(task=Task(producer=Node(node_id=node_id)), run_id=run_id) - ] - ) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) + ], ) + return self._push_task_res.with_call(request=req, metadata=metadata) - # Execute - response, call = self._push_task_res.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, PushTaskResResponse) - assert grpc.StatusCode.OK == call.code() - - def test_unsuccessful_push_task_res_with_metadata(self) -> None: - """Test server interceptor for push task res unsuccessfully.""" - # Prepare + def _test_get_run(self, metadata: list[Any]) -> Any: + """Test GetRun.""" node_id = self._create_node_and_set_public_key() run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - # Transition status to running. PushTaskRes is only allowed in running status. - _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - request = PushTaskResRequest( - task_res_list=[TaskRes(task=Task(producer=Node(node_id=node_id)))] - ) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError) as e: - self._push_task_res.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - assert e.exception.code() == grpc.StatusCode.UNAUTHENTICATED - - def test_successful_get_run_with_metadata(self) -> None: - """Test server interceptor for get run.""" - # Prepare - self._create_node_and_set_public_key() - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) # Transition status to running. GetRun is only allowed in running status. - _ = self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) - _ = self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) - request = GetRunRequest(run_id=run_id) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute - response, call = self._get_run.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, GetRunResponse) - assert grpc.StatusCode.OK == call.code() + self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + req = GetRunRequest(node=Node(node_id=node_id), run_id=run_id) + return self._get_run.with_call(request=req, metadata=metadata) - def test_unsuccessful_get_run_with_metadata(self) -> None: - """Test server interceptor for get run unsuccessfully.""" - # Prepare - self._create_node_and_set_public_key() - run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) - request = GetRunRequest(run_id=run_id) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._get_run.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - def test_successful_ping_with_metadata(self) -> None: - """Test server interceptor for ping.""" - # Prepare + def _test_ping(self, metadata: list[Any]) -> Any: + """Test Ping.""" node_id = self._create_node_and_set_public_key() - request = PingRequest(node=Node(node_id=node_id)) - shared_secret = generate_shared_key( - self._node_private_key, self._server_public_key - ) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) - ) + req = PingRequest(node=Node(node_id=node_id)) + return self._ping.with_call(request=req, metadata=metadata) - # Execute - response, call = self._ping.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) - - # Assert - assert isinstance(response, PingResponse) - assert grpc.StatusCode.OK == call.code() - - def test_unsuccessful_ping_with_metadata(self) -> None: - """Test server interceptor for ping unsuccessfully.""" - # Prepare + def _test_get_fab(self, metadata: list[Any]) -> Any: + """Test GetFab.""" + fab_hash = self.ffs.put(b"mock fab content", {}) node_id = self._create_node_and_set_public_key() - request = PingRequest(node=Node(node_id=node_id)) - node_private_key, _ = generate_key_pairs() - shared_secret = generate_shared_key(node_private_key, self._server_public_key) - hmac_value = base64.urlsafe_b64encode( - compute_hmac(shared_secret, request.SerializeToString(deterministic=True)) - ) - public_key_bytes = base64.urlsafe_b64encode( - public_key_to_bytes(self._node_public_key) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) + # Transition status to running. PushTaskRes is only allowed in running status. + self.state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + self.state.update_run_status(run_id, RunStatus(Status.RUNNING, "", "")) + req = GetFabRequest( + node=Node(node_id=node_id), + run_id=run_id, + hash_str=fab_hash, ) - - # Execute & Assert - with self.assertRaises(grpc.RpcError): - self._ping.with_call( - request=request, - metadata=( - (_PUBLIC_KEY_HEADER, public_key_bytes), - (_AUTH_TOKEN_HEADER, hmac_value), - ), - ) + return self._get_fab.with_call(request=req, metadata=metadata) def _create_node_and_set_public_key(self) -> int: node_id = self.state.create_node(ping_interval=30) - pk_bytes = public_key_to_bytes(self._node_public_key) + pk_bytes = public_key_to_bytes(self.node_pk) self.state.set_node_public_key(node_id, pk_bytes) return node_id + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_successful_rpc_with_metadata( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC.""" + # Execute + _, call = rpc(self, self._make_metadata()) + + # Assert + assert call.code() == grpc.StatusCode.OK + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_unsuccessful_rpc_with_invalid_signature( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC unsuccessfully.""" + # Execute & Assert + with self.assertRaises(grpc.RpcError) as cm: + rpc(self, self._make_metadata_with_invalid_signature()) + assert cm.exception.code() == grpc.StatusCode.UNAUTHENTICATED + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_unsuccessful_rpc_with_invalid_public_key( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC unsuccessfully.""" + # Execute & Assert + with self.assertRaises(grpc.RpcError) as cm: + rpc(self, self._make_metadata_with_invalid_public_key()) + assert cm.exception.code() == grpc.StatusCode.UNAUTHENTICATED + + @parameterized.expand( + [ + (_test_create_node,), + (_test_delete_node,), + (_test_pull_task_ins,), + (_test_push_task_res,), + (_test_get_run,), + (_test_ping,), + (_test_get_fab,), + ] + ) # type: ignore + def test_unsuccessful_rpc_with_invalid_timestamp( + self, rpc: Callable[[Any, list[Any]], Any] + ) -> None: + """Test server interceptor for RPC unsuccessfully.""" + # Execute & Assert + with self.assertRaises(grpc.RpcError) as cm: + rpc(self, self._make_metadata_with_invalid_timestamp()) + assert cm.exception.code() == grpc.StatusCode.UNAUTHENTICATED From e418b9f548a498c0aa10233bfc25349a3120aba6 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 21 Jan 2025 22:38:22 +0000 Subject: [PATCH 08/10] Update src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py --- .../server/superlink/fleet/grpc_rere/server_interceptor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 38ef0f829dc0..7cf7d3b8ec7f 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -55,9 +55,9 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore state_factory : LinkStateFactory A factory for creating new instances of LinkState. auto_auth : bool - If True, automatically authenticates nodes without verifying their public keys. - If False, only nodes with pre-stored public keys in the LinkState can be - authenticated. + If True, nodes are authenticated without requiring their public keys to be + pre-stored in the LinkState. If False, only nodes with pre-stored public keys + can be authenticated. """ def __init__(self, state_factory: LinkStateFactory, auto_auth: bool = False): From 4d5b1685da8131ee7be54934386ac5ec08d503ea Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 21 Jan 2025 22:39:23 +0000 Subject: [PATCH 09/10] Update src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py --- .../flwr/server/superlink/fleet/grpc_rere/server_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 7cf7d3b8ec7f..3935a27fa61f 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -54,7 +54,7 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore ---------- state_factory : LinkStateFactory A factory for creating new instances of LinkState. - auto_auth : bool + auto_auth : bool (default: False) If True, nodes are authenticated without requiring their public keys to be pre-stored in the LinkState. If False, only nodes with pre-stored public keys can be authenticated. From 4acc84af8ccf7ff3f984186f9d34f71dc794c600 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 21 Jan 2025 22:52:47 +0000 Subject: [PATCH 10/10] del node if fail --- .../server/superlink/fleet/grpc_rere/server_interceptor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py index 3935a27fa61f..2197ee266ac9 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor.py @@ -55,8 +55,8 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore state_factory : LinkStateFactory A factory for creating new instances of LinkState. auto_auth : bool (default: False) - If True, nodes are authenticated without requiring their public keys to be - pre-stored in the LinkState. If False, only nodes with pre-stored public keys + If True, nodes are authenticated without requiring their public keys to be + pre-stored in the LinkState. If False, only nodes with pre-stored public keys can be authenticated. """ @@ -141,6 +141,8 @@ def _generic_method_handler( try: state.set_node_public_key(response.node.node_id, node_public_key) except ValueError as e: + # Remove newly created node if setting the public key fails + state.delete_node(response.node.node_id) context.abort(grpc.StatusCode.UNAUTHENTICATED, str(e)) return response