Skip to content

Commit 01e24fd

Browse files
authored
Selective mode saving (#596)
## What does this PR do? **Type of change:** ? Bug fix **Overview:** Filter out KD state from ModelOpt state list when saving. This allows for applying the KD mode after a modelopt checkpoint restore without it complaining that it was already applied previously. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 38550b0 commit 01e24fd

File tree

6 files changed

+48
-35
lines changed

6 files changed

+48
-35
lines changed

modelopt/torch/distill/mode.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,17 @@ def convert(self) -> ConvertEntrypoint:
7878
@property
7979
def restore(self) -> RestoreEntrypoint:
8080
"""The mode's entrypoint for restoring a model."""
81-
return _restore_kd_model
81+
raise NotImplementedError(f"{self.name} mode does not support restore.")
8282

8383
@property
8484
def update_for_new_mode(self) -> UpdateEntrypoint:
8585
"""The mode's entrypoint for updating the models state for adding new mode."""
8686
return _reset_kd_state_config
8787

8888
@property
89-
def update_for_save(self) -> UpdateEntrypoint:
90-
"""The mode's entrypoint for updating the models state before saving."""
91-
return _reset_kd_state_config
89+
def save_mode_in_state(self) -> bool:
90+
"""Whether the mode should be saved into the modelopt state."""
91+
return False
9292

9393

9494
@DistillModeRegistry.register_mode
@@ -121,7 +121,12 @@ def convert(self) -> ConvertEntrypoint:
121121
@property
122122
def restore(self) -> RestoreEntrypoint:
123123
"""The mode's entrypoint for restoring a model."""
124-
return _restore_exported_student
124+
raise NotImplementedError(f"{self.name} mode does not support restore.")
125+
126+
@property
127+
def save_mode_in_state(self) -> bool:
128+
"""Whether the mode should be saved into the modelopt state."""
129+
return False
125130

126131

127132
def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType:
@@ -174,12 +179,6 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType
174179
return distillation_model, metadata
175180

176181

177-
def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module:
178-
"""Function for restoring a previously convert model to a distillation meta-model."""
179-
# NOTE: DistillationModel will purposely remain unrestored
180-
return model
181-
182-
183182
def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
184183
"""Function for resetting the state's config."""
185184
config.teacher_model = nn.Module
@@ -206,16 +205,8 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet
206205
student_model,
207206
warn=True,
208207
msg=(
209-
f"The student model is wrapped into {type(student_model).__name__}. Unwrapping and"
210-
" exporting it ..."
208+
f"The student model is wrapped into {type(student_model).__name__}. Unwrapping and exporting it ..."
211209
),
212210
)
213211

214212
return student_model, {}
215-
216-
217-
def _restore_exported_student(
218-
model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict
219-
) -> nn.Module:
220-
# NOTE: DistillationModel was unrestored so this does nothing
221-
return model

modelopt/torch/opt/conversion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,16 @@ def modelopt_state(model: nn.Module) -> dict[str, Any]:
476476
# update metadata of current mode as needed
477477
manager.update_last_state_before_save(model)
478478

479+
# filter out modes that should not be saved in the state
480+
skip_idx = []
481+
for i, (m, _, _) in enumerate(manager.modes_with_states()):
482+
if not m.save_mode_in_state:
483+
skip_idx.append(i)
484+
state_dict = [state for i, state in enumerate(manager.state_dict()) if i not in skip_idx]
485+
479486
# construct state dict and return it
480487
objs = {
481-
"modelopt_state_dict": (
482-
manager.state_dict()
483-
), # empty state_dict is okay (saving regular models)
488+
"modelopt_state_dict": state_dict, # empty state_dict is okay (saving regular models)
484489
"modelopt_version": __version__,
485490
}
486491
return objs

modelopt/torch/opt/mode.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,25 @@ def require_model_like(self) -> bool:
242242
"""
243243
return False
244244

245+
@property
246+
def save_mode_in_state(self) -> bool:
247+
"""Whether the mode should be saved into the modelopt state.
248+
249+
This is useful if the mode is intended to be manually re-applied every time it's used.
250+
251+
Returns:
252+
True
253+
"""
254+
return True
255+
245256
def assert_compatibility_as_next_mode_of(self, other_mode: "ModeDescriptor | str") -> None:
246257
"""Assert that this mode is compatible as a next mode of the other mode."""
247258
if isinstance(other_mode, str):
248259
other_mode = _ModeRegistryCls.get_from_any(other_mode)
249260

250261
if other_mode.next_modes is not None:
251262
assert str(self) in other_mode.next_modes, (
252-
f"Cannot add {self} after {other_mode}! Next modes of {other_mode} are"
253-
f" {other_mode.next_modes}."
263+
f"Cannot add {self} after {other_mode}! Next modes of {other_mode} are {other_mode.next_modes}."
254264
)
255265

256266
if other_mode.next_prohibited_modes is not None:

tests/unit/torch/distill/test_distill.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,16 +147,21 @@ def test_distillation_save_restore(distillation_model, tmp_path):
147147
new_student = tiny_mobilenet()
148148
distillation_model_new = mto.restore(new_student, tmp_path / "ckpt.pt")
149149

150-
# Ensure state config was reset
150+
# Ensure state is not actually restored
151151
manager = mto.ModeloptStateManager(distillation_model_new)
152-
cfg = manager._state[-1][1]["config"]
153-
assert cfg["teacher_model"] == nn.Module
154-
assert isinstance(next(iter(cfg["criterion"].values())), Loss)
155-
assert cfg["loss_balancer"] is None
156-
157-
# Should not have restored anything
152+
assert not manager.has_state
158153
assert isinstance(distillation_model_new, type(new_student))
159154

155+
# Subsequent convert should behave normally
156+
config = {
157+
"teacher_model": distillation_model.teacher_model,
158+
"criterion": mtd.LogitsDistillationLoss(),
159+
}
160+
distillation_model_newer = mtd.convert(new_student, mode=[("kd_loss", config)])
161+
manager = mto.ModeloptStateManager(distillation_model_newer)
162+
assert manager.has_state
163+
assert isinstance(distillation_model_newer, mtd.DistillationModel)
164+
160165

161166
def test_distillation_export(distillation_model, tmp_path):
162167
model_exported = mtd.export(distillation_model)

tests/unit/torch/opt/plugins/test_hf_patching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,5 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type):
5353
model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model")
5454

5555
tf_output_tester(model, model_test)
56-
# since distill model contains loss function, we compare state of model manually
57-
assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model)
56+
# KD state is not saved and it should be empty
57+
assert not mto.ModeloptStateManager(model_test).has_state

tests/unit/torch/opt/test_chaining.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def test_chained_save_restore(mode):
8888
# compare serialized version since some configs may be objected...
8989
manager = mto.ModeloptStateManager(model)
9090
manager2 = mto.ModeloptStateManager(model2)
91-
assert torch.equal(_serialize(manager.state_dict()), _serialize(manager2.state_dict()))
91+
# NOTE: KD modes are skipped during restore and thus won't exist
92+
state_minus_kd = [s for s in manager.state_dict() if s[0] not in ("kd_loss", "export_student")]
93+
assert torch.equal(_serialize(state_minus_kd), _serialize(manager2.state_dict()))
9294

9395
# run comparison in eval mode since there might be model randomization in train mode
9496
model.eval()

0 commit comments

Comments
 (0)