Skip to content

Commit 6deab9a

Browse files
authored
Checks for inconsistent device types (#272)
* Output array needs to be created empty in autograd example * Do not interrogate tensors that haven't been constructed * Drop `torch_tensor_to_array` (#303) * Drop torch_tensor_to_<array/blob> * Remove torch_tensor_to_array from autograd example * Use std::move for assignment * Update constructors/destructors tests to avoid to_array * Update summary of autograd example * Update overloads tests to avoid to_array * Fix previous API change note
1 parent dcdadef commit 6deab9a

9 files changed

+546
-1441
lines changed

examples/6_Autograd/autograd.f90

+10-21
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ program example
55

66
! Import our library for interfacing with PyTorch's Autograd module
77
use ftorch, only: assignment(=), operator(+), operator(-), operator(*), &
8-
operator(/), operator(**), torch_kCPU, torch_tensor, torch_tensor_delete, &
9-
torch_tensor_from_array, torch_tensor_to_array
8+
operator(/), operator(**), torch_kCPU, torch_kFloat32, torch_tensor, torch_tensor_delete, &
9+
torch_tensor_empty, torch_tensor_from_array
1010

1111
! Import our tools module for testing utils
1212
use ftorch_test_utils, only : assert_allclose
@@ -17,60 +17,49 @@ program example
1717
integer, parameter :: wp = sp
1818

1919
! Set up Fortran data structures
20+
integer, parameter :: ndims = 2
2021
integer, parameter :: n=2, m=1
2122
real(wp), dimension(n,m), target :: in_data1
2223
real(wp), dimension(n,m), target :: in_data2
23-
real(wp), dimension(:,:), pointer :: out_data
24+
real(wp), dimension(n,m), target :: out_data
2425
real(wp), dimension(n,m) :: expected
25-
integer :: tensor_layout(2) = [1, 2]
26+
integer :: tensor_layout(ndims) = [1, 2]
2627

2728
! Flag for testing
2829
logical :: test_pass
2930

3031
! Set up Torch data structures
3132
type(torch_tensor) :: a, b, Q
3233

33-
! Initialise input arrays as in Python example
34+
! Initialise Torch Tensors from input arrays as in Python example
3435
in_data1(:,1) = [2.0_wp, 3.0_wp]
3536
in_data2(:,1) = [6.0_wp, 4.0_wp]
36-
37-
! Construct a Torch Tensor from a Fortran array
3837
! TODO: Implement requires_grad=.true.
3938
call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU)
4039
call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU)
4140

41+
! Initialise Torch Tensor from array used for output
42+
call torch_tensor_from_array(Q, out_data, tensor_layout, torch_kCPU)
43+
4244
! Check arithmetic operations work for torch_tensors
4345
write (*,*) "a = ", in_data1(:,1)
4446
write (*,*) "b = ", in_data2(:,1)
4547
Q = 3 * (a**3 - b * b / 3)
4648

4749
! Extract a Fortran array from a Torch tensor
48-
call torch_tensor_to_array(Q, out_data, shape(in_data1))
4950
write (*,*) "Q = 3 * (a ** 3 - b * b / 2) =", out_data(:,1)
5051

5152
! Check output tensor matches expected value
5253
expected(:,1) = [-12.0_wp, 65.0_wp]
53-
test_pass = assert_allclose(out_data, expected, test_name="torch_tensor_to_array", rtol=1e-5)
54+
test_pass = assert_allclose(out_data, expected, test_name="autograd_Q")
5455
if (.not. test_pass) then
55-
call clean_up()
5656
print *, "Error :: out_data does not match expected value"
5757
stop 999
5858
end if
5959

6060
! Back-propagation
6161
! TODO: Requires API extension
6262

63-
call clean_up()
6463
write (*,*) "Autograd example ran successfully"
6564

66-
contains
67-
68-
! Subroutine for freeing memory and nullifying pointers used in the example
69-
subroutine clean_up()
70-
nullify(out_data)
71-
call torch_tensor_delete(a)
72-
call torch_tensor_delete(b)
73-
call torch_tensor_delete(Q)
74-
end subroutine clean_up
75-
7665
end program example

pages/examples.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,5 @@ different input multiple times in the same workflow.
183183
[This worked example](https://github.com/Cambridge-ICCS/FTorch/tree/main/examples/6_Autograd)
184184
is currently under development. Eventually, it will demonstrate how to perform
185185
automatic differentiation in FTorch by leveraging PyTorch's Autograd module.
186-
Currently, it just demonstrates how to use `torch_tensor_to_array` and compute
187-
mathematical expressions involving Torch tensors.
186+
Currently, it just demonstrates how to compute mathematical expressions
187+
involving Torch tensors.

pages/updates.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ title: Recent API Changes
33
## February 2025
44

55
If you use a version of FTorch from before commit
6-
[f7fbebf](f7fbebfdad2a4801f57742a2bb12bc21e70881ff)
6+
[c85185e](c85185e6c261606c212dd11fee734663d610b695)
77
(February 2025) you will notice that the main `CMakeLists.txt` file has moved
88
from `src/` to the root level of the FTorch repository. This move was mainly to
99
simplify the development experience, such that the examples could be built as

src/ctorch.cpp

+5-39
Original file line numberDiff line numberDiff line change
@@ -246,44 +246,6 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
246246
// --- Functions for interrogating tensors
247247
// =====================================================================================
248248

249-
void *torch_to_blob(const torch_tensor_t tensor, const torch_data_t dtype) {
250-
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
251-
void *raw_ptr;
252-
switch (dtype) {
253-
case torch_kUInt8:
254-
std::cerr << "[WARNING]: uint8 not supported" << std::endl;
255-
exit(EXIT_FAILURE);
256-
case torch_kInt8:
257-
raw_ptr = (void *)t->data_ptr<int8_t>();
258-
break;
259-
case torch_kInt16:
260-
raw_ptr = (void *)t->data_ptr<int16_t>();
261-
break;
262-
case torch_kInt32:
263-
raw_ptr = (void *)t->data_ptr<int32_t>();
264-
break;
265-
case torch_kInt64:
266-
raw_ptr = (void *)t->data_ptr<int64_t>();
267-
break;
268-
case torch_kFloat16:
269-
std::cerr << "[WARNING]: float16 not supported" << std::endl;
270-
// NOTE: std::float16_t is available but only with C++23
271-
exit(EXIT_FAILURE);
272-
case torch_kFloat32:
273-
raw_ptr = (void *)t->data_ptr<float>();
274-
// NOTE: std::float32_t is available but only with C++23
275-
break;
276-
case torch_kFloat64:
277-
raw_ptr = (void *)t->data_ptr<double>();
278-
// NOTE: std::float64_t is available but only with C++23
279-
break;
280-
default:
281-
std::cerr << "[WARNING]: unknown data type" << std::endl;
282-
exit(EXIT_FAILURE);
283-
}
284-
return raw_ptr;
285-
}
286-
287249
void torch_tensor_print(const torch_tensor_t tensor) {
288250
auto t = reinterpret_cast<torch::Tensor *>(tensor);
289251
std::cout << *t << std::endl;
@@ -338,7 +300,11 @@ void torch_tensor_assign(torch_tensor_t output, const torch_tensor_t input) {
338300
auto out = reinterpret_cast<torch::Tensor *>(output);
339301
auto in = reinterpret_cast<torch::Tensor *const>(input);
340302
torch::AutoGradMode enable_grad(in->requires_grad());
341-
*out = *in;
303+
// NOTE: The following line ensures that the output tensor continues to point to a
304+
// Fortran array if it was set up to do so using torch_tensor_from_array. If
305+
// it's removed then the Fortran array keeps its original value and is no
306+
// longer be pointed to.
307+
std::move(*out) = *in;
342308
}
343309

344310
void torch_tensor_add(torch_tensor_t output, const torch_tensor_t tensor1,

src/ctorch.h

-9
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,6 @@ EXPORT_C torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *sha
111111
// --- Functions for interrogating tensors
112112
// =============================================================================
113113

114-
/**
115-
* Function to extract a C-array from a Torch Tensor's data.
116-
*
117-
* @param the Torch Tensor
118-
* @param data type of the elements of the Tensor
119-
* @return pointer to the Tensor in memory
120-
*/
121-
EXPORT_C void *torch_to_blob(const torch_tensor_t tensor, const torch_data_t dtype);
122-
123114
/**
124115
* Function to print out a Torch Tensor
125116
* @param Torch Tensor to print

0 commit comments

Comments
 (0)