Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ def _export_hf_checkpoint(
quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"],
)
# Export the quantized weights
for weight_name in ["gate_up_proj", "down_proj"]:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
for weight_name in ["gate_up_proj", "down_proj"]:
_export_quantized_weight(sub_module, dtype, weight_name)

if accelerator is not None:
Expand Down
100 changes: 56 additions & 44 deletions modelopt/torch/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def get_prefixed_param_names(parent_model, target_module):


def create_fsdp_param_mapping(fsdp_param_list, model):
"""Builds a mapping from module name to their corresponding FSDPParam.
"""Builds a mapping from full parameter name to their corresponding FSDPParam.

Args:
fsdp_param_list (list): List of FSDPParam.
Expand All @@ -586,10 +586,16 @@ def create_fsdp_param_mapping(fsdp_param_list, model):
Returns:
dict: Full parameter name → FSDP parameter.
"""
return {
get_prefixed_param_names(model, param._module_info.module): param
for param in fsdp_param_list
}
mapping = {}
for param in fsdp_param_list:
# Get the module name
module_name = get_prefixed_param_names(model, param._module_info.module)
if module_name is not None:
# Get the parameter name from _module_info and construct full param name
param_name = param._module_info.param_name
full_param_name = f"{module_name}.{param_name}"
mapping[full_param_name] = param
return mapping


@contextmanager
Expand Down Expand Up @@ -706,9 +712,15 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
# Assert that all the modules in the module list are present in this fsdp_param_group
if len(modules_to_update) > 1:
for module in modules_to_update:
name = _get_module_name(module, root_model)
assert name in fsdp_param_mapping, (
f"Module {module} not found in fsdp_param_mapping"
module_name = _get_module_name(module, root_model)
# Check if any parameter from this module is in the mapping
module_params_in_mapping = any(
f"{module_name}.{n}" in fsdp_param_mapping
for n, _ in module.named_parameters()
)
assert module_params_in_mapping, (
f"Module {module} with name '{module_name}' not found in fsdp_param_mapping. "
f"Available keys: {list(fsdp_param_mapping.keys())}"
)
# Yields for necessary weight updates/processing
yield
Expand All @@ -718,44 +730,44 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
if isinstance(root_model, FSDPModule):
# Update FSDPParam list
for module in modules_to_update:
name = _get_module_name(module, root_model)
if name not in fsdp_param_mapping:
continue

old_fsdp_param = fsdp_param_mapping[name]

# Update mp policy to reflect the new dtype
new_mp_policy = MixedPrecisionPolicy(
param_dtype=module.weight.dtype,
reduce_dtype=None,
output_dtype=None,
cast_forward_inputs=False,
)

with no_requires_grad():
# Create a new QFSDPParam or FSDPParam based on weight type
param_class = (
QFSDPParam if isinstance(module.weight, QTensorWrapper) else FSDPParam
)

new_param = param_class(
module.weight,
old_fsdp_param._module_info,
old_fsdp_param.mesh_info,
old_fsdp_param.post_forward_mesh_info,
old_fsdp_param.device,
None,
new_mp_policy,
None,
for param_name, param in module.named_parameters():
name = _get_module_name(module, root_model)
name = f"{name}.{param_name}"
if name not in fsdp_param_mapping:
continue

old_fsdp_param = fsdp_param_mapping[name]

# Update mp policy to reflect the new dtype
new_mp_policy = MixedPrecisionPolicy(
param_dtype=param.dtype,
reduce_dtype=None,
output_dtype=None,
cast_forward_inputs=False,
)
if not isinstance(new_param, QFSDPParam):
new_param.init_dtype_attrs(new_mp_policy)

# Update the FSDPParam mapping to keep track of the new FSDPParam
fsdp_param_mapping[name] = new_param

# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
old_fsdp_param._post_load_hook_handle.remove()
with no_requires_grad(), enable_fake_quant(module):
# Create a new QFSDPParam or FSDPParam based on weight type
param_class = QFSDPParam if isinstance(param, QTensorWrapper) else FSDPParam

new_param = param_class(
param,
old_fsdp_param._module_info,
old_fsdp_param.mesh_info,
old_fsdp_param.post_forward_mesh_info,
old_fsdp_param.device,
None,
new_mp_policy,
None,
)
if not isinstance(new_param, QFSDPParam):
new_param.init_dtype_attrs(new_mp_policy)

# Update the FSDPParam mapping to keep track of the new FSDPParam
fsdp_param_mapping[name] = new_param

# Remove the post_load_hook_handle to allow gc to collect the old FSDPParam
old_fsdp_param._post_load_hook_handle.remove()

# Update FSDPParam list with new compressed weights
fsdp_param_group.fsdp_params = list(fsdp_param_mapping.values())
Expand Down
12 changes: 6 additions & 6 deletions tests/_test_utils/torch/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ def forward(self, x):


class SmallQKVModel(torch.nn.Module):
def __init__(self, dim=4, device="cuda", apply_embed=False):
def __init__(self, dim=4, device="cuda", apply_embed=False, bias=False):
super().__init__()
self.embedding = torch.nn.Embedding(2, dim)
self.q_proj = torch.nn.Linear(dim, dim, bias=False)
self.k_proj = torch.nn.Linear(dim, dim, bias=False)
self.v_proj = torch.nn.Linear(dim, dim, bias=False)
self.o_proj = torch.nn.Linear(dim, dim, bias=False)
self.q_proj = torch.nn.Linear(dim, dim, bias=bias)
self.k_proj = torch.nn.Linear(dim, dim, bias=bias)
self.v_proj = torch.nn.Linear(dim, dim, bias=bias)
self.o_proj = torch.nn.Linear(dim, dim, bias=bias)
self.device = device
self.config = None
self.apply_embed = apply_embed
# TODO: Debug why fsdp2 modifies bias of layernorm for awq
self.input_layernorm = torch.nn.LayerNorm(dim, bias=False)
self.input_layernorm = torch.nn.LayerNorm(dim, bias=bias)

def forward(self, x):
if self.apply_embed:
Expand Down
22 changes: 12 additions & 10 deletions tests/gpu/torch/export/test_fsdp2_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def _compare_parameters_and_buffers(model1, model2):
)


def _fuse_layers(rank, size, quant_config):
def _fuse_layers(rank, size, quant_config, bias):
with patch_fsdp_mp_dtypes():
# Initialize model
model = SmallQKVModel(dim=32).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
model.eval()
non_fsdp_model.eval()
Expand Down Expand Up @@ -159,15 +159,15 @@ def calib_fn(x):
_compare_parameters_and_buffers(model, non_fsdp_model)


def _export_quantized_weight_test(rank, size, quant_config):
def _export_quantized_weight_test(rank, size, quant_config, bias):
import copy

from torch.distributed._composable.fsdp import fully_shard

with patch_fsdp_mp_dtypes():
# Initialize model
model = SmallQKVModel(dim=32).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32).to("cuda")
model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model = SmallQKVModel(dim=32, bias=bias).to("cuda")
non_fsdp_model.load_state_dict(copy.deepcopy(model.state_dict()))
model.eval()
non_fsdp_model.eval()
Expand Down Expand Up @@ -247,10 +247,11 @@ def test_fsdp2_weight_update_context_for_export(device_count):
],
)
@pytest.mark.parametrize("device_count", get_device_counts())
def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config):
@pytest.mark.parametrize("bias", [True, False])
def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config, bias):
spawn_multiprocess_job(
size=device_count,
job=partial(_fuse_layers, quant_config=quant_config),
job=partial(_fuse_layers, quant_config=quant_config, bias=bias),
backend="nccl",
)

Expand All @@ -270,9 +271,10 @@ def test_fsdp2_weight_update_context_for_fuse_layers(device_count, quant_config)
],
)
@pytest.mark.parametrize("device_count", get_device_counts())
def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config):
@pytest.mark.parametrize("bias", [True, False])
def test_fsdp2_weight_update_context_for_export_quantized_weight(device_count, quant_config, bias):
spawn_multiprocess_job(
size=device_count,
job=partial(_export_quantized_weight_test, quant_config=quant_config),
job=partial(_export_quantized_weight_test, quant_config=quant_config, bias=bias),
backend="nccl",
)