Skip to content

Commit e8e7850

Browse files
authored
[sharktank] Fix VAE string formatting in tests (#1014)
Add missing format prefix f.
1 parent e905798 commit e8e7850

File tree

1 file changed

+50
-47
lines changed

1 file changed

+50
-47
lines changed

sharktank/tests/models/vae/vae_test.py

+50-47
Original file line numberDiff line numberDiff line change
@@ -49,42 +49,43 @@ def setUp(self):
4949
hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
5050
hf_hub_download(
5151
repo_id=hf_model_id,
52-
local_dir="{self._temp_dir}",
52+
local_dir=f"{self._temp_dir}",
5353
local_dir_use_symlinks=False,
5454
revision="main",
5555
filename="vae/config.json",
5656
)
5757
hf_hub_download(
5858
repo_id=hf_model_id,
59-
local_dir="{self._temp_dir}",
59+
local_dir=f"{self._temp_dir}",
6060
local_dir_use_symlinks=False,
6161
revision="main",
6262
filename="vae/diffusion_pytorch_model.safetensors",
6363
)
6464
hf_hub_download(
6565
repo_id="amd-shark/sdxl-quant-models",
66-
local_dir="{self._temp_dir}",
66+
local_dir=f"{self._temp_dir}",
6767
local_dir_use_symlinks=False,
6868
revision="main",
6969
filename="vae/vae.safetensors",
7070
)
7171
torch.manual_seed(12345)
7272
f32_dataset = import_hf_dataset(
73-
"{self._temp_dir}/vae/config.json",
74-
["{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"],
73+
f"{self._temp_dir}/vae/config.json",
74+
[f"{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"],
7575
)
76-
f32_dataset.save("{self._temp_dir}/vae_f32.irpa", io_report_callback=print)
76+
f32_dataset.save(f"{self._temp_dir}/vae_f32.irpa", io_report_callback=print)
7777
f16_dataset = import_hf_dataset(
78-
"{self._temp_dir}/vae/config.json", ["{self._temp_dir}/vae/vae.safetensors"]
78+
f"{self._temp_dir}/vae/config.json",
79+
[f"{self._temp_dir}/vae/vae.safetensors"],
7980
)
80-
f16_dataset.save("{self._temp_dir}/vae_f16.irpa", io_report_callback=print)
81+
f16_dataset.save(f"{self._temp_dir}/vae_f16.irpa", io_report_callback=print)
8182

8283
def testCompareF32EagerVsHuggingface(self):
8384
dtype = getattr(torch, "float32")
8485
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
85-
ref_results = run_torch_vae("{self._temp_dir}", inputs)
86+
ref_results = run_torch_vae(f"{self._temp_dir}", inputs)
8687

87-
ds = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa")
88+
ds = Dataset.load(f"{self._temp_dir}/vae_f32.irpa", file_type="irpa")
8889
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
8990

9091
results = model.forward(inputs)
@@ -95,9 +96,9 @@ def testCompareF32EagerVsHuggingface(self):
9596
def testCompareF16EagerVsHuggingface(self):
9697
dtype = getattr(torch, "float32")
9798
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
98-
ref_results = run_torch_vae("{self._temp_dir}", inputs)
99+
ref_results = run_torch_vae(f"{self._temp_dir}", inputs)
99100

100-
ds = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa")
101+
ds = Dataset.load(f"{self._temp_dir}/vae_f16.irpa", file_type="irpa")
101102
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
102103

103104
results = model.forward(inputs.to(torch.float16))
@@ -107,10 +108,10 @@ def testCompareF16EagerVsHuggingface(self):
107108
def testVaeIreeVsHuggingFace(self):
108109
dtype = getattr(torch, "float32")
109110
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
110-
ref_results = run_torch_vae("{self._temp_dir}", inputs)
111+
ref_results = run_torch_vae(f"{self._temp_dir}", inputs)
111112

112-
ds_f16 = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa")
113-
ds_f32 = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa")
113+
ds_f16 = Dataset.load(f"{self._temp_dir}/vae_f16.irpa", file_type="irpa")
114+
ds_f32 = Dataset.load(f"{self._temp_dir}/vae_f32.irpa", file_type="irpa")
114115

115116
model_f16 = VaeDecoderModel.from_dataset(ds_f16).to(device="cpu")
116117
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")
@@ -119,8 +120,8 @@ def testVaeIreeVsHuggingFace(self):
119120
module_f16 = export_vae(model_f16, inputs.to(torch.float16), True)
120121
module_f32 = export_vae(model_f32, inputs, True)
121122

122-
module_f16.save_mlir("{self._temp_dir}/vae_f16.mlir")
123-
module_f32.save_mlir("{self._temp_dir}/vae_f32.mlir")
123+
module_f16.save_mlir(f"{self._temp_dir}/vae_f16.mlir")
124+
module_f32.save_mlir(f"{self._temp_dir}/vae_f32.mlir")
124125
extra_args = [
125126
"--iree-hal-target-backends=rocm",
126127
"--iree-hip-target=gfx942",
@@ -137,22 +138,22 @@ def testVaeIreeVsHuggingFace(self):
137138
]
138139

139140
iree.compiler.compile_file(
140-
"{self._temp_dir}/vae_f16.mlir",
141-
output_file="{self._temp_dir}/vae_f16.vmfb",
141+
f"{self._temp_dir}/vae_f16.mlir",
142+
output_file=f"{self._temp_dir}/vae_f16.vmfb",
142143
extra_args=extra_args,
143144
)
144145
iree.compiler.compile_file(
145-
"{self._temp_dir}/vae_f32.mlir",
146-
output_file="{self._temp_dir}/vae_f32.vmfb",
146+
f"{self._temp_dir}/vae_f32.mlir",
147+
output_file=f"{self._temp_dir}/vae_f32.vmfb",
147148
extra_args=extra_args,
148149
)
149150

150151
iree_devices = get_iree_devices(driver="hip", device_count=1)
151152

152153
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
153-
module_path="{self._temp_dir}/vae_f16.vmfb",
154+
module_path=f"{self._temp_dir}/vae_f16.vmfb",
154155
devices=iree_devices,
155-
parameters_path="{self._temp_dir}/vae_f16.irpa",
156+
parameters_path=f"{self._temp_dir}/vae_f16.irpa",
156157
)
157158

158159
input_args = OrderedDict([("inputs", inputs.to(torch.float16))])
@@ -178,9 +179,9 @@ def testVaeIreeVsHuggingFace(self):
178179
)
179180

180181
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
181-
module_path="{self._temp_dir}/vae_f32.vmfb",
182+
module_path=f"{self._temp_dir}/vae_f32.vmfb",
182183
devices=iree_devices,
183-
parameters_path="{self._temp_dir}/vae_f32.irpa",
184+
parameters_path=f"{self._temp_dir}/vae_f32.irpa",
184185
)
185186

186187
input_args = OrderedDict([("inputs", inputs)])
@@ -209,30 +210,32 @@ def setUp(self):
209210
hf_model_id = "black-forest-labs/FLUX.1-dev"
210211
hf_hub_download(
211212
repo_id=hf_model_id,
212-
local_dir="{self._temp_dir}/flux_vae/",
213+
local_dir=f"{self._temp_dir}/flux_vae/",
213214
local_dir_use_symlinks=False,
214215
revision="main",
215216
filename="vae/config.json",
216217
)
217218
hf_hub_download(
218219
repo_id=hf_model_id,
219-
local_dir="{self._temp_dir}/flux_vae/",
220+
local_dir=f"{self._temp_dir}/flux_vae/",
220221
local_dir_use_symlinks=False,
221222
revision="main",
222223
filename="vae/diffusion_pytorch_model.safetensors",
223224
)
224225
torch.manual_seed(12345)
225226
dataset = import_hf_dataset(
226-
"{self._temp_dir}/flux_vae/vae/config.json",
227-
["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
227+
f"{self._temp_dir}/flux_vae/vae/config.json",
228+
[f"{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
228229
)
229-
dataset.save("{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print)
230+
dataset.save(f"{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print)
230231
dataset_f32 = import_hf_dataset(
231-
"{self._temp_dir}/flux_vae/vae/config.json",
232-
["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
232+
f"{self._temp_dir}/flux_vae/vae/config.json",
233+
[f"{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
233234
target_dtype=torch.float32,
234235
)
235-
dataset_f32.save("{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print)
236+
dataset_f32.save(
237+
f"{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print
238+
)
236239

237240
def testCompareBF16EagerVsHuggingface(self):
238241
dtype = torch.bfloat16
@@ -241,7 +244,7 @@ def testCompareBF16EagerVsHuggingface(self):
241244
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
242245
)
243246

244-
ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
247+
ds = Dataset.load(f"{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
245248
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
246249

247250
results = model.forward(inputs)
@@ -255,7 +258,7 @@ def testCompareF32EagerVsHuggingface(self):
255258
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
256259
)
257260

258-
ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
261+
ds = Dataset.load(f"{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
259262
model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype)
260263

261264
results = model.forward(inputs)
@@ -270,8 +273,8 @@ def testVaeIreeVsHuggingFace(self):
270273
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, torch.float32
271274
)
272275

273-
ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
274-
ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
276+
ds = Dataset.load(f"{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
277+
ds_f32 = Dataset.load(f"{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
275278

276279
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
277280
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")
@@ -280,8 +283,8 @@ def testVaeIreeVsHuggingFace(self):
280283
module = export_vae(model, inputs.to(dtype=dtype), True)
281284
module_f32 = export_vae(model_f32, inputs, True)
282285

283-
module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir")
284-
module_f32.save_mlir("{self._temp_dir}/flux_vae_f32.mlir")
286+
module.save_mlir(f"{self._temp_dir}/flux_vae_bf16.mlir")
287+
module_f32.save_mlir(f"{self._temp_dir}/flux_vae_f32.mlir")
285288

286289
extra_args = [
287290
"--iree-hal-target-backends=rocm",
@@ -299,22 +302,22 @@ def testVaeIreeVsHuggingFace(self):
299302
]
300303

301304
iree.compiler.compile_file(
302-
"{self._temp_dir}/flux_vae_bf16.mlir",
303-
output_file="{self._temp_dir}/flux_vae_bf16.vmfb",
305+
f"{self._temp_dir}/flux_vae_bf16.mlir",
306+
output_file=f"{self._temp_dir}/flux_vae_bf16.vmfb",
304307
extra_args=extra_args,
305308
)
306309
iree.compiler.compile_file(
307-
"{self._temp_dir}/flux_vae_f32.mlir",
308-
output_file="{self._temp_dir}/flux_vae_f32.vmfb",
310+
f"{self._temp_dir}/flux_vae_f32.mlir",
311+
output_file=f"{self._temp_dir}/flux_vae_f32.vmfb",
309312
extra_args=extra_args,
310313
)
311314

312315
iree_devices = get_iree_devices(driver="hip", device_count=1)
313316

314317
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
315-
module_path="{self._temp_dir}/flux_vae_bf16.vmfb",
318+
module_path=f"{self._temp_dir}/flux_vae_bf16.vmfb",
316319
devices=iree_devices,
317-
parameters_path="{self._temp_dir}/flux_vae_bf16.irpa",
320+
parameters_path=f"{self._temp_dir}/flux_vae_bf16.irpa",
318321
)
319322

320323
input_args = OrderedDict([("inputs", inputs.to(dtype=dtype))])
@@ -339,9 +342,9 @@ def testVaeIreeVsHuggingFace(self):
339342
)
340343

341344
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
342-
module_path="{self._temp_dir}/flux_vae_f32.vmfb",
345+
module_path=f"{self._temp_dir}/flux_vae_f32.vmfb",
343346
devices=iree_devices,
344-
parameters_path="{self._temp_dir}/flux_vae_f32.irpa",
347+
parameters_path=f"{self._temp_dir}/flux_vae_f32.irpa",
345348
)
346349

347350
input_args = OrderedDict([("inputs", inputs)])

0 commit comments

Comments
 (0)