Skip to content

Commit d6c778e

Browse files
grimoireRunningLeon
authored andcommitted
optimize int8
1 parent 6a6f9f1 commit d6c778e

File tree

4 files changed

+182
-73
lines changed

4 files changed

+182
-73
lines changed

lmdeploy/pytorch/backends/cuda/qmodules.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,23 @@ def forward(self,
2929
weight: torch.Tensor,
3030
residual: torch.Tensor = None):
3131
"""forward."""
32-
if residual is not None:
33-
x = x + residual
34-
residual = x
35-
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
36-
x, weight, self.eps, quant_dtype=self.quant_dtype)
37-
x = QTensor(hidden_states_quant, rms_scale)
3832
if residual is None:
33+
(x,
34+
rms_scale) = rms_norm_dynamic_quant(x,
35+
weight,
36+
self.eps,
37+
quant_dtype=self.quant_dtype)
38+
x = QTensor(x, rms_scale)
3939
return x
40-
return x, residual
40+
else:
41+
(x, rms_scale,
42+
residual) = rms_norm_dynamic_quant(x,
43+
weight,
44+
self.eps,
45+
residual=residual,
46+
quant_dtype=self.quant_dtype)
47+
x = QTensor(x, rms_scale)
48+
return x, residual
4149

4250

4351
class TritonRMSNormBuilder(RMSNormW8A8Builder):
@@ -70,7 +78,6 @@ def forward(self,
7078
all_reduce: bool = False):
7179
"""forward."""
7280
if isinstance(x, torch.Tensor):
73-
x = x.contiguous()
7481
input_quant, input_scale = per_token_quant_int8(
7582
x, 1e-7, quant_dtype=self.quant_dtype)
7683
else:

lmdeploy/pytorch/engine/engine.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,9 +1003,29 @@ def __send_resps(step_outputs: Dict[int, InferOutput]):
10031003
for out in step_outputs.values():
10041004
__send_resp(out)
10051005

1006+
def __do_prefill():
1007+
# decoding if no waiting
1008+
if not self.scheduler.has_waiting():
1009+
return False
1010+
1011+
num_running = len(self.scheduler.running)
1012+
num_waiting = len(self.scheduler.waiting)
1013+
max_batches = self.scheduler_config.max_batches
1014+
1015+
# prefill if too much waiting
1016+
if num_waiting >= 4:
1017+
return True
1018+
1019+
# prefill if no enough running
1020+
if num_running < max_batches * 0.5:
1021+
return True
1022+
1023+
# decoding
1024+
return False
1025+
10061026
async def __step():
10071027
"""step decoding."""
1008-
prefill = self.scheduler.has_waiting()
1028+
prefill = __do_prefill()
10091029
schedule_output = self.scheduler.schedule(
10101030
is_prefill=prefill, prealloc_size=prefill_interval)
10111031
# schedule decoding if no valid prefill reqs.

lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py

Lines changed: 145 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def per_channel_quant(x: torch.Tensor, n_bits: int, dtype: torch.dtype):
6060
num_warps=4)
6161
],
6262
key=['N', 'K'],
63+
warmup=5,
64+
rep=20,
6365
)
64-
@triton.jit
66+
@triton.jit(do_not_specialize=['M'])
6567
def _linear(
6668
A,
6769
B,
@@ -142,8 +144,10 @@ def _linear(
142144
num_warps=4)
143145
],
144146
key=['N', 'K'],
147+
warmup=5,
148+
rep=20,
145149
)
146-
@triton.jit
150+
@triton.jit(do_not_specialize=['M'])
147151
def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak,
148152
stride_bk, stride_bn, stride_cm, stride_cn,
149153
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
@@ -281,7 +285,8 @@ def _per_token_quant_int8(
281285
y_ptr,
282286
y_q_ptr,
283287
y_s_ptr,
284-
y_stride,
288+
y_stride: tl.constexpr,
289+
yq_stride: tl.constexpr,
285290
N, # number of columns in X
286291
eps: tl.constexpr, # epsilon to avoid division by zero
287292
BLOCK: tl.constexpr,
@@ -296,7 +301,7 @@ def _per_token_quant_int8(
296301
# Map the program id to the row of X and Y it should compute.
297302
row = tl.program_id(0)
298303
y_ptr += row * y_stride
299-
y_q_ptr += row * y_stride
304+
y_q_ptr += row * yq_stride
300305
y_s_ptr += row
301306

302307
cols = tl.arange(0, BLOCK) # N <= BLOCK
@@ -333,15 +338,20 @@ def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
333338
BLOCK = triton.next_power_of_2(N)
334339
# heuristics for number of warps
335340
num_warps = min(max(BLOCK // 256, 1), 8)
341+
342+
if x.dim() > 2:
343+
x = x.flatten(0, -2)
344+
assert x.stride(-1) == 1
336345
# enqueue kernel
337346
kernel_meta = get_kernel_meta(x)
338347
_per_token_quant_int8[(M, )](
339348
x,
340349
x_q,
341350
x_s,
342-
x.stride(-2),
343-
N,
344-
eps,
351+
y_stride=x.stride(-2),
352+
yq_stride=x_q.stride(-2),
353+
N=N,
354+
eps=eps,
345355
BLOCK=BLOCK,
346356
Q_MAX=q_max,
347357
IS_FLOATING_POINT=quant_dtype.is_floating_point,
@@ -352,46 +362,98 @@ def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
352362

353363

354364
@triton.jit
355-
def _rms_norm_fwd_fused_dynamic_symmetric(
356-
X, # pointer to the input
357-
Y, # pointer to the output
358-
W, # pointer to the weights
359-
Scale, # pointer to the scales of the output activation
360-
stride, # how much to increase the pointer when moving by 1 row
361-
N, # number of columns in X
362-
eps: tl.constexpr, # epsilon to avoid division by zero
363-
BLOCK_SIZE: tl.constexpr,
365+
def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
366+
"""compute rms norm."""
367+
xf = x.to(tl.float32)
368+
369+
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
370+
out = xf * tl.math.rsqrt(var + eps)
371+
out = (w * out).to(x.dtype)
372+
return out
373+
374+
375+
@triton.jit
376+
def rms_norm_quant_kernel(
377+
input,
378+
weight,
379+
output,
380+
out_scale,
381+
input_row_stride: tl.constexpr,
382+
eps: tl.constexpr,
383+
N_COLS: tl.constexpr,
384+
BLOCK_N: tl.constexpr,
364385
Q_MIN: tl.constexpr,
365386
Q_MAX: tl.constexpr,
366387
IS_FLOATING_POINT: tl.constexpr,
367388
):
368-
"""A Triton kernel that calculates Root Mean Square (RMS) normalization
369-
with fused dynamic symmetric quantization."""
370-
row = tl.program_id(0)
371-
Y += row * stride
372-
X += row * stride
389+
"""rms norm kernel."""
390+
prog_id = tl.program_id(0)
391+
offsets = tl.arange(0, BLOCK_N)
373392

374-
cols = tl.arange(0, BLOCK_SIZE)
375-
mask = cols < N
376-
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
377-
_var = x * x
378-
var = tl.sum(_var, axis=0) / N
379-
rstd = tl.math.rsqrt(var + eps)
380-
381-
w = tl.load(W + cols, mask=mask)
382-
x_hat = x * rstd
383-
y = x_hat * w
384-
385-
scale = tl.max(tl.abs(y)).to(tl.float32) / Q_MAX
386-
tl.store(Scale + row, scale)
387-
y = y / scale
393+
w = tl.load(weight + offsets, mask=offsets < N_COLS)
394+
395+
x_ptr = input + prog_id * input_row_stride
396+
x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
397+
out = _compute_rms_norm(x, w, eps, N_COLS)
398+
399+
scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
400+
out_s_ptr = out_scale + prog_id
401+
tl.store(out_s_ptr, scale)
402+
out = out / scale
403+
if not IS_FLOATING_POINT:
404+
out = tl_round(out)
405+
out = tl.clamp(out, Q_MIN, Q_MAX)
406+
out_ptr = output + prog_id * input_row_stride
407+
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
408+
409+
410+
@triton.jit
411+
def add_rms_norm_quant_kernel(
412+
input,
413+
weight,
414+
residual,
415+
output,
416+
out_scale,
417+
out_residual,
418+
input_row_stride: tl.constexpr,
419+
residual_row_stride: tl.constexpr,
420+
eps: tl.constexpr,
421+
N_COLS: tl.constexpr,
422+
BLOCK_N: tl.constexpr,
423+
Q_MIN: tl.constexpr,
424+
Q_MAX: tl.constexpr,
425+
IS_FLOATING_POINT: tl.constexpr,
426+
):
427+
"""rms norm kernel."""
428+
prog_id = tl.program_id(0)
429+
offsets = tl.arange(0, BLOCK_N)
430+
431+
w = tl.load(weight + offsets, mask=offsets < N_COLS)
432+
433+
x_ptr = input + prog_id * input_row_stride
434+
x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
435+
436+
res_ptr = residual + prog_id * residual_row_stride
437+
res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)
438+
439+
new_x = x + res
440+
out_res_ptr = out_residual + prog_id * residual_row_stride
441+
tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)
442+
443+
out = _compute_rms_norm(new_x, w, eps, N_COLS)
444+
445+
scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
446+
out_s_ptr = out_scale + prog_id
447+
tl.store(out_s_ptr, scale)
448+
out = out / scale
388449
if not IS_FLOATING_POINT:
389-
y = tl_round(y)
390-
y = tl.clamp(y, Q_MIN, Q_MAX)
391-
tl.store(Y + cols, y, mask=mask)
450+
out = tl_round(out)
451+
out = tl.clamp(out, Q_MIN, Q_MAX)
452+
out_ptr = output + prog_id * input_row_stride
453+
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
392454

393455

394-
def rms_norm_dynamic_quant(x, w, eps, quant_dtype=torch.int8):
456+
def rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8):
395457
"""Performs RMS normalization with dynamic quantization.
396458
397459
The function reshapes the input tensor `x`, creates an empty tensor `y`
@@ -401,32 +463,52 @@ def rms_norm_dynamic_quant(x, w, eps, quant_dtype=torch.int8):
401463
qdtype_info = torch.finfo(
402464
quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
403465
quant_dtype)
404-
x_arg = x.flatten(0, -2)
405466
y = torch.empty_like(x, dtype=quant_dtype)
406-
M, K = x_arg.shape
407-
MAX_FUSED_SIZE = 65536 // x.element_size()
408-
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K))
409-
if K > BLOCK_SIZE:
410-
raise RuntimeError(
411-
"This rms norm doesn't support feature dim >= 64KB.")
412-
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
413467
scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)
414-
kernel_meta = get_kernel_meta(x_arg)
415-
_rms_norm_fwd_fused_dynamic_symmetric[(M, )](
416-
x_arg,
417-
y,
418-
w,
419-
scale,
420-
x_arg.stride(0),
421-
K,
422-
eps,
423-
BLOCK_SIZE=BLOCK_SIZE,
424-
Q_MIN=qdtype_info.min,
425-
Q_MAX=qdtype_info.max,
426-
IS_FLOATING_POINT=quant_dtype.is_floating_point,
427-
num_warps=num_warps,
428-
**kernel_meta)
429-
return y, scale
468+
469+
feat_size = w.shape[0]
470+
seq_len = x.numel() // x.size(-1)
471+
input_stride = x.stride(-2)
472+
BLOCK_N = triton.next_power_of_2(feat_size)
473+
grid = (seq_len, )
474+
475+
if residual is None:
476+
rms_norm_quant_kernel[grid](
477+
x,
478+
w,
479+
y,
480+
scale,
481+
input_row_stride=input_stride,
482+
eps=eps,
483+
N_COLS=feat_size,
484+
BLOCK_N=BLOCK_N,
485+
Q_MIN=qdtype_info.min,
486+
Q_MAX=qdtype_info.max,
487+
IS_FLOATING_POINT=quant_dtype.is_floating_point,
488+
num_warps=4,
489+
num_stages=2)
490+
return y, scale
491+
else:
492+
out_residual = torch.empty_like(x)
493+
res_stride = residual.stride(-2)
494+
add_rms_norm_quant_kernel[grid](
495+
x,
496+
w,
497+
residual,
498+
y,
499+
scale,
500+
out_residual,
501+
input_row_stride=input_stride,
502+
residual_row_stride=res_stride,
503+
eps=eps,
504+
N_COLS=feat_size,
505+
BLOCK_N=BLOCK_N,
506+
Q_MIN=qdtype_info.min,
507+
Q_MAX=qdtype_info.max,
508+
IS_FLOATING_POINT=quant_dtype.is_floating_point,
509+
num_warps=4,
510+
num_stages=2)
511+
return y, scale, out_residual
430512

431513

432514
def test_rms_and_linear(x,

lmdeploy/pytorch/nn/norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def _is_w8a8(quant_config: Any):
1313
return False
1414
else:
1515
quant_method = quant_config['quant_method']
16-
if quant_method == 'w8a8':
16+
if quant_method == 'smooth_quant':
1717
return True
1818
else:
1919
return False

0 commit comments

Comments
 (0)