Skip to content

Commit

Permalink
Merge pull request #414 from gchanan/thrustalloc
Browse files Browse the repository at this point in the history
Re-route thrust memory allocation to THCudaMalloc / THCudaFree in cunn.
  • Loading branch information
soumith authored Jan 11, 2017
2 parents 349df42 + 1f8292f commit 49d1c06
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 6 deletions.
1 change: 1 addition & 0 deletions lib/THCUNN/LookupTable.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "THCUNN.h"
#include "common.h"

#include "THCThrustAllocator.cuh"
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
Expand Down
1 change: 1 addition & 0 deletions lib/THCUNN/MSECriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCThrustAllocator.cuh"

#include <thrust/fill.h>
#include <thrust/functional.h>
Expand Down
1 change: 1 addition & 0 deletions lib/THCUNN/SmoothL1Criterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCThrustAllocator.cuh"

#include <thrust/fill.h>
#include <thrust/functional.h>
Expand Down
5 changes: 3 additions & 2 deletions lib/THCUNN/generic/LookupTable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ void THNN_(LookupTable_accGradParameters)(
THCIndexTensor_(resizeAs)(state, count, input);
count_data = THCIndexTensor_(data)(state, count);

THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<THCIndex_t> sorted_ptr(sorted_data);
thrust::device_ptr<THCIndex_t> count_ptr(count_data);

Expand All @@ -72,7 +73,7 @@ void THNN_(LookupTable_accGradParameters)(
// count: 1 1 2 3 1 2 1 1 2
thrust::inclusive_scan_by_key(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
sorted_ptr,
sorted_ptr + numel,
Expand All @@ -85,7 +86,7 @@ void THNN_(LookupTable_accGradParameters)(
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
thrust::make_reverse_iterator(sorted_ptr + numel),
thrust::make_reverse_iterator(sorted_ptr),
Expand Down
6 changes: 4 additions & 2 deletions lib/THCUNN/generic/MSECriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ void THNN_(MSECriterion_updateOutput)(
input = THCTensor_(newContiguous)(state, input);
target = THCTensor_(newContiguous)(state, target);

THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<real> input_data(THCTensor_(data)(state, input));
thrust::device_ptr<real> target_data(THCTensor_(data)(state, target));
accreal sum = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
input_data, input_data+size, target_data, (accreal) 0,
thrust::plus<accreal>(), mse_functor<real, accreal>());
Expand Down Expand Up @@ -54,13 +55,14 @@ void THNN_(MSECriterion_updateGradInput)(

THCTensor_(resizeAs)(state, gradInput, input);

THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<real> input_data(THCTensor_(data)(state, input));
thrust::device_ptr<real> target_data(THCTensor_(data)(state, target));
thrust::device_ptr<real> gradInput_data(THCTensor_(data)(state, gradInput));

thrust::transform(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
input_data, input_data+size, target_data, gradInput_data,
mse_updateGradInput_functor<real, accreal>(norm));
Expand Down
6 changes: 4 additions & 2 deletions lib/THCUNN/generic/SmoothL1Criterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ void THNN_(SmoothL1Criterion_updateOutput)(
input = THCTensor_(newContiguous)(state, input);
target = THCTensor_(newContiguous)(state, target);

THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<real> input_data(THCTensor_(data)(state, input));
thrust::device_ptr<real> target_data(THCTensor_(data)(state, target));
accreal sum = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
input_data, input_data+size, target_data, (accreal) 0,
thrust::plus<accreal>(), smoothl1_functor<real, accreal>()
Expand Down Expand Up @@ -63,13 +64,14 @@ void THNN_(SmoothL1Criterion_updateGradInput)(

THCTensor_(resizeAs)(state, gradInput, input);

THCThrustAllocator thrustAlloc(state);
thrust::device_ptr<real> input_data(THCTensor_(data)(state, input));
thrust::device_ptr<real> target_data(THCTensor_(data)(state, target));
thrust::device_ptr<real> gradInput_data(THCTensor_(data)(state, gradInput));

thrust::transform(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
input_data, input_data+size, target_data, gradInput_data,
smoothl1_updateGradInput_functor<real>(norm)
Expand Down

0 comments on commit 49d1c06

Please sign in to comment.