Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ def _export_hf_checkpoint(
quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"],
)
# Export the quantized weights
for weight_name in ["gate_up_proj", "down_proj"]:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
for weight_name in ["gate_up_proj", "down_proj"]:
_export_quantized_weight(sub_module, dtype, weight_name)

if accelerator is not None:
Expand Down
100 changes: 56 additions & 44 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def get_prefixed_param_names(parent_model, target_module):


def create_fsdp_param_mapping(fsdp_param_list, model):
"""Builds a mapping from module name to their corresponding FSDPParam.
"""Builds a mapping from full parameter name to their corresponding FSDPParam.

Args:
fsdp_param_list (list): List of FSDPParam.
Expand All @@ -586,10 +586,16 @@ def create_fsdp_param_mapping(fsdp_param_list, model):
Returns:
dict: Full parameter name → FSDP parameter.
"""
return {
get_prefixed_param_names(model, param._module_info.module): param
for param in fsdp_param_list
}
mapping = {}
for param in fsdp_param_list:
# Get the module name
module_name = get_prefixed_param_names(model, param._module_info.module)
if module_name is not None:
# Get the parameter name from _module_info and construct full param name
param_name = param._module_info.param_name
full_param_name = f"{module_name}.{param_name}"
mapping[full_param_name] = param
return mapping


@contextmanager
Expand Down Expand Up @@ -706,9 +712,15 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
# Assert that all the modules in the module list are present in this fsdp_param_group
if len(modules_to_update) > 1:
for module in modules_to_update:
name = _get_module_name(module, root_model)
assert name in fsdp_param_mapping, (
f"Module {module} not found in fsdp_param_mapping"
module_name = _get_module_name(module, root_model)
# Check if any parameter from this module is in the mapping
module_params_in_mapping = any(
f"{module_name}.{n}" in fsdp_param_mapping
for n, _ in module.named_parameters()
)
assert module_params_in_mapping, (
f"Module {module} with name '{module_name}' not found in fsdp_param_mapping. "
f"Available keys: {list(fsdp_param_mapping.keys())}"
)
# Yields for necessary weight updates/processing
yield
Expand All @@ -718,44 +730,44 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
if isinstance(root_model, FSDPModule):
# Update FSDPParam list
for module in modules_to_update:
name = _get_module_name(module, root_model)
if name not in fsdp_param_mapping:
continue

old_fsdp_param = fsdp_param_mapping[name]

# Update mp policy to reflect the new dtype
new_mp_policy = MixedPrecisionPolicy(
param_dtype=module.weight.dtype,
reduce_dtype=None,
output_dtype=None,
cast_forward_inputs=False,
)

with no_requires_grad():
# Create a new QFSDPParam or FSDPParam based on weight type
param_class = (
QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam
)

new_param = param_class(
module.weight,
old_fsdp_param._module_info,
old_fsdp_param.mesh_info,
old_fsdp_param.post_forward_mesh_info,
old_fsdp_param.device,
None,
new_mp_policy,
None,
for n, p in module.named_parameters():
name = _get_module_name(module, root_model)
name = f"{name}.{n}"
if name not in fsdp_param_mapping:
continue

old_fsdp_param = fsdp_param_mapping[name]

# Update mp policy to reflect the new dtype
new_mp_policy = MixedPrecisionPolicy(
param_dtype=p.dtype,
reduce_dtype=None,
output_dtype=None,
cast_forward_inputs=False,
)
if not isinstance(new_param, QFSDPParam):
new_param.init_dtype_attrs(new_mp_policy)

# Update the FSDPParam mapping to keep track of the new FSDPParam
fsdp_param_mapping[name] = new_param

# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
old_fsdp_param._post_load_hook_handle.remove()
with no_requires_grad(), enable_fake_quant(module):
# Create a new QFSDPParam or FSDPParam based on weight type
param_class = QFSDPParam if isinstance(p, QTensorWrapper) else FSDPParam

new_param = param_class(
p,
old_fsdp_param._module_info,
old_fsdp_param.mesh_info,
old_fsdp_param.post_forward_mesh_info,
old_fsdp_param.device,
None,
new_mp_policy,
None,
)
if not isinstance(new_param, QFSDPParam):
new_param.init_dtype_attrs(new_mp_policy)

# Update the FSDPParam mapping to keep track of the new FSDPParam
fsdp_param_mapping[name] = new_param

# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
old_fsdp_param._post_load_hook_handle.remove()

# Update FSDPParam list with new compressed weights
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())
Expand Down
12 changes: 6 additions & 6 deletions tests/_test_utils/torch/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ def forward(self, x):


class SmallQKVModel(torch.nn.Module):
def __init__(self, dim=4, device="cuda", apply_embed=False):
def __init__(self, dim=4, device="cuda", apply_embed=False, bias=False):
super().__init__()
self.embedding = torch.nn.Embedding(2, dim)
self.q_proj = torch.nn.Linear(dim, dim, bias=False)
self.k_proj = torch.nn.Linear(dim, dim, bias=False)
self.v_proj = torch.nn.Linear(dim, dim, bias=False)
self.o_proj = torch.nn.Linear(dim, dim, bias=False)
self.q_proj = torch.nn.Linear(dim, dim, bias=bias)
self.k_proj = torch.nn.Linear(dim, dim, bias=bias)
self.v_proj = torch.nn.Linear(dim, dim, bias=bias)
self.o_proj = torch.nn.Linear(dim, dim, bias=bias)
self.device = device
self.config = None
self.apply_embed = apply_embed
# TODO: Debug why fsdp2 modifies bias of layernorm for awq
self.input_layernorm = torch.nn.LayerNorm(dim, bias=False)
self.input_layernorm = torch.nn.LayerNorm(dim, bias=bias)

def forward(self, x):
if self.apply_embed:
Expand Down
22 changes: 12 additions & 10 deletions tests/gpu/torch/export/test_fsdp2_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def _compare_parameters_and_buffers(model1, model2):
)


def _fuse_layers(rank, size, quant_config):
def _fuse_layers(rank, size, quant_config, bias):
with patch_fsdp_mp_dtypes():
# Initialize model
model = SmallQKVModel(dim=32).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
model.eval()
non_fsdp_model.eval()
Expand Down Expand Up @@ -159,15 +159,15 @@ def calib_fn(x):
_compare_parameters_and_buffers(model, non_fsdp_model)


def _export_quantized_weight_test(rank, size, quant_config):
def _export_quantized_weight_test(rank, size, quant_config, bias):
import copy

from torch.distributed._composable.fsdp import fully_shard

with patch_fsdp_mp_dtypes():
# Initialize model
model = SmallQKVModel(dim=32).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
model.eval()
non_fsdp_model.eval()
Expand Down Expand Up @@ -247,10 +247,11 @@ def test_fsdp2_weight_update_context_for_export(device_count):
],
)
@pytest.mark.parametrize("device_count", get_device_counts())
def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config):
@pytest.mark.parametrize("bias", [True, False])
def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config, bias):
spawn_multiprocess_job(
size=device_count,
job=partial(_fuse_layers, quant_config=quant_config),
job=partial(_fuse_layers, quant_config=quant_config, bias=bias),
backend="nccl",
)

Expand All @@ -270,9 +271,10 @@ def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config)
],
)
@pytest.mark.parametrize("device_count", get_device_counts())
def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config):
@pytest.mark.parametrize("bias", [True, False])
def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config, bias):
spawn_multiprocess_job(
size=device_count,
job=partial(_export_quantized_weight_test, quant_config=quant_config),
job=partial(_export_quantized_weight_test, quant_config=quant_config, bias=bias),
backend="nccl",
)