Skip to content

feat: auto-vectorize bf16/fp16 reduce with packed add2 intrinsics#2112

Open
kurisu6912 wants to merge 9 commits intotile-ai:mainfrom
kurisu6912:feature/packed-reduce-bf16-add2
Open

feat: auto-vectorize bf16/fp16 reduce with packed add2 intrinsics#2112
kurisu6912 wants to merge 9 commits intotile-ai:mainfrom
kurisu6912:feature/packed-reduce-bf16-add2

Conversation

@kurisu6912
Copy link
Copy Markdown
Collaborator

@kurisu6912 kurisu6912 commented Apr 28, 2026

Summary

Auto-vectorize bf16/fp16 fragment reduce to emit packed add2 / max2 / min2 instead of scalar add.bf16. Enabled when the reduction extent is a compile-time constant divisible by the pack factor (2 for bf16/fp16), and nan_propagate is not set.

Generated code (bf16 reduce_sum, 4×256, dim=1)

single-warp (32 threads)

uint1 dst_clear_pack[4];      // packed accumulator (bf16x2 per element)
bfloat16_t dst_clear[4];

for (int i_1 = 0; i_1 < 4; ++i_1) {
    bfloat16_t init = bfloat16_t(0.0);
    dst_clear_pack[i_1] = make_uint1(__pack_nv_bfloat162(init, init));

    #pragma unroll
    for (int rv = 0; rv < 4; ++rv)
        dst_clear_pack[i_1] = tl::to_uint1(tl::add2(
            tl::from_uint1<__nv_bfloat162>(dst_clear_pack[i_1]),
            tl::from_uint1<__nv_bfloat162>(
                *(uint1*)(src + (i_1 * 8) + (rv * 2)))));

    dst_clear[i_1] = bfloat16_t(((nv_bfloat162*)(&dst_clear_pack[i_1]))->x)
                   + bfloat16_t(((nv_bfloat162*)(&dst_clear_pack[i_1]))->y);

    dst_clear[i_1] = tl::AllReduce<SumOp, 32, 1, 0,
                                   NamedBarrier<32>>::run(dst_clear[i_1]);
}

multi-warp (128 threads)

uint1 dst_pack[1];
bfloat16_t dst[1];

bfloat16_t init = bfloat16_t(0.0);
dst_pack[0] = make_uint1(__pack_nv_bfloat162(init, init));

#pragma unroll
for (int rv = 0; rv < 4; ++rv)
    dst_pack[0] = tl::to_uint1(tl::add2(
        tl::from_uint1<__nv_bfloat162>(dst_pack[0]),
        tl::from_uint1<__nv_bfloat162>(
            *(uint1*)(src + (rv * 2)))));

dst[0] = bfloat16_t(((nv_bfloat162*)(&dst_pack[0]))->x)
       + bfloat16_t(((nv_bfloat162*)(&dst_pack[0]))->y);

dst[0] = tl::AllReduce<SumOp, 32, 1, 0,
                       NamedBarrier<128>>::run(dst[0]);

batched AllReduce (batch=4, threads=256)

Local reduce is the same as above. After the horizontal reduce produces scalar values, pairs are packed and AllReduce uses SumOp_bf16x2:

// Phase 1 (same as single-warp above) → scalar B_local[0..7]

// Phase 2: pack pairs → batched AllReduce on bf16x2
B_local_pack_1[0] = uint1{__pack_nv_bfloat162(B_local[0], B_local[1])};
B_local_pack_1[1] = uint1{__pack_nv_bfloat162(B_local[2], B_local[3])};
tl::AllReduce<SumOp_bf16x2, 16, 1, 0,
              NamedBarrier<256>, 2, 16>::run_batch(&B_local_pack_1[0]);
B_local[0] = bfloat16_t(((nv_bfloat162*)&B_local_pack_1[0])->x);
B_local[1] = bfloat16_t(((nv_bfloat162*)&B_local_pack_1[0])->y);
B_local[2] = bfloat16_t(((nv_bfloat162*)&B_local_pack_1[1])->x);
B_local[3] = bfloat16_t(((nv_bfloat162*)&B_local_pack_1[1])->y);
// ... repeat for remaining pairs

Summary by CodeRabbit

Release Notes

  • New Features

    • Added vectorized reduction operations for improved performance on paired element operations
    • Added NaN-aware min/max reduction operators for better numerical correctness with floating-point data
    • Extended bfloat16 and float16 reduction support with packed operation variants
    • Introduced debug logging with filename-based verbosity control
  • Tests

    • Expanded reduction test coverage with additional bfloat16 and NaN-propagation validation

…l for TVM_LOG_CUSTOMIZE builds

When TVM_LOG_CUSTOMIZE=1, TVM's logging.cc skips compiling
TvmLogDebugSettings::ParseSpec and VerboseEnabledImpl (guarded by
#if TVM_LOG_CUSTOMIZE == 0). However libtilelang.so calls these
functions via the inline FromFlag(), causing a runtime symbol
lookup error. Add the missing implementations to tilelang's own
logging.cc.
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 28, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Make ReduceOpNode lane-size aware and add 2-lane packed reduction lowering for CUDA fp16/bf16 (thread-local and batched AllReduce), introduce packed TL functors and pack/shuffle helpers, update CUDA codegen for packed lanes, and add filename-scoped runtime VLOG parsing; tests expanded for bf16 reductions.

Changes

Cohort / File(s) Summary
Reduce lowering
src/op/reduce.h, src/op/reduce.cc
Add vsize-aware helpers (MakeInitValue(vsize), MakeReduce(vsize,...)std::optional, MakeCodegenReducer(vsize)std::optional); implement conditional 2-lane packing for fragment→fragment reductions (packed identities, packed thread-local accumulation, lane-aware reducers), fall back to scalar; adapt batched AllReduce to pack/unpack and use eff_batch.
CUDA TL templates (packed ops & shuffles)
src/tl_templates/cuda/common.h, src/tl_templates/cuda/reduce.h
Add pack_half2 helper, max2_nan/min2_nan for __half2/__nv_bfloat162, uint1 warp-shuffle specializations; add packed BF16x2/FP16x2 reduction functors including NaN-aware variants.
CUDA codegen (pack/extract emission)
src/target/codegen_cuda.cc
Route min_nan/max_nan to tl::*2_nan, extend packed tl::*2 fast-path, emit tl::pack_half2 for packed fp16/bf16 construction, and implement ExtractElement for packed bf16x2/fp16x2 via reinterpret.
Builtins
src/op/builtin.cc, src/op/builtin.h
Introduce new packed intrinsics tl.max2_nan and tl.min2_nan (declarations and registrations).
Runtime logging VLOG gating
src/runtime/logging.cc
Add TvmLogDebugSettings::ParseSpec(const char*) and VerboseEnabledImpl(const std::string&, int) to parse debug-spec strings and enable per-filename VLOG levels with "DEFAULT" fallback.
Tests
testing/python/language/test_tilelang_language_reduce.py
Expand reduce tests with bfloat16 packed fragment→fragment cases and CUDA-only NaN-propagation packed tests; relax a batch-size assertion and add runtime checks for NaN behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Thread as Thread (warp/thread)
    participant PackBuf as Packed Buffer (uint1 x N)
    participant AllReduce as tl::AllReduce (external builtin)
    participant Clear as clear_buffer (scalar lanes)

    Thread->>PackBuf: allocate/init packed identity (vsize=2) or scalar fallback
    Thread->>PackBuf: thread-local accumulate (packed lanes)
    Thread->>AllReduce: pack/unpack & call run_batch on packed workspace
    AllReduce-->>Thread: packed reduced results
    Thread->>Clear: unpack/extract lanes -> write back to clear_buffer
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • bucket-xv
  • LeiWang1999

Poem

🐰 Two halves I tuck in velvet rows,
hop lanes together where the warp stream flows.
Pack, reduce, then gently unpack —
buffers snug for every stack.
Hoppity compute, two-lane glow.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.58% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding auto-vectorization of bf16/fp16 reductions using packed add2 intrinsics, which is the core objective across multiple files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_reduce.py (1)

78-79: Add a codegen assertion for the packed bf16 path.

These new cases only prove numerical correctness. A scalar fallback would still pass, so the optimization this PR is about can regress unnoticed unless at least one fragment→fragment bf16 case also asserts the generated source contains the packed reducer path (for example tl::add2 / packed bf16 pair handling).

Based on learnings, testing/python/transform/test_tilelang_transform_hoist_broadcast_values.py validates transforms by checking generated kernel source patterns.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/language/test_tilelang_language_reduce.py` around lines 78 -
79, The new bf16 fragment→fragment test cases only check numerical correctness;
add a codegen-level assertion in
testing/python/language/test_tilelang_language_reduce.py that when a reduce with
("sum", T.bfloat16, ..., "fragment", "fragment", ...) is generated the kernel
source contains the packed-bf16 reduction path (e.g. a pattern like "tl::add2"
or the packed bf16 pair handling code) so the packed reducer is actually
emitted; locate the fragment→fragment bf16 case and after generating the kernel
source assert the expected packed-bf16 pattern is present.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/reduce.cc`:
- Around line 175-177: The packed reduction branch currently treats isSum() and
isAbsSum() the same, dropping the absolute value for abssum; change the logic in
src/op/reduce.cc so isAbsSum() is handled separately: when type->isAbsSum()
construct the reduction as add2(acc, abs(b)) (i.e., wrap the packed element `b`
with the absolute operation before adding) instead of using plain add2; update
the conditional (replace the combined if (type->isSum() || type->isAbsSum())
branch) so type->isSum() uses Call(..., tl::add2(), {acc, b}) and
type->isAbsSum() uses Call(..., tl::add2(), {acc, Call(..., tl::abs(), {b})}).

In `@src/runtime/logging.cc`:
- Line 113: The code uses settings.vlog_level_map_.emplace(name, level_val)
which silently keeps the first value on duplicate keys; change this to
explicitly handle duplicates by either overwriting with
settings.vlog_level_map_.insert_or_assign(name, level_val) (or operator[]) so
later entries take effect, or detect an existing key (find on vlog_level_map_)
and fail fast/log an error before inserting; update the surrounding log message
to reflect the actual effective level based on the chosen behavior (overwrite or
error).
- Around line 111-112: The unconditional LOG(INFO) used when parsing
TVM_LOG_DEBUG (the statement emitting "TVM_LOG_DEBUG enables VLOG statements in
'...'" in src/runtime/logging.cc) should not print to stderr for every spec
entry; replace it with a gated log such as VLOG(1) (or DLOG(INFO)) or remove it
entirely so the message only appears when verbose logging is enabled; update the
single LOG(INFO) call accordingly to use VLOG(1) or drop the emission.
- Around line 104-106: The code currently calls strtol(level.c_str(),
&end_of_level, 10) and narrows the result into int level_val without range
checking; change the logic in the parsing branch around end_of_level/level_val
to parse into a long (keep strtol), then check that end_of_level points to the
end of the string, errno != ERANGE, and that the returned long is between
INT_MIN and INT_MAX before static_cast<int>ing to level_val; if any check fails,
treat the input as malformed (reject/log error) instead of silently casting to
an incorrect VLOG level.

In `@src/target/codegen_cuda.cc`:
- Around line 4369-4389: The ExtractElement handling uses
PrintExpr(op->vectors[0]) directly and takes its address, which can produce an
invalid address-to-temporary; change to materialize the vector into an SSA
variable (use SSAGetID or the existing pattern used for packed vectors) and use
that SSA name instead of PrintExpr(...) when forming the reinterpret cast; also
add a lane bounds check (ensure lane is 0 or 1) before selecting "x"/"y" to
mirror the validation used elsewhere; update both the bfloat16 path
(enable_bf16_) and float16 path (enable_fp16_) branches to use the SSA variable
and validated lane.

---

Nitpick comments:
In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 78-79: The new bf16 fragment→fragment test cases only check
numerical correctness; add a codegen-level assertion in
testing/python/language/test_tilelang_language_reduce.py that when a reduce with
("sum", T.bfloat16, ..., "fragment", "fragment", ...) is generated the kernel
source contains the packed-bf16 reduction path (e.g. a pattern like "tl::add2"
or the packed bf16 pair handling code) so the packed reducer is actually
emitted; locate the fragment→fragment bf16 case and after generating the kernel
source assert the expected packed-bf16 pattern is present.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 62642205-e2fb-4189-9a22-c1215d2e0405

📥 Commits

Reviewing files that changed from the base of the PR and between 5d09b5d and 883a75b.

📒 Files selected for processing (5)
  • src/op/reduce.cc
  • src/op/reduce.h
  • src/runtime/logging.cc
  • src/target/codegen_cuda.cc
  • testing/python/language/test_tilelang_language_reduce.py

Comment thread src/op/reduce.cc Outdated
Comment thread src/runtime/logging.cc
Comment on lines +104 to +106
char *end_of_level = nullptr;
int level_val = static_cast<int>(strtol(level.c_str(), &end_of_level, 10));
if (end_of_level != level.c_str() + level.size()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n src/runtime/logging.cc | sed -n '90,120p'

Repository: tile-ai/tilelang

Length of output: 1541


🏁 Script executed:

rg -A 5 -B 5 "vlog_level_map_" src/runtime/logging.cc

Repository: tile-ai/tilelang

Length of output: 761


🏁 Script executed:

rg "VLOG\|vlog.*level" src/runtime/logging.cc | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Check the broader function to understand the scope
sed -n '60,116p' src/runtime/logging.cc | cat -n

Repository: tile-ai/tilelang

Length of output: 2293


Add range validation before narrowing strtol result to int.

On lines 105–106, strtol parses the level string as a long and directly casts to int without checking whether the value is in range. If a user provides a value outside [INT_MIN, INT_MAX], the cast silently wraps to an incorrect VLOG level instead of rejecting the input as malformed.

Suggested fix
+    errno = 0;
+    long parsed_level = std::strtol(level.c_str(), &end_of_level, 10);
+    if (errno == ERANGE || parsed_level < std::numeric_limits<int>::min() ||
+        parsed_level > std::numeric_limits<int>::max()) {
+      LOG(FATAL) << "TVM_LOG_DEBUG ill-formed at position " << tell_pos(level)
+                 << ": level out of range: \"" << level << "\"";
+    }
-    int level_val = static_cast<int>(strtol(level.c_str(), &end_of_level, 10));
+    int level_val = static_cast<int>(parsed_level);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/runtime/logging.cc` around lines 104 - 106, The code currently calls
strtol(level.c_str(), &end_of_level, 10) and narrows the result into int
level_val without range checking; change the logic in the parsing branch around
end_of_level/level_val to parse into a long (keep strtol), then check that
end_of_level points to the end of the string, errno != ERANGE, and that the
returned long is between INT_MIN and INT_MAX before static_cast<int>ing to
level_val; if any check fails, treat the input as malformed (reject/log error)
instead of silently casting to an incorrect VLOG level.

Comment thread src/runtime/logging.cc
Comment on lines +111 to +112
LOG(INFO) << "TVM_LOG_DEBUG enables VLOG statements in '" << name
<< "' up to level " << level;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid unconditional INFO output while parsing TVM_LOG_DEBUG.

These lines print to stderr for every valid spec entry, even when the caller only wanted to configure VLOG gating. That makes the flag itself user-visible and can pollute tests or tooling that assert on stderr. Prefer VLOG/DLOG here, or drop the log entirely.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/runtime/logging.cc` around lines 111 - 112, The unconditional LOG(INFO)
used when parsing TVM_LOG_DEBUG (the statement emitting "TVM_LOG_DEBUG enables
VLOG statements in '...'" in src/runtime/logging.cc) should not print to stderr
for every spec entry; replace it with a gated log such as VLOG(1) (or
DLOG(INFO)) or remove it entirely so the message only appears when verbose
logging is enabled; update the single LOG(INFO) call accordingly to use VLOG(1)
or drop the emission.

Comment thread src/runtime/logging.cc
}
LOG(INFO) << "TVM_LOG_DEBUG enables VLOG statements in '" << name
<< "' up to level " << level;
settings.vlog_level_map_.emplace(name, level_val);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle duplicate file entries explicitly.

Line 113 uses emplace, so repeated keys keep the first level silently. That makes duplicates behave incorrectly and also leaves the preceding log message lying about the effective level. Either overwrite (insert_or_assign / operator[]) or fail fast on duplicates.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/runtime/logging.cc` at line 113, The code uses
settings.vlog_level_map_.emplace(name, level_val) which silently keeps the first
value on duplicate keys; change this to explicitly handle duplicates by either
overwriting with settings.vlog_level_map_.insert_or_assign(name, level_val) (or
operator[]) so later entries take effect, or detect an existing key (find on
vlog_level_map_) and fail fast/log an error before inserting; update the
surrounding log message to reflect the actual effective level based on the
chosen behavior (overwrite or error).

Comment on lines +4369 to +4389
// Handle ExtractElement: extract a scalar lane from a bfloat16x2 / float16x2
// vector (produced by packed reduction, etc.). The vector is stored as an
// opaque uint1 in the lowered code, but semantically it is a packed pair.
DataType vec_t = op->vectors.size() == 1 ? op->vectors[0].dtype() : DataType();
bool vec_is_bf16x2 = vec_t.is_bfloat16() && vec_t.lanes() == 2;
bool vec_is_fp16x2 = vec_t.is_float16() && vec_t.lanes() == 2;
if ((vec_is_bf16x2 || vec_is_fp16x2) && op->vectors.size() == 1 &&
op->indices.size() == 1) {
int lane = Downcast<IntImm>(op->indices[0])->value;
std::string vec = PrintExpr(op->vectors[0]);
if (vec_is_bf16x2) {
enable_bf16_ = true;
os << "bfloat16_t(((nv_bfloat162*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
} else {
enable_fp16_ = true;
os << "half_t(((half2*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
}
return;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Look at how reinterpret casting is typically done in this codegen
rg "reinterpret_cast" src/target/codegen_cuda.cc -B 2 -A 2 | head -40

Repository: tile-ai/tilelang

Length of output: 1877


Materialize the packed vector before reinterpret-casting.

The code at line 4378 passes PrintExpr(op->vectors[0]) directly into &(...). Since op->vectors[0] is an arbitrary Expr, PrintExpr may emit a function call or other expression that evaluates to a temporary. In C++, taking the address of a temporary is invalid. Materialize the expression using SSAGetID to ensure it's a valid lvalue (this pattern is already used at line 4245–4246 for the same packed vector types).

Also add a lane bounds check: for x2 vectors, valid lanes are 0 and 1 (this validation is already done at line 4248 for similar operations).

🛠️ Suggested fix
    int lane = Downcast<IntImm>(op->indices[0])->value;
+   ICHECK(lane == 0 || lane == 1)
+       << "ExtractElement on packed x2 expects lane 0 or 1, got " << lane;
-   std::string vec = PrintExpr(op->vectors[0]);
+   std::string vec = SSAGetID(PrintExpr(op->vectors[0]), vec_t);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Handle ExtractElement: extract a scalar lane from a bfloat16x2 / float16x2
// vector (produced by packed reduction, etc.). The vector is stored as an
// opaque uint1 in the lowered code, but semantically it is a packed pair.
DataType vec_t = op->vectors.size() == 1 ? op->vectors[0].dtype() : DataType();
bool vec_is_bf16x2 = vec_t.is_bfloat16() && vec_t.lanes() == 2;
bool vec_is_fp16x2 = vec_t.is_float16() && vec_t.lanes() == 2;
if ((vec_is_bf16x2 || vec_is_fp16x2) && op->vectors.size() == 1 &&
op->indices.size() == 1) {
int lane = Downcast<IntImm>(op->indices[0])->value;
std::string vec = PrintExpr(op->vectors[0]);
if (vec_is_bf16x2) {
enable_bf16_ = true;
os << "bfloat16_t(((nv_bfloat162*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
} else {
enable_fp16_ = true;
os << "half_t(((half2*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
}
return;
}
// Handle ExtractElement: extract a scalar lane from a bfloat16x2 / float16x2
// vector (produced by packed reduction, etc.). The vector is stored as an
// opaque uint1 in the lowered code, but semantically it is a packed pair.
DataType vec_t = op->vectors.size() == 1 ? op->vectors[0].dtype() : DataType();
bool vec_is_bf16x2 = vec_t.is_bfloat16() && vec_t.lanes() == 2;
bool vec_is_fp16x2 = vec_t.is_float16() && vec_t.lanes() == 2;
if ((vec_is_bf16x2 || vec_is_fp16x2) && op->vectors.size() == 1 &&
op->indices.size() == 1) {
int lane = Downcast<IntImm>(op->indices[0])->value;
ICHECK(lane == 0 || lane == 1)
<< "ExtractElement on packed x2 expects lane 0 or 1, got " << lane;
std::string vec = SSAGetID(PrintExpr(op->vectors[0]), vec_t);
if (vec_is_bf16x2) {
enable_bf16_ = true;
os << "bfloat16_t(((nv_bfloat162*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
} else {
enable_fp16_ = true;
os << "half_t(((half2*)(&(" << vec << ")))->"
<< (lane == 0 ? "x" : "y") << ")";
}
return;
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_cuda.cc` around lines 4369 - 4389, The ExtractElement
handling uses PrintExpr(op->vectors[0]) directly and takes its address, which
can produce an invalid address-to-temporary; change to materialize the vector
into an SSA variable (use SSAGetID or the existing pattern used for packed
vectors) and use that SSA name instead of PrintExpr(...) when forming the
reinterpret cast; also add a lane bounds check (ensure lane is 0 or 1) before
selecting "x"/"y" to mirror the validation used elsewhere; update both the
bfloat16 path (enable_bf16_) and float16 path (enable_fp16_) branches to use the
SSA variable and validated lane.

- Add tl::pack_half2 helper in common.h that returns uint1, avoiding taking
  address of a temporary __pack_half2 return value
- Update ShuffleNode codegen for fp16x2 to use pack_half2 instead of
  &(__pack_half2(...))
- Re-enable fp16 batch pack (remove bf16-only restriction in can_batch_pack)
- Add SumOp_fp16x2/MaxOp_fp16x2/MinOp_fp16x2 packed reducer structs
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
src/target/codegen_cuda.cc (1)

4375-4384: ⚠️ Potential issue | 🟠 Major

Materialize the packed vector before taking its address, and reject out-of-range lanes.

PrintExpr(op->vectors[0]) can expand to a temporary, so &(... ) here is not a safe lvalue. This path also silently maps any non-zero lane to .y instead of enforcing 0/1.

🛠️ Proposed fix
     int lane = Downcast<IntImm>(op->indices[0])->value;
-    std::string vec = PrintExpr(op->vectors[0]);
+    ICHECK(lane == 0 || lane == 1)
+        << "ExtractElement on packed x2 expects lane 0 or 1, got " << lane;
+    std::string vec = SSAGetID(PrintExpr(op->vectors[0]), vec_t);
     if (vec_is_bf16x2) {
       enable_bf16_ = true;
       os << "bfloat16_t(((nv_bfloat162*)(&(" << vec << ")))->"
#!/bin/bash
rg -n "VisitExpr_\\(const ShuffleNode \\*op" src/target/codegen_cuda.cc -A70 -B10
rg -n "SSAGetID\\(PrintExpr\\(" src/target/codegen_cuda.cc -C1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_cuda.cc` around lines 4375 - 4384, The code is taking the
address of PrintExpr(op->vectors[0]) which may be a temporary and therefore
unsafe, and it also silently treats any non-zero lane as .y; fix by
materializing op->vectors[0] into a local named temporary (use the same
SSA/PrintExpr helper used elsewhere, e.g., generate a tmp with
SSAGetID(PrintExpr(...)) or assign PrintExpr(...) to a local string temp
variable) and then take the address of that local when forming the
half2/nv_bfloat162 access; additionally validate the lane extracted from
Downcast<IntImm>(op->indices[0])->value to only allow 0 or 1 and
reject/out-of-range lanes (fail fast or emit a diagnostic), and preserve the
existing flags enable_bf16_ / enable_fp16_ in the bf16/float16 branches.
src/op/reduce.cc (1)

174-175: ⚠️ Potential issue | 🔴 Critical

Packed abssum still omits abs2.

This branch still treats sum and abssum the same, so packed abssum computes a plain sum instead of sum(abs(x)).

🐛 Suggested fix
-  if (type->isSum() || type->isAbsSum()) {
+  if (type->isSum()) {
     return Call(acc.dtype(), tl::add2(), {acc, b});
+  } else if (type->isAbsSum()) {
+    return Call(acc.dtype(), tl::add2(),
+                {acc, Call(acc.dtype(), tl::abs2(), {b})});
   } else if (type->isMax()) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/reduce.cc` around lines 174 - 175, The branch in src/op/reduce.cc
treats isSum() and isAbsSum() the same, causing packed abssum to compute a plain
sum; update the branch so when type->isAbsSum() you wrap b with the abs2
operation before adding (e.g., replace the current return Call(acc.dtype(),
tl::add2(), {acc, b}) with a return that uses Call(acc.dtype(), tl::add2(),
{acc, Call(acc.dtype(), tl::abs2(), {b})}) when isAbsSum() so abssum computes
sum(abs(x)).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/reduce.cc`:
- Around line 442-481: The packed-path is forcing reinterpretation to
cfg->vec_dtype (see GetReducePackConfig, cfg->vec_dtype and src_writer->dtype)
which breaks mixed-precision reductions; instead guard enabling packing so it
only runs when source element dtype matches the accumulator/clear buffer dtype
(clear_buffer->dtype) or perform an explicit widening conversion after the
packed BufferLoad (src_load) rather than retyping the load in-place; update the
conditional around can_pack/need_pack_buffer (and places that set
clear_buffer_packed) to check dtype equality between src_buffer/
src_var_compressed.back()->dtype and clear_buffer->dtype, and if dtypes differ
either skip packing or insert an explicit vectorized cast from cfg->vec_dtype to
the accumulator dtype after src_load rather than setting src_writer->dtype
directly.

In `@src/target/codegen_cuda.cc`:
- Around line 4362-4363: The generated code wraps tl::pack_half2 in an extra
uint1{...} causing nested uint1 initializers; update the emitter in
codegen_cuda.cc that currently does os << "uint1{tl::pack_half2(" << e0 << ", "
<< e1 << ")}"; to instead emit os << "tl::pack_half2(" << e0 << ", " << e1 <<
")"; so the output uses the single uint1 returned by tl::pack_half2 (also check
similar pack emitters to match the __pack_half2 style).

---

Duplicate comments:
In `@src/op/reduce.cc`:
- Around line 174-175: The branch in src/op/reduce.cc treats isSum() and
isAbsSum() the same, causing packed abssum to compute a plain sum; update the
branch so when type->isAbsSum() you wrap b with the abs2 operation before adding
(e.g., replace the current return Call(acc.dtype(), tl::add2(), {acc, b}) with a
return that uses Call(acc.dtype(), tl::add2(), {acc, Call(acc.dtype(),
tl::abs2(), {b})}) when isAbsSum() so abssum computes sum(abs(x)).

In `@src/target/codegen_cuda.cc`:
- Around line 4375-4384: The code is taking the address of
PrintExpr(op->vectors[0]) which may be a temporary and therefore unsafe, and it
also silently treats any non-zero lane as .y; fix by materializing
op->vectors[0] into a local named temporary (use the same SSA/PrintExpr helper
used elsewhere, e.g., generate a tmp with SSAGetID(PrintExpr(...)) or assign
PrintExpr(...) to a local string temp variable) and then take the address of
that local when forming the half2/nv_bfloat162 access; additionally validate the
lane extracted from Downcast<IntImm>(op->indices[0])->value to only allow 0 or 1
and reject/out-of-range lanes (fail fast or emit a diagnostic), and preserve the
existing flags enable_bf16_ / enable_fp16_ in the bf16/float16 branches.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 00418980-17e6-4611-b343-942ca14bf682

📥 Commits

Reviewing files that changed from the base of the PR and between 883a75b and b68f250.

📒 Files selected for processing (4)
  • src/op/reduce.cc
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/reduce.h

Comment thread src/op/reduce.cc Outdated
Comment thread src/op/reduce.cc Outdated
Comment on lines +4362 to +4363
os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}";
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -n "pack_half2|uint1\\{tl::pack_half2" src/target/codegen_cuda.cc src/tl_templates/cuda/common.h -C2

Repository: tile-ai/tilelang

Length of output: 2857


Remove the extra uint1{} around tl::pack_half2.

tl::pack_half2 at line 680 in src/tl_templates/cuda/common.h already returns uint1, so wrapping it again produces nested initialization uint1{uint1{...}} in the generated code. Other pack operations in the same file use __pack_half2 without extra wrapping.

Proposed fix
-      os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}";
+      os << "tl::pack_half2(" << e0 << ", " << e1 << ")";
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}";
}
os << "tl::pack_half2(" << e0 << ", " << e1 << ")";
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_cuda.cc` around lines 4362 - 4363, The generated code
wraps tl::pack_half2 in an extra uint1{...} causing nested uint1 initializers;
update the emitter in codegen_cuda.cc that currently does os <<
"uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}"; to instead emit os <<
"tl::pack_half2(" << e0 << ", " << e1 << ")"; so the output uses the single
uint1 returned by tl::pack_half2 (also check similar pack emitters to match the
__pack_half2 style).

- Replace ReducePackConfig with simple GetPreferedVectorizedSize(DataType, Target) -> int
- Merge getBatchReducerSuffix into MakeCodegenReducer(vsize)
- MakeReduce(vsize, acc, b) -> optional<PrimExpr>
- MakeInitValue(vsize) handles Broadcast internally
- Guard nan_propagate in MakeCodegenReducer vsize>1 path (no packed nan ops exist)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/reduce.h`:
- Around line 130-138: The header declares std::optional in MakeReduce and
MakeCodegenReducer but doesn’t include <optional>, relying on transitive
includes; add an explicit `#include` <optional> at the top of this header (before
the declarations of MakeInitValue, MakeReduce, and MakeCodegenReducer) so
MakeReduce and MakeCodegenReducer compile reliably without depending on other
headers.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b0dd0fb7-522b-4a7e-92b4-cf83fa26e1e5

📥 Commits

Reviewing files that changed from the base of the PR and between b68f250 and 224ba27.

📒 Files selected for processing (2)
  • src/op/reduce.cc
  • src/op/reduce.h
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/op/reduce.cc

Comment thread src/op/reduce.h Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
src/op/reduce.cc (3)

624-629: ⚠️ Potential issue | 🔴 Critical

Batched packing still ignores nan_propagate.

Line 625 can enable packed batch AllReduce for fp16/bf16 max/min/absmax even though the packed local path already disables that case. From there, Lines 628-629 can synthesize tl::MaxOpNan_*x2 / tl::MinOpNan_*x2, but src/tl_templates/cuda/reduce.h:91-131 only defines packed SumOp_*x2, MaxOp_*x2, and MinOp_*x2.

🐛 Proposed fix
-        bool can_batch_pack = vsize > 1 && batch >= vsize && batch % vsize == 0;
+        bool can_batch_pack =
+            !nan_propagate && vsize > 1 && batch >= vsize &&
+            batch % vsize == 0;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/reduce.cc` around lines 624 - 629, The batched pack condition enables
packed reducers even when nan_propagate should disable them; update the logic
around GetPreferedVectorizedSize/GetPreferedVectorizedSize(...)->vsize and
can_batch_pack so that can_batch_pack is false when nan_propagate is set for
float16/bfloat16 and for max/min/absmax reducers (i.e., only allow vsize>1 when
nan propagation is not required), then compute eff_batch and call
MakeCodegenReducer with vsize only when can_batch_pack is true (otherwise pass
1) so MakeCodegenReducer does not synthesize tl::MaxOpNan_*x2 / tl::MinOpNan_*x2
for unsupported packed nan ops.

431-468: ⚠️ Potential issue | 🔴 Critical

Guard packed local lowering on matching source/accumulator dtypes.

The scalar path handles mixed precision by casting rhs, but Lines 466-468 force the source load to vec_dtype in place. If src_buffer->dtype != clear_buffer->dtype, this becomes a raw reinterpretation instead of a widening conversion.

🐛 Proposed fix
-      if (vsize > 1 && !src_var_compressed.empty() && !nan_propagate) {
+      if (vsize > 1 && !src_var_compressed.empty() && !nan_propagate &&
+          src_buffer->dtype == clear_buffer->dtype) {
         auto *ext = src_var_compressed.back()->dom->extent.as<IntImmNode>();
         if (ext && ext->value >= vsize && ext->value % vsize == 0) {
           can_pack = true;
           DataType vec_dtype = clear_buffer->dtype.with_lanes(vsize);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/reduce.cc` around lines 431 - 468, The code currently forces the
vectorized load by mutating src_writer->dtype to vec_dtype which reinterprets
bits when src_buffer->dtype != clear_buffer->dtype; instead, check if
src_buffer->dtype == clear_buffer->dtype before changing dtype, and otherwise
create an explicit widening Cast of the scalar load to vec_dtype (e.g., replace
mutating src_writer->dtype with a new expr like Cast(vec_dtype, src_load) or
produce a Cast of each lane) so mixed-precision is handled as a proper
conversion; update the use of src_load/src_writer to use the Cast when types
differ and only set src_writer->dtype when the underlying buffer dtypes already
match.

172-173: ⚠️ Potential issue | 🔴 Critical

Packed abssum still drops abs().

Line 172 folds isAbsSum() into the same packed reducer as isSum(), so any 2-lane packed abssum becomes a plain sum. The scalar path on Lines 144-145 already shows the intended semantics.

🐛 Proposed fix
-  if (type->isSum() || type->isAbsSum()) {
-    return Call(acc.dtype(), tl::add2(), {acc, b});
+  if (type->isSum()) {
+    return Call(acc.dtype(), tl::add2(), {acc, b});
+  } else if (type->isAbsSum()) {
+    return Call(acc.dtype(), tl::add2(),
+                {acc, Call(acc.dtype(), tl::abs2(), {b})});
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/reduce.cc` around lines 172 - 173, The packed reducer branch
incorrectly treats type->isAbsSum() like isSum(), causing packed abssum to drop
the abs; change the conditional so isAbsSum() is handled separately: when
type->isSum() keep using Call(acc.dtype(), tl::add2(), {acc, b}), but when
type->isAbsSum() compute the absolute of the incoming lane (e.g.
Call(acc.dtype(), tl::abs(), {b}) or equivalent) and then Call(acc.dtype(),
tl::add2(), {acc, abs_b}); update the branch that currently checks type->isSum()
|| type->isAbsSum() to distinguish the two cases using the unique symbols
type->isAbsSum(), acc, b, tl::add2(), and tl::abs() so packed abssum preserves
abs() semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@src/op/reduce.cc`:
- Around line 624-629: The batched pack condition enables packed reducers even
when nan_propagate should disable them; update the logic around
GetPreferedVectorizedSize/GetPreferedVectorizedSize(...)->vsize and
can_batch_pack so that can_batch_pack is false when nan_propagate is set for
float16/bfloat16 and for max/min/absmax reducers (i.e., only allow vsize>1 when
nan propagation is not required), then compute eff_batch and call
MakeCodegenReducer with vsize only when can_batch_pack is true (otherwise pass
1) so MakeCodegenReducer does not synthesize tl::MaxOpNan_*x2 / tl::MinOpNan_*x2
for unsupported packed nan ops.
- Around line 431-468: The code currently forces the vectorized load by mutating
src_writer->dtype to vec_dtype which reinterprets bits when src_buffer->dtype !=
clear_buffer->dtype; instead, check if src_buffer->dtype == clear_buffer->dtype
before changing dtype, and otherwise create an explicit widening Cast of the
scalar load to vec_dtype (e.g., replace mutating src_writer->dtype with a new
expr like Cast(vec_dtype, src_load) or produce a Cast of each lane) so
mixed-precision is handled as a proper conversion; update the use of
src_load/src_writer to use the Cast when types differ and only set
src_writer->dtype when the underlying buffer dtypes already match.
- Around line 172-173: The packed reducer branch incorrectly treats
type->isAbsSum() like isSum(), causing packed abssum to drop the abs; change the
conditional so isAbsSum() is handled separately: when type->isSum() keep using
Call(acc.dtype(), tl::add2(), {acc, b}), but when type->isAbsSum() compute the
absolute of the incoming lane (e.g. Call(acc.dtype(), tl::abs(), {b}) or
equivalent) and then Call(acc.dtype(), tl::add2(), {acc, abs_b}); update the
branch that currently checks type->isSum() || type->isAbsSum() to distinguish
the two cases using the unique symbols type->isAbsSum(), acc, b, tl::add2(), and
tl::abs() so packed abssum preserves abs() semantics.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dae5e142-9c61-4d47-ad56-0a561c0c5afd

📥 Commits

Reviewing files that changed from the base of the PR and between 224ba27 and ce2cb1c.

📒 Files selected for processing (1)
  • src/op/reduce.cc

@kurisu6912 kurisu6912 force-pushed the feature/packed-reduce-bf16-add2 branch from b081ac5 to b1f3298 Compare April 29, 2026 04:38
@kurisu6912
Copy link
Copy Markdown
Collaborator Author

@regression-perf

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (3)
src/target/codegen_cuda.cc (2)

4371-4371: ⚠️ Potential issue | 🟠 Major

Remove the extra uint1{} around tl::pack_half2.

tl::pack_half2 already returns uint1, so this emits a nested initializer (uint1{uint1{...}}) in generated CUDA.

Proposed fix
-      os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}";
+      os << "tl::pack_half2(" << e0 << ", " << e1 << ")";
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_cuda.cc` at line 4371, The generated CUDA currently emits
a nested initializer by wrapping tl::pack_half2(...) with an extra uint1{...};
remove that outer wrapper so the code emits tl::pack_half2(e0, e1) directly;
update the site that writes to the output stream (the os <<
"uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}"; call in codegen_cuda.cc) to
instead emit just tl::pack_half2 with the same e0 and e1 expressions.

4384-4393: ⚠️ Potential issue | 🟠 Major

Materialize the packed value before extracting a lane.

This path takes &( PrintExpr(...) ), so any non-lvalue expression here generates an address-of-temporary in CUDA. It also needs the same 0/1 lane validation used by the other packed-lane extractors.

Proposed fix
     int lane = Downcast<IntImm>(op->indices[0])->value;
-    std::string vec = PrintExpr(op->vectors[0]);
+    ICHECK(lane == 0 || lane == 1)
+        << "ExtractElement on packed x2 expects lane 0 or 1, got " << lane;
+    std::string vec = SSAGetID(PrintExpr(op->vectors[0]), vec_t);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/target/codegen_cuda.cc` around lines 4384 - 4393, The code is taking the
address of PrintExpr(op->vectors[0]) which can produce an address-of-temporary;
refactor the branch handling vec_is_bf16x2 / else to first materialize the
packed vector expression into a temporary lvalue (same approach used by other
packed-lane extractors) and then take the address of that temp when building the
half2/nv_bfloat162 access; also add the same lane validation (ensure
op->indices[0] is 0 or 1) used by the other packed-lane extractors before
selecting "x" or "y", and ensure enable_bf16_/enable_fp16_ are still set
accordingly.
src/op/reduce.cc (1)

422-459: ⚠️ Potential issue | 🔴 Critical

Guard packed lowering on matching source and accumulator dtypes.

This path still force-retypes src_load to vec_dtype derived from clear_buffer->dtype. For mixed-precision reductions that becomes a raw vector reinterpretation of the source buffer instead of a conversion.

Proposed fix
-      if (vsize > 1 && !src_var_compressed.empty()) {
+      if (vsize > 1 && !src_var_compressed.empty() &&
+          src_buffer->dtype == clear_buffer->dtype) {
         auto *ext = src_var_compressed.back()->dom->extent.as<IntImmNode>();
         if (ext && ext->value >= vsize && ext->value % vsize == 0) {
           can_pack = true;
           DataType vec_dtype = clear_buffer->dtype.with_lanes(vsize);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/reduce.cc` around lines 422 - 459, The code currently forces
reinterpretation of the source load to vec_dtype (derived from
clear_buffer->dtype) which breaks mixed-precision reductions; restrict the
packed-lowering path so it only enables when the source element dtype matches
the accumulator/clear_buffer dtype. Concretely, augment the can_pack condition
(around GetPreferedVectorizedSize(...) and src_var_compressed usage) with a
dtype equality check (compare the source buffer element dtype — e.g., via
src_buffer->dtype or src_var_compressed.back()->dtype — against
clear_buffer->dtype) and only then create clear_buffer_packed, set
need_pack_buffer, and assign src_writer->dtype = vec_dtype; if dtypes differ,
skip the packed path (or perform an explicit conversion instead of retyping) so
src_load is not naively reinterpreted.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/op/reduce.cc`:
- Around line 132-158: The vsize==1 branch of ReduceOpNode::MakeReduce no longer
handles scalar bitwise reducers (bitand/bitor/bitxor), causing LOG(FATAL) when
scalar path calls MakeReduce(1); fix by adding cases for type->isBitAnd(),
type->isBitOr(), and type->isBitXor() in the vsize==1 block of
ReduceOpNode::MakeReduce: ensure rhs is cast to acc.dtype (as already done),
then return the appropriate bitwise expression (acc & rhs, acc | rhs, acc ^ rhs)
and only apply them for integer dtypes (or assert/handle otherwise), mirroring
how Sum/Max/Min are handled so scalar reductions succeed.

In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 310-312: The section header string "nan_propagate tests – packed
(vsize=2) path for bf16/fp16" contains a Unicode en-dash (–) that triggers Ruff
RUF003; replace it with a normal ASCII hyphen (-) so the header reads
"nan_propagate tests - packed (vsize=2) path for bf16/fp16" to clear the lint
warning.
- Around line 319-329: The helper _make_nan_reduce_kernel should accept and
thread a batch parameter so the test can exercise the batched reducer path: add
a batch argument (default 1) to _make_nan_reduce_kernel, use T.Kernel(batch,
threads=threads) in the inner kernel, allocate dst and B with shape (M, batch)
(and dst fragment as (M, batch)), and pass batch=batch through to the reduce_fn
call (i.e., reduce_fn(src, dst, dim=1, nan_propagate=nan_propagate,
batch=batch)) so the batched reduction path is emitted.

---

Duplicate comments:
In `@src/op/reduce.cc`:
- Around line 422-459: The code currently forces reinterpretation of the source
load to vec_dtype (derived from clear_buffer->dtype) which breaks
mixed-precision reductions; restrict the packed-lowering path so it only enables
when the source element dtype matches the accumulator/clear_buffer dtype.
Concretely, augment the can_pack condition (around
GetPreferedVectorizedSize(...) and src_var_compressed usage) with a dtype
equality check (compare the source buffer element dtype — e.g., via
src_buffer->dtype or src_var_compressed.back()->dtype — against
clear_buffer->dtype) and only then create clear_buffer_packed, set
need_pack_buffer, and assign src_writer->dtype = vec_dtype; if dtypes differ,
skip the packed path (or perform an explicit conversion instead of retyping) so
src_load is not naively reinterpreted.

In `@src/target/codegen_cuda.cc`:
- Line 4371: The generated CUDA currently emits a nested initializer by wrapping
tl::pack_half2(...) with an extra uint1{...}; remove that outer wrapper so the
code emits tl::pack_half2(e0, e1) directly; update the site that writes to the
output stream (the os << "uint1{tl::pack_half2(" << e0 << ", " << e1 << ")}";
call in codegen_cuda.cc) to instead emit just tl::pack_half2 with the same e0
and e1 expressions.
- Around line 4384-4393: The code is taking the address of
PrintExpr(op->vectors[0]) which can produce an address-of-temporary; refactor
the branch handling vec_is_bf16x2 / else to first materialize the packed vector
expression into a temporary lvalue (same approach used by other packed-lane
extractors) and then take the address of that temp when building the
half2/nv_bfloat162 access; also add the same lane validation (ensure
op->indices[0] is 0 or 1) used by the other packed-lane extractors before
selecting "x" or "y", and ensure enable_bf16_/enable_fp16_ are still set
accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4db743dd-b757-4804-85ab-b89071be61c1

📥 Commits

Reviewing files that changed from the base of the PR and between ce2cb1c and b1f3298.

📒 Files selected for processing (7)
  • src/op/builtin.cc
  • src/op/builtin.h
  • src/op/reduce.cc
  • src/target/codegen_cuda.cc
  • src/tl_templates/cuda/common.h
  • src/tl_templates/cuda/reduce.h
  • testing/python/language/test_tilelang_language_reduce.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tl_templates/cuda/reduce.h

Comment thread src/op/reduce.cc Outdated
Comment on lines +132 to +158
std::optional<PrimExpr> ReduceOpNode::MakeReduce(int vsize, const PrimExpr &acc,
const PrimExpr &b) const {
if (vsize == 1) {

PrimExpr rhs = b;
if (acc->dtype != rhs->dtype) {
rhs = Cast(acc->dtype, rhs);
}
const bool use_nan_op = nan_propagate && (acc.dtype().is_float16() ||
acc.dtype().is_bfloat16());
if (type->isSum()) {
return acc + rhs;
} else if (type->isAbsSum()) {
return acc + Max(rhs, -rhs);
} else if (type->isMax()) {
return use_nan_op ? Call(acc.dtype(), tl::max_nan(), {acc, rhs})
: PrimExpr(Max(acc, rhs));
} else if (type->isMin()) {
return use_nan_op ? Call(acc.dtype(), tl::min_nan(), {acc, rhs})
: PrimExpr(Min(acc, rhs));
} else if (type->isAbsMax()) {
auto abs_rhs = Max(rhs, -rhs);
return use_nan_op ? Call(acc.dtype(), tl::max_nan(), {acc, abs_rhs})
: PrimExpr(Max(acc, abs_rhs));
}
LOG(FATAL) << "Unsupported reduce type: " << type->type;
return std::nullopt;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Restore scalar bitwise reducers in MakeReduce.

The new vsize == 1 branch no longer handles bitand/bitor/bitxor, but the scalar fragment path below still calls MakeReduce(1).value(). Those reductions now hit LOG(FATAL) during lowering.

Proposed fix
     if (type->isSum()) {
       return acc + rhs;
     } else if (type->isAbsSum()) {
       return acc + Max(rhs, -rhs);
+    } else if (type->isBitAnd()) {
+      return bitwise_and(acc, rhs);
+    } else if (type->isBitOr()) {
+      return bitwise_or(acc, rhs);
+    } else if (type->isBitXor()) {
+      return bitwise_xor(acc, rhs);
     } else if (type->isMax()) {
       return use_nan_op ? Call(acc.dtype(), tl::max_nan(), {acc, rhs})
                         : PrimExpr(Max(acc, rhs));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/op/reduce.cc` around lines 132 - 158, The vsize==1 branch of
ReduceOpNode::MakeReduce no longer handles scalar bitwise reducers
(bitand/bitor/bitxor), causing LOG(FATAL) when scalar path calls MakeReduce(1);
fix by adding cases for type->isBitAnd(), type->isBitOr(), and type->isBitXor()
in the vsize==1 block of ReduceOpNode::MakeReduce: ensure rhs is cast to
acc.dtype (as already done), then return the appropriate bitwise expression (acc
& rhs, acc | rhs, acc ^ rhs) and only apply them for integer dtypes (or
assert/handle otherwise), mirroring how Sum/Max/Min are handled so scalar
reductions succeed.

Comment on lines +310 to +312
# ---------------------------------------------------------------------------
# nan_propagate tests – packed (vsize=2) path for bf16/fp16
# ---------------------------------------------------------------------------
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Replace the Unicode dash in the section header.

Ruff flags the here (RUF003). Swapping it to - will clear the lint warning.

🧰 Tools
🪛 Ruff (0.15.12)

[warning] 311-311: Comment contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/language/test_tilelang_language_reduce.py` around lines 310 -
312, The section header string "nan_propagate tests – packed (vsize=2) path for
bf16/fp16" contains a Unicode en-dash (–) that triggers Ruff RUF003; replace it
with a normal ASCII hyphen (-) so the header reads "nan_propagate tests - packed
(vsize=2) path for bf16/fp16" to clear the lint warning.

Comment on lines +319 to +329
def _make_nan_reduce_kernel(reduce_fn, M, N, dtype, threads, *, nan_propagate):
@T.prim_func
def kernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M,), dtype)):
with T.Kernel(1, threads=threads):
src = T.alloc_fragment((M, N), dtype)
dst = T.alloc_fragment((M,), dtype)
T.copy(A, src)
reduce_fn(src, dst, dim=1, nan_propagate=nan_propagate)
T.copy(dst, B)

return kernel
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Thread batch through the nan-reduce test helper.

The helper always emits the scalar reduction call, so test_reduce_packed_max_nan_batch_runtime below never hits run_batch or the packed batched reducer path. A batched NaN regression would still pass this suite.

Proposed fix
-def _make_nan_reduce_kernel(reduce_fn, M, N, dtype, threads, *, nan_propagate):
+def _make_nan_reduce_kernel(
+    reduce_fn, M, N, dtype, threads, *, nan_propagate, batch=1
+):
     `@T.prim_func`
     def kernel(A: T.Tensor((M, N), dtype), B: T.Tensor((M,), dtype)):
         with T.Kernel(1, threads=threads):
             src = T.alloc_fragment((M, N), dtype)
             dst = T.alloc_fragment((M,), dtype)
             T.copy(A, src)
-            reduce_fn(src, dst, dim=1, nan_propagate=nan_propagate)
+            kwargs = {"nan_propagate": nan_propagate}
+            if batch != 1:
+                kwargs["batch"] = batch
+            reduce_fn(src, dst, dim=1, **kwargs)
             T.copy(dst, B)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/language/test_tilelang_language_reduce.py` around lines 319 -
329, The helper _make_nan_reduce_kernel should accept and thread a batch
parameter so the test can exercise the batched reducer path: add a batch
argument (default 1) to _make_nan_reduce_kernel, use T.Kernel(batch,
threads=threads) in the inner kernel, allocate dst and B with shape (M, batch)
(and dst fragment as (M, batch)), and pass batch=batch through to the reduce_fn
call (i.e., reduce_fn(src, dst, dim=1, nan_propagate=nan_propagate,
batch=batch)) so the batched reduction path is emitted.

@github-actions
Copy link
Copy Markdown

Performance Regression Test Report

Triggered by: @kurisu6912
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/25091296491

Results

File Original Latency Current Latency Speedup
example_group_per_split_token_cast_to_fp8 0.0103424 0.010519 0.983205
example_gemv 0.284854 0.288213 0.988346
topk_selector 0.053421 0.0538961 0.991184
example_fusedmoe_tilelang 0.132378 0.133105 0.994541
example_tilelang_sparse_gqa_decode_varlen_mask 0.0175243 0.017603 0.995528
block_sparse_attn_tilelang 0.00878019 0.00881936 0.995559
example_tilelang_nsa_fwd 0.00700546 0.00703471 0.995842
example_tilelang_nsa_decode 0.00681653 0.00684105 0.996415
example_linear_attn_fwd 0.0363162 0.0364388 0.996635
example_linear_attn_bwd 0.152772 0.153281 0.996678
fp8_lighting_indexer 0.0322893 0.0323952 0.996732
example_tilelang_sparse_gqa_decode_varlen_indice 0.0159181 0.0159673 0.99692
example_blocksparse_gemm 0.0190732 0.0191319 0.996935
sparse_mla_fwd 0.125487 0.125858 0.997053
example_dequant_gemv_fp16xint4 0.0282745 0.0283544 0.997184
example_gqa_sink_bwd_bhsd_sliding_window 0.0252039 0.0252692 0.997418
example_dequant_gemm_fp4_hopper 1.0277 1.03034 0.997434
example_per_token_cast_to_fp8 0.00735772 0.00737516 0.997635
example_gemm 0.0222669 0.0223197 0.997636
example_gemm_autotune 0.0224693 0.0225214 0.997687
sparse_mla_fwd_pipelined 0.0898107 0.0900185 0.997692
example_tilelang_block_sparse_attn 0.00862908 0.00864896 0.997701
example_gqa_bwd_tma_reduce_varlen 0.0462759 0.0463649 0.99808
example_mha_sink_fwd_bhsd 0.0151842 0.0152118 0.998182
example_tilelang_gemm_splitk_vectorize_atomicadd 1.01086 1.01263 0.998257
example_tilelang_gemm_splitk 1.00907 1.01078 0.998302
example_gqa_sink_bwd_bhsd 0.0427081 0.0427773 0.998383
sparse_mla_bwd 0.292978 0.293442 0.998417
example_mha_fwd_bshd 0.0247594 0.0247978 0.998454
example_mha_fwd_varlen 0.0443891 0.0444533 0.998555
example_warp_specialize_gemm_barrierpipe_stage2 0.040349 0.0403999 0.998741
example_mha_fwd_bhsd 0.0108109 0.010824 0.998796
example_dequant_gemm_bf16_mxfp4_hopper 0.513872 0.514348 0.999075
example_mha_sink_bwd_bhsd_sliding_window 0.0434821 0.0435212 0.999102
example_elementwise_add 0.115307 0.115409 0.999112
example_mha_bwd_bshd 0.0402553 0.0402899 0.999141
example_gemm_intrinsics 0.0348401 0.0348696 0.999152
example_dynamic 0.637518 0.63799 0.999261
example_gqa_fwd_bshd 0.0689708 0.0690167 0.999334
example_mha_bwd_bhsd 0.0392132 0.0392332 0.99949
example_convolution_autotune 0.979853 0.980344 0.9995
example_warp_specialize_gemm_softpipe_stage2 0.0275625 0.0275762 0.999503
example_tilelang_gemm_fp8_2xAcc 0.133124 0.133181 0.999573
example_gqa_bwd 0.0465154 0.0465314 0.999656
example_dequant_gemm_bf16_fp4_hopper 0.557054 0.557236 0.999675
example_tilelang_gemm_fp8 0.304967 0.305039 0.999765
example_topk 0.0111128 0.0111141 0.999887
example_mhc_post 0.109687 0.109697 0.999904
example_vertical_slash_sparse_attn 0.227567 0.227589 0.999904
example_mla_decode 0.45116 0.451179 0.999958
example_dequant_gemm_w4a8 5.57995 5.58016 0.999962
example_tilelang_gemm_fp8_intrinsic 0.842027 0.842051 0.999972
tilelang_example_sparse_tensorcore 0.0146469 0.0146472 0.999976
example_convolution 1.23641 1.23629 1.00009
example_mha_sink_bwd_bhsd 0.0645465 0.0645382 1.00013
example_warp_specialize_gemm_copy_1_gemm_0 0.0275796 0.0275755 1.00015
example_mha_sink_fwd_bhsd_sliding_window 0.0151302 0.0151273 1.00019
example_warp_specialize_gemm_copy_0_gemm_1 0.0373696 0.0373377 1.00085
example_mhc_pre 0.152453 0.152309 1.00095

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

…re/packed-reduce-bf16-add2

# Conflicts:
#	src/op/reduce.cc
#	src/op/reduce.h
#	src/runtime/logging.cc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Vectorize float/bfloat16/float16 into float2/bfloat162/float162 in TileOPs

1 participant