Skip to content

Commit d3ba035

Browse files
committed
Merge branch 'main' of github.com:NVIDIA/TileGym into tilegym_ci_init
2 parents 63c6639 + 42538cd commit d3ba035

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ TileGym is a CUDA Tile kernel library that provides a rich collection of kernel
1616
## Overview
1717

1818
This repository aims to provide helpful kernel tutorials and examples for tile-based GPU programming. TileGym is a playground for experimenting with CUDA Tile, where you can learn how to build efficient GPU kernels and explore their integration into real-world large language models such as Llama 3.1 and DeepSeek V2. Whether you're learning tile-based GPU programming or looking to optimize your LLM implementations, TileGym offers practical examples and comprehensive guidance.
19+
<img width="90%" alt="TileGym_repo" src="https://github.com/user-attachments/assets/1d8741f0-f15c-49ff-ad5c-32d1ae6ec71e" />
1920

2021
## Features
2122

src/tilegym/ops/cutile/silu_and_mul.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def silu_and_mul_kernel_row_wise(
3636
input,
3737
output,
3838
TILE_SIZE: ConstInt,
39-
n_elements: ConstInt,
4039
hidden_size: ConstInt,
4140
):
4241
bid = ct.bid(0) # this gives us our row
@@ -47,7 +46,6 @@ def silu_and_mul_kernel_row_wise(
4746
row_idx = bid
4847
a_col_idx = offsets # First half: [0, hidden_size)
4948
b_col_idx = offsets + hidden_size # Second half: [hidden_size, 2*hidden_size)
50-
out_offsets = bid * hidden_size + offsets
5149

5250
# Load tiles using gather with 2D indices
5351
# gather broadcasts (scalar, tile) to (tile,)
@@ -95,7 +93,6 @@ def silu_and_mul(
9593
# Flatten input to 2D: (batch_size, 2 * hidden_size)
9694
input_flat = input.view(-1, original_shape[-1])
9795
batch_size = input_flat.shape[0]
98-
n_elements = batch_size * hidden_size # Total elements to process in output
9996

10097
# Get final output shape
10198
output_shape = list(original_shape)
@@ -123,6 +120,11 @@ def silu_and_mul(
123120
torch.cuda.current_stream(),
124121
grid,
125122
silu_and_mul_kernel_row_wise,
126-
(input_flat, output, TILE_SIZE, n_elements, hidden_size),
123+
(
124+
input_flat,
125+
output,
126+
TILE_SIZE,
127+
hidden_size
128+
),
127129
)
128130
return output.reshape(*output_shape)

0 commit comments

Comments
 (0)