3737 import torch
3838 from torch import nn
3939
40-
4140_DISABLE_KERNEL_MAPPING : bool = bool (int (os .environ .get ("DISABLE_KERNEL_MAPPING" , "0" )))
4241
4342
@@ -122,6 +121,8 @@ def create_repo(self) -> _DeviceRepos:
122121 """Create an appropriate repository set for this device type."""
123122 if self .type == "cuda" :
124123 return _CUDARepos ()
124+ elif self .type == "rocm" :
125+ return _ROCMRepos ()
125126 elif self .type == "mps" :
126127 return _MPSRepos ()
127128 else :
@@ -181,6 +182,51 @@ def __hash__(self):
181182 return hash ((self .min_capability , self .max_capability ))
182183
183184
185+ @dataclass (frozen = True )
186+ class ROCMProperties :
187+ """
188+ ROCM-specific device properties for capability-based kernel selection.
189+
190+ This class defines ROCM compute capability constraints for kernel selection, allowing kernels to specify
191+ minimum and maximum ROCM compute capabilities they support.
192+
193+ Args:
194+ min_capability (`int`):
195+ Minimum ROCM compute capability required (e.g., 75 for compute capability 7.5).
196+ max_capability (`int`):
197+ Maximum ROCM compute capability supported (e.g., 90 for compute capability 9.0).
198+
199+ Example:
200+ ```python
201+ from kernels import ROCMProperties, Device
202+
203+ # Define ROCM properties for modern GPUs (compute capability 7.5 to 9.0)
204+ rocm_props = ROCMProperties(min_capability=75, max_capability=90)
205+
206+ # Create a device with these properties
207+ device = Device(type="rocm", properties=rocm_props)
208+ ```
209+
210+ Note:
211+ ROCM compute capabilities are represented as integers where the major and minor versions are concatenated.
212+ For example, compute capability 7.5 is represented as 75, and 8.6 is represented as 86.
213+ """
214+
215+ min_capability : int
216+ max_capability : int
217+
218+ def __eq__ (self , other ):
219+ if not isinstance (other , ROCMProperties ):
220+ return NotImplemented
221+ return (
222+ self .min_capability == other .min_capability
223+ and self .max_capability == other .max_capability
224+ )
225+
226+ def __hash__ (self ):
227+ return hash ((self .min_capability , self .max_capability ))
228+
229+
184230class LayerRepositoryProtocol (Protocol ):
185231 @property
186232 def layer_name (self ) -> str : ...
@@ -452,6 +498,46 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
452498 self .repos_by_capability .insert (min_capability , max_capability , repos )
453499
454500
501+ class _ROCMRepos (_DeviceRepos ):
502+ _repos : IntervalTree [Dict [Mode , LayerRepositoryProtocol ]]
503+
504+ def __init__ (self ):
505+ super ().__init__ ()
506+ self .repos_by_capability = IntervalTree ()
507+
508+ @property
509+ def repos (
510+ self ,
511+ ) -> Optional [Dict [Mode , LayerRepositoryProtocol ]]:
512+ capability = _find_capability ()
513+ return self .repos_by_capability .find_smallest_interval (capability )
514+
515+ def insert (self , device : Device , repos : Dict [Mode , LayerRepositoryProtocol ]):
516+ assert device .properties is None or isinstance (
517+ device .properties , ROCMProperties
518+ )
519+
520+ min_capability = (
521+ 0 if device .properties is None else device .properties .min_capability
522+ )
523+ max_capability = (
524+ sys .maxsize
525+ if device .properties is None
526+ else device .properties .max_capability
527+ )
528+
529+ self .repos_by_capability .insert (min_capability , max_capability , repos )
530+
531+
532+ def _validate_device_type (device_type : str ) -> None :
533+ """Validate that the device type is supported."""
534+ supported_devices = {"cuda" , "rocm" , "mps" , "cpu" }
535+ if device_type not in supported_devices :
536+ raise ValueError (
537+ f"Unsupported device type '{ device_type } '. Supported device types are: { ', ' .join (sorted (supported_devices ))} "
538+ )
539+
540+
455541_KERNEL_MAPPING : ContextVar [Dict [str , Dict [str , _DeviceRepos ]]] = ContextVar (
456542 "_KERNEL_MAPPING" , default = {}
457543)
@@ -703,8 +789,8 @@ def kernelize(
703789 The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
704790 kernelizes the model for training with `torch.compile`.
705791 device (`Union[str, torch.device]`, *optional*):
706- The device type to load kernels for. The device type will be inferred from the model parameters
707- when not provided.
792+ The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
793+ The device type will be inferred from the model parameters when not provided.
708794 use_fallback (`bool`, *optional*, defaults to `True`):
709795 Whether to use the original forward method of modules when no compatible kernel could be found.
710796 If set to `False`, an exception will be raised in such cases.
@@ -746,7 +832,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
746832 kernelized_model = kernelize(model)
747833 ```
748834 """
749- import torch
750835
751836 if mode == Mode .FALLBACK :
752837 raise ValueError ("Mode.FALLBACK can only be used to register kernel mappings." )
@@ -760,7 +845,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
760845 if device is None :
761846 device_type = _find_device (model )
762847 elif isinstance (device , str ):
763- device_type = Device (type = torch .device (device ).type )
848+ _validate_device_type (device )
849+ device_type = Device (type = device )
764850 else :
765851 device_type = Device (device .type )
766852
@@ -948,6 +1034,18 @@ def _validate_layer(*, check_cls, cls):
9481034 )
9491035
9501036
1037+ def _is_cuda_platform ():
1038+ import torch
1039+
1040+ return torch .version .cuda is not None
1041+
1042+
1043+ def _is_rocm_platform ():
1044+ import torch
1045+
1046+ return torch .version .hip is not None
1047+
1048+
9511049def _find_device (model : "nn.Module" ) -> Device :
9521050 try :
9531051 param = next (model .parameters ())
@@ -956,7 +1054,15 @@ def _find_device(model: "nn.Module") -> Device:
9561054 "Cannot determine model device, provide as `device` argument to `kernelize`."
9571055 )
9581056
959- return Device (type = param .device .type )
1057+ dev_type = param .device .type
1058+ if dev_type == "cuda" :
1059+ # Refine based on actual platform
1060+ if _is_rocm_platform ():
1061+ return Device (type = "rocm" )
1062+ elif _is_cuda_platform ():
1063+ return Device (type = "cuda" )
1064+
1065+ return Device (type = dev_type )
9601066
9611067
9621068@lru_cache
0 commit comments