1
+ import dataclasses
1
2
from collections import defaultdict
2
3
from collections .abc import Callable
3
4
from typing import Any
4
5
from unittest .mock import patch
5
6
6
7
import grpc
7
8
import pytest
9
+ from django .test import override_settings
8
10
from google .protobuf .message import Message
9
11
from sentry_protos .taskbroker .v1 .taskbroker_pb2 import (
10
12
TASK_ACTIVATION_STATUS_RETRY ,
18
20
from sentry .testutils .pytest .fixtures import django_db_all
19
21
20
22
23
+ @dataclasses .dataclass
24
+ class MockServiceCall :
25
+ response : Any
26
+ metadata : tuple [tuple [str , str | bytes ], ...] | None = None
27
+
28
+
21
29
class MockServiceMethod :
22
30
"""Stub for grpc service methods"""
23
31
@@ -40,24 +48,42 @@ def __call__(self, *args, **kwargs):
40
48
tail = self .responses [1 :]
41
49
self .responses = tail + [res ]
42
50
43
- if isinstance (res , Exception ):
44
- raise res
45
- return res
51
+ if isinstance (res .response , Exception ):
52
+ raise res .response
53
+ return res .response
54
+
55
+ def with_call (self , * args , ** kwargs ):
56
+ res = self .responses [0 ]
57
+ if res .metadata :
58
+ assert res .metadata == kwargs .get ("metadata" ), "Metadata mismatch"
59
+ if isinstance (res .response , Exception ):
60
+ raise res .response
61
+ return (res .response , None )
46
62
47
63
48
64
class MockChannel :
49
65
def __init__ (self ):
50
66
self ._responses = defaultdict (list )
51
67
52
68
def unary_unary (
53
- self , path : str , request_serializer : Callable , response_deserializer : Callable , ** kwargs
69
+ self ,
70
+ path : str ,
71
+ request_serializer : Callable ,
72
+ response_deserializer : Callable ,
73
+ * args ,
74
+ ** kwargs ,
54
75
):
55
76
return MockServiceMethod (
56
77
path , self ._responses .get (path , []), request_serializer , response_deserializer
57
78
)
58
79
59
- def add_response (self , path : str , resp : Message | Exception ):
60
- self ._responses [path ].append (resp )
80
+ def add_response (
81
+ self ,
82
+ path : str ,
83
+ resp : Message | Exception ,
84
+ metadata : tuple [tuple [str , str | bytes ], ...] | None = None ,
85
+ ):
86
+ self ._responses [path ].append (MockServiceCall (response = resp , metadata = metadata ))
61
87
62
88
63
89
class MockGrpcError (grpc .RpcError ):
@@ -73,6 +99,9 @@ def code(self) -> grpc.StatusCode:
73
99
def details (self ) -> str :
74
100
return self ._message
75
101
102
+ def result (self ):
103
+ raise self
104
+
76
105
77
106
@django_db_all
78
107
def test_get_task_ok ():
@@ -100,6 +129,39 @@ def test_get_task_ok():
100
129
assert result .namespace == "testing"
101
130
102
131
132
+ @django_db_all
133
+ @override_settings (TASKWORKER_SHARED_SECRET = "a long secret value" )
134
+ def test_get_task_with_interceptor ():
135
+ channel = MockChannel ()
136
+ channel .add_response (
137
+ "/sentry_protos.taskbroker.v1.ConsumerService/GetTask" ,
138
+ GetTaskResponse (
139
+ task = TaskActivation (
140
+ id = "abc123" ,
141
+ namespace = "testing" ,
142
+ taskname = "do_thing" ,
143
+ parameters = "" ,
144
+ headers = {},
145
+ processing_deadline_duration = 10 ,
146
+ )
147
+ ),
148
+ metadata = (
149
+ (
150
+ "sentry-signature" ,
151
+ "3202702605c1b65055c28e7c78a5835e760830cff3e9f995eb7ad5f837130b1f" ,
152
+ ),
153
+ ),
154
+ )
155
+ with patch ("sentry.taskworker.client.grpc.insecure_channel" ) as mock_channel :
156
+ mock_channel .return_value = channel
157
+ client = TaskworkerClient ("localhost:50051" , 1 )
158
+ result = client .get_task ()
159
+
160
+ assert result
161
+ assert result .id
162
+ assert result .namespace == "testing"
163
+
164
+
103
165
@django_db_all
104
166
def test_get_task_with_namespace ():
105
167
channel = MockChannel ()
0 commit comments