1919    Dict ,
2020    Optional ,
2121    Protocol ,
22+     Set ,
2223    Tuple ,
2324    Type ,
2425    Union ,
@@ -868,9 +869,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868869        raise  ValueError ("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING." )
869870
870871    if  device  is  None :
871-         device_type  =  _find_device (model )
872+         device  =  _find_device (model )
873+         device_type  =  _find_device_type (model )
872874    elif  isinstance (device , str ):
873875        _validate_device_type (device )
876+         import  torch 
877+ 
878+         device  =  torch .device (device )
874879        device_type  =  Device (type = device )
875880    else :
876881        device_type  =  Device (device .type )
@@ -884,7 +889,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884889        layer_name  =  module_class .kernel_layer_name 
885890
886891        if  _DISABLE_KERNEL_MAPPING :
887-             _replace_forward (module , module_class )
892+             _replace_forward (device ,  module , module_class )
888893            continue 
889894
890895        kernel  =  _KERNEL_MAPPING .get ().get (str (layer_name ))
@@ -898,7 +903,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898903            )
899904            if  not  use_fallback :
900905                raise  ValueError (f"No layer mapping for `{ layer_name }  )
901-             _replace_forward (module , module_class )
906+             _replace_forward (device ,  module , module_class )
902907            continue 
903908
904909        # Get kernel options for the device 
@@ -909,7 +914,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909914                raise  ValueError (
910915                    f"No layer mapping for `{ layer_name } { device_type }  
911916                )
912-             _replace_forward (module , module_class )
917+             _replace_forward (device ,  module , module_class )
913918            continue 
914919
915920        repos  =  property_repos .repos 
@@ -919,7 +924,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919924                raise  ValueError (
920925                    f"No layer mapping for `{ layer_name } { device_type }  
921926                )
922-             _replace_forward (module , module_class )
927+             _replace_forward (device ,  module , module_class )
923928            continue 
924929
925930        repo_with_mode  =  _select_repository (
@@ -932,7 +937,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932937                raise  ValueError (
933938                    f"No repository for `{ layer_name } { mode }  
934939                )
935-             _replace_forward (module , module_class )
940+             _replace_forward (device ,  module , module_class )
936941            continue 
937942
938943        repo , repo_mode  =  repo_with_mode 
@@ -951,6 +956,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951956        )
952957
953958        _conditionally_replace_forward (
959+             device = device ,
954960            module = module ,
955961            layer = layer ,
956962            mode = mode ,
@@ -1037,19 +1043,26 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371043        raise  TypeError (f"{ repo }  )
10381044
10391045    # ... or predefined member variables. 
1040-     torch_module_members  =  {name  for  name , _  in  inspect .getmembers (nn .Module )}
1041-     cls_members  =  {name  for  name , _  in  inspect .getmembers (cls )}
1042-     difference  =  cls_members  -  torch_module_members 
1046+     unique_members  =  _unique_layer_members (cls )
10431047    # verify if : difference ⊄ {"can_torch_compile", "has_backward"} 
1044-     if  not  difference  <=  {"can_torch_compile" , "has_backward" }:
1048+     if  not  unique_members  <=  {
1049+         "can_torch_compile" ,
1050+         "create_state" ,
1051+         "has_backward" ,
1052+         "forward_with_state" ,
1053+     }:
10451054        raise  TypeError (
10461055            f"{ repo } { check_cls .__name__ }  
10471056        )
10481057
10491058    # Check whether the forward signatures are similar. 
1050-     params  =  inspect .signature (cls .forward ).parameters 
10511059    ref_params  =  inspect .signature (check_cls .forward ).parameters 
10521060
1061+     if  _is_stateful_layer (cls ):
1062+         params  =  inspect .signature (cls .forward_with_state ).parameters 
1063+     else :
1064+         params  =  inspect .signature (cls .forward ).parameters 
1065+ 
10531066    if  len (params ) !=  len (ref_params ):
10541067        raise  TypeError (
10551068            f"Forward signature of { repo } { check_cls .__name__ }  
@@ -1074,15 +1087,21 @@ def _is_rocm_platform():
10741087    return  torch .version .hip  is  not None 
10751088
10761089
1077- def  _find_device (model : "nn.Module" ) ->  Device :
1090+ def  _find_device (model : "nn.Module" ) ->  torch . device :
10781091    try :
10791092        param  =  next (model .parameters ())
10801093    except  StopIteration :
10811094        raise  ValueError (
10821095            "Cannot determine model device, provide as `device` argument to `kernelize`." 
10831096        )
10841097
1085-     dev_type  =  param .device .type 
1098+     return  param .device 
1099+ 
1100+ 
1101+ def  _find_device_type (model : "nn.Module" ) ->  Device :
1102+     device  =  _find_device (model )
1103+ 
1104+     dev_type  =  device .type 
10861105    if  dev_type  ==  "cuda" :
10871106        # Refine based on actual platform 
10881107        if  _is_rocm_platform ():
@@ -1103,6 +1122,7 @@ def _find_capability() -> int:
11031122
11041123def  _conditionally_replace_forward (
11051124    * ,
1125+     device : "torch.device" ,
11061126    module : "nn.Module" ,
11071127    layer : Type ["nn.Module" ],
11081128    mode : Mode ,
@@ -1128,15 +1148,25 @@ def _conditionally_replace_forward(
11281148                logging .info ("Layer does not support torch.compile, using fallback" )
11291149            if  needs_fallback_for_backward :
11301150                logging .info ("Layer does not support backward, using fallback" )
1131-             _replace_forward (module , module_class )
1151+             _replace_forward (device ,  module , module_class )
11321152        else :
11331153            raise  ValueError (f"Available kernel does not support mode: { mode }  )
11341154    else :
1135-         _replace_forward (module , layer )
1155+         _replace_forward (device ,  module , layer )
11361156
11371157
1138- def  _replace_forward (module : "nn.Module" , layer : Type ["nn.Module" ]):
1139-     module .forward  =  MethodType (layer .forward , module )  # type: ignore[method-assign] 
1158+ def  _replace_forward (
1159+     device : "torch.device" , module : "nn.Module" , layer : Type ["nn.Module" ]
1160+ ):
1161+     if  _is_stateful_layer (layer ):
1162+         state  =  layer .create_state (module , device )
1163+ 
1164+         def  forward (self , * args , ** kwargs ):
1165+             return  layer .forward_with_state (self , state , * args , ** kwargs )
1166+ 
1167+         module .forward  =  forward 
1168+     else :
1169+         module .forward  =  MethodType (layer .forward , module )  # type: ignore[method-assign] 
11401170
11411171
11421172def  _validate_layer_has_mode (
@@ -1179,3 +1209,21 @@ def _get_layer_memoize(
11791209    _CACHED_LAYER [repo ] =  layer 
11801210
11811211    return  layer 
1212+ 
1213+ 
1214+ def  _unique_layer_members (layer : Type ["nn.Module" ]) ->  Set [str ]:
1215+     import  torch .nn  as  nn 
1216+ 
1217+     torch_module_members  =  {name  for  name , _  in  inspect .getmembers (nn .Module )}
1218+     cls_members  =  {name  for  name , _  in  inspect .getmembers (layer )}
1219+     return  cls_members  -  torch_module_members 
1220+ 
1221+ 
1222+ def  _is_stateful_layer (layer : Type [nn .Module ]) ->  bool :
1223+     unique  =  _unique_layer_members (layer )
1224+     is_stateful  =  "forward_with_state"  in  unique 
1225+     if  is_stateful  and  len (unique  &  {"create_state" , "forward_with_state" }) !=  2 :
1226+         raise  TypeError (
1227+             f"Stateful layer `{ layer .__name__ }  
1228+         )
1229+     return  is_stateful 
0 commit comments