wtpsplit already support inference paths:
- standard PyTorch inference (
SaT(...).to("cuda"), half(), etc.)
- ONNX Runtime inference via
ort_providers / exported ONNX models
- eager PyTorch inference
- ONNX Runtime inference (ort_providers)
- ONNX export scripts for SaT/WtP
It would be useful to add a third optimized path based on torch.compile / TorchInductor for users who want faster PyTorch-native nference without exporting to ONNX.
NVIDIA AITune looks relevant here because it can inspect a PyTorch model, try runtime backends, and activate an optimized backend for inference.
from wtpsplit import SaT
sat = SaT("sat-3l-sm")
sat.optimize(backend="torchinductor")
# optionally
sat.optimize(backend="aitune")
wtpsplitalready support inference paths:SaT(...).to("cuda"),half(), etc.)ort_providers/ exported ONNX modelsIt would be useful to add a third optimized path based on
torch.compile/ TorchInductor for users who want faster PyTorch-native nference without exporting to ONNX.NVIDIA AITune looks relevant here because it can inspect a PyTorch model, try runtime backends, and activate an optimized backend for inference.