@@ -27,26 +27,43 @@ def __init__(self, model_path: str):
2727 Args:
2828 model_path: Path to the ONNX model file.
2929 """
30+ self .model_path = model_path
31+ self .providers = ["CPUExecutionProvider" ]
3032 self .session = ort .InferenceSession (
3133 model_path ,
32- providers = [ "CPUExecutionProvider" ], # TODO: Add GPU support
34+ providers = self . providers ,
3335 )
3436
3537 def to (
3638 self ,
37- device : torch .device , # noqa: ARG002
39+ device : torch .device ,
3840 ) -> ONNXModelWrapper :
3941 """Moves the model to the specified device.
4042
41- This is a no-op for the ONNX model wrapper. GPU support is not implemented.
42-
4343 Args:
44- device: The target device (unused ).
44+ device: The target device (cuda or cpu ).
4545
4646 Returns:
4747 self
4848 """
49- # TODO: Add GPU support by changing provider
49+ if device .type == "cuda" :
50+ # Check if CUDA is available in ONNX Runtime
51+ cuda_provider = "CUDAExecutionProvider"
52+ if cuda_provider in ort .get_available_providers ():
53+ self .providers = [cuda_provider , "CPUExecutionProvider" ]
54+ # Reinitialize session with CUDA provider
55+ self .session = ort .InferenceSession (
56+ self .model_path ,
57+ providers = self .providers ,
58+ )
59+ else :
60+ pass
61+ else :
62+ self .providers = ["CPUExecutionProvider" ]
63+ self .session = ort .InferenceSession (
64+ self .model_path ,
65+ providers = self .providers ,
66+ )
5067 return self
5168
5269 def type (
@@ -105,8 +122,6 @@ def __call__(
105122
106123 Returns:
107124 A torch tensor containing the model output.
108-
109- Note that only_return_standard_out is not used in the ONNX runtime.
110125 """
111126 # Convert inputs to numpy
112127 X_np = X .cpu ().numpy () if isinstance (X , torch .Tensor ) else X
@@ -125,8 +140,11 @@ def __call__(
125140 # Run inference
126141 outputs = self .session .run (None , onnx_inputs )
127142
128- # Convert back to a torch tensor
129- return torch .from_numpy (outputs [0 ])
143+ # Convert back to torch tensor and move to the appropriate device
144+ output_tensor = torch .from_numpy (outputs [0 ])
145+ if "CUDAExecutionProvider" in self .providers :
146+ output_tensor = output_tensor .cuda ()
147+ return output_tensor
130148
131149
132150class ModelWrapper (nn .Module ):
0 commit comments