Skip to content

Commit 3a91671

Browse files
odeke-emolavloite
andauthored
chore(x-goog-spanner-request-id): plug in functionality after test scaffolding (#1367)
* chore(x-goog-spanner-request-id): plug in functionality after test scaffolding This change chops down the load of the large changes for x-goog-spanner-request-id. It depends on PR #1366 and should only be merged after that PR. Updates #1261 Requires PR #1366 * Include batch* * Address review feedback * chore: fix formatting --------- Co-authored-by: Knut Olav Løite <[email protected]>
1 parent fd4ee67 commit 3a91671

14 files changed

+1423
-293
lines changed

google/cloud/spanner_v1/batch.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -249,17 +249,28 @@ def commit(
249249
observability_options=observability_options,
250250
metadata=metadata,
251251
), MetricsCapture():
252-
method = functools.partial(
253-
api.commit,
254-
request=request,
255-
metadata=metadata,
256-
)
252+
253+
def wrapped_method(*args, **kwargs):
254+
method = functools.partial(
255+
api.commit,
256+
request=request,
257+
metadata=database.metadata_with_request_id(
258+
# This code is retried due to ABORTED, hence nth_request
259+
# should be increased. attempt can only be increased if
260+
# we encounter UNAVAILABLE or INTERNAL.
261+
getattr(database, "_next_nth_request", 0),
262+
1,
263+
metadata,
264+
),
265+
)
266+
return method(*args, **kwargs)
267+
257268
deadline = time.time() + kwargs.get(
258269
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
259270
)
260271
default_retry_delay = kwargs.get("default_retry_delay", None)
261272
response = _retry_on_aborted_exception(
262-
method,
273+
wrapped_method,
263274
deadline=deadline,
264275
default_retry_delay=default_retry_delay,
265276
)

google/cloud/spanner_v1/database.py

+86-12
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from google.cloud.spanner_v1._helpers import (
5454
_metadata_with_prefix,
5555
_metadata_with_leader_aware_routing,
56+
_metadata_with_request_id,
5657
)
5758
from google.cloud.spanner_v1.batch import Batch
5859
from google.cloud.spanner_v1.batch import MutationGroups
@@ -151,6 +152,9 @@ class Database(object):
151152

152153
_spanner_api: SpannerClient = None
153154

155+
__transport_lock = threading.Lock()
156+
__transports_to_channel_id = dict()
157+
154158
def __init__(
155159
self,
156160
database_id,
@@ -188,6 +192,7 @@ def __init__(
188192
self._instance._client.default_transaction_options
189193
)
190194
self._proto_descriptors = proto_descriptors
195+
self._channel_id = 0 # It'll be created when _spanner_api is created.
191196

192197
if pool is None:
193198
pool = BurstyPool(database_role=database_role)
@@ -446,8 +451,26 @@ def spanner_api(self):
446451
client_info=client_info,
447452
client_options=client_options,
448453
)
454+
455+
with self.__transport_lock:
456+
transport = self._spanner_api._transport
457+
channel_id = self.__transports_to_channel_id.get(transport, None)
458+
if channel_id is None:
459+
channel_id = len(self.__transports_to_channel_id) + 1
460+
self.__transports_to_channel_id[transport] = channel_id
461+
self._channel_id = channel_id
462+
449463
return self._spanner_api
450464

465+
def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
466+
return _metadata_with_request_id(
467+
self._nth_client_id,
468+
self._channel_id,
469+
nth_request,
470+
nth_attempt,
471+
prior_metadata,
472+
)
473+
451474
def __eq__(self, other):
452475
if not isinstance(other, self.__class__):
453476
return NotImplemented
@@ -490,7 +513,10 @@ def create(self):
490513
database_dialect=self._database_dialect,
491514
proto_descriptors=self._proto_descriptors,
492515
)
493-
future = api.create_database(request=request, metadata=metadata)
516+
future = api.create_database(
517+
request=request,
518+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
519+
)
494520
return future
495521

496522
def exists(self):
@@ -506,7 +532,12 @@ def exists(self):
506532
metadata = _metadata_with_prefix(self.name)
507533

508534
try:
509-
api.get_database_ddl(database=self.name, metadata=metadata)
535+
api.get_database_ddl(
536+
database=self.name,
537+
metadata=self.metadata_with_request_id(
538+
self._next_nth_request, 1, metadata
539+
),
540+
)
510541
except NotFound:
511542
return False
512543
return True
@@ -523,10 +554,16 @@ def reload(self):
523554
"""
524555
api = self._instance._client.database_admin_api
525556
metadata = _metadata_with_prefix(self.name)
526-
response = api.get_database_ddl(database=self.name, metadata=metadata)
557+
response = api.get_database_ddl(
558+
database=self.name,
559+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
560+
)
527561
self._ddl_statements = tuple(response.statements)
528562
self._proto_descriptors = response.proto_descriptors
529-
response = api.get_database(name=self.name, metadata=metadata)
563+
response = api.get_database(
564+
name=self.name,
565+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
566+
)
530567
self._state = DatabasePB.State(response.state)
531568
self._create_time = response.create_time
532569
self._restore_info = response.restore_info
@@ -571,7 +608,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
571608
proto_descriptors=proto_descriptors,
572609
)
573610

574-
future = api.update_database_ddl(request=request, metadata=metadata)
611+
future = api.update_database_ddl(
612+
request=request,
613+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
614+
)
575615
return future
576616

577617
def update(self, fields):
@@ -609,7 +649,9 @@ def update(self, fields):
609649
metadata = _metadata_with_prefix(self.name)
610650

611651
future = api.update_database(
612-
database=database_pb, update_mask=field_mask, metadata=metadata
652+
database=database_pb,
653+
update_mask=field_mask,
654+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
613655
)
614656

615657
return future
@@ -622,7 +664,10 @@ def drop(self):
622664
"""
623665
api = self._instance._client.database_admin_api
624666
metadata = _metadata_with_prefix(self.name)
625-
api.drop_database(database=self.name, metadata=metadata)
667+
api.drop_database(
668+
database=self.name,
669+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
670+
)
626671

627672
def execute_partitioned_dml(
628673
self,
@@ -711,7 +756,13 @@ def execute_pdml():
711756
with SessionCheckout(self._pool) as session:
712757
add_span_event(span, "Starting BeginTransaction")
713758
txn = api.begin_transaction(
714-
session=session.name, options=txn_options, metadata=metadata
759+
session=session.name,
760+
options=txn_options,
761+
metadata=self.metadata_with_request_id(
762+
self._next_nth_request,
763+
1,
764+
metadata,
765+
),
715766
)
716767

717768
txn_selector = TransactionSelector(id=txn.id)
@@ -724,6 +775,7 @@ def execute_pdml():
724775
query_options=query_options,
725776
request_options=request_options,
726777
)
778+
727779
method = functools.partial(
728780
api.execute_streaming_sql,
729781
metadata=metadata,
@@ -736,6 +788,7 @@ def execute_pdml():
736788
metadata=metadata,
737789
transaction_selector=txn_selector,
738790
observability_options=self.observability_options,
791+
request_id_manager=self,
739792
)
740793

741794
result_set = StreamedResultSet(iterator)
@@ -745,6 +798,18 @@ def execute_pdml():
745798

746799
return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
747800

801+
@property
802+
def _next_nth_request(self):
803+
if self._instance and self._instance._client:
804+
return self._instance._client._next_nth_request
805+
return 1
806+
807+
@property
808+
def _nth_client_id(self):
809+
if self._instance and self._instance._client:
810+
return self._instance._client._nth_client_id
811+
return 0
812+
748813
def session(self, labels=None, database_role=None):
749814
"""Factory to create a session for this database.
750815
@@ -965,7 +1030,7 @@ def restore(self, source):
9651030
)
9661031
future = api.restore_database(
9671032
request=request,
968-
metadata=metadata,
1033+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
9691034
)
9701035
return future
9711036

@@ -1034,7 +1099,10 @@ def list_database_roles(self, page_size=None):
10341099
parent=self.name,
10351100
page_size=page_size,
10361101
)
1037-
return api.list_database_roles(request=request, metadata=metadata)
1102+
return api.list_database_roles(
1103+
request=request,
1104+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1105+
)
10381106

10391107
def table(self, table_id):
10401108
"""Factory to create a table object within this database.
@@ -1118,7 +1186,10 @@ def get_iam_policy(self, policy_version=None):
11181186
requested_policy_version=policy_version
11191187
),
11201188
)
1121-
response = api.get_iam_policy(request=request, metadata=metadata)
1189+
response = api.get_iam_policy(
1190+
request=request,
1191+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1192+
)
11221193
return response
11231194

11241195
def set_iam_policy(self, policy):
@@ -1140,7 +1211,10 @@ def set_iam_policy(self, policy):
11401211
resource=self.name,
11411212
policy=policy,
11421213
)
1143-
response = api.set_iam_policy(request=request, metadata=metadata)
1214+
response = api.set_iam_policy(
1215+
request=request,
1216+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1217+
)
11441218
return response
11451219

11461220
@property

google/cloud/spanner_v1/pool.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def bind(self, database):
256256
)
257257
resp = api.batch_create_sessions(
258258
request=request,
259-
metadata=metadata,
259+
metadata=database.metadata_with_request_id(
260+
database._next_nth_request, 1, metadata
261+
),
260262
)
261263

262264
add_span_event(
@@ -561,7 +563,9 @@ def bind(self, database):
561563
while returned_session_count < self.size:
562564
resp = api.batch_create_sessions(
563565
request=request,
564-
metadata=metadata,
566+
metadata=database.metadata_with_request_id(
567+
database._next_nth_request, 1, metadata
568+
),
565569
)
566570

567571
add_span_event(

google/cloud/spanner_v1/session.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def create(self):
170170
), MetricsCapture():
171171
session_pb = api.create_session(
172172
request=request,
173-
metadata=metadata,
173+
metadata=self._database.metadata_with_request_id(
174+
self._database._next_nth_request, 1, metadata
175+
),
174176
)
175177
self._session_id = session_pb.name.split("/")[-1]
176178

@@ -195,7 +197,8 @@ def exists(self):
195197
current_span, "Checking if Session exists", {"session.id": self._session_id}
196198
)
197199

198-
api = self._database.spanner_api
200+
database = self._database
201+
api = database.spanner_api
199202
metadata = _metadata_with_prefix(self._database.name)
200203
if self._database._route_to_leader_enabled:
201204
metadata.append(
@@ -212,7 +215,12 @@ def exists(self):
212215
metadata=metadata,
213216
) as span, MetricsCapture():
214217
try:
215-
api.get_session(name=self.name, metadata=metadata)
218+
api.get_session(
219+
name=self.name,
220+
metadata=database.metadata_with_request_id(
221+
database._next_nth_request, 1, metadata
222+
),
223+
)
216224
if span:
217225
span.set_attribute("session_found", True)
218226
except NotFound:
@@ -242,8 +250,9 @@ def delete(self):
242250
current_span, "Deleting Session", {"session.id": self._session_id}
243251
)
244252

245-
api = self._database.spanner_api
246-
metadata = _metadata_with_prefix(self._database.name)
253+
database = self._database
254+
api = database.spanner_api
255+
metadata = _metadata_with_prefix(database.name)
247256
observability_options = getattr(self._database, "observability_options", None)
248257
with trace_call(
249258
"CloudSpanner.DeleteSession",
@@ -255,7 +264,12 @@ def delete(self):
255264
observability_options=observability_options,
256265
metadata=metadata,
257266
), MetricsCapture():
258-
api.delete_session(name=self.name, metadata=metadata)
267+
api.delete_session(
268+
name=self.name,
269+
metadata=database.metadata_with_request_id(
270+
database._next_nth_request, 1, metadata
271+
),
272+
)
259273

260274
def ping(self):
261275
"""Ping the session to keep it alive by executing "SELECT 1".
@@ -264,10 +278,17 @@ def ping(self):
264278
"""
265279
if self._session_id is None:
266280
raise ValueError("Session ID not set by back-end")
267-
api = self._database.spanner_api
268-
metadata = _metadata_with_prefix(self._database.name)
281+
database = self._database
282+
api = database.spanner_api
269283
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
270-
api.execute_sql(request=request, metadata=metadata)
284+
api.execute_sql(
285+
request=request,
286+
metadata=database.metadata_with_request_id(
287+
database._next_nth_request,
288+
1,
289+
_metadata_with_prefix(database.name),
290+
),
291+
)
271292
self._last_use_time = datetime.now()
272293

273294
def snapshot(self, **kw):

0 commit comments

Comments
 (0)