diff --git a/.github/workflows/static_analysis.yml b/.github/workflows/static_analysis.yml index 8c25dcaf..2ff7eecc 100644 --- a/.github/workflows/static_analysis.yml +++ b/.github/workflows/static_analysis.yml @@ -98,8 +98,9 @@ jobs: if: always() run: | cd ${{ github.workspace }} - . ftorch_venv/bin/activate # Uses .clang-tidy config file if present - fortitude check src/ + . ftorch_venv/bin/activate + fortitude check --ignore=E001,T041 src/ftorch.F90 + fortitude check src/ftorch_test_utils.f90 # Apply C++ and C linter and formatter, clang # Configurable using the .clang-format and .clang-tidy config files if present @@ -113,7 +114,7 @@ jobs: style: 'file' tidy-checks: '' # Use the compile_commands.json from CMake to locate headers - database: ${{ github.workspace }}/src/build + database: ${{ github.workspace }}/build # only 'update' a single comment in a pull request thread. thread-comments: ${{ github.event_name == 'pull_request' && 'update' }} - name: Fail fast?! diff --git a/CMakeLists.txt b/CMakeLists.txt index ead82aac..0f79fb2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,8 +22,33 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +# Set GPU device type using consistent numbering as in PyTorch +# https://github.com/pytorch/pytorch/blob/main/c10/core/DeviceType.h +# Set in a single location here and passed as preprocessor flag for use +# throughout source files. +set(GPU_DEVICE_NONE 0) +set(GPU_DEVICE_CUDA 1) +set(GPU_DEVICE_XPU 12) +set(GPU_DEVICE_MPS 13) +option(GPU_DEVICE "Set the GPU device (NONE [default], CUDA, XPU, or MPS)" NONE) +if("${GPU_DEVICE}" STREQUAL "OFF") + set(GPU_DEVICE NONE) +endif() +if("${GPU_DEVICE}" STREQUAL "NONE") + set(GPU_DEVICE_CODE ${GPU_DEVICE_NONE}) +elseif("${GPU_DEVICE}" STREQUAL "CUDA") + set(GPU_DEVICE_CODE ${GPU_DEVICE_CUDA}) +elseif("${GPU_DEVICE}" STREQUAL "XPU") + set(GPU_DEVICE_CODE ${GPU_DEVICE_XPU}) +elseif("${GPU_DEVICE}" STREQUAL "MPS") + set(GPU_DEVICE_CODE ${GPU_DEVICE_MPS}) +else() + message(SEND_ERROR "GPU_DEVICE '${GPU_DEVICE}' not recognised") +endif() + +# Other GPU specific setup include(CheckLanguage) -if(ENABLE_CUDA) +if("${GPU_DEVICE}" STREQUAL "CUDA") check_language(CUDA) if(CMAKE_CUDA_COMPILER) enable_language(CUDA) @@ -57,11 +82,17 @@ find_package(Torch REQUIRED) add_library(${LIB_NAME} SHARED src/ctorch.cpp src/ftorch.F90 src/ftorch_test_utils.f90) -if(UNIX) - if(NOT APPLE) # only add definition for linux (not apple which is also unix) - target_compile_definitions(${LIB_NAME} PRIVATE UNIX) - endif() +# Define compile definitions, including GPU devices +set(COMPILE_DEFS "") +if(UNIX AND NOT APPLE) + # only add UNIX definition for linux (not apple which is also unix) + set(COMPILE_DEFS UNIX) endif() +target_compile_definitions( + ${LIB_NAME} + PRIVATE ${COMPILE_DEFS} GPU_DEVICE=${GPU_DEVICE_CODE} + GPU_DEVICE_NONE=${GPU_DEVICE_NONE} GPU_DEVICE_CUDA=${GPU_DEVICE_CUDA} + GPU_DEVICE_XPU=${GPU_DEVICE_XPU} GPU_DEVICE_MPS=${GPU_DEVICE_MPS}) # Add an alias FTorch::ftorch for the library add_library(${PROJECT_NAME}::${LIB_NAME} ALIAS ${LIB_NAME}) diff --git a/README.md b/README.md index 393f20ab..9c8f541f 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,7 @@ To build and install the library: | [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_PREFIX.html) | `` | Location at which the library files should be installed. By default this is `/usr/local` | | [`CMAKE_BUILD_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html) | `Release` / `Debug` | Specifies build type. The default is `Debug`, use `Release` for production code| | `CMAKE_BUILD_TESTS` | `TRUE` / `FALSE` | Specifies whether to compile FTorch's [test suite](https://cambridge-iccs.github.io/FTorch/page/testing.html) as part of the build. | - | `ENABLE_CUDA` | `TRUE` / `FALSE` | Specifies whether to check for and enable CUDA3 | + | `GPU_DEVICE` | `NONE` / `CUDA` / `XPU` / `MPS` | Specifies the target GPU architecture (if any) 3 | 1 _On Windows this may need to be the full path to the compiler if CMake cannot locate it by default._ @@ -176,11 +176,9 @@ To build and install the library: e.g. with `pip install torch`, then this should be `lib/python<3.xx>/site-packages/torch/`. You can find the location of your torch install by importing torch from your Python environment (`import torch`) and running `print(torch.__file__)`_ - 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU or GPU only version is installed, e.g. - `pip install torch --index-url https://download.pytorch.org/whl/cpu` - or - `pip install torch --index-url https://download.pytorch.org/whl/cu118` - (for CUDA 11.8). URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._ + 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU-only or GPU-enabled version is installed, e.g. + `pip install torch --index-url https://download.pytorch.org/whl/cpu`. + URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._ 4. Make and install the library to the desired location with either: ``` @@ -219,12 +217,16 @@ These steps are described in more detail in the ## GPU Support -To run on GPU requires a CUDA-compatible installation of LibTorch and two main -adaptations to the code: +To run on GPU requires an installation of LibTorch compatible for the GPU device +you wish to target and two main adaptations to the code: -1. When saving a TorchScript model, ensure that it is on the GPU +1. When saving a TorchScript model, ensure that it is on the appropriate GPU + device type. The `pt2ts.py` script has a command line argument + `--device_type`, which currently accepts four different device types: `cpu` + (default), `cuda`, `xpu`, or `mps`. 2. When using FTorch in Fortran, set the device for the input - tensor(s) to `torch_kCUDA`, rather than `torch_kCPU`. + tensor(s) to the appropriate GPU device type, rather than `torch_kCPU`. There + are currently three options: `torch_kCUDA`, `torch_kXPU`, or `torch_kMPS`. For detailed guidance about running on GPU, including instructions for using multiple devices, please see the diff --git a/conda/README.md b/conda/README.md index 9090210a..02214191 100644 --- a/conda/README.md +++ b/conda/README.md @@ -38,7 +38,7 @@ cmake \ -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX \ -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)') \ -DCMAKE_BUILD_TYPE=Release \ - -DENABLE_CUDA=FALSE \ + -DGPU_DEVICE=NONE \ .. cmake --build . --target install ``` @@ -65,7 +65,7 @@ cmake \ -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX \ -DCMAKE_PREFIX_PATH=$(python -c 'import torch;print(torch.utils.cmake_prefix_path)') \ -DCMAKE_BUILD_TYPE=Release \ - -DENABLE_CUDA=TRUE \ + -DGPU_DEVICE=CUDA \ -DCUDA_TOOLKIT_ROOT_DIR=$CONDA_PREFIX/targets/x86_64-linux \ -Dnvtx3_dir=$CONDA_PREFIX/targets/x86_64-linux/include/nvtx3 \ .. diff --git a/examples/1_SimpleNet/CMakeLists.txt b/examples/1_SimpleNet/CMakeLists.txt index 15e0b579..97c08698 100644 --- a/examples/1_SimpleNet/CMakeLists.txt +++ b/examples/1_SimpleNet/CMakeLists.txt @@ -33,9 +33,8 @@ if(CMAKE_BUILD_TESTS) # pt2ts.py script add_test( NAME pt2ts - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving - # the model + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --filepath + ${PROJECT_BINARY_DIR} WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) # 3. Check the model can be loaded from file and run in Python and that its @@ -43,8 +42,7 @@ if(CMAKE_BUILD_TESTS) add_test( NAME simplenet_infer_python COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet_infer_python.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath to find the - # model + --filepath ${PROJECT_BINARY_DIR} WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) # 4. Check the model can be loaded from file and run in Fortran and that its diff --git a/examples/1_SimpleNet/pt2ts.py b/examples/1_SimpleNet/pt2ts.py index 161ea2e7..35c22350 100644 --- a/examples/1_SimpleNet/pt2ts.py +++ b/examples/1_SimpleNet/pt2ts.py @@ -1,7 +1,6 @@ """Load a PyTorch model and convert it to TorchScript.""" import os -import sys from typing import Optional # FPTLIB-TODO @@ -72,6 +71,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cpu", + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + device_type = parsed_args.device_type + filepath = parsed_args.filepath + # ===================================================== # Load model and prepare for saving # ===================================================== @@ -97,12 +118,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # This example assumes one input of size (5) trained_model_dummy_input = torch.ones(5) - # FPTLIB-TODO - # Uncomment the following lines to save for inference on GPU (rather than CPU): - # device = torch.device('cuda') - # trained_model = trained_model.to(device) - # trained_model.eval() - # trained_model_dummy_input = trained_model_dummy_input.to(device) + # Transfer the model and inputs to GPU device, if appropriate + if device_type != "cpu": + device = torch.device(device_type) + trained_model = trained_model.to(device) + trained_model.eval() + trained_model_dummy_input = trained_model_dummy_input.to(device) # FPTLIB-TODO # Run model for dummy inputs @@ -117,7 +138,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # FPTLIB-TODO # Set the name of the file you want to save the torchscript model to: - saved_ts_filename = "saved_simplenet_model_cpu.pt" + saved_ts_filename = f"saved_simplenet_model_{device_type}.pt" # A filepath may also be provided. To do this, pass the filepath as an argument to # this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`. @@ -141,9 +162,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # Check model saved OK # ===================================================== - # Load torchscript and run model as a test - # FPTLIB-TODO - # Scale inputs as above and, if required, move inputs and mode to GPU + # Load torchscript and run model as a test, scaling inputs as above trained_model_dummy_input = 2.0 * trained_model_dummy_input trained_model_testing_outputs = trained_model( trained_model_dummy_input, @@ -169,7 +188,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod raise RuntimeError(model_error) # Check that the model file is created - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] if not os.path.exists(os.path.join(filepath, saved_ts_filename)): torchscript_file_error = ( f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} " diff --git a/examples/1_SimpleNet/simplenet_infer_python.py b/examples/1_SimpleNet/simplenet_infer_python.py index 338f83c7..2851cef9 100644 --- a/examples/1_SimpleNet/simplenet_infer_python.py +++ b/examples/1_SimpleNet/simplenet_infer_python.py @@ -1,7 +1,6 @@ """Load saved SimpleNet to TorchScript and run inference example.""" import os -import sys import torch @@ -49,7 +48,19 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor: if __name__ == "__main__": - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + filepath = parsed_args.filepath saved_model_file = os.path.join(filepath, "saved_simplenet_model_cpu.pt") device_to_run = "cpu" diff --git a/examples/2_ResNet18/CMakeLists.txt b/examples/2_ResNet18/CMakeLists.txt index a3f8a7c1..56e8b63f 100644 --- a/examples/2_ResNet18/CMakeLists.txt +++ b/examples/2_ResNet18/CMakeLists.txt @@ -35,9 +35,8 @@ if(CMAKE_BUILD_TESTS) # pt2ts.py script add_test( NAME pt2ts - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving - # the model + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --filepath + ${PROJECT_BINARY_DIR} WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) # 3. Check the model can be loaded from file and run in Fortran and that its diff --git a/examples/2_ResNet18/pt2ts.py b/examples/2_ResNet18/pt2ts.py index d04cb5c4..2db59bfb 100644 --- a/examples/2_ResNet18/pt2ts.py +++ b/examples/2_ResNet18/pt2ts.py @@ -1,7 +1,6 @@ """Load a PyTorch model and convert it to TorchScript.""" import os -import sys from typing import Optional # FPTLIB-TODO @@ -75,6 +74,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cpu", + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + device_type = parsed_args.device_type + filepath = parsed_args.filepath + # ===================================================== # Load model and prepare for saving # ===================================================== @@ -103,12 +124,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # of resolution 244x244 in a batch size of 1. trained_model_dummy_input = torch.ones(1, 3, 224, 224) - # FPTLIB-TODO - # Uncomment the following lines to save for inference on GPU (rather than CPU): - # device = torch.device('cuda') - # trained_model = trained_model.to(device) - # trained_model.eval() - # trained_model_dummy_input = trained_model_dummy_input.to(device) + # Transfer the model and inputs to GPU device, if appropriate + if device_type != "cpu": + device = torch.device(device_type) + trained_model = trained_model.to(device) + trained_model.eval() + trained_model_dummy_input = trained_model_dummy_input.to(device) # FPTLIB-TODO # Run model for dummy inputs @@ -123,7 +144,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # FPTLIB-TODO # Set the name of the file you want to save the torchscript model to: - saved_ts_filename = "saved_resnet18_model_cpu.pt" + saved_ts_filename = f"saved_resnet18_model_{device_type}.pt" # A filepath may also be provided. To do this, pass the filepath as an argument to # this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`. @@ -147,9 +168,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # Check model saved OK # ===================================================== - # Load torchscript and run model as a test - # FPTLIB-TODO - # Scale inputs as above and, if required, move inputs and mode to GPU + # Load torchscript and run model as a test, scaling inputs as above trained_model_dummy_input = 2.0 * trained_model_dummy_input trained_model_testing_outputs = trained_model( trained_model_dummy_input, @@ -175,7 +194,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod raise RuntimeError(model_error) # Check that the model file is created - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] if not os.path.exists(os.path.join(filepath, saved_ts_filename)): torchscript_file_error = ( f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} " diff --git a/examples/2_ResNet18/resnet_infer_python.py b/examples/2_ResNet18/resnet_infer_python.py index c5590c7a..1834bb17 100644 --- a/examples/2_ResNet18/resnet_infer_python.py +++ b/examples/2_ResNet18/resnet_infer_python.py @@ -1,7 +1,6 @@ """Load ResNet-18 saved to TorchScript and run inference with an example image.""" import os -import sys from math import isclose import numpy as np @@ -78,11 +77,22 @@ def check_results(output: torch.Tensor) -> None: if __name__ == "__main__": - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + filepath = parsed_args.filepath saved_model_file = os.path.join(filepath, "saved_resnet18_model_cpu.pt") device_to_run = "cpu" - # device_to_run = "cuda" batch_size_to_run = 1 diff --git a/examples/3_MultiGPU/CMakeLists.txt b/examples/3_MultiGPU/CMakeLists.txt index 81bcb4e8..dea6b5ba 100644 --- a/examples/3_MultiGPU/CMakeLists.txt +++ b/examples/3_MultiGPU/CMakeLists.txt @@ -18,11 +18,13 @@ find_package(FTorch) message(STATUS "Building with Fortran PyTorch coupling") include(CheckLanguage) -check_language(CUDA) -if(CMAKE_CUDA_COMPILER) - enable_language(CUDA) -else() - message(ERROR "No CUDA support") +if("${GPU_DEVICE}" STREQUAL "CUDA") + check_language(CUDA) + if(CMAKE_CUDA_COMPILER) + enable_language(CUDA) + else() + message(ERROR "No CUDA support") + endif() endif() # Fortran example @@ -30,40 +32,116 @@ add_executable(multigpu_infer_fortran multigpu_infer_fortran.f90) target_link_libraries(multigpu_infer_fortran PRIVATE FTorch::ftorch) # Integration testing -if (CMAKE_BUILD_TESTS) +if(CMAKE_BUILD_TESTS) include(CTest) - # 1. Check the PyTorch model runs and its outputs meet expectations - add_test(NAME simplenet COMMAND ${Python_EXECUTABLE} - ${PROJECT_SOURCE_DIR}/simplenet.py) - - # 2. Check the model is saved to file in the expected location with the - # pt2ts.py script - add_test( - NAME pt2ts - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving - # the model - WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) - - # 3. Check the model can be loaded from file and run in Python and that its - # outputs meet expectations - add_test( - NAME multigpu_infer_python - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath to find the - # model - WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) - - # 4. Check the model can be loaded from file and run in Fortran and that its - # outputs meet expectations - add_test( - NAME multigpu_infer_fortran - COMMAND - multigpu_infer_fortran ${PROJECT_BINARY_DIR}/saved_multigpu_model_cuda.pt - # Command line argument: model file - WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) - set_tests_properties( - multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION - "MultiGPU example ran successfully") + if("${GPU_DEVICE}" STREQUAL "CUDA") + # 1a. Check the PyTorch model runs on a CUDA device and its outputs meet + # expectations + add_test(NAME simplenet + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet.py + --device_type cuda) + + # 2a. Check the model is saved to file in the expected location with the + # pt2ts.py script + add_test( + NAME pt2ts + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --device_type + cuda --filepath ${PROJECT_BINARY_DIR} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + + # 3a. Check the model can be loaded from file and run on two CUDA devices in + # Python and that its outputs meet expectations + add_test( + NAME multigpu_infer_python + COMMAND ${Python_EXECUTABLE} + ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py --device_type cuda + --filepath ${PROJECT_BINARY_DIR} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + + # 4a. Check the model can be loaded from file and run on two CUDA devices in + # Fortran and that its outputs meet expectations + add_test( + NAME multigpu_infer_fortran + COMMAND multigpu_infer_fortran cuda + ${PROJECT_BINARY_DIR}/saved_multigpu_model_cuda.pt + # Command line arguments for device type and model file + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + set_tests_properties( + multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION + "MultiGPU example ran successfully") + endif() + + if("${GPU_DEVICE}" STREQUAL "XPU") + # 1b. Check the PyTorch model runs on an XPU device and its outputs meet + # expectations + add_test(NAME simplenet + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet.py + --device_type xpu) + + # 2b. Check the model is saved to file in the expected location with the + # pt2ts.py script + add_test( + NAME pt2ts + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --device_type + xpu --filepath ${PROJECT_BINARY_DIR} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + + # 3b. Check the model can be loaded from file and run on two XPU devices in + # Python and that its outputs meet expectations + add_test( + NAME multigpu_infer_python + COMMAND ${Python_EXECUTABLE} + ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py --device_type xpu + --filepath ${PROJECT_BINARY_DIR} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + + # 4b. Check the model can be loaded from file and run on two XPU devices in + # Fortran and that its outputs meet expectations + add_test( + NAME multigpu_infer_fortran + COMMAND multigpu_infer_fortran xpu + ${PROJECT_BINARY_DIR}/saved_multigpu_model_xpu.pt + # Command line arguments for device type and model file + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + set_tests_properties( + multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION + "MultiGPU example ran successfully") + endif() + + if("${GPU_DEVICE}" STREQUAL "MPS") + # 1c. Check the PyTorch model runs on an MPS device and its outputs meet + # expectations + add_test(NAME simplenet + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/simplenet.py + --device_type mps) + # 2c. Check the model is saved to file in the expected location with the + # pt2ts.py script + add_test( + NAME pt2ts + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --device_type + mps --filepath ${PROJECT_BINARY_DIR} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + + # 3c. Check the model can be loaded from file and run on one MPS device in + # Python and that its outputs meet expectations + add_test( + NAME multigpu_infer_python + COMMAND ${Python_EXECUTABLE} + ${PROJECT_SOURCE_DIR}/multigpu_infer_python.py --device_type mps + --filepath ${PROJECT_BINARY_DIR} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + + # 4c. Check the model can be loaded from file and run on one MPS device in + # Fortran and that its outputs meet expectations + add_test( + NAME multigpu_infer_fortran + COMMAND multigpu_infer_fortran mps + ${PROJECT_BINARY_DIR}/saved_multigpu_model_mps.pt + # Command line arguments for device type and model file + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) + set_tests_properties( + multigpu_infer_fortran PROPERTIES PASS_REGULAR_EXPRESSION + "MultiGPU example ran successfully") + endif() endif() diff --git a/examples/3_MultiGPU/README.md b/examples/3_MultiGPU/README.md index 9c46d83c..7b0f2543 100644 --- a/examples/3_MultiGPU/README.md +++ b/examples/3_MultiGPU/README.md @@ -20,8 +20,8 @@ the TorchScript model in inference mode. To run this example requires: - CMake -- Two (or more) GPU devices that support CUDA and have it installed. -- FTorch (installed with CUDA enabled as described in main package) +- Two (or more) CUDA or XPU GPU devices (or a single MPS device). +- FTorch (installed with a GPU_DEVICE enabled as described in main package) - Python 3 ## Running @@ -36,30 +36,33 @@ pip install -r requirements.txt You can check that everything is working by running `simplenet.py`: ``` -python3 simplenet.py +python3 simplenet.py --device_type ``` +where `` is `cuda`/`xpu`/`mps` as appropriate for your device. + As before, this defines the network and runs it with an input tensor [0.0, 1.0, 2.0, 3.0, 4.0]. The difference is that the code will make use of the -default CUDA device (index 0) to produce the result: +default GPU device (index 0) to produce the result: ``` SimpleNet forward pass on CUDA device 0 tensor([[0, 2, 4, 6, 8]]) ``` +for CUDA, and similarly for other device types. To save the `SimpleNet` model to TorchScript run the modified version of the `pt2ts.py` tool: ``` -python3 pt2ts.py +python3 pt2ts.py --device_type ``` -which will generate `saved_multigpu_model_cuda.pt` - the TorchScript instance -of the network. The only difference with the earlier example is that the model -is built to be run using CUDA rather than on CPU. +which will generate `saved_multigpu_model_.pt` - the TorchScript +instance of the network. The only difference with the earlier example is that +the model is built to be run on GPU devices rather than on CPU. You can check that everything is working by running the `multigpu_infer_python.py` script. It's set up such that it loops over two GPU devices. Run with: ``` -python3 multigpu_infer_python.py +python3 multigpu_infer_python.py --device_type ``` This reads the model in from the TorchScript file and runs it with an different input tensor on each GPU device: [0.0, 1.0, 2.0, 3.0, 4.0], plus the device index in each @@ -68,6 +71,7 @@ entry. The result should be: Output on device 0: tensor([[0., 2., 4., 6., 8.]]) Output on device 1: tensor([[ 2., 4., 6., 8., 10.]]) ``` +Note that Mps will only use device 0. At this point we no longer require Python, so can deactivate the virtual environment: ``` @@ -89,9 +93,9 @@ cmake --build . and should match the compiler that was used to locally build FTorch.) To run the compiled code calling the saved `SimpleNet` TorchScript from -Fortran, run the executable with an argument of the saved model file: +Fortran, run the executable with arguments of device type and the saved model file: ``` -./multigpu_infer_fortran ../saved_multigpu_model_cuda.pt +./multigpu_infer_fortran ../saved_multigpu_model_.pt ``` This runs the model with the same inputs as described above and should produce (some @@ -102,6 +106,7 @@ input on device 1: [ 1.0, 2.0, 3.0, 4.0, 5.0] output on device 0: [ 0.0, 2.0, 4.0, 6.0, 8.0] output on device 1: [ 2.0, 4.0, 6.0, 8.0, 10.0] ``` +Again, note that MPS will only use device 0. Alternatively, we can use `make`, instead of CMake, copying the Makefile over from the first example: diff --git a/examples/3_MultiGPU/multigpu_infer_fortran.f90 b/examples/3_MultiGPU/multigpu_infer_fortran.f90 index 297844e4..d5d49980 100644 --- a/examples/3_MultiGPU/multigpu_infer_fortran.f90 +++ b/examples/3_MultiGPU/multigpu_infer_fortran.f90 @@ -4,10 +4,14 @@ program inference use, intrinsic :: iso_fortran_env, only : sp => real32 ! Import our library for interfacing with PyTorch - use ftorch, only : torch_model, torch_tensor, torch_kCUDA, torch_kCPU, & + use ftorch, only : torch_model, torch_tensor, & + torch_kCPU, torch_kCUDA, torch_kXPU, torch_kMPS, & torch_tensor_from_array, torch_model_load, torch_model_forward, & torch_delete + ! Import our tools module for testing utils + use ftorch_test_utils, only : assert_allclose + implicit none ! Set precision for reals @@ -19,6 +23,7 @@ program inference ! Set up Fortran data structures real(wp), dimension(5), target :: in_data real(wp), dimension(5), target :: out_data + real(wp), dimension(5) :: expected integer, parameter :: tensor_layout(1) = [1] ! Set up Torch data structures @@ -27,15 +32,30 @@ program inference type(torch_tensor), dimension(1) :: out_tensors ! Variables for multi-GPU setup - integer, parameter :: num_devices = 2 - integer :: device_index, i + integer :: num_devices = 2 + integer :: device_type, device_index, i + + ! Flag for testing + logical :: test_pass - ! Get TorchScript model file as a command line argument + ! Get device type as first command line argument and TorchScript model file as second command + ! line argument num_args = command_argument_count() allocate(args(num_args)) do ix = 1, num_args call get_command_argument(ix,args(ix)) end do + if (trim(args(1)) == "cuda") then + device_type = torch_kCUDA + else if (trim(args(1)) == "xpu") then + device_type = torch_kXPU + else if (trim(args(1)) == "mps") then + device_type = torch_kMPS + num_devices = 1 + else + write (*,*) "Error :: invalid device type", trim(args(1)) + stop 999 + end if do device_index = 0, num_devices-1 @@ -46,8 +66,8 @@ program inference ! Create Torch input tensor from the above array and assign it to the first (and only) ! element in the array of input tensors. - ! We use the torch_kCUDA device type with the given device index - call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, torch_kCUDA, & + ! We use the specified GPU device type with the given device index + call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, device_type, & device_index=device_index) ! Create Torch output tensor from the above array. @@ -57,7 +77,7 @@ program inference ! Load ML model. Ensure that the same device type and device index are used ! as for the input data. - call torch_model_load(model, args(1), torch_kCUDA, device_index=device_index) + call torch_model_load(model, args(2), device_type, device_index=device_index) ! Infer call torch_model_forward(model, in_tensors, out_tensors) @@ -66,11 +86,19 @@ program inference write (6, 200) device_index, out_data(:) 200 format("output on device ", i1,": [", 4(f5.1,","), f5.1,"]") + ! Check output tensor matches expected value + expected = [(2 * (device_index + i), i = 0, 4)] + test_pass = assert_allclose(out_data, expected, test_name="MultiGPU") + ! Cleanup call torch_delete(model) call torch_delete(in_tensors) call torch_delete(out_tensors) + if (.not. test_pass) then + stop 999 + end if + end do write (*,*) "MultiGPU example ran successfully" diff --git a/examples/3_MultiGPU/multigpu_infer_python.py b/examples/3_MultiGPU/multigpu_infer_python.py index e6075874..cc4c1c8e 100644 --- a/examples/3_MultiGPU/multigpu_infer_python.py +++ b/examples/3_MultiGPU/multigpu_infer_python.py @@ -1,5 +1,7 @@ """Load saved SimpleNet to TorchScript and run inference example.""" +import os + import torch @@ -28,39 +30,76 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor: # Load saved TorchScript model model = torch.jit.load(saved_model) # Inference - output = model.forward(input_tensor) - - elif device.startswith("cuda"): - # Add the device index to each tensor to make them differ - input_tensor += int(device.split(":")[-1] or 0) - - # All previously saved modules, no matter their device, are first - # loaded onto CPU, and then are moved to the devices they were saved - # from, so we don't need to manually transfer the model to the GPU - model = torch.jit.load(saved_model) - model = model.to(device) - input_tensor_gpu = input_tensor.to(torch.device(device)) - output_gpu = model.forward(input_tensor_gpu) - output = output_gpu.to(torch.device("cpu")) - + return model.forward(input_tensor) + + if device.startswith("cuda"): + pass + elif device.startswith("xpu"): + # XPU devices need to be initialised before use + torch.xpu.init() + elif device.startswith("mps"): + pass else: device_error = f"Device '{device}' not recognised." raise ValueError(device_error) + # Add the device index to each tensor to make them differ + input_tensor += int(device.split(":")[-1] or 0) + + # All previously saved modules, no matter their device, are first + # loaded onto CPU, and then are moved to the devices they were saved + # from. + # Since we are loading one saved model to multiple devices we explicitly + # transfer using `.to(device)` to ensure the model is on the correct index. + model = torch.jit.load(saved_model) + model = model.to(device) + input_tensor_gpu = input_tensor.to(torch.device(device)) + output_gpu = model.forward(input_tensor_gpu) + output = output_gpu.to(torch.device("cpu")) + return output if __name__ == "__main__": - saved_model_file = "saved_multigpu_model_cuda.pt" - - num_devices = 2 + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cuda", + ) + parsed_args = parser.parse_args() + filepath = parsed_args.filepath + device_type = parsed_args.device_type + saved_model_file = os.path.join(filepath, f"saved_multigpu_model_{device_type}.pt") + + batch_size_to_run = 1 + + # Use 2 devices unless MPS for which there is only one + num_devices = 1 if device_type == "mps" else 2 for device_index in range(num_devices): - device_to_run = f"cuda:{device_index}" - - batch_size_to_run = 1 + device_to_run = f"{device_type}:{device_index}" with torch.no_grad(): result = deploy(saved_model_file, device_to_run, batch_size_to_run) print(f"Output on device {device_to_run}: {result}") + + expected = torch.Tensor([2 * (i + device_index) for i in range(5)]) + if not torch.allclose(result, expected): + result_error = ( + f"result:\n{result}\ndoes not match expected value:\n{expected}" + ) + raise ValueError(result_error) diff --git a/examples/3_MultiGPU/pt2ts.py b/examples/3_MultiGPU/pt2ts.py index 8242f691..823e69dd 100644 --- a/examples/3_MultiGPU/pt2ts.py +++ b/examples/3_MultiGPU/pt2ts.py @@ -1,7 +1,6 @@ """Load a PyTorch model and convert it to TorchScript.""" import os -import sys from typing import Optional # FPTLIB-TODO @@ -72,6 +71,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cuda", + ) + parsed_args = parser.parse_args() + filepath = parsed_args.filepath + device_type = parsed_args.device_type + # ===================================================== # Load model and prepare for saving # ===================================================== @@ -97,12 +118,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # This example assumes one input of size (5) trained_model_dummy_input = torch.ones(5) - # FPTLIB-TODO - # Uncomment the following lines to save for inference on GPU (rather than CPU): - device = torch.device("cuda") - trained_model = trained_model.to(device) - trained_model.eval() - trained_model_dummy_input = trained_model_dummy_input.to(device) + # Transfer the model and inputs to GPU device, if appropriate + if device_type != "cpu": + device = torch.device(device_type) + trained_model = trained_model.to(device) + trained_model.eval() + trained_model_dummy_input = trained_model_dummy_input.to(device) # FPTLIB-TODO # Run model for dummy inputs @@ -117,7 +138,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # FPTLIB-TODO # Set the name of the file you want to save the torchscript model to: - saved_ts_filename = "saved_multigpu_model_cuda.pt" + saved_ts_filename = f"saved_multigpu_model_{device_type}.pt" # A filepath may also be provided. To do this, pass the filepath as an argument to # this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`. @@ -141,11 +162,8 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # Check model saved OK # ===================================================== - # Load torchscript and run model as a test - # FPTLIB-TODO - # Scale inputs as above and, if required, move inputs and mode to GPU + # Load torchscript and run model as a test, scaling inputs as above trained_model_dummy_input = 2.0 * trained_model_dummy_input - trained_model_dummy_input = trained_model_dummy_input.to("cuda") trained_model_testing_outputs = trained_model( trained_model_dummy_input, ) @@ -170,7 +188,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod raise RuntimeError(model_error) # Check that the model file is created - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] if not os.path.exists(os.path.join(filepath, saved_ts_filename)): torchscript_file_error = ( f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} " diff --git a/examples/3_MultiGPU/simplenet.py b/examples/3_MultiGPU/simplenet.py index b7c2e6b2..46f46550 100644 --- a/examples/3_MultiGPU/simplenet.py +++ b/examples/3_MultiGPU/simplenet.py @@ -42,13 +42,38 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": - model = SimpleNet().to(torch.device("cuda")) + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cuda", + ) + parsed_args = parser.parse_args() + device_type = parsed_args.device_type + + model = SimpleNet().to(torch.device(device_type)) model.eval() input_tensor = torch.Tensor([0.0, 1.0, 2.0, 3.0, 4.0]) - input_tensor_gpu = input_tensor.to(torch.device("cuda")) + input_tensor_gpu = input_tensor.to(torch.device(device_type)) - print(f"SimpleNet forward pass on CUDA device {input_tensor_gpu.get_device()}") + print( + f"SimpleNet forward pass on {device_type.capitalize()} device" + f" {input_tensor_gpu.get_device()}" + ) with torch.no_grad(): - output = model(input_tensor_gpu) - print(output) + output_tensor = model(input_tensor_gpu).to("cpu") + + print(output_tensor) + if not torch.allclose(output_tensor, 2 * input_tensor): + result_error = ( + f"result:\n{output_tensor}\ndoes not match expected value:\n" + f"{2 * input_tensor}" + ) + raise ValueError(result_error) diff --git a/examples/4_MultiIO/CMakeLists.txt b/examples/4_MultiIO/CMakeLists.txt index d8bef873..03b96dff 100644 --- a/examples/4_MultiIO/CMakeLists.txt +++ b/examples/4_MultiIO/CMakeLists.txt @@ -33,9 +33,8 @@ if(CMAKE_BUILD_TESTS) # pt2ts.py script add_test( NAME pt2ts - COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath for saving - # the model + COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/pt2ts.py --filepath + ${PROJECT_BINARY_DIR} WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) # 3. Check the model can be loaded from file and run in Python and that its @@ -44,7 +43,7 @@ if(CMAKE_BUILD_TESTS) NAME multiionet_infer_python COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/multiionet_infer_python.py - ${PROJECT_BINARY_DIR} # Command line argument: filepath to find the model + --filepath ${PROJECT_BINARY_DIR} WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) # 4. Check the model can be loaded from file and run in Fortran and that its diff --git a/examples/4_MultiIO/multiionet_infer_python.py b/examples/4_MultiIO/multiionet_infer_python.py index d5c9cf2d..a615289c 100644 --- a/examples/4_MultiIO/multiionet_infer_python.py +++ b/examples/4_MultiIO/multiionet_infer_python.py @@ -1,7 +1,6 @@ """Load saved MultIONet to TorchScript and run inference example.""" import os -import sys import torch @@ -52,7 +51,19 @@ def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor: if __name__ == "__main__": - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + filepath = parsed_args.filepath saved_model_file = os.path.join(filepath, "saved_multiio_model_cpu.pt") device_to_run = "cpu" diff --git a/examples/4_MultiIO/pt2ts.py b/examples/4_MultiIO/pt2ts.py index e7cd67e6..31066406 100644 --- a/examples/4_MultiIO/pt2ts.py +++ b/examples/4_MultiIO/pt2ts.py @@ -1,7 +1,6 @@ """Load a PyTorch model and convert it to TorchScript.""" import os -import sys from typing import Optional # FPTLIB-TODO @@ -72,6 +71,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cpu", + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + device_type = parsed_args.device_type + filepath = parsed_args.filepath + # ===================================================== # Load model and prepare for saving # ===================================================== @@ -97,12 +118,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # This example assumes one input of size (5) trained_model_dummy_inputs = (torch.ones(4), torch.ones(4)) - # FPTLIB-TODO - # Uncomment the following lines to save for inference on GPU (rather than CPU): - # device = torch.device('cuda') - # trained_model = trained_model.to(device) - # trained_model.eval() - # trained_model_dummy_input = trained_model_dummy_input.to(device) + # Transfer the model and inputs to GPU device, if appropriate + if device_type != "cpu": + device = torch.device(device_type) + trained_model = trained_model.to(device) + trained_model.eval() + trained_model_dummy_inputs = trained_model_dummy_inputs.to(device) # FPTLIB-TODO # Run model for dummy inputs @@ -117,7 +138,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # FPTLIB-TODO # Set the name of the file you want to save the torchscript model to: - saved_ts_filename = "saved_multiio_model_cpu.pt" + saved_ts_filename = f"saved_multiio_model_{device_type}.pt" # A filepath may also be provided. To do this, pass the filepath as an argument to # this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`. @@ -141,9 +162,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # Check model saved OK # ===================================================== - # Load torchscript and run model as a test - # FPTLIB-TODO - # Scale inputs as above and, if required, move inputs and mode to GPU + # Load torchscript and run model as a test, scaling inputs as above trained_model_dummy_inputs = ( 2.0 * trained_model_dummy_inputs[0], 3.0 * trained_model_dummy_inputs[1], @@ -172,7 +191,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod raise RuntimeError(model_error) # Check that the model file is created - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] if not os.path.exists(os.path.join(filepath, saved_ts_filename)): torchscript_file_error = ( f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} " diff --git a/examples/5_Looping/pt2ts.py b/examples/5_Looping/pt2ts.py index 114bb325..962903fd 100644 --- a/examples/5_Looping/pt2ts.py +++ b/examples/5_Looping/pt2ts.py @@ -1,7 +1,6 @@ """Load a PyTorch model and convert it to TorchScript.""" import os -import sys from typing import Optional # FPTLIB-TODO @@ -68,6 +67,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cpu", + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + device_type = parsed_args.device_type + filepath = parsed_args.filepath + # ===================================================== # Load model and prepare for saving # ===================================================== @@ -93,12 +114,12 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # This example assumes two inputs of size (512x40) and (512x1) trained_model_dummy_input = torch.ones((5), dtype=torch.float32) - # FPTLIB-TODO - # Uncomment the following lines to save for inference on GPU (rather than CPU): - # device = torch.device('cuda') - # trained_model = trained_model.to(device) - # trained_model.eval() - # trained_model_dummy_input = trained_model_dummy_input.to(device) + # Transfer the model and inputs to GPU device, if appropriate + if device_type != "cpu": + device = torch.device(device_type) + trained_model = trained_model.to(device) + trained_model.eval() + trained_model_dummy_input = trained_model_dummy_input.to(device) # FPTLIB-TODO # Run model for dummy inputs @@ -113,7 +134,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # FPTLIB-TODO # Set the name of the file you want to save the torchscript model to: - saved_ts_filename = "saved_simplenet_model.pt" + saved_ts_filename = f"saved_simplenet_{device_type}.pt" # A filepath may also be provided. To do this, pass the filepath as an argument to # this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`. @@ -135,9 +156,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # Check model saved OK # ===================================================== - # Load torchscript and run model as a test - # FPTLIB-TODO - # Scale inputs as above and, if required, move inputs and mode to GPU + # Load torchscript and run model as a test, scaling inputs as above trained_model_dummy_input = 2.0 * trained_model_dummy_input trained_model_testing_outputs = trained_model( trained_model_dummy_input, @@ -163,7 +182,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod raise RuntimeError(model_error) # Check that the model file is created - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] if not os.path.exists(os.path.join(filepath, saved_ts_filename)): torchscript_file_error = ( f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} " diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index c69812f8..b651a594 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,7 +1,7 @@ if(CMAKE_BUILD_TESTS) add_subdirectory(1_SimpleNet) add_subdirectory(2_ResNet18) - if(ENABLE_CUDA) + if(NOT "${GPU_DEVICE}" STREQUAL "NONE") add_subdirectory(3_MultiGPU) endif() add_subdirectory(4_MultiIO) diff --git a/pages/cmake.md b/pages/cmake.md index 90dc11c6..5a4ee565 100644 --- a/pages/cmake.md +++ b/pages/cmake.md @@ -62,24 +62,22 @@ The following CMake flags are available and can be passed as arguments through ` | [`CMAKE_INSTALL_PREFIX`](https://cmake.org/cmake/help/latest/variable/CMAKE_INSTALL_PREFIX.html) | `` | Location at which the library files should be installed. By default this is `/usr/local` | | [`CMAKE_BUILD_TYPE`](https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html) | `Release` / `Debug` | Specifies build type. The default is `Debug`, use `Release` for production code| | `CMAKE_BUILD_TESTS` | `TRUE` / `FALSE` | Specifies whether to compile FTorch's [test suite](testing.html) as part of the build. | -| `ENABLE_CUDA` | `TRUE` / `FALSE` | Specifies whether to check for and enable CUDA3 | +| `GPU_DEVICE` | `NONE` / `CUDA` / `XPU` / `MPS` | Specifies the target GPU architecture (if any) 3 | > 1 _On Windows this may need to be the full path to the compiler if CMake > cannot locate it by default._ -> +> > 2 _The path to the Torch installation needs to allow CMake to locate the relevant Torch CMake files. > If Torch has been [installed as LibTorch](https://pytorch.org/cppdocs/installing.html) > then this should be the absolute path to the unzipped LibTorch distribution. > If Torch has been installed as PyTorch in a Python [venv (virtual environment)](https://docs.python.org/3/library/venv.html), > e.g. with `pip install torch`, then this should be `lib/python<3.xx>/site-packages/torch/`._ -> -> 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU or GPU only version is installed, e.g. -> `pip install torch --index-url https://download.pytorch.org/whl/cpu` -> or -> `pip install torch --index-url https://download.pytorch.org/whl/cu118` -> (for CUDA 11.8). URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._ +> +> 3 _This is often overridden by PyTorch. When installing with pip, the `index-url` flag can be used to ensure a CPU-only or GPU-enabled version is installed, e.g. +> `pip install torch --index-url https://download.pytorch.org/whl/cpu`. +> URLs for alternative versions can be found [here](https://pytorch.org/get-started/locally/)._ For example, to build on a unix system using the gnu compilers and install to `$HOME/FTorchbin/` we would need to run: diff --git a/pages/developer.md b/pages/developer.md index 13a77c30..b6b64bfd 100644 --- a/pages/developer.md +++ b/pages/developer.md @@ -85,6 +85,37 @@ files `src/ctorch.h` and `src/ctorch.cpp`, we refer to the Torch [C++ API documentation](https://pytorch.org/cppdocs/api/library_root.html) pages on the PyTorch website for details. +### GPU device handling + +GPU device-specific code is handled in FTorch using codes defined in the root +`CMakeLists.txt` file: +```cmake +set(GPU_DEVICE_NONE 0) +set(GPU_DEVICE_CUDA 1) +set(GPU_DEVICE_XPU 12) +set(GPU_DEVICE_MPS 13) +``` +These device codes are chosen to be consistent with the numbering used in +PyTorch (see +https://github.com/pytorch/pytorch/blob/main/c10/core/DeviceType.h). When a user +specifies `-DGPU_DEVICE=XPU` (for example) in the FTorch CMake build, this is +mapped to the appropriate device code (in this case 12). The chosen device code +and all other ones defined are passed to the C++ compiler in the following step: +```cmake +target_compile_definitions( + ${LIB_NAME} + PRIVATE ${COMPILE_DEFS} GPU_DEVICE=${GPU_DEVICE_CODE} + GPU_DEVICE_NONE=${GPU_DEVICE_NONE} GPU_DEVICE_CUDA=${GPU_DEVICE_CUDA} + GPU_DEVICE_XPU=${GPU_DEVICE_XPU} GPU_DEVICE_MPS=${GPU_DEVICE_MPS}) +``` +The chosen device code will enable the appropriate C pre-processor conditions in +the C++ source so that that the code relevant to that device type becomes +active. + +An example illustrating why this approach was taken is that if we removed the +device codes and pre-processor conditions and tried to build with a CPU-only or +CUDA LibTorch installation then compile errors would arise from the use of the +`torch::xpu` module in `src/ctorch.cpp`. ### git hook diff --git a/pages/gpu.md b/pages/gpu.md index ae84bcfe..7266f4c5 100644 --- a/pages/gpu.md +++ b/pages/gpu.md @@ -4,27 +4,41 @@ title: GPU Support ## GPU Support -In order to run a model on GPU, two main changes are required: +In order to run a model on GPU, three main changes are required: -1) When saving your TorchScript model, ensure that it is on the GPU. +1) When building FTorch, specify the target GPU architecture using the +`GPU_DEVICE` argument. That is, set +```sh +cmake .. -DGPU_DEVICE= +``` +as appropriate. The default setting is equivalent to +```sh +cmake .. -DGPU_DEVICE=NONE +``` +i.e., CPU-only. + +2) When saving your TorchScript model, ensure that it is on the GPU. For example, when using [`pt2ts.py`](https://github.com/Cambridge-ICCS/FTorch/blob/main/utils/pt2ts.py), -this can be done by uncommenting the following lines: - +this can be done by passing the `--device_type ` argument. This +sets the `device_type` variable, which has the effect of transferring the model +and any input arrays to the specified GPU device in the following lines: ```python -device_type = torch.device("cuda") -trained_model = trained_model.to(device_type) -trained_model.eval() -trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device_type) -trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device_type) +if device_type != "cpu": + trained_model = trained_model.to(device_type) + trained_model.eval() + trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device_type) + trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device_type) ``` -> Note: _This code also moves the dummy input tensors to the GPU. -> Whilst not necessary for saving the model, but the tensors must also be on the GPU -> to test that the models runs._ +> Note: _This code moves the dummy input tensors to the GPU, as well as the +> model. +> Whilst this is not necessary for saving the model the tensors must be on +> the same GPU device to test that the models runs._ -2) When calling `torch_tensor_from_array` in Fortran, the device type for the input - tensor(s) should be set to `torch_kCUDA`, rather than `torch_kCPU`. +3) When calling `torch_tensor_from_array` in Fortran, the device type for the + input tensor(s) should be set to the relevant device type (`torch_kCUDA`, + `torch_kXPU`, or `torch_kMPS`) rather than `torch_kCPU`. This ensures that the inputs are on the same device type as the model. > Note: _You do **not** need to change the device type for the output tensors as we @@ -32,25 +46,28 @@ trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device_type) ### Multi-GPU runs -In the case of having multiple GPU devices, as well as setting `torch_kCUDA` as the -device type for any input tensors and models, you should also specify their device index -as the GPU device to be targeted. This argument is optional and will default to device -index 0 if unset. - -For example, the following code snippet sets up a Torch tensor with GPU device index 2: +In the case of having multiple GPU devices, as well as setting the device type +for any input tensors and models, you should also specify their device index +as the GPU device to be targeted. This argument is optional and will default to +device index 0 if unset. +For example, the following code snippet sets up a Torch tensor with CUDA GPU +device index 2: ```fortran device_index = 2 call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, & torch_kCUDA, device_index=device_index) ``` - -Whereas the following code snippet sets up a Torch tensor with (default) device index 0: - +Whereas the following code snippet sets up a Torch tensor with (default) CUDA +device index 0: ```fortran call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, & torch_kCUDA) ``` +Similarly for the XPU device type. + +> Note: The MPS device type does not currently support multiple devices, so the +> default device index should always be used. See the [MultiGPU example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/3_MultiGPU) diff --git a/src/ctorch.cpp b/src/ctorch.cpp index b5d5fd5f..cdc330dd 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -8,6 +8,10 @@ #include "ctorch.h" +#ifndef GPU_DEVICE +#define GPU_DEVICE GPU_DEVICE_NONE +#endif + // ============================================================================= // --- Constant expressions // ============================================================================= @@ -78,6 +82,7 @@ const auto get_libtorch_device(torch_device_t device_type, int device_index) { std::cerr << "[WARNING]: device index unused for CPU-only runs" << std::endl; } return torch::Device(torch::kCPU); +#if GPU_DEVICE == GPU_DEVICE_CUDA case torch_kCUDA: if (device_index == -1) { std::cerr << "[WARNING]: device index unset, defaulting to 0" << std::endl; @@ -90,6 +95,26 @@ const auto get_libtorch_device(torch_device_t device_type, int device_index) { << " for device count " << torch::cuda::device_count() << std::endl; exit(EXIT_FAILURE); } +#endif + case torch_kMPS: + if (device_index != -1 && device_index != 0) { + std::cerr << "[WARNING]: Only one device is available for MPS runs" << std::endl; + } + return torch::Device(torch::kMPS); +#if GPU_DEVICE == GPU_DEVICE_XPU + case torch_kXPU: + if (device_index == -1) { + std::cerr << "[WARNING]: device index unset, defaulting to 0" << std::endl; + device_index = 0; + } + if (device_index >= 0 && device_index < torch::xpu::device_count()) { + return torch::Device(torch::kXPU, device_index); + } else { + std::cerr << "[ERROR]: invalid device index " << device_index + << " for XPU device count " << torch::xpu::device_count() << std::endl; + exit(EXIT_FAILURE); + } +#endif default: std::cerr << "[WARNING]: unknown device type, setting to torch_kCPU" << std::endl; return torch::Device(torch::kCPU); @@ -103,6 +128,10 @@ const torch_device_t get_ftorch_device(torch::DeviceType device_type) { return torch_kCPU; case torch::kCUDA: return torch_kCUDA; + case torch::kXPU: + return torch_kXPU; + case torch::kMPS: + return torch_kMPS; default: std::cerr << "[ERROR]: device type " << device_type << " not implemented in FTorch" << std::endl; diff --git a/src/ctorch.h b/src/ctorch.h index 1123d6eb..8b029247 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -38,7 +38,13 @@ typedef enum { } torch_data_t; // Device types -typedef enum { torch_kCPU, torch_kCUDA } torch_device_t; +// NOTE: Defined in main CMakeLists and passed via preprocessor +typedef enum { + torch_kCPU = GPU_DEVICE_NONE, + torch_kCUDA = GPU_DEVICE_CUDA, + torch_kXPU = GPU_DEVICE_XPU, + torch_kMPS = GPU_DEVICE_MPS, +} torch_device_t; // ============================================================================= // --- Functions for constructing tensors diff --git a/src/ftorch.F90 b/src/ftorch.F90 index c3612a13..ba4939b1 100644 --- a/src/ftorch.F90 +++ b/src/ftorch.F90 @@ -52,9 +52,12 @@ module ftorch !| Enumerator for Torch devices ! From c_torch.h (torch_device_t) + ! NOTE: Defined in main CMakeLists and passed via preprocessor enum, bind(c) - enumerator :: torch_kCPU = 0 - enumerator :: torch_kCUDA = 1 + enumerator :: torch_kCPU = GPU_DEVICE_NONE + enumerator :: torch_kCUDA = GPU_DEVICE_CUDA + enumerator :: torch_kXPU = GPU_DEVICE_XPU + enumerator :: torch_kMPS = GPU_DEVICE_MPS end enum ! ============================================================================ diff --git a/src/ftorch.fypp b/src/ftorch.fypp index 0f57e0d8..ad99ba6e 100644 --- a/src/ftorch.fypp +++ b/src/ftorch.fypp @@ -71,9 +71,12 @@ module ftorch !| Enumerator for Torch devices ! From c_torch.h (torch_device_t) + ! NOTE: Defined in main CMakeLists and passed via preprocessor enum, bind(c) - enumerator :: torch_kCPU = 0 - enumerator :: torch_kCUDA = 1 + enumerator :: torch_kCPU = GPU_DEVICE_NONE + enumerator :: torch_kCUDA = GPU_DEVICE_CUDA + enumerator :: torch_kXPU = GPU_DEVICE_XPU + enumerator :: torch_kMPS = GPU_DEVICE_MPS end enum ! ============================================================================ diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index fc5705b0..fab05fd0 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -15,7 +15,7 @@ add_pfunit_ctest(test_tensor_interrogation add_pfunit_ctest(test_operator_overloads TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch) -if(ENABLE_CUDA) +if("${GPU_DEVICE}" STREQUAL "CUDA") check_language(CUDA) if(CMAKE_CUDA_COMPILER) enable_language(CUDA) diff --git a/utils/README.md b/utils/README.md index fffc9551..540b0dd5 100644 --- a/utils/README.md +++ b/utils/README.md @@ -4,13 +4,15 @@ This directory contains useful utilities for users of the library. ## `pt2ts.py` -This is a python script that can take a PyTorch model and convert it to Torchscript. +This is a python script that can take a PyTorch model and convert it to +TorchScript. It provides the user with the option to [jit.script](https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script) or [jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace) for both CPU and GPU. Dependencies: - PyTorch ### Usage + 1. Create and activate a virtual environment with PyTorch and any dependencies for your model. 2. Place the `pt2ts.py` script in the same folder as your model files. 3. Import your model into `pt2ts.py` and amend options as necessary (search for `FPTLIB-TODO`). @@ -18,3 +20,14 @@ Dependencies: The Torchscript model will be saved locally in the same location from which the `pt2ts.py` script is being run. + +#### Command line arguments + +The `pt2ts.py` script is set up with the `argparse` Python module such that it +accepts two command line arguments: +* `--filepath `, which allows you to specify the path + to the directory in which the TorchScript model should be saved. +* `--device_type `, which allows you to specify the CPU or GPU + device type with which to save the model. (To read and make use of a model on + a particular architecture, it's important that the model was targeted for that + architecture when it was saved.) diff --git a/utils/pt2ts.py b/utils/pt2ts.py index 8cea5fb3..dad047c5 100644 --- a/utils/pt2ts.py +++ b/utils/pt2ts.py @@ -1,7 +1,6 @@ """Load a PyTorch model and convert it to TorchScript.""" import os -import sys from typing import Optional # FPTLIB-TODO @@ -71,6 +70,28 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--device_type", + help="Device type to run the inference on", + type=str, + choices=["cpu", "cuda", "xpu", "mps"], + default="cpu", + ) + parser.add_argument( + "--filepath", + help="Path to the file containing the PyTorch model", + type=str, + default=os.path.dirname(__file__), + ) + parsed_args = parser.parse_args() + device_type = parsed_args.device_type + filepath = parsed_args.filepath + # ===================================================== # Load model and prepare for saving # ===================================================== @@ -97,13 +118,13 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod trained_model_dummy_input_1 = torch.ones((512, 40), dtype=torch.float64) trained_model_dummy_input_2 = torch.ones((512, 1), dtype=torch.float64) - # FPTLIB-TODO - # Uncomment the following lines to save for inference on GPU (rather than CPU): - # device = torch.device('cuda') - # trained_model = trained_model.to(device) - # trained_model.eval() - # trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device) - # trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device) + # Transfer the model and inputs to GPU device, if appropriate + if device_type != "cpu": + device = torch.device(device_type) + trained_model = trained_model.to(device) + trained_model.eval() + trained_model_dummy_input_1 = trained_model_dummy_input_1.to(device) + trained_model_dummy_input_2 = trained_model_dummy_input_2.to(device) # FPTLIB-TODO # Run model for dummy inputs @@ -119,7 +140,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # FPTLIB-TODO # Set the name of the file you want to save the torchscript model to: - saved_ts_filename = "saved_model.pt" + saved_ts_filename = f"saved_model_{device_type}.pt" # A filepath may also be provided. To do this, pass the filepath as an argument to # this script when it is run from the command line, i.e. `./pt2ts.py path/to/model`. @@ -141,9 +162,7 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod # Check model saved OK # ===================================================== - # Load torchscript and run model as a test - # FPTLIB-TODO - # Scale inputs as above and, if required, move inputs and mode to GPU + # Load torchscript and run model as a test, scaling inputs as above trained_model_dummy_input_1 = 2.0 * trained_model_dummy_input_1 trained_model_dummy_input_2 = 2.0 * trained_model_dummy_input_2 trained_model_testing_outputs = trained_model( @@ -172,7 +191,6 @@ def load_torchscript(filename: Optional[str] = "saved_model.pt") -> torch.nn.Mod raise RuntimeError(model_error) # Check that the model file is created - filepath = os.path.dirname(__file__) if len(sys.argv) == 1 else sys.argv[1] if not os.path.exists(os.path.join(filepath, saved_ts_filename)): torchscript_file_error = ( f"Saved TorchScript file {os.path.join(filepath, saved_ts_filename)} "