From d476985199e270a1cbb2bb189ac493d6afae32d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:05:45 +0800 Subject: [PATCH 1/5] Create median_filter_gpu.cu --- src/ops/median_filter_gpu.cu | 148 +++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/ops/median_filter_gpu.cu diff --git a/src/ops/median_filter_gpu.cu b/src/ops/median_filter_gpu.cu new file mode 100644 index 000000000..8dcc61031 --- /dev/null +++ b/src/ops/median_filter_gpu.cu @@ -0,0 +1,148 @@ +#include "ctranslate2/ops/median_filter.h" + +#include +#ifdef CUDA_BF16_AVAILABLE +#include +#endif + +#include "type_dispatch.h" +#include "cuda/helpers.h" +#include + +namespace ctranslate2 { + namespace ops { + + constexpr dim_t num_threads = 256; + + // Conversion helpers + __device__ __forceinline__ float to_float(float v) { return v; } + __device__ __forceinline__ float to_float(const half v) { return __half2float(v); } +#ifdef CUDA_BF16_AVAILABLE + __device__ __forceinline__ float to_float(const __nv_bfloat16 v) { return __bfloat162float(v); } +#endif + + __device__ __forceinline__ float from_float(float v) { return v; } + __device__ __forceinline__ half from_float_half(float v) { return __float2half(v); } +#ifdef CUDA_BF16_AVAILABLE + __device__ __forceinline__ __nv_bfloat16 from_float_bf16(float v) { return __float2bfloat16(v); } +#endif + + namespace { + constexpr int kMaxWindow = 129; // supports window widths up to 129 (rank 64) + } + + template + __global__ void sliding_median_lastdim_kernel(const DeviceT* input, + DeviceT* output, + int rows, + int depth, + int width) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int total = rows * depth; + if (tid >= total) return; + + int row = tid / depth; + int col = tid % depth; + const int rank = width / 2; + + if (depth <= rank) { + output[tid] = input[tid]; + return; + } + if (width > kMax) { + output[tid] = input[tid]; + return; + } + + float window[kMax]; + + const int row_offset = row * depth; + // Reflection gather. + for (int k = -rank; k <= rank; ++k) { + int read = col + k; + if (read < 0) read = -read; + if (read >= depth) read = 2 * depth - read - 2; + window[k + rank] = to_float(input[row_offset + read]); + } + + // Insertion sort (width is small: <= kMax, typically < 129). + for (int i = 1; i < width; ++i) { + float key = window[i]; + int j = i - 1; + while (j >= 0 && window[j] > key) { + window[j + 1] = window[j]; + --j; + } + window[j + 1] = key; + } + float median = window[rank]; + + if constexpr (std::is_same::value) { + output[tid] = median; + } else if constexpr (std::is_same::value) { + output[tid] = from_float_half(median); +#ifdef CUDA_BF16_AVAILABLE + } else if constexpr (std::is_same::value) { + output[tid] = from_float_bf16(median); +#endif + } + } + + template + void MedianFilter::compute(const StorageView& input, + const dim_t axis_size, + StorageView& output) const { + output.resize_as(input); + const int depth = static_cast(axis_size); + const int rows = static_cast(input.size() / depth); + const int width = static_cast(_width); + const int rank = width / 2; + + // Host-side guards and fallbacks. + if (width <= 1) { + if (&output != &input) + output.copy_from(input); + return; + } + if ((width & 1) == 0) + throw std::invalid_argument("MedianFilter width must be odd"); + if (width > kMaxWindow) + throw std::invalid_argument("MedianFilter width exceeds supported GPU max (" + std::to_string(kMaxWindow) + ")"); + if (depth <= rank) { + if (&output != &input) + output.copy_from(input); + return; + } + + // Grid configuration + const int total = rows * depth; + int blocks = (total + num_threads - 1) / num_threads; + if (blocks > cuda::max_blocks) { + blocks = cuda::max_blocks; + } + + using device_t = cuda::device_type; + const device_t* in_ptr = cuda::device_cast(input.data()); + device_t* out_ptr = cuda::device_cast(output.data()); + sliding_median_lastdim_kernel<<>>( + in_ptr, + out_ptr, + rows, + depth, + width); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + } + +#define DECLARE_IMPL(T) \ + template void \ + MedianFilter::compute(const StorageView& input, \ + const dim_t axis_size, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} From d51d66143b5f3a4701e1fa590c6220e958fc3cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:06:15 +0800 Subject: [PATCH 2/5] Create median_filter_cpu.cc --- src/ops/median_filter_cpu.cc | 60 ++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/ops/median_filter_cpu.cc diff --git a/src/ops/median_filter_cpu.cc b/src/ops/median_filter_cpu.cc new file mode 100644 index 000000000..e92f5c95d --- /dev/null +++ b/src/ops/median_filter_cpu.cc @@ -0,0 +1,60 @@ +#include "ctranslate2/ops/median_filter.h" + +#include + +#include +#include "cpu/parallel.h" +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void MedianFilter::compute(const StorageView& input, + const dim_t axis_size, + StorageView& output) const { + const auto* src = input.data(); + auto* dst = output.data(); + + + const dim_t depth = axis_size; + const dim_t batch_size = input.size() / depth; + const dim_t rank = _width / 2; + + if (depth <= rank) + return; + + cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { + StorageView window_storage({_width}, DataType::FLOAT32); + auto* window = window_storage.data(); + + for (dim_t i = begin; i < end; ++i) { + const dim_t offset = i * depth; + const auto* in = src + offset; + auto* out = dst + offset; + + for (dim_t j = 0; j < depth; ++j) { + for (dim_t k = -rank; k <= rank; ++k) { + dim_t read = std::abs(j + k); + if (read >= depth) + read = depth - (read - depth) - 2; + window[k + rank] = in[read]; + } + + std::nth_element(window, window + rank, window + _width); + out[j] = window[rank]; + } + } + }); + } + +#define DECLARE_IMPL(T) \ + template void \ + MedianFilter::compute(const StorageView& input, \ + const dim_t axis_size, \ + StorageView& output) const; + + DECLARE_IMPL(float) + + } +} From f35fb6911498b757ad934c0d7936a700e7091595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:07:20 +0800 Subject: [PATCH 3/5] Update median_filter.h to contain CPU and GPU compute call --- include/ctranslate2/ops/median_filter.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/ctranslate2/ops/median_filter.h b/include/ctranslate2/ops/median_filter.h index ba5e87f0c..b875fa67f 100644 --- a/include/ctranslate2/ops/median_filter.h +++ b/include/ctranslate2/ops/median_filter.h @@ -1,5 +1,4 @@ #pragma once - #include "op.h" namespace ctranslate2 { @@ -7,11 +6,13 @@ namespace ctranslate2 { class MedianFilter : public Op { public: - MedianFilter(const dim_t width); + explicit MedianFilter(dim_t width); void operator()(const StorageView& input, StorageView& output) const; private: const dim_t _width; + template + void compute(const StorageView& input, const dim_t axis_size, StorageView& output) const; }; } From 9cd7f17f57d8a9b8c9c2395959332a961794988b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:08:30 +0800 Subject: [PATCH 4/5] Add CPU and GPU of median_filter operator --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 62b99d136..0bf0fccca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -170,6 +170,8 @@ set(SOURCES src/ops/mean.cc src/ops/mean_cpu.cc src/ops/median_filter.cc + src/ops/median_filter_cpu.cc + src/ops/median_filter_gpu.cu src/ops/min_max.cc src/ops/mul.cc src/ops/multinomial.cc From 08e7dc07c6047cd336e287102f6cc1b05af8afd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:15:47 +0800 Subject: [PATCH 5/5] Update median_filter.cc --- src/ops/median_filter.cc | 48 +++++++--------------------------------- 1 file changed, 8 insertions(+), 40 deletions(-) diff --git a/src/ops/median_filter.cc b/src/ops/median_filter.cc index d83c06e11..ab022faf8 100644 --- a/src/ops/median_filter.cc +++ b/src/ops/median_filter.cc @@ -1,57 +1,25 @@ #include "ctranslate2/ops/median_filter.h" -#include - -#include "cpu/parallel.h" +#include "dispatch.h" namespace ctranslate2 { namespace ops { - MedianFilter::MedianFilter(const dim_t width) + MedianFilter::MedianFilter(dim_t width) : _width(width) - { - } + { + } void MedianFilter::operator()(const StorageView& input, StorageView& output) const { PROFILE("MedianFilter"); - if (input.device() != Device::CPU) - throw std::invalid_argument("MedianFilter currently only supports CPU execution"); + const dim_t axis = input.rank() - 1; + const dim_t axis_size = input.dim(axis); output.resize_as(input); - const dim_t depth = input.dim(-1); - const dim_t batch_size = input.size() / depth; - const dim_t rank = _width / 2; - - if (depth <= rank) - return; - - const auto* src = input.data(); - auto* dst = output.data(); - - cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { - StorageView window_storage({_width}, DataType::FLOAT32); - auto* window = window_storage.data(); - - for (dim_t i = begin; i < end; ++i) { - const dim_t offset = i * depth; - const auto* in = src + offset; - auto* out = dst + offset; - - for (dim_t j = 0; j < depth; ++j) { - for (dim_t k = -rank; k <= rank; ++k) { - dim_t read = std::abs(j + k); - if (read >= depth) - read = depth - (read - depth) - 2; - window[k + rank] = in[read]; - } - - std::nth_element(window, window + rank, window + _width); - out[j] = window[rank]; - } - } - }); + DEVICE_AND_FLOAT_DISPATCH("MedianFilter", input.device(), input.dtype(), + (compute(input, axis_size, output))); } }