Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Implement signature-based authentication interceptors #4791

Merged
merged 14 commits into from
Jan 21, 2025
Merged
144 changes: 19 additions & 125 deletions src/py/flwr/client/grpc_rere_client/client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,71 +15,18 @@
"""Flower client interceptor."""


import base64
import collections
from collections.abc import Sequence
from logging import WARNING
from typing import Any, Callable, Optional, Union
from typing import Any, Callable

import grpc
from cryptography.hazmat.primitives.asymmetric import ec
from google.protobuf.message import Message as GrpcMessage

from flwr.common.logger import log
from flwr.common import now
from flwr.common.constant import PUBLIC_KEY_HEADER, SIGNATURE_HEADER, TIMESTAMP_HEADER
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
bytes_to_public_key,
compute_hmac,
generate_shared_key,
public_key_to_bytes,
sign_message,
)
from flwr.proto.fab_pb2 import GetFabRequest # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
DeleteNodeRequest,
PingRequest,
PullMessagesRequest,
PullTaskInsRequest,
PushMessagesRequest,
PushTaskResRequest,
)
from flwr.proto.run_pb2 import GetRunRequest # pylint: disable=E0611

_PUBLIC_KEY_HEADER = "public-key"
_AUTH_TOKEN_HEADER = "auth-token"

Request = Union[
CreateNodeRequest,
DeleteNodeRequest,
PullTaskInsRequest,
PushTaskResRequest,
GetRunRequest,
PingRequest,
GetFabRequest,
PullMessagesRequest,
PushMessagesRequest,
]


def _get_value_from_tuples(
key_string: str, tuples: Sequence[tuple[str, Union[str, bytes]]]
) -> bytes:
value = next((value for key, value in tuples if key == key_string), "")
if isinstance(value, str):
return value.encode()

return value


class _ClientCallDetails(
collections.namedtuple(
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials")
),
grpc.ClientCallDetails, # type: ignore
):
"""Details for each client call.

The class will be passed on as the first argument in continuation function.
In our case, `AuthenticateClientInterceptor` adds new metadata to the construct.
"""


class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore
Expand All @@ -91,86 +38,33 @@ def __init__(
public_key: ec.EllipticCurvePublicKey,
):
self.private_key = private_key
self.public_key = public_key
self.shared_secret: Optional[bytes] = None
self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None
self.encoded_public_key = base64.urlsafe_b64encode(
public_key_to_bytes(self.public_key)
)
self.public_key_bytes = public_key_to_bytes(public_key)

def intercept_unary_unary(
self,
continuation: Callable[[Any, Any], Any],
client_call_details: grpc.ClientCallDetails,
request: Request,
request: GrpcMessage,
) -> grpc.Call:
"""Flower client interceptor.

Intercept unary call from client and add necessary authentication header in the
RPC metadata.
"""
metadata = []
postprocess = False
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)

# Always add the public key header
metadata.append(
(
_PUBLIC_KEY_HEADER,
self.encoded_public_key,
)
)

if isinstance(request, CreateNodeRequest):
postprocess = True
elif isinstance(
request,
(
DeleteNodeRequest,
PullTaskInsRequest,
PushTaskResRequest,
GetRunRequest,
PingRequest,
GetFabRequest,
PullMessagesRequest,
PushMessagesRequest,
),
):
if self.shared_secret is None:
raise RuntimeError("Failure to compute hmac")

message_bytes = request.SerializeToString(deterministic=True)
metadata.append(
(
_AUTH_TOKEN_HEADER,
base64.urlsafe_b64encode(
compute_hmac(self.shared_secret, message_bytes)
),
)
)
metadata = list(client_call_details.metadata or [])

client_call_details = _ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
)
# Add the public key
metadata.append((PUBLIC_KEY_HEADER, self.public_key_bytes))

response = continuation(client_call_details, request)
if postprocess:
server_public_key_bytes = base64.urlsafe_b64decode(
_get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata())
)
# Add timestamp
timestamp = now().isoformat()
metadata.append((TIMESTAMP_HEADER, timestamp))

if server_public_key_bytes != b"":
self.server_public_key = bytes_to_public_key(server_public_key_bytes)
else:
log(WARNING, "Can't get server public key, SuperLink may be offline")
# Sign and add the signature
signature = sign_message(self.private_key, timestamp.encode("ascii"))
metadata.append((SIGNATURE_HEADER, signature))

if self.server_public_key is not None:
self.shared_secret = generate_shared_key(
self.private_key, self.server_public_key
)
# Overwrite the metadata
details = client_call_details._replace(metadata=metadata)

return response
return continuation(details, request)
Loading