Skip to content

Commit 866eecc

Browse files
committed
chore(x-goog-spanner-request-id): plug in functionality after test scaffolding
This change chops down the load of the large changes for x-goog-spanner-request-id. It depends on PR googleapis#1366 and should only be merged after that PR. Updates googleapis#1261 Requires PR googleapis#1366
1 parent e064474 commit 866eecc

File tree

8 files changed

+2088
-675
lines changed

8 files changed

+2088
-675
lines changed

google/cloud/spanner_v1/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
except ImportError: # pragma: NO COVER
7171
HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False
7272

73+
from google.cloud.spanner_v1._helpers import AtomicCounter
7374

7475
_CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__)
7576
EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST"
@@ -182,6 +183,8 @@ class Client(ClientWithProject):
182183
SCOPE = (SPANNER_ADMIN_SCOPE,)
183184
"""The scopes required for Google Cloud Spanner."""
184185

186+
NTH_CLIENT = AtomicCounter()
187+
185188
def __init__(
186189
self,
187190
project=None,
@@ -261,6 +264,12 @@ def __init__(
261264
"default_transaction_options must be an instance of DefaultTransactionOptions"
262265
)
263266
self._default_transaction_options = default_transaction_options
267+
self._nth_client_id = Client.NTH_CLIENT.increment()
268+
self._nth_request = AtomicCounter(0)
269+
270+
@property
271+
def _next_nth_request(self):
272+
return self._nth_request.increment()
264273

265274
@property
266275
def credentials(self):

google/cloud/spanner_v1/database.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151
from google.cloud.spanner_v1 import SpannerClient
5252
from google.cloud.spanner_v1._helpers import _merge_query_options
5353
from google.cloud.spanner_v1._helpers import (
54+
AtomicCounter,
5455
_metadata_with_prefix,
5556
_metadata_with_leader_aware_routing,
57+
_metadata_with_request_id,
5658
)
5759
from google.cloud.spanner_v1.batch import Batch
5860
from google.cloud.spanner_v1.batch import MutationGroups
@@ -151,6 +153,9 @@ class Database(object):
151153

152154
_spanner_api: SpannerClient = None
153155

156+
__transport_lock = threading.Lock()
157+
__transports_to_channel_id = dict()
158+
154159
def __init__(
155160
self,
156161
database_id,
@@ -188,6 +193,7 @@ def __init__(
188193
self._instance._client.default_transaction_options
189194
)
190195
self._proto_descriptors = proto_descriptors
196+
self._channel_id = 0 # It'll be created when _spanner_api is created.
191197

192198
if pool is None:
193199
pool = BurstyPool(database_role=database_role)
@@ -446,8 +452,26 @@ def spanner_api(self):
446452
client_info=client_info,
447453
client_options=client_options,
448454
)
455+
456+
with self.__transport_lock:
457+
transport = self._spanner_api._transport
458+
channel_id = self.__transports_to_channel_id.get(transport, None)
459+
if channel_id is None:
460+
channel_id = len(self.__transports_to_channel_id) + 1
461+
self.__transports_to_channel_id[transport] = channel_id
462+
self._channel_id = channel_id
463+
449464
return self._spanner_api
450465

466+
def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
467+
return _metadata_with_request_id(
468+
self._nth_client_id,
469+
self._channel_id,
470+
nth_request,
471+
nth_attempt,
472+
prior_metadata,
473+
)
474+
451475
def __eq__(self, other):
452476
if not isinstance(other, self.__class__):
453477
return NotImplemented
@@ -490,7 +514,10 @@ def create(self):
490514
database_dialect=self._database_dialect,
491515
proto_descriptors=self._proto_descriptors,
492516
)
493-
future = api.create_database(request=request, metadata=metadata)
517+
future = api.create_database(
518+
request=request,
519+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
520+
)
494521
return future
495522

496523
def exists(self):
@@ -506,7 +533,12 @@ def exists(self):
506533
metadata = _metadata_with_prefix(self.name)
507534

508535
try:
509-
api.get_database_ddl(database=self.name, metadata=metadata)
536+
api.get_database_ddl(
537+
database=self.name,
538+
metadata=self.metadata_with_request_id(
539+
self._next_nth_request, 1, metadata
540+
),
541+
)
510542
except NotFound:
511543
return False
512544
return True
@@ -523,10 +555,16 @@ def reload(self):
523555
"""
524556
api = self._instance._client.database_admin_api
525557
metadata = _metadata_with_prefix(self.name)
526-
response = api.get_database_ddl(database=self.name, metadata=metadata)
558+
response = api.get_database_ddl(
559+
database=self.name,
560+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
561+
)
527562
self._ddl_statements = tuple(response.statements)
528563
self._proto_descriptors = response.proto_descriptors
529-
response = api.get_database(name=self.name, metadata=metadata)
564+
response = api.get_database(
565+
name=self.name,
566+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
567+
)
530568
self._state = DatabasePB.State(response.state)
531569
self._create_time = response.create_time
532570
self._restore_info = response.restore_info
@@ -571,7 +609,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
571609
proto_descriptors=proto_descriptors,
572610
)
573611

574-
future = api.update_database_ddl(request=request, metadata=metadata)
612+
future = api.update_database_ddl(
613+
request=request,
614+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
615+
)
575616
return future
576617

577618
def update(self, fields):
@@ -609,7 +650,9 @@ def update(self, fields):
609650
metadata = _metadata_with_prefix(self.name)
610651

611652
future = api.update_database(
612-
database=database_pb, update_mask=field_mask, metadata=metadata
653+
database=database_pb,
654+
update_mask=field_mask,
655+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
613656
)
614657

615658
return future
@@ -622,7 +665,10 @@ def drop(self):
622665
"""
623666
api = self._instance._client.database_admin_api
624667
metadata = _metadata_with_prefix(self.name)
625-
api.drop_database(database=self.name, metadata=metadata)
668+
api.drop_database(
669+
database=self.name,
670+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
671+
)
626672

627673
def execute_partitioned_dml(
628674
self,
@@ -711,7 +757,13 @@ def execute_pdml():
711757
with SessionCheckout(self._pool) as session:
712758
add_span_event(span, "Starting BeginTransaction")
713759
txn = api.begin_transaction(
714-
session=session.name, options=txn_options, metadata=metadata
760+
session=session.name,
761+
options=txn_options,
762+
metadata=self.metadata_with_request_id(
763+
self._next_nth_request,
764+
1,
765+
metadata,
766+
),
715767
)
716768

717769
txn_selector = TransactionSelector(id=txn.id)
@@ -724,6 +776,7 @@ def execute_pdml():
724776
query_options=query_options,
725777
request_options=request_options,
726778
)
779+
727780
method = functools.partial(
728781
api.execute_streaming_sql,
729782
metadata=metadata,
@@ -736,6 +789,7 @@ def execute_pdml():
736789
metadata=metadata,
737790
transaction_selector=txn_selector,
738791
observability_options=self.observability_options,
792+
request_id_manager=self,
739793
)
740794

741795
result_set = StreamedResultSet(iterator)
@@ -745,6 +799,17 @@ def execute_pdml():
745799

746800
return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()
747801

802+
@property
803+
def _next_nth_request(self):
804+
if self._instance and self._instance._client:
805+
return self._instance._client._next_nth_request
806+
raise Exception("returning 1 for next_nth_request")
807+
return 1
808+
809+
@property
810+
def _nth_client_id(self):
811+
return self._instance._client._nth_client_id
812+
748813
def session(self, labels=None, database_role=None):
749814
"""Factory to create a session for this database.
750815
@@ -965,7 +1030,8 @@ def restore(self, source):
9651030
)
9661031
future = api.restore_database(
9671032
request=request,
968-
metadata=metadata,
1033+
# TODO: Infer the channel_id being used.
1034+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
9691035
)
9701036
return future
9711037

@@ -1034,7 +1100,10 @@ def list_database_roles(self, page_size=None):
10341100
parent=self.name,
10351101
page_size=page_size,
10361102
)
1037-
return api.list_database_roles(request=request, metadata=metadata)
1103+
return api.list_database_roles(
1104+
request=request,
1105+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1106+
)
10381107

10391108
def table(self, table_id):
10401109
"""Factory to create a table object within this database.
@@ -1118,7 +1187,10 @@ def get_iam_policy(self, policy_version=None):
11181187
requested_policy_version=policy_version
11191188
),
11201189
)
1121-
response = api.get_iam_policy(request=request, metadata=metadata)
1190+
response = api.get_iam_policy(
1191+
request=request,
1192+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1193+
)
11221194
return response
11231195

11241196
def set_iam_policy(self, policy):
@@ -1140,7 +1212,10 @@ def set_iam_policy(self, policy):
11401212
resource=self.name,
11411213
policy=policy,
11421214
)
1143-
response = api.set_iam_policy(request=request, metadata=metadata)
1215+
response = api.set_iam_policy(
1216+
request=request,
1217+
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
1218+
)
11441219
return response
11451220

11461221
@property

0 commit comments

Comments
 (0)