Skip to content

Commit 52e87ce

Browse files
committed
allow to move to gpu
1 parent 79dafa3 commit 52e87ce

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

src/tabpfn/misc/onnx_wrapper.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,43 @@ def __init__(self, model_path: str):
2727
Args:
2828
model_path: Path to the ONNX model file.
2929
"""
30+
self.model_path = model_path
31+
self.providers = ["CPUExecutionProvider"]
3032
self.session = ort.InferenceSession(
3133
model_path,
32-
providers=["CPUExecutionProvider"], # TODO: Add GPU support
34+
providers=self.providers,
3335
)
3436

3537
def to(
3638
self,
37-
device: torch.device, # noqa: ARG002
39+
device: torch.device,
3840
) -> ONNXModelWrapper:
3941
"""Moves the model to the specified device.
4042
41-
This is a no-op for the ONNX model wrapper. GPU support is not implemented.
42-
4343
Args:
44-
device: The target device (unused).
44+
device: The target device (cuda or cpu).
4545
4646
Returns:
4747
self
4848
"""
49-
# TODO: Add GPU support by changing provider
49+
if device.type == "cuda":
50+
# Check if CUDA is available in ONNX Runtime
51+
cuda_provider = "CUDAExecutionProvider"
52+
if cuda_provider in ort.get_available_providers():
53+
self.providers = [cuda_provider, "CPUExecutionProvider"]
54+
# Reinitialize session with CUDA provider
55+
self.session = ort.InferenceSession(
56+
self.model_path,
57+
providers=self.providers,
58+
)
59+
else:
60+
pass
61+
else:
62+
self.providers = ["CPUExecutionProvider"]
63+
self.session = ort.InferenceSession(
64+
self.model_path,
65+
providers=self.providers,
66+
)
5067
return self
5168

5269
def type(
@@ -105,8 +122,6 @@ def __call__(
105122
106123
Returns:
107124
A torch tensor containing the model output.
108-
109-
Note that only_return_standard_out is not used in the ONNX runtime.
110125
"""
111126
# Convert inputs to numpy
112127
X_np = X.cpu().numpy() if isinstance(X, torch.Tensor) else X
@@ -125,8 +140,11 @@ def __call__(
125140
# Run inference
126141
outputs = self.session.run(None, onnx_inputs)
127142

128-
# Convert back to a torch tensor
129-
return torch.from_numpy(outputs[0])
143+
# Convert back to torch tensor and move to the appropriate device
144+
output_tensor = torch.from_numpy(outputs[0])
145+
if "CUDAExecutionProvider" in self.providers:
146+
output_tensor = output_tensor.cuda()
147+
return output_tensor
130148

131149

132150
class ModelWrapper(nn.Module):

0 commit comments

Comments
 (0)