Skip to content

Commit a8a6564

Browse files
ahadnagydanieldk
andauthored
Add ROCm device discovery (#122)
* Add ROCm device discovery * Ruff * Address review comments * Ruff * Reorg torch import * Remove redundant import * Apply suggestions from code review Co-authored-by: Daniël de Kok <[email protected]> * Address review comments * Validat device type * Clean diff * black * Sync test with repo changes * black again --------- Co-authored-by: Daniël de Kok <[email protected]>
1 parent c89e0fa commit a8a6564

File tree

5 files changed

+189
-9
lines changed

5 files changed

+189
-9
lines changed

docs/source/layers.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ kernel_layer_mapping = {
135135
"cuda": LayerRepository(
136136
repo_id="kernels-community/activation",
137137
layer_name="SiluAndMul",
138+
),
139+
"rocm": LayerRepository(
140+
repo_id="kernels-community/activation",
141+
layer_name="SiluAndMul",
138142
)
139143
}
140144
}
@@ -261,7 +265,6 @@ Capabilities behave as follows:
261265
an existing kernel, the new kernel will replace the old kernel.
262266
- When there are multiple kernels that support a capability, the kernel
263267
with the smaller capability interval will be used. E.g. given:
264-
265268
- `KernelA` with `min_capability=80` and `max_capability=89`;
266269
- `KernelB` with `min_capability=75` and `max_capability=89`;
267270
- `kernelize` runs on a system with capability 8.6.
@@ -271,6 +274,12 @@ Capabilities behave as follows:
271274
tend to be more optimized for a specific set of GPUs. **This behavior
272275
might still change in the future.**
273276

277+
### Registering kernels for specific ROCm capabilities
278+
279+
Registering kernels for the ROCm architecture follows the exact same
280+
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
281+
a kernel to a range of ROCm capabilities.
282+
274283
### Loading from a local repository for testing
275284

276285
The `LocalLayerRepository` class is provided to load a repository from

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[pytest]
22
markers =
33
cuda_only: marks tests that should only hosts with CUDA GPUs
4+
rocm_only: marks tests that should only run on hosts with ROCm GPUs
45
darwin_only: marks tests that should only run on macOS

src/kernels/layer.py

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import torch
3838
from torch import nn
3939

40-
4140
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
4241

4342

@@ -122,6 +121,8 @@ def create_repo(self) -> _DeviceRepos:
122121
"""Create an appropriate repository set for this device type."""
123122
if self.type == "cuda":
124123
return _CUDARepos()
124+
elif self.type == "rocm":
125+
return _ROCMRepos()
125126
elif self.type == "mps":
126127
return _MPSRepos()
127128
else:
@@ -181,6 +182,51 @@ def __hash__(self):
181182
return hash((self.min_capability, self.max_capability))
182183

183184

185+
@dataclass(frozen=True)
186+
class ROCMProperties:
187+
"""
188+
ROCM-specific device properties for capability-based kernel selection.
189+
190+
This class defines ROCM compute capability constraints for kernel selection, allowing kernels to specify
191+
minimum and maximum ROCM compute capabilities they support.
192+
193+
Args:
194+
min_capability (`int`):
195+
Minimum ROCM compute capability required (e.g., 75 for compute capability 7.5).
196+
max_capability (`int`):
197+
Maximum ROCM compute capability supported (e.g., 90 for compute capability 9.0).
198+
199+
Example:
200+
```python
201+
from kernels import ROCMProperties, Device
202+
203+
# Define ROCM properties for modern GPUs (compute capability 7.5 to 9.0)
204+
rocm_props = ROCMProperties(min_capability=75, max_capability=90)
205+
206+
# Create a device with these properties
207+
device = Device(type="rocm", properties=rocm_props)
208+
```
209+
210+
Note:
211+
ROCM compute capabilities are represented as integers where the major and minor versions are concatenated.
212+
For example, compute capability 7.5 is represented as 75, and 8.6 is represented as 86.
213+
"""
214+
215+
min_capability: int
216+
max_capability: int
217+
218+
def __eq__(self, other):
219+
if not isinstance(other, ROCMProperties):
220+
return NotImplemented
221+
return (
222+
self.min_capability == other.min_capability
223+
and self.max_capability == other.max_capability
224+
)
225+
226+
def __hash__(self):
227+
return hash((self.min_capability, self.max_capability))
228+
229+
184230
class LayerRepositoryProtocol(Protocol):
185231
@property
186232
def layer_name(self) -> str: ...
@@ -452,6 +498,46 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
452498
self.repos_by_capability.insert(min_capability, max_capability, repos)
453499

454500

501+
class _ROCMRepos(_DeviceRepos):
502+
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
503+
504+
def __init__(self):
505+
super().__init__()
506+
self.repos_by_capability = IntervalTree()
507+
508+
@property
509+
def repos(
510+
self,
511+
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
512+
capability = _find_capability()
513+
return self.repos_by_capability.find_smallest_interval(capability)
514+
515+
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
516+
assert device.properties is None or isinstance(
517+
device.properties, ROCMProperties
518+
)
519+
520+
min_capability = (
521+
0 if device.properties is None else device.properties.min_capability
522+
)
523+
max_capability = (
524+
sys.maxsize
525+
if device.properties is None
526+
else device.properties.max_capability
527+
)
528+
529+
self.repos_by_capability.insert(min_capability, max_capability, repos)
530+
531+
532+
def _validate_device_type(device_type: str) -> None:
533+
"""Validate that the device type is supported."""
534+
supported_devices = {"cuda", "rocm", "mps", "cpu"}
535+
if device_type not in supported_devices:
536+
raise ValueError(
537+
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
538+
)
539+
540+
455541
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
456542
"_KERNEL_MAPPING", default={}
457543
)
@@ -703,8 +789,8 @@ def kernelize(
703789
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
704790
kernelizes the model for training with `torch.compile`.
705791
device (`Union[str, torch.device]`, *optional*):
706-
The device type to load kernels for. The device type will be inferred from the model parameters
707-
when not provided.
792+
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
793+
The device type will be inferred from the model parameters when not provided.
708794
use_fallback (`bool`, *optional*, defaults to `True`):
709795
Whether to use the original forward method of modules when no compatible kernel could be found.
710796
If set to `False`, an exception will be raised in such cases.
@@ -746,7 +832,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
746832
kernelized_model = kernelize(model)
747833
```
748834
"""
749-
import torch
750835

751836
if mode == Mode.FALLBACK:
752837
raise ValueError("Mode.FALLBACK can only be used to register kernel mappings.")
@@ -760,7 +845,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
760845
if device is None:
761846
device_type = _find_device(model)
762847
elif isinstance(device, str):
763-
device_type = Device(type=torch.device(device).type)
848+
_validate_device_type(device)
849+
device_type = Device(type=device)
764850
else:
765851
device_type = Device(device.type)
766852

@@ -948,6 +1034,18 @@ def _validate_layer(*, check_cls, cls):
9481034
)
9491035

9501036

1037+
def _is_cuda_platform():
1038+
import torch
1039+
1040+
return torch.version.cuda is not None
1041+
1042+
1043+
def _is_rocm_platform():
1044+
import torch
1045+
1046+
return torch.version.hip is not None
1047+
1048+
9511049
def _find_device(model: "nn.Module") -> Device:
9521050
try:
9531051
param = next(model.parameters())
@@ -956,7 +1054,15 @@ def _find_device(model: "nn.Module") -> Device:
9561054
"Cannot determine model device, provide as `device` argument to `kernelize`."
9571055
)
9581056

959-
return Device(type=param.device.type)
1057+
dev_type = param.device.type
1058+
if dev_type == "cuda":
1059+
# Refine based on actual platform
1060+
if _is_rocm_platform():
1061+
return Device(type="rocm")
1062+
elif _is_cuda_platform():
1063+
return Device(type="cuda")
1064+
1065+
return Device(type=dev_type)
9601066

9611067

9621068
@lru_cache

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,22 @@
33
import pytest
44
import torch
55

6-
has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
6+
has_cuda = (
7+
hasattr(torch.version, "cuda")
8+
and torch.version.cuda is not None
9+
and torch.cuda.device_count() > 0
10+
)
11+
has_rocm = (
12+
hasattr(torch.version, "hip")
13+
and torch.version.hip is not None
14+
and torch.cuda.device_count() > 0
15+
)
716

817

918
def pytest_runtest_setup(item):
1019
if "cuda_only" in item.keywords and not has_cuda:
1120
pytest.skip("skipping CUDA-only test on host without CUDA")
21+
if "rocm_only" in item.keywords and not has_rocm:
22+
pytest.skip("skipping ROCm-only test on host without ROCm")
1223
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
1324
pytest.skip("skipping macOS-only test on non-macOS platform")

tests/test_layer.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
"cuda": LayerRepository(
3535
repo_id="kernels-test/op-without-fake-test",
3636
layer_name="SiluAndMul",
37-
)
37+
),
38+
"rocm": LayerRepository(
39+
repo_id="kernels-test/op-without-fake-test",
40+
layer_name="SiluAndMul",
41+
),
3842
},
3943
"SiluAndMulStringDevice": {
4044
"cuda": LayerRepository(
@@ -126,6 +130,55 @@ def test_hub_forward(cls, device):
126130
assert silu_and_mul_with_kernel.n_calls == 1
127131

128132

133+
@pytest.mark.rocm_only
134+
def test_hub_forward_rocm():
135+
torch.manual_seed(0)
136+
137+
silu_and_mul = SiluAndMul()
138+
X = torch.randn((32, 64))
139+
Y = silu_and_mul(X)
140+
141+
silu_and_mul_with_kernel = kernelize(
142+
SiluAndMulNoCompileKernel(), device="rocm", mode=Mode.INFERENCE
143+
)
144+
Y_kernel = silu_and_mul_with_kernel(X)
145+
146+
torch.testing.assert_close(Y_kernel, Y)
147+
148+
assert silu_and_mul.n_calls == 1
149+
# Should use kernel (n_calls == 0) if ROCm kernel is available, otherwise fallback (n_calls == 1)
150+
# The exact behavior depends on whether the test kernel exists for ROCm
151+
assert silu_and_mul_with_kernel.n_calls in [0, 1]
152+
153+
154+
def test_rocm_kernel_mapping():
155+
"""Test that ROCm shorthand device mapping works correctly."""
156+
kernel_layer_mapping = {
157+
"SiluAndMul": {
158+
"rocm": LayerRepository(
159+
repo_id="kernels-community/activation",
160+
layer_name="SiluAndMul",
161+
)
162+
}
163+
}
164+
165+
# Test that the mapping is processed correctly
166+
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
167+
mapping = _KERNEL_MAPPING.get()
168+
169+
# Verify the mapping exists
170+
assert "SiluAndMul" in mapping
171+
assert "rocm" in mapping["SiluAndMul"]
172+
173+
# Verify the repository is correctly stored
174+
rocm_repos = mapping["SiluAndMul"]["rocm"]
175+
assert rocm_repos is not None
176+
assert (
177+
rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-community/activation"
178+
)
179+
assert rocm_repos.repos[Mode.FALLBACK].layer_name == "SiluAndMul"
180+
181+
129182
@pytest.mark.cuda_only
130183
def test_capability():
131184
linear = TorchLinearWithCounter(32, 32).to("cuda")

0 commit comments

Comments
 (0)