77import  sys 
88import  warnings 
99from  abc  import  ABC , abstractmethod 
10+ from  collections  import  OrderedDict 
1011from  contextvars  import  ContextVar 
1112from  copy  import  deepcopy 
1213from  dataclasses  import  dataclass 
1920    Dict ,
2021    Optional ,
2122    Protocol ,
23+     Set ,
2224    Tuple ,
2325    Type ,
2426    Union ,
@@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868870        raise  ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869871
870872    if  device  is  None :
871-         device_type  =  _find_device (model )
873+         device  =  _find_device (model )
874+         device_type  =  _find_device_type (model )
872875    elif  isinstance (device , str ):
873876        _validate_device_type (device )
877+         import  torch 
878+ 
874879        device_type  =  Device (type = device )
880+         device  =  torch .device (device )
875881    else :
876882        device_type  =  Device (device .type )
877883
@@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884890        layer_name  =  module_class .kernel_layer_name 
885891
886892        if  _DISABLE_KERNEL_MAPPING :
887-             _replace_forward (module , module_class )
893+             _replace_forward (device ,  module , module_class )
888894            continue 
889895
890896        kernel  =  _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898904            )
899905            if  not  use_fallback :
900906                raise  ValueError (f"No layer mapping for `{ layer_name }  )
901-             _replace_forward (module , module_class )
907+             _replace_forward (device ,  module , module_class )
902908            continue 
903909
904910        # Get kernel options for the device 
@@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909915                raise  ValueError (
910916                    f"No layer mapping for `{ layer_name } { device_type }  
911917                )
912-             _replace_forward (module , module_class )
918+             _replace_forward (device ,  module , module_class )
913919            continue 
914920
915921        repos  =  property_repos .repos 
@@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919925                raise  ValueError (
920926                    f"No layer mapping for `{ layer_name } { device_type }  
921927                )
922-             _replace_forward (module , module_class )
928+             _replace_forward (device ,  module , module_class )
923929            continue 
924930
925931        repo_with_mode  =  _select_repository (
@@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932938                raise  ValueError (
933939                    f"No repository for `{ layer_name } { mode }  
934940                )
935-             _replace_forward (module , module_class )
941+             _replace_forward (device ,  module , module_class )
936942            continue 
937943
938944        repo , repo_mode  =  repo_with_mode 
@@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951957        )
952958
953959        _conditionally_replace_forward (
960+             device = device ,
954961            module = module ,
955962            layer = layer ,
956963            mode = mode ,
@@ -1037,19 +1044,30 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371044        raise  TypeError (f"{ repo }  )
10381045
10391046    # ... or predefined member variables. 
1040-     torch_module_members  =  {name  for  name , _  in  inspect .getmembers (nn .Module )}
1041-     cls_members  =  {name  for  name , _  in  inspect .getmembers (cls )}
1042-     difference  =  cls_members  -  torch_module_members 
1047+     unique_members  =  _unique_layer_members (cls )
10431048    # verify if : difference ⊄ {"can_torch_compile", "has_backward"} 
1044-     if  not  difference  <=  {"can_torch_compile" , "has_backward" }:
1049+     if  not  unique_members  <=  {
1050+         "can_torch_compile" ,
1051+         "create_state" ,
1052+         "has_backward" ,
1053+         "forward_with_state" ,
1054+     }:
10451055        raise  TypeError (
10461056            f"{ repo } { check_cls .__name__ }  
10471057        )
10481058
10491059    # Check whether the forward signatures are similar. 
1050-     params  =  inspect .signature (cls .forward ).parameters 
10511060    ref_params  =  inspect .signature (check_cls .forward ).parameters 
10521061
1062+     if  _is_stateful_layer (cls ):
1063+         params  =  inspect .signature (cls .forward_with_state ).parameters 
1064+         # Get rid of the mappingproxy. 
1065+         params  =  OrderedDict (params )
1066+         # Remove the state to be able to compare with forward. 
1067+         del  params ["state" ]
1068+     else :
1069+         params  =  inspect .signature (cls .forward ).parameters 
1070+ 
10531071    if  len (params ) !=  len (ref_params ):
10541072        raise  TypeError (
10551073            f"Forward signature of { repo } { check_cls .__name__ }  
@@ -1074,15 +1092,21 @@ def _is_rocm_platform():
10741092    return  torch .version .hip  is  not None 
10751093
10761094
1077- def  _find_device (model : "nn.Module" ) ->  Device :
1095+ def  _find_device (model : "nn.Module" ) ->  torch . device :
10781096    try :
10791097        param  =  next (model .parameters ())
10801098    except  StopIteration :
10811099        raise  ValueError (
10821100            "Cannot determine model device, provide as `device` argument to `kernelize`." 
10831101        )
10841102
1085-     dev_type  =  param .device .type 
1103+     return  param .device 
1104+ 
1105+ 
1106+ def  _find_device_type (model : "nn.Module" ) ->  Device :
1107+     device  =  _find_device (model )
1108+ 
1109+     dev_type  =  device .type 
10861110    if  dev_type  ==  "cuda" :
10871111        # Refine based on actual platform 
10881112        if  _is_rocm_platform ():
@@ -1103,6 +1127,7 @@ def _find_capability() -> int:
11031127
11041128def  _conditionally_replace_forward (
11051129    * ,
1130+     device : "torch.device" ,
11061131    module : "nn.Module" ,
11071132    layer : Type ["nn.Module" ],
11081133    mode : Mode ,
@@ -1128,15 +1153,25 @@ def _conditionally_replace_forward(
11281153                logging .info ("Layer does not support torch.compile, using fallback" )
11291154            if  needs_fallback_for_backward :
11301155                logging .info ("Layer does not support backward, using fallback" )
1131-             _replace_forward (module , module_class )
1156+             _replace_forward (device ,  module , module_class )
11321157        else :
11331158            raise  ValueError (f"Available kernel does not support mode: { mode }  )
11341159    else :
1135-         _replace_forward (module , layer )
1160+         _replace_forward (device ,  module , layer )
11361161
11371162
1138- def  _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139-     module .forward  =  MethodType (layer .forward , module )  # type: ignore[method-assign] 
1163+ def  _replace_forward (
1164+     device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1165+ ):
1166+     if  _is_stateful_layer (layer ):
1167+         state  =  layer .create_state (device , module )
1168+ 
1169+         def  forward (self , * args , ** kwargs ):
1170+             return  layer .forward_with_state (self , state , * args , ** kwargs )
1171+ 
1172+         module .forward  =  MethodType (forward , module )
1173+     else :
1174+         module .forward  =  MethodType (layer .forward , module )  # type: ignore[method-assign] 
11401175
11411176
11421177def  _validate_layer_has_mode (
@@ -1179,3 +1214,21 @@ def _get_layer_memoize(
11791214    _CACHED_LAYER [repo ] =  layer 
11801215
11811216    return  layer 
1217+ 
1218+ 
1219+ def  _unique_layer_members (layer : Type ["nn.Module" ]) ->  Set [str ]:
1220+     import  torch .nn  as  nn 
1221+ 
1222+     torch_module_members  =  {name  for  name , _  in  inspect .getmembers (nn .Module )}
1223+     cls_members  =  {name  for  name , _  in  inspect .getmembers (layer )}
1224+     return  cls_members  -  torch_module_members 
1225+ 
1226+ 
1227+ def  _is_stateful_layer (layer : Type [nn .Module ]) ->  bool :
1228+     unique  =  _unique_layer_members (layer )
1229+     is_stateful  =  "forward_with_state"  in  unique 
1230+     if  is_stateful  and  len (unique  &  {"create_state" , "forward_with_state" }) !=  2 :
1231+         raise  TypeError (
1232+             f"Stateful layer `{ layer .__name__ }  
1233+         )
1234+     return  is_stateful 
0 commit comments