Skip to content

Commit e67cc9e

Browse files
committed
Experiment with wrapping for gapic retries
1 parent c31c03f commit e67cc9e

File tree

5 files changed

+216
-65
lines changed

5 files changed

+216
-65
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 129 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def _retry(
576576

577577

578578
def _check_rst_stream_error(exc):
579+
print("\033[31mrst_", exc, "\033[00m")
579580
resumable_error = (
580581
any(
581582
resumable_message in exc.message
@@ -589,6 +590,11 @@ def _check_rst_stream_error(exc):
589590
raise
590591

591592

593+
def _check_unavailable(exc):
594+
print("\033[31mcheck_unavailable", exc, "\033[00m")
595+
raise
596+
597+
592598
def _metadata_with_leader_aware_routing(value, **kw):
593599
"""Create RPC metadata containing a leader aware routing header
594600
@@ -763,96 +769,164 @@ def __init__(self, original_callable: Callable):
763769

764770

765771
def inject_retry_header_control(api):
766-
return
767-
monkey_patch(type(api))
772+
# monkey_patch(type(api))
773+
# monkey_patch(api)
774+
pass
768775

769-
memoize_map = dict()
770776

771-
def monkey_patch(obj):
772-
return
777+
def monkey_patch(typ):
778+
keys = dir(typ)
779+
attempts = dict()
780+
for key in keys:
781+
if key.startswith("_"):
782+
continue
773783

774-
"""
775-
klass = obj
776-
attrs = dir(klass)
777-
for attr_key in attrs:
778-
if attr_key.startswith('_'):
784+
if key != "batch_create_sessions":
779785
continue
780786

781-
attr_value = getattr(obj, attr_key)
782-
if not callable(attr_value):
787+
fn = getattr(typ, key)
788+
789+
signature = inspect.signature(fn)
790+
if signature.parameters.get("metadata", None) is None:
783791
continue
784792

785-
signature = inspect.signature(attr_value)
786-
print(attr_key, signature.parameters)
793+
print("fn.__call__", inspect.getsource(fn))
787794

788-
call = attr_value
789-
# Our goal is to replace the runtime pass through.
790-
def wrapped(*args, **kwargs):
791-
print(attr_key, 'called')
792-
return call(*args, **kwargs)
795+
def as_proxy(db, *args, **kwargs):
796+
print("db_key", hex(id(db)))
797+
print("as_proxy", args, kwargs)
798+
metadata = kwargs.get("metadata", None)
799+
if not metadata:
800+
return fn(db, *args, **kwargs)
793801

794-
setattr(klass, attr_key, wrapped)
802+
hash_key = hex(id(db)) + "." + hex(id(key))
803+
attempts.setdefault(hash_key, 0)
804+
attempts[hash_key] += 1
805+
# 4. Find all the headers that match the target header key.
806+
all_metadata = []
807+
for mkey, value in metadata:
808+
if mkey is not REQ_ID_HEADER_KEY:
809+
continue
795810

796-
return
797-
"""
811+
splits = value.split(".")
812+
# 5. Increment the original_attempt with that of our re-invocation count.
813+
print("\033[34mkey", mkey, "\033[00m", splits)
814+
hdr_attempt_plus_reinvocation = int(splits[-1]) + attempts[hash_key]
815+
splits[-1] = str(hdr_attempt_plus_reinvocation)
816+
value = ".".join(splits)
798817

818+
all_metadata.append((mkey, value))
819+
820+
kwargs["metadata"] = all_metadata
821+
return fn(db, *args, **kwargs)
822+
823+
setattr(typ, key, as_proxy)
824+
825+
826+
def alt_foo():
827+
memoize_map = dict()
799828
orig_get_attr = getattr(obj, "__getattribute__")
829+
hex_orig = hex(id(orig_get_attr))
830+
hex_patched = None
831+
800832
def patched_getattribute(obj, key, *args, **kwargs):
801-
if key.startswith('_'):
833+
if key.startswith("_"):
802834
return orig_get_attr(obj, key, *args, **kwargs)
803835

804-
orig_value = orig_get_attr(obj, key, *args, **kwargs)
805-
if not callable(orig_value):
806-
return orig_value
836+
if key != "batch_create_sessions":
837+
return orig_get_attr(obj, key, *args, **kwargs)
807838

808839
map_key = hex(id(key)) + hex(id(obj))
809840
memoized = memoize_map.get(map_key, None)
810841
if memoized:
811-
print("memoized_hit", key, '\033[35m', inspect.getsource(orig_value), '\033[00m')
842+
if False:
843+
print(
844+
"memoized_hit",
845+
key,
846+
"\033[35m",
847+
inspect.getsource(orig_value),
848+
"\033[00m",
849+
)
850+
print("memoized_hit", key, "\033[35m", map_key, "\033[00m")
812851
return memoized
813852

853+
orig_value = orig_get_attr(obj, key, *args, **kwargs)
854+
if not callable(orig_value):
855+
return orig_value
856+
814857
signature = inspect.signature(orig_value)
815-
if signature.parameters.get('metadata', None) is None:
858+
if signature.parameters.get("metadata", None) is None:
816859
return orig_value
817860

818-
print(key, '\033[34m', map_key, '\033[00m', signature, signature.parameters.get('metadata', None))
861+
if False:
862+
print(
863+
key,
864+
"\033[34m",
865+
map_key,
866+
"\033[00m",
867+
signature,
868+
signature.parameters.get("metadata", None),
869+
)
870+
871+
if False:
872+
stack = inspect.stack()
873+
ends = stack[-50:-20]
874+
for i, st in enumerate(ends):
875+
print(i, st.filename, st.lineno)
876+
877+
print(
878+
"\033[33mmonkey patching now\033[00m",
879+
key,
880+
"hex_orig",
881+
hex_orig,
882+
"hex_patched",
883+
hex_patched,
884+
)
819885
counters = dict(attempt=0)
886+
820887
def patched_method(*aargs, **kkwargs):
821-
counters['attempt'] += 1
822-
metadata = kkwargs.get('metadata', None)
888+
counters["attempt"] += 1
889+
print("counters", counters)
890+
metadata = kkwargs.get("metadata", None)
823891
if not metadata:
824892
return orig_value(*aargs, **kkwargs)
825893

826894
# 4. Find all the headers that match the target header key.
827895
all_metadata = []
828896
for mkey, value in metadata:
829897
if mkey is REQ_ID_HEADER_KEY:
830-
attempt = counters['attempt']
898+
attempt = counters["attempt"]
831899
if attempt > 1:
832900
# 5. Increment the original_attempt with that of our re-invocation count.
833901
splits = value.split(".")
834-
print('\033[34mkey', mkey, '\033[00m', splits)
835-
hdr_attempt_plus_reinvocation = (
836-
int(splits[-1]) + attempt
837-
)
902+
print("\033[34mkey", mkey, "\033[00m", splits)
903+
hdr_attempt_plus_reinvocation = int(splits[-1]) + attempt
838904
splits[-1] = str(hdr_attempt_plus_reinvocation)
839905
value = ".".join(splits)
840906

841907
all_metadata.append((mkey, value))
842908

843909
kwargs["metadata"] = all_metadata
844-
return orig_value(*aargs, **kkwargs)
910+
911+
try:
912+
return orig_value(*aargs, **kkwargs)
913+
914+
except (InternalServerError, ServiceUnavailable) as exc:
915+
print("caught this exception, incrementing", exc)
916+
counters["attempt"] += 1
917+
raise exc
845918

846919
memoize_map[map_key] = patched_method
847920
return patched_method
848921

849-
setattr(obj, '__getattribute__', patched_getattribute)
922+
hex_patched = hex(id(patched_getattribute))
923+
setattr(obj, "__getattribute__", patched_getattribute)
850924

851925

852926
def foo(api):
853927
global patched
854928
global patched_mu
855-
929+
856930
# For each method, add an _attempt value that'll then be
857931
# retrieved for each retry.
858932
# 1. Patch the __getattribute__ method to match items in our manifest.
@@ -878,20 +952,29 @@ def patched_getattribute(obj, key, *args, **kwargs):
878952
patched_key = hex(id(key)) + hex(id(obj))
879953
patched_mu.acquire()
880954
already_patched = patched.get(patched_key, None)
881-
955+
882956
other_attempts = dict(attempts=0)
957+
883958
# 3. Wrap the callable attribute and then capture its metadata keyed argument.
884959
def wrapped_attr(*args, **kwargs):
885-
print("\033[31m", key, "attempt", other_attempts['attempts'], "\033[00m")
886-
other_attempts['attempts'] += 1
960+
print("\033[31m", key, "attempt", other_attempts["attempts"], "\033[00m")
961+
other_attempts["attempts"] += 1
887962

888963
metadata = kwargs.get("metadata", [])
889964
if not metadata:
890965
# Increment the reinvocation count.
891966
wrapped_attr._attempt += 1
892967
return attr(*args, **kwargs)
893968

894-
print("\033[35mwrapped_attr", key, args, kwargs, 'attempt', wrapped_attr._attempt, "\033[00m")
969+
print(
970+
"\033[35mwrapped_attr",
971+
key,
972+
args,
973+
kwargs,
974+
"attempt",
975+
wrapped_attr._attempt,
976+
"\033[00m",
977+
)
895978

896979
# 4. Find all the headers that match the target header key.
897980
all_metadata = []
@@ -900,7 +983,7 @@ def wrapped_attr(*args, **kwargs):
900983
if wrapped_attr._attempt > 0:
901984
# 5. Increment the original_attempt with that of our re-invocation count.
902985
splits = value.split(".")
903-
print('\033[34mkey', mkey, '\033[00m', splits)
986+
print("\033[34mkey", mkey, "\033[00m", splits)
904987
hdr_attempt_plus_reinvocation = (
905988
int(splits[-1]) + wrapped_attr._attempt
906989
)
@@ -916,13 +999,13 @@ def wrapped_attr(*args, **kwargs):
916999

9171000
if already_patched:
9181001
print("patched_key \033[32m", patched_key, key, "\033[00m", already_patched)
919-
setattr(attr, 'patched', True)
1002+
setattr(attr, "patched", True)
9201003
# Increment the reinvocation count.
9211004
patched_mu.release()
9221005
return already_patched
9231006

9241007
patched[patched_key] = wrapped_attr
925-
setattr(wrapped_attr, '_attempt', 0)
1008+
setattr(wrapped_attr, "_attempt", 0)
9261009
patched_mu.release()
9271010
return wrapped_attr
9281011

google/cloud/spanner_v1/batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def commit(
250250
observability_options=observability_options,
251251
metadata=metadata,
252252
), MetricsCapture():
253+
253254
def wrapped_method(*args, **kwargs):
254255
method = functools.partial(
255256
api.commit,

google/cloud/spanner_v1/database.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,15 +432,6 @@ def logger(self):
432432

433433
@property
434434
def spanner_api(self):
435-
"""Helper for session-related API calls."""
436-
api = self.__generate_spanner_api()
437-
if not api:
438-
return api
439-
440-
monkey_patch(api)
441-
return api
442-
443-
def __generate_spanner_api(self):
444435
"""Helper for session-related API calls."""
445436
if self._spanner_api is None:
446437
client_info = self._instance._client._client_info

0 commit comments

Comments
 (0)