Skip to content

Commit

Permalink
fixed bug in find_if()
Browse files Browse the repository at this point in the history
added initial draft of mismatch() and unit tests
  • Loading branch information
wnbell committed Apr 8, 2010
1 parent bd7bf10 commit 2431305
Show file tree
Hide file tree
Showing 13 changed files with 557 additions and 12 deletions.
62 changes: 62 additions & 0 deletions performance/find.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
PREAMBLE = \
"""
#include <thrust/find.h>
#include <thrust/reduce.h>
#include <thrust/extrema.h>

template <typename Vector>
void find_partial(const Vector& v)
{
thrust::find(v.begin(), v.end(), 1);
}

template <typename Vector>
void find_full(const Vector& v)
{
thrust::max_element(v.begin(), v.end());
}

template <typename Vector>
void reduce_full(const Vector& v)
{
thrust::max_element(v.begin(), v.end());
}
"""

INITIALIZE = \
"""
thrust::host_vector<$InputType> h_input($InputSize, 0);
thrust::device_vector<$InputType> d_input($InputSize, 0);

size_t pos = $Fraction * $InputSize;

if (pos < $InputSize)
{
h_input[pos] = 1;
d_input[pos] = 1;
}

size_t h_index = thrust::find(h_input.begin(), h_input.end(), 1) - h_input.begin();
size_t d_index = thrust::find(d_input.begin(), d_input.end(), 1) - d_input.begin();

ASSERT_EQUAL(h_index, d_index);
"""

TIME = \
"""
$Method(d_input);
"""

FINALIZE = \
"""
RECORD_TIME();
RECORD_BANDWIDTH(sizeof($InputType) * double($InputSize));
"""

InputTypes = ['int']
InputSizes = [2**23]
Fractions = [0.01, 0.99]
Methods = ['find_partial', 'find_full', 'reduce_full']

TestVariables = [('InputType', InputTypes), ('InputSize', InputSizes), ('Fraction', Fractions), ('Method', Methods)]

10 changes: 4 additions & 6 deletions testing/find.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@ void TestFindSimple(void)
vec[0] = 1;
vec[1] = 2;
vec[2] = 3;
vec[3] = 4;
vec[3] = 3;
vec[4] = 5;

ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 0) - vec.begin(), 5);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 1) - vec.begin(), 0);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 2) - vec.begin(), 1);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 3) - vec.begin(), 2);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 4) - vec.begin(), 3);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 4) - vec.begin(), 5);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 5) - vec.begin(), 4);
ASSERT_EQUAL(thrust::find(vec.begin(), vec.end(), 6) - vec.begin(), 5);
}
DECLARE_VECTOR_UNITTEST(TestFindSimple);

Expand All @@ -43,16 +42,15 @@ void TestFindIfSimple(void)
vec[0] = 1;
vec[1] = 2;
vec[2] = 3;
vec[3] = 4;
vec[3] = 3;
vec[4] = 5;

ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(0)) - vec.begin(), 5);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(1)) - vec.begin(), 0);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(2)) - vec.begin(), 1);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(3)) - vec.begin(), 2);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(4)) - vec.begin(), 3);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(4)) - vec.begin(), 5);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(5)) - vec.begin(), 4);
ASSERT_EQUAL(thrust::find_if(vec.begin(), vec.end(), equal_to_value_pred<T>(6)) - vec.begin(), 5);
}
DECLARE_VECTOR_UNITTEST(TestFindIfSimple);

Expand Down
29 changes: 29 additions & 0 deletions testing/mismatch.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <thrusttest/unittest.h>
#include <thrust/mismatch.h>

template <class Vector>
void TestMismatchSimple(void)
{
typedef typename Vector::value_type T;

Vector a(4); Vector b(4);
a[0] = 1; b[0] = 1;
a[1] = 2; b[1] = 2;
a[2] = 3; b[2] = 4;
a[3] = 4; b[3] = 3;

ASSERT_EQUAL(thrust::mismatch(a.begin(), a.end(), b.begin()).first - a.begin(), 2);
ASSERT_EQUAL(thrust::mismatch(a.begin(), a.end(), b.begin()).second - b.begin(), 2);

b[2] = 3;

ASSERT_EQUAL(thrust::mismatch(a.begin(), a.end(), b.begin()).first - a.begin(), 3);
ASSERT_EQUAL(thrust::mismatch(a.begin(), a.end(), b.begin()).second - b.begin(), 3);

b[3] = 4;

ASSERT_EQUAL(thrust::mismatch(a.begin(), a.end(), b.begin()).first - a.begin(), 4);
ASSERT_EQUAL(thrust::mismatch(a.begin(), a.end(), b.begin()).second - b.begin(), 4);
}
DECLARE_VECTOR_UNITTEST(TestMismatchSimple);

28 changes: 23 additions & 5 deletions thrust/detail/device/generic/find.inl
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@

#include <thrust/detail/device/reduce.h>

#include <thrust/functional.h>
#include <thrust/tuple.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/zip_iterator.h>

// Implementation of find_if() using short-circuiting
// Contributed by Erich Elsen

namespace thrust
Expand All @@ -34,6 +32,23 @@ namespace device
namespace generic
{

template <typename TupleType>
struct find_if_functor
{
__host__ __device__
TupleType operator()(const TupleType& lhs, const TupleType& rhs) const
{
// select the smallest index among true results
if (thrust::get<0>(lhs) && thrust::get<0>(rhs))
return TupleType(true, min(thrust::get<1>(lhs), thrust::get<1>(rhs)));
else if (thrust::get<0>(lhs))
return lhs;
else
return rhs;
}
};


template <typename InputIterator, typename Predicate>
InputIterator find_if(InputIterator first,
InputIterator last,
Expand All @@ -48,8 +63,11 @@ InputIterator find_if(InputIterator first,

const difference_type n = thrust::distance(first, last);

// this implementation breaks up the sequence into separate intervals
// in an attempt to early-out as soon as a value is found

// TODO incorporate sizeof(InputType) into interval_threshold and round to multiple of 32
const difference_type interval_threshold = n; //1 << 20; // XXX disabled until performance is sorted out
const difference_type interval_threshold = 1 << 20;
const difference_type interval_size = std::min(interval_threshold, n);

for(difference_type begin = 0; begin < n; begin += interval_size)
Expand All @@ -75,11 +93,11 @@ InputIterator find_if(InputIterator first,
)
) + end,
result_type(false, end),
thrust::maximum<result_type>()
find_if_functor<result_type>()
);

// see if we found something
if (thrust::get<1>(result) != end)
if (thrust::get<0>(result))
{
return first + thrust::get<1>(result);
}
Expand Down
47 changes: 47 additions & 0 deletions thrust/detail/device/generic/mismatch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2008-2010 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


/*! \file find.h
* \brief Search for differences between sequences [generic device].
*/

#pragma once

#include <thrust/pair.h>

namespace thrust
{
namespace detail
{
namespace device
{
namespace generic
{

template <typename InputIterator1, typename InputIterator2, typename BinaryPredicate>
thrust::pair<InputIterator1, InputIterator2> mismatch(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
BinaryPredicate pred);

} // end namespace generic
} // end namespace device
} // end namespace detail
} // end namespace thrust

#include <thrust/detail/device/generic/mismatch.inl>

78 changes: 78 additions & 0 deletions thrust/detail/device/generic/mismatch.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2008-2010 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


#include <thrust/pair.h>
#include <thrust/distance.h>
#include <thrust/iterator/iterator_traits.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include <thrust/detail/internal_functional.h>

#include <thrust/detail/device/find.h>

// Contributed by Erich Elsen

namespace thrust
{
namespace detail
{
namespace device
{
namespace generic
{

template <typename InputIterator1, typename InputIterator2, typename BinaryPredicate>
thrust::pair<InputIterator1, InputIterator2> mismatch(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
BinaryPredicate pred)
{
typedef typename thrust::iterator_traits<InputIterator1>::difference_type difference_type;
typedef typename thrust::tuple<bool,difference_type> result_type;

const difference_type n = thrust::distance(first1, last1);

difference_type offset = thrust::detail::device::find_if
(
thrust::make_transform_iterator
(
thrust::make_zip_iterator(thrust::make_tuple(first1, first2)),
thrust::detail::tuple_equal_to<BinaryPredicate>(pred)
),
thrust::make_transform_iterator
(
thrust::make_zip_iterator(thrust::make_tuple(first1, first2)),
thrust::detail::tuple_equal_to<BinaryPredicate>(pred)
) + n,
thrust::detail::equal_to_value<bool>(false)
)
-
thrust::make_transform_iterator
(
thrust::make_zip_iterator(thrust::make_tuple(first1, first2)),
thrust::detail::tuple_equal_to<BinaryPredicate>(pred)
);

return thrust::make_pair(first1 + offset, first2 + offset);
}

} // end namespace generic
} // end namespace device
} // end namespace detail
} // end namespace thrust

47 changes: 47 additions & 0 deletions thrust/detail/device/mismatch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2008-2010 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


/*! \file mismatch.h
* \brief Search for differences between sequences [device].
*/

#pragma once

#include <thrust/pair.h>

#include <thrust/detail/device/generic/mismatch.h>

namespace thrust
{
namespace detail
{
namespace device
{

template <typename InputIterator1, typename InputIterator2, typename BinaryPredicate>
thrust::pair<InputIterator1, InputIterator2> mismatch(InputIterator1 first1,
InputIterator1 last1,
InputIterator2 first2,
BinaryPredicate pred)
{
return thrust::detail::device::generic::mismatch(first1, last1, first2, pred);
}

} // end namespace device
} // end namespace detail
} // end namespace thrust

2 changes: 1 addition & 1 deletion thrust/detail/dispatch/find.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/


/*! \file gather.h
/*! \file find.h
* \brief Dispatch layer of the find functions.
*/

Expand Down
Loading

0 comments on commit 2431305

Please sign in to comment.