Skip to content

Commit a590393

Browse files
committed
Fixed memory overhead and enabled Flux with Mutable Module
1 parent 1f1e903 commit a590393

File tree

6 files changed

+45
-84
lines changed

6 files changed

+45
-84
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ def compile(
413413
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
414414
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
415415
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
416+
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
416417
**kwargs: Any,
417418
) -> torch.fx.GraphModule:
418419
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -662,6 +663,7 @@ def compile(
662663
"immutable_weights": immutable_weights,
663664
"enable_cross_compile_for_windows": False,
664665
"enable_weight_streaming": enable_weight_streaming,
666+
"offload_module_to_cpu": offload_module_to_cpu,
665667
}
666668

667669
settings = CompilationSettings(**compilation_options)
@@ -673,7 +675,8 @@ def compile(
673675

674676
gm = exported_program.module()
675677
# TODO: Memory control prototyping. Under discussion
676-
exported_program.module().to("cpu")
678+
if offload_module_to_cpu:
679+
exported_program.module().to("cpu")
677680
logger.debug("Input graph: " + str(gm.graph))
678681

679682
# Apply lowering on the graph module

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ENABLE_WEIGHT_STREAMING = False
4848
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
4949
USE_AOT_JOINT_EXPORT = True
50+
OFFLOAD_MODULE_TO_CPU = True
5051

5152

5253
def default_device() -> Device:

py/torch_tensorrt/dynamo/_refit.py

+14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import copy
5+
import gc
56
import logging
67
from typing import Any, List, Optional, Sequence, Tuple
78

@@ -307,6 +308,10 @@ def refit_module_weights(
307308
get_decompositions(settings.enable_experimental_decompositions)
308309
)
309310
new_gm = new_weight_module.module()
311+
# TODO: Memory control prototyping. Under discussion
312+
if settings.offload_module_to_cpu:
313+
new_weight_module.module().to("cpu")
314+
310315
logger.debug("Input graph: " + str(new_gm.graph))
311316
# Apply lowering on the graph module
312317

@@ -462,12 +467,21 @@ def refit_module_weights(
462467
settings=settings,
463468
weight_name_map=None,
464469
)
470+
# TODO: Memory control prototyping. Under discussion
471+
if settings.offload_module_to_cpu:
472+
del new_submodule
473+
gc.collect()
474+
torch.cuda.empty_cache()
465475

466476
# clear EXCLUDE_WEIGHTS flag
467477
serialization_config = engine.create_serialization_config()
468478
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
469479
serialized_engine = engine.serialize_with_config(serialization_config)
470480

481+
del engine
482+
gc.collect()
483+
torch.cuda.empty_cache()
484+
471485
if isinstance(
472486
compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule)
473487
):

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MAX_AUX_STREAMS,
2525
MIN_BLOCK_SIZE,
2626
NUM_AVG_TIMING_ITERS,
27+
OFFLOAD_MODULE_TO_CPU,
2728
OPTIMIZATION_LEVEL,
2829
PASS_THROUGH_BUILD_FAILURES,
2930
REFIT_IDENTICAL_ENGINE_WEIGHTS,
@@ -134,6 +135,7 @@ class CompilationSettings:
134135
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
135136
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
136137
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
138+
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
137139

138140

139141
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -715,9 +715,10 @@ def run(
715715
builder_config, self.compilation_settings.timing_cache_path
716716
)
717717
# TODO: Memory control prototyping. Under discussion
718-
del self.module
719-
gc.collect()
720-
torch.cuda.empty_cache()
718+
if self.compilation_settings.offload_module_to_cpu:
719+
del self.module
720+
gc.collect()
721+
torch.cuda.empty_cache()
721722
serialized_engine = self.builder.build_serialized_network(
722723
self.ctx.net, builder_config
723724
)

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+20-80
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,14 @@
22
import logging
33
from copy import deepcopy
44
from enum import Enum, auto
5-
from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union
5+
from typing import Any, Dict, Iterator, Optional, Union
66

77
import numpy as np
88
import torch
9-
from torch.fx.node import Target
109
from torch_tensorrt._Device import Device
11-
from torch_tensorrt._enums import EngineCapability, dtype
1210
from torch_tensorrt.dynamo import _defaults
1311
from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
1412
from torch_tensorrt.dynamo._refit import refit_module_weights
15-
from torch_tensorrt.dynamo._settings import CompilationSettings
1613
from torch_tensorrt.dynamo.utils import (
1714
check_output_equal,
1815
to_torch_device,
@@ -63,35 +60,8 @@ def __init__(
6360
pytorch_model: torch.nn.Module,
6461
*,
6562
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
66-
disable_tf32: bool = _defaults.DISABLE_TF32,
67-
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
68-
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
69-
enabled_precisions: Set[
70-
Union[torch.dtype, dtype]
71-
] = _defaults.ENABLED_PRECISIONS,
72-
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
73-
immutable_weights: bool = False,
74-
debug: bool = _defaults.DEBUG,
75-
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
76-
workspace_size: int = _defaults.WORKSPACE_SIZE,
77-
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
78-
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
79-
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
80-
truncate_double: bool = _defaults.TRUNCATE_DOUBLE,
81-
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
82-
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
83-
torch_executed_ops: Optional[Collection[Target]] = None,
84-
torch_executed_modules: Optional[List[str]] = None,
85-
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
86-
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
87-
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
88-
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
8963
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
90-
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
91-
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
92-
dryrun: bool = _defaults.DRYRUN,
93-
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
94-
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
64+
immutable_weights: bool = False,
9565
**kwargs: Any,
9666
) -> None:
9767
"""
@@ -154,50 +124,15 @@ def __init__(
154124
self.exp_program: Any = None
155125
self.arg_inputs: tuple[Any, ...] = tuple()
156126
self.kwarg_inputs: dict[str, Any] = {}
157-
device = to_torch_tensorrt_device(device)
158-
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
127+
self.additional_settings = kwargs
128+
self.use_python_runtime = use_python_runtime
129+
self.trt_device = to_torch_tensorrt_device(device)
159130
assert (
160131
not immutable_weights
161-
), "`immutable_weights` has to be False for a MutableTorchTensorRTModule."
162-
compilation_options = {
163-
"enabled_precisions": (
164-
enabled_precisions
165-
if enabled_precisions
166-
else _defaults.ENABLED_PRECISIONS
167-
),
168-
"debug": debug,
169-
"device": device,
170-
"assume_dynamic_shape_support": assume_dynamic_shape_support,
171-
"workspace_size": workspace_size,
172-
"min_block_size": min_block_size,
173-
"torch_executed_ops": (
174-
torch_executed_ops if torch_executed_ops is not None else set()
175-
),
176-
"pass_through_build_failures": pass_through_build_failures,
177-
"max_aux_streams": max_aux_streams,
178-
"version_compatible": version_compatible,
179-
"optimization_level": optimization_level,
180-
"use_python_runtime": use_python_runtime,
181-
"truncate_double": truncate_double,
182-
"use_fast_partitioner": use_fast_partitioner,
183-
"num_avg_timing_iters": num_avg_timing_iters,
184-
"enable_experimental_decompositions": enable_experimental_decompositions,
185-
"require_full_compilation": require_full_compilation,
186-
"disable_tf32": disable_tf32,
187-
"sparse_weights": sparse_weights,
188-
"immutable_weights": immutable_weights,
189-
"engine_capability": engine_capability,
190-
"dla_sram_size": dla_sram_size,
191-
"dla_local_dram_size": dla_local_dram_size,
192-
"dla_global_dram_size": dla_global_dram_size,
193-
"dryrun": dryrun,
194-
"hardware_compatible": hardware_compatible,
195-
"timing_cache_path": timing_cache_path,
196-
}
132+
), "`immutable_weights has to be False for a MutableTorchTensorRTModule"
133+
197134
self.arg_dynamic_shapes: Optional[tuple[Any]] = None
198135
self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
199-
200-
self.settings = CompilationSettings(**compilation_options)
201136
self.run_info: Optional[tuple[Any, ...]] = None
202137
self.state_dict_metadata: dict[str, torch.Size] = {}
203138
self._store_state_dict_metadata()
@@ -293,7 +228,7 @@ def update_refit_condition(self) -> None:
293228
# to determine whether refit/recompilation is needed. If the output is the same, no further process needed.
294229
if self.run_info:
295230
args, kwargs, result = self.run_info
296-
self.original_model.to(to_torch_device(self.settings.device))
231+
self.original_model.to(to_torch_device(self.trt_device))
297232
new_result = self.original_model(*args, **kwargs)
298233
self.original_model.cpu()
299234
torch.cuda.empty_cache()
@@ -325,7 +260,7 @@ def refit_gm(self) -> None:
325260
MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module.
326261
If it fails to catch the changes, please call this function manually to update the TRT graph module.
327262
"""
328-
self.original_model.to(to_torch_device(self.settings.device))
263+
self.original_model.to(to_torch_device(self.trt_device))
329264
if self.exp_program is None:
330265
self.exp_program = torch.export.export(
331266
self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs
@@ -356,25 +291,30 @@ def compile(self) -> None:
356291
If it fails to catch the changes, please call this function manually to recompile the TRT graph module.
357292
"""
358293
# Export the module
359-
self.original_model.to(to_torch_device(self.settings.device))
360-
self.exp_program = torch.export.export(
294+
self.original_model.to(to_torch_device(self.trt_device))
295+
self.exp_program = torch.export._trace._export(
361296
self.original_model,
362297
self.arg_inputs,
363298
kwargs=self.kwarg_inputs,
364299
dynamic_shapes=self._get_total_dynamic_shapes(),
300+
strict=False,
301+
allow_complex_guards_as_runtime_asserts=True,
302+
# **self.additional_settings
365303
)
366304
self.gm = dynamo_compile(
367305
self.exp_program,
368306
arg_inputs=self.arg_inputs,
369307
kwarg_inputs=self.kwarg_inputs,
370-
**self.settings.__dict__,
308+
immutable_weights=False,
309+
use_python_runtime=self.use_python_runtime,
310+
**self.additional_settings,
371311
)
372312
self.original_model.cpu()
373313
torch.cuda.empty_cache()
374314

375315
def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
376316

377-
if not self.arg_inputs:
317+
if not self.arg_inputs and not self.kwarg_inputs:
378318
logger.info("First time compilation initiated. This may take some time.")
379319
self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
380320
self._store_inputs(args, kwargs)
@@ -628,7 +568,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
628568
def save(module: Any, path: str) -> None:
629569
# Cast the object back to MutableTorchTensorRTModule to save
630570
assert (
631-
not module.settings.use_python_runtime
571+
not module.use_python_runtime
632572
), "Python runtime does not support serialization. Save failed."
633573
module.init_finished = False
634574
module.__class__ = MutableTorchTensorRTModule
@@ -658,7 +598,7 @@ def load(path: str) -> Any:
658598
module.pytorch_model = _make_refit_change_trigger(
659599
module.original_model, module.refit_state
660600
)
661-
module.original_model.to(to_torch_device(module.settings.device))
601+
module.original_model.to(to_torch_device(module.device))
662602
module.exp_program = torch.export.export(
663603
module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs
664604
)

0 commit comments

Comments
 (0)