diff --git a/src/sentry/conf/server.py b/src/sentry/conf/server.py index 47dd011f081d7c..c181ac8c320813 100644 --- a/src/sentry/conf/server.py +++ b/src/sentry/conf/server.py @@ -1344,6 +1344,9 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str: } # Taskworker settings # +# Shared secret used to sign RPC requests to taskbrokers +TASKWORKER_SHARED_SECRET: str | None = None + # The list of modules that workers will import after starting up # Like celery, taskworkers need to import task modules to make tasks # accessible to the worker. @@ -1353,6 +1356,7 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str: ) TASKWORKER_ROUTER: str = "sentry.taskworker.router.DefaultRouter" TASKWORKER_ROUTES: dict[str, str] = {} + # Schedules for taskworker tasks to be spawned on. TASKWORKER_SCHEDULES: ScheduleConfigMap = {} diff --git a/src/sentry/taskworker/client.py b/src/sentry/taskworker/client.py index c8bc02110b0e95..3185e4dd51d80d 100644 --- a/src/sentry/taskworker/client.py +++ b/src/sentry/taskworker/client.py @@ -1,8 +1,14 @@ +import hashlib +import hmac import logging import random +from collections.abc import Callable from datetime import datetime +from typing import Any import grpc +from django.conf import settings +from google.protobuf.message import Message from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( FetchNextTask, GetTaskRequest, @@ -18,6 +24,59 @@ logger = logging.getLogger("sentry.taskworker.client") +class ClientCallDetails(grpc.ClientCallDetails): + """ + Subclass of grpc.ClientCallDetails that allows metadata to be updated + """ + + def __init__( + self, + method: str, + timeout: float | None, + metadata: tuple[tuple[str, str | bytes], ...] | None, + credentials: grpc.CallCredentials | None, + ): + self.timeout = timeout + self.method = method + self.metadata = metadata + self.credentials = credentials + + +# Type alias based on grpc-stubs +ContinuationType = Callable[[ClientCallDetails, Message], Any] + + +# The type stubs for grpc.UnaryUnaryClientInterceptor have generics +# but the implementation in grpc does not, and providing the type parameters +# results in a runtime error. +class RequestSignatureInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore[type-arg] + def __init__(self, shared_secret: str): + self._secret = shared_secret.encode("utf-8") + + def intercept_unary_unary( + self, + continuation: ContinuationType, + client_call_details: grpc.ClientCallDetails, + request: Message, + ) -> Any: + request_body = request.SerializeToString() + method = client_call_details.method.encode("utf-8") + + signing_payload = method + b":" + request_body + signature = hmac.new(self._secret, signing_payload, hashlib.sha256).hexdigest() + + metadata = list(client_call_details.metadata) if client_call_details.metadata else [] + metadata.append(("sentry-signature", signature)) + + call_details_with_meta = ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + tuple(metadata), + client_call_details.credentials, + ) + return continuation(call_details_with_meta, request) + + class TaskworkerClient: """ Taskworker RPC client wrapper @@ -33,7 +92,12 @@ def __init__(self, host: str, num_brokers: int | None) -> None: grpc_options = [("grpc.service_config", grpc_config)] logger.info("Connecting to %s with options %s", self._host, grpc_options) - self._channel = grpc.insecure_channel(self._host, options=grpc_options) + channel = grpc.insecure_channel(self._host, options=grpc_options) + if settings.TASKWORKER_SHARED_SECRET: + channel = grpc.intercept_channel( + channel, RequestSignatureInterceptor(settings.TASKWORKER_SHARED_SECRET) + ) + self._channel = channel self._stub = ConsumerServiceStub(self._channel) def loadbalance(self, host: str, num_brokers: int) -> str: diff --git a/tests/sentry/taskworker/test_client.py b/tests/sentry/taskworker/test_client.py index 3bb4cdbb16f2f0..e552bfc91b7a56 100644 --- a/tests/sentry/taskworker/test_client.py +++ b/tests/sentry/taskworker/test_client.py @@ -1,3 +1,4 @@ +import dataclasses from collections import defaultdict from collections.abc import Callable from typing import Any @@ -5,6 +6,7 @@ import grpc import pytest +from django.test import override_settings from google.protobuf.message import Message from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( TASK_ACTIVATION_STATUS_RETRY, @@ -18,6 +20,12 @@ from sentry.testutils.pytest.fixtures import django_db_all +@dataclasses.dataclass +class MockServiceCall: + response: Any + metadata: tuple[tuple[str, str | bytes], ...] | None = None + + class MockServiceMethod: """Stub for grpc service methods""" @@ -40,9 +48,17 @@ def __call__(self, *args, **kwargs): tail = self.responses[1:] self.responses = tail + [res] - if isinstance(res, Exception): - raise res - return res + if isinstance(res.response, Exception): + raise res.response + return res.response + + def with_call(self, *args, **kwargs): + res = self.responses[0] + if res.metadata: + assert res.metadata == kwargs.get("metadata"), "Metadata mismatch" + if isinstance(res.response, Exception): + raise res.response + return (res.response, None) class MockChannel: @@ -50,14 +66,24 @@ def __init__(self): self._responses = defaultdict(list) def unary_unary( - self, path: str, request_serializer: Callable, response_deserializer: Callable, **kwargs + self, + path: str, + request_serializer: Callable, + response_deserializer: Callable, + *args, + **kwargs, ): return MockServiceMethod( path, self._responses.get(path, []), request_serializer, response_deserializer ) - def add_response(self, path: str, resp: Message | Exception): - self._responses[path].append(resp) + def add_response( + self, + path: str, + resp: Message | Exception, + metadata: tuple[tuple[str, str | bytes], ...] | None = None, + ): + self._responses[path].append(MockServiceCall(response=resp, metadata=metadata)) class MockGrpcError(grpc.RpcError): @@ -73,6 +99,9 @@ def code(self) -> grpc.StatusCode: def details(self) -> str: return self._message + def result(self): + raise self + @django_db_all def test_get_task_ok(): @@ -100,6 +129,39 @@ def test_get_task_ok(): assert result.namespace == "testing" +@django_db_all +@override_settings(TASKWORKER_SHARED_SECRET="a long secret value") +def test_get_task_with_interceptor(): + channel = MockChannel() + channel.add_response( + "/sentry_protos.taskbroker.v1.ConsumerService/GetTask", + GetTaskResponse( + task=TaskActivation( + id="abc123", + namespace="testing", + taskname="do_thing", + parameters="", + headers={}, + processing_deadline_duration=10, + ) + ), + metadata=( + ( + "sentry-signature", + "3202702605c1b65055c28e7c78a5835e760830cff3e9f995eb7ad5f837130b1f", + ), + ), + ) + with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel: + mock_channel.return_value = channel + client = TaskworkerClient("localhost:50051", 1) + result = client.get_task() + + assert result + assert result.id + assert result.namespace == "testing" + + @django_db_all def test_get_task_with_namespace(): channel = MockChannel()