Skip to content

Commit

Permalink
Treat floating point counting_iterators special when computing their …
Browse files Browse the repository at this point in the history
…difference_type and comparing them.

Add additional testing for these changes.

Reported by Ollie on thrust-users

Fixes NVIDIA#248
  • Loading branch information
jaredhoberock authored and vamatya committed Feb 22, 2013
1 parent 0754340 commit 57b3a9c
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 23 deletions.
66 changes: 66 additions & 0 deletions testing/counting_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,72 @@ void TestCountingIteratorComparison(void)
DECLARE_UNITTEST(TestCountingIteratorComparison);


void TestCountingIteratorFloatComparison(void)
{
thrust::counting_iterator<float> iter1(0);
thrust::counting_iterator<float> iter2(0);

ASSERT_EQUAL(iter1 - iter2, 0);
ASSERT_EQUAL(iter1 == iter2, true);
ASSERT_EQUAL(iter1 < iter2, false);
ASSERT_EQUAL(iter2 < iter1, false);

iter1++;

ASSERT_EQUAL(iter1 - iter2, 1);
ASSERT_EQUAL(iter1 == iter2, false);
ASSERT_EQUAL(iter2 < iter1, true);
ASSERT_EQUAL(iter1 < iter2, false);

iter2++;

ASSERT_EQUAL(iter1 - iter2, 0);
ASSERT_EQUAL(iter1 == iter2, true);
ASSERT_EQUAL(iter1 < iter2, false);
ASSERT_EQUAL(iter2 < iter1, false);

iter1 += 100;
iter2 += 100;

ASSERT_EQUAL(iter1 - iter2, 0);
ASSERT_EQUAL(iter1 == iter2, true);
ASSERT_EQUAL(iter1 < iter2, false);
ASSERT_EQUAL(iter2 < iter1, false);


thrust::counting_iterator<float> iter3(0);
thrust::counting_iterator<float> iter4(0.5);

ASSERT_EQUAL(iter3 - iter4, 0);
ASSERT_EQUAL(iter3 == iter4, true);
ASSERT_EQUAL(iter3 < iter4, false);
ASSERT_EQUAL(iter4 < iter3, false);

iter3++; // iter3 = 1.0, iter4 = 0.5

ASSERT_EQUAL(iter3 - iter4, 0);
ASSERT_EQUAL(iter3 == iter4, true);
ASSERT_EQUAL(iter3 < iter4, false);
ASSERT_EQUAL(iter4 < iter3, false);

iter4++; // iter3 = 1.0, iter4 = 1.5

ASSERT_EQUAL(iter3 - iter4, 0);
ASSERT_EQUAL(iter3 == iter4, true);
ASSERT_EQUAL(iter3 < iter4, false);
ASSERT_EQUAL(iter4 < iter3, false);

iter4++; // iter3 = 1.0, iter4 = 2.5

ASSERT_EQUAL(iter3 - iter4, -1);
ASSERT_EQUAL(iter4 - iter3, 1);
ASSERT_EQUAL(iter3 == iter4, false);
ASSERT_EQUAL(iter3 < iter4, true);
ASSERT_EQUAL(iter4 < iter3, false);
}
DECLARE_UNITTEST(TestCountingIteratorFloatComparison);


void TestCountingIteratorDistance(void)
{
thrust::counting_iterator<int> iter1(0);
Expand Down
21 changes: 21 additions & 0 deletions testing/reduce.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <unittest/unittest.h>
#include <thrust/reduce.h>
#include <thrust/iterator/counting_iterator.h>
#include <limits>

template<typename T>
struct plus_mod_10
Expand Down Expand Up @@ -187,3 +189,22 @@ void TestReduceWithIndirection(void)
}
DECLARE_VECTOR_UNITTEST(TestReduceWithIndirection);

template<typename T>
void TestReduceCountingIterator(size_t n)
{
// be careful not to generate a range larger than we can represent
n = thrust::min<size_t>(n, std::numeric_limits<T>::max());

thrust::counting_iterator<T, thrust::host_system_tag> h_first = thrust::make_counting_iterator<T>(0);
thrust::counting_iterator<T, thrust::device_system_tag> d_first = thrust::make_counting_iterator<T>(0);

T init = 13;

T h_result = thrust::reduce(h_first, h_first + n, init);
T d_result = thrust::reduce(d_first, d_first + n, init);

// we use ASSERT_ALMOST_EQUAL because we're testing floating point types
ASSERT_ALMOST_EQUAL(h_result, d_result);
}
DECLARE_VARIABLE_UNITTEST(TestReduceCountingIterator);

46 changes: 24 additions & 22 deletions testing/transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -739,41 +739,43 @@ void TestTransformIfBinaryToDiscardIterator(const size_t n)
DECLARE_VARIABLE_UNITTEST(TestTransformIfBinaryToDiscardIterator);


template <class Vector>
void TestTransformUnaryCountingIterator(void)
template <class T>
void TestTransformUnaryCountingIterator(size_t n)
{
typedef typename Vector::value_type T;
// be careful not to generate a range larger than we can represent
n = thrust::min<size_t>(n, std::numeric_limits<T>::max());

thrust::counting_iterator<T> first(1);
thrust::counting_iterator<T, thrust::host_system_tag> h_first = thrust::make_counting_iterator<T>(0);
thrust::counting_iterator<T, thrust::device_system_tag> d_first = thrust::make_counting_iterator<T>(0);

Vector output(3);
thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

thrust::transform(first, first + 3, output.begin(), thrust::identity<T>());
thrust::transform(h_first, h_first + n, h_result.begin(), thrust::identity<T>());
thrust::transform(d_first, d_first + n, d_result.begin(), thrust::identity<T>());

Vector result(3);
result[0] = 1; result[1] = 2; result[2] = 3;

ASSERT_EQUAL(output, result);
ASSERT_EQUAL(h_result, d_result);
}
DECLARE_VECTOR_UNITTEST(TestTransformUnaryCountingIterator);
DECLARE_VARIABLE_UNITTEST(TestTransformUnaryCountingIterator);

template <class Vector>
void TestTransformBinaryCountingIterator(void)
template <typename T>
void TestTransformBinaryCountingIterator(size_t n)
{
typedef typename Vector::value_type T;
// be careful not to generate a range larger than we can represent
n = thrust::min<size_t>(n, std::numeric_limits<T>::max());

thrust::counting_iterator<T> first(1);
thrust::counting_iterator<T, thrust::host_system_tag> h_first = thrust::make_counting_iterator<T>(0);
thrust::counting_iterator<T, thrust::device_system_tag> d_first = thrust::make_counting_iterator<T>(0);

Vector output(3);
thrust::host_vector<T> h_result(n);
thrust::device_vector<T> d_result(n);

thrust::transform(first, first + 3, first, output.begin(), thrust::plus<T>());
thrust::transform(h_first, h_first + n, h_first, h_result.begin(), thrust::plus<T>());
thrust::transform(d_first, d_first + n, d_first, d_result.begin(), thrust::plus<T>());

Vector result(3);
result[0] = 2; result[1] = 4; result[2] = 6;

ASSERT_EQUAL(output, result);
ASSERT_EQUAL(h_result, d_result);
}
DECLARE_VECTOR_UNITTEST(TestTransformBinaryCountingIterator);
DECLARE_VARIABLE_UNITTEST(TestTransformBinaryCountingIterator);


template <typename T>
Expand Down
9 changes: 9 additions & 0 deletions thrust/iterator/counting_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ template<typename Incrementable,
return this->base_reference();
}

// note that we implement equal specially for floating point counting_iterator
template <typename OtherIncrementable, typename OtherSystem, typename OtherTraversal, typename OtherDifference>
__host__ __device__
bool equal(counting_iterator<OtherIncrementable, OtherSystem, OtherTraversal, OtherDifference> const& y) const
{
typedef thrust::detail::counting_iterator_equal<difference_type,Incrementable,OtherIncrementable> e;
return e::equal(this->base(), y.base());
}

template <class OtherIncrementable>
__host__ __device__
difference_type
Expand Down
42 changes: 41 additions & 1 deletion thrust/iterator/detail/counting_iterator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/detail/numeric_traits.h>
#include <thrust/detail/type_traits.h>
#include <cstddef>

namespace thrust
{
Expand Down Expand Up @@ -49,11 +51,17 @@ template <typename Incrementable, typename System, typename Traversal, typename
>
>::type traversal;

// unlike Boost, we explicitly use std::ptrdiff_t as the difference type
// for floating point counting_iterators
typedef typename thrust::experimental::detail::ia_dflt_help<
Difference,
thrust::detail::eval_if<
thrust::detail::is_numeric<Incrementable>::value,
thrust::detail::numeric_difference<Incrementable>,
thrust::detail::eval_if<
thrust::detail::is_integral<Incrementable>::value,
thrust::detail::numeric_difference<Incrementable>,
thrust::detail::identity_<std::ptrdiff_t>
>,
thrust::iterator_difference<Incrementable>
>
>::type difference;
Expand Down Expand Up @@ -97,6 +105,38 @@ template<typename Difference, typename Incrementable1, typename Incrementable2>
};


template<typename Difference, typename Incrementable1, typename Incrementable2, typename Enable = void>
struct counting_iterator_equal
{
__host__ __device__
static bool equal(Incrementable1 x, Incrementable2 y)
{
return x == y;
}
};


// specialization for floating point equality
template<typename Difference, typename Incrementable1, typename Incrementable2>
struct counting_iterator_equal<
Difference,
Incrementable1,
Incrementable2,
typename thrust::detail::enable_if<
thrust::detail::is_floating_point<Incrementable1>::value ||
thrust::detail::is_floating_point<Incrementable2>::value
>::type
>
{
__host__ __device__
static bool equal(Incrementable1 x, Incrementable2 y)
{
typedef number_distance<Difference,Incrementable1,Incrementable2> d;
return d::distance(x,y) == 0;
}
};


} // end detail
} // end thrust

0 comments on commit 57b3a9c

Please sign in to comment.