Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/ir/op/tensor_ops/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ TypePtr DeduceTensorGatherType(const std::vector<ExprPtr>& args,
CHECK(input_type) << "The operator " << op_name << " requires input to be a TensorType, but got "
<< args[0]->GetType()->TypeName();
CHECK(input_type->dtype_ == DataType::FP16 || input_type->dtype_ == DataType::FP32 ||
input_type->dtype_ == DataType::INT16 || input_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires input dtype to be FP16, FP32, INT16, or INT32, but got "
input_type->dtype_ == DataType::INT8 || input_type->dtype_ == DataType::INT16 ||
input_type->dtype_ == DataType::INT32)
<< "The operator " << op_name
<< " requires input dtype to be FP16, FP32, INT8, INT16, or INT32, but got "
<< input_type->dtype_.ToString();

auto index_type = As<TensorType>(args[1]->GetType());
CHECK(index_type) << "The operator " << op_name << " requires index to be a TensorType, but got "
<< args[1]->GetType()->TypeName();
CHECK(index_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires index dtype to be INT32, but got "
CHECK(index_type->dtype_ == DataType::INT32 || index_type->dtype_ == DataType::INT16)
<< "The operator " << op_name << " requires index dtype to be INT16 or INT32, but got "
<< index_type->dtype_.ToString();

const int64_t rank = static_cast<int64_t>(input_type->shape_.size());
Expand Down Expand Up @@ -101,8 +103,8 @@ REGISTER_OP("tensor.gather")
"Gather elements of input along the specified dimension using the index tensor "
"(tensor-level). Supports rank>=2 and any dim; lowered via tile.transpose + "
"tile.reshape + tile.gather by ConvertTensorToTileOps.")
.add_argument("input", "Input tensor (TensorType; FP16, FP32, INT16, or INT32)")
.add_argument("index", "Index tensor (TensorType, INT32, same shape as output)")
.add_argument("input", "Input tensor (TensorType; FP16, FP32, INT8, INT16, or INT32)")
.add_argument("index", "Index tensor (TensorType; INT16 or INT32, same shape as output)")
.set_attr<int>("dim")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
Expand Down
27 changes: 14 additions & 13 deletions src/ir/op/tile_ops/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,31 @@ static TypePtr DeduceTileGatherType(const std::vector<ExprPtr>& args,
CHECK(args.size() == 3) << "The operator " << op_name
<< " requires 3 arguments (src, indices, tmp), but got " << args.size();

// First arg: src tile (f16, f32, i16, or i32)
// First arg: src tile
auto src_type = As<TileType>(args[0]->GetType());
CHECK(src_type) << "The operator " << op_name << " requires first argument to be a TileType, but got "
<< args[0]->GetType()->TypeName();
CHECK(src_type->dtype_ == DataType::FP16 || src_type->dtype_ == DataType::FP32 ||
src_type->dtype_ == DataType::INT16 || src_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires src dtype to be FP16, FP32, INT16, or INT32, but got "
src_type->dtype_ == DataType::INT8 || src_type->dtype_ == DataType::INT16 ||
src_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires src dtype to be FP16, FP32, INT8, INT16, or INT32, but got "
<< src_type->dtype_.ToString();

Comment thread
coderabbitai[bot] marked this conversation as resolved.
// Second arg: indices tile (must be i32)
// Second arg: indices tile
auto idx_type = As<TileType>(args[1]->GetType());
CHECK(idx_type) << "The operator " << op_name << " requires second argument to be a TileType, but got "
<< args[1]->GetType()->TypeName();
CHECK(idx_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires indices dtype to be INT32, but got "
CHECK(idx_type->dtype_ == DataType::INT32 || idx_type->dtype_ == DataType::INT16)
<< "The operator " << op_name << " requires indices dtype to be INT16 or INT32, but got "
<< idx_type->dtype_.ToString();

// Third arg: tmp workspace tile (must be i32, same shape as indices)
// Third arg: tmp workspace tile (must match indices dtype, same shape rank)
auto tmp_type = As<TileType>(args[2]->GetType());
CHECK(tmp_type) << "The operator " << op_name << " requires third argument to be a TileType, but got "
<< args[2]->GetType()->TypeName();
CHECK(tmp_type->dtype_ == DataType::INT32)
<< "The operator " << op_name << " requires tmp dtype to be INT32, but got "
<< tmp_type->dtype_.ToString();
CHECK(tmp_type->dtype_ == idx_type->dtype_)
<< "The operator " << op_name << " requires tmp dtype to match indices dtype ("
<< idx_type->dtype_.ToString() << "), but got " << tmp_type->dtype_.ToString();
CHECK(tmp_type->shape_.size() == idx_type->shape_.size())
<< "The operator " << op_name << " requires tmp shape rank to match indices rank ("
<< idx_type->shape_.size() << "), but got " << tmp_type->shape_.size();
Expand All @@ -88,9 +89,9 @@ static TypePtr DeduceTileGatherType(const std::vector<ExprPtr>& args,
REGISTER_OP("tile.gather")
.set_op_category("TileOp")
.set_description("Gather elements by index (maps to pto.tgather)")
.add_argument("src", "Source tile (FP16, FP32, INT16, or INT32)")
.add_argument("indices", "Index tile (INT32)")
.add_argument("tmp", "Temporary workspace tile (INT32)")
.add_argument("src", "Source tile (FP16, FP32, INT8, INT16, or INT32)")
.add_argument("indices", "Index tile (INT16 or INT32)")
.add_argument("tmp", "Temporary workspace tile (dtype must match indices)")
.set_input_memory(0, MemorySpace::Vec)
.set_input_memory(1, MemorySpace::Vec)
.set_input_memory(2, MemorySpace::Vec)
Expand Down
Loading
Loading