diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index e76284864b..7b86a5653f 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -707,6 +707,10 @@ def __radd__(self, n): """ return self.__add__(n) + def reset(self): + with self.__lock: + self.__value = 0 + def _metadata_with_request_id(*args, **kwargs): return with_request_id(*args, **kwargs) diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index c006b965cf..e0e8c44058 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -70,6 +70,7 @@ except ImportError: # pragma: NO COVER HAS_GOOGLE_CLOUD_MONITORING_INSTALLED = False +from google.cloud.spanner_v1._helpers import AtomicCounter _CLIENT_INFO = client_info.ClientInfo(client_library_version=__version__) EMULATOR_ENV_VAR = "SPANNER_EMULATOR_HOST" @@ -182,6 +183,8 @@ class Client(ClientWithProject): SCOPE = (SPANNER_ADMIN_SCOPE,) """The scopes required for Google Cloud Spanner.""" + NTH_CLIENT = AtomicCounter() + def __init__( self, project=None, @@ -263,6 +266,12 @@ def __init__( "default_transaction_options must be an instance of DefaultTransactionOptions" ) self._default_transaction_options = default_transaction_options + self._nth_client_id = Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter(0) + + @property + def _next_nth_request(self): + return self._nth_request.increment() @property def credentials(self): diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py index 8376778273..74a5bb1253 100644 --- a/google/cloud/spanner_v1/request_id_header.py +++ b/google/cloud/spanner_v1/request_id_header.py @@ -37,6 +37,6 @@ def generate_rand_uint64(): def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]): req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" - all_metadata = other_metadata.copy() + all_metadata = (other_metadata or []).copy() all_metadata.append((REQ_ID_HEADER_KEY, req_id)) return all_metadata diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py index 54afda11e0..5af89fea42 100644 --- a/google/cloud/spanner_v1/testing/database_test.py +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -25,6 +25,7 @@ from google.cloud.spanner_v1.testing.interceptors import ( MethodCountInterceptor, MethodAbortInterceptor, + XGoogRequestIDHeaderInterceptor, ) @@ -34,6 +35,8 @@ class TestDatabase(Database): currently, and we don't want to make changes in the Database class for testing purpose as this is a hack to use interceptors in tests.""" + _interceptors = [] + def __init__( self, database_id, @@ -74,6 +77,8 @@ def spanner_api(self): client_options = client._client_options if self._instance.emulator_host is not None: channel = grpc.insecure_channel(self._instance.emulator_host) + self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor() + self._interceptors.append(self._x_goog_request_id_interceptor) channel = grpc.intercept_channel(channel, *self._interceptors) transport = SpannerGrpcTransport(channel=channel) self._spanner_api = SpannerClient( @@ -110,3 +115,7 @@ def _create_spanner_client_for_tests(self, client_options, credentials): client_options=client_options, transport=transport, ) + + def reset(self): + if self._x_goog_request_id_interceptor: + self._x_goog_request_id_interceptor.reset() diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index a8b015a87d..bf5e271e26 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -13,6 +13,8 @@ # limitations under the License. from collections import defaultdict +import threading + from grpc_interceptor import ClientInterceptor from google.api_core.exceptions import Aborted @@ -63,3 +65,72 @@ def reset(self): self._method_to_abort = None self._count = 0 self._connection = None + + +X_GOOG_REQUEST_ID = "x-goog-spanner-request-id" + + +class XGoogRequestIDHeaderInterceptor(ClientInterceptor): + # TODO:(@odeke-em): delete this guard when PR #1367 is merged. + X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED = False + + def __init__(self): + self._unary_req_segments = [] + self._stream_req_segments = [] + self.__lock = threading.Lock() + + def intercept(self, method, request_or_iterator, call_details): + metadata = call_details.metadata + x_goog_request_id = None + for key, value in metadata: + if key == X_GOOG_REQUEST_ID: + x_goog_request_id = value + break + + if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED and not x_goog_request_id: + raise Exception( + f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}" + ) + + response_or_iterator = method(request_or_iterator, call_details) + streaming = getattr(response_or_iterator, "__iter__", None) is not None + + if self.X_GOOG_REQUEST_ID_FUNCTIONALITY_MERGED: + with self.__lock: + if streaming: + self._stream_req_segments.append( + (call_details.method, parse_request_id(x_goog_request_id)) + ) + else: + self._unary_req_segments.append( + (call_details.method, parse_request_id(x_goog_request_id)) + ) + + return response_or_iterator + + @property + def unary_request_ids(self): + return self._unary_req_segments + + @property + def stream_request_ids(self): + return self._stream_req_segments + + def reset(self): + self._stream_req_segments.clear() + self._unary_req_segments.clear() + + +def parse_request_id(request_id_str): + splits = request_id_str.split(".") + version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list( + map(lambda v: int(v), splits) + ) + return ( + version, + rand_process_id, + client_id, + channel_id, + nth_request, + nth_attempt, + ) diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index f60dbbe72a..f8971a6098 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -22,8 +22,6 @@ from google.cloud.spanner_v1 import ( TransactionOptions, ResultSetMetadata, - ExecuteSqlRequest, - ExecuteBatchDmlRequest, ) from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc @@ -107,6 +105,7 @@ def CreateSession(self, request, context): def BatchCreateSessions(self, request, context): self._requests.append(request) + self.mock_spanner.pop_error(context) sessions = [] for i in range(request.session_count): sessions.append( @@ -186,9 +185,7 @@ def BeginTransaction(self, request, context): self._requests.append(request) return self.__create_transaction(request.session, request.options) - def __maybe_create_transaction( - self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest - ): + def __maybe_create_transaction(self, request): started_transaction = None if not request.transaction.begin == TransactionOptions(): started_transaction = self.__create_transaction( diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index b332c88d7c..7b4538d601 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -153,6 +153,7 @@ def setup_class(cls): def teardown_class(cls): if MockServerTestBase.server is not None: MockServerTestBase.server.stop(grace=None) + Client.NTH_CLIENT.reset() MockServerTestBase.server = None def setup_method(self, *args, **kwargs): @@ -186,6 +187,8 @@ def instance(self) -> Instance: def database(self) -> Database: if self._database is None: self._database = self.instance.database( - "test-database", pool=FixedSizePool(size=10) + "test-database", + pool=FixedSizePool(size=10), + enable_interceptors_in_tests=True, ) return self._database diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index ddc91ea522..ff4743f1f6 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -21,6 +21,10 @@ from google.cloud.spanner_v1 import TypeCode from google.api_core.retry import Retry from google.api_core import gapic_v1 +from google.cloud.spanner_v1._helpers import ( + AtomicCounter, + _metadata_with_request_id, +) from tests._helpers import ( HAS_OPENTELEMETRY_INSTALLED, @@ -197,6 +201,11 @@ def test_begin_ok(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + # TODO(@odeke-em): enable with PR #1367. + # ( + # "x-goog-spanner-request-id", + # f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + # ), ], ) @@ -301,6 +310,11 @@ def test_rollback_ok(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + # TODO(@odeke-em): enable with PR #1367. + # ( + # "x-goog-spanner-request-id", + # f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + # ), ], ) @@ -492,6 +506,11 @@ def _commit_helper( [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + # TODO(@odeke-em): enable with PR #1367. + # ( + # "x-goog-spanner-request-id", + # f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + # ), ], ) self.assertEqual(actual_request_options, expected_request_options) @@ -666,6 +685,11 @@ def _execute_update_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + # TODO(@odeke-em): enable with PR #1367. + # ( + # "x-goog-spanner-request-id", + # f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + # ), ], ) @@ -859,6 +883,11 @@ def _batch_update_helper( metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + # TODO(@odeke-em): enable with PR #1367. + # ( + # "x-goog-spanner-request-id", + # f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.1.1", + # ), ], retry=retry, timeout=timeout, @@ -974,6 +1003,11 @@ def test_context_mgr_success(self): [ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), + # TODO(@odeke-em): enable with PR #1367. + # ( + # "x-goog-spanner-request-id", + # f"1.{REQ_RAND_PROCESS_ID}.{_Client.NTH_CLIENT.value}.1.2.1", + # ), ], ) @@ -1004,11 +1038,19 @@ def test_context_mgr_failure(self): class _Client(object): + NTH_CLIENT = AtomicCounter() + def __init__(self): from google.cloud.spanner_v1 import ExecuteSqlRequest self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1") self.directed_read_options = None + self._nth_client_id = _Client.NTH_CLIENT.increment() + self._nth_request = AtomicCounter() + + @property + def _next_nth_request(self): + return self._nth_request.increment() class _Instance(object): @@ -1024,6 +1066,27 @@ def __init__(self): self._directed_read_options = None self.default_transaction_options = DefaultTransactionOptions() + @property + def _next_nth_request(self): + return self._instance._client._next_nth_request + + @property + def _nth_client_id(self): + return self._instance._client._nth_client_id + + def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]): + return _metadata_with_request_id( + self._nth_client_id, + self._channel_id, + nth_request, + nth_attempt, + prior_metadata, + ) + + @property + def _channel_id(self): + return 1 + class _Session(object): _transaction = None