Skip to content

Commit a768ff8

Browse files
committed
feat(taskworker) Add signature based authentication to RPC calls
Longer term we may be able to use service mesh authentication, but the requirements for that incur additional infrastructure complexity. This level of authentication will prevent untrusted clients from fetching and updating tasks. Refs getsentry/taskbroker#57
1 parent 4fb28df commit a768ff8

File tree

3 files changed

+137
-7
lines changed

3 files changed

+137
-7
lines changed

src/sentry/conf/server.py

+4
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,9 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str:
13441344
}
13451345

13461346
# Taskworker settings #
1347+
# Shared secret used to sign RPC requests to taskbrokers
1348+
TASKWORKER_SHARED_SECRET: str | None = None
1349+
13471350
# The list of modules that workers will import after starting up
13481351
# Like celery, taskworkers need to import task modules to make tasks
13491352
# accessible to the worker.
@@ -1353,6 +1356,7 @@ def SOCIAL_AUTH_DEFAULT_USERNAME() -> str:
13531356
)
13541357
TASKWORKER_ROUTER: str = "sentry.taskworker.router.DefaultRouter"
13551358
TASKWORKER_ROUTES: dict[str, str] = {}
1359+
13561360
# Schedules for taskworker tasks to be spawned on.
13571361
TASKWORKER_SCHEDULES: ScheduleConfigMap = {}
13581362

src/sentry/taskworker/client.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import hashlib
2+
import hmac
13
import logging
24
import random
5+
from collections.abc import Callable
36
from datetime import datetime
7+
from typing import Any
48

59
import grpc
10+
from django.conf import settings
11+
from google.protobuf.message import Message
612
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
713
FetchNextTask,
814
GetTaskRequest,
@@ -18,6 +24,59 @@
1824
logger = logging.getLogger("sentry.taskworker.client")
1925

2026

27+
class ClientCallDetails(grpc.ClientCallDetails):
28+
"""
29+
Subclass of grpc.ClientCallDetails that allows metadata to be updated
30+
"""
31+
32+
def __init__(
33+
self,
34+
method: str,
35+
timeout: float | None,
36+
metadata: tuple[tuple[str, str | bytes], ...] | None,
37+
credentials: grpc.CallCredentials | None,
38+
):
39+
self.timeout = timeout
40+
self.method = method
41+
self.metadata = metadata
42+
self.credentials = credentials
43+
44+
45+
# Type alias based on grpc-stubs
46+
ContinuationType = Callable[[ClientCallDetails, Message], Any]
47+
48+
49+
# The type stubs for grpc.UnaryUnaryClientInterceptor have generics
50+
# but the implementation in grpc does not, and providing the type parameters
51+
# results in a runtime error.
52+
class RequestSignatureInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore[type-arg]
53+
def __init__(self, shared_secret: str):
54+
self._secret = shared_secret.encode("utf-8")
55+
56+
def intercept_unary_unary(
57+
self,
58+
continuation: ContinuationType,
59+
client_call_details: grpc.ClientCallDetails,
60+
request: Message,
61+
) -> Any:
62+
request_body = request.SerializeToString()
63+
method = client_call_details.method.encode("utf-8")
64+
65+
signing_payload = method + b":" + request_body
66+
signature = hmac.new(self._secret, signing_payload, hashlib.sha256).hexdigest()
67+
68+
metadata = list(client_call_details.metadata) if client_call_details.metadata else []
69+
metadata.append(("sentry-signature", signature))
70+
71+
call_details_with_meta = ClientCallDetails(
72+
client_call_details.method,
73+
client_call_details.timeout,
74+
tuple(metadata),
75+
client_call_details.credentials,
76+
)
77+
return continuation(call_details_with_meta, request)
78+
79+
2180
class TaskworkerClient:
2281
"""
2382
Taskworker RPC client wrapper
@@ -33,7 +92,12 @@ def __init__(self, host: str, num_brokers: int | None) -> None:
3392
grpc_options = [("grpc.service_config", grpc_config)]
3493

3594
logger.info("Connecting to %s with options %s", self._host, grpc_options)
36-
self._channel = grpc.insecure_channel(self._host, options=grpc_options)
95+
channel = grpc.insecure_channel(self._host, options=grpc_options)
96+
if settings.TASKWORKER_SHARED_SECRET:
97+
channel = grpc.intercept_channel(
98+
channel, RequestSignatureInterceptor(settings.TASKWORKER_SHARED_SECRET)
99+
)
100+
self._channel = channel
37101
self._stub = ConsumerServiceStub(self._channel)
38102

39103
def loadbalance(self, host: str, num_brokers: int) -> str:

tests/sentry/taskworker/test_client.py

+68-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import dataclasses
12
from collections import defaultdict
23
from collections.abc import Callable
34
from typing import Any
45
from unittest.mock import patch
56

67
import grpc
78
import pytest
9+
from django.test import override_settings
810
from google.protobuf.message import Message
911
from sentry_protos.taskbroker.v1.taskbroker_pb2 import (
1012
TASK_ACTIVATION_STATUS_RETRY,
@@ -18,6 +20,12 @@
1820
from sentry.testutils.pytest.fixtures import django_db_all
1921

2022

23+
@dataclasses.dataclass
24+
class MockServiceCall:
25+
response: Any
26+
metadata: tuple[tuple[str, str | bytes], ...] | None = None
27+
28+
2129
class MockServiceMethod:
2230
"""Stub for grpc service methods"""
2331

@@ -40,24 +48,42 @@ def __call__(self, *args, **kwargs):
4048
tail = self.responses[1:]
4149
self.responses = tail + [res]
4250

43-
if isinstance(res, Exception):
44-
raise res
45-
return res
51+
if isinstance(res.response, Exception):
52+
raise res.response
53+
return res.response
54+
55+
def with_call(self, *args, **kwargs):
56+
res = self.responses[0]
57+
if res.metadata:
58+
assert res.metadata == kwargs.get("metadata"), "Metadata mismatch"
59+
if isinstance(res.response, Exception):
60+
raise res.response
61+
return (res.response, None)
4662

4763

4864
class MockChannel:
4965
def __init__(self):
5066
self._responses = defaultdict(list)
5167

5268
def unary_unary(
53-
self, path: str, request_serializer: Callable, response_deserializer: Callable, **kwargs
69+
self,
70+
path: str,
71+
request_serializer: Callable,
72+
response_deserializer: Callable,
73+
*args,
74+
**kwargs,
5475
):
5576
return MockServiceMethod(
5677
path, self._responses.get(path, []), request_serializer, response_deserializer
5778
)
5879

59-
def add_response(self, path: str, resp: Message | Exception):
60-
self._responses[path].append(resp)
80+
def add_response(
81+
self,
82+
path: str,
83+
resp: Message | Exception,
84+
metadata: tuple[tuple[str, str | bytes], ...] | None = None,
85+
):
86+
self._responses[path].append(MockServiceCall(response=resp, metadata=metadata))
6187

6288

6389
class MockGrpcError(grpc.RpcError):
@@ -73,6 +99,9 @@ def code(self) -> grpc.StatusCode:
7399
def details(self) -> str:
74100
return self._message
75101

102+
def result(self):
103+
raise self
104+
76105

77106
@django_db_all
78107
def test_get_task_ok():
@@ -100,6 +129,39 @@ def test_get_task_ok():
100129
assert result.namespace == "testing"
101130

102131

132+
@django_db_all
133+
@override_settings(TASKWORKER_SHARED_SECRET="a long secret value")
134+
def test_get_task_with_interceptor():
135+
channel = MockChannel()
136+
channel.add_response(
137+
"/sentry_protos.taskbroker.v1.ConsumerService/GetTask",
138+
GetTaskResponse(
139+
task=TaskActivation(
140+
id="abc123",
141+
namespace="testing",
142+
taskname="do_thing",
143+
parameters="",
144+
headers={},
145+
processing_deadline_duration=10,
146+
)
147+
),
148+
metadata=(
149+
(
150+
"sentry-signature",
151+
"3202702605c1b65055c28e7c78a5835e760830cff3e9f995eb7ad5f837130b1f",
152+
),
153+
),
154+
)
155+
with patch("sentry.taskworker.client.grpc.insecure_channel") as mock_channel:
156+
mock_channel.return_value = channel
157+
client = TaskworkerClient("localhost:50051", 1)
158+
result = client.get_task()
159+
160+
assert result
161+
assert result.id
162+
assert result.namespace == "testing"
163+
164+
103165
@django_db_all
104166
def test_get_task_with_namespace():
105167
channel = MockChannel()

0 commit comments

Comments
 (0)