diff --git a/performance/reduce_by_key.test b/performance/reduce_by_key.test new file mode 100644 index 000000000..10aee8091 --- /dev/null +++ b/performance/reduce_by_key.test @@ -0,0 +1,61 @@ +PREAMBLE = \ + """ + #include + #include + """ + +INITIALIZE = \ + """ + thrust::host_vector<$ValueType> h_values = unittest::random_integers<$ValueType>($InputSize); + thrust::device_vector<$ValueType> d_values = h_values; + + thrust::host_vector<$KeyType> h_keys_result($InputSize); + thrust::host_vector<$ValueType> h_values_result($InputSize); + + thrust::device_vector<$KeyType> d_keys_result($InputSize); + thrust::device_vector<$ValueType> d_values_result($InputSize); + + thrust::default_random_engine rng(13); + thrust::host_vector<$KeyType> h_keys($InputSize); + for(size_t i = 0, k = 0; i < $InputSize; i++) + { + h_keys[i] = k; + if(rng() % 50 == 0) + k++; + } + thrust::device_vector<$KeyType> d_keys = h_keys; + + thrust::pair< + thrust::host_vector<$KeyType>::iterator, + thrust::host_vector<$ValueType>::iterator + > h_end = thrust::reduce_by_key(h_keys.begin(), h_keys.end(), h_values.begin(), h_keys_result.begin(), h_values_result.begin()); + h_keys_result.erase(h_end.first, h_keys_result.end()); + + thrust::pair< + thrust::device_vector<$KeyType>::iterator, + thrust::device_vector<$ValueType>::iterator + > d_end = thrust::reduce_by_key(d_keys.begin(), d_keys.end(), d_values.begin(), d_keys_result.begin(), d_values_result.begin()); + d_keys_result.erase(d_end.first, d_keys_result.end()); + + ASSERT_EQUAL(h_keys_result, d_keys_result); + ASSERT_EQUAL(h_values_result, d_values_result); + """ + +TIME = \ + """ + thrust::reduce_by_key(d_keys.begin(), d_keys.end(), d_values.begin(), d_keys_result.begin(), d_values_result.begin()); + """ + +FINALIZE = \ + """ + RECORD_TIME(); + RECORD_THROUGHPUT(double($InputSize)); + RECORD_BANDWIDTH(sizeof($KeyType) * double(d_keys.size() + d_keys_result.size()) + sizeof($ValueType) * double(d_values.size() + d_values_result.size())); + """ + +KeyTypes = ['int'] #SignedIntegerTypes +ValueTypes = SignedIntegerTypes +InputSizes = [2**24] #StandardSizes + +TestVariables = [('KeyType', KeyTypes), ('ValueType', ValueTypes),('InputSize', InputSizes)] +