Skip to content

[Example] Add layernorm example in tilelang#2168

Open
ighoshsubho wants to merge 3 commits intotile-ai:mainfrom
ighoshsubho:examples/layernorm
Open

[Example] Add layernorm example in tilelang#2168
ighoshsubho wants to merge 3 commits intotile-ai:mainfrom
ighoshsubho:examples/layernorm

Conversation

@ighoshsubho
Copy link
Copy Markdown

@ighoshsubho ighoshsubho commented May 8, 2026

Adds examples/norm/layernorm.py — a TileLang LayerNorm with forward and backward kernels, wrapped in a torch.autograd.Function so it works as a drop-in replacement for torch.nn.functional.layer_norm.

Forward saves (mean, rstd) per row. Backward fuses dx, dgamma, dbeta in one kernel: each CTA processes blk_m=32 rows, accumulates dgamma/dbeta partials in registers, and atomic-adds them to the global tensors at exit.

On B200 / bf16 / M=4096, N=8192:

  • fwd: ~1.7× faster than torch
  • bwd: ~1.2× faster than torch

Summary by CodeRabbit

  • New Features
    • Added a layer normalization example demonstrating JIT-compiled forward/backward kernels with autograd integration.
    • Includes a convenience wrapper for easy use, a reference implementation for correctness checks, built-in forward/backward validation, and performance benchmarking tools.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 8, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Review Change Stack
No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 82c78014-9512-4ea0-a243-5ad28f83352c

📥 Commits

Reviewing files that changed from the base of the PR and between 4dbc176 and d35728a.

📒 Files selected for processing (1)
  • examples/norm/layernorm.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/norm/layernorm.py

📝 Walkthrough

Walkthrough

Adds a TileLang-backed LayerNorm: JIT forward and backward kernels, a PyTorch autograd Function bridging them, a public layer_norm wrapper and ref_program, plus an executable correctness and benchmark harness.

Changes

LayerNorm Implementation

Layer / File(s) Summary
Forward Kernel
examples/norm/layernorm.py
TileLang kernel computes per-row mean and reciprocal standard deviation, applies normalization and affine transformation with gamma/beta, and saves intermediates for backward.
Backward Kernel
examples/norm/layernorm.py
TileLang kernel uses saved Mean and Rstd to compute input gradients DX, parameter gradients DGamma, and DBeta via feature-dimension reductions and atomic updates.
PyTorch Autograd Integration
examples/norm/layernorm.py
LayerNormFn class bridges kernels to PyTorch: maps dtypes, launches forward kernel, saves tensors, handles backward contiguity and allocation, converts gradient dtypes.
Public API & Reference
examples/norm/layernorm.py
layer_norm(...) wrapper and ref_program(...) reference using torch.nn.functional.layer_norm for testing.
Testing & Benchmarking
examples/norm/layernorm.py
Creates bfloat16 test inputs, validates forward/backward against reference with torch.testing.assert_close, and benchmarks via tilelang.profiler.do_bench.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰
I nibble bytes at break of dawn,
Mean and Rstd, my breakfast drawn.
Forward hums, backward bends,
Gamma, beta — joyful blends.
TileLang carrots guide my prawn!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Example] Add layernorm example in tilelang' clearly and directly describes the main change: adding a LayerNorm example implementation to the TileLang repository.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
examples/norm/layernorm.py (1)

180-187: ⚡ Quick win

Check dgamma and dbeta before printing success.

The script only asserts dx, so it can print “All checks pass.” even if two of the three backward outputs are wrong. Since the fused kernel’s value proposition includes dgamma and dbeta, they should be validated too.

✅ Proposed addition
     y_ref = ref_program(x, g, b, eps)
     y_ref.backward(dy)
     torch.testing.assert_close(dx_ours, x.grad, rtol=1e-2, atol=1e-2)
+    torch.testing.assert_close(dg_ours, g.grad, rtol=1e-2, atol=1e-2)
+    torch.testing.assert_close(db_ours, b.grad, rtol=1e-2, atol=1e-2)
     print("All checks pass.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/norm/layernorm.py` around lines 180 - 187, The test currently only
asserts dx (dx_ours vs x.grad) before printing success; also compare dg_ours and
db_ours to the reference gradients by running the same
torch.testing.assert_close checks for dg_ours vs g.grad and db_ours vs b.grad
(use the same rtol=1e-2, atol=1e-2) after calling y_ref.backward(dy) so all
three gradients (dx_ours, dg_ours, db_ours) are validated before the "All checks
pass." print; locate the comparison area around dx_ours, dg_ours, db_ours,
ref_program, y_ref.backward and add the two additional assert_close calls.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/norm/layernorm.py`:
- Around line 76-123: The kernel launched ceildiv(N, blk_m) CTAs but the inner
loop for k in T.serial(blk_m) always processes blk_m rows, causing out-of-bounds
accesses; guard the per-row work by checking row < N after computing row = bx *
blk_m + k and skip all reads/writes for that iteration if false. Concretely,
wrap the body that reads DY, X, Mean, Rstd, writes DX, and updates
dgamma_acc/dbeta_acc (the code inside the for k in T.serial(blk_m) loop that
references row, DY, X, Mean, Rstd, DX, DX_smem, DX_smem copy) with an if row <
N: ... (or equivalent early continue) so only valid rows are processed.
- Around line 137-157: The backward/forward functions rely on _TORCH_DTYPE_TO_TL
lookup without validation and assume x, gamma, beta share the same dtype which
leads to KeyError and wrong casts; before calling _layernorm_fwd/_layernorm_bwd
validate that x.dtype is present in _TORCH_DTYPE_TO_TL and that gamma.dtype and
beta.dtype match x.dtype (raise a clear TypeError with supported dtype list if
not), use gamma.dtype when allocating dgamma/dbeta (don't hardcode
torch.float32), ensure any intermediate casts use the common dtype, and return
gradients in the correct order (dx, dgamma, dbeta, None) with dgamma/dbeta
converted to the original parameter dtype.

---

Nitpick comments:
In `@examples/norm/layernorm.py`:
- Around line 180-187: The test currently only asserts dx (dx_ours vs x.grad)
before printing success; also compare dg_ours and db_ours to the reference
gradients by running the same torch.testing.assert_close checks for dg_ours vs
g.grad and db_ours vs b.grad (use the same rtol=1e-2, atol=1e-2) after calling
y_ref.backward(dy) so all three gradients (dx_ours, dg_ours, db_ours) are
validated before the "All checks pass." print; locate the comparison area around
dx_ours, dg_ours, db_ours, ref_program, y_ref.backward and add the two
additional assert_close calls.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 050281ed-bab4-4bf1-bb1e-3d4c6bbb066b

📥 Commits

Reviewing files that changed from the base of the PR and between 2878f75 and 4dbc176.

📒 Files selected for processing (1)
  • examples/norm/layernorm.py

Comment thread examples/norm/layernorm.py
Comment thread examples/norm/layernorm.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant