66
77@triton .jit
88def _multinomial_sampling_kernel (Scores , Seeds , Offsets , Indices , Outputs , stride_sb , stride_st , stride_ib , stride_it ,
9- num_batchs , num_tokens , BLOCK : tl . constexpr , BLOCK_N : tl .constexpr ):
9+ num_tokens , BLOCK_N : tl .constexpr ):
1010 """Kernel."""
11- batch_block_id = tl .program_id (0 )
12-
13- off = batch_block_id * BLOCK + tl .arange (0 , BLOCK )
11+ batch_id = tl .program_id (0 )
1412 n_off = tl .arange (0 , BLOCK_N )
1513
16- off_mask = off < num_batchs
17- seed = tl .load (Seeds + off , mask = off_mask )
18- offset = tl .load (Offsets + off , mask = off_mask ).to (tl .int32 )
19-
20- samp = tl .rand (seed , offset )[:, None ]
21- acc = tl .zeros ((BLOCK , ), dtype = tl .float32 )
22- output = tl .load (Indices + off * stride_ib , mask = off_mask )
23-
24- for b_idx in range (0 , num_tokens , BLOCK_N ):
25- s_off = b_idx + n_off
26- s_mask = off_mask [:, None ] & (s_off [None , :] < num_tokens )
27- scores = tl .load (Scores + off [:, None ] * stride_sb + s_off [None , :] * stride_st , mask = s_mask ,
28- other = 0.0 ).to (tl .float32 )
29- c_scores = tl .cumsum (scores , 1 )
30- cum_scores = acc [:, None ] + c_scores
31- acc += tl .max (c_scores , 1 )
32-
33- pre_cum_scores = cum_scores - scores
34- valid_mask = (samp > pre_cum_scores ) & (samp <= cum_scores )
35- found_mask = tl .sum (valid_mask , 1 ) > 0
36-
37- valid_pos = b_idx + tl .argmax (valid_mask .to (tl .int32 ), 1 )
38- indices = tl .load (Indices + off * stride_ib + valid_pos * stride_it , mask = found_mask & off_mask , other = - 1 )
39- output = tl .where (found_mask , indices , output )
40-
41- tl .store (Outputs + off , output , mask = off_mask )
14+ # sampling random seed
15+ seed = tl .load (Seeds + batch_id )
16+ offset = tl .load (Offsets + batch_id ).to (tl .int32 )
17+ samp = tl .rand (seed , offset )
18+
19+ # initialize
20+ acc = 0.0
21+ score_ptr = Scores + batch_id * stride_sb + n_off * stride_st
22+ indice_ptr = Indices + batch_id * stride_ib
23+ output = tl .load (indice_ptr )
24+
25+ found_mask = False
26+ for b_idx in tl .range (0 , num_tokens , BLOCK_N ):
27+ # triton does not have break statement, use mask to skip computation
28+ if not found_mask :
29+ s_off = b_idx + n_off
30+ s_mask = (s_off < num_tokens )
31+ scores = tl .load (score_ptr , mask = s_mask , other = 0.0 ).to (tl .float32 )
32+ c_scores = tl .cumsum (scores , 0 )
33+ cum_scores = acc + c_scores
34+ acc += tl .max (c_scores , 0 )
35+
36+ pre_cum_scores = cum_scores - scores
37+ valid_mask = (samp > pre_cum_scores ) & (samp <= cum_scores )
38+ found_mask = tl .sum (valid_mask , 0 ) > 0
39+
40+ if found_mask :
41+ valid_pos = tl .argmax (valid_mask .to (tl .int32 ), 0 )
42+ indice = tl .load (indice_ptr + valid_pos * stride_it )
43+ output = indice
44+ score_ptr += stride_st * BLOCK_N
45+ indice_ptr += stride_it * BLOCK_N
46+
47+ tl .store (Outputs + batch_id , output )
4248
4349
4450def multinomial_sampling (scores : torch .Tensor ,
4551 seeds : torch .LongTensor ,
4652 offsets : torch .LongTensor ,
4753 indices : torch .Tensor = None ):
48- """Multinomial sampling."""
54+ """Multinomial sampling.
55+
56+ Note that this kernel assumes the input scores are already sorted in descending order.
4957
58+ scores: [batch_size, num_tokens], sorted softmax scores
59+ seeds: [batch_size]
60+ offsets: [batch_size]
61+ indices: [batch_size, num_tokens], original token indices before sorting
62+ """
5063 assert scores .dim () == 2
5164 batch_size , num_tokens = scores .size ()
5265 device = scores .device
@@ -63,10 +76,9 @@ def multinomial_sampling(scores: torch.Tensor,
6376
6477 outputs = indices [:, 0 ].clone ()
6578
66- BLOCK = 8
6779 BLOCK_N = 128
6880
69- grid = [triton . cdiv ( batch_size , BLOCK ) ]
81+ grid = [batch_size ]
7082 _multinomial_sampling_kernel [grid ](scores ,
7183 seeds ,
7284 offsets ,
@@ -76,10 +88,8 @@ def multinomial_sampling(scores: torch.Tensor,
7688 stride_st = scores .stride (1 ),
7789 stride_ib = indices .stride (0 ),
7890 stride_it = indices .stride (1 ),
79- num_batchs = batch_size ,
8091 num_tokens = num_tokens ,
81- BLOCK = BLOCK ,
8292 BLOCK_N = BLOCK_N ,
83- num_warps = 8 )
93+ num_warps = 1 )
8494
8595 return outputs
0 commit comments