Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(taskworker) Add signature based authentication to RPC calls #85533

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sentry/conf/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,9 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str:
}

# Taskworker settings #
# Shared secret used to sign RPC requests to taskbrokers
TASKWORKER_SHARED_SECRET: str | None = None

# The list of modules that workers will import after starting up
# Like celery, taskworkers need to import task modules to make tasks
# accessible to the worker.
Expand All @@ -1353,6 +1356,7 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str:
)
TASKWORKER_ROUTER: str = "sentry.taskworker.router.DefaultRouter"
TASKWORKER_ROUTES: dict[str, str] = {}

# Schedules for taskworker tasks to be spawned on.
TASKWORKER_SCHEDULES: ScheduleConfigMap = {}

Expand Down
66 changes: 65 additions & 1 deletion src/sentry/taskworker/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import hashlib
import hmac
import logging
import random
from collections.abc import Callable
from datetime import datetime
from typing import Any

import grpc
from django.conf import settings
from google.protobuf.message import Message
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
FetchNextTask,
GetTaskRequest,
Expand All @@ -18,6 +24,59 @@
logger = logging.getLogger("sentry.taskworker.client")


class ClientCallDetails(grpc.ClientCallDetails):
"""
Subclass of grpc.ClientCallDetails that allows metadata to be updated
"""

def __init__(
self,
method: str,
timeout: float | None,
metadata: tuple[tuple[str, str | bytes], ...] | None,
credentials: grpc.CallCredentials | None,
):
self.timeout = timeout
self.method = method
self.metadata = metadata
self.credentials = credentials


# Type alias based on grpc-stubs
ContinuationType = Callable[[ClientCallDetails, Message], Any]


# The type stubs for grpc.UnaryUnaryClientInterceptor have generics
# but the implementation in grpc does not, and providing the type parameters
# results in a runtime error.
class RequestSignatureInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore[type-arg]
def __init__(self, shared_secret: str):
self._secret = shared_secret.encode("utf-8")

def intercept_unary_unary(
self,
continuation: ContinuationType,
client_call_details: grpc.ClientCallDetails,
request: Message,
) -> Any:
request_body = request.SerializeToString()
method = client_call_details.method.encode("utf-8")

signing_payload = method + b":" + request_body
signature = hmac.new(self._secret, signing_payload, hashlib.sha256).hexdigest()

metadata = list(client_call_details.metadata) if client_call_details.metadata else []
metadata.append(("sentry-signature", signature))

call_details_with_meta = ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
tuple(metadata),
client_call_details.credentials,
)
return continuation(call_details_with_meta, request)


class TaskworkerClient:
"""
Taskworker RPC client wrapper
Expand All @@ -33,7 +92,12 @@ def __init__(self, host: str, num_brokers: int | None) -> None:
grpc_options = [("grpc.service_config", grpc_config)]

logger.info("Connecting to %s with options %s", self._host, grpc_options)
self._channel = grpc.insecure_channel(self._host, options=grpc_options)
channel = grpc.insecure_channel(self._host, options=grpc_options)
if settings.TASKWORKER_SHARED_SECRET:
channel = grpc.intercept_channel(
channel, RequestSignatureInterceptor(settings.TASKWORKER_SHARED_SECRET)
)
self._channel = channel
self._stub = ConsumerServiceStub(self._channel)

def loadbalance(self, host: str, num_brokers: int) -> str:
Expand Down
74 changes: 68 additions & 6 deletions tests/sentry/taskworker/test_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import dataclasses
from collections import defaultdict
from collections.abc import Callable
from typing import Any
from unittest.mock import patch

import grpc
import pytest
from django.test import override_settings
from google.protobuf.message import Message
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
TASK_ACTIVATION_STATUS_RETRY,
Expand All @@ -18,6 +20,12 @@
from sentry.testutils.pytest.fixtures import django_db_all


@dataclasses.dataclass
class MockServiceCall:
response: Any
metadata: tuple[tuple[str, str | bytes], ...] | None = None


class MockServiceMethod:
"""Stub for grpc service methods"""

Expand All @@ -40,24 +48,42 @@ def __call__(self, *args, **kwargs):
tail = self.responses[1:]
self.responses = tail + [res]

if isinstance(res, Exception):
raise res
return res
if isinstance(res.response, Exception):
raise res.response
return res.response

def with_call(self, *args, **kwargs):
res = self.responses[0]
if res.metadata:
assert res.metadata == kwargs.get("metadata"), "Metadata mismatch"
if isinstance(res.response, Exception):
raise res.response
return (res.response, None)


class MockChannel:
def __init__(self):
self._responses = defaultdict(list)

def unary_unary(
self, path: str, request_serializer: Callable, response_deserializer: Callable, **kwargs
self,
path: str,
request_serializer: Callable,
response_deserializer: Callable,
*args,
**kwargs,
):
return MockServiceMethod(
path, self._responses.get(path, []), request_serializer, response_deserializer
)

def add_response(self, path: str, resp: Message | Exception):
self._responses[path].append(resp)
def add_response(
self,
path: str,
resp: Message | Exception,
metadata: tuple[tuple[str, str | bytes], ...] | None = None,
):
self._responses[path].append(MockServiceCall(response=resp, metadata=metadata))


class MockGrpcError(grpc.RpcError):
Expand All @@ -73,6 +99,9 @@ def code(self) -> grpc.StatusCode:
def details(self) -> str:
return self._message

def result(self):
raise self


@django_db_all
def test_get_task_ok():
Expand Down Expand Up @@ -100,6 +129,39 @@ def test_get_task_ok():
assert result.namespace == "testing"


@django_db_all
@override_settings(TASKWORKER_SHARED_SECRET="a long secret value")
def test_get_task_with_interceptor():
channel = MockChannel()
channel.add_response(
"/sentry_protos.taskbroker.v1.ConsumerService/GetTask",
GetTaskResponse(
task=TaskActivation(
id="abc123",
namespace="testing",
taskname="do_thing",
parameters="",
headers={},
processing_deadline_duration=10,
)
),
metadata=(
(
"sentry-signature",
"3202702605c1b65055c28e7c78a5835e760830cff3e9f995eb7ad5f837130b1f",
),
),
)
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
mock_channel.return_value = channel
client = TaskworkerClient("localhost:50051", 1)
result = client.get_task()

assert result
assert result.id
assert result.namespace == "testing"


@django_db_all
def test_get_task_with_namespace():
channel = MockChannel()
Expand Down
Loading