Skip to content

Commit a988871

Browse files
committed
Add support for stateful layers
1 parent 055a953 commit a988871

File tree

2 files changed

+113
-17
lines changed

2 files changed

+113
-17
lines changed

src/kernels/layer.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from typing import (
1818
TYPE_CHECKING,
1919
Dict,
20+
Mapping,
2021
Optional,
2122
Protocol,
23+
Set,
2224
Tuple,
2325
Type,
2426
Union,
@@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868870
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
869871

870872
if device is None:
871-
device_type = _find_device(model)
873+
device = _find_device(model)
874+
device_type = _find_device_type(model)
872875
elif isinstance(device, str):
873876
_validate_device_type(device)
877+
import torch
878+
874879
device_type = Device(type=device)
880+
device = torch.device(device)
875881
else:
876882
device_type = Device(device.type)
877883

@@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884890
layer_name = module_class.kernel_layer_name
885891

886892
if _DISABLE_KERNEL_MAPPING:
887-
_replace_forward(module, module_class)
893+
_replace_forward(device, module, module_class)
888894
continue
889895

890896
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
@@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898904
)
899905
if not use_fallback:
900906
raise ValueError(f"No layer mapping for `{layer_name}`")
901-
_replace_forward(module, module_class)
907+
_replace_forward(device, module, module_class)
902908
continue
903909

904910
# Get kernel options for the device
@@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909915
raise ValueError(
910916
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
911917
)
912-
_replace_forward(module, module_class)
918+
_replace_forward(device, module, module_class)
913919
continue
914920

915921
repos = property_repos.repos
@@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919925
raise ValueError(
920926
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
921927
)
922-
_replace_forward(module, module_class)
928+
_replace_forward(device, module, module_class)
923929
continue
924930

925931
repo_with_mode = _select_repository(
@@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932938
raise ValueError(
933939
f"No repository for `{layer_name}` for configuration mode={mode}"
934940
)
935-
_replace_forward(module, module_class)
941+
_replace_forward(device, module, module_class)
936942
continue
937943

938944
repo, repo_mode = repo_with_mode
@@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951957
)
952958

953959
_conditionally_replace_forward(
960+
device=device,
954961
module=module,
955962
layer=layer,
956963
mode=mode,
@@ -1037,19 +1044,31 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371044
raise TypeError(f"{repo} must not override nn.Module constructor.")
10381045

10391046
# ... or predefined member variables.
1040-
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1041-
cls_members = {name for name, _ in inspect.getmembers(cls)}
1042-
difference = cls_members - torch_module_members
1047+
unique_members = _unique_layer_members(cls)
10431048
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044-
if not difference <= {"can_torch_compile", "has_backward"}:
1049+
if not unique_members <= {
1050+
"can_torch_compile",
1051+
"create_state",
1052+
"has_backward",
1053+
"forward_with_state",
1054+
}:
10451055
raise TypeError(
10461056
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
10471057
)
10481058

10491059
# Check whether the forward signatures are similar.
1050-
params = inspect.signature(cls.forward).parameters
10511060
ref_params = inspect.signature(check_cls.forward).parameters
10521061

1062+
params: Mapping[str, inspect.Parameter]
1063+
if _is_stateful_layer(cls):
1064+
params = inspect.signature(cls.forward_with_state).parameters
1065+
# Get rid of the mappingproxy.
1066+
params = params.copy()
1067+
# Remove the state to be able to compare with forward.
1068+
del params["state"]
1069+
else:
1070+
params = inspect.signature(cls.forward).parameters
1071+
10531072
if len(params) != len(ref_params):
10541073
raise TypeError(
10551074
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
@@ -1074,15 +1093,21 @@ def _is_rocm_platform():
10741093
return torch.version.hip is not None
10751094

10761095

1077-
def _find_device(model: "nn.Module") -> Device:
1096+
def _find_device(model: "nn.Module") -> torch.device:
10781097
try:
10791098
param = next(model.parameters())
10801099
except StopIteration:
10811100
raise ValueError(
10821101
"Cannot determine model device, provide as `device` argument to `kernelize`."
10831102
)
10841103

1085-
dev_type = param.device.type
1104+
return param.device
1105+
1106+
1107+
def _find_device_type(model: "nn.Module") -> Device:
1108+
device = _find_device(model)
1109+
1110+
dev_type = device.type
10861111
if dev_type == "cuda":
10871112
# Refine based on actual platform
10881113
if _is_rocm_platform():
@@ -1103,6 +1128,7 @@ def _find_capability() -> int:
11031128

11041129
def _conditionally_replace_forward(
11051130
*,
1131+
device: "torch.device",
11061132
module: "nn.Module",
11071133
layer: Type["nn.Module"],
11081134
mode: Mode,
@@ -1128,15 +1154,25 @@ def _conditionally_replace_forward(
11281154
logging.info("Layer does not support torch.compile, using fallback")
11291155
if needs_fallback_for_backward:
11301156
logging.info("Layer does not support backward, using fallback")
1131-
_replace_forward(module, module_class)
1157+
_replace_forward(device, module, module_class)
11321158
else:
11331159
raise ValueError(f"Available kernel does not support mode: {mode}")
11341160
else:
1135-
_replace_forward(module, layer)
1161+
_replace_forward(device, module, layer)
11361162

11371163

1138-
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
1139-
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
1164+
def _replace_forward(
1165+
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
1166+
):
1167+
if _is_stateful_layer(layer):
1168+
state = layer.create_state(device, module) # type: ignore[attr-defined]
1169+
1170+
def forward(self, *args, **kwargs):
1171+
return layer.forward_with_state(self, state, *args, **kwargs)
1172+
1173+
module.forward = MethodType(forward, module)
1174+
else:
1175+
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
11401176

11411177

11421178
def _validate_layer_has_mode(
@@ -1179,3 +1215,21 @@ def _get_layer_memoize(
11791215
_CACHED_LAYER[repo] = layer
11801216

11811217
return layer
1218+
1219+
1220+
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
1221+
import torch.nn as nn
1222+
1223+
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1224+
cls_members = {name for name, _ in inspect.getmembers(layer)}
1225+
return cls_members - torch_module_members
1226+
1227+
1228+
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
1229+
unique = _unique_layer_members(layer)
1230+
is_stateful = "forward_with_state" in unique
1231+
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
1232+
raise TypeError(
1233+
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
1234+
)
1235+
return is_stateful

tests/test_layer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn as nn
77
from torch.nn import functional as F
8+
from torch.testing import assert_close
89

910
from kernels import (
1011
CUDAProperties,
@@ -321,6 +322,47 @@ def test_local_layer_repo(device):
321322
assert linear.n_calls == 0
322323

323324

325+
def test_stateful_layer(device):
326+
@use_kernel_forward_from_hub("ReluWithHiddenSize")
327+
class ReluWithHiddenSize(nn.Module):
328+
hidden_size: int
329+
330+
def __init__(self, hidden_size: int):
331+
super().__init__()
332+
self.hidden_size = hidden_size
333+
334+
def forward(self, x: torch.Tensor) -> torch.Tensor:
335+
return F.relu(x)
336+
337+
model = ReluWithHiddenSize(hidden_size=64).to(device)
338+
x = torch.randn((32, 64), device=device)
339+
y_ref = model(x)
340+
341+
with use_kernel_mapping(
342+
{
343+
"ReluWithHiddenSize": {
344+
"cuda": LayerRepository(
345+
repo_id="kernels-test/state-test",
346+
layer_name="StatefulReLU",
347+
),
348+
"xpu": LayerRepository(
349+
repo_id="kernels-test/state-test",
350+
layer_name="StatefulReLU",
351+
),
352+
}
353+
},
354+
inherit_mapping=False,
355+
):
356+
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
357+
358+
y = model(x)
359+
assert_close(y, y_ref)
360+
361+
model = torch.compile(model, fullgraph=True)
362+
y = model(x)
363+
assert_close(y, y_ref)
364+
365+
324366
@pytest.mark.cuda_only
325367
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
326368
@pytest.mark.parametrize("device", ["cuda"])

0 commit comments

Comments
 (0)