|
| 1 | +program inference |
| 2 | + |
| 3 | + ! Import precision info from iso |
| 4 | + use, intrinsic :: iso_fortran_env, only : sp => real32, stdout => output_unit |
| 5 | + |
| 6 | + ! Import our library for interfacing with PyTorch |
| 7 | + use ftorch, only : torch_model, torch_tensor, torch_kCPU, torch_delete, & |
| 8 | + torch_tensor_from_array, torch_model_load, torch_model_forward |
| 9 | + |
| 10 | + ! Import our tools module for testing utils |
| 11 | + use ftorch_test_utils, only : assert_allclose |
| 12 | + |
| 13 | + ! Import MPI |
| 14 | + use mpi, only : mpi_comm_rank, mpi_comm_size, mpi_comm_world, mpi_finalize, mpi_float, & |
| 15 | + mpi_gather, mpi_init |
| 16 | + |
| 17 | + implicit none |
| 18 | + |
| 19 | + ! Set working precision for reals |
| 20 | + integer, parameter :: wp = sp |
| 21 | + |
| 22 | + integer :: num_args, ix |
| 23 | + character(len=128), dimension(:), allocatable :: args |
| 24 | + |
| 25 | + ! Set up Fortran data structures |
| 26 | + real(wp), dimension(5), target :: in_data |
| 27 | + real(wp), dimension(5), target :: out_data |
| 28 | + real(wp), dimension(5), target :: expected |
| 29 | + integer, parameter :: tensor_layout(1) = [1] |
| 30 | + |
| 31 | + ! Set up Torch data structures |
| 32 | + ! The net, a vector of input tensors (in this case we only have one), and the output tensor |
| 33 | + type(torch_model) :: model |
| 34 | + type(torch_tensor), dimension(1) :: in_tensors |
| 35 | + type(torch_tensor), dimension(1) :: out_tensors |
| 36 | + |
| 37 | + ! Flag for testing |
| 38 | + logical :: test_pass |
| 39 | + |
| 40 | + ! MPI configuration |
| 41 | + integer :: rank, size, ierr, i |
| 42 | + |
| 43 | + ! Variables for testing |
| 44 | + real(wp), allocatable, dimension(:,:) :: recvbuf |
| 45 | + real(wp), dimension(5) :: result_chk |
| 46 | + integer :: rank_chk |
| 47 | + |
| 48 | + call mpi_init(ierr) |
| 49 | + call mpi_comm_rank(mpi_comm_world, rank, ierr) |
| 50 | + call mpi_comm_size(mpi_comm_world, size, ierr) |
| 51 | + |
| 52 | + ! Check MPI was configured correctly |
| 53 | + if (size == 1) then |
| 54 | + write(*,*) "MPI communicator size is 1, indicating that it is not configured correctly" |
| 55 | + write(*,*) "(assuming you specified more than one rank)" |
| 56 | + call clean_up() |
| 57 | + stop 999 |
| 58 | + end if |
| 59 | + |
| 60 | + ! Get TorchScript model file as a command line argument |
| 61 | + num_args = command_argument_count() |
| 62 | + allocate(args(num_args)) |
| 63 | + do ix = 1, num_args |
| 64 | + call get_command_argument(ix,args(ix)) |
| 65 | + end do |
| 66 | + |
| 67 | + ! Initialise data and print the values used on each MPI rank |
| 68 | + in_data = [(rank + i, i = 0, 4)] |
| 69 | + write(unit=stdout, fmt="('input on rank ',i1,': ')", advance="no") rank |
| 70 | + write(unit=stdout, fmt=100) in_data(:) |
| 71 | + 100 format('[',4(f5.1,','),f5.1,']') |
| 72 | + |
| 73 | + ! Create Torch input/output tensors from the above arrays |
| 74 | + call torch_tensor_from_array(in_tensors(1), in_data, tensor_layout, torch_kCPU) |
| 75 | + call torch_tensor_from_array(out_tensors(1), out_data, tensor_layout, torch_kCPU) |
| 76 | + |
| 77 | + ! Load ML model |
| 78 | + call torch_model_load(model, args(1), torch_kCPU) |
| 79 | + |
| 80 | + ! Run inference on each MPI rank |
| 81 | + call torch_model_forward(model, in_tensors, out_tensors) |
| 82 | + |
| 83 | + ! Print the values computed on each MPI rank |
| 84 | + write(unit=stdout, fmt="('output on rank ',i1,': ')", advance="no") rank |
| 85 | + write(unit=stdout, fmt=100) out_data(:) |
| 86 | + |
| 87 | + ! Gather the outputs onto rank 0 |
| 88 | + allocate(recvbuf(5,size)) |
| 89 | + call mpi_gather(out_data, 5, mpi_float, recvbuf, 5, mpi_float, 0, mpi_comm_world, ierr) |
| 90 | + |
| 91 | + ! Check that the correct values were attained |
| 92 | + if (rank == 0) then |
| 93 | + |
| 94 | + ! Check output tensor matches expected value |
| 95 | + do rank_chk = 0, size-1 |
| 96 | + expected = [(2 * (rank_chk + i), i = 0, 4)] |
| 97 | + result_chk(:) = recvbuf(:,rank_chk+1) |
| 98 | + test_pass = assert_allclose(result_chk, expected, test_name="MPI") |
| 99 | + if (.not. test_pass) then |
| 100 | + write(unit=stdout, fmt="('rank ',i1,' result: ')") rank_chk |
| 101 | + write(unit=stdout, fmt=100) result_chk(:) |
| 102 | + write(unit=stdout, fmt="('does not match expected value')") |
| 103 | + write(unit=stdout, fmt=100) expected(:) |
| 104 | + call clean_up() |
| 105 | + stop 999 |
| 106 | + end if |
| 107 | + end do |
| 108 | + |
| 109 | + write (*,*) "MPI Fortran example ran successfully" |
| 110 | + end if |
| 111 | + |
| 112 | + call clean_up() |
| 113 | + |
| 114 | + contains |
| 115 | + |
| 116 | + subroutine clean_up() |
| 117 | + call torch_delete(model) |
| 118 | + call torch_delete(in_tensors) |
| 119 | + call torch_delete(out_tensors) |
| 120 | + call mpi_finalize(ierr) |
| 121 | + deallocate(recvbuf) |
| 122 | + end subroutine clean_up |
| 123 | + |
| 124 | +end program inference |
0 commit comments