Skip to content

Commit

Permalink
feat(framework) Implement signature-based authentication interceptors (
Browse files Browse the repository at this point in the history
…#4791)

Co-authored-by: Javier <[email protected]>
  • Loading branch information
panh99 and jafermarq authored Jan 21, 2025
1 parent 2527881 commit 48f1bfe
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 833 deletions.
144 changes: 19 additions & 125 deletions src/py/flwr/client/grpc_rere_client/client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,71 +15,18 @@
"""Flower client interceptor."""


import base64
import collections
from collections.abc import Sequence
from logging import WARNING
from typing import Any, Callable, Optional, Union
from typing import Any, Callable

import grpc
from cryptography.hazmat.primitives.asymmetric import ec
from google.protobuf.message import Message as GrpcMessage

from flwr.common.logger import log
from flwr.common import now
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
bytes_to_public_key,
compute_hmac,
generate_shared_key,
public_key_to_bytes,
sign_message,
)
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
PingRequest,
PullMessagesRequest,
PullTaskInsRequest,
PushMessagesRequest,
PushTaskResRequest,
)
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611

_PUBLIC_KEY_HEADER = "public-key"
_AUTH_TOKEN_HEADER = "auth-token"

Request = Union[
CreateNodeRequest,
DeleteNodeRequest,
PullTaskInsRequest,
PushTaskResRequest,
GetRunRequest,
PingRequest,
GetFabRequest,
PullMessagesRequest,
PushMessagesRequest,
]


def _get_value_from_tuples(
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
) -> bytes:
value = next((value for key, value in tuples if key == key_string), "")
if isinstance(value, str):
return value.encode()

return value


class _ClientCallDetails(
collections.namedtuple(
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
),
grpc.ClientCallDetails, # type: ignore
):
"""Details for each client call.
The class will be passed on as the first argument in continuation function.
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
"""


class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
Expand All @@ -91,86 +38,33 @@ def __init__(
public_key: ec.EllipticCurvePublicKey,
):
self.private_key = private_key
self.public_key = public_key
self.shared_secret: Optional[bytes] = None
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
self.encoded_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self.public_key)
)
self.public_key_bytes = public_key_to_bytes(public_key)

def intercept_unary_unary(
self,
continuation: Callable[[Any, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Request,
request: GrpcMessage,
) -> grpc.Call:
"""Flower client interceptor.
Intercept unary call from client and add necessary authentication header in the
RPC metadata.
"""
metadata = []
postprocess = False
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)

# Always add the public key header
metadata.append(
(
_PUBLIC_KEY_HEADER,
self.encoded_public_key,
)
)

if isinstance(request, CreateNodeRequest):
postprocess = True
elif isinstance(
request,
(
DeleteNodeRequest,
PullTaskInsRequest,
PushTaskResRequest,
GetRunRequest,
PingRequest,
GetFabRequest,
PullMessagesRequest,
PushMessagesRequest,
),
):
if self.shared_secret is None:
raise RuntimeError("Failure to compute hmac")

message_bytes = request.SerializeToString(deterministic=True)
metadata.append(
(
_AUTH_TOKEN_HEADER,
base64.urlsafe_b64encode(
compute_hmac(self.shared_secret, message_bytes)
),
)
)
metadata = list(client_call_details.metadata or [])

client_call_details = _ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
)
# Add the public key
metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))

response = continuation(client_call_details, request)
if postprocess:
server_public_key_bytes = base64.urlsafe_b64decode(
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
)
# Add timestamp
timestamp = now().isoformat()
metadata.append((TIMESTAMP_HEADER, timestamp))

if server_public_key_bytes != b"":
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
else:
log(WARNING, "Can't get server public key, SuperLink may be offline")
# Sign and add the signature
signature = sign_message(self.private_key, timestamp.encode("ascii"))
metadata.append((SIGNATURE_HEADER, signature))

if self.server_public_key is not None:
self.shared_secret = generate_shared_key(
self.private_key, self.server_public_key
)
# Overwrite the metadata
details = client_call_details._replace(metadata=metadata)

return response
return continuation(details, request)
Loading

0 comments on commit 48f1bfe

Please sign in to comment.