1717from  typing  import  (
1818    TYPE_CHECKING ,
1919    Dict ,
20+     Mapping ,
2021    Optional ,
2122    Protocol ,
23+     Set ,
2224    Tuple ,
2325    Type ,
2426    Union ,
@@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868870        raise  ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869871
870872    if  device  is  None :
871-         device_type  =  _find_device (model )
873+         device  =  _find_device (model )
874+         device_type  =  _find_device_type (model )
872875    elif  isinstance (device , str ):
873876        _validate_device_type (device )
877+         import  torch 
878+ 
874879        device_type  =  Device (type = device )
880+         device  =  torch .device (device )
875881    else :
876882        device_type  =  Device (device .type )
877883
@@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884890        layer_name  =  module_class .kernel_layer_name 
885891
886892        if  _DISABLE_KERNEL_MAPPING :
887-             _replace_forward (module , module_class )
893+             _replace_forward (device ,  module , module_class )
888894            continue 
889895
890896        kernel  =  _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898904            )
899905            if  not  use_fallback :
900906                raise  ValueError (f"No layer mapping for `{ layer_name }  )
901-             _replace_forward (module , module_class )
907+             _replace_forward (device ,  module , module_class )
902908            continue 
903909
904910        # Get kernel options for the device 
@@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909915                raise  ValueError (
910916                    f"No layer mapping for `{ layer_name } { device_type }  
911917                )
912-             _replace_forward (module , module_class )
918+             _replace_forward (device ,  module , module_class )
913919            continue 
914920
915921        repos  =  property_repos .repos 
@@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919925                raise  ValueError (
920926                    f"No layer mapping for `{ layer_name } { device_type }  
921927                )
922-             _replace_forward (module , module_class )
928+             _replace_forward (device ,  module , module_class )
923929            continue 
924930
925931        repo_with_mode  =  _select_repository (
@@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932938                raise  ValueError (
933939                    f"No repository for `{ layer_name } { mode }  
934940                )
935-             _replace_forward (module , module_class )
941+             _replace_forward (device ,  module , module_class )
936942            continue 
937943
938944        repo , repo_mode  =  repo_with_mode 
@@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951957        )
952958
953959        _conditionally_replace_forward (
960+             device = device ,
954961            module = module ,
955962            layer = layer ,
956963            mode = mode ,
@@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371044        raise  TypeError (f"{ repo }  )
10381045
10391046    # ... or predefined member variables. 
1040-     torch_module_members  =  {name  for  name , _  in  inspect .getmembers (nn .Module )}
1041-     cls_members  =  {name  for  name , _  in  inspect .getmembers (cls )}
1042-     difference  =  cls_members  -  torch_module_members 
1047+     unique_members  =  _unique_layer_members (cls )
10431048    # verify if : difference ⊄ {"can_torch_compile", "has_backward"} 
1044-     if  not  difference  <=  {"can_torch_compile" , "has_backward" }:
1049+     if  not  unique_members  <=  {
1050+         "can_torch_compile" ,
1051+         "create_state" ,
1052+         "has_backward" ,
1053+         "forward_with_state" ,
1054+     }:
10451055        raise  TypeError (
10461056            f"{ repo } { check_cls .__name__ }  
10471057        )
10481058
10491059    # Check whether the forward signatures are similar. 
1050-     params  =  inspect .signature (cls .forward ).parameters 
10511060    ref_params  =  inspect .signature (check_cls .forward ).parameters 
10521061
1062+     params : Mapping [str , inspect .Parameter ]
1063+     if  _is_stateful_layer (cls ):
1064+         params  =  inspect .signature (cls .forward_with_state ).parameters 
1065+         # Get rid of the mappingproxy. 
1066+         params  =  params .copy ()
1067+         # Remove the state to be able to compare with forward. 
1068+         del  params ["state" ]
1069+     else :
1070+         params  =  inspect .signature (cls .forward ).parameters 
1071+ 
10531072    if  len (params ) !=  len (ref_params ):
10541073        raise  TypeError (
10551074            f"Forward signature of { repo } { check_cls .__name__ }  
@@ -1074,15 +1093,21 @@ def _is_rocm_platform():
10741093    return  torch .version .hip  is  not None 
10751094
10761095
1077- def  _find_device (model : "nn.Module" ) ->  Device :
1096+ def  _find_device (model : "nn.Module" ) ->  torch . device :
10781097    try :
10791098        param  =  next (model .parameters ())
10801099    except  StopIteration :
10811100        raise  ValueError (
10821101            "Cannot determine model device, provide as `device` argument to `kernelize`." 
10831102        )
10841103
1085-     dev_type  =  param .device .type 
1104+     return  param .device 
1105+ 
1106+ 
1107+ def  _find_device_type (model : "nn.Module" ) ->  Device :
1108+     device  =  _find_device (model )
1109+ 
1110+     dev_type  =  device .type 
10861111    if  dev_type  ==  "cuda" :
10871112        # Refine based on actual platform 
10881113        if  _is_rocm_platform ():
@@ -1103,6 +1128,7 @@ def _find_capability() -> int:
11031128
11041129def  _conditionally_replace_forward (
11051130    * ,
1131+     device : "torch.device" ,
11061132    module : "nn.Module" ,
11071133    layer : Type ["nn.Module" ],
11081134    mode : Mode ,
@@ -1128,15 +1154,25 @@ def _conditionally_replace_forward(
11281154                logging .info ("Layer does not support torch.compile, using fallback" )
11291155            if  needs_fallback_for_backward :
11301156                logging .info ("Layer does not support backward, using fallback" )
1131-             _replace_forward (module , module_class )
1157+             _replace_forward (device ,  module , module_class )
11321158        else :
11331159            raise  ValueError (f"Available kernel does not support mode: { mode }  )
11341160    else :
1135-         _replace_forward (module , layer )
1161+         _replace_forward (device ,  module , layer )
11361162
11371163
1138- def  _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139-     module .forward  =  MethodType (layer .forward , module )  # type: ignore[method-assign] 
1164+ def  _replace_forward (
1165+     device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1166+ ):
1167+     if  _is_stateful_layer (layer ):
1168+         state  =  layer .create_state (device , module )  # type: ignore[attr-defined] 
1169+ 
1170+         def  forward (self , * args , ** kwargs ):
1171+             return  layer .forward_with_state (self , state , * args , ** kwargs )
1172+ 
1173+         module .forward  =  MethodType (forward , module )
1174+     else :
1175+         module .forward  =  MethodType (layer .forward , module )  # type: ignore[method-assign] 
11401176
11411177
11421178def  _validate_layer_has_mode (
@@ -1179,3 +1215,21 @@ def _get_layer_memoize(
11791215    _CACHED_LAYER [repo ] =  layer 
11801216
11811217    return  layer 
1218+ 
1219+ 
1220+ def  _unique_layer_members (layer : Type ["nn.Module" ]) ->  Set [str ]:
1221+     import  torch .nn  as  nn 
1222+ 
1223+     torch_module_members  =  {name  for  name , _  in  inspect .getmembers (nn .Module )}
1224+     cls_members  =  {name  for  name , _  in  inspect .getmembers (layer )}
1225+     return  cls_members  -  torch_module_members 
1226+ 
1227+ 
1228+ def  _is_stateful_layer (layer : Type [nn .Module ]) ->  bool :
1229+     unique  =  _unique_layer_members (layer )
1230+     is_stateful  =  "forward_with_state"  in  unique 
1231+     if  is_stateful  and  len (unique  &  {"create_state" , "forward_with_state" }) !=  2 :
1232+         raise  TypeError (
1233+             f"Stateful layer `{ layer .__name__ }  
1234+         )
1235+     return  is_stateful 
0 commit comments