feat(backend): Add per-backend dtype allowlist for gather op#1317
feat(backend): Add per-backend dtype allowlist for gather op#1317Little-oil wants to merge 2 commits into
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a BackendHandler::IsDtypeSupported hook, provides Ascend910B/950 allowlists, updates tensor.gather and tile.gather type checks to consult backends (wider index/input dtypes), refactors tensor.gather lowering to a generalized emit_flat_index_gather, and extends tests for extra ranks/dimensions. ChangesGather Backend Dtype Support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism for backend-specific data type validation for operators, specifically applied to the gather operation. It generalizes the gather lowering logic to support arbitrary ranks and dimensions using a flat-index approach and expands the supported data types for the Ascend950 backend to include INT8 for sources and INT16 for indices. The review feedback correctly identifies that the lowering logic uses hardcoded INT32 constants for index arithmetic, which will cause type mismatches when INT16 indices are used; it suggests using the actual index tensor data type for these constants to maintain IR consistency.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/ut/ir/operators/test_tensor_ops.py (1)
2397-2401: ⚡ Quick winAdd positive tests for newly accepted gather dtypes.
This update strengthens rejection paths, but it still doesn’t assert success for the newly allowed tensor-level cases (
index=INT16,src=INT8). Adding those keeps the widened contract from regressing silently.Suggested test additions
+def test_tensor_gather_accepts_int16_index(): + inp, idx = _make_gather_inputs(idx_dtype=DataType.INT16) + call = ir.op.tensor.gather(inp, dim=-1, index=idx) + assert call.op.name == "tensor.gather" + + +def test_tensor_gather_accepts_int8_input(): + inp, idx = _make_gather_inputs(src_dtype=DataType.INT8, idx_dtype=DataType.INT32) + call = ir.op.tensor.gather(inp, dim=-1, index=idx) + rt = call.type + assert isinstance(rt, ir.TensorType) + assert rt.dtype == DataType.INT8Also applies to: 2404-2407
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/ut/ir/operators/test_tensor_ops.py` around lines 2397 - 2401, Add positive assertions that exercise the newly-accepted gather dtypes: use _make_gather_inputs to create inputs with index dtype DataType.INT16 and source dtype DataType.INT8 and call ir.op.tensor.gather(inp, dim=-1, index=idx) expecting no exception (i.e., remove pytest.raises and let the call succeed), and mirror the same positive test for the adjacent/relevant gather test to ensure both newly-allowed cases are asserted as successful rather than only asserting rejections.src/ir/transforms/op_conversion_registry.cpp (1)
1148-1275: 💤 Low valueTop-of-section comment in this file now lists only 4 cases — consider refreshing for the new dispatch.
The block comment at lines 889–938 still enumerates exactly “Four cases (by rank and norm_dim)” (rank-2 dim=1, rank-3 dim=0/1/2). With this PR,
RegisterGatherOpsnow also dispatches rank-2 dim=0 and rank≥4 any-dim throughemit_flat_index_gather. The new inline comment at lines 1122–1147 describes the generalized helper well, but a reader scanning the file’s top-of-section overview will get a stale picture of supported cases. Worth a one-paragraph refresh to mention the rank-2 dim=0 and rank≥4 routes alongside the existing four.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/ir/transforms/op_conversion_registry.cpp` around lines 1148 - 1275, Update the top-of-section block comment that enumerates "Four cases (by rank and norm_dim)" to reflect the new dispatch paths: mention that emit_flat_index_gather now handles rank-2 dim=0 and any rank>=4 (in addition to the previously-listed rank-2 dim=1 and rank-3 dim=0/1/2 cases); locate the comment near RegisterGatherOps/emit_flat_index_gather and replace the outdated four-case enumeration with a short paragraph that lists the full set of dispatched routes (rank==2 dim==1, rank==2 dim==0, rank==3 dim==0/1/2, and rank>=4 any-dim) so the overview matches the actual dispatch logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/ir/op/tensor_ops/gather.cpp`:
- Around line 61-67: Update the tensor.gather op registration strings to reflect
the widened dtype contract: include INT8 as an allowed src dtype and allow index
to be INT16 or INT32. Locate the tensor.gather registration metadata (the
human-readable type description lines that currently list allowed src/index
dtypes) and modify them so they match the runtime checks (which use CHECK on
input_type->dtype_ and CheckBackendDtype for "src"); apply the same change to
the second registration occurrence referenced in the comment so both metadata
entries advertise the new INT8 for src and INT16|INT32 for index.
In `@src/ir/op/tile_ops/gather.cpp`:
- Around line 68-74: The operator argument documentation for tile.gather is out
of sync with the deduce logic: update the registered argument descriptions for
the gather op (the tile.gather registration/arg doc strings) to reflect that src
may be FP16|FP32|INT8|INT16|INT32 (per the CHECK on src_type and
CheckBackendDtype usage), indices may be INT16 or INT32, and tmp should be
documented as matching the indices dtype (tmp dtype == indices dtype); also
update the other duplicate doc block that mirrors these lines (the section
corresponding to the same registration around the other check). Ensure op_name,
src_type, CheckBackendDtype, and the indices/tmp arg descriptions are edited so
the docs match the new deduce rules.
---
Nitpick comments:
In `@src/ir/transforms/op_conversion_registry.cpp`:
- Around line 1148-1275: Update the top-of-section block comment that enumerates
"Four cases (by rank and norm_dim)" to reflect the new dispatch paths: mention
that emit_flat_index_gather now handles rank-2 dim=0 and any rank>=4 (in
addition to the previously-listed rank-2 dim=1 and rank-3 dim=0/1/2 cases);
locate the comment near RegisterGatherOps/emit_flat_index_gather and replace the
outdated four-case enumeration with a short paragraph that lists the full set of
dispatched routes (rank==2 dim==1, rank==2 dim==0, rank==3 dim==0/1/2, and
rank>=4 any-dim) so the overview matches the actual dispatch logic.
In `@tests/ut/ir/operators/test_tensor_ops.py`:
- Around line 2397-2401: Add positive assertions that exercise the
newly-accepted gather dtypes: use _make_gather_inputs to create inputs with
index dtype DataType.INT16 and source dtype DataType.INT8 and call
ir.op.tensor.gather(inp, dim=-1, index=idx) expecting no exception (i.e., remove
pytest.raises and let the call succeed), and mirror the same positive test for
the adjacent/relevant gather test to ensure both newly-allowed cases are
asserted as successful rather than only asserting rejections.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 52f38f37-e181-435f-bd61-e8e69d4eefaa
📒 Files selected for processing (10)
include/pypto/backend/910B/backend_910b_handler.hinclude/pypto/backend/950/backend_950_handler.hinclude/pypto/backend/common/backend_handler.hsrc/backend/910B/backend_910b_handler.cppsrc/backend/950/backend_950_handler.cppsrc/ir/op/tensor_ops/gather.cppsrc/ir/op/tile_ops/gather.cppsrc/ir/transforms/op_conversion_registry.cpptests/st/runtime/test_gather.pytests/ut/ir/operators/test_tensor_ops.py
- Use index_tensor_type->dtype_ for tile.ci range and tile.muls multiplier in tensor.gather lowering, so INT16 indices flow through without hardcoded INT32 mismatches (gemini review). - Update tensor.gather and tile.gather argument descriptions to reflect the per-backend dtype contract: INT8 src and INT16 indices are valid on Ascend950 (coderabbit review). - Add direct #include "pypto/core/dtype.h" to backend handler files so clang-tidy's misc-include-cleaner is satisfied.
Introduces BackendHandler::IsDtypeSupported(op_name, arg_role, dtype)
so gather's accepted dtypes can vary per backend:
- a2a3 (910B): src {FP16, FP32, INT16, INT32}, indices {INT32}
- a5 (950) : src adds INT8; indices adds INT16
The op-level type-deduction enforces the universal union (a2a3 ∪ a5)
and then narrows to the active backend via CheckBackendDtype when a
backend is configured. Generalises tensor.gather lowering and updates
the gather ST/UT suite accordingly.
fix(pr): resolve issues for hw-native-sys#1317
- Use index_tensor_type->dtype_ for tile.ci range and tile.muls multiplier
in tensor.gather lowering, so INT16 indices flow through without
hardcoded INT32 mismatches (gemini review).
- Update tensor.gather and tile.gather argument descriptions to reflect
the per-backend dtype contract: INT8 src and INT16 indices are valid on
Ascend950 (coderabbit review).
- Add direct #include "pypto/core/dtype.h" to backend handler files so
clang-tidy's misc-include-cleaner is satisfied.
test(gather): Add INT8 ST test case for Ascend950 backend
Extend gather ST coverage with an INT8 src + INT32 idx case targeting
Ascend950 (a5 dtype allowlist), and tag existing rank/dim coverage tests
with explicit a2a3 platform markers + Ascend910B backend type. Adds
INT8 to harness DataType enum.
删除多余的注释
提取CheckOpDtype函数
撤回不必要的类型区分,PTOAS会给出错误日志
dd57a76 to
0f1b5de
Compare
Introduces BackendHandler::IsDtypeSupported(op_name, arg_role, dtype)
so gather's accepted dtypes can vary per backend:
- a2a3 (910B): src {FP16, FP32, INT16, INT32}, indices {INT32}
- a5 (950) : src adds INT8; indices adds INT16
The op-level type-deduction enforces the universal union (a2a3 ∪ a5)
and then narrows to the active backend via CheckBackendDtype when a
backend is configured. Generalises tensor.gather lowering and updates
the gather ST/UT suite accordingly.
fix(pr): resolve issues for hw-native-sys#1317
- Use index_tensor_type->dtype_ for tile.ci range and tile.muls multiplier
in tensor.gather lowering, so INT16 indices flow through without
hardcoded INT32 mismatches (gemini review).
- Update tensor.gather and tile.gather argument descriptions to reflect
the per-backend dtype contract: INT8 src and INT16 indices are valid on
Ascend950 (coderabbit review).
- Add direct #include "pypto/core/dtype.h" to backend handler files so
clang-tidy's misc-include-cleaner is satisfied.
test(gather): Add INT8 ST test case for Ascend950 backend
Extend gather ST coverage with an INT8 src + INT32 idx case targeting
Ascend950 (a5 dtype allowlist), and tag existing rank/dim coverage tests
with explicit a2a3 platform markers + Ascend910B backend type. Adds
INT8 to harness DataType enum.
删除多余的注释
提取CheckOpDtype函数
撤回不必要的类型区分,PTOAS会给出错误日志
- Refresh the gather-section block comment in op_conversion_registry.cpp to list all six dispatch routes (was stale "Four cases"): rank=2 dim=0 and rank>=4 any-dim now go through emit_flat_index_gather alongside the existing rank=3 dim=0/1/2 cases. - Add positive tensor.gather tests for the newly accepted dtypes (INT16 index, INT8 src), complementing the existing rejection tests.
0f1b5de to
c5a4654
Compare
|
Addressed the two remaining CodeRabbit nitpicks from the latest review in
Branch was also rebased onto the latest |
Summary
This PR completes the
gatherop's index form and introduces a per-backend dtype allowlist mechanism so the same op can accept different dtypes on different platforms withoutif (BackendType == ...)branches in passes.1. Per-backend dtype allowlist (new mechanism)
BackendHandler::IsDtypeSupported(op_name, arg_role, dtype)— new virtual method on the backend handler interface (include/pypto/backend/common/backend_handler.h). Defaults tofalse; each backend opts in to whatever it actually accepts via a per-(op, arg_role)allowlist table.The pattern (using a hypothetical
tile.foo(src, idx)/tensor.foo(input, index)):CheckBackendDtype(op, arg_role, dtype)to narrow to the active backend when one is configured:src/backend/910B/backend_910b_handler.cppandsrc/backend/950/backend_950_handler.cppregister the real subset for each backend.No changes needed to existing public headers beyond the new virtual method override.
2. Gather op — applied to the new mechanism
Op type-deduction (
src/ir/op/{tile,tensor}_ops/gather.cpp):src ∈ {FP16, FP32, INT8, INT16, INT32},indices ∈ {INT16, INT32}.tile.gathertmpworkspace constraint relaxed from hardcodedINT32to "must match indices dtype".910B (a2a3) allowlist (
src/backend/910B/backend_910b_handler.cpp):tile.gather/tensor.gathersrc∈{FP16, FP32, INT16, INT32}tile.gather/tensor.gatherindices∈{INT32}950 (a5) allowlist (
src/backend/950/backend_950_handler.cpp):tile.gather/tensor.gathersrc∈{INT8, FP16, FP32, INT16, INT32}(a2a3 ∪ INT8)tile.gather/tensor.gatherindices∈{INT16, INT32}(a2a3 ∪ INT16)3. Generalized
tensor.gatherloweringsrc/ir/transforms/op_conversion_registry.cpp:emit_flat_index_gather(gather_dim)helper that uses mixed-radix decomposition of the loop variable.idx_dtypeinstead of hardcodedINT32, so INT16-indices paths share the same lowering once the codegen-side INT16 work lands.4. Tests
Unit tests (
tests/ut/ir/operators/test_tensor_ops.py):FP16, FP32, INT8, INT16, or INT32/INT16 or INT32).ST tests (
tests/st/runtime/test_gather.py):@pytest.mark.platforms("a2a3", "a2a3sim")/("a5", "a5sim")) added per test.BackendType.Ascend910Bfrom the base class so each test case picks its own backend.ST harness (
tests/st/harness/core/harness.py): AddedINT8to theDataTypeenum.Testing
Notes
range_1d/tmpnow use the indices dtype consistently. An end-to-end INT16-idx ST case is intentionally deferred to a follow-up PR pending PTOAS-side INT16 codegen verification.