diff --git a/include/merlin/core_kernels/lookup.cuh b/include/merlin/core_kernels/lookup.cuh index 9c443659..e3a223a6 100644 --- a/include/merlin/core_kernels/lookup.cuh +++ b/include/merlin/core_kernels/lookup.cuh @@ -734,6 +734,53 @@ struct LaunchPipelineLookupV2 { } }; +// Small-dimension TLP kernel (no Table dependency): one thread per key +// Optimized for tiny value size (e.g., dim <= 2 for 4-byte V), small-to-mid +// batch sizes +template +__global__ void lookup_kernel_tlp_small_dim( + Bucket* buckets, const size_t buckets_num, const uint32_t dim, + const K* __restrict keys, V* __restrict values, S* __restrict scores, + bool* __restrict founds, const size_t n) { + constexpr int BUCKET_SIZE = 128; + constexpr int TILE_SIZE = 4; + auto g = cg::tiled_partition(cg::this_thread_block()); + const int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / g.size(); + // if (group_id >= n) return; + const int rank = g.thread_rank(); + + const K find_key = keys[group_id]; + // const bool inactive = IS_RESERVED_KEY(find_key); + // unsigned int active_mask = g.ballot(!inactive); + // if (active_mask == 0u) return; + + size_t bkt_idx = 0; + size_t start_idx = 0; + Bucket* bucket = get_key_position( + buckets, find_key, bkt_idx, start_idx, buckets_num, BUCKET_SIZE); + + int key_pos = -1; + int src_lane = -1; + // Align start index to TILE_SIZE boundary to improve coalesced digest probing + size_t start_idx_aligned = (start_idx & ~(static_cast(TILE_SIZE) - 1)); + OccupyResult occupy_result = find_without_lock( + g, bucket, find_key, static_cast(start_idx_aligned), key_pos, src_lane, + BUCKET_SIZE); + + const V v_r = __ldg(reinterpret_cast(bucket->vectors) + static_cast(key_pos)); + const S s_r = *(bucket->scores(key_pos)); + + const bool found = (occupy_result == OccupyResult::DUPLICATE); + if (rank == 0) { + // dim == 1: copy single element directly, assume values/scores/founds are non-null + founds[group_id] = found; + if (found) { + values[group_id] = v_r; + scores[group_id] = s_r; + } + } +} + template struct LookupValueBufConfig; @@ -755,6 +802,32 @@ template ; + // Helper overloads: try small-dimension TLP launch only for LookupKernelParams (FoundFunctorV1) + template + static inline bool small_dim_tlp_try_launch(ParamsT&, uint32_t, cudaStream_t&) { + return false; // default: not applicable + } + + template + static inline bool small_dim_tlp_try_launch( + LookupKernelParams& params, + const uint32_t total_value_size, + cudaStream_t& stream) { + if (total_value_size <= 8) { + constexpr int TILE_SIZE = 4; + constexpr int BLOCK = 256; // must be multiple of TILE_SIZE + constexpr int GROUPS_PER_BLOCK = BLOCK / TILE_SIZE; + const int grid = static_cast((params.n + GROUPS_PER_BLOCK - 1) / GROUPS_PER_BLOCK); + lookup_kernel_tlp_small_dim<<>>( + params.buckets, static_cast(params.buckets_num), + static_cast(params.dim), params.keys, params.values, + params.scores, params.found_functor.founds, + static_cast(params.n)); + return true; + } + return false; + } + static inline uint32_t max_value_size() { return ValueBufConfig::size_pipeline_v1; } @@ -762,12 +835,13 @@ struct SelectPipelineLookupKernelWithIO { template