Skip to content

Commit dcdadef

Browse files
authored
Address memory leak in overloaded operators (#297)
* Make operator overload C bindings subroutines; create empty tensors inside * Set up destructor for torch_tensor * No need to call torch_tensor_delete in unit tests * Test manual/auto delete with fixture extension
1 parent 120f1e5 commit dcdadef

9 files changed

+469
-537
lines changed

src/ctorch.cpp

+32-48
Original file line numberDiff line numberDiff line change
@@ -334,83 +334,67 @@ void torch_tensor_delete(torch_tensor_t tensor) {
334334
// --- Operator overloads acting on tensors
335335
// =====================================================================================
336336

337-
torch_tensor_t torch_tensor_assign(const torch_tensor_t input) {
337+
void torch_tensor_assign(torch_tensor_t output, const torch_tensor_t input) {
338+
auto out = reinterpret_cast<torch::Tensor *>(output);
338339
auto in = reinterpret_cast<torch::Tensor *const>(input);
339340
torch::AutoGradMode enable_grad(in->requires_grad());
340-
torch::Tensor *output = nullptr;
341-
output = new torch::Tensor;
342-
*output = in->detach().clone();
343-
return output;
341+
*out = *in;
344342
}
345343

346-
torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
347-
const torch_tensor_t tensor2) {
344+
void torch_tensor_add(torch_tensor_t output, const torch_tensor_t tensor1,
345+
const torch_tensor_t tensor2) {
346+
auto out = reinterpret_cast<torch::Tensor *>(output);
348347
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
349348
auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
350-
torch::Tensor *output = nullptr;
351-
output = new torch::Tensor;
352-
*output = *t1 + *t2;
353-
return output;
349+
*out = *t1 + *t2;
354350
}
355351

356-
torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor) {
352+
void torch_tensor_negative(torch_tensor_t output, const torch_tensor_t tensor) {
353+
auto out = reinterpret_cast<torch::Tensor *>(output);
357354
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
358-
torch::Tensor *output = nullptr;
359-
output = new torch::Tensor;
360-
*output = -*t;
361-
return output;
355+
*out = -*t;
362356
}
363357

364-
torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1,
365-
const torch_tensor_t tensor2) {
358+
void torch_tensor_subtract(torch_tensor_t output, const torch_tensor_t tensor1,
359+
const torch_tensor_t tensor2) {
360+
auto out = reinterpret_cast<torch::Tensor *>(output);
366361
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
367362
auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
368-
torch::Tensor *output = nullptr;
369-
output = new torch::Tensor;
370-
*output = *t1 - *t2;
371-
return output;
363+
*out = *t1 - *t2;
372364
}
373365

374-
torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1,
375-
const torch_tensor_t tensor2) {
366+
void torch_tensor_multiply(torch_tensor_t output, const torch_tensor_t tensor1,
367+
const torch_tensor_t tensor2) {
368+
auto out = reinterpret_cast<torch::Tensor *>(output);
376369
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
377370
auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
378-
torch::Tensor *output = nullptr;
379-
output = new torch::Tensor;
380-
*output = *t1 * *t2;
381-
return output;
371+
*out = *t1 * *t2;
382372
}
383373

384-
torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1,
385-
const torch_tensor_t tensor2) {
374+
void torch_tensor_divide(torch_tensor_t output, const torch_tensor_t tensor1,
375+
const torch_tensor_t tensor2) {
376+
auto out = reinterpret_cast<torch::Tensor *>(output);
386377
auto t1 = reinterpret_cast<torch::Tensor *const>(tensor1);
387378
auto t2 = reinterpret_cast<torch::Tensor *const>(tensor2);
388-
torch::Tensor *output = nullptr;
389-
output = new torch::Tensor;
390-
*output = *t1 / *t2;
391-
return output;
379+
*out = *t1 / *t2;
392380
}
393381

394-
torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor,
395-
const torch_int_t exponent) {
396-
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
382+
void torch_tensor_power_int(torch_tensor_t output, const torch_tensor_t tensor,
383+
const torch_int_t exponent) {
397384
// NOTE: The following cast will only work for integer exponents
385+
auto out = reinterpret_cast<torch::Tensor *>(output);
386+
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
398387
auto exp = reinterpret_cast<int *const>(exponent);
399-
torch::Tensor *output = nullptr;
400-
output = new torch::Tensor;
401-
*output = pow(*t, *exp);
402-
return output;
388+
*out = pow(*t, *exp);
403389
}
404390

405-
torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor,
406-
const torch_float_t exponent) {
407-
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
391+
void torch_tensor_power_float(torch_tensor_t output, const torch_tensor_t tensor,
392+
const torch_float_t exponent) {
408393
// NOTE: The following cast will only work for floating point exponents
394+
auto out = reinterpret_cast<torch::Tensor *>(output);
395+
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
409396
auto exp = reinterpret_cast<float *const>(exponent);
410-
torch::Tensor *output = nullptr;
411-
output = new torch::Tensor;
412-
*output = pow(*t, *exp);
413-
return output;
397+
*out = pow(*t, *exp);
414398
}
415399

416400
// =============================================================================

src/ctorch.h

+23-22
Original file line numberDiff line numberDiff line change
@@ -181,72 +181,73 @@ EXPORT_C void torch_tensor_delete(torch_tensor_t tensor);
181181

182182
/**
183183
* Overloads the assignment operator for Torch Tensor
184+
* @param output Tensor
184185
* @param input Tensor
185-
* @return copy of input Tensor
186186
*/
187-
EXPORT_C torch_tensor_t torch_tensor_assign(const torch_tensor_t input);
187+
EXPORT_C void torch_tensor_assign(torch_tensor_t output, const torch_tensor_t input);
188188

189189
/**
190190
* Overloads the addition operator for two Torch Tensors
191+
* @param sum of the Tensors
191192
* @param first Tensor to be added
192193
* @param second Tensor to be added
193-
* @return sum of the Tensors
194194
*/
195-
EXPORT_C torch_tensor_t torch_tensor_add(const torch_tensor_t tensor1,
196-
const torch_tensor_t tensor2);
195+
EXPORT_C void torch_tensor_add(torch_tensor_t, const torch_tensor_t tensor1,
196+
const torch_tensor_t tensor2);
197197

198198
/**
199199
* Overloads the minus operator for a single Torch Tensor
200+
* @param the negative Tensor
200201
* @param Tensor to take the negative of
201-
* @return the negative Tensor
202202
*/
203-
EXPORT_C torch_tensor_t torch_tensor_negative(const torch_tensor_t tensor);
203+
EXPORT_C void torch_tensor_negative(torch_tensor_t output, const torch_tensor_t tensor);
204204

205205
/**
206206
* Overloads the subtraction operator for two Torch Tensors
207+
* @param output Tensor
207208
* @param first Tensor to be subtracted
208209
* @param second Tensor to be subtracted
209-
* @return difference of the Tensors
210210
*/
211-
EXPORT_C torch_tensor_t torch_tensor_subtract(const torch_tensor_t tensor1,
212-
const torch_tensor_t tensor2);
211+
EXPORT_C void torch_tensor_subtract(torch_tensor_t output, const torch_tensor_t tensor1,
212+
const torch_tensor_t tensor2);
213213

214214
/**
215215
* Overloads the multiplication operator for two Torch Tensors
216+
* @param output Tensor
216217
* @param first Tensor to be multiplied
217218
* @param second Tensor to be multiplied
218-
* @return product of the Tensors
219219
*/
220-
EXPORT_C torch_tensor_t torch_tensor_multiply(const torch_tensor_t tensor1,
221-
const torch_tensor_t tensor2);
220+
EXPORT_C void torch_tensor_multiply(torch_tensor_t output, const torch_tensor_t tensor1,
221+
const torch_tensor_t tensor2);
222222

223223
/**
224224
* Overloads the division operator for two Torch Tensors.
225+
* @param output Tensor
225226
* @param first Tensor to be divided
226227
* @param second Tensor to be divided
227-
* @return quotient of the Tensors
228228
*/
229-
EXPORT_C torch_tensor_t torch_tensor_divide(const torch_tensor_t tensor1,
230-
const torch_tensor_t tensor2);
229+
EXPORT_C void torch_tensor_divide(torch_tensor_t output, const torch_tensor_t tensor1,
230+
const torch_tensor_t tensor2);
231231

232232
/**
233233
* Overloads the exponentiation operator for a Torch Tensor and an integer exponent
234+
* @param output Tensor
234235
* @param Tensor to take the power of
235236
* @param integer exponent
236-
* @return power of the Tensor
237237
*/
238-
EXPORT_C torch_tensor_t torch_tensor_power_int(const torch_tensor_t tensor,
239-
const torch_int_t exponent);
238+
EXPORT_C void torch_tensor_power_int(torch_tensor_t output, const torch_tensor_t tensor,
239+
const torch_int_t exponent);
240240

241241
/**
242242
* Overloads the exponentiation operator for a Torch Tensor and a floating point
243243
* exponent
244+
* @param output Tensor
244245
* @param Tensor to take the power of
245246
* @param floating point exponent
246-
* @return power of the Tensor
247247
*/
248-
EXPORT_C torch_tensor_t torch_tensor_power_float(const torch_tensor_t tensor,
249-
const torch_float_t exponent);
248+
EXPORT_C void torch_tensor_power_float(torch_tensor_t output,
249+
const torch_tensor_t tensor,
250+
const torch_float_t exponent);
250251

251252
// =============================================================================
252253
// --- Torch model API

0 commit comments

Comments
 (0)