[Example] Add layernorm example in tilelang#2168
[Example] Add layernorm example in tilelang#2168ighoshsubho wants to merge 3 commits intotile-ai:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds a TileLang-backed LayerNorm: JIT forward and backward kernels, a PyTorch autograd Function bridging them, a public ChangesLayerNorm Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
examples/norm/layernorm.py (1)
180-187: ⚡ Quick winCheck
dgammaanddbetabefore printing success.The script only asserts
dx, so it can print “All checks pass.” even if two of the three backward outputs are wrong. Since the fused kernel’s value proposition includesdgammaanddbeta, they should be validated too.✅ Proposed addition
y_ref = ref_program(x, g, b, eps) y_ref.backward(dy) torch.testing.assert_close(dx_ours, x.grad, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(dg_ours, g.grad, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(db_ours, b.grad, rtol=1e-2, atol=1e-2) print("All checks pass.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/norm/layernorm.py` around lines 180 - 187, The test currently only asserts dx (dx_ours vs x.grad) before printing success; also compare dg_ours and db_ours to the reference gradients by running the same torch.testing.assert_close checks for dg_ours vs g.grad and db_ours vs b.grad (use the same rtol=1e-2, atol=1e-2) after calling y_ref.backward(dy) so all three gradients (dx_ours, dg_ours, db_ours) are validated before the "All checks pass." print; locate the comparison area around dx_ours, dg_ours, db_ours, ref_program, y_ref.backward and add the two additional assert_close calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/norm/layernorm.py`:
- Around line 76-123: The kernel launched ceildiv(N, blk_m) CTAs but the inner
loop for k in T.serial(blk_m) always processes blk_m rows, causing out-of-bounds
accesses; guard the per-row work by checking row < N after computing row = bx *
blk_m + k and skip all reads/writes for that iteration if false. Concretely,
wrap the body that reads DY, X, Mean, Rstd, writes DX, and updates
dgamma_acc/dbeta_acc (the code inside the for k in T.serial(blk_m) loop that
references row, DY, X, Mean, Rstd, DX, DX_smem, DX_smem copy) with an if row <
N: ... (or equivalent early continue) so only valid rows are processed.
- Around line 137-157: The backward/forward functions rely on _TORCH_DTYPE_TO_TL
lookup without validation and assume x, gamma, beta share the same dtype which
leads to KeyError and wrong casts; before calling _layernorm_fwd/_layernorm_bwd
validate that x.dtype is present in _TORCH_DTYPE_TO_TL and that gamma.dtype and
beta.dtype match x.dtype (raise a clear TypeError with supported dtype list if
not), use gamma.dtype when allocating dgamma/dbeta (don't hardcode
torch.float32), ensure any intermediate casts use the common dtype, and return
gradients in the correct order (dx, dgamma, dbeta, None) with dgamma/dbeta
converted to the original parameter dtype.
---
Nitpick comments:
In `@examples/norm/layernorm.py`:
- Around line 180-187: The test currently only asserts dx (dx_ours vs x.grad)
before printing success; also compare dg_ours and db_ours to the reference
gradients by running the same torch.testing.assert_close checks for dg_ours vs
g.grad and db_ours vs b.grad (use the same rtol=1e-2, atol=1e-2) after calling
y_ref.backward(dy) so all three gradients (dx_ours, dg_ours, db_ours) are
validated before the "All checks pass." print; locate the comparison area around
dx_ours, dg_ours, db_ours, ref_program, y_ref.backward and add the two
additional assert_close calls.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 050281ed-bab4-4bf1-bb1e-3d4c6bbb066b
📒 Files selected for processing (1)
examples/norm/layernorm.py
Adds
examples/norm/layernorm.py— a TileLang LayerNorm with forward and backward kernels, wrapped in atorch.autograd.Functionso it works as a drop-in replacement fortorch.nn.functional.layer_norm.Forward saves
(mean, rstd)per row. Backward fuses dx, dgamma, dbeta in one kernel: each CTA processesblk_m=32rows, accumulates dgamma/dbeta partials in registers, and atomic-adds them to the global tensors at exit.On B200 / bf16 / M=4096, N=8192:
Summary by CodeRabbit