Skip to content

Commit 18918ff

Browse files
committed
Supported LoRA
1 parent a590393 commit 18918ff

File tree

4 files changed

+176
-178
lines changed

4 files changed

+176
-178
lines changed

examples/apps/flux-demo.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
3+
import gradio as gr
4+
import torch
5+
import torch_tensorrt
6+
from diffusers import FluxPipeline, StableDiffusionPipeline
7+
from torch.export._trace import _export
8+
9+
DEVICE = "cuda:0"
10+
pipe = FluxPipeline.from_pretrained(
11+
"black-forest-labs/FLUX.1-dev",
12+
torch_dtype=torch.float16,
13+
)
14+
pipe.to(DEVICE).to(torch.float16)
15+
backbone = pipe.transformer
16+
17+
18+
batch_size = 2
19+
BATCH = torch.export.Dim("batch", min=1, max=8)
20+
21+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
22+
# To see this recommendation, you can try exporting using min=1, max=4096
23+
dynamic_shapes = {
24+
"hidden_states": {0: BATCH},
25+
"encoder_hidden_states": {0: BATCH},
26+
"pooled_projections": {0: BATCH},
27+
"timestep": {0: BATCH},
28+
"txt_ids": {},
29+
"img_ids": {},
30+
"guidance": {0: BATCH},
31+
"joint_attention_kwargs": {},
32+
"return_dict": None,
33+
}
34+
35+
settings = {
36+
"strict": False,
37+
"allow_complex_guards_as_runtime_asserts": True,
38+
"enabled_precisions": {torch.float32},
39+
"truncate_double": True,
40+
"min_block_size": 1,
41+
"use_fp32_acc": True,
42+
"use_explicit_typing": True,
43+
"debug": False,
44+
"use_python_runtime": True,
45+
"immutable_weights": False,
46+
}
47+
48+
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
49+
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
50+
pipe.transformer = trt_gm
51+
52+
53+
def generate_image(prompt, inference_step, batch_size=1):
54+
image = pipe(
55+
prompt,
56+
output_type="pil",
57+
num_inference_steps=inference_step,
58+
num_images_per_prompt=batch_size,
59+
).images
60+
return image
61+
62+
63+
generate_image(["A golden retriever holding a sign to code"], 2)
64+
65+
66+
def model_change(model):
67+
if model == "Torch Model":
68+
pipe.transformer = backbone
69+
backbone.to(DEVICE)
70+
else:
71+
backbone.to("cpu")
72+
pipe.transformer = trt_gm
73+
torch.cuda.empty_cache()
74+
75+
76+
def load_lora(path):
77+
78+
pipe.load_lora_weights(
79+
path,
80+
adapter_name="lora1",
81+
)
82+
pipe.set_adapters(["lora1"], adapter_weights=[1])
83+
pipe.fuse_lora()
84+
pipe.unload_lora_weights()
85+
print("LoRA loaded!")
86+
87+
88+
# Create Gradio interface
89+
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
90+
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")
91+
92+
with gr.Row():
93+
with gr.Column():
94+
# Input components
95+
prompt_input = gr.Textbox(
96+
label="Prompt", placeholder="Enter your prompt here...", lines=3
97+
)
98+
model_dropdown = gr.Dropdown(
99+
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
100+
value="Torch-TensorRT Accelerated Model",
101+
label="Model Variant",
102+
)
103+
104+
lora_upload_path = gr.Textbox(
105+
label="LoRA Path",
106+
placeholder="/home/TensorRT/examples/apps/NGRVNG.safetensors",
107+
lines=2,
108+
)
109+
num_steps = gr.Slider(
110+
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
111+
)
112+
batch_size = gr.Slider(
113+
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
114+
)
115+
116+
generate_btn = gr.Button("Generate Image")
117+
load_lora_btn = gr.Button("Load LoRA")
118+
119+
with gr.Column():
120+
# Output component
121+
output_image = gr.Gallery(label="Generated Image")
122+
123+
# Connect the button to the generation function
124+
model_dropdown.change(model_change, inputs=[model_dropdown])
125+
generate_btn.click(
126+
fn=generate_image,
127+
inputs=[
128+
prompt_input,
129+
num_steps,
130+
batch_size,
131+
],
132+
outputs=output_image,
133+
)
134+
load_lora_btn.click(
135+
fn=load_lora,
136+
inputs=[
137+
lora_upload_path,
138+
],
139+
)
140+
141+
# Launch the interface
142+
if __name__ == "__main__":
143+
demo.launch()

examples/apps/flux_demo.py

-158
This file was deleted.

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,12 @@ def _save_weight_mapping(self) -> None:
469469
# Stage 1: Name mapping
470470
torch_device = to_torch_device(self.compilation_settings.device)
471471
gm_is_on_cuda = get_model_device(self.module).type == "cuda"
472-
if not gm_is_on_cuda:
473-
# If the model original position is on CPU, move it GPU
474-
sd = {
475-
k: v.reshape(-1).to(torch_device)
476-
for k, v in self.module.state_dict().items()
477-
}
478-
else:
479-
sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()}
472+
# If the model original position is on CPU, move it GPU
473+
sd = {
474+
k: v.reshape(-1).to(torch_device)
475+
for k, v in self.module.state_dict().items()
476+
}
477+
480478
weight_name_map: dict[str, Any] = {}
481479
np_map = {}
482480
constant_mapping = {}

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(
6262
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
6363
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
6464
immutable_weights: bool = False,
65+
strict: bool = True,
66+
allow_complex_guards_as_runtime_asserts: bool = False,
6567
**kwargs: Any,
6668
) -> None:
6769
"""
@@ -125,6 +127,10 @@ def __init__(
125127
self.arg_inputs: tuple[Any, ...] = tuple()
126128
self.kwarg_inputs: dict[str, Any] = {}
127129
self.additional_settings = kwargs
130+
self.strict = strict
131+
self.allow_complex_guards_as_runtime_asserts = (
132+
allow_complex_guards_as_runtime_asserts
133+
)
128134
self.use_python_runtime = use_python_runtime
129135
self.trt_device = to_torch_tensorrt_device(device)
130136
assert (
@@ -262,9 +268,7 @@ def refit_gm(self) -> None:
262268
"""
263269
self.original_model.to(to_torch_device(self.trt_device))
264270
if self.exp_program is None:
265-
self.exp_program = torch.export.export(
266-
self.original_model, self.arg_inputs, kwargs=self.kwarg_inputs
267-
)
271+
self.exp_program = self.get_exported_program()
268272
else:
269273
self.exp_program._state_dict = (
270274
MutableTorchTensorRTModule._transform_state_dict(
@@ -283,6 +287,25 @@ def refit_gm(self) -> None:
283287
self.original_model.cpu()
284288
torch.cuda.empty_cache()
285289

290+
def get_exported_program(self) -> torch.export.ExportedProgram:
291+
if self.allow_complex_guards_as_runtime_asserts:
292+
return torch.export._trace._export(
293+
self.original_model,
294+
self.arg_inputs,
295+
kwargs=self.kwarg_inputs,
296+
dynamic_shapes=self._get_total_dynamic_shapes(),
297+
strict=self.strict,
298+
allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts,
299+
)
300+
else:
301+
return torch.export.export(
302+
self.original_model,
303+
self.arg_inputs,
304+
kwargs=self.kwarg_inputs,
305+
dynamic_shapes=self._get_total_dynamic_shapes(),
306+
strict=self.strict,
307+
)
308+
286309
def compile(self) -> None:
287310
"""
288311
(Re)compile the TRT graph module using the PyTorch module.
@@ -292,15 +315,7 @@ def compile(self) -> None:
292315
"""
293316
# Export the module
294317
self.original_model.to(to_torch_device(self.trt_device))
295-
self.exp_program = torch.export._trace._export(
296-
self.original_model,
297-
self.arg_inputs,
298-
kwargs=self.kwarg_inputs,
299-
dynamic_shapes=self._get_total_dynamic_shapes(),
300-
strict=False,
301-
allow_complex_guards_as_runtime_asserts=True,
302-
# **self.additional_settings
303-
)
318+
self.exp_program = self.get_exported_program()
304319
self.gm = dynamo_compile(
305320
self.exp_program,
306321
arg_inputs=self.arg_inputs,

0 commit comments

Comments
 (0)