diff --git a/src/ir/op/tensor_ops/gather.cpp b/src/ir/op/tensor_ops/gather.cpp index 4a286261f..bdac22b46 100644 --- a/src/ir/op/tensor_ops/gather.cpp +++ b/src/ir/op/tensor_ops/gather.cpp @@ -46,15 +46,17 @@ TypePtr DeduceTensorGatherType(const std::vector& args, CHECK(input_type) << "The operator " << op_name << " requires input to be a TensorType, but got " << args[0]->GetType()->TypeName(); CHECK(input_type->dtype_ == DataType::FP16 || input_type->dtype_ == DataType::FP32 || - input_type->dtype_ == DataType::INT16 || input_type->dtype_ == DataType::INT32) - << "The operator " << op_name << " requires input dtype to be FP16, FP32, INT16, or INT32, but got " + input_type->dtype_ == DataType::INT8 || input_type->dtype_ == DataType::INT16 || + input_type->dtype_ == DataType::INT32) + << "The operator " << op_name + << " requires input dtype to be FP16, FP32, INT8, INT16, or INT32, but got " << input_type->dtype_.ToString(); auto index_type = As(args[1]->GetType()); CHECK(index_type) << "The operator " << op_name << " requires index to be a TensorType, but got " << args[1]->GetType()->TypeName(); - CHECK(index_type->dtype_ == DataType::INT32) - << "The operator " << op_name << " requires index dtype to be INT32, but got " + CHECK(index_type->dtype_ == DataType::INT32 || index_type->dtype_ == DataType::INT16) + << "The operator " << op_name << " requires index dtype to be INT16 or INT32, but got " << index_type->dtype_.ToString(); const int64_t rank = static_cast(input_type->shape_.size()); @@ -101,8 +103,8 @@ REGISTER_OP("tensor.gather") "Gather elements of input along the specified dimension using the index tensor " "(tensor-level). Supports rank>=2 and any dim; lowered via tile.transpose + " "tile.reshape + tile.gather by ConvertTensorToTileOps.") - .add_argument("input", "Input tensor (TensorType; FP16, FP32, INT16, or INT32)") - .add_argument("index", "Index tensor (TensorType, INT32, same shape as output)") + .add_argument("input", "Input tensor (TensorType; FP16, FP32, INT8, INT16, or INT32)") + .add_argument("index", "Index tensor (TensorType; INT16 or INT32, same shape as output)") .set_attr("dim") .f_deduce_type([](const std::vector& args, const std::vector>& kwargs) { diff --git a/src/ir/op/tile_ops/gather.cpp b/src/ir/op/tile_ops/gather.cpp index 58159d81c..0a7328e85 100644 --- a/src/ir/op/tile_ops/gather.cpp +++ b/src/ir/op/tile_ops/gather.cpp @@ -46,30 +46,31 @@ static TypePtr DeduceTileGatherType(const std::vector& args, CHECK(args.size() == 3) << "The operator " << op_name << " requires 3 arguments (src, indices, tmp), but got " << args.size(); - // First arg: src tile (f16, f32, i16, or i32) + // First arg: src tile auto src_type = As(args[0]->GetType()); CHECK(src_type) << "The operator " << op_name << " requires first argument to be a TileType, but got " << args[0]->GetType()->TypeName(); CHECK(src_type->dtype_ == DataType::FP16 || src_type->dtype_ == DataType::FP32 || - src_type->dtype_ == DataType::INT16 || src_type->dtype_ == DataType::INT32) - << "The operator " << op_name << " requires src dtype to be FP16, FP32, INT16, or INT32, but got " + src_type->dtype_ == DataType::INT8 || src_type->dtype_ == DataType::INT16 || + src_type->dtype_ == DataType::INT32) + << "The operator " << op_name << " requires src dtype to be FP16, FP32, INT8, INT16, or INT32, but got " << src_type->dtype_.ToString(); - // Second arg: indices tile (must be i32) + // Second arg: indices tile auto idx_type = As(args[1]->GetType()); CHECK(idx_type) << "The operator " << op_name << " requires second argument to be a TileType, but got " << args[1]->GetType()->TypeName(); - CHECK(idx_type->dtype_ == DataType::INT32) - << "The operator " << op_name << " requires indices dtype to be INT32, but got " + CHECK(idx_type->dtype_ == DataType::INT32 || idx_type->dtype_ == DataType::INT16) + << "The operator " << op_name << " requires indices dtype to be INT16 or INT32, but got " << idx_type->dtype_.ToString(); - // Third arg: tmp workspace tile (must be i32, same shape as indices) + // Third arg: tmp workspace tile (must match indices dtype, same shape rank) auto tmp_type = As(args[2]->GetType()); CHECK(tmp_type) << "The operator " << op_name << " requires third argument to be a TileType, but got " << args[2]->GetType()->TypeName(); - CHECK(tmp_type->dtype_ == DataType::INT32) - << "The operator " << op_name << " requires tmp dtype to be INT32, but got " - << tmp_type->dtype_.ToString(); + CHECK(tmp_type->dtype_ == idx_type->dtype_) + << "The operator " << op_name << " requires tmp dtype to match indices dtype (" + << idx_type->dtype_.ToString() << "), but got " << tmp_type->dtype_.ToString(); CHECK(tmp_type->shape_.size() == idx_type->shape_.size()) << "The operator " << op_name << " requires tmp shape rank to match indices rank (" << idx_type->shape_.size() << "), but got " << tmp_type->shape_.size(); @@ -88,9 +89,9 @@ static TypePtr DeduceTileGatherType(const std::vector& args, REGISTER_OP("tile.gather") .set_op_category("TileOp") .set_description("Gather elements by index (maps to pto.tgather)") - .add_argument("src", "Source tile (FP16, FP32, INT16, or INT32)") - .add_argument("indices", "Index tile (INT32)") - .add_argument("tmp", "Temporary workspace tile (INT32)") + .add_argument("src", "Source tile (FP16, FP32, INT8, INT16, or INT32)") + .add_argument("indices", "Index tile (INT16 or INT32)") + .add_argument("tmp", "Temporary workspace tile (dtype must match indices)") .set_input_memory(0, MemorySpace::Vec) .set_input_memory(1, MemorySpace::Vec) .set_input_memory(2, MemorySpace::Vec) diff --git a/src/ir/transforms/op_conversion_registry.cpp b/src/ir/transforms/op_conversion_registry.cpp index 4b4a2dab7..a308abbea 100644 --- a/src/ir/transforms/op_conversion_registry.cpp +++ b/src/ir/transforms/op_conversion_registry.cpp @@ -909,7 +909,20 @@ void OpConversionRegistry::RegisterSortOps() { // We always add a trailing tile.reshape so Phase 3 (RewriteReturnedAssemble- // LoopToStore) does not fire; we want the full-tile store path instead. // -// Four cases (by rank and norm_dim): +// Supported (rank, norm_dim) routes: +// +// Case 1 rank==2, dim==1 (last): nested-loop single-row gather +// Case 2 rank==3, dim==2 (last): nested-loop single-row gather +// Case 3 rank==3, dim==0 (first): emit_flat_index_gather +// Case 4 rank==3, dim==1 (middle): emit_flat_index_gather +// Case 5 rank==2, dim==0 (first): emit_flat_index_gather +// Case 6 rank>=4, any dim: emit_flat_index_gather +// +// Cases 1–2 use dedicated nested-loop lowerings (described below). Cases 3–6 +// share the generalized `emit_flat_index_gather` helper, which collapses every +// non-gather/non-last axis into the outer loop variable and emits a single-row +// gather per output row. See the helper-local block comment near the +// definition of `emit_flat_index_gather` for the per-iteration formula. // // Case 1 rank==2, dim==1 (last): // Loop over I0 rows: load [1,S1] and [1,K], single-row gather. @@ -920,22 +933,6 @@ void OpConversionRegistry::RegisterSortOps() { // Load [1,1,S2]→reshape[1,S2]; Load [1,1,K]→reshape[1,K]; gather [1,K]. // Inner acc [I1,K]; reshape→[1,I1*K]; outer acc [I0,I1*K]. // Final reshape [I0,I1*K]→[I0*I1,K]; tile.store at [0,0,0]. -// -// Case 3 rank==3, dim==0 (first): -// Flat-index gather: for each output row r = i0*I1+i1: -// inp_flat = inp[:, i1, :].flatten() → [1, S0*I2] -// idx_row = idx[i0, i1, :] → [1, I2] -// flat_idx = idx_row * I2 + [0..I2-1] → [1, I2] -// out_row = gather(inp_flat, flat_idx) → [1, I2] -// Accumulator [I0*I1, I2]; reshape→[I0*I1,I2]; tile.store at [0,0,0]. -// -// Case 4 rank==3, dim==1 (middle): -// Flat-index gather: for each output row r = i0*I1+i1: -// inp_flat = inp[i0, :, :].flatten() → [1, S1*I2] -// idx_row = idx[i0, i1, :] → [1, I2] -// flat_idx = idx_row * I2 + [0..I2-1] → [1, I2] -// out_row = gather(inp_flat, flat_idx) → [1, I2] -// Accumulator [I0*I1, I2]; reshape→[I0*I1,I2]; tile.store at [0,0,0]. // ============================================================================ void OpConversionRegistry::RegisterGatherOps() { @@ -968,20 +965,23 @@ void OpConversionRegistry::RegisterGatherOps() { << "tensor.gather conversion: dim out of range, got " << dim_val; DataType input_dtype = input_tensor_type->dtype_; + DataType idx_dtype = index_tensor_type->dtype_; auto make_idx = [&](int64_t value) -> ExprPtr { return std::make_shared(value, DataType::INDEX, span); }; - auto make_i32 = [&](int64_t value) -> ExprPtr { - return std::make_shared(value, DataType::INT32, span); + auto make_idx_scalar = [&](int64_t value) -> ExprPtr { + return std::make_shared(value, idx_dtype, span); }; auto zero = make_idx(0); auto one = make_idx(1); std::vector> load_kwargs = {{"target_memory", MemorySpace::Vec}, {"transpose", false}}; + // tile.gather requires tmp_dtype == idx_dtype; track the indices dtype so + // INT16 / INT32 indices share the same lowering path. std::vector> tmp_create_kwargs = { - {"dtype", DataType(DataType::INT32)}, {"target_memory", MemorySpace::Vec}}; + {"dtype", idx_dtype}, {"target_memory", MemorySpace::Vec}}; std::vector prologue; @@ -1122,130 +1122,160 @@ void OpConversionRegistry::RegisterGatherOps() { } // ================================================================ - // Case 3 rank==3, dim==0 (first dim) - // out[i0, i1, k] = inp[idx[i0, i1, k], i1, k] - // Result tile: [I0*I1, I2] where tile[i0*I1+i1, k] = output[i0, i1, k]. + // Cases 3, 4, 5, 6 — flat-index gather helper. // - // Uses flat-index gather to avoid intermediate tiles with I0 (potentially - // non-8-aligned) columns, which would violate hardware 32-byte row alignment. - // For each output row r = i0*I1+i1: - // inp_flat = inp[:, i1, :].flatten() → [1, S0*S2] - // idx_row = idx[i0, i1, :] → [1, I2] - // flat_idx = idx_row * S2 + [0..I2-1] → [1, I2] - // out_row = gather(inp_flat, flat_idx) → [1, I2] - // ================================================================ - if (rank == 3 && norm_dim == 0) { - int64_t S0 = get_const(input_shape[0], "input.shape[0]"); - int64_t S2 = get_const(input_shape[2], "input.shape[2]"); - int64_t I0 = get_const(index_shape[0], "index.shape[0]"); - int64_t I1 = get_const(index_shape[1], "index.shape[1]"); - int64_t I2 = get_const(index_shape[2], "index.shape[2]"); - int64_t I0I1 = I0 * I1; - int64_t S0S2 = S0 * S2; - - // Precompute constant range tile [0, 1, ..., I2-1] (shared across all loop iterations). - std::vector> ci_kw = {{"dtype", DataType(DataType::INT32)}}; - auto range_1d = emit("tile.ci", {make_i32(0), MakeShapeTuple({one, make_idx(I2)}, span)}, ci_kw, - "gather_range"); - - // Outer loop: r=0..I0*I1-1, accumulating [I0*I1, I2]. - auto result = make_loop( - prologue, "gather_main", make_idx(I0I1), I0I1, I2, input_dtype, - [&](const VarPtr& lv, const IterArgPtr& /*ia*/, std::vector& bs) -> VarPtr { - auto i0_expr = MakeFloorDiv(lv, make_idx(I1), span); - auto i1_expr = MakeFloorMod(lv, make_idx(I1), span); - - // Load inp[:, i1, :] → [S0, 1, I2] → [S0, I2] → [1, S0*I2]. - auto inp_ofs = std::make_shared(std::vector{zero, i1_expr, zero}, span); - auto inp_sh = MakeShapeTuple({input_shape[0], one, input_shape[2]}, span); - auto inp_raw = - emit_to(bs, "tile.load", {input, inp_ofs, inp_sh, inp_sh}, load_kwargs, "gather_inp_raw"); - auto inp_2d = reshape_to(bs, inp_raw, {input_shape[0], input_shape[2]}, "gather_inp_2d"); - auto inp_flat = reshape_to(bs, inp_2d, {one, make_idx(S0S2)}, "gather_inp_flat"); - - // Load idx[i0, i1, :] → [1, 1, I2] → [1, I2]. - auto idx_ofs = - std::make_shared(std::vector{i0_expr, i1_expr, zero}, span); - auto idx_sh = MakeShapeTuple({one, one, index_shape[2]}, span); - auto idx_raw = - emit_to(bs, "tile.load", {index, idx_ofs, idx_sh, idx_sh}, load_kwargs, "gather_idx_raw"); - auto idx_row = reshape_to(bs, idx_raw, {one, index_shape[2]}, "gather_idx_row"); - - // flat_idx[k] = idx_row[k] * S2 + k → selects inp_flat[flat_idx[k]]. - auto idx_sc = emit_to(bs, "tile.muls", {idx_row, make_i32(S2)}, {}, "gather_idx_s"); - auto flat_idx = emit_to(bs, "tile.add", {idx_sc, range_1d}, {}, "gather_fidx"); - - return single_row_gather(bs, inp_flat, flat_idx, I2, "gather_row"); - }); - // Reshape [I0*I1, I2] is already the correct 2D layout; prevents Phase 3 optimization. - auto out_2d = reshape_to(prologue, result, {make_idx(I0I1), make_idx(I2)}, "gather_out"); - return ConversionResult{std::move(prologue), out_2d}; - } - - // ================================================================ - // Case 4 rank==3, dim==1 (middle dim) - // out[i0, i1, k] = inp[i0, idx[i0, i1, k], k] - // Result tile: [I0*I1, I2] where tile[i0*I1+i1, k] = output[i0, i1, k]. + // Generalizes the flat-index lowering used by case 3 (rank=3 dim=0) + // and case 4 (rank=3 dim=1) to any (rank>=2, 0<=dim=4 any dim (new) + // + // For each output row r = lv ∈ [0, out_rows), where + // out_rows = ∏_{i!=last} index_shape[i]: + // 1. Decompose lv into per-axis outer indices (i_0, ..., i_{last-1}). + // 2. inp_offsets/shape: full extent at gather_dim and last; offset 0 size 1 + // at the remaining axes (sliced to outer-index value). + // 3. Reshape input slice to [1, flat_size]; flat_size = S_d * S_last when + // gather_dim != last, else S_last. + // 4. idx_row = idx[i_0, ..., i_{last-1}, :] → [1, K] where K=index_shape[last]. + // 5. When gather_dim != last, flat_idx[k] = idx_row[k] * S_last + range(K). + // When gather_dim == last, idx_row directly indexes the flat input. + // 6. out_row = single-row gather(inp_flat, flat_idx) → [1, K]. // - // Uses flat-index gather to avoid intermediate tiles with I1 (potentially - // non-8-aligned) columns, which would violate hardware 32-byte row alignment. - // For each output row r = i0*I1+i1: - // inp_flat = inp[i0, :, :].flatten() → [1, S1*S2] - // idx_row = idx[i0, i1, :] → [1, I2] - // flat_idx = idx_row * S2 + [0..I2-1] → [1, I2] - // out_row = gather(inp_flat, flat_idx) → [1, I2] + // Output is a [out_rows, K] tile (already 2D); identity reshape preserves + // the layout against later optimization passes. // ================================================================ - CHECK(rank == 3 && norm_dim == 1) << "tensor.gather: unsupported (rank, dim) combination, " - << "got rank=" << rank << " norm_dim=" << norm_dim; + auto emit_flat_index_gather = [&](int gather_dim) -> ConversionResult { + const int rank_v = static_cast(rank); + const int last = rank_v - 1; + + // Resolve static dimensions. + std::vector input_dims(rank_v), index_dims(rank_v); + for (int i = 0; i < rank_v; ++i) { + input_dims[i] = get_const(input_shape[i], "input.shape"); + index_dims[i] = get_const(index_shape[i], "index.shape"); + } + const int64_t S_last = input_dims[last]; + const int64_t K = index_dims[last]; + const bool need_flat_idx = (gather_dim != last); + const int64_t flat_size = need_flat_idx ? input_dims[gather_dim] * S_last : S_last; + + int64_t out_rows = 1; + std::vector outer_dims; + outer_dims.reserve(last); + for (int i = 0; i < last; ++i) { + out_rows *= index_dims[i]; + outer_dims.push_back(index_dims[i]); + } - { - int64_t I0 = get_const(index_shape[0], "index.shape[0]"); - int64_t I1 = get_const(index_shape[1], "index.shape[1]"); - int64_t I2 = get_const(index_shape[2], "index.shape[2]"); - int64_t S1 = get_const(input_shape[1], "input.shape[1]"); - int64_t S2 = get_const(input_shape[2], "input.shape[2]"); - int64_t I0I1 = I0 * I1; - int64_t S1S2 = S1 * S2; + // Strides for mixed-radix decomposition of lv into (i_0, ..., i_{last-1}). + // stride[last-1] = 1, stride[i] = ∏_{j>i} outer_dims[j]. + std::vector stride(last); + if (last > 0) { + stride[last - 1] = 1; + for (int i = last - 2; i >= 0; --i) { + stride[i] = stride[i + 1] * outer_dims[i + 1]; + } + } - // Precompute constant range tile [0, 1, ..., I2-1] (shared across all loop iterations). - std::vector> ci_kw = {{"dtype", DataType(DataType::INT32)}}; - auto range_1d = emit("tile.ci", {make_i32(0), MakeShapeTuple({one, make_idx(I2)}, span)}, ci_kw, - "gather_range"); + // Constant range tile [0, 1, ..., K-1] in idx_dtype — shared across loop iterations. + VarPtr range_1d; + if (need_flat_idx) { + std::vector> ci_kw = {{"dtype", idx_dtype}}; + range_1d = emit("tile.ci", {make_idx_scalar(0), MakeShapeTuple({one, make_idx(K)}, span)}, ci_kw, + "gather_range"); + } - // Outer loop: r=0..I0*I1-1, accumulating [I0*I1, I2]. auto result = make_loop( - prologue, "gather_main", make_idx(I0I1), I0I1, I2, input_dtype, + prologue, "gather_main", make_idx(out_rows), out_rows, K, input_dtype, [&](const VarPtr& lv, const IterArgPtr& /*ia*/, std::vector& bs) -> VarPtr { - auto i0_expr = MakeFloorDiv(lv, make_idx(I1), span); - auto i1_expr = MakeFloorMod(lv, make_idx(I1), span); - - // Load inp[i0, :, :] → [1, S1, I2] → [S1, I2] → [1, S1*I2]. - auto inp_ofs = std::make_shared(std::vector{i0_expr, zero, zero}, span); - auto inp_sh = MakeShapeTuple({one, input_shape[1], input_shape[2]}, span); + // Decompose lv into per-axis outer indices. + std::vector outer_idx_exprs(last); + for (int i = 0; i < last; ++i) { + ExprPtr div_v; + if (stride[i] == 1) { + div_v = lv; + } else { + div_v = MakeFloorDiv(lv, make_idx(stride[i]), span); + } + outer_idx_exprs[i] = (i == 0) ? div_v : MakeFloorMod(div_v, make_idx(outer_dims[i]), span); + } + + // Build input offsets/shape. + std::vector inp_offsets(rank_v); + std::vector inp_shape_v(rank_v); + for (int i = 0; i < rank_v; ++i) { + if (i == gather_dim || i == last) { + inp_offsets[i] = zero; + inp_shape_v[i] = input_shape[i]; + } else { + inp_offsets[i] = outer_idx_exprs[i]; + inp_shape_v[i] = one; + } + } + auto inp_ofs = std::make_shared(inp_offsets, span); + auto inp_sh = MakeShapeTuple(inp_shape_v, span); auto inp_raw = emit_to(bs, "tile.load", {input, inp_ofs, inp_sh, inp_sh}, load_kwargs, "gather_inp_raw"); - auto inp_2d = reshape_to(bs, inp_raw, {input_shape[1], input_shape[2]}, "gather_inp_2d"); - auto inp_flat = reshape_to(bs, inp_2d, {one, make_idx(S1S2)}, "gather_inp_flat"); - // Load idx[i0, i1, :] → [1, 1, I2] → [1, I2]. - auto idx_ofs = - std::make_shared(std::vector{i0_expr, i1_expr, zero}, span); - auto idx_sh = MakeShapeTuple({one, one, index_shape[2]}, span); + // Reshape input to [1, flat_size]. For rank>2 isolate the gather plane. + VarPtr inp_flat; + if (need_flat_idx) { + VarPtr inp_2d = inp_raw; + if (rank_v > 2) { + inp_2d = reshape_to(bs, inp_raw, {make_idx(input_dims[gather_dim]), make_idx(S_last)}, + "gather_inp_2d"); + } + inp_flat = reshape_to(bs, inp_2d, {one, make_idx(flat_size)}, "gather_inp_flat"); + } else { + // gather along last axis: collapse leading 1-extent dims to a [1, S_last] row. + inp_flat = reshape_to(bs, inp_raw, {one, make_idx(flat_size)}, "gather_inp_flat"); + } + + // Build index offsets/shape. + std::vector idx_offsets(rank_v); + std::vector idx_shape_v(rank_v, one); + for (int i = 0; i < last; ++i) idx_offsets[i] = outer_idx_exprs[i]; + idx_offsets[last] = zero; + idx_shape_v[last] = index_shape[last]; + auto idx_ofs = std::make_shared(idx_offsets, span); + auto idx_sh = MakeShapeTuple(idx_shape_v, span); auto idx_raw = emit_to(bs, "tile.load", {index, idx_ofs, idx_sh, idx_sh}, load_kwargs, "gather_idx_raw"); - auto idx_row = reshape_to(bs, idx_raw, {one, index_shape[2]}, "gather_idx_row"); - - // flat_idx[k] = idx_row[k] * S2 + k → selects inp_flat[flat_idx[k]]. - auto idx_sc = emit_to(bs, "tile.muls", {idx_row, make_i32(S2)}, {}, "gather_idx_s"); - auto flat_idx = emit_to(bs, "tile.add", {idx_sc, range_1d}, {}, "gather_fidx"); - - return single_row_gather(bs, inp_flat, flat_idx, I2, "gather_row"); + VarPtr idx_row = idx_raw; + if (rank_v > 2) { + idx_row = reshape_to(bs, idx_raw, {one, make_idx(K)}, "gather_idx_row"); + } + + VarPtr final_idx = idx_row; + if (need_flat_idx) { + auto idx_sc = + emit_to(bs, "tile.muls", {idx_row, make_idx_scalar(S_last)}, {}, "gather_idx_s"); + final_idx = emit_to(bs, "tile.add", {idx_sc, range_1d}, {}, "gather_fidx"); + } + + return single_row_gather(bs, inp_flat, final_idx, K, "gather_row"); }); - // Reshape [I0*I1, I2] is already the correct 2D layout; prevents Phase 3 optimization. - auto out_2d = reshape_to(prologue, result, {make_idx(I0I1), make_idx(I2)}, "gather_out"); + // Identity reshape preserves the 2D layout against later optimization passes. + auto out_2d = reshape_to(prologue, result, {make_idx(out_rows), make_idx(K)}, "gather_out"); return ConversionResult{std::move(prologue), out_2d}; - } + }; + + // Case 3: rank==3, dim==0 (first dim) — out[i0,i1,k] = inp[idx[i0,i1,k], i1, k]. + // Case 4: rank==3, dim==1 (middle dim) — out[i0,i1,k] = inp[i0, idx[i0,i1,k], k]. + // Case 5: rank==2, dim==0 — out[i0,k] = inp[idx[i0,k], k]. + // Case 6: rank>=4, any dim. + if (rank == 3 && norm_dim == 0) return emit_flat_index_gather(0); + if (rank == 3 && norm_dim == 1) return emit_flat_index_gather(1); + if (rank == 2 && norm_dim == 0) return emit_flat_index_gather(0); + if (rank >= 4) return emit_flat_index_gather(norm_dim); + + CHECK(false) << "tensor.gather: unsupported (rank, dim) combination, " + << "got rank=" << rank << " norm_dim=" << norm_dim; + return ConversionResult{std::move(prologue), nullptr}; // unreachable }); } diff --git a/tests/st/harness/core/harness.py b/tests/st/harness/core/harness.py index 07a5b0105..439586c3d 100644 --- a/tests/st/harness/core/harness.py +++ b/tests/st/harness/core/harness.py @@ -89,6 +89,7 @@ class DataType(Enum): FP16 = "fp16" INT32 = "int32" UINT32 = "uint32" + INT8 = "int8" INT16 = "int16" UINT16 = "uint16" INT64 = "int64" @@ -103,6 +104,7 @@ def torch_dtype(self) -> torch.dtype: DataType.FP16: torch.float16, DataType.INT32: torch.int32, DataType.UINT32: torch.int32, # PyTorch has no uint32; use int32 (same bits) + DataType.INT8: torch.int8, DataType.INT16: torch.int16, DataType.UINT16: torch.int16, # PyTorch has limited uint16 support; use int16 (same bits) DataType.INT64: torch.int64, diff --git a/tests/st/runtime/test_gather.py b/tests/st/runtime/test_gather.py index 3cb29f509..49fc9229b 100644 --- a/tests/st/runtime/test_gather.py +++ b/tests/st/runtime/test_gather.py @@ -16,6 +16,10 @@ 3. Rank-3 + dim=-1 (collapses leading dims via ``tile.reshape``). 4. Rank-3 + dim=1 (middle axis — flat-index gather). 5. Rank-3 + dim=-3 (negative-dim normalization on the first axis). +6. Rank-2 + dim=0 (first axis — flat-index gather). +7. Rank-4 + dim=-1 (last axis — collapses leading dims via ``tile.reshape``). +8. Rank-4 + dim=2 (interior axis — flat-index gather, mixed-radix decomposition). +9. Rank-2 + dim=-1, INT8 src + INT32 idx (Ascend950). All cases are validated against a torch ``gather`` reference. """ @@ -137,6 +141,90 @@ def main( return output +@pl.program +class GatherRank2FirstDimProgram: + """Rank-2 + dim=0 (first axis) — flat-index gather. + + Last dim is 8 (8×4=32 bytes) to satisfy the hardware tile column + alignment requirement. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + inp: pl.Tensor[[8, 8], pl.FP32], + idx: pl.Tensor[[3, 8], pl.INT32], + output: pl.Out[pl.Tensor[[3, 8], pl.FP32]], + ) -> pl.Tensor[[3, 8], pl.FP32]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.gather(inp, dim=0, index=idx) + output = pl.assemble(output, out, [0, 0]) + return output + + +@pl.program +class GatherRank4LastDimProgram: + """Rank-4 + dim=-1: collapses leading dims via ``tile.reshape``. + + Last dim is 8 (8×4=32 bytes) to satisfy the hardware tile column + alignment requirement. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + inp: pl.Tensor[[2, 2, 2, 16], pl.FP32], + idx: pl.Tensor[[2, 2, 2, 8], pl.INT32], + output: pl.Out[pl.Tensor[[2, 2, 2, 8], pl.FP32]], + ) -> pl.Tensor[[2, 2, 2, 8], pl.FP32]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.gather(inp, dim=-1, index=idx) + output = pl.assemble(output, out, [0, 0, 0, 0]) + return output + + +@pl.program +class GatherRank4InteriorDimProgram: + """Rank-4 + dim=2 (interior axis) — flat-index gather with mixed-radix decomposition. + + Last dim is 8 (8×4=32 bytes) to satisfy the hardware tile column + alignment requirement. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + inp: pl.Tensor[[2, 2, 4, 8], pl.FP32], + idx: pl.Tensor[[2, 2, 3, 8], pl.INT32], + output: pl.Out[pl.Tensor[[2, 2, 3, 8], pl.FP32]], + ) -> pl.Tensor[[2, 2, 3, 8], pl.FP32]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.gather(inp, dim=2, index=idx) + output = pl.assemble(output, out, [0, 0, 0, 0]) + return output + + +@pl.program +class GatherRank2INT8Program: + """Rank-2 + dim=-1, INT8 src + INT32 idx (Ascend950). + + INT8 elements are 1 byte wide, so cols must be a multiple of 32 to + satisfy the 32-byte tile column alignment requirement. + """ + + @pl.function(type=pl.FunctionType.Opaque) + def main( + self, + inp: pl.Tensor[[4, 64], pl.INT8], + idx: pl.Tensor[[4, 32], pl.INT32], + output: pl.Out[pl.Tensor[[4, 32], pl.INT8]], + ) -> pl.Tensor[[4, 32], pl.INT8]: + with pl.at(level=pl.Level.CORE_GROUP): + out = pl.tensor.gather(inp, dim=-1, index=idx) + output = pl.assemble(output, out, [0, 0]) + return output + + # --- Test cases --- @@ -146,9 +234,6 @@ class _GatherBaseTestCase(PTOTestCase): def get_strategy(self) -> OptimizationStrategy: return OptimizationStrategy.Default - def get_backend_type(self) -> BackendType: - return BackendType.Ascend910B - class GatherRank2LastDimTestCase(_GatherBaseTestCase): def get_name(self) -> str: @@ -169,6 +254,9 @@ def define_tensors(self) -> list[TensorSpec]: def get_program(self) -> Any: return GatherRank2LastDimProgram + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + def compute_expected(self, tensors, params=None): # torch.gather semantics: out[b, k] = inp[b, idx[b, k]] inp = tensors["inp"] @@ -195,6 +283,9 @@ def define_tensors(self) -> list[TensorSpec]: def get_program(self) -> Any: return GatherRank2SmallerLeadingProgram + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + def compute_expected(self, tensors, params=None): # torch's index broadcast along the non-gather axis must match the # PyPTO contract: rows of the input beyond index.shape[0] are unused. @@ -222,6 +313,9 @@ def define_tensors(self) -> list[TensorSpec]: def get_program(self) -> Any: return GatherRank3LastDimProgram + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + def compute_expected(self, tensors, params=None): inp = tensors["inp"] idx = tensors["idx"].to(torch.int64) @@ -247,6 +341,9 @@ def define_tensors(self) -> list[TensorSpec]: def get_program(self) -> Any: return GatherRank3MiddleDimProgram + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + def compute_expected(self, tensors, params=None): inp = tensors["inp"] idx = tensors["idx"].to(torch.int64) @@ -272,6 +369,9 @@ def define_tensors(self) -> list[TensorSpec]: def get_program(self) -> Any: return GatherRank3NegFirstDimProgram + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + def compute_expected(self, tensors, params=None): inp = tensors["inp"] idx = tensors["idx"].to(torch.int64) @@ -279,6 +379,123 @@ def compute_expected(self, tensors, params=None): tensors["output"][:] = torch.gather(inp, dim=0, index=idx) +class GatherRank2FirstDimTestCase(_GatherBaseTestCase): + def get_name(self) -> str: + return "gather_rank2_first_dim" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("inp", [8, 8], DataType.FP32, init_value=torch.randn), + TensorSpec( + "idx", + [3, 8], + DataType.INT32, + init_value=lambda: _rand_indices(0, 8, (3, 8)), + ), + TensorSpec("output", [3, 8], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return GatherRank2FirstDimProgram + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + inp = tensors["inp"] + idx = tensors["idx"].to(torch.int64) + tensors["output"][:] = torch.gather(inp, dim=0, index=idx) + + +class GatherRank4LastDimTestCase(_GatherBaseTestCase): + def get_name(self) -> str: + return "gather_rank4_last_dim" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("inp", [2, 2, 2, 16], DataType.FP32, init_value=torch.randn), + TensorSpec( + "idx", + [2, 2, 2, 8], + DataType.INT32, + init_value=lambda: _rand_indices(0, 16, (2, 2, 2, 8)), + ), + TensorSpec("output", [2, 2, 2, 8], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return GatherRank4LastDimProgram + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + inp = tensors["inp"] + idx = tensors["idx"].to(torch.int64) + tensors["output"][:] = torch.gather(inp, dim=-1, index=idx) + + +class GatherRank4InteriorDimTestCase(_GatherBaseTestCase): + def get_name(self) -> str: + return "gather_rank4_interior_dim" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("inp", [2, 2, 4, 8], DataType.FP32, init_value=torch.randn), + TensorSpec( + "idx", + [2, 2, 3, 8], + DataType.INT32, + init_value=lambda: _rand_indices(0, 4, (2, 2, 3, 8)), + ), + TensorSpec("output", [2, 2, 3, 8], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return GatherRank4InteriorDimProgram + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + inp = tensors["inp"] + idx = tensors["idx"].to(torch.int64) + tensors["output"][:] = torch.gather(inp, dim=2, index=idx) + + +class GatherRank2INT8TestCase(_GatherBaseTestCase): + def get_name(self) -> str: + return "gather_rank2_int8" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec( + "inp", + [4, 64], + DataType.INT8, + init_value=lambda: torch.randint(-128, 128, (4, 64), dtype=torch.int8), + ), + TensorSpec( + "idx", + [4, 32], + DataType.INT32, + init_value=lambda: _rand_indices(0, 64, (4, 32)), + ), + TensorSpec("output", [4, 32], DataType.INT8, is_output=True), + ] + + def get_program(self) -> Any: + return GatherRank2INT8Program + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend950 + + def compute_expected(self, tensors, params=None): + inp = tensors["inp"] + idx = tensors["idx"].to(torch.int64) + tensors["output"][:] = torch.gather(inp, dim=-1, index=idx) + + # --- Tests --- @@ -286,31 +503,60 @@ class TestGather: """Verify ``pl.tensor.gather`` against a torch reference for the generalized rank/dim contract introduced by issue #676.""" + @pytest.mark.platforms("a2a3", "a2a3sim") @pytest.mark.parametrize("platform", PLATFORMS) def test_gather_rank2_last_dim(self, test_runner, platform): result = test_runner.run(GatherRank2LastDimTestCase(platform=platform)) assert result.passed, f"Test failed: {result.error}" + @pytest.mark.platforms("a2a3", "a2a3sim") @pytest.mark.parametrize("platform", PLATFORMS) def test_gather_rank2_smaller_leading(self, test_runner, platform): result = test_runner.run(GatherRank2SmallerLeadingTestCase(platform=platform)) assert result.passed, f"Test failed: {result.error}" + @pytest.mark.platforms("a2a3", "a2a3sim") @pytest.mark.parametrize("platform", PLATFORMS) def test_gather_rank3_last_dim(self, test_runner, platform): result = test_runner.run(GatherRank3LastDimTestCase(platform=platform)) assert result.passed, f"Test failed: {result.error}" + @pytest.mark.platforms("a2a3", "a2a3sim") @pytest.mark.parametrize("platform", PLATFORMS) def test_gather_rank3_middle_dim(self, test_runner, platform): result = test_runner.run(GatherRank3MiddleDimTestCase(platform=platform)) assert result.passed, f"Test failed: {result.error}" + @pytest.mark.platforms("a2a3", "a2a3sim") @pytest.mark.parametrize("platform", PLATFORMS) def test_gather_rank3_neg_first_dim(self, test_runner, platform): result = test_runner.run(GatherRank3NegFirstDimTestCase(platform=platform)) assert result.passed, f"Test failed: {result.error}" + @pytest.mark.platforms("a2a3", "a2a3sim") + @pytest.mark.parametrize("platform", PLATFORMS) + def test_gather_rank2_first_dim(self, test_runner, platform): + result = test_runner.run(GatherRank2FirstDimTestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + + @pytest.mark.platforms("a2a3", "a2a3sim") + @pytest.mark.parametrize("platform", PLATFORMS) + def test_gather_rank4_last_dim(self, test_runner, platform): + result = test_runner.run(GatherRank4LastDimTestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + + @pytest.mark.platforms("a2a3", "a2a3sim") + @pytest.mark.parametrize("platform", PLATFORMS) + def test_gather_rank4_interior_dim(self, test_runner, platform): + result = test_runner.run(GatherRank4InteriorDimTestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + + @pytest.mark.platforms("a5", "a5sim") + @pytest.mark.parametrize("platform", PLATFORMS) + def test_gather_rank2_int8(self, test_runner, platform): + result = test_runner.run(GatherRank2INT8TestCase(platform=platform)) + assert result.passed, f"Test failed: {result.error}" + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/operators/test_tensor_ops.py b/tests/ut/ir/operators/test_tensor_ops.py index ed0c1d5e6..97b22f3db 100644 --- a/tests/ut/ir/operators/test_tensor_ops.py +++ b/tests/ut/ir/operators/test_tensor_ops.py @@ -2474,14 +2474,35 @@ def test_tensor_gather_rejects_bad_dim(): def test_tensor_gather_rejects_non_int32_index(): - inp, idx = _make_gather_inputs(idx_dtype=DataType.INT16) - with pytest.raises(Exception, match=r"index dtype to be INT32"): + # IR accepts INT16/INT32 indices; reject anything else. + inp, idx = _make_gather_inputs(idx_dtype=DataType.INT8) + with pytest.raises(Exception, match=r"index dtype to be INT16 or INT32"): ir.op.tensor.gather(inp, dim=-1, index=idx) +def test_tensor_gather_accepts_int16_index(): + """tensor.gather accepts INT16 indices (widened contract).""" + inp, idx = _make_gather_inputs(idx_dtype=DataType.INT16) + call = ir.op.tensor.gather(inp, dim=-1, index=idx) + assert call.op.name == "tensor.gather" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.dtype == DataType.FP32 + + +def test_tensor_gather_accepts_int8_input(): + """tensor.gather accepts INT8 src dtype and preserves it on the output.""" + inp, idx = _make_gather_inputs(src_dtype=DataType.INT8, idx_dtype=DataType.INT32) + call = ir.op.tensor.gather(inp, dim=-1, index=idx) + assert call.op.name == "tensor.gather" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.dtype == DataType.INT8 + + def test_tensor_gather_rejects_unsupported_input_dtype(): inp, idx = _make_gather_inputs(src_dtype=DataType.UINT32) - with pytest.raises(Exception, match=r"FP16, FP32, INT16, or INT32"): + with pytest.raises(Exception, match=r"FP16, FP32, INT8, INT16, or INT32"): ir.op.tensor.gather(inp, dim=-1, index=idx)