51
51
from google .cloud .spanner_v1 import SpannerClient
52
52
from google .cloud .spanner_v1 ._helpers import _merge_query_options
53
53
from google .cloud .spanner_v1 ._helpers import (
54
+ AtomicCounter ,
54
55
_metadata_with_prefix ,
55
56
_metadata_with_leader_aware_routing ,
57
+ _metadata_with_request_id ,
56
58
)
57
59
from google .cloud .spanner_v1 .batch import Batch
58
60
from google .cloud .spanner_v1 .batch import MutationGroups
@@ -151,6 +153,9 @@ class Database(object):
151
153
152
154
_spanner_api : SpannerClient = None
153
155
156
+ __transport_lock = threading .Lock ()
157
+ __transports_to_channel_id = dict ()
158
+
154
159
def __init__ (
155
160
self ,
156
161
database_id ,
@@ -188,6 +193,7 @@ def __init__(
188
193
self ._instance ._client .default_transaction_options
189
194
)
190
195
self ._proto_descriptors = proto_descriptors
196
+ self ._channel_id = 0 # It'll be created when _spanner_api is created.
191
197
192
198
if pool is None :
193
199
pool = BurstyPool (database_role = database_role )
@@ -446,8 +452,26 @@ def spanner_api(self):
446
452
client_info = client_info ,
447
453
client_options = client_options ,
448
454
)
455
+
456
+ with self .__transport_lock :
457
+ transport = self ._spanner_api ._transport
458
+ channel_id = self .__transports_to_channel_id .get (transport , None )
459
+ if channel_id is None :
460
+ channel_id = len (self .__transports_to_channel_id ) + 1
461
+ self .__transports_to_channel_id [transport ] = channel_id
462
+ self ._channel_id = channel_id
463
+
449
464
return self ._spanner_api
450
465
466
+ def metadata_with_request_id (self , nth_request , nth_attempt , prior_metadata = []):
467
+ return _metadata_with_request_id (
468
+ self ._nth_client_id ,
469
+ self ._channel_id ,
470
+ nth_request ,
471
+ nth_attempt ,
472
+ prior_metadata ,
473
+ )
474
+
451
475
def __eq__ (self , other ):
452
476
if not isinstance (other , self .__class__ ):
453
477
return NotImplemented
@@ -490,7 +514,10 @@ def create(self):
490
514
database_dialect = self ._database_dialect ,
491
515
proto_descriptors = self ._proto_descriptors ,
492
516
)
493
- future = api .create_database (request = request , metadata = metadata )
517
+ future = api .create_database (
518
+ request = request ,
519
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
520
+ )
494
521
return future
495
522
496
523
def exists (self ):
@@ -506,7 +533,12 @@ def exists(self):
506
533
metadata = _metadata_with_prefix (self .name )
507
534
508
535
try :
509
- api .get_database_ddl (database = self .name , metadata = metadata )
536
+ api .get_database_ddl (
537
+ database = self .name ,
538
+ metadata = self .metadata_with_request_id (
539
+ self ._next_nth_request , 1 , metadata
540
+ ),
541
+ )
510
542
except NotFound :
511
543
return False
512
544
return True
@@ -523,10 +555,16 @@ def reload(self):
523
555
"""
524
556
api = self ._instance ._client .database_admin_api
525
557
metadata = _metadata_with_prefix (self .name )
526
- response = api .get_database_ddl (database = self .name , metadata = metadata )
558
+ response = api .get_database_ddl (
559
+ database = self .name ,
560
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
561
+ )
527
562
self ._ddl_statements = tuple (response .statements )
528
563
self ._proto_descriptors = response .proto_descriptors
529
- response = api .get_database (name = self .name , metadata = metadata )
564
+ response = api .get_database (
565
+ name = self .name ,
566
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
567
+ )
530
568
self ._state = DatabasePB .State (response .state )
531
569
self ._create_time = response .create_time
532
570
self ._restore_info = response .restore_info
@@ -571,7 +609,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
571
609
proto_descriptors = proto_descriptors ,
572
610
)
573
611
574
- future = api .update_database_ddl (request = request , metadata = metadata )
612
+ future = api .update_database_ddl (
613
+ request = request ,
614
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
615
+ )
575
616
return future
576
617
577
618
def update (self , fields ):
@@ -609,7 +650,9 @@ def update(self, fields):
609
650
metadata = _metadata_with_prefix (self .name )
610
651
611
652
future = api .update_database (
612
- database = database_pb , update_mask = field_mask , metadata = metadata
653
+ database = database_pb ,
654
+ update_mask = field_mask ,
655
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
613
656
)
614
657
615
658
return future
@@ -622,7 +665,10 @@ def drop(self):
622
665
"""
623
666
api = self ._instance ._client .database_admin_api
624
667
metadata = _metadata_with_prefix (self .name )
625
- api .drop_database (database = self .name , metadata = metadata )
668
+ api .drop_database (
669
+ database = self .name ,
670
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
671
+ )
626
672
627
673
def execute_partitioned_dml (
628
674
self ,
@@ -711,7 +757,13 @@ def execute_pdml():
711
757
with SessionCheckout (self ._pool ) as session :
712
758
add_span_event (span , "Starting BeginTransaction" )
713
759
txn = api .begin_transaction (
714
- session = session .name , options = txn_options , metadata = metadata
760
+ session = session .name ,
761
+ options = txn_options ,
762
+ metadata = self .metadata_with_request_id (
763
+ self ._next_nth_request ,
764
+ 1 ,
765
+ metadata ,
766
+ ),
715
767
)
716
768
717
769
txn_selector = TransactionSelector (id = txn .id )
@@ -724,6 +776,7 @@ def execute_pdml():
724
776
query_options = query_options ,
725
777
request_options = request_options ,
726
778
)
779
+
727
780
method = functools .partial (
728
781
api .execute_streaming_sql ,
729
782
metadata = metadata ,
@@ -736,6 +789,7 @@ def execute_pdml():
736
789
metadata = metadata ,
737
790
transaction_selector = txn_selector ,
738
791
observability_options = self .observability_options ,
792
+ request_id_manager = self ,
739
793
)
740
794
741
795
result_set = StreamedResultSet (iterator )
@@ -745,6 +799,17 @@ def execute_pdml():
745
799
746
800
return _retry_on_aborted (execute_pdml , DEFAULT_RETRY_BACKOFF )()
747
801
802
+ @property
803
+ def _next_nth_request (self ):
804
+ if self ._instance and self ._instance ._client :
805
+ return self ._instance ._client ._next_nth_request
806
+ raise Exception ("returning 1 for next_nth_request" )
807
+ return 1
808
+
809
+ @property
810
+ def _nth_client_id (self ):
811
+ return self ._instance ._client ._nth_client_id
812
+
748
813
def session (self , labels = None , database_role = None ):
749
814
"""Factory to create a session for this database.
750
815
@@ -965,7 +1030,8 @@ def restore(self, source):
965
1030
)
966
1031
future = api .restore_database (
967
1032
request = request ,
968
- metadata = metadata ,
1033
+ # TODO: Infer the channel_id being used.
1034
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
969
1035
)
970
1036
return future
971
1037
@@ -1034,7 +1100,10 @@ def list_database_roles(self, page_size=None):
1034
1100
parent = self .name ,
1035
1101
page_size = page_size ,
1036
1102
)
1037
- return api .list_database_roles (request = request , metadata = metadata )
1103
+ return api .list_database_roles (
1104
+ request = request ,
1105
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
1106
+ )
1038
1107
1039
1108
def table (self , table_id ):
1040
1109
"""Factory to create a table object within this database.
@@ -1118,7 +1187,10 @@ def get_iam_policy(self, policy_version=None):
1118
1187
requested_policy_version = policy_version
1119
1188
),
1120
1189
)
1121
- response = api .get_iam_policy (request = request , metadata = metadata )
1190
+ response = api .get_iam_policy (
1191
+ request = request ,
1192
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
1193
+ )
1122
1194
return response
1123
1195
1124
1196
def set_iam_policy (self , policy ):
@@ -1140,7 +1212,10 @@ def set_iam_policy(self, policy):
1140
1212
resource = self .name ,
1141
1213
policy = policy ,
1142
1214
)
1143
- response = api .set_iam_policy (request = request , metadata = metadata )
1215
+ response = api .set_iam_policy (
1216
+ request = request ,
1217
+ metadata = self .metadata_with_request_id (self ._next_nth_request , 1 , metadata ),
1218
+ )
1144
1219
return response
1145
1220
1146
1221
@property
0 commit comments