Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement torch_tensor_backward #286

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0ffec4f
Enable requires_grad in autograd example
jwallwork23 Jan 20, 2025
2135968
Implement torch_tensor_backward
jwallwork23 Jan 20, 2025
24559fe
Simplify autograd example
jwallwork23 Jan 20, 2025
be86846
Setup requires_grad properly; Use TensorOptions in tensor constructors
jwallwork23 Jan 20, 2025
10d2499
Implement get_gradient
jwallwork23 Jan 20, 2025
b124da0
Finish autograd example
jwallwork23 Jan 20, 2025
586de97
Unit test for gradient of assignment
jwallwork23 Jan 20, 2025
9cc506f
Unit test for gradient of addition
jwallwork23 Jan 20, 2025
13d28e8
Unit test for gradient of subtraction
jwallwork23 Jan 21, 2025
13d8e40
Unit test for gradient of negative
jwallwork23 Jan 21, 2025
d8f40b2
Unit test for gradient of multiplication
jwallwork23 Jan 21, 2025
abb303e
Unit test for gradient of division
jwallwork23 Jan 21, 2025
0c9a59a
Unit test for gradient of square
jwallwork23 Jan 21, 2025
f68d647
Unit test for gradient of square root
jwallwork23 Jan 21, 2025
f100c97
Unit test for gradient of scalar multiplication - FIXME
jwallwork23 Jan 21, 2025
55d3849
Unit test for gradient of scalar division - FIXME
jwallwork23 Jan 21, 2025
37e6873
Rename get_gradient and provide method
jwallwork23 Feb 17, 2025
bd28159
Drop unnecessary c_loc use
jwallwork23 Feb 17, 2025
c815301
FIXME- backward needs intent(inout)
jwallwork23 Feb 17, 2025
7e350e0
Drop unused imports
jwallwork23 Feb 17, 2025
3bc60dc
Set up external gradient of ones for now
jwallwork23 Feb 17, 2025
7e3fe43
Scalar multiplication and division using rank-1 tensors
jwallwork23 Feb 18, 2025
d42f3ed
Fix static analysis
jwallwork23 Feb 18, 2025
2c5e8cb
Apply cmake-format
jwallwork23 Feb 18, 2025
ef3e22a
Rework scalar multiplication and division
jwallwork23 Feb 18, 2025
390c016
Update autograd example
jwallwork23 Feb 18, 2025
5835d3a
Add notes to autograd page on working with scalars
jwallwork23 Feb 18, 2025
7e31545
Merge branch 'rework-scalar-ops' into 152_tensor-backward
jwallwork23 Mar 10, 2025
c7e7f37
Update gradient approach to avoid memory leaks
jwallwork23 Mar 10, 2025
50c4b09
Update tests
jwallwork23 Mar 10, 2025
e3c890d
Merge branch 'main' into 152_tensor-backward
jwallwork23 Mar 12, 2025
0db20f2
Reinstate expected array in autograd example
jwallwork23 Mar 12, 2025
e2a1a0a
Better ordering of tensors in example
jwallwork23 Mar 12, 2025
883654a
Update autograd docs
jwallwork23 Mar 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions examples/7_Autograd/autograd.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ program example

! Import our library for interfacing with PyTorch's Autograd module
use ftorch, only: assignment(=), operator(+), operator(-), operator(*), operator(/), &
operator(**), torch_kCPU, torch_tensor, torch_tensor_from_array
operator(**), torch_kCPU, torch_tensor, torch_tensor_backward, &
torch_tensor_from_array

! Import our tools module for testing utils
use ftorch_test_utils, only : assert_allclose
Expand All @@ -15,25 +16,24 @@ program example
! Set working precision for reals
integer, parameter :: wp = sp

! Set up Fortran data structures
integer, parameter :: ndims = 1
integer, parameter :: n = 2
real(wp), dimension(n), target :: out_data
real(wp), dimension(n), target :: out_data1, out_data2, out_data3
real(wp), dimension(n) :: expected
integer :: tensor_layout(ndims) = [1]

! Flag for testing
logical :: test_pass

! Set up Torch data structures
type(torch_tensor) :: a, b, Q, multiplier, divisor
type(torch_tensor) :: a, b, Q, multiplier, divisor, dQda, dQdb

! Initialise Torch Tensors from input arrays as in Python example
call torch_tensor_from_array(a, [2.0_wp, 3.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.)
call torch_tensor_from_array(b, [6.0_wp, 4.0_wp], tensor_layout, torch_kCPU, requires_grad=.true.)

! Initialise Torch Tensor from array used for output
call torch_tensor_from_array(Q, out_data, tensor_layout, torch_kCPU)
call torch_tensor_from_array(Q, out_data1, tensor_layout, torch_kCPU)

! Scalar multiplication and division are not currently implemented in FTorch. However, you can
! achieve the same thing by defining a rank-1 tensor with a single entry, as follows:
Expand All @@ -42,17 +42,37 @@ program example

! Compute the same mathematical expression as in the Python example
Q = multiplier * (a**3 - b * b / divisor)
write (*,*) "Q = 3 * (a^3 - b*b/3) = 3*a^3 - b^2 = ", out_data(:)
write (*,*) "Q = 3 * (a^3 - b*b/3) = 3*a^3 - b^2 = ", out_data1(:)

! Check output tensor matches expected value
expected(:) = [-12.0_wp, 65.0_wp]
if (.not. assert_allclose(out_data, expected, test_name="autograd_Q")) then
print *, "Error :: value of Q does not match expected value"
if (.not. assert_allclose(out_data1, expected, test_name="autograd_Q")) then
write(*,*) "Error :: value of Q does not match expected value"
stop 999
end if

! Back-propagation
! TODO: Requires API extension
! Run the back-propagation operator
call torch_tensor_backward(Q)

! Create tensors based off output arrays for the gradients and then retrieve them
call torch_tensor_from_array(dQda, out_data2, tensor_layout, torch_kCPU)
call torch_tensor_from_array(dQdb, out_data3, tensor_layout, torch_kCPU)
dQda = a%grad()
dQdb = b%grad()

! Check the gradients take expected values
write(*,*) "dQda = 9*a^2 = ", out_data2
expected(:) = [36.0_wp, 81.0_wp]
if (.not. assert_allclose(out_data2, expected, test_name="autograd_dQdb")) then
write(*,*) "Error :: value of dQdb does not match expected value"
stop 999
end if
write(*,*) "dQdb = - 2*b = ", out_data3
expected(:) = [-12.0_wp, -8.0_wp]
if (.not. assert_allclose(out_data3, expected, test_name="autograd_dQdb")) then
write(*,*) "Error :: value of dQdb does not match expected value"
stop 999
end if

write (*,*) "Autograd example ran successfully"

Expand Down
41 changes: 40 additions & 1 deletion pages/autograd.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,45 @@ Torch tensors, see the associated
For Tensors that you would like to differentiate with respect to, be sure to
set the `requires_grad` optional argument to `.true.` when you construct it.

### The `backward` operator
### Back-propagation

Having defined some tensors with the `requires_grad` property set to `.true.`
and computed another tensor in terms of an expression involving these, we can
compute gradients of that tensor with respect to those that it depends on. This
is achieved using the `torch_tensor_backward` subroutine. For example, for
input tensors `a` and `b` and an output tensor `Q`:

```fortran
call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU, &
requires_grad=.true.)
call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU, &
requires_grad=.true.)
call torch_tensor_from_array(Q, out_data1, tensor_layout, torch_kCPU)

Q = a * b

call torch_tensor_backward(Q)
```

In the example code above, we can extract gradients of `Q` with respect to `a`
and/or `b`. To do this, we can use either the `torch_tensor_get_gradient`
subroutine or its alias - the `grad` method of the `torch_tensor` class. That
is, for tensors `dQda` and `dQdb`:

```fortran
! Function approach
call torch_tensor_from_array(dQda, out_data2, tensor_layout, torch_kCPU)
dQda = torch_tensor_get_gradient(a)

! Method approach
call torch_tensor_from_array(dQdb, out_data3, tensor_layout, torch_kCPU)
dQdb = b%grad()
```

### Optimisation

*Not yet implemented.*

### Loss functions

*Not yet implemented.*
17 changes: 17 additions & 0 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,23 @@ void torch_tensor_power_float(torch_tensor_t output, const torch_tensor_t tensor
*out = pow(*t, *exp);
}

// =============================================================================
// --- Functions related to automatic differentiation functionality for tensors
// =============================================================================

void torch_tensor_backward(const torch_tensor_t tensor,
const torch_tensor_t external_gradient) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
auto g = reinterpret_cast<torch::Tensor *const>(external_gradient);
t->backward(*g);
}

void torch_tensor_get_gradient(torch_tensor_t gradient, const torch_tensor_t tensor) {
auto g = reinterpret_cast<torch::Tensor *>(gradient);
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
*g = t->grad();
}

// =============================================================================
// --- Torch model API
// =============================================================================
Expand Down
21 changes: 21 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,27 @@ EXPORT_C void torch_tensor_power_float(torch_tensor_t output,
const torch_tensor_t tensor,
const torch_float_t exponent);

// =============================================================================
// --- Functions related to automatic differentiation functionality for tensors
// =============================================================================

/**
* Function to perform back-propagation on a Torch Tensor.
* Note that the Tensor must have the requires_grad attribute set to true.
* @param Tensor to perform back-propagation on
* @param Tensor with an external gradient to supply for the back-propagation
*/
EXPORT_C void torch_tensor_backward(const torch_tensor_t tensor,
const torch_tensor_t external_gradient);

/**
* Function to return the grad attribute of a Torch Tensor.
* @param Tensor for the gradient
* @param Tensor to get the gradient of
*/
EXPORT_C void torch_tensor_get_gradient(torch_tensor_t gradient,
const torch_tensor_t tensor);

// =============================================================================
// --- Torch model API
// =============================================================================
Expand Down
56 changes: 56 additions & 0 deletions src/ftorch.F90
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ module ftorch
procedure :: get_dtype => torch_tensor_get_dtype
procedure :: get_device_type => torch_tensor_get_device_type
procedure :: get_device_index => torch_tensor_get_device_index
procedure :: grad => torch_tensor_get_gradient
procedure :: requires_grad => torch_tensor_requires_grad
final :: torch_tensor_delete
end type torch_tensor
Expand Down Expand Up @@ -1815,6 +1816,61 @@ end subroutine torch_tensor_power_float_c
end function torch_tensor_power_real64


! ============================================================================
! --- Procedures related to automatic differentation functionality for tensors
! ============================================================================

!> Performs back-propagation on a Torch Tensor, given some external gradient.
subroutine torch_tensor_backward(tensor)
type(torch_tensor), intent(in) :: tensor
type(torch_tensor) :: external_gradient

interface
subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) &
bind(c, name = 'torch_tensor_backward')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: tensor_c
type(c_ptr), value, intent(in) :: external_gradient_c
end subroutine torch_tensor_backward_c
end interface

! External gradient to provide to the back-propagation consisting of a tensor of ones
! TODO: Accept other external gradients as an optional argument
call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), &
tensor%get_dtype(), tensor%get_device_type(), &
device_index=tensor%get_device_index())

! Call back-propagation with the provided external gradient
call torch_tensor_backward_c(tensor%p, external_gradient%p)

! Delete the external gradient tensor
call torch_tensor_delete(external_gradient)
end subroutine torch_tensor_backward

!> Retrieves the gradient with respect to a Torch Tensor.
function torch_tensor_get_gradient(tensor) result(gradient)
class(torch_tensor), intent(in) :: tensor ! Tensor to compute the gradient with respect to
type(torch_tensor) :: gradient ! Tensor holding the gradient

interface
subroutine torch_tensor_get_gradient_c(gradient_c, tensor_c) &
bind(c, name = 'torch_tensor_get_gradient')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: gradient_c
type(c_ptr), value, intent(in) :: tensor_c
end subroutine torch_tensor_get_gradient_c
end interface

if (.not. c_associated(gradient%p)) then
call torch_tensor_empty(gradient, tensor%get_rank(), tensor%get_shape(), tensor%get_dtype(), &
tensor%get_device_type(), device_index=tensor%get_device_index(), &
requires_grad=tensor%requires_grad())
end if
call torch_tensor_get_gradient_c(gradient%p, tensor%p)
end function torch_tensor_get_gradient

! ============================================================================
! --- Torch Model API
! ============================================================================
Expand Down
56 changes: 56 additions & 0 deletions src/ftorch.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ module ftorch
procedure :: get_dtype => torch_tensor_get_dtype
procedure :: get_device_type => torch_tensor_get_device_type
procedure :: get_device_index => torch_tensor_get_device_index
procedure :: grad => torch_tensor_get_gradient
procedure :: requires_grad => torch_tensor_requires_grad
final :: torch_tensor_delete
end type torch_tensor
Expand Down Expand Up @@ -836,6 +837,61 @@ contains

#:endfor

! ============================================================================
! --- Procedures related to automatic differentation functionality for tensors
! ============================================================================

!> Performs back-propagation on a Torch Tensor, given some external gradient.
subroutine torch_tensor_backward(tensor)
type(torch_tensor), intent(in) :: tensor
type(torch_tensor) :: external_gradient

interface
subroutine torch_tensor_backward_c(tensor_c, external_gradient_c) &
bind(c, name = 'torch_tensor_backward')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: tensor_c
type(c_ptr), value, intent(in) :: external_gradient_c
end subroutine torch_tensor_backward_c
end interface

! External gradient to provide to the back-propagation consisting of a tensor of ones
! TODO: Accept other external gradients as an optional argument
call torch_tensor_ones(external_gradient, tensor%get_rank(), tensor%get_shape(), &
tensor%get_dtype(), tensor%get_device_type(), &
device_index=tensor%get_device_index())

! Call back-propagation with the provided external gradient
call torch_tensor_backward_c(tensor%p, external_gradient%p)

! Delete the external gradient tensor
call torch_tensor_delete(external_gradient)
end subroutine torch_tensor_backward

!> Retrieves the gradient with respect to a Torch Tensor.
function torch_tensor_get_gradient(tensor) result(gradient)
class(torch_tensor), intent(in) :: tensor ! Tensor to compute the gradient with respect to
type(torch_tensor) :: gradient ! Tensor holding the gradient

interface
subroutine torch_tensor_get_gradient_c(gradient_c, tensor_c) &
bind(c, name = 'torch_tensor_get_gradient')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: gradient_c
type(c_ptr), value, intent(in) :: tensor_c
end subroutine torch_tensor_get_gradient_c
end interface

if (.not. c_associated(gradient%p)) then
call torch_tensor_empty(gradient, tensor%get_rank(), tensor%get_shape(), tensor%get_dtype(), &
tensor%get_device_type(), device_index=tensor%get_device_index(), &
requires_grad=tensor%requires_grad())
end if
call torch_tensor_get_gradient_c(gradient%p, tensor%p)
end function torch_tensor_get_gradient

! ============================================================================
! --- Torch Model API
! ============================================================================
Expand Down
21 changes: 14 additions & 7 deletions test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
cmake_minimum_required(VERSION 3.15...3.31)
cmake_policy (SET CMP0076 NEW)
cmake_policy(SET CMP0076 NEW)

project("FTorch unit tests" VERSION 1.0.0 LANGUAGES Fortran)
project(
"FTorch unit tests"
VERSION 1.0.0
LANGUAGES Fortran)

find_package(FTorch)
message(STATUS "Building with Fortran PyTorch coupling")
Expand All @@ -13,8 +16,12 @@ add_pfunit_ctest(
test_tensor_constructors_destructors.pf LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(test_tensor_interrogation
TEST_SOURCES test_tensor_interrogation.pf LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(test_operator_overloads
TEST_SOURCES test_tensor_operator_overloads.pf LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(
test_operator_overloads TEST_SOURCES test_tensor_operator_overloads.pf
LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(
test_operator_overloads_autograd TEST_SOURCES
test_tensor_operator_overloads_autograd.pf LINK_LIBRARIES FTorch::ftorch)

if("${GPU_DEVICE}" STREQUAL "CUDA")
check_language(CUDA)
Expand All @@ -23,7 +30,7 @@ if("${GPU_DEVICE}" STREQUAL "CUDA")
else()
message(ERROR "No CUDA support")
endif()
add_pfunit_ctest(test_tensor_interrogation_cuda
TEST_SOURCES test_tensor_interrogation_cuda.pf
LINK_LIBRARIES FTorch::ftorch)
add_pfunit_ctest(
test_tensor_interrogation_cuda TEST_SOURCES
test_tensor_interrogation_cuda.pf LINK_LIBRARIES FTorch::ftorch)
endif()
7 changes: 4 additions & 3 deletions test/unit/test_tensor_operator_overloads.pf
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module test_tensor_operator_overloads
use funit
use ftorch, only: assignment(=), torch_kCPU, torch_kFloat32, torch_tensor, torch_tensor_from_array
use ftorch_test_utils, only: assert_allclose
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t
use, intrinsic :: iso_c_binding, only : c_int64_t

implicit none

Expand Down Expand Up @@ -225,7 +225,6 @@ contains
torch_tensor, torch_tensor_from_array
use ftorch_test_utils, only: assert_allclose
use, intrinsic :: iso_fortran_env, only: sp => real32
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t

! Set working precision for reals
integer, parameter :: wp = sp
Expand Down Expand Up @@ -454,6 +453,9 @@ contains
! Create an arbitrary input array
in_data(:,:) = reshape([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3])

! Create a single valued rank-1 tensor based off the scalar
call torch_tensor_from_array(divisor, [scalar], [1], device_type)

! Create a tensor based off the input array
call torch_tensor_from_array(tensor1, in_data, tensor_layout, device_type)

Expand Down Expand Up @@ -542,7 +544,6 @@ contains
subroutine test_torch_tensor_sqrt(this)
use ftorch, only: operator(**)
use, intrinsic :: iso_fortran_env, only: sp => real32
use, intrinsic :: iso_c_binding, only : c_associated, c_int64_t

! Set working precision for reals
integer, parameter :: wp = sp
Expand Down
Loading