@@ -151,9 +151,12 @@ torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
151
151
try {
152
152
// This doesn't throw if shape and dimensions are incompatible
153
153
c10::IntArrayRef vshape (shape, ndim);
154
+ auto options = torch::TensorOptions ()
155
+ .dtype (get_libtorch_dtype (dtype))
156
+ .device (get_libtorch_device (device_type, device_index))
157
+ .requires_grad (requires_grad);
154
158
tensor = new torch::Tensor;
155
- *tensor = torch::empty (vshape, torch::dtype (get_libtorch_dtype (dtype)))
156
- .to (get_libtorch_device (device_type, device_index));
159
+ *tensor = torch::empty (vshape, options);
157
160
} catch (const torch::Error &e) {
158
161
std::cerr << " [ERROR]: " << e.msg () << std::endl;
159
162
delete tensor;
@@ -174,9 +177,12 @@ torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
174
177
try {
175
178
// This doesn't throw if shape and dimensions are incompatible
176
179
c10::IntArrayRef vshape (shape, ndim);
180
+ auto options = torch::TensorOptions ()
181
+ .dtype (get_libtorch_dtype (dtype))
182
+ .device (get_libtorch_device (device_type, device_index))
183
+ .requires_grad (requires_grad);
177
184
tensor = new torch::Tensor;
178
- *tensor = torch::zeros (vshape, torch::dtype (get_libtorch_dtype (dtype)))
179
- .to (get_libtorch_device (device_type, device_index));
185
+ *tensor = torch::zeros (vshape, options);
180
186
} catch (const torch::Error &e) {
181
187
std::cerr << " [ERROR]: " << e.msg () << std::endl;
182
188
delete tensor;
@@ -197,9 +203,12 @@ torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
197
203
try {
198
204
// This doesn't throw if shape and dimensions are incompatible
199
205
c10::IntArrayRef vshape (shape, ndim);
206
+ auto options = torch::TensorOptions ()
207
+ .dtype (get_libtorch_dtype (dtype))
208
+ .device (get_libtorch_device (device_type, device_index))
209
+ .requires_grad (requires_grad);
200
210
tensor = new torch::Tensor;
201
- *tensor = torch::ones (vshape, torch::dtype (get_libtorch_dtype (dtype)))
202
- .to (get_libtorch_device (device_type, device_index));
211
+ *tensor = torch::ones (vshape, options);
203
212
} catch (const torch::Error &e) {
204
213
std::cerr << " [ERROR]: " << e.msg () << std::endl;
205
214
delete tensor;
@@ -225,10 +234,12 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
225
234
// This doesn't throw if shape and dimensions are incompatible
226
235
c10::IntArrayRef vshape (shape, ndim);
227
236
c10::IntArrayRef vstrides (strides, ndim);
237
+ auto options = torch::TensorOptions ()
238
+ .dtype (get_libtorch_dtype (dtype))
239
+ .device (get_libtorch_device (device_type, device_index))
240
+ .requires_grad (requires_grad);
228
241
tensor = new torch::Tensor;
229
- *tensor = torch::from_blob (data, vshape, vstrides,
230
- torch::dtype (get_libtorch_dtype (dtype)))
231
- .to (get_libtorch_device (device_type, device_index));
242
+ *tensor = torch::from_blob (data, vshape, vstrides, options);
232
243
233
244
} catch (const torch::Error &e) {
234
245
std::cerr << " [ERROR]: " << e.msg () << std::endl;
@@ -283,6 +294,11 @@ int torch_tensor_get_device_index(const torch_tensor_t tensor) {
283
294
return t->device ().index ();
284
295
}
285
296
297
+ bool torch_tensor_requires_grad (const torch_tensor_t tensor) {
298
+ auto t = reinterpret_cast <torch::Tensor *>(tensor);
299
+ return t->requires_grad ();
300
+ }
301
+
286
302
// =====================================================================================
287
303
// --- Functions for deallocating tensors
288
304
// =====================================================================================
0 commit comments