Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions include/merlin/core_kernels/lookup.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,53 @@ struct LaunchPipelineLookupV2 {
}
};

// Small-dimension TLP kernel (no Table dependency): one thread per key
// Optimized for tiny value size (e.g., dim <= 2 for 4-byte V), small-to-mid
// batch sizes
template <typename K = uint64_t, typename V = float, typename S = uint32_t>
__global__ void lookup_kernel_tlp_small_dim(
Bucket<K, V, S>* buckets, const size_t buckets_num, const uint32_t dim,
const K* __restrict keys, V* __restrict values, S* __restrict scores,
bool* __restrict founds, const size_t n) {
constexpr int BUCKET_SIZE = 128;
constexpr int TILE_SIZE = 4;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());
const int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / g.size();
// if (group_id >= n) return;
const int rank = g.thread_rank();

const K find_key = keys[group_id];
// const bool inactive = IS_RESERVED_KEY<K>(find_key);
// unsigned int active_mask = g.ballot(!inactive);
// if (active_mask == 0u) return;

size_t bkt_idx = 0;
size_t start_idx = 0;
Bucket<K, V, S>* bucket = get_key_position<K>(
buckets, find_key, bkt_idx, start_idx, buckets_num, BUCKET_SIZE);

int key_pos = -1;
int src_lane = -1;
// Align start index to TILE_SIZE boundary to improve coalesced digest probing
size_t start_idx_aligned = (start_idx & ~(static_cast<size_t>(TILE_SIZE) - 1));
OccupyResult occupy_result = find_without_lock<K, V, S, TILE_SIZE>(
g, bucket, find_key, static_cast<int>(start_idx_aligned), key_pos, src_lane,
BUCKET_SIZE);

const V v_r = __ldg(reinterpret_cast<const V*>(bucket->vectors) + static_cast<size_t>(key_pos));
const S s_r = *(bucket->scores(key_pos));

const bool found = (occupy_result == OccupyResult::DUPLICATE);
if (rank == 0) {
// dim == 1: copy single element directly, assume values/scores/founds are non-null
founds[group_id] = found;
if (found) {
values[group_id] = v_r;
scores[group_id] = s_r;
}
}
}

template <typename ArchTag>
struct LookupValueBufConfig;

Expand All @@ -755,19 +802,46 @@ template <typename K, typename V, typename S = uint64_t,
struct SelectPipelineLookupKernelWithIO {
using ValueBufConfig = LookupValueBufConfig<ArchTag>;

// Helper overloads: try small-dimension TLP launch only for LookupKernelParams (FoundFunctorV1)
template <typename ParamsT>
static inline bool small_dim_tlp_try_launch(ParamsT&, uint32_t, cudaStream_t&) {
return false; // default: not applicable
}

template <typename KK, typename VV, typename SS>
static inline bool small_dim_tlp_try_launch(
LookupKernelParams<KK, VV, SS>& params,
const uint32_t total_value_size,
cudaStream_t& stream) {
if (total_value_size <= 8) {
constexpr int TILE_SIZE = 4;
constexpr int BLOCK = 256; // must be multiple of TILE_SIZE
constexpr int GROUPS_PER_BLOCK = BLOCK / TILE_SIZE;
const int grid = static_cast<int>((params.n + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK);
lookup_kernel_tlp_small_dim<KK, VV, SS><<<grid, BLOCK, 0, stream>>>(
params.buckets, static_cast<size_t>(params.buckets_num),
static_cast<uint32_t>(params.dim), params.keys, params.values,
params.scores, params.found_functor.founds,
static_cast<size_t>(params.n));
return true;
}
return false;
}

static inline uint32_t max_value_size() {
return ValueBufConfig::size_pipeline_v1;
}

template <template <typename, typename, typename> typename LookupKernelParams>
static void select_kernel(LookupKernelParams<K, V, S>& params,
cudaStream_t& stream) {
// Small-dimension direct TLP path: dim*sizeof(V) <= 8
const uint32_t total_value_size = static_cast<uint32_t>(params.dim * sizeof(V));
if (small_dim_tlp_try_launch(params, total_value_size, stream)) return;
constexpr int BUCKET_SIZE = 128;
constexpr uint32_t buf_size_v1 = ValueBufConfig::size_pipeline_v1;
constexpr uint32_t buf_size_v2 = ValueBufConfig::size_pipeline_v2;

uint32_t total_value_size = static_cast<uint32_t>(params.dim * sizeof(V));

if (params.scores == nullptr) {
using CopyScore = CopyScoreEmpty<S, K, BUCKET_SIZE>;
if (total_value_size <= buf_size_v1) {
Expand Down