From a768ff8c44c2a8dc92ceac561ca92f3194f848f0 Mon Sep 17 00:00:00 2001 From: Mark Story Date: Wed, 19 Feb 2025 14:19:50 -0500 Subject: [PATCH] feat(taskworker) Add signature based authentication to RPC calls Longer term we may be able to use service mesh authentication, but the requirements for that incur additional infrastructure complexity. This level of authentication will prevent untrusted clients from fetching and updating tasks. Refs getsentry/taskbroker#57 --- src/sentry/conf/server.py | 4 ++ src/sentry/taskworker/client.py | 66 ++++++++++++++++++++++- tests/sentry/taskworker/test_client.py | 74 +++++++++++++++++++++++--- 3 files changed, 137 insertions(+), 7 deletions(-) diff --git a/src/sentry/conf/server.py b/src/sentry/conf/server.py index 47dd011f081d7c..c181ac8c320813 100644 --- a/src/sentry/conf/server.py +++ b/src/sentry/conf/server.py @@ -1344,6 +1344,9 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str: } # Taskworker settings # +# Shared secret used to sign RPC requests to taskbrokers +TASKWORKER_SHARED_SECRET: str | None = None + # The list of modules that workers will import after starting up # Like celery, taskworkers need to import task modules to make tasks # accessible to the worker. @@ -1353,6 +1356,7 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str: ) TASKWORKER_ROUTER: str = "sentry.taskworker.router.DefaultRouter" TASKWORKER_ROUTES: dict[str, str] = {} + # Schedules for taskworker tasks to be spawned on. TASKWORKER_SCHEDULES: ScheduleConfigMap = {} diff --git a/src/sentry/taskworker/client.py b/src/sentry/taskworker/client.py index c8bc02110b0e95..3185e4dd51d80d 100644 --- a/src/sentry/taskworker/client.py +++ b/src/sentry/taskworker/client.py @@ -1,8 +1,14 @@ +import hashlib +import hmac import logging import random +from collections.abc import Callable from datetime import datetime +from typing import Any import grpc +from django.conf import settings +from google.protobuf.message import Message from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( FetchNextTask, GetTaskRequest, @@ -18,6 +24,59 @@ logger = logging.getLogger("sentry.taskworker.client") +class ClientCallDetails(grpc.ClientCallDetails): + """ + Subclass of grpc.ClientCallDetails that allows metadata to be updated + """ + + def __init__( + self, + method: str, + timeout: float | None, + metadata: tuple[tuple[str, str | bytes], ...] | None, + credentials: grpc.CallCredentials | None, + ): + self.timeout = timeout + self.method = method + self.metadata = metadata + self.credentials = credentials + + +# Type alias based on grpc-stubs +ContinuationType = Callable[[ClientCallDetails, Message], Any] + + +# The type stubs for grpc.UnaryUnaryClientInterceptor have generics +# but the implementation in grpc does not, and providing the type parameters +# results in a runtime error. +class RequestSignatureInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore[type-arg] + def __init__(self, shared_secret: str): + self._secret = shared_secret.encode("utf-8") + + def intercept_unary_unary( + self, + continuation: ContinuationType, + client_call_details: grpc.ClientCallDetails, + request: Message, + ) -> Any: + request_body = request.SerializeToString() + method = client_call_details.method.encode("utf-8") + + signing_payload = method + b":" + request_body + signature = hmac.new(self._secret, signing_payload, hashlib.sha256).hexdigest() + + metadata = list(client_call_details.metadata) if client_call_details.metadata else [] + metadata.append(("sentry-signature", signature)) + + call_details_with_meta = ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + tuple(metadata), + client_call_details.credentials, + ) + return continuation(call_details_with_meta, request) + + class TaskworkerClient: """ Taskworker RPC client wrapper @@ -33,7 +92,12 @@ def __init__(self, host: str, num_brokers: int | None) -> None: grpc_options = [("grpc.service_config", grpc_config)] logger.info("Connecting to %s with options %s", self._host, grpc_options) - self._channel = grpc.insecure_channel(self._host, options=grpc_options) + channel = grpc.insecure_channel(self._host, options=grpc_options) + if settings.TASKWORKER_SHARED_SECRET: + channel = grpc.intercept_channel( + channel, RequestSignatureInterceptor(settings.TASKWORKER_SHARED_SECRET) + ) + self._channel = channel self._stub = ConsumerServiceStub(self._channel) def loadbalance(self, host: str, num_brokers: int) -> str: diff --git a/tests/sentry/taskworker/test_client.py b/tests/sentry/taskworker/test_client.py index 3bb4cdbb16f2f0..e552bfc91b7a56 100644 --- a/tests/sentry/taskworker/test_client.py +++ b/tests/sentry/taskworker/test_client.py @@ -1,3 +1,4 @@ +import dataclasses from collections import defaultdict from collections.abc import Callable from typing import Any @@ -5,6 +6,7 @@ import grpc import pytest +from django.test import override_settings from google.protobuf.message import Message from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( TASK_ACTIVATION_STATUS_RETRY, @@ -18,6 +20,12 @@ from sentry.testutils.pytest.fixtures import django_db_all +@dataclasses.dataclass +class MockServiceCall: + response: Any + metadata: tuple[tuple[str, str | bytes], ...] | None = None + + class MockServiceMethod: """Stub for grpc service methods""" @@ -40,9 +48,17 @@ def __call__(self, *args, **kwargs): tail = self.responses[1:] self.responses = tail + [res] - if isinstance(res, Exception): - raise res - return res + if isinstance(res.response, Exception): + raise res.response + return res.response + + def with_call(self, *args, **kwargs): + res = self.responses[0] + if res.metadata: + assert res.metadata == kwargs.get("metadata"), "Metadata mismatch" + if isinstance(res.response, Exception): + raise res.response + return (res.response, None) class MockChannel: @@ -50,14 +66,24 @@ def __init__(self): self._responses = defaultdict(list) def unary_unary( - self, path: str, request_serializer: Callable, response_deserializer: Callable, **kwargs + self, + path: str, + request_serializer: Callable, + response_deserializer: Callable, + *args, + **kwargs, ): return MockServiceMethod( path, self._responses.get(path, []), request_serializer, response_deserializer ) - def add_response(self, path: str, resp: Message | Exception): - self._responses[path].append(resp) + def add_response( + self, + path: str, + resp: Message | Exception, + metadata: tuple[tuple[str, str | bytes], ...] | None = None, + ): + self._responses[path].append(MockServiceCall(response=resp, metadata=metadata)) class MockGrpcError(grpc.RpcError): @@ -73,6 +99,9 @@ def code(self) -> grpc.StatusCode: def details(self) -> str: return self._message + def result(self): + raise self + @django_db_all def test_get_task_ok(): @@ -100,6 +129,39 @@ def test_get_task_ok(): assert result.namespace == "testing" +@django_db_all +@override_settings(TASKWORKER_SHARED_SECRET="a long secret value") +def test_get_task_with_interceptor(): + channel = MockChannel() + channel.add_response( + "/sentry_protos.taskbroker.v1.ConsumerService/GetTask", + GetTaskResponse( + task=TaskActivation( + id="abc123", + namespace="testing", + taskname="do_thing", + parameters="", + headers={}, + processing_deadline_duration=10, + ) + ), + metadata=( + ( + "sentry-signature", + "3202702605c1b65055c28e7c78a5835e760830cff3e9f995eb7ad5f837130b1f", + ), + ), + ) + with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel: + mock_channel.return_value = channel + client = TaskworkerClient("localhost:50051", 1) + result = client.get_task() + + assert result + assert result.id + assert result.namespace == "testing" + + @django_db_all def test_get_task_with_namespace(): channel = MockChannel()