Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions examples/advanced/gemm_eltwise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand All @@ -11,7 +11,7 @@
output = matmul(attn_out, wo) + hidden_states

Stage 0 (matmul: attn_out x wo) and Stage 1 (residual add) can be:
- Fused: single pl.at block with chunked_loop_optimizer (mix mode)
- Fused: single pl.at block with auto_chunk (mix mode)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment is being updated to use auto_chunk, but auto_chunk is deprecated and its usage is being removed from the code in this PR. Since the implementation has migrated to explicit chunk loops, the documentation should reflect this change to avoid confusion.

Suggested change
- Fused: single pl.at block with auto_chunk (mix mode)
- Fused: single pl.at block with explicit chunk loops (mix mode)

- Split: separate pl.at blocks for each stage (split mode)
Comment on lines +14 to 15
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update mix-mode docs to match explicit chunk loops.

Line 14 and Line 39 still mention auto_chunk, but this implementation now uses explicit outer pl.parallel + inner pl.range chunking.

✏️ Suggested doc fix
-  - Fused: single pl.at block with auto_chunk (mix mode)
+  - Fused: single pl.at block with explicit chunk loops (mix mode)
@@
-    """Build fused matmul + elementwise program with auto_chunk."""
+    """Build fused matmul + elementwise program with explicit chunk loops."""

Also applies to: 39-39

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/advanced/gemm_eltwise.py` around lines 14 - 15, Update the mix-mode
documentation that still references `auto_chunk` to describe the current
implementation’s explicit chunking: replace mentions of `auto_chunk` with a
description that the fused/mix mode uses an outer `pl.parallel(...)` to split
work into chunks and an inner `pl.range(...)` loop for per-chunk iteration
(i.e., explicit outer parallel + inner range chunking). Edit the comment blocks
in the example that describe "Fused: single pl.at block with auto_chunk (mix
mode)" so they instead mention the outer `pl.parallel` + inner `pl.range`
chunking pattern (also update the analogous phrasing later in the file where
`auto_chunk` appears).


Input and hidden_states are BF16; wo is BF16; output is FP32.
Expand All @@ -36,7 +36,7 @@
batch_tile: int = BATCH_TILE,
chunk: int = 4,
):
"""Build fused matmul + elementwise program with chunked_loop_optimizer."""
"""Build fused matmul + elementwise program with auto_chunk."""
k_blocks = hidden // k_chunk
n_blocks = hidden // n_chunk

Expand All @@ -50,26 +50,27 @@
wo: pl.Tensor[[hidden, hidden], pl.BF16],
resid: pl.Out[pl.Tensor[[batch, hidden], pl.FP32]],
) -> pl.Tensor[[batch, hidden], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)]):
for nb in pl.parallel(0, n_blocks, chunk=chunk):
n0 = nb * n_chunk
# First K-tile: initialize accumulator via matmul
a_chunk_0 = pl.slice(attn_out, [batch_tile, k_chunk], [0, 0])
w_chunk_0 = pl.slice(wo, [k_chunk, n_chunk], [0, n0])
acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32)

# Remaining K-tiles: accumulate via matmul_acc
for kb in pl.range(1, k_blocks):
k0 = kb * k_chunk
a_chunk = pl.slice(attn_out, [batch_tile, k_chunk], [0, k0])
w_chunk = pl.slice(wo, [k_chunk, n_chunk], [k0, n0])
acc = pl.matmul_acc(acc, a_chunk, w_chunk)

# Elementwise residual addition
hidden_chunk = pl.slice(hidden_states, [batch_tile, n_chunk], [0, n0])
hidden_chunk_f32 = pl.cast(hidden_chunk, target_type=pl.FP32)
resid_sum = pl.add(acc, hidden_chunk_f32)
resid = pl.assemble(resid, resid_sum, [0, n0])
for nb_chunk in pl.parallel(0, n_blocks, 1 * chunk):
with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)]):
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inner loop pl.range(nb_chunk, nb_chunk + 1 * chunk, 1) does not account for cases where n_blocks is not a multiple of chunk. In the final iteration of the outer pl.parallel loop, nb will exceed n_blocks, leading to out-of-bounds memory accesses when slicing tensors (e.g., wo at line 59). Use pl.min to clamp the stop condition.

Suggested change
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
for nb in pl.range(nb_chunk, pl.min(nb_chunk + 1 * chunk, n_blocks), 1):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inner loop upper bound nb_chunk + 1 * chunk does not account for cases where n_blocks is not a multiple of chunk. This can lead to out-of-bounds access in the last chunk. It is safer to use pl.min to clamp the bound.

Suggested change
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
for nb in pl.range(nb_chunk, pl.min(n_blocks, nb_chunk + 1 * chunk), 1):

n0 = nb * n_chunk
# First K-tile: initialize accumulator via matmul
a_chunk_0 = pl.slice(attn_out, [batch_tile, k_chunk], [0, 0])
w_chunk_0 = pl.slice(wo, [k_chunk, n_chunk], [0, n0])
acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32)

# Remaining K-tiles: accumulate via matmul_acc
for kb in pl.range(1, k_blocks):
k0 = kb * k_chunk
a_chunk = pl.slice(attn_out, [batch_tile, k_chunk], [0, k0])
w_chunk = pl.slice(wo, [k_chunk, n_chunk], [k0, n0])
acc = pl.matmul_acc(acc, a_chunk, w_chunk)

# Elementwise residual addition
hidden_chunk = pl.slice(hidden_states, [batch_tile, n_chunk], [0, n0])
hidden_chunk_f32 = pl.cast(hidden_chunk, target_type=pl.FP32)
resid_sum = pl.add(acc, hidden_chunk_f32)
resid = pl.assemble(resid, resid_sum, [0, n0])

return resid

Expand Down
11 changes: 6 additions & 5 deletions examples/beginner/hello_world.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -38,11 +38,12 @@
a: pl.Scalar[pl.FP32],
y: pl.Out[pl.Tensor[[rows, cols], pl.FP32]],
) -> pl.Tensor[[rows, cols], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for r in pl.parallel(0, rows, 1, chunk=row_chunk):
tile_x = pl.slice(x, [1, cols], [r, 0])
tile_y = pl.add(tile_x, a)
y = pl.assemble(y, tile_y, [r, 0])
for r_chunk in pl.parallel(0, rows, 1 * row_chunk):
with pl.at(level=pl.Level.CORE_GROUP):
for r in pl.range(r_chunk, r_chunk + 1 * row_chunk, 1):
tile_x = pl.slice(x, [1, cols], [r, 0])
tile_y = pl.add(tile_x, a)
y = pl.assemble(y, tile_y, [r, 0])

return y

Expand Down
16 changes: 9 additions & 7 deletions examples/beginner/matmul.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -46,13 +46,15 @@
b: pl.Tensor[[k, n], pl.FP32],
c: pl.Out[pl.Tensor[[m, n], pl.FP32]],
) -> pl.Tensor[[m, n], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for mb in pl.parallel(0, m, m_tile, chunk=m_chunk):
for nb in pl.parallel(0, n, n_tile, chunk=n_chunk):
tile_a = pl.slice(a, [m_tile, k], [mb, 0])
tile_b = pl.slice(b, [k, n_tile], [0, nb])
tile_c = pl.matmul(tile_a, tile_b)
c = pl.assemble(c, tile_c, [mb, nb])
for mb_chunk in pl.parallel(0, m, m_tile * m_chunk):
for nb_chunk in pl.parallel(0, n, n_tile * n_chunk):
with pl.at(level=pl.Level.CORE_GROUP):
for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile):
for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile):
Comment on lines +52 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These inner loops can exceed the dimensions m and n if they are not multiples of the total chunk size (m_tile * m_chunk and n_tile * n_chunk). This will result in out-of-bounds slicing of tensors a, b, and c. You should use pl.min to ensure the loop indices stay within the valid range.

Suggested change
for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile):
for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile):
for mb in pl.range(mb_chunk, pl.min(mb_chunk + m_tile * m_chunk, m), m_tile):
for nb in pl.range(nb_chunk, pl.min(nb_chunk + n_tile * n_chunk, n), n_tile):

tile_a = pl.slice(a, [m_tile, k], [mb, 0])
tile_b = pl.slice(b, [k, n_tile], [0, nb])
tile_c = pl.matmul(tile_a, tile_b)
c = pl.assemble(c, tile_c, [mb, nb])

return c

Expand Down
34 changes: 18 additions & 16 deletions examples/intermediate/gemm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -51,22 +51,24 @@
b: pl.Tensor[[k, n], pl.FP32],
c: pl.Out[pl.Tensor[[m, n], pl.FP32]],
) -> pl.Tensor[[m, n], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for mb in pl.parallel(0, m, m_tile, chunk=m_chunk):
for nb in pl.parallel(0, n, n_tile, chunk=n_chunk):
# First K-tile: initialize accumulator via matmul
tile_a = pl.slice(a, [m_tile, k_tile], [mb, 0])
tile_b = pl.slice(b, [k_tile, n_tile], [0, nb])
acc = pl.matmul(tile_a, tile_b)

# Remaining K-tiles: accumulate via matmul_acc
for kb in pl.range(1, k_blocks):
k0 = kb * k_tile
tile_a_i = pl.slice(a, [m_tile, k_tile], [mb, k0])
tile_b_i = pl.slice(b, [k_tile, n_tile], [k0, nb])
acc = pl.matmul_acc(acc, tile_a_i, tile_b_i)

c = pl.assemble(c, acc, [mb, nb])
for mb_chunk in pl.parallel(0, m, m_tile * m_chunk):
for nb_chunk in pl.parallel(0, n, n_tile * n_chunk):
with pl.at(level=pl.Level.CORE_GROUP):
for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile):
for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile):
# First K-tile: initialize accumulator via matmul
tile_a = pl.slice(a, [m_tile, k_tile], [mb, 0])
tile_b = pl.slice(b, [k_tile, n_tile], [0, nb])
acc = pl.matmul(tile_a, tile_b)

# Remaining K-tiles: accumulate via matmul_acc
for kb in pl.range(1, k_blocks):
k0 = kb * k_tile
tile_a_i = pl.slice(a, [m_tile, k_tile], [mb, k0])
tile_b_i = pl.slice(b, [k_tile, n_tile], [k0, nb])
acc = pl.matmul_acc(acc, tile_a_i, tile_b_i)

c = pl.assemble(c, acc, [mb, nb])

return c

Expand Down
55 changes: 28 additions & 27 deletions examples/intermediate/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -42,33 +42,34 @@
beta: pl.Tensor[[1, hidden], pl.FP32],
y: pl.Out[pl.Tensor[[rows, hidden], pl.FP32]],
) -> pl.Tensor[[rows, hidden], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for r in pl.parallel(0, rows, row_chunk, chunk=1):
tile_x = pl.slice(x, [row_chunk, hidden], [r, 0])
gamma_tile = pl.slice(gamma, [1, hidden], [0, 0])
beta_tile = pl.slice(beta, [1, hidden], [0, 0])

# Step 1: row mean — pre-scale before row_sum, no reshape
mean = pl.row_sum(pl.mul(tile_x, hidden_inv))

# Step 2: row variance + eps — pre-scale and pre-add
centred = pl.row_expand_sub(tile_x, mean)
var_eps = pl.row_sum(
pl.mul(pl.add(pl.mul(centred, centred), eps), hidden_inv)
)

# Step 3: normalise — single reshape pair for sqrt
std = pl.reshape(
pl.sqrt(pl.reshape(var_eps, [1, row_chunk])),
[row_chunk, 1],
)
normed = pl.row_expand_div(centred, std)

# Step 4: apply gamma scale and beta offset
scaled = pl.col_expand_mul(normed, gamma_tile)
ones = pl.add(pl.sub(tile_x, tile_x), 1.0)
result = pl.add(scaled, pl.col_expand_mul(ones, beta_tile))
y = pl.assemble(y, result, [r, 0])
for r_chunk in pl.parallel(0, rows, row_chunk * 1):
with pl.at(level=pl.Level.CORE_GROUP):
for r in pl.range(r_chunk, r_chunk + row_chunk * 1, row_chunk):
tile_x = pl.slice(x, [row_chunk, hidden], [r, 0])
gamma_tile = pl.slice(gamma, [1, hidden], [0, 0])
beta_tile = pl.slice(beta, [1, hidden], [0, 0])

# Step 1: row mean — pre-scale before row_sum, no reshape
mean = pl.row_sum(pl.mul(tile_x, hidden_inv))

# Step 2: row variance + eps — pre-scale and pre-add
centred = pl.row_expand_sub(tile_x, mean)
var_eps = pl.row_sum(
pl.mul(pl.add(pl.mul(centred, centred), eps), hidden_inv)
)

# Step 3: normalise — single reshape pair for sqrt
std = pl.reshape(
pl.sqrt(pl.reshape(var_eps, [1, row_chunk])),
[row_chunk, 1],
)
normed = pl.row_expand_div(centred, std)

# Step 4: apply gamma scale and beta offset
scaled = pl.col_expand_mul(normed, gamma_tile)
ones = pl.add(pl.sub(tile_x, tile_x), 1.0)
result = pl.add(scaled, pl.col_expand_mul(ones, beta_tile))
y = pl.assemble(y, result, [r, 0])

return y

Expand Down
53 changes: 27 additions & 26 deletions examples/intermediate/rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -48,32 +48,33 @@
gamma: pl.Tensor[[1, hidden], pl.FP32],
y: pl.Out[pl.Tensor[[rows, hidden], pl.FP32]],
) -> pl.Tensor[[rows, hidden], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for r in pl.parallel(0, rows, row_chunk, chunk=1):
# Pass 1: accumulate sum(x^2) across hidden chunks
# row_sum produces [row_chunk, 1] col_major; scalar ops
# need row_major, so accumulate in [1, row_chunk] shape.
sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32)
sq_sum = pl.mul(sq_sum, 0.0)
for hb in pl.range(hidden_blocks):
h0 = hb * hidden_chunk
x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0])
rs = pl.row_sum(pl.mul(x_chunk, x_chunk))
sq_sum = pl.add(sq_sum, pl.reshape(rs, [1, row_chunk]))

# inv_rms = 1 / sqrt(mean(x^2) + eps)
inv_rms_T = pl.rsqrt(pl.add(pl.mul(sq_sum, hidden_inv), eps))
inv_rms = pl.reshape(inv_rms_T, [row_chunk, 1])

# Pass 2: normalise and apply gamma weight
for hb in pl.range(hidden_blocks):
h0 = hb * hidden_chunk
x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0])
gamma_chunk = pl.slice(gamma, [1, hidden_chunk], [0, h0])
normed = pl.col_expand_mul(
pl.row_expand_mul(x_chunk, inv_rms), gamma_chunk
)
y = pl.assemble(y, normed, [r, h0])
for r_chunk in pl.parallel(0, rows, row_chunk * 1):
with pl.at(level=pl.Level.CORE_GROUP):
for r in pl.range(r_chunk, r_chunk + row_chunk * 1, row_chunk):
# Pass 1: accumulate sum(x^2) across hidden chunks
# row_sum produces [row_chunk, 1] col_major; scalar ops
# need row_major, so accumulate in [1, row_chunk] shape.
sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32)
sq_sum = pl.mul(sq_sum, 0.0)
for hb in pl.range(hidden_blocks):
h0 = hb * hidden_chunk
x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0])
rs = pl.row_sum(pl.mul(x_chunk, x_chunk))
sq_sum = pl.add(sq_sum, pl.reshape(rs, [1, row_chunk]))

# inv_rms = 1 / sqrt(mean(x^2) + eps)
inv_rms_T = pl.rsqrt(pl.add(pl.mul(sq_sum, hidden_inv), eps))
inv_rms = pl.reshape(inv_rms_T, [row_chunk, 1])

# Pass 2: normalise and apply gamma weight
for hb in pl.range(hidden_blocks):
h0 = hb * hidden_chunk
x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0])
gamma_chunk = pl.slice(gamma, [1, hidden_chunk], [0, h0])
normed = pl.col_expand_mul(
pl.row_expand_mul(x_chunk, inv_rms), gamma_chunk
)
y = pl.assemble(y, normed, [r, h0])

return y

Expand Down
47 changes: 24 additions & 23 deletions examples/intermediate/rope.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -57,29 +57,30 @@
sin: pl.Tensor[[1, head_dim], pl.FP32],
y: pl.Out[pl.Tensor[[total_rows, head_dim], pl.FP32]],
) -> pl.Tensor[[total_rows, head_dim], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for b in pl.parallel(0, batch, 1, chunk=batch_chunk):
# Slice cos/sin lo/hi halves directly from tensor
# so each becomes a separate tile.load (no textract).
cos_lo = pl.slice(cos, [1, half_dim], [0, 0])
cos_hi = pl.slice(cos, [1, half_dim], [0, half_dim])
sin_lo = pl.slice(sin, [1, half_dim], [0, 0])
sin_hi = pl.slice(sin, [1, half_dim], [0, half_dim])

base = b * num_heads
x_lo = pl.slice(x, [num_heads, half_dim], [base, 0])
x_hi = pl.slice(x, [num_heads, half_dim], [base, half_dim])

rot_lo = pl.sub(
pl.col_expand_mul(x_lo, cos_lo),
pl.col_expand_mul(x_hi, sin_lo),
)
rot_hi = pl.add(
pl.col_expand_mul(x_hi, cos_hi),
pl.col_expand_mul(x_lo, sin_hi),
)
y = pl.assemble(y, rot_lo, [base, 0])
y = pl.assemble(y, rot_hi, [base, half_dim])
for b_chunk in pl.parallel(0, batch, 1 * batch_chunk):
with pl.at(level=pl.Level.CORE_GROUP):
for b in pl.range(b_chunk, b_chunk + 1 * batch_chunk, 1):
# Slice cos/sin lo/hi halves directly from tensor
# so each becomes a separate tile.load (no textract).
cos_lo = pl.slice(cos, [1, half_dim], [0, 0])
cos_hi = pl.slice(cos, [1, half_dim], [0, half_dim])
sin_lo = pl.slice(sin, [1, half_dim], [0, 0])
sin_hi = pl.slice(sin, [1, half_dim], [0, half_dim])

base = b * num_heads
x_lo = pl.slice(x, [num_heads, half_dim], [base, 0])
x_hi = pl.slice(x, [num_heads, half_dim], [base, half_dim])

rot_lo = pl.sub(
pl.col_expand_mul(x_lo, cos_lo),
pl.col_expand_mul(x_hi, sin_lo),
)
rot_hi = pl.add(
pl.col_expand_mul(x_hi, cos_hi),
pl.col_expand_mul(x_lo, sin_hi),
)
y = pl.assemble(y, rot_lo, [base, 0])
y = pl.assemble(y, rot_hi, [base, half_dim])

return y

Expand Down
29 changes: 15 additions & 14 deletions examples/intermediate/softmax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -36,26 +36,27 @@
x: pl.Tensor[[rows, cols], pl.FP32],
y: pl.Out[pl.Tensor[[rows, cols], pl.FP32]],
) -> pl.Tensor[[rows, cols], pl.FP32]:
with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer):
for r in pl.parallel(0, rows, row_chunk, chunk=1):
tile_x = pl.slice(x, [row_chunk, cols], [r, 0])
for r_chunk in pl.parallel(0, rows, row_chunk * 1):
with pl.at(level=pl.Level.CORE_GROUP):
for r in pl.range(r_chunk, r_chunk + row_chunk * 1, row_chunk):
tile_x = pl.slice(x, [row_chunk, cols], [r, 0])

# Step 1: row-wise max for numerical stability
row_max = pl.row_max(tile_x)
# Step 1: row-wise max for numerical stability
row_max = pl.row_max(tile_x)

# Step 2: subtract row max: x - max(x)
shifted = pl.row_expand_sub(tile_x, row_max)
# Step 2: subtract row max: x - max(x)
shifted = pl.row_expand_sub(tile_x, row_max)

# Step 3: exp(x - max(x))
exp_shifted = pl.exp(shifted)
# Step 3: exp(x - max(x))
exp_shifted = pl.exp(shifted)

# Step 4: row-wise sum of exp values
row_sum = pl.row_sum(exp_shifted)
# Step 4: row-wise sum of exp values
row_sum = pl.row_sum(exp_shifted)

# Step 5: divide each row by its sum
result = pl.row_expand_div(exp_shifted, row_sum)
# Step 5: divide each row by its sum
result = pl.row_expand_div(exp_shifted, row_sum)

y = pl.assemble(y, result, [r, 0])
y = pl.assemble(y, result, [r, 0])

return y

Expand Down
Loading
Loading