-
Notifications
You must be signed in to change notification settings - Fork 185
Postprocessing to share lm_head weights to embedding #1461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
# Subtract zero point from casted weights | ||
sub_node = helper.make_node( | ||
'Sub', | ||
inputs=["casted_quant_weights", "casted_zero_point"], | ||
outputs=["centered_weights"], | ||
name='/model/embed_tokens/SubtractZeroPoint' | ||
) | ||
|
||
# Multiply by scale | ||
dequantized_output = "dequantized_embeddings" | ||
mul_node = helper.make_node( | ||
'Mul', | ||
inputs=["centered_weights", "gathered_scales"], | ||
outputs=[dequantized_output], | ||
name='/model/embed_tokens/MultiplyByScale' | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use DequantizeLinear op?
https://onnx.ai/onnx/operators/onnx__DequantizeLinear.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please elaborate more, how to construct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use helper.make_node to create a DequantizeLinear node, and feed the quantized lm weight and same scale and bias used in last MatMulNBits node into DequantizeLinear. Then you can get the dequantized weights. Then you can Gather based on input_ids.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we should use DequantizeLinear. It is already constructed in the model builder.
onnxruntime-genai/src/python/py/models/builder.py
Lines 861 to 886 in 8eed730
def make_dequantize_linear(self, dequantize_name, quantized_op): | |
# Input weights are quantized, save quantized MatMul weights for onnx model | |
qweight = dequantize_name[1:].replace("/", ".") + ".qweight" | |
qweight_npy = quantized_op.qweight.detach().cpu() | |
qweight_npy = qweight_npy.reshape(*qweight_npy.shape[:-2], qweight_npy.shape[-2] * qweight_npy.shape[-1]) | |
self.make_external_tensor(qweight_npy.contiguous(), qweight, True) | |
scales = dequantize_name[1:].replace("/", ".") + ".scales" | |
scales_npy = quantized_op.scales.detach().cpu().to(self.to_torch_dtype[self.io_dtype]) | |
scales_npy = scales_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] * 2 // quantized_op.group_size) | |
self.make_external_tensor(scales_npy.contiguous(), scales) | |
dequantize_inputs = [qweight, scales] | |
if hasattr(quantized_op, "qzeros") and quantized_op.qzeros is not None: | |
zeros = dequantize_name[1:].replace("/", ".") + ".qzeros" | |
zeros_npy = quantized_op.qzeros.detach().cpu() | |
zeros_npy = zeros_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] // quantized_op.group_size) | |
self.make_external_tensor(zeros_npy.contiguous(), zeros, True) | |
dequantize_inputs.append(zeros) | |
dequantize_output = f"{dequantize_name}/output_0" | |
self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[dequantize_output], name=dequantize_name, block_size=quantized_op.group_size, axis=-1) | |
self.make_value_info(dequantize_output, self.io_dtype, shape=[*scales_npy.shape[:-1], scales_npy.shape[-1] * quantized_op.group_size]) | |
return dequantize_output |
It will also be easier to construct the temporary subgraph for GatherBlockQuantized in the model builder directly.
207b13a
to
4fefaad
Compare
# Inputs A and scale has the same type, but scale is in external data. So we can only get the type from A here. | ||
scale_value_type = get_tensor_type_from_graph(graph, matmul_node.input[0]) | ||
if scale_value_type: | ||
scale_value_type = scale_value_type.elem_type |
Check notice
Code scanning / CodeQL
Unused local variable Note
This PR is based on #1437, but we don't need convert MatMul here.