Skip to content

Commit 864b0eb

Browse files
authored
Hook up requires_grad (#288)
* Enable requires_grad in autograd example * Setup requires_grad properly; Use TensorOptions in tensor constructors * Implement torch_tensor_requires_grad and test * Add note on requires_grad on autograd page * Introduce fixtures in interrogation tests * Pass requires_grad argument when creating new tensors
1 parent 3870e1b commit 864b0eb

8 files changed

+283
-132
lines changed

examples/6_Autograd/autograd.f90

+2-3
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ program example
3434
! Initialise Torch Tensors from input arrays as in Python example
3535
in_data1(:,1) = [2.0_wp, 3.0_wp]
3636
in_data2(:,1) = [6.0_wp, 4.0_wp]
37-
! TODO: Implement requires_grad=.true.
38-
call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU)
39-
call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU)
37+
call torch_tensor_from_array(a, in_data1, tensor_layout, torch_kCPU, requires_grad=.true.)
38+
call torch_tensor_from_array(b, in_data2, tensor_layout, torch_kCPU, requires_grad=.true.)
4039

4140
! Initialise Torch Tensor from array used for output
4241
call torch_tensor_from_array(Q, out_data, tensor_layout, torch_kCPU)

pages/autograd.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ Torch tensors, see the associated
3535

3636
### The `requires_grad` property
3737

38-
*Not yet implemented.*
38+
For Tensors that you would like to differentiate with respect to, be sure to
39+
set the `requires_grad` optional argument to `.true.` when you construct it.
3940

4041
### The `backward` operator
4142

src/ctorch.cpp

+25-9
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,12 @@ torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
151151
try {
152152
// This doesn't throw if shape and dimensions are incompatible
153153
c10::IntArrayRef vshape(shape, ndim);
154+
auto options = torch::TensorOptions()
155+
.dtype(get_libtorch_dtype(dtype))
156+
.device(get_libtorch_device(device_type, device_index))
157+
.requires_grad(requires_grad);
154158
tensor = new torch::Tensor;
155-
*tensor = torch::empty(vshape, torch::dtype(get_libtorch_dtype(dtype)))
156-
.to(get_libtorch_device(device_type, device_index));
159+
*tensor = torch::empty(vshape, options);
157160
} catch (const torch::Error &e) {
158161
std::cerr << "[ERROR]: " << e.msg() << std::endl;
159162
delete tensor;
@@ -174,9 +177,12 @@ torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
174177
try {
175178
// This doesn't throw if shape and dimensions are incompatible
176179
c10::IntArrayRef vshape(shape, ndim);
180+
auto options = torch::TensorOptions()
181+
.dtype(get_libtorch_dtype(dtype))
182+
.device(get_libtorch_device(device_type, device_index))
183+
.requires_grad(requires_grad);
177184
tensor = new torch::Tensor;
178-
*tensor = torch::zeros(vshape, torch::dtype(get_libtorch_dtype(dtype)))
179-
.to(get_libtorch_device(device_type, device_index));
185+
*tensor = torch::zeros(vshape, options);
180186
} catch (const torch::Error &e) {
181187
std::cerr << "[ERROR]: " << e.msg() << std::endl;
182188
delete tensor;
@@ -197,9 +203,12 @@ torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
197203
try {
198204
// This doesn't throw if shape and dimensions are incompatible
199205
c10::IntArrayRef vshape(shape, ndim);
206+
auto options = torch::TensorOptions()
207+
.dtype(get_libtorch_dtype(dtype))
208+
.device(get_libtorch_device(device_type, device_index))
209+
.requires_grad(requires_grad);
200210
tensor = new torch::Tensor;
201-
*tensor = torch::ones(vshape, torch::dtype(get_libtorch_dtype(dtype)))
202-
.to(get_libtorch_device(device_type, device_index));
211+
*tensor = torch::ones(vshape, options);
203212
} catch (const torch::Error &e) {
204213
std::cerr << "[ERROR]: " << e.msg() << std::endl;
205214
delete tensor;
@@ -225,10 +234,12 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
225234
// This doesn't throw if shape and dimensions are incompatible
226235
c10::IntArrayRef vshape(shape, ndim);
227236
c10::IntArrayRef vstrides(strides, ndim);
237+
auto options = torch::TensorOptions()
238+
.dtype(get_libtorch_dtype(dtype))
239+
.device(get_libtorch_device(device_type, device_index))
240+
.requires_grad(requires_grad);
228241
tensor = new torch::Tensor;
229-
*tensor = torch::from_blob(data, vshape, vstrides,
230-
torch::dtype(get_libtorch_dtype(dtype)))
231-
.to(get_libtorch_device(device_type, device_index));
242+
*tensor = torch::from_blob(data, vshape, vstrides, options);
232243

233244
} catch (const torch::Error &e) {
234245
std::cerr << "[ERROR]: " << e.msg() << std::endl;
@@ -283,6 +294,11 @@ int torch_tensor_get_device_index(const torch_tensor_t tensor) {
283294
return t->device().index();
284295
}
285296

297+
bool torch_tensor_requires_grad(const torch_tensor_t tensor) {
298+
auto t = reinterpret_cast<torch::Tensor *>(tensor);
299+
return t->requires_grad();
300+
}
301+
286302
// =====================================================================================
287303
// --- Functions for deallocating tensors
288304
// =====================================================================================

src/ctorch.h

+7
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ EXPORT_C torch_device_t torch_tensor_get_device_type(const torch_tensor_t tensor
156156
*/
157157
EXPORT_C int torch_tensor_get_device_index(const torch_tensor_t tensor);
158158

159+
/**
160+
* Function to determine whether a Torch Tensor requires the autograd module
161+
* @param Torch Tensor to interrogate
162+
* @return whether the Torch Tensor requires autograd
163+
*/
164+
EXPORT_C bool torch_tensor_requires_grad(const torch_tensor_t tensor);
165+
159166
// =============================================================================
160167
// --- Functions for deallocating tensors
161168
// =============================================================================

0 commit comments

Comments
 (0)