|
2 | 2 | import logging
|
3 | 3 | from copy import deepcopy
|
4 | 4 | from enum import Enum, auto
|
5 |
| -from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union |
| 5 | +from typing import Any, Dict, Iterator, Optional, Union |
6 | 6 |
|
7 | 7 | import numpy as np
|
8 | 8 | import torch
|
9 |
| -from torch.fx.node import Target |
10 | 9 | from torch_tensorrt._Device import Device
|
11 |
| -from torch_tensorrt._enums import EngineCapability, dtype |
12 | 10 | from torch_tensorrt.dynamo import _defaults
|
13 | 11 | from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
|
14 | 12 | from torch_tensorrt.dynamo._refit import refit_module_weights
|
15 |
| -from torch_tensorrt.dynamo._settings import CompilationSettings |
16 | 13 | from torch_tensorrt.dynamo.utils import (
|
17 | 14 | check_output_equal,
|
18 | 15 | to_torch_device,
|
@@ -63,35 +60,8 @@ def __init__(
|
63 | 60 | pytorch_model: torch.nn.Module,
|
64 | 61 | *,
|
65 | 62 | device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
|
66 |
| - disable_tf32: bool = _defaults.DISABLE_TF32, |
67 |
| - assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT, |
68 |
| - sparse_weights: bool = _defaults.SPARSE_WEIGHTS, |
69 |
| - enabled_precisions: Set[ |
70 |
| - Union[torch.dtype, dtype] |
71 |
| - ] = _defaults.ENABLED_PRECISIONS, |
72 |
| - engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, |
73 |
| - immutable_weights: bool = False, |
74 |
| - debug: bool = _defaults.DEBUG, |
75 |
| - num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, |
76 |
| - workspace_size: int = _defaults.WORKSPACE_SIZE, |
77 |
| - dla_sram_size: int = _defaults.DLA_SRAM_SIZE, |
78 |
| - dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE, |
79 |
| - dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE, |
80 |
| - truncate_double: bool = _defaults.TRUNCATE_DOUBLE, |
81 |
| - require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, |
82 |
| - min_block_size: int = _defaults.MIN_BLOCK_SIZE, |
83 |
| - torch_executed_ops: Optional[Collection[Target]] = None, |
84 |
| - torch_executed_modules: Optional[List[str]] = None, |
85 |
| - pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES, |
86 |
| - max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS, |
87 |
| - version_compatible: bool = _defaults.VERSION_COMPATIBLE, |
88 |
| - optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL, |
89 | 63 | use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
|
90 |
| - use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER, |
91 |
| - enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS, |
92 |
| - dryrun: bool = _defaults.DRYRUN, |
93 |
| - hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, |
94 |
| - timing_cache_path: str = _defaults.TIMING_CACHE_PATH, |
| 64 | + immutable_weights: bool = False, |
95 | 65 | **kwargs: Any,
|
96 | 66 | ) -> None:
|
97 | 67 | """
|
@@ -154,50 +124,15 @@ def __init__(
|
154 | 124 | self.exp_program: Any = None
|
155 | 125 | self.arg_inputs: tuple[Any, ...] = tuple()
|
156 | 126 | self.kwarg_inputs: dict[str, Any] = {}
|
157 |
| - device = to_torch_tensorrt_device(device) |
158 |
| - enabled_precisions = {dtype._from(p) for p in enabled_precisions} |
| 127 | + self.additional_settings = kwargs |
| 128 | + self.use_python_runtime = use_python_runtime |
| 129 | + self.trt_device = to_torch_tensorrt_device(device) |
159 | 130 | assert (
|
160 | 131 | not immutable_weights
|
161 |
| - ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." |
162 |
| - compilation_options = { |
163 |
| - "enabled_precisions": ( |
164 |
| - enabled_precisions |
165 |
| - if enabled_precisions |
166 |
| - else _defaults.ENABLED_PRECISIONS |
167 |
| - ), |
168 |
| - "debug": debug, |
169 |
| - "device": device, |
170 |
| - "assume_dynamic_shape_support": assume_dynamic_shape_support, |
171 |
| - "workspace_size": workspace_size, |
172 |
| - "min_block_size": min_block_size, |
173 |
| - "torch_executed_ops": ( |
174 |
| - torch_executed_ops if torch_executed_ops is not None else set() |
175 |
| - ), |
176 |
| - "pass_through_build_failures": pass_through_build_failures, |
177 |
| - "max_aux_streams": max_aux_streams, |
178 |
| - "version_compatible": version_compatible, |
179 |
| - "optimization_level": optimization_level, |
180 |
| - "use_python_runtime": use_python_runtime, |
181 |
| - "truncate_double": truncate_double, |
182 |
| - "use_fast_partitioner": use_fast_partitioner, |
183 |
| - "num_avg_timing_iters": num_avg_timing_iters, |
184 |
| - "enable_experimental_decompositions": enable_experimental_decompositions, |
185 |
| - "require_full_compilation": require_full_compilation, |
186 |
| - "disable_tf32": disable_tf32, |
187 |
| - "sparse_weights": sparse_weights, |
188 |
| - "immutable_weights": immutable_weights, |
189 |
| - "engine_capability": engine_capability, |
190 |
| - "dla_sram_size": dla_sram_size, |
191 |
| - "dla_local_dram_size": dla_local_dram_size, |
192 |
| - "dla_global_dram_size": dla_global_dram_size, |
193 |
| - "dryrun": dryrun, |
194 |
| - "hardware_compatible": hardware_compatible, |
195 |
| - "timing_cache_path": timing_cache_path, |
196 |
| - } |
| 132 | + ), "`immutable_weights has to be False for a MutableTorchTensorRTModule" |
| 133 | + |
197 | 134 | self.arg_dynamic_shapes: Optional[tuple[Any]] = None
|
198 | 135 | self.kwarg_dynamic_shapes: Optional[dict[Any, Any]] = None
|
199 |
| - |
200 |
| - self.settings = CompilationSettings(**compilation_options) |
201 | 136 | self.run_info: Optional[tuple[Any, ...]] = None
|
202 | 137 | self.state_dict_metadata: dict[str, torch.Size] = {}
|
203 | 138 | self._store_state_dict_metadata()
|
@@ -293,7 +228,7 @@ def update_refit_condition(self) -> None:
|
293 | 228 | # to determine whether refit/recompilation is needed. If the output is the same, no further process needed.
|
294 | 229 | if self.run_info:
|
295 | 230 | args, kwargs, result = self.run_info
|
296 |
| - self.original_model.to(to_torch_device(self.settings.device)) |
| 231 | + self.original_model.to(to_torch_device(self.trt_device)) |
297 | 232 | new_result = self.original_model(*args, **kwargs)
|
298 | 233 | self.original_model.cpu()
|
299 | 234 | torch.cuda.empty_cache()
|
@@ -325,7 +260,7 @@ def refit_gm(self) -> None:
|
325 | 260 | MutableTorchTensorRTModule automatically catches weight value updates and call this function to refit the module.
|
326 | 261 | If it fails to catch the changes, please call this function manually to update the TRT graph module.
|
327 | 262 | """
|
328 |
| - self.original_model.to(to_torch_device(self.settings.device)) |
| 263 | + self.original_model.to(to_torch_device(self.trt_device)) |
329 | 264 | if self.exp_program is None:
|
330 | 265 | self.exp_program = torch.export.export(
|
331 | 266 | self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs
|
@@ -356,25 +291,30 @@ def compile(self) -> None:
|
356 | 291 | If it fails to catch the changes, please call this function manually to recompile the TRT graph module.
|
357 | 292 | """
|
358 | 293 | # Export the module
|
359 |
| - self.original_model.to(to_torch_device(self.settings.device)) |
360 |
| - self.exp_program = torch.export.export( |
| 294 | + self.original_model.to(to_torch_device(self.trt_device)) |
| 295 | + self.exp_program = torch.export._trace._export( |
361 | 296 | self.original_model,
|
362 | 297 | self.arg_inputs,
|
363 | 298 | kwargs=self.kwarg_inputs,
|
364 | 299 | dynamic_shapes=self._get_total_dynamic_shapes(),
|
| 300 | + strict=False, |
| 301 | + allow_complex_guards_as_runtime_asserts=True, |
| 302 | + # **self.additional_settings |
365 | 303 | )
|
366 | 304 | self.gm = dynamo_compile(
|
367 | 305 | self.exp_program,
|
368 | 306 | arg_inputs=self.arg_inputs,
|
369 | 307 | kwarg_inputs=self.kwarg_inputs,
|
370 |
| - **self.settings.__dict__, |
| 308 | + immutable_weights=False, |
| 309 | + use_python_runtime=self.use_python_runtime, |
| 310 | + **self.additional_settings, |
371 | 311 | )
|
372 | 312 | self.original_model.cpu()
|
373 | 313 | torch.cuda.empty_cache()
|
374 | 314 |
|
375 | 315 | def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
|
376 | 316 |
|
377 |
| - if not self.arg_inputs: |
| 317 | + if not self.arg_inputs and not self.kwarg_inputs: |
378 | 318 | logger.info("First time compilation initiated. This may take some time.")
|
379 | 319 | self.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
|
380 | 320 | self._store_inputs(args, kwargs)
|
@@ -628,7 +568,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
|
628 | 568 | def save(module: Any, path: str) -> None:
|
629 | 569 | # Cast the object back to MutableTorchTensorRTModule to save
|
630 | 570 | assert (
|
631 |
| - not module.settings.use_python_runtime |
| 571 | + not module.use_python_runtime |
632 | 572 | ), "Python runtime does not support serialization. Save failed."
|
633 | 573 | module.init_finished = False
|
634 | 574 | module.__class__ = MutableTorchTensorRTModule
|
@@ -658,7 +598,7 @@ def load(path: str) -> Any:
|
658 | 598 | module.pytorch_model = _make_refit_change_trigger(
|
659 | 599 | module.original_model, module.refit_state
|
660 | 600 | )
|
661 |
| - module.original_model.to(to_torch_device(module.settings.device)) |
| 601 | + module.original_model.to(to_torch_device(module.device)) |
662 | 602 | module.exp_program = torch.export.export(
|
663 | 603 | module.original_model, module.arg_inputs, kwargs=module.kwarg_inputs
|
664 | 604 | )
|
|
0 commit comments