Skip to content

Commit

Permalink
Strip mine in CUDA implementation of merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredhoberock committed Oct 26, 2010
1 parent 8611520 commit 7ebd80f
Showing 1 changed file with 86 additions and 67 deletions.
153 changes: 86 additions & 67 deletions thrust/detail/device/cuda/merge.inl
Original file line number Diff line number Diff line change
Expand Up @@ -92,51 +92,19 @@ template<unsigned int block_size,
typename RandomAccessIterator3,
typename RandomAccessIterator4,
typename RandomAccessIterator5,
typename StrictWeakOrdering>
typename StrictWeakOrdering,
typename Size>
__launch_bounds__(block_size, 1)
__global__ void merge_kernel(RandomAccessIterator1 first1,
RandomAccessIterator1 last1,
RandomAccessIterator2 first2,
RandomAccessIterator2 last2,
__global__ void merge_kernel(const RandomAccessIterator1 first1,
const RandomAccessIterator1 last1,
const RandomAccessIterator2 first2,
const RandomAccessIterator2 last2,
RandomAccessIterator3 splitter_ranks1,
RandomAccessIterator4 splitter_ranks2,
RandomAccessIterator5 result,
StrictWeakOrdering comp)
const RandomAccessIterator5 result,
StrictWeakOrdering comp,
Size num_merged_partitions)
{
const unsigned int partition_idx = blockIdx.x;

// advance iterators
splitter_ranks1 += partition_idx;
splitter_ranks2 += partition_idx;

// find the end of the input if this is not the last block
// the end of merged partition i is at splitter_ranks1[i] + splitter_ranks2[i]
if(partition_idx != gridDim.x - 1)
{
RandomAccessIterator3 temp1 = splitter_ranks1;
RandomAccessIterator4 temp2 = splitter_ranks2;

last1 = first1 + dereference(temp1);
last2 = first2 + dereference(temp2);
}

// find the beginning of the input and output if this is not the first block
// merged partition i begins at splitter_ranks1[i-1] + splitter_ranks2[i-1]
if(partition_idx != 0)
{
RandomAccessIterator3 temp1 = splitter_ranks1;
--temp1;
RandomAccessIterator4 temp2 = splitter_ranks2;
--temp2;

// advance the input to point to the beginning
first1 += dereference(temp1);
first2 += dereference(temp2);

// advance the result to point to the beginning of the output
result += dereference(temp1);
result += dereference(temp2);
}

typedef typename thrust::iterator_value<RandomAccessIterator5>::type value_type;

Expand All @@ -150,40 +118,87 @@ __global__ void merge_kernel(RandomAccessIterator1 first1,
value_type *s_input2 = reinterpret_cast<value_type*>(_shared2);
value_type *s_result = reinterpret_cast<value_type*>(_result);

if(first1 < last1 && first2 < last2)
// advance splitter iterators
splitter_ranks1 += blockIdx.x;
splitter_ranks2 += blockIdx.x;

for(Size partition_idx = blockIdx.x;
partition_idx < num_merged_partitions;
partition_idx += gridDim.x,
splitter_ranks1 += gridDim.x,
splitter_ranks2 += gridDim.x)
{
typedef typename thrust::iterator_difference<RandomAccessIterator1>::type difference1;
RandomAccessIterator1 input_begin1 = first1;
RandomAccessIterator1 input_end1 = last1;
RandomAccessIterator2 input_begin2 = first2;
RandomAccessIterator2 input_end2 = last2;

typedef typename thrust::iterator_difference<RandomAccessIterator2>::type difference2;
RandomAccessIterator5 output_begin = result;

// load the first segment
difference1 s_input1_size = thrust::min<difference1>(block_size, last1 - first1);
// find the end of the input if this is not the last block
// the end of merged partition i is at splitter_ranks1[i] + splitter_ranks2[i]
if(partition_idx != num_merged_partitions - 1)
{
RandomAccessIterator3 rank1 = splitter_ranks1;
RandomAccessIterator4 rank2 = splitter_ranks2;

block::copy(first1, first1 + s_input1_size, s_input1);
first1 += s_input1_size;
input_end1 = first1 + dereference(rank1);
input_end2 = first2 + dereference(rank2);
}

// load the second segment
difference2 s_input2_size = thrust::min<difference2>(block_size, last2 - first2);
// find the beginning of the input and output if this is not the first partition
// merged partition i begins at splitter_ranks1[i-1] + splitter_ranks2[i-1]
if(partition_idx != 0)
{
RandomAccessIterator3 rank1 = splitter_ranks1;
--rank1;
RandomAccessIterator4 rank2 = splitter_ranks2;
--rank2;

// advance the input to point to the beginning
input_begin1 += dereference(rank1);
input_begin2 += dereference(rank2);

// advance the result to point to the beginning of the output
output_begin += dereference(rank1);
output_begin += dereference(rank2);
}

block::copy(first2, first2 + s_input2_size, s_input2);
first2 += s_input2_size;
if(input_begin1 < input_end1 && input_begin2 < input_end2)
{
typedef typename thrust::iterator_difference<RandomAccessIterator1>::type difference1;

__syncthreads();
typedef typename thrust::iterator_difference<RandomAccessIterator2>::type difference2;

block::merge(s_input1, s_input1 + s_input1_size,
s_input2, s_input2 + s_input2_size,
s_result,
comp);
// load the first segment
difference1 s_input1_size = thrust::min<difference1>(block_size, input_end1 - input_begin1);

__syncthreads();
block::copy(input_begin1, input_begin1 + s_input1_size, s_input1);
input_begin1 += s_input1_size;

// store to gmem
result = block::copy(s_result, s_result + s_input1_size + s_input2_size, result);
}
// load the second segment
difference2 s_input2_size = thrust::min<difference2>(block_size, input_end2 - input_begin2);

// simply copy any remaining input
block::copy(first2, last2, block::copy(first1, last1, result));
}
block::copy(input_begin2, input_begin2 + s_input2_size, s_input2);
input_begin2 += s_input2_size;

__syncthreads();

block::merge(s_input1, s_input1 + s_input1_size,
s_input2, s_input2 + s_input2_size,
s_result,
comp);

__syncthreads();

// store to gmem
output_begin = block::copy(s_result, s_result + s_input1_size + s_input2_size, output_begin);
}

// simply copy any remaining input
block::copy(input_begin2, input_end2, block::copy(input_begin1, input_end1, output_begin));
} // end for partition
} // end merge_kernel


template<typename T>
Expand Down Expand Up @@ -340,7 +355,7 @@ RandomAccessIterator3 merge(RandomAccessIterator1 first1,
typedef typename thrust::iterator_value<RandomAccessIterator1>::type value_type;

// XXX vary block_size dynamically
const size_t block_size = 128;
const size_t block_size = 512;
const size_t partition_size = block_size;

const difference1 num_partitions = ceil_div(num_elements1, partition_size);
Expand Down Expand Up @@ -399,13 +414,17 @@ RandomAccessIterator3 merge(RandomAccessIterator1 first1,
splitters_begin, splitters_end,
splitter_ranks1.begin(), strong_compare<Compare>(comp));

merge_detail::merge_kernel<block_size><<<num_merged_partitions, (unsigned int) block_size >>>(
// maximize the number of blocks we can launch
size_t num_blocks = thrust::min(num_merged_partitions, 64000u);

merge_detail::merge_kernel<block_size><<<num_blocks, (unsigned int) block_size >>>(
first1, last1,
first2, last2,
splitter_ranks1.begin(),
splitter_ranks2.begin(),
result,
comp);
comp,
num_merged_partitions);
synchronize_if_enabled("merge_kernel");

return result + num_elements1 + num_elements2;
Expand Down

0 comments on commit 7ebd80f

Please sign in to comment.