From 57b3a9cdbb69d79755853d41b83fe29f02d7404c Mon Sep 17 00:00:00 2001 From: Jared Hoberock Date: Wed, 7 Nov 2012 16:03:02 -0800 Subject: [PATCH] Treat floating point counting_iterators special when computing their difference_type and comparing them. Add additional testing for these changes. Reported by Ollie on thrust-users Fixes #248 --- testing/counting_iterator.cu | 66 ++++++++++++++++++++ testing/reduce.cu | 21 +++++++ testing/transform.cu | 46 +++++++------- thrust/iterator/counting_iterator.h | 9 +++ thrust/iterator/detail/counting_iterator.inl | 42 ++++++++++++- 5 files changed, 161 insertions(+), 23 deletions(-) diff --git a/testing/counting_iterator.cu b/testing/counting_iterator.cu index 32ddc7122..8c7c0fec9 100644 --- a/testing/counting_iterator.cu +++ b/testing/counting_iterator.cu @@ -80,6 +80,72 @@ void TestCountingIteratorComparison(void) DECLARE_UNITTEST(TestCountingIteratorComparison); +void TestCountingIteratorFloatComparison(void) +{ + thrust::counting_iterator iter1(0); + thrust::counting_iterator iter2(0); + + ASSERT_EQUAL(iter1 - iter2, 0); + ASSERT_EQUAL(iter1 == iter2, true); + ASSERT_EQUAL(iter1 < iter2, false); + ASSERT_EQUAL(iter2 < iter1, false); + + iter1++; + + ASSERT_EQUAL(iter1 - iter2, 1); + ASSERT_EQUAL(iter1 == iter2, false); + ASSERT_EQUAL(iter2 < iter1, true); + ASSERT_EQUAL(iter1 < iter2, false); + + iter2++; + + ASSERT_EQUAL(iter1 - iter2, 0); + ASSERT_EQUAL(iter1 == iter2, true); + ASSERT_EQUAL(iter1 < iter2, false); + ASSERT_EQUAL(iter2 < iter1, false); + + iter1 += 100; + iter2 += 100; + + ASSERT_EQUAL(iter1 - iter2, 0); + ASSERT_EQUAL(iter1 == iter2, true); + ASSERT_EQUAL(iter1 < iter2, false); + ASSERT_EQUAL(iter2 < iter1, false); + + + thrust::counting_iterator iter3(0); + thrust::counting_iterator iter4(0.5); + + ASSERT_EQUAL(iter3 - iter4, 0); + ASSERT_EQUAL(iter3 == iter4, true); + ASSERT_EQUAL(iter3 < iter4, false); + ASSERT_EQUAL(iter4 < iter3, false); + + iter3++; // iter3 = 1.0, iter4 = 0.5 + + ASSERT_EQUAL(iter3 - iter4, 0); + ASSERT_EQUAL(iter3 == iter4, true); + ASSERT_EQUAL(iter3 < iter4, false); + ASSERT_EQUAL(iter4 < iter3, false); + + iter4++; // iter3 = 1.0, iter4 = 1.5 + + ASSERT_EQUAL(iter3 - iter4, 0); + ASSERT_EQUAL(iter3 == iter4, true); + ASSERT_EQUAL(iter3 < iter4, false); + ASSERT_EQUAL(iter4 < iter3, false); + + iter4++; // iter3 = 1.0, iter4 = 2.5 + + ASSERT_EQUAL(iter3 - iter4, -1); + ASSERT_EQUAL(iter4 - iter3, 1); + ASSERT_EQUAL(iter3 == iter4, false); + ASSERT_EQUAL(iter3 < iter4, true); + ASSERT_EQUAL(iter4 < iter3, false); +} +DECLARE_UNITTEST(TestCountingIteratorFloatComparison); + + void TestCountingIteratorDistance(void) { thrust::counting_iterator iter1(0); diff --git a/testing/reduce.cu b/testing/reduce.cu index 47b8162a1..48002d081 100644 --- a/testing/reduce.cu +++ b/testing/reduce.cu @@ -1,5 +1,7 @@ #include #include +#include +#include template struct plus_mod_10 @@ -187,3 +189,22 @@ void TestReduceWithIndirection(void) } DECLARE_VECTOR_UNITTEST(TestReduceWithIndirection); +template + void TestReduceCountingIterator(size_t n) +{ + // be careful not to generate a range larger than we can represent + n = thrust::min(n, std::numeric_limits::max()); + + thrust::counting_iterator h_first = thrust::make_counting_iterator(0); + thrust::counting_iterator d_first = thrust::make_counting_iterator(0); + + T init = 13; + + T h_result = thrust::reduce(h_first, h_first + n, init); + T d_result = thrust::reduce(d_first, d_first + n, init); + + // we use ASSERT_ALMOST_EQUAL because we're testing floating point types + ASSERT_ALMOST_EQUAL(h_result, d_result); +} +DECLARE_VARIABLE_UNITTEST(TestReduceCountingIterator); + diff --git a/testing/transform.cu b/testing/transform.cu index f8593e228..723068f45 100644 --- a/testing/transform.cu +++ b/testing/transform.cu @@ -739,41 +739,43 @@ void TestTransformIfBinaryToDiscardIterator(const size_t n) DECLARE_VARIABLE_UNITTEST(TestTransformIfBinaryToDiscardIterator); -template -void TestTransformUnaryCountingIterator(void) +template + void TestTransformUnaryCountingIterator(size_t n) { - typedef typename Vector::value_type T; + // be careful not to generate a range larger than we can represent + n = thrust::min(n, std::numeric_limits::max()); - thrust::counting_iterator first(1); + thrust::counting_iterator h_first = thrust::make_counting_iterator(0); + thrust::counting_iterator d_first = thrust::make_counting_iterator(0); - Vector output(3); + thrust::host_vector h_result(n); + thrust::device_vector d_result(n); - thrust::transform(first, first + 3, output.begin(), thrust::identity()); + thrust::transform(h_first, h_first + n, h_result.begin(), thrust::identity()); + thrust::transform(d_first, d_first + n, d_result.begin(), thrust::identity()); - Vector result(3); - result[0] = 1; result[1] = 2; result[2] = 3; - - ASSERT_EQUAL(output, result); + ASSERT_EQUAL(h_result, d_result); } -DECLARE_VECTOR_UNITTEST(TestTransformUnaryCountingIterator); +DECLARE_VARIABLE_UNITTEST(TestTransformUnaryCountingIterator); -template -void TestTransformBinaryCountingIterator(void) +template + void TestTransformBinaryCountingIterator(size_t n) { - typedef typename Vector::value_type T; + // be careful not to generate a range larger than we can represent + n = thrust::min(n, std::numeric_limits::max()); - thrust::counting_iterator first(1); + thrust::counting_iterator h_first = thrust::make_counting_iterator(0); + thrust::counting_iterator d_first = thrust::make_counting_iterator(0); - Vector output(3); + thrust::host_vector h_result(n); + thrust::device_vector d_result(n); - thrust::transform(first, first + 3, first, output.begin(), thrust::plus()); + thrust::transform(h_first, h_first + n, h_first, h_result.begin(), thrust::plus()); + thrust::transform(d_first, d_first + n, d_first, d_result.begin(), thrust::plus()); - Vector result(3); - result[0] = 2; result[1] = 4; result[2] = 6; - - ASSERT_EQUAL(output, result); + ASSERT_EQUAL(h_result, d_result); } -DECLARE_VECTOR_UNITTEST(TestTransformBinaryCountingIterator); +DECLARE_VARIABLE_UNITTEST(TestTransformBinaryCountingIterator); template diff --git a/thrust/iterator/counting_iterator.h b/thrust/iterator/counting_iterator.h index e2f857016..5df03fd3d 100644 --- a/thrust/iterator/counting_iterator.h +++ b/thrust/iterator/counting_iterator.h @@ -190,6 +190,15 @@ templatebase_reference(); } + // note that we implement equal specially for floating point counting_iterator + template + __host__ __device__ + bool equal(counting_iterator const& y) const + { + typedef thrust::detail::counting_iterator_equal e; + return e::equal(this->base(), y.base()); + } + template __host__ __device__ difference_type diff --git a/thrust/iterator/detail/counting_iterator.inl b/thrust/iterator/detail/counting_iterator.inl index 06c815da5..1822dac73 100644 --- a/thrust/iterator/detail/counting_iterator.inl +++ b/thrust/iterator/detail/counting_iterator.inl @@ -19,6 +19,8 @@ #include #include #include +#include +#include namespace thrust { @@ -49,11 +51,17 @@ template >::type traversal; + // unlike Boost, we explicitly use std::ptrdiff_t as the difference type + // for floating point counting_iterators typedef typename thrust::experimental::detail::ia_dflt_help< Difference, thrust::detail::eval_if< thrust::detail::is_numeric::value, - thrust::detail::numeric_difference, + thrust::detail::eval_if< + thrust::detail::is_integral::value, + thrust::detail::numeric_difference, + thrust::detail::identity_ + >, thrust::iterator_difference > >::type difference; @@ -97,6 +105,38 @@ template }; +template + struct counting_iterator_equal +{ + __host__ __device__ + static bool equal(Incrementable1 x, Incrementable2 y) + { + return x == y; + } +}; + + +// specialization for floating point equality +template + struct counting_iterator_equal< + Difference, + Incrementable1, + Incrementable2, + typename thrust::detail::enable_if< + thrust::detail::is_floating_point::value || + thrust::detail::is_floating_point::value + >::type + > +{ + __host__ __device__ + static bool equal(Incrementable1 x, Incrementable2 y) + { + typedef number_distance d; + return d::distance(x,y) == 0; + } +}; + + } // end detail } // end thrust