diff --git a/.github/workflows/static_analysis.yml b/.github/workflows/static_analysis.yml
index 8c25dcaf..2ff7eecc 100644
--- a/.github/workflows/static_analysis.yml
+++ b/.github/workflows/static_analysis.yml
@@ -98,8 +98,9 @@ jobs:
if: always()
run: |
cd ${{ github.workspace }}
- . ftorch_venv/bin/activate # Uses .clang-tidy config file if present
- fortitude check src/
+ . ftorch_venv/bin/activate
+ fortitude check --ignore=E001,T041 src/ftorch.F90
+ fortitude check src/ftorch_test_utils.f90
# Apply C++ and C linter and formatter, clang
# Configurable using the .clang-format and .clang-tidy config files if present
@@ -113,7 +114,7 @@ jobs:
style: 'file'
tidy-checks: ''
# Use the compile_commands.json from CMake to locate headers
- database: ${{ github.workspace }}/src/build
+ database: ${{ github.workspace }}/build
# only 'update' a single comment in a pull request thread.
thread-comments: ${{ github.event_name == 'pull_request' && 'update' }}
- name: Fail fast?!
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ead82aac..0f79fb2d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -22,8 +22,33 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
+# Set GPU device type using consistent numbering as in PyTorch
+# https://github.com/pytorch/pytorch/blob/main/c10/core/DeviceType.h
+# Set in a single location here and passed as preprocessor flag for use
+# throughout source files.
+set(GPU_DEVICE_NONE 0)
+set(GPU_DEVICE_CUDA 1)
+set(GPU_DEVICE_XPU 12)
+set(GPU_DEVICE_MPS 13)
+option(GPU_DEVICE "Set the GPU device (NONE [default], CUDA, XPU, or MPS)" NONE)
+if("${GPU_DEVICE}" STREQUAL "OFF")
+ set(GPU_DEVICE NONE)
+endif()
+if("${GPU_DEVICE}" STREQUAL "NONE")
+ set(GPU_DEVICE_CODE ${GPU_DEVICE_NONE})
+elseif("${GPU_DEVICE}" STREQUAL "CUDA")
+ set(GPU_DEVICE_CODE ${GPU_DEVICE_CUDA})
+elseif("${GPU_DEVICE}" STREQUAL "XPU")
+ set(GPU_DEVICE_CODE ${GPU_DEVICE_XPU})
+elseif("${GPU_DEVICE}" STREQUAL "MPS")
+ set(GPU_DEVICE_CODE ${GPU_DEVICE_MPS})
+else()
+ message(SEND_ERROR "GPU_DEVICE '${GPU_DEVICE}' not recognised")
+endif()
+
+# Other GPU specific setup
include(CheckLanguage)
-if(ENABLE_CUDA)
+if("${GPU_DEVICE}" STREQUAL "CUDA")
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
@@ -57,11 +82,17 @@ find_package(Torch REQUIRED)
add_library(${LIB_NAME} SHARED src/ctorch.cpp src/ftorch.F90
src/ftorch_test_utils.f90)
-if(UNIX)
- if(NOT APPLE) # only add definition for linux (not apple which is also unix)
- target_compile_definitions(${LIB_NAME} PRIVATE UNIX)
- endif()
+# Define compile definitions, including GPU devices
+set(COMPILE_DEFS "")
+if(UNIX AND NOT APPLE)
+ # only add UNIX definition for linux (not apple which is also unix)
+ set(COMPILE_DEFS UNIX)
endif()
+target_compile_definitions(
+ ${LIB_NAME}
+ PRIVATE ${COMPILE_DEFS} GPU_DEVICE=${GPU_DEVICE_CODE}
+ GPU_DEVICE_NONE=${GPU_DEVICE_NONE} GPU_DEVICE_CUDA=${GPU_DEVICE_CUDA}
+ GPU_DEVICE_XPU=${GPU_DEVICE_XPU} GPU_DEVICE_MPS=${GPU_DEVICE_MPS})
# Add an alias FTorch::ftorch for the library
add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME})
diff --git a/README.md b/README.md
index 393f20ab..9c8f541f 100644
--- a/README.md
+++ b/README.md
@@ -165,7 +165,7 @@ To build and install the library:
| [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_PREFIX.html) | `` | Location at which the library files should be installed. By default this is `/usr/local` |
| [`CMAKE_BUILD_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html) | `Release` / `Debug` | Specifies build type. The default is `Debug`, use `Release` for production code|
| `CMAKE_BUILD_TESTS` | `TRUE` / `FALSE` | Specifies whether to compile FTorch's [test suite](https://cambridge-iccs.github.io/FTorch/page/testing.html) as part of the build. |
- | `ENABLE_CUDA` | `TRUE` / `FALSE` | Specifies whether to check for and enable CUDA3 |
+ | `GPU_DEVICE` | `NONE` / `CUDA` / `XPU` / `MPS` | Specifies the target GPU architecture (if any) 3 |
1 _On Windows this may need to be the full path to the compiler if CMake cannot locate it by default._
@@ -176,11 +176,9 @@ To build and install the library:
e.g. with `pip install torch`, then this should be `lib/python<3.xx>/site-packages/torch/`.
You can find the location of your torch install by importing torch from your Python environment (`import torch`) and running `print(torch.__file__)`_
- 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU or GPU only version is installed, e.g.
- `pip install torch --index-url https://download.pytorch.org/whl/cpu`
- or
- `pip install torch --index-url https://download.pytorch.org/whl/cu118`
- (for CUDA 11.8). URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._
+ 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU-only or GPU-enabled version is installed, e.g.
+ `pip install torch --index-url https://download.pytorch.org/whl/cpu`.
+ URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._
4. Make and install the library to the desired location with either:
```
@@ -219,12 +217,16 @@ These steps are described in more detail in the
## GPU Support
-To run on GPU requires a CUDA-compatible installation of LibTorch and two main
-adaptations to the code:
+To run on GPU requires an installation of LibTorch compatible for the GPU device
+you wish to target and two main adaptations to the code:
-1. When saving a TorchScript model, ensure that it is on the GPU
+1. When saving a TorchScript model, ensure that it is on the appropriate GPU
+ device type. The `pt2ts.py` script has a command line argument
+ `--device_type`, which currently accepts four different device types: `cpu`
+ (default), `cuda`, `xpu`, or `mps`.
2. When using FTorch in Fortran, set the device for the input
- tensor(s) to `torch_kCUDA`, rather than `torch_kCPU`.
+ tensor(s) to the appropriate GPU device type, rather than `torch_kCPU`. There
+ are currently three options: `torch_kCUDA`, `torch_kXPU`, or `torch_kMPS`.
For detailed guidance about running on GPU, including instructions for using multiple
devices, please see the
diff --git a/conda/README.md b/conda/README.md
index 9090210a..02214191 100644
--- a/conda/README.md
+++ b/conda/README.md
@@ -38,7 +38,7 @@ cmake \
-DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX \
-DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)') \
-DCMAKE_BUILD_TYPE=Release \
- -DENABLE_CUDA=FALSE \
+ -DGPU_DEVICE=NONE \
..
cmake --build . --target install
```
@@ -65,7 +65,7 @@ cmake \
-DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX \
-DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)') \
-DCMAKE_BUILD_TYPE=Release \
- -DENABLE_CUDA=TRUE \
+ -DGPU_DEVICE=CUDA \
-DCUDA_TOOLKIT_ROOT_DIR=$CONDA_PREFIX/targets/x86_64-linux \
-Dnvtx3_dir=$CONDA_PREFIX/targets/x86_64-linux/include/nvtx3 \
..
diff --git a/examples/1_SimpleNet/CMakeLists.txt b/examples/1_SimpleNet/CMakeLists.txt
index 15e0b579..97c08698 100644
--- a/examples/1_SimpleNet/CMakeLists.txt
+++ b/examples/1_SimpleNet/CMakeLists.txt
@@ -33,9 +33,8 @@ if(CMAKE_BUILD_TESTS)
# pt2ts.py script
add_test(
NAME pt2ts
- COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving
- # the model
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --filepath
+ ${PROJECT_BINARY_DIR}
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
# 3. Check the model can be loaded from file and run in Python and that its
@@ -43,8 +42,7 @@ if(CMAKE_BUILD_TESTS)
add_test(
NAME simplenet_infer_python
COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet_infer_python.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath to find the
- # model
+ --filepath ${PROJECT_BINARY_DIR}
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
# 4. Check the model can be loaded from file and run in Fortran and that its
diff --git a/examples/1_SimpleNet/pt2ts.py b/examples/1_SimpleNet/pt2ts.py
index 161ea2e7..35c22350 100644
--- a/examples/1_SimpleNet/pt2ts.py
+++ b/examples/1_SimpleNet/pt2ts.py
@@ -1,7 +1,6 @@
"""Load a PyTorch model and convert it to TorchScript."""
import os
-import sys
from typing import Optional
# FPTLIB-TODO
@@ -72,6 +71,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cpu",
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ device_type = parsed_args.device_type
+ filepath = parsed_args.filepath
+
# =====================================================
# Load model and prepare for saving
# =====================================================
@@ -97,12 +118,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# This example assumes one input of size (5)
trained_model_dummy_input = torch.ones(5)
- # FPTLIB-TODO
- # Uncomment the following lines to save for inference on GPU (rather than CPU):
- # device = torch.device('cuda')
- # trained_model = trained_model.to(device)
- # trained_model.eval()
- # trained_model_dummy_input = trained_model_dummy_input.to(device)
+ # Transfer the model and inputs to GPU device, if appropriate
+ if device_type != "cpu":
+ device = torch.device(device_type)
+ trained_model = trained_model.to(device)
+ trained_model.eval()
+ trained_model_dummy_input = trained_model_dummy_input.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
@@ -117,7 +138,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
- saved_ts_filename = "saved_simplenet_model_cpu.pt"
+ saved_ts_filename = f"saved_simplenet_model_{device_type}.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
@@ -141,9 +162,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# Check model saved OK
# =====================================================
- # Load torchscript and run model as a test
- # FPTLIB-TODO
- # Scale inputs as above and, if required, move inputs and mode to GPU
+ # Load torchscript and run model as a test, scaling inputs as above
trained_model_dummy_input = 2.0 * trained_model_dummy_input
trained_model_testing_outputs = trained_model(
trained_model_dummy_input,
@@ -169,7 +188,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
raise RuntimeError(model_error)
# Check that the model file is created
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "
diff --git a/examples/1_SimpleNet/simplenet_infer_python.py b/examples/1_SimpleNet/simplenet_infer_python.py
index 338f83c7..2851cef9 100644
--- a/examples/1_SimpleNet/simplenet_infer_python.py
+++ b/examples/1_SimpleNet/simplenet_infer_python.py
@@ -1,7 +1,6 @@
"""Load saved SimpleNet to TorchScript and run inference example."""
import os
-import sys
import torch
@@ -49,7 +48,19 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
if __name__ == "__main__":
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ filepath = parsed_args.filepath
saved_model_file = os.path.join(filepath, "saved_simplenet_model_cpu.pt")
device_to_run = "cpu"
diff --git a/examples/2_ResNet18/CMakeLists.txt b/examples/2_ResNet18/CMakeLists.txt
index a3f8a7c1..56e8b63f 100644
--- a/examples/2_ResNet18/CMakeLists.txt
+++ b/examples/2_ResNet18/CMakeLists.txt
@@ -35,9 +35,8 @@ if(CMAKE_BUILD_TESTS)
# pt2ts.py script
add_test(
NAME pt2ts
- COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving
- # the model
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --filepath
+ ${PROJECT_BINARY_DIR}
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
# 3. Check the model can be loaded from file and run in Fortran and that its
diff --git a/examples/2_ResNet18/pt2ts.py b/examples/2_ResNet18/pt2ts.py
index d04cb5c4..2db59bfb 100644
--- a/examples/2_ResNet18/pt2ts.py
+++ b/examples/2_ResNet18/pt2ts.py
@@ -1,7 +1,6 @@
"""Load a PyTorch model and convert it to TorchScript."""
import os
-import sys
from typing import Optional
# FPTLIB-TODO
@@ -75,6 +74,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cpu",
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ device_type = parsed_args.device_type
+ filepath = parsed_args.filepath
+
# =====================================================
# Load model and prepare for saving
# =====================================================
@@ -103,12 +124,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# of resolution 244x244 in a batch size of 1.
trained_model_dummy_input = torch.ones(1, 3, 224, 224)
- # FPTLIB-TODO
- # Uncomment the following lines to save for inference on GPU (rather than CPU):
- # device = torch.device('cuda')
- # trained_model = trained_model.to(device)
- # trained_model.eval()
- # trained_model_dummy_input = trained_model_dummy_input.to(device)
+ # Transfer the model and inputs to GPU device, if appropriate
+ if device_type != "cpu":
+ device = torch.device(device_type)
+ trained_model = trained_model.to(device)
+ trained_model.eval()
+ trained_model_dummy_input = trained_model_dummy_input.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
@@ -123,7 +144,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
- saved_ts_filename = "saved_resnet18_model_cpu.pt"
+ saved_ts_filename = f"saved_resnet18_model_{device_type}.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
@@ -147,9 +168,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# Check model saved OK
# =====================================================
- # Load torchscript and run model as a test
- # FPTLIB-TODO
- # Scale inputs as above and, if required, move inputs and mode to GPU
+ # Load torchscript and run model as a test, scaling inputs as above
trained_model_dummy_input = 2.0 * trained_model_dummy_input
trained_model_testing_outputs = trained_model(
trained_model_dummy_input,
@@ -175,7 +194,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
raise RuntimeError(model_error)
# Check that the model file is created
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "
diff --git a/examples/2_ResNet18/resnet_infer_python.py b/examples/2_ResNet18/resnet_infer_python.py
index c5590c7a..1834bb17 100644
--- a/examples/2_ResNet18/resnet_infer_python.py
+++ b/examples/2_ResNet18/resnet_infer_python.py
@@ -1,7 +1,6 @@
"""Load ResNet-18 saved to TorchScript and run inference with an example image."""
import os
-import sys
from math import isclose
import numpy as np
@@ -78,11 +77,22 @@ def check_results(output: torch.Tensor) -> None:
if __name__ == "__main__":
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ filepath = parsed_args.filepath
saved_model_file = os.path.join(filepath, "saved_resnet18_model_cpu.pt")
device_to_run = "cpu"
- # device_to_run = "cuda"
batch_size_to_run = 1
diff --git a/examples/3_MultiGPU/CMakeLists.txt b/examples/3_MultiGPU/CMakeLists.txt
index 81bcb4e8..dea6b5ba 100644
--- a/examples/3_MultiGPU/CMakeLists.txt
+++ b/examples/3_MultiGPU/CMakeLists.txt
@@ -18,11 +18,13 @@ find_package(FTorch)
message(STATUS "Building with Fortran PyTorch coupling")
include(CheckLanguage)
-check_language(CUDA)
-if(CMAKE_CUDA_COMPILER)
- enable_language(CUDA)
-else()
- message(ERROR "No CUDA support")
+if("${GPU_DEVICE}" STREQUAL "CUDA")
+ check_language(CUDA)
+ if(CMAKE_CUDA_COMPILER)
+ enable_language(CUDA)
+ else()
+ message(ERROR "No CUDA support")
+ endif()
endif()
# Fortran example
@@ -30,40 +32,116 @@ add_executable(multigpu_infer_fortran multigpu_infer_fortran.f90)
target_link_libraries(multigpu_infer_fortran PRIVATE FTorch::ftorch)
# Integration testing
-if (CMAKE_BUILD_TESTS)
+if(CMAKE_BUILD_TESTS)
include(CTest)
- # 1. Check the PyTorch model runs and its outputs meet expectations
- add_test(NAME simplenet COMMAND ${Python_EXECUTABLE}
- ${PROJECT_SOURCE_DIR}/simplenet.py)
-
- # 2. Check the model is saved to file in the expected location with the
- # pt2ts.py script
- add_test(
- NAME pt2ts
- COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving
- # the model
- WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
-
- # 3. Check the model can be loaded from file and run in Python and that its
- # outputs meet expectations
- add_test(
- NAME multigpu_infer_python
- COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath to find the
- # model
- WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
-
- # 4. Check the model can be loaded from file and run in Fortran and that its
- # outputs meet expectations
- add_test(
- NAME multigpu_infer_fortran
- COMMAND
- multigpu_infer_fortran ${PROJECT_BINARY_DIR}/saved_multigpu_model_cuda.pt
- # Command line argument: model file
- WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
- set_tests_properties(
- multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION
- "MultiGPU example ran successfully")
+ if("${GPU_DEVICE}" STREQUAL "CUDA")
+ # 1a. Check the PyTorch model runs on a CUDA device and its outputs meet
+ # expectations
+ add_test(NAME simplenet
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet.py
+ --device_type cuda)
+
+ # 2a. Check the model is saved to file in the expected location with the
+ # pt2ts.py script
+ add_test(
+ NAME pt2ts
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --device_type
+ cuda --filepath ${PROJECT_BINARY_DIR}
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+
+ # 3a. Check the model can be loaded from file and run on two CUDA devices in
+ # Python and that its outputs meet expectations
+ add_test(
+ NAME multigpu_infer_python
+ COMMAND ${Python_EXECUTABLE}
+ ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py --device_type cuda
+ --filepath ${PROJECT_BINARY_DIR}
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+
+ # 4a. Check the model can be loaded from file and run on two CUDA devices in
+ # Fortran and that its outputs meet expectations
+ add_test(
+ NAME multigpu_infer_fortran
+ COMMAND multigpu_infer_fortran cuda
+ ${PROJECT_BINARY_DIR}/saved_multigpu_model_cuda.pt
+ # Command line arguments for device type and model file
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+ set_tests_properties(
+ multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION
+ "MultiGPU example ran successfully")
+ endif()
+
+ if("${GPU_DEVICE}" STREQUAL "XPU")
+ # 1b. Check the PyTorch model runs on an XPU device and its outputs meet
+ # expectations
+ add_test(NAME simplenet
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet.py
+ --device_type xpu)
+
+ # 2b. Check the model is saved to file in the expected location with the
+ # pt2ts.py script
+ add_test(
+ NAME pt2ts
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --device_type
+ xpu --filepath ${PROJECT_BINARY_DIR}
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+
+ # 3b. Check the model can be loaded from file and run on two XPU devices in
+ # Python and that its outputs meet expectations
+ add_test(
+ NAME multigpu_infer_python
+ COMMAND ${Python_EXECUTABLE}
+ ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py --device_type xpu
+ --filepath ${PROJECT_BINARY_DIR}
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+
+ # 4b. Check the model can be loaded from file and run on two XPU devices in
+ # Fortran and that its outputs meet expectations
+ add_test(
+ NAME multigpu_infer_fortran
+ COMMAND multigpu_infer_fortran xpu
+ ${PROJECT_BINARY_DIR}/saved_multigpu_model_xpu.pt
+ # Command line arguments for device type and model file
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+ set_tests_properties(
+ multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION
+ "MultiGPU example ran successfully")
+ endif()
+
+ if("${GPU_DEVICE}" STREQUAL "MPS")
+ # 1c. Check the PyTorch model runs on an MPS device and its outputs meet
+ # expectations
+ add_test(NAME simplenet
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet.py
+ --device_type mps)
+ # 2c. Check the model is saved to file in the expected location with the
+ # pt2ts.py script
+ add_test(
+ NAME pt2ts
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --device_type
+ mps --filepath ${PROJECT_BINARY_DIR}
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+
+ # 3c. Check the model can be loaded from file and run on one MPS device in
+ # Python and that its outputs meet expectations
+ add_test(
+ NAME multigpu_infer_python
+ COMMAND ${Python_EXECUTABLE}
+ ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py --device_type mps
+ --filepath ${PROJECT_BINARY_DIR}
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+
+ # 4c. Check the model can be loaded from file and run on one MPS device in
+ # Fortran and that its outputs meet expectations
+ add_test(
+ NAME multigpu_infer_fortran
+ COMMAND multigpu_infer_fortran mps
+ ${PROJECT_BINARY_DIR}/saved_multigpu_model_mps.pt
+ # Command line arguments for device type and model file
+ WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
+ set_tests_properties(
+ multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION
+ "MultiGPU example ran successfully")
+ endif()
endif()
diff --git a/examples/3_MultiGPU/README.md b/examples/3_MultiGPU/README.md
index 9c46d83c..7b0f2543 100644
--- a/examples/3_MultiGPU/README.md
+++ b/examples/3_MultiGPU/README.md
@@ -20,8 +20,8 @@ the TorchScript model in inference mode.
To run this example requires:
- CMake
-- Two (or more) GPU devices that support CUDA and have it installed.
-- FTorch (installed with CUDA enabled as described in main package)
+- Two (or more) CUDA or XPU GPU devices (or a single MPS device).
+- FTorch (installed with a GPU_DEVICE enabled as described in main package)
- Python 3
## Running
@@ -36,30 +36,33 @@ pip install -r requirements.txt
You can check that everything is working by running `simplenet.py`:
```
-python3 simplenet.py
+python3 simplenet.py --device_type
```
+where `` is `cuda`/`xpu`/`mps` as appropriate for your device.
+
As before, this defines the network and runs it with an input tensor
[0.0, 1.0, 2.0, 3.0, 4.0]. The difference is that the code will make use of the
-default CUDA device (index 0) to produce the result:
+default GPU device (index 0) to produce the result:
```
SimpleNet forward pass on CUDA device 0
tensor([[0, 2, 4, 6, 8]])
```
+for CUDA, and similarly for other device types.
To save the `SimpleNet` model to TorchScript run the modified version of the
`pt2ts.py` tool:
```
-python3 pt2ts.py
+python3 pt2ts.py --device_type
```
-which will generate `saved_multigpu_model_cuda.pt` - the TorchScript instance
-of the network. The only difference with the earlier example is that the model
-is built to be run using CUDA rather than on CPU.
+which will generate `saved_multigpu_model_.pt` - the TorchScript
+instance of the network. The only difference with the earlier example is that
+the model is built to be run on GPU devices rather than on CPU.
You can check that everything is working by running the
`multigpu_infer_python.py` script. It's set up such that it loops over two GPU
devices. Run with:
```
-python3 multigpu_infer_python.py
+python3 multigpu_infer_python.py --device_type
```
This reads the model in from the TorchScript file and runs it with an different input
tensor on each GPU device: [0.0, 1.0, 2.0, 3.0, 4.0], plus the device index in each
@@ -68,6 +71,7 @@ entry. The result should be:
Output on device 0: tensor([[0., 2., 4., 6., 8.]])
Output on device 1: tensor([[ 2., 4., 6., 8., 10.]])
```
+Note that Mps will only use device 0.
At this point we no longer require Python, so can deactivate the virtual environment:
```
@@ -89,9 +93,9 @@ cmake --build .
and should match the compiler that was used to locally build FTorch.)
To run the compiled code calling the saved `SimpleNet` TorchScript from
-Fortran, run the executable with an argument of the saved model file:
+Fortran, run the executable with arguments of device type and the saved model file:
```
-./multigpu_infer_fortran ../saved_multigpu_model_cuda.pt
+./multigpu_infer_fortran ../saved_multigpu_model_.pt
```
This runs the model with the same inputs as described above and should produce (some
@@ -102,6 +106,7 @@ input on device 1: [ 1.0, 2.0, 3.0, 4.0, 5.0]
output on device 0: [ 0.0, 2.0, 4.0, 6.0, 8.0]
output on device 1: [ 2.0, 4.0, 6.0, 8.0, 10.0]
```
+Again, note that MPS will only use device 0.
Alternatively, we can use `make`, instead of CMake, copying the Makefile over from the
first example:
diff --git a/examples/3_MultiGPU/multigpu_infer_fortran.f90 b/examples/3_MultiGPU/multigpu_infer_fortran.f90
index 297844e4..d5d49980 100644
--- a/examples/3_MultiGPU/multigpu_infer_fortran.f90
+++ b/examples/3_MultiGPU/multigpu_infer_fortran.f90
@@ -4,10 +4,14 @@ program inference
use, intrinsic :: iso_fortran_env, only : sp => real32
! Import our library for interfacing with PyTorch
- use ftorch, only : torch_model, torch_tensor, torch_kCUDA, torch_kCPU, &
+ use ftorch, only : torch_model, torch_tensor, &
+ torch_kCPU, torch_kCUDA, torch_kXPU, torch_kMPS, &
torch_tensor_from_array, torch_model_load, torch_model_forward, &
torch_delete
+ ! Import our tools module for testing utils
+ use ftorch_test_utils, only : assert_allclose
+
implicit none
! Set precision for reals
@@ -19,6 +23,7 @@ program inference
! Set up Fortran data structures
real(wp), dimension(5), target :: in_data
real(wp), dimension(5), target :: out_data
+ real(wp), dimension(5) :: expected
integer, parameter :: tensor_layout(1) = [1]
! Set up Torch data structures
@@ -27,15 +32,30 @@ program inference
type(torch_tensor), dimension(1) :: out_tensors
! Variables for multi-GPU setup
- integer, parameter :: num_devices = 2
- integer :: device_index, i
+ integer :: num_devices = 2
+ integer :: device_type, device_index, i
+
+ ! Flag for testing
+ logical :: test_pass
- ! Get TorchScript model file as a command line argument
+ ! Get device type as first command line argument and TorchScript model file as second command
+ ! line argument
num_args = command_argument_count()
allocate(args(num_args))
do ix = 1, num_args
call get_command_argument(ix,args(ix))
end do
+ if (trim(args(1)) == "cuda") then
+ device_type = torch_kCUDA
+ else if (trim(args(1)) == "xpu") then
+ device_type = torch_kXPU
+ else if (trim(args(1)) == "mps") then
+ device_type = torch_kMPS
+ num_devices = 1
+ else
+ write (*,*) "Error :: invalid device type", trim(args(1))
+ stop 999
+ end if
do device_index = 0, num_devices-1
@@ -46,8 +66,8 @@ program inference
! Create Torch input tensor from the above array and assign it to the first (and only)
! element in the array of input tensors.
- ! We use the torch_kCUDA device type with the given device index
- call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, torch_kCUDA, &
+ ! We use the specified GPU device type with the given device index
+ call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, device_type, &
device_index=device_index)
! Create Torch output tensor from the above array.
@@ -57,7 +77,7 @@ program inference
! Load ML model. Ensure that the same device type and device index are used
! as for the input data.
- call torch_model_load(model, args(1), torch_kCUDA, device_index=device_index)
+ call torch_model_load(model, args(2), device_type, device_index=device_index)
! Infer
call torch_model_forward(model, in_tensors, out_tensors)
@@ -66,11 +86,19 @@ program inference
write (6, 200) device_index, out_data(:)
200 format("output on device ", i1,": [", 4(f5.1,","), f5.1,"]")
+ ! Check output tensor matches expected value
+ expected = [(2 * (device_index + i), i = 0, 4)]
+ test_pass = assert_allclose(out_data, expected, test_name="MultiGPU")
+
! Cleanup
call torch_delete(model)
call torch_delete(in_tensors)
call torch_delete(out_tensors)
+ if (.not. test_pass) then
+ stop 999
+ end if
+
end do
write (*,*) "MultiGPU example ran successfully"
diff --git a/examples/3_MultiGPU/multigpu_infer_python.py b/examples/3_MultiGPU/multigpu_infer_python.py
index e6075874..cc4c1c8e 100644
--- a/examples/3_MultiGPU/multigpu_infer_python.py
+++ b/examples/3_MultiGPU/multigpu_infer_python.py
@@ -1,5 +1,7 @@
"""Load saved SimpleNet to TorchScript and run inference example."""
+import os
+
import torch
@@ -28,39 +30,76 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
# Load saved TorchScript model
model = torch.jit.load(saved_model)
# Inference
- output = model.forward(input_tensor)
-
- elif device.startswith("cuda"):
- # Add the device index to each tensor to make them differ
- input_tensor += int(device.split(":")[-1] or 0)
-
- # All previously saved modules, no matter their device, are first
- # loaded onto CPU, and then are moved to the devices they were saved
- # from, so we don't need to manually transfer the model to the GPU
- model = torch.jit.load(saved_model)
- model = model.to(device)
- input_tensor_gpu = input_tensor.to(torch.device(device))
- output_gpu = model.forward(input_tensor_gpu)
- output = output_gpu.to(torch.device("cpu"))
-
+ return model.forward(input_tensor)
+
+ if device.startswith("cuda"):
+ pass
+ elif device.startswith("xpu"):
+ # XPU devices need to be initialised before use
+ torch.xpu.init()
+ elif device.startswith("mps"):
+ pass
else:
device_error = f"Device '{device}' not recognised."
raise ValueError(device_error)
+ # Add the device index to each tensor to make them differ
+ input_tensor += int(device.split(":")[-1] or 0)
+
+ # All previously saved modules, no matter their device, are first
+ # loaded onto CPU, and then are moved to the devices they were saved
+ # from.
+ # Since we are loading one saved model to multiple devices we explicitly
+ # transfer using `.to(device)` to ensure the model is on the correct index.
+ model = torch.jit.load(saved_model)
+ model = model.to(device)
+ input_tensor_gpu = input_tensor.to(torch.device(device))
+ output_gpu = model.forward(input_tensor_gpu)
+ output = output_gpu.to(torch.device("cpu"))
+
return output
if __name__ == "__main__":
- saved_model_file = "saved_multigpu_model_cuda.pt"
-
- num_devices = 2
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cuda",
+ )
+ parsed_args = parser.parse_args()
+ filepath = parsed_args.filepath
+ device_type = parsed_args.device_type
+ saved_model_file = os.path.join(filepath, f"saved_multigpu_model_{device_type}.pt")
+
+ batch_size_to_run = 1
+
+ # Use 2 devices unless MPS for which there is only one
+ num_devices = 1 if device_type == "mps" else 2
for device_index in range(num_devices):
- device_to_run = f"cuda:{device_index}"
-
- batch_size_to_run = 1
+ device_to_run = f"{device_type}:{device_index}"
with torch.no_grad():
result = deploy(saved_model_file, device_to_run, batch_size_to_run)
print(f"Output on device {device_to_run}: {result}")
+
+ expected = torch.Tensor([2 * (i + device_index) for i in range(5)])
+ if not torch.allclose(result, expected):
+ result_error = (
+ f"result:\n{result}\ndoes not match expected value:\n{expected}"
+ )
+ raise ValueError(result_error)
diff --git a/examples/3_MultiGPU/pt2ts.py b/examples/3_MultiGPU/pt2ts.py
index 8242f691..823e69dd 100644
--- a/examples/3_MultiGPU/pt2ts.py
+++ b/examples/3_MultiGPU/pt2ts.py
@@ -1,7 +1,6 @@
"""Load a PyTorch model and convert it to TorchScript."""
import os
-import sys
from typing import Optional
# FPTLIB-TODO
@@ -72,6 +71,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cuda",
+ )
+ parsed_args = parser.parse_args()
+ filepath = parsed_args.filepath
+ device_type = parsed_args.device_type
+
# =====================================================
# Load model and prepare for saving
# =====================================================
@@ -97,12 +118,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# This example assumes one input of size (5)
trained_model_dummy_input = torch.ones(5)
- # FPTLIB-TODO
- # Uncomment the following lines to save for inference on GPU (rather than CPU):
- device = torch.device("cuda")
- trained_model = trained_model.to(device)
- trained_model.eval()
- trained_model_dummy_input = trained_model_dummy_input.to(device)
+ # Transfer the model and inputs to GPU device, if appropriate
+ if device_type != "cpu":
+ device = torch.device(device_type)
+ trained_model = trained_model.to(device)
+ trained_model.eval()
+ trained_model_dummy_input = trained_model_dummy_input.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
@@ -117,7 +138,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
- saved_ts_filename = "saved_multigpu_model_cuda.pt"
+ saved_ts_filename = f"saved_multigpu_model_{device_type}.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
@@ -141,11 +162,8 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# Check model saved OK
# =====================================================
- # Load torchscript and run model as a test
- # FPTLIB-TODO
- # Scale inputs as above and, if required, move inputs and mode to GPU
+ # Load torchscript and run model as a test, scaling inputs as above
trained_model_dummy_input = 2.0 * trained_model_dummy_input
- trained_model_dummy_input = trained_model_dummy_input.to("cuda")
trained_model_testing_outputs = trained_model(
trained_model_dummy_input,
)
@@ -170,7 +188,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
raise RuntimeError(model_error)
# Check that the model file is created
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "
diff --git a/examples/3_MultiGPU/simplenet.py b/examples/3_MultiGPU/simplenet.py
index b7c2e6b2..46f46550 100644
--- a/examples/3_MultiGPU/simplenet.py
+++ b/examples/3_MultiGPU/simplenet.py
@@ -42,13 +42,38 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:
if __name__ == "__main__":
- model = SimpleNet().to(torch.device("cuda"))
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cuda",
+ )
+ parsed_args = parser.parse_args()
+ device_type = parsed_args.device_type
+
+ model = SimpleNet().to(torch.device(device_type))
model.eval()
input_tensor = torch.Tensor([0.0, 1.0, 2.0, 3.0, 4.0])
- input_tensor_gpu = input_tensor.to(torch.device("cuda"))
+ input_tensor_gpu = input_tensor.to(torch.device(device_type))
- print(f"SimpleNet forward pass on CUDA device {input_tensor_gpu.get_device()}")
+ print(
+ f"SimpleNet forward pass on {device_type.capitalize()} device"
+ f" {input_tensor_gpu.get_device()}"
+ )
with torch.no_grad():
- output = model(input_tensor_gpu)
- print(output)
+ output_tensor = model(input_tensor_gpu).to("cpu")
+
+ print(output_tensor)
+ if not torch.allclose(output_tensor, 2 * input_tensor):
+ result_error = (
+ f"result:\n{output_tensor}\ndoes not match expected value:\n"
+ f"{2 * input_tensor}"
+ )
+ raise ValueError(result_error)
diff --git a/examples/4_MultiIO/CMakeLists.txt b/examples/4_MultiIO/CMakeLists.txt
index d8bef873..03b96dff 100644
--- a/examples/4_MultiIO/CMakeLists.txt
+++ b/examples/4_MultiIO/CMakeLists.txt
@@ -33,9 +33,8 @@ if(CMAKE_BUILD_TESTS)
# pt2ts.py script
add_test(
NAME pt2ts
- COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving
- # the model
+ COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --filepath
+ ${PROJECT_BINARY_DIR}
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
# 3. Check the model can be loaded from file and run in Python and that its
@@ -44,7 +43,7 @@ if(CMAKE_BUILD_TESTS)
NAME multiionet_infer_python
COMMAND
${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/multiionet_infer_python.py
- ${PROJECT_BINARY_DIR} # Command line argument: filepath to find the model
+ --filepath ${PROJECT_BINARY_DIR}
WORKING_DIRECTORY ${PROJECT_BINARY_DIR})
# 4. Check the model can be loaded from file and run in Fortran and that its
diff --git a/examples/4_MultiIO/multiionet_infer_python.py b/examples/4_MultiIO/multiionet_infer_python.py
index d5c9cf2d..a615289c 100644
--- a/examples/4_MultiIO/multiionet_infer_python.py
+++ b/examples/4_MultiIO/multiionet_infer_python.py
@@ -1,7 +1,6 @@
"""Load saved MultIONet to TorchScript and run inference example."""
import os
-import sys
import torch
@@ -52,7 +51,19 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
if __name__ == "__main__":
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ filepath = parsed_args.filepath
saved_model_file = os.path.join(filepath, "saved_multiio_model_cpu.pt")
device_to_run = "cpu"
diff --git a/examples/4_MultiIO/pt2ts.py b/examples/4_MultiIO/pt2ts.py
index e7cd67e6..31066406 100644
--- a/examples/4_MultiIO/pt2ts.py
+++ b/examples/4_MultiIO/pt2ts.py
@@ -1,7 +1,6 @@
"""Load a PyTorch model and convert it to TorchScript."""
import os
-import sys
from typing import Optional
# FPTLIB-TODO
@@ -72,6 +71,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cpu",
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ device_type = parsed_args.device_type
+ filepath = parsed_args.filepath
+
# =====================================================
# Load model and prepare for saving
# =====================================================
@@ -97,12 +118,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# This example assumes one input of size (5)
trained_model_dummy_inputs = (torch.ones(4), torch.ones(4))
- # FPTLIB-TODO
- # Uncomment the following lines to save for inference on GPU (rather than CPU):
- # device = torch.device('cuda')
- # trained_model = trained_model.to(device)
- # trained_model.eval()
- # trained_model_dummy_input = trained_model_dummy_input.to(device)
+ # Transfer the model and inputs to GPU device, if appropriate
+ if device_type != "cpu":
+ device = torch.device(device_type)
+ trained_model = trained_model.to(device)
+ trained_model.eval()
+ trained_model_dummy_inputs = trained_model_dummy_inputs.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
@@ -117,7 +138,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
- saved_ts_filename = "saved_multiio_model_cpu.pt"
+ saved_ts_filename = f"saved_multiio_model_{device_type}.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
@@ -141,9 +162,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# Check model saved OK
# =====================================================
- # Load torchscript and run model as a test
- # FPTLIB-TODO
- # Scale inputs as above and, if required, move inputs and mode to GPU
+ # Load torchscript and run model as a test, scaling inputs as above
trained_model_dummy_inputs = (
2.0 * trained_model_dummy_inputs[0],
3.0 * trained_model_dummy_inputs[1],
@@ -172,7 +191,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
raise RuntimeError(model_error)
# Check that the model file is created
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "
diff --git a/examples/5_Looping/pt2ts.py b/examples/5_Looping/pt2ts.py
index 114bb325..962903fd 100644
--- a/examples/5_Looping/pt2ts.py
+++ b/examples/5_Looping/pt2ts.py
@@ -1,7 +1,6 @@
"""Load a PyTorch model and convert it to TorchScript."""
import os
-import sys
from typing import Optional
# FPTLIB-TODO
@@ -68,6 +67,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cpu",
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ device_type = parsed_args.device_type
+ filepath = parsed_args.filepath
+
# =====================================================
# Load model and prepare for saving
# =====================================================
@@ -93,12 +114,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# This example assumes two inputs of size (512x40) and (512x1)
trained_model_dummy_input = torch.ones((5), dtype=torch.float32)
- # FPTLIB-TODO
- # Uncomment the following lines to save for inference on GPU (rather than CPU):
- # device = torch.device('cuda')
- # trained_model = trained_model.to(device)
- # trained_model.eval()
- # trained_model_dummy_input = trained_model_dummy_input.to(device)
+ # Transfer the model and inputs to GPU device, if appropriate
+ if device_type != "cpu":
+ device = torch.device(device_type)
+ trained_model = trained_model.to(device)
+ trained_model.eval()
+ trained_model_dummy_input = trained_model_dummy_input.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
@@ -113,7 +134,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
- saved_ts_filename = "saved_simplenet_model.pt"
+ saved_ts_filename = f"saved_simplenet_{device_type}.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
@@ -135,9 +156,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# Check model saved OK
# =====================================================
- # Load torchscript and run model as a test
- # FPTLIB-TODO
- # Scale inputs as above and, if required, move inputs and mode to GPU
+ # Load torchscript and run model as a test, scaling inputs as above
trained_model_dummy_input = 2.0 * trained_model_dummy_input
trained_model_testing_outputs = trained_model(
trained_model_dummy_input,
@@ -163,7 +182,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
raise RuntimeError(model_error)
# Check that the model file is created
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index c69812f8..b651a594 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -1,7 +1,7 @@
if(CMAKE_BUILD_TESTS)
add_subdirectory(1_SimpleNet)
add_subdirectory(2_ResNet18)
- if(ENABLE_CUDA)
+ if(NOT "${GPU_DEVICE}" STREQUAL "NONE")
add_subdirectory(3_MultiGPU)
endif()
add_subdirectory(4_MultiIO)
diff --git a/pages/cmake.md b/pages/cmake.md
index 90dc11c6..5a4ee565 100644
--- a/pages/cmake.md
+++ b/pages/cmake.md
@@ -62,24 +62,22 @@ The following CMake flags are available and can be passed as arguments through `
| [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_PREFIX.html) | `` | Location at which the library files should be installed. By default this is `/usr/local` |
| [`CMAKE_BUILD_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html) | `Release` / `Debug` | Specifies build type. The default is `Debug`, use `Release` for production code|
| `CMAKE_BUILD_TESTS` | `TRUE` / `FALSE` | Specifies whether to compile FTorch's [test suite](testing.html) as part of the build. |
-| `ENABLE_CUDA` | `TRUE` / `FALSE` | Specifies whether to check for and enable CUDA3 |
+| `GPU_DEVICE` | `NONE` / `CUDA` / `XPU` / `MPS` | Specifies the target GPU architecture (if any) 3 |
> 1 _On Windows this may need to be the full path to the compiler if CMake
> cannot locate it by default._
->
+>
> 2 _The path to the Torch installation needs to allow CMake to locate the relevant Torch CMake files.
> If Torch has been [installed as LibTorch](https://pytorch.org/cppdocs/installing.html)
> then this should be the absolute path to the unzipped LibTorch distribution.
> If Torch has been installed as PyTorch in a Python [venv (virtual environment)](https://docs.python.org/3/library/venv.html),
> e.g. with `pip install torch`, then this should be `lib/python<3.xx>/site-packages/torch/`._
->
-> 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU or GPU only version is installed, e.g.
-> `pip install torch --index-url https://download.pytorch.org/whl/cpu`
-> or
-> `pip install torch --index-url https://download.pytorch.org/whl/cu118`
-> (for CUDA 11.8). URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._
+>
+> 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU-only or GPU-enabled version is installed, e.g.
+> `pip install torch --index-url https://download.pytorch.org/whl/cpu`.
+> URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._
For example, to build on a unix system using the gnu compilers and install to `$HOME/FTorchbin/`
we would need to run:
diff --git a/pages/developer.md b/pages/developer.md
index 13a77c30..b6b64bfd 100644
--- a/pages/developer.md
+++ b/pages/developer.md
@@ -85,6 +85,37 @@ files `src/ctorch.h` and `src/ctorch.cpp`, we refer to the Torch
[C++ API documentation](https://pytorch.org/cppdocs/api/library_root.html)
pages on the PyTorch website for details.
+### GPU device handling
+
+GPU device-specific code is handled in FTorch using codes defined in the root
+`CMakeLists.txt` file:
+```cmake
+set(GPU_DEVICE_NONE 0)
+set(GPU_DEVICE_CUDA 1)
+set(GPU_DEVICE_XPU 12)
+set(GPU_DEVICE_MPS 13)
+```
+These device codes are chosen to be consistent with the numbering used in
+PyTorch (see
+https://github.com/pytorch/pytorch/blob/main/c10/core/DeviceType.h). When a user
+specifies `-DGPU_DEVICE=XPU` (for example) in the FTorch CMake build, this is
+mapped to the appropriate device code (in this case 12). The chosen device code
+and all other ones defined are passed to the C++ compiler in the following step:
+```cmake
+target_compile_definitions(
+ ${LIB_NAME}
+ PRIVATE ${COMPILE_DEFS} GPU_DEVICE=${GPU_DEVICE_CODE}
+ GPU_DEVICE_NONE=${GPU_DEVICE_NONE} GPU_DEVICE_CUDA=${GPU_DEVICE_CUDA}
+ GPU_DEVICE_XPU=${GPU_DEVICE_XPU} GPU_DEVICE_MPS=${GPU_DEVICE_MPS})
+```
+The chosen device code will enable the appropriate C pre-processor conditions in
+the C++ source so that that the code relevant to that device type becomes
+active.
+
+An example illustrating why this approach was taken is that if we removed the
+device codes and pre-processor conditions and tried to build with a CPU-only or
+CUDA LibTorch installation then compile errors would arise from the use of the
+`torch::xpu` module in `src/ctorch.cpp`.
### git hook
diff --git a/pages/gpu.md b/pages/gpu.md
index ae84bcfe..7266f4c5 100644
--- a/pages/gpu.md
+++ b/pages/gpu.md
@@ -4,27 +4,41 @@ title: GPU Support
## GPU Support
-In order to run a model on GPU, two main changes are required:
+In order to run a model on GPU, three main changes are required:
-1) When saving your TorchScript model, ensure that it is on the GPU.
+1) When building FTorch, specify the target GPU architecture using the
+`GPU_DEVICE` argument. That is, set
+```sh
+cmake .. -DGPU_DEVICE=
+```
+as appropriate. The default setting is equivalent to
+```sh
+cmake .. -DGPU_DEVICE=NONE
+```
+i.e., CPU-only.
+
+2) When saving your TorchScript model, ensure that it is on the GPU.
For example, when using
[`pt2ts.py`](https://github.com/Cambridge-ICCS/FTorch/blob/main/utils/pt2ts.py),
-this can be done by uncommenting the following lines:
-
+this can be done by passing the `--device_type ` argument. This
+sets the `device_type` variable, which has the effect of transferring the model
+and any input arrays to the specified GPU device in the following lines:
```python
-device_type = torch.device("cuda")
-trained_model = trained_model.to(device_type)
-trained_model.eval()
-trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device_type)
-trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device_type)
+if device_type != "cpu":
+ trained_model = trained_model.to(device_type)
+ trained_model.eval()
+ trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device_type)
+ trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device_type)
```
-> Note: _This code also moves the dummy input tensors to the GPU.
-> Whilst not necessary for saving the model, but the tensors must also be on the GPU
-> to test that the models runs._
+> Note: _This code moves the dummy input tensors to the GPU, as well as the
+> model.
+> Whilst this is not necessary for saving the model the tensors must be on
+> the same GPU device to test that the models runs._
-2) When calling `torch_tensor_from_array` in Fortran, the device type for the input
- tensor(s) should be set to `torch_kCUDA`, rather than `torch_kCPU`.
+3) When calling `torch_tensor_from_array` in Fortran, the device type for the
+ input tensor(s) should be set to the relevant device type (`torch_kCUDA`,
+ `torch_kXPU`, or `torch_kMPS`) rather than `torch_kCPU`.
This ensures that the inputs are on the same device type as the model.
> Note: _You do **not** need to change the device type for the output tensors as we
@@ -32,25 +46,28 @@ trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device_type)
### Multi-GPU runs
-In the case of having multiple GPU devices, as well as setting `torch_kCUDA` as the
-device type for any input tensors and models, you should also specify their device index
-as the GPU device to be targeted. This argument is optional and will default to device
-index 0 if unset.
-
-For example, the following code snippet sets up a Torch tensor with GPU device index 2:
+In the case of having multiple GPU devices, as well as setting the device type
+for any input tensors and models, you should also specify their device index
+as the GPU device to be targeted. This argument is optional and will default to
+device index 0 if unset.
+For example, the following code snippet sets up a Torch tensor with CUDA GPU
+device index 2:
```fortran
device_index = 2
call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, &
torch_kCUDA, device_index=device_index)
```
-
-Whereas the following code snippet sets up a Torch tensor with (default) device index 0:
-
+Whereas the following code snippet sets up a Torch tensor with (default) CUDA
+device index 0:
```fortran
call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, &
torch_kCUDA)
```
+Similarly for the XPU device type.
+
+> Note: The MPS device type does not currently support multiple devices, so the
+> default device index should always be used.
See the
[MultiGPU example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/3_MultiGPU)
diff --git a/src/ctorch.cpp b/src/ctorch.cpp
index b5d5fd5f..cdc330dd 100644
--- a/src/ctorch.cpp
+++ b/src/ctorch.cpp
@@ -8,6 +8,10 @@
#include "ctorch.h"
+#ifndef GPU_DEVICE
+#define GPU_DEVICE GPU_DEVICE_NONE
+#endif
+
// =============================================================================
// --- Constant expressions
// =============================================================================
@@ -78,6 +82,7 @@ const auto get_libtorch_device(torch_device_t device_type, int device_index) {
std::cerr << "[WARNING]: device index unused for CPU-only runs" << std::endl;
}
return torch::Device(torch::kCPU);
+#if GPU_DEVICE == GPU_DEVICE_CUDA
case torch_kCUDA:
if (device_index == -1) {
std::cerr << "[WARNING]: device index unset, defaulting to 0" << std::endl;
@@ -90,6 +95,26 @@ const auto get_libtorch_device(torch_device_t device_type, int device_index) {
<< " for device count " << torch::cuda::device_count() << std::endl;
exit(EXIT_FAILURE);
}
+#endif
+ case torch_kMPS:
+ if (device_index != -1 && device_index != 0) {
+ std::cerr << "[WARNING]: Only one device is available for MPS runs" << std::endl;
+ }
+ return torch::Device(torch::kMPS);
+#if GPU_DEVICE == GPU_DEVICE_XPU
+ case torch_kXPU:
+ if (device_index == -1) {
+ std::cerr << "[WARNING]: device index unset, defaulting to 0" << std::endl;
+ device_index = 0;
+ }
+ if (device_index >= 0 && device_index < torch::xpu::device_count()) {
+ return torch::Device(torch::kXPU, device_index);
+ } else {
+ std::cerr << "[ERROR]: invalid device index " << device_index
+ << " for XPU device count " << torch::xpu::device_count() << std::endl;
+ exit(EXIT_FAILURE);
+ }
+#endif
default:
std::cerr << "[WARNING]: unknown device type, setting to torch_kCPU" << std::endl;
return torch::Device(torch::kCPU);
@@ -103,6 +128,10 @@ const torch_device_t get_ftorch_device(torch::DeviceType device_type) {
return torch_kCPU;
case torch::kCUDA:
return torch_kCUDA;
+ case torch::kXPU:
+ return torch_kXPU;
+ case torch::kMPS:
+ return torch_kMPS;
default:
std::cerr << "[ERROR]: device type " << device_type << " not implemented in FTorch"
<< std::endl;
diff --git a/src/ctorch.h b/src/ctorch.h
index 1123d6eb..8b029247 100644
--- a/src/ctorch.h
+++ b/src/ctorch.h
@@ -38,7 +38,13 @@ typedef enum {
} torch_data_t;
// Device types
-typedef enum { torch_kCPU, torch_kCUDA } torch_device_t;
+// NOTE: Defined in main CMakeLists and passed via preprocessor
+typedef enum {
+ torch_kCPU = GPU_DEVICE_NONE,
+ torch_kCUDA = GPU_DEVICE_CUDA,
+ torch_kXPU = GPU_DEVICE_XPU,
+ torch_kMPS = GPU_DEVICE_MPS,
+} torch_device_t;
// =============================================================================
// --- Functions for constructing tensors
diff --git a/src/ftorch.F90 b/src/ftorch.F90
index c3612a13..ba4939b1 100644
--- a/src/ftorch.F90
+++ b/src/ftorch.F90
@@ -52,9 +52,12 @@ module ftorch
!| Enumerator for Torch devices
! From c_torch.h (torch_device_t)
+ ! NOTE: Defined in main CMakeLists and passed via preprocessor
enum, bind(c)
- enumerator :: torch_kCPU = 0
- enumerator :: torch_kCUDA = 1
+ enumerator :: torch_kCPU = GPU_DEVICE_NONE
+ enumerator :: torch_kCUDA = GPU_DEVICE_CUDA
+ enumerator :: torch_kXPU = GPU_DEVICE_XPU
+ enumerator :: torch_kMPS = GPU_DEVICE_MPS
end enum
! ============================================================================
diff --git a/src/ftorch.fypp b/src/ftorch.fypp
index 0f57e0d8..ad99ba6e 100644
--- a/src/ftorch.fypp
+++ b/src/ftorch.fypp
@@ -71,9 +71,12 @@ module ftorch
!| Enumerator for Torch devices
! From c_torch.h (torch_device_t)
+ ! NOTE: Defined in main CMakeLists and passed via preprocessor
enum, bind(c)
- enumerator :: torch_kCPU = 0
- enumerator :: torch_kCUDA = 1
+ enumerator :: torch_kCPU = GPU_DEVICE_NONE
+ enumerator :: torch_kCUDA = GPU_DEVICE_CUDA
+ enumerator :: torch_kXPU = GPU_DEVICE_XPU
+ enumerator :: torch_kMPS = GPU_DEVICE_MPS
end enum
! ============================================================================
diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt
index fc5705b0..fab05fd0 100644
--- a/test/unit/CMakeLists.txt
+++ b/test/unit/CMakeLists.txt
@@ -15,7 +15,7 @@ add_pfunit_ctest(test_tensor_interrogation
add_pfunit_ctest(test_operator_overloads
TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch)
-if(ENABLE_CUDA)
+if("${GPU_DEVICE}" STREQUAL "CUDA")
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
diff --git a/utils/README.md b/utils/README.md
index fffc9551..540b0dd5 100644
--- a/utils/README.md
+++ b/utils/README.md
@@ -4,13 +4,15 @@ This directory contains useful utilities for users of the library.
## `pt2ts.py`
-This is a python script that can take a PyTorch model and convert it to Torchscript.
+This is a python script that can take a PyTorch model and convert it to
+TorchScript.
It provides the user with the option to [jit.script](https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script) or [jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace) for both CPU and GPU.
Dependencies:
- PyTorch
### Usage
+
1. Create and activate a virtual environment with PyTorch and any dependencies for your model.
2. Place the `pt2ts.py` script in the same folder as your model files.
3. Import your model into `pt2ts.py` and amend options as necessary (search for `FPTLIB-TODO`).
@@ -18,3 +20,14 @@ Dependencies:
The Torchscript model will be saved locally in the same location from which the `pt2ts.py`
script is being run.
+
+#### Command line arguments
+
+The `pt2ts.py` script is set up with the `argparse` Python module such that it
+accepts two command line arguments:
+* `--filepath `, which allows you to specify the path
+ to the directory in which the TorchScript model should be saved.
+* `--device_type `, which allows you to specify the CPU or GPU
+ device type with which to save the model. (To read and make use of a model on
+ a particular architecture, it's important that the model was targeted for that
+ architecture when it was saved.)
diff --git a/utils/pt2ts.py b/utils/pt2ts.py
index 8cea5fb3..dad047c5 100644
--- a/utils/pt2ts.py
+++ b/utils/pt2ts.py
@@ -1,7 +1,6 @@
"""Load a PyTorch model and convert it to TorchScript."""
import os
-import sys
from typing import Optional
# FPTLIB-TODO
@@ -71,6 +70,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--device_type",
+ help="Device type to run the inference on",
+ type=str,
+ choices=["cpu", "cuda", "xpu", "mps"],
+ default="cpu",
+ )
+ parser.add_argument(
+ "--filepath",
+ help="Path to the file containing the PyTorch model",
+ type=str,
+ default=os.path.dirname(__file__),
+ )
+ parsed_args = parser.parse_args()
+ device_type = parsed_args.device_type
+ filepath = parsed_args.filepath
+
# =====================================================
# Load model and prepare for saving
# =====================================================
@@ -97,13 +118,13 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
trained_model_dummy_input_1 = torch.ones((512, 40), dtype=torch.float64)
trained_model_dummy_input_2 = torch.ones((512, 1), dtype=torch.float64)
- # FPTLIB-TODO
- # Uncomment the following lines to save for inference on GPU (rather than CPU):
- # device = torch.device('cuda')
- # trained_model = trained_model.to(device)
- # trained_model.eval()
- # trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device)
- # trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device)
+ # Transfer the model and inputs to GPU device, if appropriate
+ if device_type != "cpu":
+ device = torch.device(device_type)
+ trained_model = trained_model.to(device)
+ trained_model.eval()
+ trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device)
+ trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device)
# FPTLIB-TODO
# Run model for dummy inputs
@@ -119,7 +140,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# FPTLIB-TODO
# Set the name of the file you want to save the torchscript model to:
- saved_ts_filename = "saved_model.pt"
+ saved_ts_filename = f"saved_model_{device_type}.pt"
# A filepath may also be provided. To do this, pass the filepath as an argument to
# this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`.
@@ -141,9 +162,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
# Check model saved OK
# =====================================================
- # Load torchscript and run model as a test
- # FPTLIB-TODO
- # Scale inputs as above and, if required, move inputs and mode to GPU
+ # Load torchscript and run model as a test, scaling inputs as above
trained_model_dummy_input_1 = 2.0 * trained_model_dummy_input_1
trained_model_dummy_input_2 = 2.0 * trained_model_dummy_input_2
trained_model_testing_outputs = trained_model(
@@ -172,7 +191,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod
raise RuntimeError(model_error)
# Check that the model file is created
- filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1]
if not os.path.exists(os.path.join(filepath, saved_ts_filename)):
torchscript_file_error = (
f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "