Skip to content

Postprocessing to share lm_head weights to embedding #1461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jiafatom
Copy link
Contributor

@jiafatom jiafatom commented May 8, 2025

This PR is based on #1437, but we don't need convert MatMul here.

Comment on lines +94 to +109
# Subtract zero point from casted weights
sub_node = helper.make_node(
'Sub',
inputs=["casted_quant_weights", "casted_zero_point"],
outputs=["centered_weights"],
name='/model/embed_tokens/SubtractZeroPoint'
)

# Multiply by scale
dequantized_output = "dequantized_embeddings"
mul_node = helper.make_node(
'Mul',
inputs=["centered_weights", "gathered_scales"],
outputs=[dequantized_output],
name='/model/embed_tokens/MultiplyByScale'
)
Copy link

@tianleiwu tianleiwu May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate more, how to construct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use helper.make_node to create a DequantizeLinear node, and feed the quantized lm weight and same scale and bias used in last MatMulNBits node into DequantizeLinear. Then you can get the dequantized weights. Then you can Gather based on input_ids.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should use DequantizeLinear. It is already constructed in the model builder.

def make_dequantize_linear(self, dequantize_name, quantized_op):
# Input weights are quantized, save quantized MatMul weights for onnx model
qweight = dequantize_name[1:].replace("/", ".") + ".qweight"
qweight_npy = quantized_op.qweight.detach().cpu()
qweight_npy = qweight_npy.reshape(*qweight_npy.shape[:-2], qweight_npy.shape[-2] * qweight_npy.shape[-1])
self.make_external_tensor(qweight_npy.contiguous(), qweight, True)
scales = dequantize_name[1:].replace("/", ".") + ".scales"
scales_npy = quantized_op.scales.detach().cpu().to(self.to_torch_dtype[self.io_dtype])
scales_npy = scales_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] * 2 // quantized_op.group_size)
self.make_external_tensor(scales_npy.contiguous(), scales)
dequantize_inputs = [qweight, scales]
if hasattr(quantized_op, "qzeros") and quantized_op.qzeros is not None:
zeros = dequantize_name[1:].replace("/", ".") + ".qzeros"
zeros_npy = quantized_op.qzeros.detach().cpu()
zeros_npy = zeros_npy.reshape(*qweight_npy.shape[:-1], qweight_npy.shape[-1] // quantized_op.group_size)
self.make_external_tensor(zeros_npy.contiguous(), zeros, True)
dequantize_inputs.append(zeros)
dequantize_output = f"{dequantize_name}/output_0"
self.make_node("DequantizeLinear", inputs=dequantize_inputs, outputs=[dequantize_output], name=dequantize_name, block_size=quantized_op.group_size, axis=-1)
self.make_value_info(dequantize_output, self.io_dtype, shape=[*scales_npy.shape[:-1], scales_npy.shape[-1] * quantized_op.group_size])
return dequantize_output

It will also be easier to construct the temporary subgraph for GatherBlockQuantized in the model builder directly.

@jiafatom jiafatom force-pushed the tie branch 3 times, most recently from 207b13a to 4fefaad Compare May 9, 2025 02:58
# Inputs A and scale has the same type, but scale is in external data. So we can only get the type from A here.
scale_value_type = get_tensor_type_from_graph(graph, matmul_node.input[0])
if scale_value_type:
scale_value_type = scale_value_type.elem_type

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable scale_value_type is not used.

import onnx
import numpy as np
from onnx import helper, numpy_helper, version_converter

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'version_converter' is not used.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants