Skip to content

Commit faaf9f4

Browse files
jingyu-mlclaude
authored andcommitted
Skip ComfyUI safetensors post-processing unless opted in (fix sharded FLUX export) (#1794)
### What does this PR do? **Type of change:** Bug fix Fixes diffusers HuggingFace export (`export_hf_checkpoint`) failing with `NotImplementedError: Post-processing sharded safetensors is not supported` for large quantized diffusion models (e.g. the FP8 FLUX.1-dev transformer, ~12 GB, which exceeds the default 10 GB `max_shard_size` and is split into multiple `.safetensors` shards). **Root cause:** `_postprocess_safetensors` runs for every quantized diffusers component. It exists only to build single-file deployment checkpoints (e.g. ComfyUI) — merging with a base checkpoint, NVFP4 padding/swizzling, and embedding quant metadata in the safetensors header — none of which ModelOpt reads back (the diffusers reload uses `config.json`). But it embedded the header metadata **by default** (`enable_layerwise_quant_metadata=True`), and that default-on path hard-raised for any sharded checkpoint, so a plain `--format fp8` FLUX export failed at export time. (Introduced in #1195 / #911; not a transformers/diffusers API change.) **Fix — make the post-processing opt-in:** `_postprocess_safetensors` now returns immediately (no-op) unless the caller opts into one of `merged_base_safetensor_path` / `padding_strategy` / `enable_swizzle_layout` / `enable_layerwise_quant_metadata` (the last default flipped `True → False`). A plain quantized diffusers export — including a sharded one — is left untouched and no longer fails; `config.json` still carries `quantization_config` for the diffusers-native reload. Behavior for callers that opt in is unchanged (sharded + merge/metadata remains unsupported, which is out of scope for this fix). ### Usage ```bash # Default export — now succeeds (no ComfyUI post-processing): python examples/diffusers/quantization/quantize.py --model flux-dev --format fp8 \ --quantized-torch-ckpt-save-path flux-dev-fp8.pt --hf-ckpt-dir flux-dev-fp8 \ --collect-method default --calib-size 128 --quantize-mha \ --model-dtype BFloat16 --trt-high-precision-dtype BFloat16 # Opt into the ComfyUI single-file header metadata: # ... --extra-param enable_layerwise_quant_metadata=true ``` ### Testing - CPU unit tests (`tests/unit/torch/export/test_export_diffusers.py`): - `test_postprocess_default_is_noop` — default (incl. a sharded checkpoint) injects nothing and does not raise. - `test_postprocess_single_file_metadata_when_opted_in` — opt-in single-file injection still works. - `test_postprocess_sharded_opt_in_raises` — an explicit opt-in on a sharded checkpoint still raises (documents the existing, out-of-scope limitation). - GPU test (`tests/gpu/torch/export/test_export_diffusers.py::test_export_diffusers_sharded_default_no_header_metadata`) — real FP8 tiny-FLUX, forced sharding, default → succeeds with a clean header. - Verified on a GB200 in the dev container: 10/10 unit + 11/11 GPU diffusers export tests pass; the regression test fails on the original code (`NotImplementedError`) and passes with the fix. `ruff check` / `ruff format` clean. ### Before your PR is "*Ready for review*" - Is this change backward compatible?: ✅ for the diffusers-native reload (quant config still in `config.json`). ⚠️ Behavior change: the ComfyUI **safetensors-header** metadata is no longer written for a plain export — it is now opt-in (`enable_layerwise_quant_metadata=true`, or automatically when using merge/swizzle/padding). - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: N/A — scoped as a minor diffusers export-path fix; no changelog entry - Did you get Claude approval on this PR?: ❌ — will run `/claude review`. ### Additional Information Reported against ModelOpt 0.45.0rc0 / TRT-LLM 1.3.0rc17 with FLUX.1-dev FP8 export on RTX 6000 Ada. The safetensors post-processing was added in #1195 (ComfyUI single-file origin in #911). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Changes** * Quantization metadata injection in safetensors exports is now opt-in rather than automatic, with the default behavior changed to exclude metadata. * Sharded FP8 model exports no longer include quantization header metadata by default. * **Tests** * Added regression and unit tests validating quantization metadata opt-in behavior for safetensors exports. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jingyu Xin <jingyux@nvidia.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent bb53066 commit faaf9f4

3 files changed

Lines changed: 172 additions & 7 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ def _postprocess_safetensors(
168168
``_quantization_metadata`` so inference runtimes can detect and handle
169169
quantized layers.
170170
171+
All of these target single-file deployment runtimes (e.g. ComfyUI) and are
172+
opt-in; ModelOpt itself reads the quant config from ``config.json`` on reload. If
173+
the caller passes none of ``merged_base_safetensor_path``, ``padding_strategy``,
174+
``enable_swizzle_layout``, or ``enable_layerwise_quant_metadata``, this function
175+
does nothing and leaves the standard exported checkpoint untouched.
176+
171177
Args:
172178
export_dir: Directory containing the saved ``.safetensors`` file(s).
173179
pipe: The diffusion pipeline / model. Used to infer the model type
@@ -181,11 +187,11 @@ def _postprocess_safetensors(
181187
file to produce a single-file checkpoint compatible with ComfyUI.
182188
Value should be the path to a full base model ``.safetensors``
183189
file (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
184-
enable_layerwise_quant_metadata (bool, optional): When True
185-
(default), includes per-layer ``_quantization_metadata`` in the
186-
checkpoint metadata so that inference runtimes (e.g., ComfyUI)
187-
can identify which layers are quantized and in what format. Set
188-
to False to skip.
190+
enable_layerwise_quant_metadata (bool, optional): When True, embeds
191+
``quantization_config`` and per-layer ``_quantization_metadata`` in the
192+
safetensors header so single-file runtimes (e.g., ComfyUI) can identify
193+
which layers are quantized and in what format. Defaults to False (no
194+
header metadata; this alone leaves the export untouched).
189195
enable_swizzle_layout (bool, optional): When True, rearranges NVFP4
190196
block scales from ModelOpt's flat layout to cuBLAS 2-D tiled
191197
layout. Required for runtimes that consume cuBLAS block-scaled
@@ -198,10 +204,23 @@ def _postprocess_safetensors(
198204
199205
"""
200206
merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path")
201-
enable_layerwise_quant_metadata: bool = kwargs.get("enable_layerwise_quant_metadata", True)
207+
enable_layerwise_quant_metadata: bool = kwargs.get("enable_layerwise_quant_metadata", False)
202208
enable_swizzle_layout: bool = kwargs.get("enable_swizzle_layout", False)
203209
padding_strategy: str | None = kwargs.get("padding_strategy")
204210

211+
# This post-processing only produces single-file deployment checkpoints (e.g.
212+
# ComfyUI): merging with a base checkpoint, NVFP4 padding/swizzling, and embedding
213+
# quant metadata in the safetensors header. None of it is read back by ModelOpt
214+
# (the diffusers reload uses ``config.json``), so if the user has not opted into any
215+
# of these options there is nothing to do — leave the exported checkpoint untouched.
216+
if not (
217+
merged_base_safetensor_path is not None
218+
or padding_strategy is not None
219+
or enable_swizzle_layout
220+
or enable_layerwise_quant_metadata
221+
):
222+
return
223+
205224
safetensor_files = sorted(export_dir.glob("*.safetensors"))
206225
if not safetensor_files:
207226
return

tests/gpu/torch/export/test_export_diffusers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet
20+
from safetensors import safe_open
2021

2122
import modelopt.torch.quantization as mtq
2223
from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs
@@ -28,6 +29,13 @@ def _load_config(config_path):
2829
return json.load(file)
2930

3031

32+
def _calib_with_dummy_inputs(m):
33+
param = next(m.parameters())
34+
dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype)
35+
assert dummy_inputs is not None
36+
m(**dummy_inputs)
37+
38+
3139
@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux])
3240
@pytest.mark.parametrize(
3341
("config_id", "quant_cfg"),
@@ -78,3 +86,33 @@ def _calib_fn(m):
7886

7987
config_data = _load_config(config_path)
8088
assert "quantization_config" in config_data
89+
90+
91+
def test_export_diffusers_sharded_default_no_header_metadata(tmp_path):
92+
"""A default (non-opt-in) sharded FP8 export succeeds and writes no header metadata.
93+
94+
Regression test for the FLUX FP8 export crash (NotImplementedError on sharded
95+
safetensors). A tiny max_shard_size forces the tiny model to split into multiple
96+
shards (+ index.json), reproducing the large-model path. With the header quant
97+
metadata off by default, post-processing is a no-op: the export must succeed and
98+
leave a clean safetensors header (the ComfyUI metadata is opt-in).
99+
"""
100+
model = get_tiny_flux()
101+
export_dir = tmp_path / "export_flux_fp8_sharded_default"
102+
103+
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=_calib_with_dummy_inputs)
104+
105+
# Tiny shard size forces sharding even for this tiny model.
106+
export_hf_checkpoint(model, export_dir=export_dir, max_shard_size="1KB")
107+
108+
assert list(export_dir.glob("*.safetensors.index.json")), (
109+
"expected a sharded export (index.json) with a tiny max_shard_size"
110+
)
111+
shard_files = sorted(export_dir.glob("*.safetensors"))
112+
assert len(shard_files) >= 2, "expected the model to split across multiple shards"
113+
114+
for shard in shard_files:
115+
with safe_open(str(shard), framework="pt") as f:
116+
md = f.metadata() or {}
117+
assert "quantization_config" not in md, f"unexpected header metadata in {shard.name}"
118+
assert "_quantization_metadata" not in md, f"unexpected header metadata in {shard.name}"

tests/unit/torch/export/test_export_diffusers.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,46 @@
2626

2727
pytest.importorskip("diffusers")
2828

29+
from safetensors import safe_open
30+
from safetensors.torch import save_file
31+
2932
import modelopt.torch.export.unified_export_hf as unified_export_hf
3033
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
3134
from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs
32-
from modelopt.torch.export.unified_export_hf import export_hf_checkpoint
35+
from modelopt.torch.export.unified_export_hf import _postprocess_safetensors, export_hf_checkpoint
3336

3437

3538
def _load_config(config_path):
3639
with open(config_path) as file:
3740
return json.load(file)
3841

3942

43+
def _write_sharded_checkpoint(export_dir, shards):
44+
"""Write ``shards`` (list of state-dict chunks) as sharded safetensors + index.json.
45+
46+
Mimics the layout produced by ``save_pretrained`` when a component is split across
47+
multiple files because it exceeds ``max_shard_size``.
48+
"""
49+
export_dir.mkdir(parents=True, exist_ok=True)
50+
total = len(shards)
51+
weight_map = {}
52+
total_size = 0
53+
for i, shard in enumerate(shards, start=1):
54+
filename = f"diffusion_pytorch_model-{i:05d}-of-{total:05d}.safetensors"
55+
save_file(shard, str(export_dir / filename))
56+
for key, tensor in shard.items():
57+
weight_map[key] = filename
58+
total_size += tensor.numel() * tensor.element_size()
59+
index = {"metadata": {"total_size": total_size}, "weight_map": weight_map}
60+
with open(export_dir / "diffusion_pytorch_model.safetensors.index.json", "w") as file:
61+
json.dump(index, file)
62+
63+
64+
def _read_safetensors_metadata(path):
65+
with safe_open(str(path), framework="pt") as file:
66+
return dict(file.metadata() or {})
67+
68+
4069
@pytest.mark.parametrize(
4170
"model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux, get_tiny_flux2]
4271
)
@@ -117,3 +146,82 @@ def test_flux2_dummy_inputs_shape():
117146

118147
# guidance_embeds defaults to True for Flux2
119148
assert "guidance" in inputs
149+
150+
151+
@pytest.mark.parametrize(
152+
"opt_in_kwargs",
153+
[
154+
{"enable_layerwise_quant_metadata": True},
155+
{"merged_base_safetensor_path": "/tmp/base.safetensors"},
156+
],
157+
)
158+
def test_postprocess_sharded_opt_in_raises(tmp_path, opt_in_kwargs):
159+
"""Opting into ComfyUI post-processing on a sharded checkpoint is unsupported.
160+
161+
Documents the existing limitation (out of scope for this fix). The bug fix is the
162+
default no-op path (see ``test_postprocess_default_is_noop``); only an explicit
163+
opt-in reaches this guard.
164+
"""
165+
export_dir = tmp_path / "sharded_opt_in"
166+
_write_sharded_checkpoint(
167+
export_dir,
168+
[
169+
{"layer_a.weight": torch.zeros(4, 4), "layer_a.weight_scale": torch.ones(1)},
170+
{"layer_b.weight": torch.zeros(4, 4), "layer_b.weight_scale": torch.ones(1)},
171+
],
172+
)
173+
174+
with pytest.raises(NotImplementedError, match="sharded safetensors"):
175+
_postprocess_safetensors(
176+
export_dir,
177+
hf_quant_config={"quant_algo": "FP8"},
178+
**opt_in_kwargs,
179+
)
180+
181+
182+
def test_postprocess_single_file_metadata_when_opted_in(tmp_path):
183+
"""With the opt-in flag, a non-sharded export injects quant config + per-layer metadata."""
184+
export_dir = tmp_path / "single_file"
185+
export_dir.mkdir(parents=True, exist_ok=True)
186+
save_file(
187+
{"layer_a.weight": torch.zeros(4, 4), "layer_a.weight_scale": torch.ones(1)},
188+
str(export_dir / "diffusion_pytorch_model.safetensors"),
189+
)
190+
191+
_postprocess_safetensors(
192+
export_dir,
193+
hf_quant_config={"quant_algo": "FP8"},
194+
enable_layerwise_quant_metadata=True,
195+
)
196+
197+
metadata = _read_safetensors_metadata(export_dir / "diffusion_pytorch_model.safetensors")
198+
assert "quantization_config" in metadata
199+
assert json.loads(metadata["_quantization_metadata"])["layers"] == {
200+
"layer_a": {"format": "fp8"}
201+
}
202+
203+
204+
def test_postprocess_default_is_noop(tmp_path):
205+
"""By default (no opt-in) nothing is written to the safetensors header.
206+
207+
The header quant metadata is a single-file deployment (e.g. ComfyUI) feature, so a
208+
plain export must leave the checkpoint untouched. This no-op default is also what
209+
keeps a default *sharded* export from reaching the unsupported-sharded path that
210+
caused the original FP8 FLUX crash.
211+
"""
212+
export_dir = tmp_path / "default_noop"
213+
_write_sharded_checkpoint(
214+
export_dir,
215+
[
216+
{"layer_a.weight": torch.zeros(4, 4), "layer_a.weight_scale": torch.ones(1)},
217+
{"layer_b.weight": torch.zeros(4, 4), "layer_b.weight_scale": torch.ones(1)},
218+
],
219+
)
220+
221+
# No opt-in kwargs: must not raise (even though sharded) and must inject nothing.
222+
_postprocess_safetensors(export_dir, hf_quant_config={"quant_algo": "FP8"})
223+
224+
for shard in sorted(export_dir.glob("*.safetensors")):
225+
metadata = _read_safetensors_metadata(shard)
226+
assert "quantization_config" not in metadata
227+
assert "_quantization_metadata" not in metadata

0 commit comments

Comments
 (0)