Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ set(SOURCES
src/ops/mean.cc
src/ops/mean_cpu.cc
src/ops/median_filter.cc
src/ops/median_filter_cpu.cc
src/ops/median_filter_gpu.cu
src/ops/min_max.cc
src/ops/mul.cc
src/ops/multinomial.cc
Expand Down
5 changes: 3 additions & 2 deletions include/ctranslate2/ops/median_filter.h
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class MedianFilter : public Op {
public:
MedianFilter(const dim_t width);
explicit MedianFilter(dim_t width);
void operator()(const StorageView& input, StorageView& output) const;

private:
const dim_t _width;
template <Device D, typename T>
void compute(const StorageView& input, const dim_t axis_size, StorageView& output) const;
};

}
Expand Down
48 changes: 8 additions & 40 deletions src/ops/median_filter.cc
Original file line number Diff line number Diff line change
@@ -1,57 +1,25 @@
#include "ctranslate2/ops/median_filter.h"

#include <algorithm>

#include "cpu/parallel.h"
#include "dispatch.h"

namespace ctranslate2 {
namespace ops {

MedianFilter::MedianFilter(const dim_t width)
MedianFilter::MedianFilter(dim_t width)
: _width(width)
{
}
{
}

void MedianFilter::operator()(const StorageView& input, StorageView& output) const {
PROFILE("MedianFilter");

if (input.device() != Device::CPU)
throw std::invalid_argument("MedianFilter currently only supports CPU execution");
const dim_t axis = input.rank() - 1;
const dim_t axis_size = input.dim(axis);

output.resize_as(input);

const dim_t depth = input.dim(-1);
const dim_t batch_size = input.size() / depth;
const dim_t rank = _width / 2;

if (depth <= rank)
return;

const auto* src = input.data<float>();
auto* dst = output.data<float>();

cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
StorageView window_storage({_width}, DataType::FLOAT32);
auto* window = window_storage.data<float>();

for (dim_t i = begin; i < end; ++i) {
const dim_t offset = i * depth;
const auto* in = src + offset;
auto* out = dst + offset;

for (dim_t j = 0; j < depth; ++j) {
for (dim_t k = -rank; k <= rank; ++k) {
dim_t read = std::abs(j + k);
if (read >= depth)
read = depth - (read - depth) - 2;
window[k + rank] = in[read];
}

std::nth_element(window, window + rank, window + _width);
out[j] = window[rank];
}
}
});
DEVICE_AND_FLOAT_DISPATCH("MedianFilter", input.device(), input.dtype(),
(compute<D, T>(input, axis_size, output)));
}

}
Expand Down
60 changes: 60 additions & 0 deletions src/ops/median_filter_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "ctranslate2/ops/median_filter.h"

#include <iostream>

#include <algorithm>
#include "cpu/parallel.h"
#include "type_dispatch.h"

namespace ctranslate2 {
namespace ops {

template <Device D, typename T>
void MedianFilter::compute(const StorageView& input,
const dim_t axis_size,
StorageView& output) const {
const auto* src = input.data<T>();
auto* dst = output.data<T>();


const dim_t depth = axis_size;
const dim_t batch_size = input.size() / depth;
const dim_t rank = _width / 2;

if (depth <= rank)
return;

cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
StorageView window_storage({_width}, DataType::FLOAT32);
auto* window = window_storage.data<float>();

for (dim_t i = begin; i < end; ++i) {
const dim_t offset = i * depth;
const auto* in = src + offset;
auto* out = dst + offset;

for (dim_t j = 0; j < depth; ++j) {
for (dim_t k = -rank; k <= rank; ++k) {
dim_t read = std::abs(j + k);
if (read >= depth)
read = depth - (read - depth) - 2;
window[k + rank] = in[read];
}

std::nth_element(window, window + rank, window + _width);
out[j] = window[rank];
}
}
});
}

#define DECLARE_IMPL(T) \
template void \
MedianFilter::compute<Device::CPU, T>(const StorageView& input, \
const dim_t axis_size, \
StorageView& output) const;

DECLARE_IMPL(float)

}
}
148 changes: 148 additions & 0 deletions src/ops/median_filter_gpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#include "ctranslate2/ops/median_filter.h"

#include <cuda_fp16.h>
#ifdef CUDA_BF16_AVAILABLE
#include <cuda_bf16.h>
#endif

#include "type_dispatch.h"
#include "cuda/helpers.h"
#include <type_traits>

namespace ctranslate2 {
namespace ops {

constexpr dim_t num_threads = 256;

// Conversion helpers
__device__ __forceinline__ float to_float(float v) { return v; }
__device__ __forceinline__ float to_float(const half v) { return __half2float(v); }
#ifdef CUDA_BF16_AVAILABLE
__device__ __forceinline__ float to_float(const __nv_bfloat16 v) { return __bfloat162float(v); }
#endif

__device__ __forceinline__ float from_float(float v) { return v; }
__device__ __forceinline__ half from_float_half(float v) { return __float2half(v); }
#ifdef CUDA_BF16_AVAILABLE
__device__ __forceinline__ __nv_bfloat16 from_float_bf16(float v) { return __float2bfloat16(v); }
#endif

namespace {
constexpr int kMaxWindow = 129; // supports window widths up to 129 (rank 64)
}

template <typename DeviceT, int kMax>
__global__ void sliding_median_lastdim_kernel(const DeviceT* input,
DeviceT* output,
int rows,
int depth,
int width) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int total = rows * depth;
if (tid >= total) return;

int row = tid / depth;
int col = tid % depth;
const int rank = width / 2;

if (depth <= rank) {
output[tid] = input[tid];
return;
}
if (width > kMax) {
output[tid] = input[tid];
return;
}

float window[kMax];

const int row_offset = row * depth;
// Reflection gather.
for (int k = -rank; k <= rank; ++k) {
int read = col + k;
if (read < 0) read = -read;
if (read >= depth) read = 2 * depth - read - 2;
window[k + rank] = to_float(input[row_offset + read]);
}

// Insertion sort (width is small: <= kMax, typically < 129).
for (int i = 1; i < width; ++i) {
float key = window[i];
int j = i - 1;
while (j >= 0 && window[j] > key) {
window[j + 1] = window[j];
--j;
}
window[j + 1] = key;
}
float median = window[rank];

if constexpr (std::is_same<DeviceT, float>::value) {
output[tid] = median;
} else if constexpr (std::is_same<DeviceT, half>::value) {
output[tid] = from_float_half(median);
#ifdef CUDA_BF16_AVAILABLE
} else if constexpr (std::is_same<DeviceT, __nv_bfloat16>::value) {
output[tid] = from_float_bf16(median);
#endif
}
}

template <Device D, typename T>
void MedianFilter::compute(const StorageView& input,
const dim_t axis_size,
StorageView& output) const {
output.resize_as(input);
const int depth = static_cast<int>(axis_size);
const int rows = static_cast<int>(input.size() / depth);
const int width = static_cast<int>(_width);
const int rank = width / 2;

// Host-side guards and fallbacks.
if (width <= 1) {
if (&output != &input)
output.copy_from(input);
return;
}
if ((width & 1) == 0)
throw std::invalid_argument("MedianFilter width must be odd");
if (width > kMaxWindow)
throw std::invalid_argument("MedianFilter width exceeds supported GPU max (" + std::to_string(kMaxWindow) + ")");
if (depth <= rank) {
if (&output != &input)
output.copy_from(input);
return;
}

// Grid configuration
const int total = rows * depth;
int blocks = (total + num_threads - 1) / num_threads;
if (blocks > cuda::max_blocks) {
blocks = cuda::max_blocks;
}

using device_t = cuda::device_type<T>;
const device_t* in_ptr = cuda::device_cast(input.data<T>());
device_t* out_ptr = cuda::device_cast(output.data<T>());
sliding_median_lastdim_kernel<device_t, kMaxWindow><<<blocks, num_threads, 0, cuda::get_cuda_stream()>>>(
in_ptr,
out_ptr,
rows,
depth,
width);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
}

#define DECLARE_IMPL(T) \
template void \
MedianFilter::compute<Device::CUDA, T>(const StorageView& input, \
const dim_t axis_size, \
StorageView& output) const;

DECLARE_IMPL(float)
DECLARE_IMPL(float16_t)
DECLARE_IMPL(bfloat16_t)

}
}
Loading