Skip to content

Commit adac4b1

Browse files
authored
Optimize multinomial sampling (InternLM#4056)
* optimize multinomial sampling kernel * remove * add comments * optimize * remove sync * recovery * remove print * fix * optimize output pipeline
1 parent 0c7c015 commit adac4b1

File tree

3 files changed

+77
-48
lines changed

3 files changed

+77
-48
lines changed

lmdeploy/pytorch/engine/engine.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,10 @@ def _make_infer_outputs(
831831
logits = batched_outputs.logits
832832
logprobs = batched_outputs.logprobs
833833

834+
if logprobs is not None:
835+
logprobs.vals = logprobs.vals.tolist()
836+
logprobs.indices = logprobs.indices.tolist()
837+
834838
seq_length = [seq.num_token_ids for seq in running]
835839
is_run = [seq.status == MessageStatus.LOCKED for seq in running]
836840
self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding)
@@ -858,7 +862,7 @@ def _make_infer_outputs(
858862
num_logprobs = msg.sampling_param.num_logprobs
859863
cur_logprobs = None
860864
if num_logprobs >= 0:
861-
cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1])
865+
cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1])
862866

863867
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events)
864868
out = InferOutput(session_id=session_id,
@@ -953,15 +957,7 @@ def __log_resps(outputs: List[InferOutput]):
953957
def __send_resp(out: InferOutput):
954958
"""Send response."""
955959
resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
956-
cur_logprobs = out.logprobs
957-
logprobs = None
958-
if cur_logprobs is not None:
959-
# logprobs to dict
960-
vals = cur_logprobs[0].tolist()
961-
indices = cur_logprobs[1].tolist()
962-
cur_logprobs = dict(zip(indices, vals))
963-
logprobs = [] if out.resp.data is None else out.resp.data.get('logprobs', [])
964-
logprobs = logprobs + [cur_logprobs]
960+
logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None)
965961
self._response(out.resp,
966962
resp_type,
967963
data=dict(token_ids=out.token_ids,
@@ -970,10 +966,33 @@ def __send_resp(out: InferOutput):
970966
req_metrics=out.req_metrics,
971967
logprobs=logprobs))
972968

969+
def __update_logprobs(step_outputs: List[InferOutput]):
970+
for out in step_outputs:
971+
cur_logprobs = out.logprobs
972+
if cur_logprobs is None:
973+
continue
974+
975+
if out.resp.data is None:
976+
out.resp.data = dict()
977+
out.resp.data.setdefault('logprobs', [])
978+
979+
# logprobs to dict
980+
vals = cur_logprobs[0]
981+
indices = cur_logprobs[1]
982+
cur_logprobs = dict(zip(indices, vals))
983+
logprobs = out.resp.data['logprobs']
984+
logprobs.append(cur_logprobs)
985+
973986
def __send_resps(step_outputs: List[InferOutput]):
974987
"""Send response callback."""
975988
__log_resps(step_outputs)
976-
for out in step_outputs:
989+
__update_logprobs(step_outputs)
990+
991+
is_done = set()
992+
for out in reversed(step_outputs):
993+
if out.session_id in is_done:
994+
continue
995+
is_done.add(out.session_id)
977996
__send_resp(out)
978997

979998
while True:

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
236236
if max_topk <= 0:
237237
max_topk = scores.size(1)
238238
if top_k is not None:
239-
top_k = torch.where(top_k <= 0, top_k.new_tensor(max_topk), top_k)
239+
top_k = torch.masked_fill(top_k, top_k <= 0, max_topk)
240240

241241
if top_k is not None:
242242
scores = _filter_topk_sorted_(scores, top_k)

lmdeploy/pytorch/kernels/cuda/multinomial_sampling.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,60 @@
66

77
@triton.jit
88
def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs, stride_sb, stride_st, stride_ib, stride_it,
9-
num_batchs, num_tokens, BLOCK: tl.constexpr, BLOCK_N: tl.constexpr):
9+
num_tokens, BLOCK_N: tl.constexpr):
1010
"""Kernel."""
11-
batch_block_id = tl.program_id(0)
12-
13-
off = batch_block_id * BLOCK + tl.arange(0, BLOCK)
11+
batch_id = tl.program_id(0)
1412
n_off = tl.arange(0, BLOCK_N)
1513

16-
off_mask = off < num_batchs
17-
seed = tl.load(Seeds + off, mask=off_mask)
18-
offset = tl.load(Offsets + off, mask=off_mask).to(tl.int32)
19-
20-
samp = tl.rand(seed, offset)[:, None]
21-
acc = tl.zeros((BLOCK, ), dtype=tl.float32)
22-
output = tl.load(Indices + off * stride_ib, mask=off_mask)
23-
24-
for b_idx in range(0, num_tokens, BLOCK_N):
25-
s_off = b_idx + n_off
26-
s_mask = off_mask[:, None] & (s_off[None, :] < num_tokens)
27-
scores = tl.load(Scores + off[:, None] * stride_sb + s_off[None, :] * stride_st, mask=s_mask,
28-
other=0.0).to(tl.float32)
29-
c_scores = tl.cumsum(scores, 1)
30-
cum_scores = acc[:, None] + c_scores
31-
acc += tl.max(c_scores, 1)
32-
33-
pre_cum_scores = cum_scores - scores
34-
valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
35-
found_mask = tl.sum(valid_mask, 1) > 0
36-
37-
valid_pos = b_idx + tl.argmax(valid_mask.to(tl.int32), 1)
38-
indices = tl.load(Indices + off * stride_ib + valid_pos * stride_it, mask=found_mask & off_mask, other=-1)
39-
output = tl.where(found_mask, indices, output)
40-
41-
tl.store(Outputs + off, output, mask=off_mask)
14+
# sampling random seed
15+
seed = tl.load(Seeds + batch_id)
16+
offset = tl.load(Offsets + batch_id).to(tl.int32)
17+
samp = tl.rand(seed, offset)
18+
19+
# initialize
20+
acc = 0.0
21+
score_ptr = Scores + batch_id * stride_sb + n_off * stride_st
22+
indice_ptr = Indices + batch_id * stride_ib
23+
output = tl.load(indice_ptr)
24+
25+
found_mask = False
26+
for b_idx in tl.range(0, num_tokens, BLOCK_N):
27+
# triton does not have break statement, use mask to skip computation
28+
if not found_mask:
29+
s_off = b_idx + n_off
30+
s_mask = (s_off < num_tokens)
31+
scores = tl.load(score_ptr, mask=s_mask, other=0.0).to(tl.float32)
32+
c_scores = tl.cumsum(scores, 0)
33+
cum_scores = acc + c_scores
34+
acc += tl.max(c_scores, 0)
35+
36+
pre_cum_scores = cum_scores - scores
37+
valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
38+
found_mask = tl.sum(valid_mask, 0) > 0
39+
40+
if found_mask:
41+
valid_pos = tl.argmax(valid_mask.to(tl.int32), 0)
42+
indice = tl.load(indice_ptr + valid_pos * stride_it)
43+
output = indice
44+
score_ptr += stride_st * BLOCK_N
45+
indice_ptr += stride_it * BLOCK_N
46+
47+
tl.store(Outputs + batch_id, output)
4248

4349

4450
def multinomial_sampling(scores: torch.Tensor,
4551
seeds: torch.LongTensor,
4652
offsets: torch.LongTensor,
4753
indices: torch.Tensor = None):
48-
"""Multinomial sampling."""
54+
"""Multinomial sampling.
55+
56+
Note that this kernel assumes the input scores are already sorted in descending order.
4957
58+
scores: [batch_size, num_tokens], sorted softmax scores
59+
seeds: [batch_size]
60+
offsets: [batch_size]
61+
indices: [batch_size, num_tokens], original token indices before sorting
62+
"""
5063
assert scores.dim() == 2
5164
batch_size, num_tokens = scores.size()
5265
device = scores.device
@@ -63,10 +76,9 @@ def multinomial_sampling(scores: torch.Tensor,
6376

6477
outputs = indices[:, 0].clone()
6578

66-
BLOCK = 8
6779
BLOCK_N = 128
6880

69-
grid = [triton.cdiv(batch_size, BLOCK)]
81+
grid = [batch_size]
7082
_multinomial_sampling_kernel[grid](scores,
7183
seeds,
7284
offsets,
@@ -76,10 +88,8 @@ def multinomial_sampling(scores: torch.Tensor,
7688
stride_st=scores.stride(1),
7789
stride_ib=indices.stride(0),
7890
stride_it=indices.stride(1),
79-
num_batchs=batch_size,
8091
num_tokens=num_tokens,
81-
BLOCK=BLOCK,
8292
BLOCK_N=BLOCK_N,
83-
num_warps=8)
93+
num_warps=1)
8494

8595
return outputs

0 commit comments

Comments
 (0)