@@ -383,6 +383,8 @@ def grouped_matmul_tlx_kernel(
383383 num_m_tiles = tl .cdiv (gm , BLOCK_SIZE_M )
384384 num_n_tiles = tl .cdiv (gn , BLOCK_SIZE_N )
385385 num_tiles = num_m_tiles * num_n_tiles
386+ num_k_tiles = tl .cdiv (gk , BLOCK_SIZE_K )
387+
386388 if (
387389 tile_idx >= last_problem_end
388390 and tile_idx < last_problem_end + num_tiles
@@ -413,7 +415,6 @@ def grouped_matmul_tlx_kernel(
413415 tile_idx >= last_problem_end
414416 and tile_idx < last_problem_end + num_tiles
415417 ):
416- k = gk
417418 # figure out tile coordinates
418419 tile_idx_in_gemm = tile_idx - last_problem_end
419420 tile_m_idx = tile_idx_in_gemm // num_n_tiles
@@ -423,7 +424,7 @@ def grouped_matmul_tlx_kernel(
423424 offs_am = tile_m_idx * BLOCK_SIZE_M
424425 offs_bn = tile_n_idx * BLOCK_SIZE_N
425426
426- for kk in range (0 , tl . cdiv ( k , BLOCK_SIZE_K ) ):
427+ for kk in range (0 , num_k_tiles ):
427428 buf , phase = _get_bufidx_phase (accum_cnt , NUM_SMEM_BUFFERS )
428429 tlx .barrier_wait (smem_empty_bars [buf ], phase ^ 1 )
429430 tlx .barrier_expect_bytes (
@@ -447,6 +448,12 @@ def grouped_matmul_tlx_kernel(
447448 # go to the next tile by advancing NUM_SMS
448449 tile_idx += NUM_SMS
449450
451+ # Wait for the last pair of TMA load to complete before doing
452+ # the TMA desc update for the next gemm problem.
453+ if num_k_tiles > 0 :
454+ buf , phase = _get_bufidx_phase (accum_cnt - 1 , NUM_SMEM_BUFFERS )
455+ tlx .barrier_wait (smem_full_bars [buf ], phase )
456+
450457 # get ready to go to the next gemm problem
451458 last_problem_end = last_problem_end + num_tiles
452459
0 commit comments