Skip to content

Commit 8616572

Browse files
committed
Inject header in more Session using spots plus more tests
1 parent 3e7114a commit 8616572

File tree

3 files changed

+153
-7
lines changed

3 files changed

+153
-7
lines changed

google/cloud/spanner_v1/session.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def exists(self):
193193
current_span, "Checking if Session exists", {"session.id": self._session_id}
194194
)
195195

196-
api = self._database.spanner_api
196+
database = self._database
197+
api = database.spanner_api
197198
metadata = _metadata_with_prefix(self._database.name)
198199
if self._database._route_to_leader_enabled:
199200
metadata.append(
@@ -202,12 +203,16 @@ def exists(self):
202203
)
203204
)
204205

206+
all_metadata = database.metadata_with_request_id(
207+
database._next_nth_request, 1, metadata
208+
)
209+
205210
observability_options = getattr(self._database, "observability_options", None)
206211
with trace_call(
207212
"CloudSpanner.GetSession", self, observability_options=observability_options
208213
) as span:
209214
try:
210-
api.get_session(name=self.name, metadata=metadata)
215+
api.get_session(name=self.name, metadata=all_metadata)
211216
if span:
212217
span.set_attribute("session_found", True)
213218
except NotFound:
@@ -237,8 +242,11 @@ def delete(self):
237242
current_span, "Deleting Session", {"session.id": self._session_id}
238243
)
239244

240-
api = self._database.spanner_api
241-
metadata = _metadata_with_prefix(self._database.name)
245+
database = self._database
246+
api = database.spanner_api
247+
metadata = database.metadata_with_request_id(
248+
database._next_nth_request, 1, _metadata_with_prefix(database.name)
249+
)
242250
observability_options = getattr(self._database, "observability_options", None)
243251
with trace_call(
244252
"CloudSpanner.DeleteSession",
@@ -255,7 +263,10 @@ def ping(self):
255263
if self._session_id is None:
256264
raise ValueError("Session ID not set by back-end")
257265
api = self._database.spanner_api
258-
metadata = _metadata_with_prefix(self._database.name)
266+
database = self._database
267+
metadata = database.metadata_with_request_id(
268+
database._next_nth_request, 1, _metadata_with_prefix(database.name)
269+
)
259270
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
260271
api.execute_sql(request=request, metadata=metadata)
261272
self._last_use_time = datetime.now()

google/cloud/spanner_v1/testing/interceptors.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def reset(self):
6767
self._connection = None
6868

6969

70+
X_GOOG_REQUEST_ID = "x-goog-spanner-request-id"
71+
72+
7073
class XGoogRequestIDHeaderInterceptor(ClientInterceptor):
7174
def __init__(self):
7275
self._unary_req_segments = []
@@ -77,12 +80,14 @@ def intercept(self, method, request_or_iterator, call_details):
7780
metadata = call_details.metadata
7881
x_goog_request_id = None
7982
for key, value in metadata:
80-
if key == "x-goog-spanner-request-id":
83+
if key == X_GOOG_REQUEST_ID:
8184
x_goog_request_id = value
8285
break
8386

8487
if not x_goog_request_id:
85-
raise Exception("Missing x_goog_request_id header")
88+
raise Exception(
89+
f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}"
90+
)
8691

8792
response_or_iterator = method(request_or_iterator, call_details)
8893
streaming = getattr(response_or_iterator, "__iter__", None) is not None

tests/unit/test_request_id_header.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import random
16+
import threading
17+
1518
from tests.mockserver_tests.mock_server_test_base import (
1619
MockServerTestBase,
1720
add_select1_result,
@@ -63,6 +66,133 @@ def test_snapshot_read(self):
6366
assert got_unary_segments == want_unary_segments
6467
assert got_stream_segments == want_stream_segments
6568

69+
def test_snapshot_read_concurrent(self):
70+
def select1():
71+
with self.database.snapshot() as snapshot:
72+
rows = snapshot.execute_sql("select 1")
73+
res_list = []
74+
for row in rows:
75+
self.assertEqual(1, row[0])
76+
res_list.append(row)
77+
self.assertEqual(1, len(res_list))
78+
79+
n = 10
80+
threads = []
81+
for i in range(n):
82+
th = threading.Thread(target=select1, name=f"snapshot-select1-{i}")
83+
th.run()
84+
threads.append(th)
85+
86+
random.shuffle(threads)
87+
88+
while True:
89+
n_finished = 0
90+
for thread in threads:
91+
if thread.is_alive():
92+
thread.join()
93+
else:
94+
n_finished += 1
95+
96+
if n_finished == len(threads):
97+
break
98+
99+
time.sleep(1)
100+
101+
requests = self.spanner_service.requests
102+
self.assertEqual(n * 2, len(requests), msg=requests)
103+
104+
client_id = self.database._nth_client_id
105+
channel_id = self.database._channel_id
106+
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
107+
108+
want_unary_segments = [
109+
(
110+
"/google.spanner.v1.Spanner/BatchCreateSessions",
111+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
112+
),
113+
(
114+
"/google.spanner.v1.Spanner/GetSession",
115+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1),
116+
),
117+
(
118+
"/google.spanner.v1.Spanner/GetSession",
119+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1),
120+
),
121+
(
122+
"/google.spanner.v1.Spanner/GetSession",
123+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1),
124+
),
125+
(
126+
"/google.spanner.v1.Spanner/GetSession",
127+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1),
128+
),
129+
(
130+
"/google.spanner.v1.Spanner/GetSession",
131+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1),
132+
),
133+
(
134+
"/google.spanner.v1.Spanner/GetSession",
135+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1),
136+
),
137+
(
138+
"/google.spanner.v1.Spanner/GetSession",
139+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1),
140+
),
141+
(
142+
"/google.spanner.v1.Spanner/GetSession",
143+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1),
144+
),
145+
(
146+
"/google.spanner.v1.Spanner/GetSession",
147+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1),
148+
),
149+
]
150+
assert got_unary_segments == want_unary_segments
151+
152+
want_stream_segments = [
153+
(
154+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
155+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1),
156+
),
157+
(
158+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
159+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1),
160+
),
161+
(
162+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
163+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1),
164+
),
165+
(
166+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
167+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1),
168+
),
169+
(
170+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
171+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1),
172+
),
173+
(
174+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
175+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1),
176+
),
177+
(
178+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
179+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1),
180+
),
181+
(
182+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
183+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1),
184+
),
185+
(
186+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
187+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1),
188+
),
189+
(
190+
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
191+
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1),
192+
),
193+
]
194+
assert got_stream_segments == want_stream_segments
195+
66196
def canonicalize_request_id_headers(self):
67197
src = self.database._x_goog_request_id_interceptor
68198
return src._stream_req_segments, src._unary_req_segments

0 commit comments

Comments
 (0)