diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a98e455db..ccfc01200 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -536,8 +536,8 @@ def _export_hf_checkpoint( quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], ) # Export the quantized weights - for weight_name in ["gate_up_proj", "down_proj"]: - with fsdp2_aware_weight_update(model, sub_module, reshard=False): + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + for weight_name in ["gate_up_proj", "down_proj"]: _export_quantized_weight(sub_module, dtype, weight_name) if accelerator is not None: diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index c8be1d014..97f287723 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -577,7 +577,7 @@ def get_prefixed_param_names(parent_model, target_module): def create_fsdp_param_mapping(fsdp_param_list, model): - """Builds a mapping from module name to their corresponding FSDPParam. + """Builds a mapping from full parameter name to their corresponding FSDPParam. Args: fsdp_param_list (list): List of FSDPParam. @@ -586,10 +586,16 @@ def create_fsdp_param_mapping(fsdp_param_list, model): Returns: dict: Full parameter name → FSDP parameter. """ - return { - get_prefixed_param_names(model, param._module_info.module): param - for param in fsdp_param_list - } + mapping = {} + for param in fsdp_param_list: + # Get the module name + module_name = get_prefixed_param_names(model, param._module_info.module) + if module_name is not None: + # Get the parameter name from _module_info and construct full param name + param_name = param._module_info.param_name + full_param_name = f"{module_name}.{param_name}" + mapping[full_param_name] = param + return mapping @contextmanager @@ -706,9 +712,15 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): # Assert that all the modules in the module list are present in this fsdp_param_group if len(modules_to_update) > 1: for module in modules_to_update: - name = _get_module_name(module, root_model) - assert name in fsdp_param_mapping, ( - f"Module {module} not found in fsdp_param_mapping" + module_name = _get_module_name(module, root_model) + # Check if any parameter from this module is in the mapping + module_params_in_mapping = any( + f"{module_name}.{n}" in fsdp_param_mapping + for n, _ in module.named_parameters() + ) + assert module_params_in_mapping, ( + f"Module {module} with name '{module_name}' not found in fsdp_param_mapping. " + f"Available keys: {list(fsdp_param_mapping.keys())}" ) # Yields for necessary weight updates/processing yield @@ -718,44 +730,44 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): if isinstance(root_model, FSDPModule): # Update FSDPParam list for module in modules_to_update: - name = _get_module_name(module, root_model) - if name not in fsdp_param_mapping: - continue - - old_fsdp_param = fsdp_param_mapping[name] - - # Update mp policy to reflect the new dtype - new_mp_policy = MixedPrecisionPolicy( - param_dtype=module.weight.dtype, - reduce_dtype=None, - output_dtype=None, - cast_forward_inputs=False, - ) - - with no_requires_grad(): - # Create a new QFSDPParam or FSDPParam based on weight type - param_class = ( - QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam - ) - - new_param = param_class( - module.weight, - old_fsdp_param._module_info, - old_fsdp_param.mesh_info, - old_fsdp_param.post_forward_mesh_info, - old_fsdp_param.device, - None, - new_mp_policy, - None, + for param_name, param in module.named_parameters(): + name = _get_module_name(module, root_model) + name = f"{name}.{param_name}" + if name not in fsdp_param_mapping: + continue + + old_fsdp_param = fsdp_param_mapping[name] + + # Update mp policy to reflect the new dtype + new_mp_policy = MixedPrecisionPolicy( + param_dtype=param.dtype, + reduce_dtype=None, + output_dtype=None, + cast_forward_inputs=False, ) - if not isinstance(new_param, QFSDPParam): - new_param.init_dtype_attrs(new_mp_policy) - - # Update the FSDPParam mapping to keep track of the new FSDPParam - fsdp_param_mapping[name] = new_param - # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam - old_fsdp_param._post_load_hook_handle.remove() + with no_requires_grad(), enable_fake_quant(module): + # Create a new QFSDPParam or FSDPParam based on weight type + param_class = QFSDPParam if isinstance(param, QTensorWrapper) else FSDPParam + + new_param = param_class( + param, + old_fsdp_param._module_info, + old_fsdp_param.mesh_info, + old_fsdp_param.post_forward_mesh_info, + old_fsdp_param.device, + None, + new_mp_policy, + None, + ) + if not isinstance(new_param, QFSDPParam): + new_param.init_dtype_attrs(new_mp_policy) + + # Update the FSDPParam mapping to keep track of the new FSDPParam + fsdp_param_mapping[name] = new_param + + # Remove the post_load_hook_handle to allow gc to collect the old FSDPParam + old_fsdp_param._post_load_hook_handle.remove() # Update FSDPParam list with new compressed weights fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values()) diff --git a/tests/_test_utils/torch/export/utils.py b/tests/_test_utils/torch/export/utils.py index 8d2d88608..8d4bf2032 100644 --- a/tests/_test_utils/torch/export/utils.py +++ b/tests/_test_utils/torch/export/utils.py @@ -55,18 +55,18 @@ def forward(self, x): class SmallQKVModel(torch.nn.Module): - def __init__(self, dim=4, device="cuda", apply_embed=False): + def __init__(self, dim=4, device="cuda", apply_embed=False, bias=False): super().__init__() self.embedding = torch.nn.Embedding(2, dim) - self.q_proj = torch.nn.Linear(dim, dim, bias=False) - self.k_proj = torch.nn.Linear(dim, dim, bias=False) - self.v_proj = torch.nn.Linear(dim, dim, bias=False) - self.o_proj = torch.nn.Linear(dim, dim, bias=False) + self.q_proj = torch.nn.Linear(dim, dim, bias=bias) + self.k_proj = torch.nn.Linear(dim, dim, bias=bias) + self.v_proj = torch.nn.Linear(dim, dim, bias=bias) + self.o_proj = torch.nn.Linear(dim, dim, bias=bias) self.device = device self.config = None self.apply_embed = apply_embed # TODO: Debug why fsdp2 modifies bias of layernorm for awq - self.input_layernorm = torch.nn.LayerNorm(dim, bias=False) + self.input_layernorm = torch.nn.LayerNorm(dim, bias=bias) def forward(self, x): if self.apply_embed: diff --git a/tests/gpu/torch/export/test_fsdp2_export.py b/tests/gpu/torch/export/test_fsdp2_export.py index 0c3496dec..524444eee 100644 --- a/tests/gpu/torch/export/test_fsdp2_export.py +++ b/tests/gpu/torch/export/test_fsdp2_export.py @@ -118,11 +118,11 @@ def _compare_parameters_and_buffers(model1, model2): ) -def _fuse_layers(rank, size, quant_config): +def _fuse_layers(rank, size, quant_config, bias): with patch_fsdp_mp_dtypes(): # Initialize model - model = SmallQKVModel(dim=32).to("cuda") - non_fsdp_model = SmallQKVModel(dim=32).to("cuda") + model = SmallQKVModel(dim=32, bias=bias).to("cuda") + non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda") non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict())) model.eval() non_fsdp_model.eval() @@ -159,15 +159,15 @@ def calib_fn(x): _compare_parameters_and_buffers(model, non_fsdp_model) -def _export_quantized_weight_test(rank, size, quant_config): +def _export_quantized_weight_test(rank, size, quant_config, bias): import copy from torch.distributed._composable.fsdp import fully_shard with patch_fsdp_mp_dtypes(): # Initialize model - model = SmallQKVModel(dim=32).to("cuda") - non_fsdp_model = SmallQKVModel(dim=32).to("cuda") + model = SmallQKVModel(dim=32, bias=bias).to("cuda") + non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda") non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict())) model.eval() non_fsdp_model.eval() @@ -247,10 +247,11 @@ def test_fsdp2_weight_update_context_for_export(device_count): ], ) @pytest.mark.parametrize("device_count", get_device_counts()) -def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config): +@pytest.mark.parametrize("bias", [True, False]) +def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config, bias): spawn_multiprocess_job( size=device_count, - job=partial(_fuse_layers, quant_config=quant_config), + job=partial(_fuse_layers, quant_config=quant_config, bias=bias), backend="nccl", ) @@ -270,9 +271,10 @@ def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config) ], ) @pytest.mark.parametrize("device_count", get_device_counts()) -def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config): +@pytest.mark.parametrize("bias", [True, False]) +def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config, bias): spawn_multiprocess_job( size=device_count, - job=partial(_export_quantized_weight_test, quant_config=quant_config), + job=partial(_export_quantized_weight_test, quant_config=quant_config, bias=bias), backend="nccl", )