@@ -576,6 +576,7 @@ def _retry(
576
576
577
577
578
578
def _check_rst_stream_error (exc ):
579
+ print ("\033 [31mrst_" , exc , "\033 [00m" )
579
580
resumable_error = (
580
581
any (
581
582
resumable_message in exc .message
@@ -589,6 +590,11 @@ def _check_rst_stream_error(exc):
589
590
raise
590
591
591
592
593
+ def _check_unavailable (exc ):
594
+ print ("\033 [31mcheck_unavailable" , exc , "\033 [00m" )
595
+ raise
596
+
597
+
592
598
def _metadata_with_leader_aware_routing (value , ** kw ):
593
599
"""Create RPC metadata containing a leader aware routing header
594
600
@@ -763,96 +769,164 @@ def __init__(self, original_callable: Callable):
763
769
764
770
765
771
def inject_retry_header_control (api ):
766
- return
767
- monkey_patch (type (api ))
772
+ # monkey_patch(type(api))
773
+ # monkey_patch(api)
774
+ pass
768
775
769
- memoize_map = dict ()
770
776
771
- def monkey_patch (obj ):
772
- return
777
+ def monkey_patch (typ ):
778
+ keys = dir (typ )
779
+ attempts = dict ()
780
+ for key in keys :
781
+ if key .startswith ("_" ):
782
+ continue
773
783
774
- """
775
- klass = obj
776
- attrs = dir(klass)
777
- for attr_key in attrs:
778
- if attr_key.startswith('_'):
784
+ if key != "batch_create_sessions" :
779
785
continue
780
786
781
- attr_value = getattr(obj, attr_key)
782
- if not callable(attr_value):
787
+ fn = getattr (typ , key )
788
+
789
+ signature = inspect .signature (fn )
790
+ if signature .parameters .get ("metadata" , None ) is None :
783
791
continue
784
792
785
- signature = inspect.signature(attr_value)
786
- print(attr_key, signature.parameters)
793
+ print ("fn.__call__" , inspect .getsource (fn ))
787
794
788
- call = attr_value
789
- # Our goal is to replace the runtime pass through.
790
- def wrapped(*args, **kwargs):
791
- print(attr_key, 'called')
792
- return call(*args, **kwargs)
795
+ def as_proxy (db , * args , ** kwargs ):
796
+ print ("db_key" , hex (id (db )))
797
+ print ("as_proxy" , args , kwargs )
798
+ metadata = kwargs .get ("metadata" , None )
799
+ if not metadata :
800
+ return fn (db , * args , ** kwargs )
793
801
794
- setattr(klass, attr_key, wrapped)
802
+ hash_key = hex (id (db )) + "." + hex (id (key ))
803
+ attempts .setdefault (hash_key , 0 )
804
+ attempts [hash_key ] += 1
805
+ # 4. Find all the headers that match the target header key.
806
+ all_metadata = []
807
+ for mkey , value in metadata :
808
+ if mkey is not REQ_ID_HEADER_KEY :
809
+ continue
795
810
796
- return
797
- """
811
+ splits = value .split ("." )
812
+ # 5. Increment the original_attempt with that of our re-invocation count.
813
+ print ("\033 [34mkey" , mkey , "\033 [00m" , splits )
814
+ hdr_attempt_plus_reinvocation = int (splits [- 1 ]) + attempts [hash_key ]
815
+ splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
816
+ value = "." .join (splits )
798
817
818
+ all_metadata .append ((mkey , value ))
819
+
820
+ kwargs ["metadata" ] = all_metadata
821
+ return fn (db , * args , ** kwargs )
822
+
823
+ setattr (typ , key , as_proxy )
824
+
825
+
826
+ def alt_foo ():
827
+ memoize_map = dict ()
799
828
orig_get_attr = getattr (obj , "__getattribute__" )
829
+ hex_orig = hex (id (orig_get_attr ))
830
+ hex_patched = None
831
+
800
832
def patched_getattribute (obj , key , * args , ** kwargs ):
801
- if key .startswith ('_' ):
833
+ if key .startswith ("_" ):
802
834
return orig_get_attr (obj , key , * args , ** kwargs )
803
835
804
- orig_value = orig_get_attr (obj , key , * args , ** kwargs )
805
- if not callable (orig_value ):
806
- return orig_value
836
+ if key != "batch_create_sessions" :
837
+ return orig_get_attr (obj , key , * args , ** kwargs )
807
838
808
839
map_key = hex (id (key )) + hex (id (obj ))
809
840
memoized = memoize_map .get (map_key , None )
810
841
if memoized :
811
- print ("memoized_hit" , key , '\033 [35m' , inspect .getsource (orig_value ), '\033 [00m' )
842
+ if False :
843
+ print (
844
+ "memoized_hit" ,
845
+ key ,
846
+ "\033 [35m" ,
847
+ inspect .getsource (orig_value ),
848
+ "\033 [00m" ,
849
+ )
850
+ print ("memoized_hit" , key , "\033 [35m" , map_key , "\033 [00m" )
812
851
return memoized
813
852
853
+ orig_value = orig_get_attr (obj , key , * args , ** kwargs )
854
+ if not callable (orig_value ):
855
+ return orig_value
856
+
814
857
signature = inspect .signature (orig_value )
815
- if signature .parameters .get (' metadata' , None ) is None :
858
+ if signature .parameters .get (" metadata" , None ) is None :
816
859
return orig_value
817
860
818
- print (key , '\033 [34m' , map_key , '\033 [00m' , signature , signature .parameters .get ('metadata' , None ))
861
+ if False :
862
+ print (
863
+ key ,
864
+ "\033 [34m" ,
865
+ map_key ,
866
+ "\033 [00m" ,
867
+ signature ,
868
+ signature .parameters .get ("metadata" , None ),
869
+ )
870
+
871
+ if False :
872
+ stack = inspect .stack ()
873
+ ends = stack [- 50 :- 20 ]
874
+ for i , st in enumerate (ends ):
875
+ print (i , st .filename , st .lineno )
876
+
877
+ print (
878
+ "\033 [33mmonkey patching now\033 [00m" ,
879
+ key ,
880
+ "hex_orig" ,
881
+ hex_orig ,
882
+ "hex_patched" ,
883
+ hex_patched ,
884
+ )
819
885
counters = dict (attempt = 0 )
886
+
820
887
def patched_method (* aargs , ** kkwargs ):
821
- counters ['attempt' ] += 1
822
- metadata = kkwargs .get ('metadata' , None )
888
+ counters ["attempt" ] += 1
889
+ print ("counters" , counters )
890
+ metadata = kkwargs .get ("metadata" , None )
823
891
if not metadata :
824
892
return orig_value (* aargs , ** kkwargs )
825
893
826
894
# 4. Find all the headers that match the target header key.
827
895
all_metadata = []
828
896
for mkey , value in metadata :
829
897
if mkey is REQ_ID_HEADER_KEY :
830
- attempt = counters [' attempt' ]
898
+ attempt = counters [" attempt" ]
831
899
if attempt > 1 :
832
900
# 5. Increment the original_attempt with that of our re-invocation count.
833
901
splits = value .split ("." )
834
- print ('\033 [34mkey' , mkey , '\033 [00m' , splits )
835
- hdr_attempt_plus_reinvocation = (
836
- int (splits [- 1 ]) + attempt
837
- )
902
+ print ("\033 [34mkey" , mkey , "\033 [00m" , splits )
903
+ hdr_attempt_plus_reinvocation = int (splits [- 1 ]) + attempt
838
904
splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
839
905
value = "." .join (splits )
840
906
841
907
all_metadata .append ((mkey , value ))
842
908
843
909
kwargs ["metadata" ] = all_metadata
844
- return orig_value (* aargs , ** kkwargs )
910
+
911
+ try :
912
+ return orig_value (* aargs , ** kkwargs )
913
+
914
+ except (InternalServerError , ServiceUnavailable ) as exc :
915
+ print ("caught this exception, incrementing" , exc )
916
+ counters ["attempt" ] += 1
917
+ raise exc
845
918
846
919
memoize_map [map_key ] = patched_method
847
920
return patched_method
848
921
849
- setattr (obj , '__getattribute__' , patched_getattribute )
922
+ hex_patched = hex (id (patched_getattribute ))
923
+ setattr (obj , "__getattribute__" , patched_getattribute )
850
924
851
925
852
926
def foo (api ):
853
927
global patched
854
928
global patched_mu
855
-
929
+
856
930
# For each method, add an _attempt value that'll then be
857
931
# retrieved for each retry.
858
932
# 1. Patch the __getattribute__ method to match items in our manifest.
@@ -878,20 +952,29 @@ def patched_getattribute(obj, key, *args, **kwargs):
878
952
patched_key = hex (id (key )) + hex (id (obj ))
879
953
patched_mu .acquire ()
880
954
already_patched = patched .get (patched_key , None )
881
-
955
+
882
956
other_attempts = dict (attempts = 0 )
957
+
883
958
# 3. Wrap the callable attribute and then capture its metadata keyed argument.
884
959
def wrapped_attr (* args , ** kwargs ):
885
- print ("\033 [31m" , key , "attempt" , other_attempts [' attempts' ], "\033 [00m" )
886
- other_attempts [' attempts' ] += 1
960
+ print ("\033 [31m" , key , "attempt" , other_attempts [" attempts" ], "\033 [00m" )
961
+ other_attempts [" attempts" ] += 1
887
962
888
963
metadata = kwargs .get ("metadata" , [])
889
964
if not metadata :
890
965
# Increment the reinvocation count.
891
966
wrapped_attr ._attempt += 1
892
967
return attr (* args , ** kwargs )
893
968
894
- print ("\033 [35mwrapped_attr" , key , args , kwargs , 'attempt' , wrapped_attr ._attempt , "\033 [00m" )
969
+ print (
970
+ "\033 [35mwrapped_attr" ,
971
+ key ,
972
+ args ,
973
+ kwargs ,
974
+ "attempt" ,
975
+ wrapped_attr ._attempt ,
976
+ "\033 [00m" ,
977
+ )
895
978
896
979
# 4. Find all the headers that match the target header key.
897
980
all_metadata = []
@@ -900,7 +983,7 @@ def wrapped_attr(*args, **kwargs):
900
983
if wrapped_attr ._attempt > 0 :
901
984
# 5. Increment the original_attempt with that of our re-invocation count.
902
985
splits = value .split ("." )
903
- print (' \033 [34mkey' , mkey , ' \033 [00m' , splits )
986
+ print (" \033 [34mkey" , mkey , " \033 [00m" , splits )
904
987
hdr_attempt_plus_reinvocation = (
905
988
int (splits [- 1 ]) + wrapped_attr ._attempt
906
989
)
@@ -916,13 +999,13 @@ def wrapped_attr(*args, **kwargs):
916
999
917
1000
if already_patched :
918
1001
print ("patched_key \033 [32m" , patched_key , key , "\033 [00m" , already_patched )
919
- setattr (attr , ' patched' , True )
1002
+ setattr (attr , " patched" , True )
920
1003
# Increment the reinvocation count.
921
1004
patched_mu .release ()
922
1005
return already_patched
923
1006
924
1007
patched [patched_key ] = wrapped_attr
925
- setattr (wrapped_attr , ' _attempt' , 0 )
1008
+ setattr (wrapped_attr , " _attempt" , 0 )
926
1009
patched_mu .release ()
927
1010
return wrapped_attr
928
1011
0 commit comments