Skip to content

Commit

Permalink
Hoist compare augmentation out of get_set_operation_splitter_ranks an…
Browse files Browse the repository at this point in the history
…d into get_merge_splitter_ranks

We only need one compare augmentation instead of two.
  • Loading branch information
jaredhoberock committed Nov 1, 2010
1 parent 3a0ed0d commit b6233e2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,6 @@ struct mult_by
}
};

// this predicate tests two two-element tuples
// we first use a Compare for the first element
// if the first elements are equivalent, we use
// < for the second elements
template<typename Compare>
struct compare_first_less_second
{
compare_first_less_second(Compare c)
: comp(c) {}

template<typename Tuple>
__host__ __device__
bool operator()(Tuple lhs, Tuple rhs)
{
return comp(lhs.get<0>(), rhs.get<0>()) || (!comp(rhs.get<0>(), lhs.get<0>()) && lhs.get<1>() < rhs.get<1>());
}

Compare comp;
}; // end compare_first_less_second

template<typename Iterator1, typename Iterator2, typename Compare>
struct select_functor
{
Expand Down Expand Up @@ -158,29 +138,6 @@ template<typename RandomAccessIterator, typename Size>
return thrust::make_permutation_iterator(iter, make_leapfrog_iterator<difference>(0, split_size));
} // end make_splitter_iterator()


template<typename Compare>
struct strong_compare
{
strong_compare(Compare c)
: comp(c) {}

// T1 and T2 are tuples
template<typename T1, typename T2>
__host__ __device__
bool operator()(T1 lhs, T2 rhs)
{
if(comp(lhs.get<0>(), rhs.get<0>()))
{
return true;
}

return lhs.get<1>() < rhs.get<1>();
}

Compare comp;
};

} // end get_set_operation_splitter_ranks_detail

template<typename RandomAccessIterator1,
Expand Down Expand Up @@ -216,24 +173,22 @@ template<typename RandomAccessIterator1,
splitter_iterator2 splitters2_begin = make_splitter_iterator(first2, partition_size) + 1;
splitter_iterator2 splitters2_end = splitters2_begin + num_splitters_from_each_range;

typedef compare_first_less_second<Compare> splitter_compare;

typedef typename merge_iterator<splitter_iterator1,splitter_iterator2,splitter_compare>::type merge_iterator;
typedef typename merge_iterator<splitter_iterator1,splitter_iterator2,Compare>::type merge_iterator;

// "merge" the splitters
merge_iterator splitters_begin = make_merge_iterator(splitters1_begin, splitters1_end,
splitters2_begin, splitters2_end,
splitter_compare(comp));
comp);
merge_iterator splitters_end = splitters_begin + 2 * num_splitters_from_each_range;

// find the rank of each splitter in the other range
thrust::lower_bound(first2, last2,
splitters_begin, splitters_end,
splitter_ranks2, strong_compare<Compare>(comp));
splitter_ranks2, comp);

thrust::upper_bound(first1, last1,
splitters_begin, splitters_end,
splitter_ranks1, strong_compare<Compare>(comp));
splitter_ranks1, comp);
} // end get_set_operation_splitter_ranks()

} // end detail
Expand Down
27 changes: 26 additions & 1 deletion thrust/detail/device/cuda/merge.inl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,28 @@ __global__ void merge_kernel(const RandomAccessIterator1 first1,
} // end merge_kernel


// this predicate tests two two-element tuples
// we first use a Compare for the first element
// if the first elements are equivalent, we use
// < for the second elements
template<typename Compare>
struct compare_first_less_second
{
compare_first_less_second(Compare c)
: comp(c) {}

template<typename T1, typename T2>
__host__ __device__
bool operator()(T1 lhs, T2 rhs)
{
return comp(lhs.get<0>(), rhs.get<0>()) || (!comp(rhs.get<0>(), lhs.get<0>()) && lhs.get<1>() < rhs.get<1>());
}

Compare comp;
}; // end compare_first_less_second



template<typename RandomAccessIterator1,
typename RandomAccessIterator2,
typename RandomAccessIterator3,
Expand Down Expand Up @@ -234,12 +256,15 @@ template<typename RandomAccessIterator1,
thrust::make_zip_iterator(thrust::make_tuple(first2, thrust::make_counting_iterator<difference2>(num_elements1)));
iterator_and_counter2 last_and_counter2 = first_and_counter2 + num_elements2;

// take into account the tuples when comparing
typedef compare_first_less_second<Compare> splitter_compare;

using namespace thrust::detail::device::cuda::detail;
return get_set_operation_splitter_ranks(first_and_counter1, last_and_counter1,
first_and_counter2, last_and_counter2,
splitter_ranks1,
splitter_ranks2,
comp,
splitter_compare(comp),
partition_size,
num_splitters_from_each_range);
} // end get_merge_splitter_ranks()
Expand Down

0 comments on commit b6233e2

Please sign in to comment.