Skip to content

Commit b8b1591

Browse files
committed
chore(x-goog-spanner-request-id): plug in functionality after test scaffolding
This change chops down the load of the large changes for x-goog-spanner-request-id. It depends on PR googleapis#1366 and should only be merged after that PR. Updates googleapis#1261 Requires PR googleapis#1366
1 parent fd4ee67 commit b8b1591

File tree

7 files changed

+380
-96
lines changed

7 files changed

+380
-96
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151
from google.cloud.spanner_v1 import SpannerClient
5252
from google.cloud.spanner_v1._helpers import _merge_query_options
5353
from google.cloud.spanner_v1._helpers import (
54+
AtomicCounter,
5455
_metadata_with_prefix,
5556
_metadata_with_leader_aware_routing,
57+
_metadata_with_request_id,
5658
)
5759
from google.cloud.spanner_v1.batch import Batch
5860
from google.cloud.spanner_v1.batch import MutationGroups
@@ -151,6 +153,9 @@ class Database(object):
151153

152154
_spanner_api: SpannerClient = None
153155

156+
__transport_lock = threading.Lock()
157+
__transports_to_channel_id = dict()
158+
154159
def __init__(
155160
self,
156161
database_id,
@@ -188,6 +193,7 @@ def __init__(
188193
self._instance._client.default_transaction_options
189194
)
190195
self._proto_descriptors = proto_descriptors
196+
self._channel_id = 0 # It'll be created when _spanner_api is created.
191197

192198
if pool is None:
193199
pool = BurstyPool(database_role=database_role)
@@ -446,8 +452,26 @@ def spanner_api(self):
446452
client_info=client_info,
447453
client_options=client_options,
448454
)
455+
456+
with self.__transport_lock:
457+
transport = self._spanner_api._transport
458+
channel_id = self.__transports_to_channel_id.get(transport, None)
459+
if channel_id is None:
460+
channel_id = len(self.__transports_to_channel_id) + 1
461+
self.__transports_to_channel_id[transport] = channel_id
462+
self._channel_id = channel_id
463+
449464
return self._spanner_api
450465

466+
def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
467+
return _metadata_with_request_id(
468+
self._nth_client_id,
469+
self._channel_id,
470+
nth_request,
471+
nth_attempt,
472+
prior_metadata,
473+
)
474+
451475
def __eq__(self, other):
452476
if not isinstance(other, self.__class__):
453477
return NotImplemented
@@ -490,7 +514,10 @@ def create(self):
490514
database_dialect=self._database_dialect,
491515
proto_descriptors=self._proto_descriptors,
492516
)
493-
future = api.create_database(request=request, metadata=metadata)
517+
future = api.create_database(
518+
request=request,
519+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
520+
)
494521
return future
495522

496523
def exists(self):
@@ -506,7 +533,12 @@ def exists(self):
506533
metadata = _metadata_with_prefix(self.name)
507534

508535
try:
509-
api.get_database_ddl(database=self.name, metadata=metadata)
536+
api.get_database_ddl(
537+
database=self.name,
538+
metadata=self.metadata_with_request_id(
539+
self._next_nth_request, 1, metadata
540+
),
541+
)
510542
except NotFound:
511543
return False
512544
return True
@@ -523,10 +555,16 @@ def reload(self):
523555
"""
524556
api = self._instance._client.database_admin_api
525557
metadata = _metadata_with_prefix(self.name)
526-
response = api.get_database_ddl(database=self.name, metadata=metadata)
558+
response = api.get_database_ddl(
559+
database=self.name,
560+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
561+
)
527562
self._ddl_statements = tuple(response.statements)
528563
self._proto_descriptors = response.proto_descriptors
529-
response = api.get_database(name=self.name, metadata=metadata)
564+
response = api.get_database(
565+
name=self.name,
566+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
567+
)
530568
self._state = DatabasePB.State(response.state)
531569
self._create_time = response.create_time
532570
self._restore_info = response.restore_info
@@ -571,7 +609,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
571609
proto_descriptors=proto_descriptors,
572610
)
573611

574-
future = api.update_database_ddl(request=request, metadata=metadata)
612+
future = api.update_database_ddl(
613+
request=request,
614+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
615+
)
575616
return future
576617

577618
def update(self, fields):
@@ -609,7 +650,9 @@ def update(self, fields):
609650
metadata = _metadata_with_prefix(self.name)
610651

611652
future = api.update_database(
612-
database=database_pb, update_mask=field_mask, metadata=metadata
653+
database=database_pb,
654+
update_mask=field_mask,
655+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
613656
)
614657

615658
return future
@@ -622,7 +665,10 @@ def drop(self):
622665
"""
623666
api = self._instance._client.database_admin_api
624667
metadata = _metadata_with_prefix(self.name)
625-
api.drop_database(database=self.name, metadata=metadata)
668+
api.drop_database(
669+
database=self.name,
670+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
671+
)
626672

627673
def execute_partitioned_dml(
628674
self,
@@ -711,7 +757,13 @@ def execute_pdml():
711757
with SessionCheckout(self._pool) as session:
712758
add_span_event(span, "Starting BeginTransaction")
713759
txn = api.begin_transaction(
714-
session=session.name, options=txn_options, metadata=metadata
760+
session=session.name,
761+
options=txn_options,
762+
metadata=self.metadata_with_request_id(
763+
self._next_nth_request,
764+
1,
765+
metadata,
766+
),
715767
)
716768

717769
txn_selector = TransactionSelector(id=txn.id)
@@ -724,6 +776,7 @@ def execute_pdml():
724776
query_options=query_options,
725777
request_options=request_options,
726778
)
779+
727780
method = functools.partial(
728781
api.execute_streaming_sql,
729782
metadata=metadata,
@@ -736,6 +789,7 @@ def execute_pdml():
736789
metadata=metadata,
737790
transaction_selector=txn_selector,
738791
observability_options=self.observability_options,
792+
request_id_manager=self,
739793
)
740794

741795
result_set = StreamedResultSet(iterator)
@@ -745,6 +799,16 @@ def execute_pdml():
745799

746800
return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
747801

802+
@property
803+
def _next_nth_request(self):
804+
if self._instance and self._instance._client:
805+
return self._instance._client._next_nth_request
806+
return 1
807+
808+
@property
809+
def _nth_client_id(self):
810+
return self._instance._client._nth_client_id
811+
748812
def session(self, labels=None, database_role=None):
749813
"""Factory to create a session for this database.
750814
@@ -965,7 +1029,8 @@ def restore(self, source):
9651029
)
9661030
future = api.restore_database(
9671031
request=request,
968-
metadata=metadata,
1032+
# TODO: Infer the channel_id being used.
1033+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
9691034
)
9701035
return future
9711036

@@ -1034,7 +1099,10 @@ def list_database_roles(self, page_size=None):
10341099
parent=self.name,
10351100
page_size=page_size,
10361101
)
1037-
return api.list_database_roles(request=request, metadata=metadata)
1102+
return api.list_database_roles(
1103+
request=request,
1104+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1105+
)
10381106

10391107
def table(self, table_id):
10401108
"""Factory to create a table object within this database.
@@ -1118,7 +1186,10 @@ def get_iam_policy(self, policy_version=None):
11181186
requested_policy_version=policy_version
11191187
),
11201188
)
1121-
response = api.get_iam_policy(request=request, metadata=metadata)
1189+
response = api.get_iam_policy(
1190+
request=request,
1191+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1192+
)
11221193
return response
11231194

11241195
def set_iam_policy(self, policy):
@@ -1140,7 +1211,10 @@ def set_iam_policy(self, policy):
11401211
resource=self.name,
11411212
policy=policy,
11421213
)
1143-
response = api.set_iam_policy(request=request, metadata=metadata)
1214+
response = api.set_iam_policy(
1215+
request=request,
1216+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1217+
)
11441218
return response
11451219

11461220
@property

google/cloud/spanner_v1/pool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def bind(self, database):
256256
)
257257
resp = api.batch_create_sessions(
258258
request=request,
259-
metadata=metadata,
259+
metadata=database.metadata_with_request_id(
260+
database._next_nth_request, 1, metadata
261+
),
260262
)
261263

262264
add_span_event(
@@ -561,7 +563,9 @@ def bind(self, database):
561563
while returned_session_count < self.size:
562564
resp = api.batch_create_sessions(
563565
request=request,
564-
metadata=metadata,
566+
metadata=database.metadata_with_request_id(
567+
database._next_nth_request, 1, metadata
568+
),
565569
)
566570

567571
add_span_event(

google/cloud/spanner_v1/session.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def create(self):
170170
), MetricsCapture():
171171
session_pb = api.create_session(
172172
request=request,
173-
metadata=metadata,
173+
metadata=self._database.metadata_with_request_id(
174+
self._database._next_nth_request, 1, metadata
175+
),
174176
)
175177
self._session_id = session_pb.name.split("/")[-1]
176178

@@ -195,7 +197,8 @@ def exists(self):
195197
current_span, "Checking if Session exists", {"session.id": self._session_id}
196198
)
197199

198-
api = self._database.spanner_api
200+
database = self._database
201+
api = database.spanner_api
199202
metadata = _metadata_with_prefix(self._database.name)
200203
if self._database._route_to_leader_enabled:
201204
metadata.append(
@@ -212,7 +215,12 @@ def exists(self):
212215
metadata=metadata,
213216
) as span, MetricsCapture():
214217
try:
215-
api.get_session(name=self.name, metadata=metadata)
218+
api.get_session(
219+
name=self.name,
220+
metadata=database.metadata_with_request_id(
221+
database._next_nth_request, 1, metadata
222+
),
223+
)
216224
if span:
217225
span.set_attribute("session_found", True)
218226
except NotFound:
@@ -242,8 +250,11 @@ def delete(self):
242250
current_span, "Deleting Session", {"session.id": self._session_id}
243251
)
244252

245-
api = self._database.spanner_api
246-
metadata = _metadata_with_prefix(self._database.name)
253+
database = self._database
254+
api = database.spanner_api
255+
metadata = database.metadata_with_request_id(
256+
database._next_nth_request, 1, _metadata_with_prefix(database.name)
257+
)
247258
observability_options = getattr(self._database, "observability_options", None)
248259
with trace_call(
249260
"CloudSpanner.DeleteSession",
@@ -255,7 +266,10 @@ def delete(self):
255266
observability_options=observability_options,
256267
metadata=metadata,
257268
), MetricsCapture():
258-
api.delete_session(name=self.name, metadata=metadata)
269+
api.delete_session(
270+
name=self.name,
271+
metadata=metadata,
272+
)
259273

260274
def ping(self):
261275
"""Ping the session to keep it alive by executing "SELECT 1".
@@ -264,10 +278,17 @@ def ping(self):
264278
"""
265279
if self._session_id is None:
266280
raise ValueError("Session ID not set by back-end")
267-
api = self._database.spanner_api
268-
metadata = _metadata_with_prefix(self._database.name)
281+
database = self._database
282+
api = database.spanner_api
269283
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
270-
api.execute_sql(request=request, metadata=metadata)
284+
api.execute_sql(
285+
request=request,
286+
metadata=database.metadata_with_request_id(
287+
database._next_nth_request,
288+
1,
289+
_metadata_with_prefix(database.name),
290+
),
291+
)
271292
self._last_use_time = datetime.now()
272293

273294
def snapshot(self, **kw):

0 commit comments

Comments
 (0)