From b64b0ae36a733615ac42f326027b4949892d29c2 Mon Sep 17 00:00:00 2001 From: Nathan Bell Date: Tue, 8 Feb 2011 08:05:22 -0500 Subject: [PATCH] implemented arch-specific tuning for cuda::fill() resolves issue #286 P4 Change 8755144 on 2011/02/08 08:05:22 by nbell --- .LATEST_P4_CHANGELIST | 2 +- performance/fill_optimization.test | 1 + testing/arch.cu | 22 ++++++++++++++++++++++ thrust/detail/device/cuda/arch.h | 9 +++++++++ thrust/detail/device/cuda/arch.inl | 19 +++++++++++++++---- thrust/detail/device/cuda/fill.inl | 27 ++++++++++++++++++++------- 6 files changed, 68 insertions(+), 12 deletions(-) diff --git a/.LATEST_P4_CHANGELIST b/.LATEST_P4_CHANGELIST index f1c581661..96442d696 100644 --- a/.LATEST_P4_CHANGELIST +++ b/.LATEST_P4_CHANGELIST @@ -1 +1 @@ -8750234 \ No newline at end of file +8755144 \ No newline at end of file diff --git a/performance/fill_optimization.test b/performance/fill_optimization.test index fc3a4debd..3b03fad9e 100644 --- a/performance/fill_optimization.test +++ b/performance/fill_optimization.test @@ -9,6 +9,7 @@ PREAMBLE = \ T x; constant_functor(T x) : x(x) {} + __host__ __device__ T operator()(void) const {return x;} }; diff --git a/testing/arch.cu b/testing/arch.cu index e41ce4bf9..ebc89ba64 100644 --- a/testing/arch.cu +++ b/testing/arch.cu @@ -56,6 +56,28 @@ void set_func_attributes(cudaFuncAttributes& attributes, attributes.sharedSizeBytes = sharedSizeBytes; } +void TestComputeCapability(void) +{ + cudaDeviceProp properties; + + set_compute_capability(properties, 1, 0); + ASSERT_EQUAL(compute_capability(properties), 10); + + set_compute_capability(properties, 1, 1); + ASSERT_EQUAL(compute_capability(properties), 11); + + set_compute_capability(properties, 1, 3); + ASSERT_EQUAL(compute_capability(properties), 13); + + set_compute_capability(properties, 2, 0); + ASSERT_EQUAL(compute_capability(properties), 20); + + set_compute_capability(properties, 2, 1); + ASSERT_EQUAL(compute_capability(properties), 21); +} +DECLARE_UNITTEST(TestComputeCapability); + + void TestMaxActiveThreads(void) { cudaDeviceProp properties; diff --git a/thrust/detail/device/cuda/arch.h b/thrust/detail/device/cuda/arch.h index 266500d9b..999e207b0 100644 --- a/thrust/detail/device/cuda/arch.h +++ b/thrust/detail/device/cuda/arch.h @@ -49,6 +49,15 @@ namespace cuda namespace arch { + +/*! This function returns the compute capability of a device. + * For example, returns 10 for sm_10 and 21 for sm_21 + * \return The compute capability as an integer + */ + +inline size_t compute_capability(const cudaDeviceProp &properties); +inline size_t compute_capability(void); + /*! This function returns the number of streaming * multiprocessors available for processing. * \return The number of SMs available. diff --git a/thrust/detail/device/cuda/arch.inl b/thrust/detail/device/cuda/arch.inl index 690f1c6ce..406ecded1 100644 --- a/thrust/detail/device/cuda/arch.inl +++ b/thrust/detail/device/cuda/arch.inl @@ -43,7 +43,7 @@ namespace arch namespace detail { -inline void checked_get_current_device_properties(cudaDeviceProp &props) +inline void checked_get_current_device_properties(cudaDeviceProp &properties) { int current_device = -1; @@ -66,7 +66,7 @@ inline void checked_get_current_device_properties(cudaDeviceProp &props) if(iter == properties_map.end()) { // the properties weren't found, ask the runtime to generate them - error = cudaGetDeviceProperties(&props, current_device); + error = cudaGetDeviceProperties(&properties, current_device); if(error) { @@ -74,12 +74,12 @@ inline void checked_get_current_device_properties(cudaDeviceProp &props) } // insert the new entry - properties_map[current_device] = props; + properties_map[current_device] = properties; } // end if else { // use the cached value - props = iter->second; + properties = iter->second; } // end else } // end checked_get_current_device_properties() @@ -119,6 +119,10 @@ void checked_get_function_attributes(cudaFuncAttributes& attributes, KernelFunct } // end detail +size_t compute_capability(const cudaDeviceProp &properties) +{ + return 10 * properties.major + properties.minor; +} // end compute_capability() size_t num_multiprocessors(const cudaDeviceProp& properties) { @@ -187,6 +191,13 @@ size_t max_active_blocks_per_multiprocessor(const cudaDeviceProp& properties, // Functions that query the runtime for device properties +size_t compute_capability(void) +{ + cudaDeviceProp properties; + detail::checked_get_current_device_properties(properties); + return compute_capability(properties); +} // end compute_capability() + size_t num_multiprocessors(void) { diff --git a/thrust/detail/device/cuda/fill.inl b/thrust/detail/device/cuda/fill.inl index 72d3f4dc3..3fe84a8bb 100644 --- a/thrust/detail/device/cuda/fill.inl +++ b/thrust/detail/device/cuda/fill.inl @@ -28,6 +28,8 @@ #include #include +#include + namespace thrust { namespace detail @@ -39,7 +41,7 @@ namespace cuda namespace detail { -template +template Pointer wide_fill_n(Pointer first, Size n, const T &value) @@ -48,9 +50,6 @@ template size_t ALIGNMENT_BOUNDARY = 128; // begin copying blocks at this byte boundary - // type used to pack the OutputTypes - typedef unsigned long long WideType; - WideType wide_exemplar; OutputType narrow_exemplars[sizeof(WideType) / sizeof(OutputType)]; @@ -98,11 +97,25 @@ template if ( thrust::detail::util::is_aligned(thrust::raw_pointer_cast(&*first)) ) { - wide_fill_n(&*first, n, value); + if (arch::compute_capability() < 20) + { + // 32-bit writes are faster on G80 and GT200 + typedef unsigned int WideType; + wide_fill_n(&*first, n, value); + } + else + { + // 64-bit writes are faster on Fermi + typedef unsigned long long WideType; + wide_fill_n(&*first, n, value); + } + return first + n; } - - return fill_n(first, n, value, thrust::detail::false_type()); + else + { + return fill_n(first, n, value, thrust::detail::false_type()); + } } } // end detail