Skip to content

Commit 6f4d549

Browse files
committed
Add support for stateful layers
1 parent 055a953 commit 6f4d549

File tree

2 files changed

+112
-17
lines changed

2 files changed

+112
-17
lines changed

src/kernels/layer.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import warnings
99
from abc import ABC, abstractmethod
10+
from collections import OrderedDict
1011
from contextvars import ContextVar
1112
from copy import deepcopy
1213
from dataclasses import dataclass
@@ -19,6 +20,7 @@
1920
Dict,
2021
Optional,
2122
Protocol,
23+
Set,
2224
Tuple,
2325
Type,
2426
Union,
@@ -868,10 +870,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
868870
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
869871

870872
if device is None:
871-
device_type = _find_device(model)
873+
device = _find_device(model)
874+
device_type = _find_device_type(model)
872875
elif isinstance(device, str):
873876
_validate_device_type(device)
877+
import torch
878+
874879
device_type = Device(type=device)
880+
device = torch.device(device)
875881
else:
876882
device_type = Device(device.type)
877883

@@ -884,7 +890,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
884890
layer_name = module_class.kernel_layer_name
885891

886892
if _DISABLE_KERNEL_MAPPING:
887-
_replace_forward(module, module_class)
893+
_replace_forward(device, module, module_class)
888894
continue
889895

890896
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
@@ -898,7 +904,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
898904
)
899905
if not use_fallback:
900906
raise ValueError(f"No layer mapping for `{layer_name}`")
901-
_replace_forward(module, module_class)
907+
_replace_forward(device, module, module_class)
902908
continue
903909

904910
# Get kernel options for the device
@@ -909,7 +915,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
909915
raise ValueError(
910916
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
911917
)
912-
_replace_forward(module, module_class)
918+
_replace_forward(device, module, module_class)
913919
continue
914920

915921
repos = property_repos.repos
@@ -919,7 +925,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
919925
raise ValueError(
920926
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
921927
)
922-
_replace_forward(module, module_class)
928+
_replace_forward(device, module, module_class)
923929
continue
924930

925931
repo_with_mode = _select_repository(
@@ -932,7 +938,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
932938
raise ValueError(
933939
f"No repository for `{layer_name}` for configuration mode={mode}"
934940
)
935-
_replace_forward(module, module_class)
941+
_replace_forward(device, module, module_class)
936942
continue
937943

938944
repo, repo_mode = repo_with_mode
@@ -951,6 +957,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
951957
)
952958

953959
_conditionally_replace_forward(
960+
device=device,
954961
module=module,
955962
layer=layer,
956963
mode=mode,
@@ -1037,19 +1044,30 @@ def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
10371044
raise TypeError(f"{repo} must not override nn.Module constructor.")
10381045

10391046
# ... or predefined member variables.
1040-
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1041-
cls_members = {name for name, _ in inspect.getmembers(cls)}
1042-
difference = cls_members - torch_module_members
1047+
unique_members = _unique_layer_members(cls)
10431048
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
1044-
if not difference <= {"can_torch_compile", "has_backward"}:
1049+
if not unique_members <= {
1050+
"can_torch_compile",
1051+
"create_state",
1052+
"has_backward",
1053+
"forward_with_state",
1054+
}:
10451055
raise TypeError(
10461056
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
10471057
)
10481058

10491059
# Check whether the forward signatures are similar.
1050-
params = inspect.signature(cls.forward).parameters
10511060
ref_params = inspect.signature(check_cls.forward).parameters
10521061

1062+
if _is_stateful_layer(cls):
1063+
params = inspect.signature(cls.forward_with_state).parameters
1064+
# Get rid of the mappingproxy.
1065+
params = OrderedDict(params)
1066+
# Remove the state to be able to compare with forward.
1067+
del params["state"]
1068+
else:
1069+
params = inspect.signature(cls.forward).parameters
1070+
10531071
if len(params) != len(ref_params):
10541072
raise TypeError(
10551073
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
@@ -1074,15 +1092,21 @@ def _is_rocm_platform():
10741092
return torch.version.hip is not None
10751093

10761094

1077-
def _find_device(model: "nn.Module") -> Device:
1095+
def _find_device(model: "nn.Module") -> torch.device:
10781096
try:
10791097
param = next(model.parameters())
10801098
except StopIteration:
10811099
raise ValueError(
10821100
"Cannot determine model device, provide as `device` argument to `kernelize`."
10831101
)
10841102

1085-
dev_type = param.device.type
1103+
return param.device
1104+
1105+
1106+
def _find_device_type(model: "nn.Module") -> Device:
1107+
device = _find_device(model)
1108+
1109+
dev_type = device.type
10861110
if dev_type == "cuda":
10871111
# Refine based on actual platform
10881112
if _is_rocm_platform():
@@ -1103,6 +1127,7 @@ def _find_capability() -> int:
11031127

11041128
def _conditionally_replace_forward(
11051129
*,
1130+
device: "torch.device",
11061131
module: "nn.Module",
11071132
layer: Type["nn.Module"],
11081133
mode: Mode,
@@ -1128,15 +1153,25 @@ def _conditionally_replace_forward(
11281153
logging.info("Layer does not support torch.compile, using fallback")
11291154
if needs_fallback_for_backward:
11301155
logging.info("Layer does not support backward, using fallback")
1131-
_replace_forward(module, module_class)
1156+
_replace_forward(device, module, module_class)
11321157
else:
11331158
raise ValueError(f"Available kernel does not support mode: {mode}")
11341159
else:
1135-
_replace_forward(module, layer)
1160+
_replace_forward(device, module, layer)
11361161

11371162

1138-
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
1139-
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
1163+
def _replace_forward(
1164+
device: "torch.device", module: "nn.Module", layer: Type["nn.Module"]
1165+
):
1166+
if _is_stateful_layer(layer):
1167+
state = layer.create_state(device, module)
1168+
1169+
def forward(self, *args, **kwargs):
1170+
return layer.forward_with_state(self, state, *args, **kwargs)
1171+
1172+
module.forward = MethodType(forward, module)
1173+
else:
1174+
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
11401175

11411176

11421177
def _validate_layer_has_mode(
@@ -1179,3 +1214,21 @@ def _get_layer_memoize(
11791214
_CACHED_LAYER[repo] = layer
11801215

11811216
return layer
1217+
1218+
1219+
def _unique_layer_members(layer: Type["nn.Module"]) -> Set[str]:
1220+
import torch.nn as nn
1221+
1222+
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
1223+
cls_members = {name for name, _ in inspect.getmembers(layer)}
1224+
return cls_members - torch_module_members
1225+
1226+
1227+
def _is_stateful_layer(layer: Type[nn.Module]) -> bool:
1228+
unique = _unique_layer_members(layer)
1229+
is_stateful = "forward_with_state" in unique
1230+
if is_stateful and len(unique & {"create_state", "forward_with_state"}) != 2:
1231+
raise TypeError(
1232+
f"Stateful layer `{layer.__name__}` must implement both `create_state` and `forward_with_state` or neither."
1233+
)
1234+
return is_stateful

tests/test_layer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torch.nn as nn
77
from torch.nn import functional as F
8+
from torch.testing import assert_close
89

910
from kernels import (
1011
CUDAProperties,
@@ -321,6 +322,47 @@ def test_local_layer_repo(device):
321322
assert linear.n_calls == 0
322323

323324

325+
def test_stateful_layer(device):
326+
@use_kernel_forward_from_hub("ReluWithHiddenSize")
327+
class ReluWithHiddenSize(nn.Module):
328+
hidden_size: int
329+
330+
def __init__(self, hidden_size: int):
331+
super().__init__()
332+
self.hidden_size = hidden_size
333+
334+
def forward(self, x: torch.Tensor) -> torch.Tensor:
335+
return F.relu(x)
336+
337+
model = ReluWithHiddenSize(hidden_size=64).to(device)
338+
x = torch.randn((32, 64), device=device)
339+
y_ref = model(x)
340+
341+
with use_kernel_mapping(
342+
{
343+
"ReluWithHiddenSize": {
344+
"cuda": LayerRepository(
345+
repo_id="kernels-test/state-test",
346+
layer_name="StatefulReLU",
347+
),
348+
"xpu": LayerRepository(
349+
repo_id="kernels-test/state-test",
350+
layer_name="StatefulReLU",
351+
),
352+
}
353+
},
354+
inherit_mapping=False,
355+
):
356+
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
357+
358+
y = model(x)
359+
assert_close(y, y_ref)
360+
361+
model = torch.compile(model, fullgraph=True)
362+
y = model(x)
363+
assert_close(y, y_ref)
364+
365+
324366
@pytest.mark.cuda_only
325367
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
326368
@pytest.mark.parametrize("device", ["cuda"])

0 commit comments

Comments
 (0)