@@ -334,83 +334,67 @@ void torch_tensor_delete(torch_tensor_t tensor) {
334
334
// --- Operator overloads acting on tensors
335
335
// =====================================================================================
336
336
337
- torch_tensor_t torch_tensor_assign (const torch_tensor_t input) {
337
+ void torch_tensor_assign (torch_tensor_t output, const torch_tensor_t input) {
338
+ auto out = reinterpret_cast <torch::Tensor *>(output);
338
339
auto in = reinterpret_cast <torch::Tensor *const >(input);
339
340
torch::AutoGradMode enable_grad (in->requires_grad ());
340
- torch::Tensor *output = nullptr ;
341
- output = new torch::Tensor;
342
- *output = in->detach ().clone ();
343
- return output;
341
+ *out = *in;
344
342
}
345
343
346
- torch_tensor_t torch_tensor_add (const torch_tensor_t tensor1,
347
- const torch_tensor_t tensor2) {
344
+ void torch_tensor_add (torch_tensor_t output, const torch_tensor_t tensor1,
345
+ const torch_tensor_t tensor2) {
346
+ auto out = reinterpret_cast <torch::Tensor *>(output);
348
347
auto t1 = reinterpret_cast <torch::Tensor *const >(tensor1);
349
348
auto t2 = reinterpret_cast <torch::Tensor *const >(tensor2);
350
- torch::Tensor *output = nullptr ;
351
- output = new torch::Tensor;
352
- *output = *t1 + *t2;
353
- return output;
349
+ *out = *t1 + *t2;
354
350
}
355
351
356
- torch_tensor_t torch_tensor_negative (const torch_tensor_t tensor) {
352
+ void torch_tensor_negative (torch_tensor_t output, const torch_tensor_t tensor) {
353
+ auto out = reinterpret_cast <torch::Tensor *>(output);
357
354
auto t = reinterpret_cast <torch::Tensor *const >(tensor);
358
- torch::Tensor *output = nullptr ;
359
- output = new torch::Tensor;
360
- *output = -*t;
361
- return output;
355
+ *out = -*t;
362
356
}
363
357
364
- torch_tensor_t torch_tensor_subtract (const torch_tensor_t tensor1,
365
- const torch_tensor_t tensor2) {
358
+ void torch_tensor_subtract (torch_tensor_t output, const torch_tensor_t tensor1,
359
+ const torch_tensor_t tensor2) {
360
+ auto out = reinterpret_cast <torch::Tensor *>(output);
366
361
auto t1 = reinterpret_cast <torch::Tensor *const >(tensor1);
367
362
auto t2 = reinterpret_cast <torch::Tensor *const >(tensor2);
368
- torch::Tensor *output = nullptr ;
369
- output = new torch::Tensor;
370
- *output = *t1 - *t2;
371
- return output;
363
+ *out = *t1 - *t2;
372
364
}
373
365
374
- torch_tensor_t torch_tensor_multiply (const torch_tensor_t tensor1,
375
- const torch_tensor_t tensor2) {
366
+ void torch_tensor_multiply (torch_tensor_t output, const torch_tensor_t tensor1,
367
+ const torch_tensor_t tensor2) {
368
+ auto out = reinterpret_cast <torch::Tensor *>(output);
376
369
auto t1 = reinterpret_cast <torch::Tensor *const >(tensor1);
377
370
auto t2 = reinterpret_cast <torch::Tensor *const >(tensor2);
378
- torch::Tensor *output = nullptr ;
379
- output = new torch::Tensor;
380
- *output = *t1 * *t2;
381
- return output;
371
+ *out = *t1 * *t2;
382
372
}
383
373
384
- torch_tensor_t torch_tensor_divide (const torch_tensor_t tensor1,
385
- const torch_tensor_t tensor2) {
374
+ void torch_tensor_divide (torch_tensor_t output, const torch_tensor_t tensor1,
375
+ const torch_tensor_t tensor2) {
376
+ auto out = reinterpret_cast <torch::Tensor *>(output);
386
377
auto t1 = reinterpret_cast <torch::Tensor *const >(tensor1);
387
378
auto t2 = reinterpret_cast <torch::Tensor *const >(tensor2);
388
- torch::Tensor *output = nullptr ;
389
- output = new torch::Tensor;
390
- *output = *t1 / *t2;
391
- return output;
379
+ *out = *t1 / *t2;
392
380
}
393
381
394
- torch_tensor_t torch_tensor_power_int (const torch_tensor_t tensor,
395
- const torch_int_t exponent) {
396
- auto t = reinterpret_cast <torch::Tensor *const >(tensor);
382
+ void torch_tensor_power_int (torch_tensor_t output, const torch_tensor_t tensor,
383
+ const torch_int_t exponent) {
397
384
// NOTE: The following cast will only work for integer exponents
385
+ auto out = reinterpret_cast <torch::Tensor *>(output);
386
+ auto t = reinterpret_cast <torch::Tensor *const >(tensor);
398
387
auto exp = reinterpret_cast <int *const >(exponent);
399
- torch::Tensor *output = nullptr ;
400
- output = new torch::Tensor;
401
- *output = pow (*t, *exp );
402
- return output;
388
+ *out = pow (*t, *exp );
403
389
}
404
390
405
- torch_tensor_t torch_tensor_power_float (const torch_tensor_t tensor,
406
- const torch_float_t exponent) {
407
- auto t = reinterpret_cast <torch::Tensor *const >(tensor);
391
+ void torch_tensor_power_float (torch_tensor_t output, const torch_tensor_t tensor,
392
+ const torch_float_t exponent) {
408
393
// NOTE: The following cast will only work for floating point exponents
394
+ auto out = reinterpret_cast <torch::Tensor *>(output);
395
+ auto t = reinterpret_cast <torch::Tensor *const >(tensor);
409
396
auto exp = reinterpret_cast <float *const >(exponent);
410
- torch::Tensor *output = nullptr ;
411
- output = new torch::Tensor;
412
- *output = pow (*t, *exp );
413
- return output;
397
+ *out = pow (*t, *exp );
414
398
}
415
399
416
400
// =============================================================================
0 commit comments