@@ -49,42 +49,43 @@ def setUp(self):
49
49
hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
50
50
hf_hub_download (
51
51
repo_id = hf_model_id ,
52
- local_dir = "{self._temp_dir}" ,
52
+ local_dir = f "{ self ._temp_dir } " ,
53
53
local_dir_use_symlinks = False ,
54
54
revision = "main" ,
55
55
filename = "vae/config.json" ,
56
56
)
57
57
hf_hub_download (
58
58
repo_id = hf_model_id ,
59
- local_dir = "{self._temp_dir}" ,
59
+ local_dir = f "{ self ._temp_dir } " ,
60
60
local_dir_use_symlinks = False ,
61
61
revision = "main" ,
62
62
filename = "vae/diffusion_pytorch_model.safetensors" ,
63
63
)
64
64
hf_hub_download (
65
65
repo_id = "amd-shark/sdxl-quant-models" ,
66
- local_dir = "{self._temp_dir}" ,
66
+ local_dir = f "{ self ._temp_dir } " ,
67
67
local_dir_use_symlinks = False ,
68
68
revision = "main" ,
69
69
filename = "vae/vae.safetensors" ,
70
70
)
71
71
torch .manual_seed (12345 )
72
72
f32_dataset = import_hf_dataset (
73
- "{self._temp_dir}/vae/config.json" ,
74
- ["{self._temp_dir}/vae/diffusion_pytorch_model.safetensors" ],
73
+ f "{ self ._temp_dir } /vae/config.json" ,
74
+ [f "{ self ._temp_dir } /vae/diffusion_pytorch_model.safetensors" ],
75
75
)
76
- f32_dataset .save ("{self._temp_dir}/vae_f32.irpa" , io_report_callback = print )
76
+ f32_dataset .save (f "{ self ._temp_dir } /vae_f32.irpa" , io_report_callback = print )
77
77
f16_dataset = import_hf_dataset (
78
- "{self._temp_dir}/vae/config.json" , ["{self._temp_dir}/vae/vae.safetensors" ]
78
+ f"{ self ._temp_dir } /vae/config.json" ,
79
+ [f"{ self ._temp_dir } /vae/vae.safetensors" ],
79
80
)
80
- f16_dataset .save ("{self._temp_dir}/vae_f16.irpa" , io_report_callback = print )
81
+ f16_dataset .save (f "{ self ._temp_dir } /vae_f16.irpa" , io_report_callback = print )
81
82
82
83
def testCompareF32EagerVsHuggingface (self ):
83
84
dtype = getattr (torch , "float32" )
84
85
inputs = get_random_inputs (dtype = dtype , device = "cpu" , bs = 1 )
85
- ref_results = run_torch_vae ("{self._temp_dir}" , inputs )
86
+ ref_results = run_torch_vae (f "{ self ._temp_dir } " , inputs )
86
87
87
- ds = Dataset .load ("{self._temp_dir}/vae_f32.irpa" , file_type = "irpa" )
88
+ ds = Dataset .load (f "{ self ._temp_dir } /vae_f32.irpa" , file_type = "irpa" )
88
89
model = VaeDecoderModel .from_dataset (ds ).to (device = "cpu" )
89
90
90
91
results = model .forward (inputs )
@@ -95,9 +96,9 @@ def testCompareF32EagerVsHuggingface(self):
95
96
def testCompareF16EagerVsHuggingface (self ):
96
97
dtype = getattr (torch , "float32" )
97
98
inputs = get_random_inputs (dtype = dtype , device = "cpu" , bs = 1 )
98
- ref_results = run_torch_vae ("{self._temp_dir}" , inputs )
99
+ ref_results = run_torch_vae (f "{ self ._temp_dir } " , inputs )
99
100
100
- ds = Dataset .load ("{self._temp_dir}/vae_f16.irpa" , file_type = "irpa" )
101
+ ds = Dataset .load (f "{ self ._temp_dir } /vae_f16.irpa" , file_type = "irpa" )
101
102
model = VaeDecoderModel .from_dataset (ds ).to (device = "cpu" )
102
103
103
104
results = model .forward (inputs .to (torch .float16 ))
@@ -107,10 +108,10 @@ def testCompareF16EagerVsHuggingface(self):
107
108
def testVaeIreeVsHuggingFace (self ):
108
109
dtype = getattr (torch , "float32" )
109
110
inputs = get_random_inputs (dtype = dtype , device = "cpu" , bs = 1 )
110
- ref_results = run_torch_vae ("{self._temp_dir}" , inputs )
111
+ ref_results = run_torch_vae (f "{ self ._temp_dir } " , inputs )
111
112
112
- ds_f16 = Dataset .load ("{self._temp_dir}/vae_f16.irpa" , file_type = "irpa" )
113
- ds_f32 = Dataset .load ("{self._temp_dir}/vae_f32.irpa" , file_type = "irpa" )
113
+ ds_f16 = Dataset .load (f "{ self ._temp_dir } /vae_f16.irpa" , file_type = "irpa" )
114
+ ds_f32 = Dataset .load (f "{ self ._temp_dir } /vae_f32.irpa" , file_type = "irpa" )
114
115
115
116
model_f16 = VaeDecoderModel .from_dataset (ds_f16 ).to (device = "cpu" )
116
117
model_f32 = VaeDecoderModel .from_dataset (ds_f32 ).to (device = "cpu" )
@@ -119,8 +120,8 @@ def testVaeIreeVsHuggingFace(self):
119
120
module_f16 = export_vae (model_f16 , inputs .to (torch .float16 ), True )
120
121
module_f32 = export_vae (model_f32 , inputs , True )
121
122
122
- module_f16 .save_mlir ("{self._temp_dir}/vae_f16.mlir" )
123
- module_f32 .save_mlir ("{self._temp_dir}/vae_f32.mlir" )
123
+ module_f16 .save_mlir (f "{ self ._temp_dir } /vae_f16.mlir" )
124
+ module_f32 .save_mlir (f "{ self ._temp_dir } /vae_f32.mlir" )
124
125
extra_args = [
125
126
"--iree-hal-target-backends=rocm" ,
126
127
"--iree-hip-target=gfx942" ,
@@ -137,22 +138,22 @@ def testVaeIreeVsHuggingFace(self):
137
138
]
138
139
139
140
iree .compiler .compile_file (
140
- "{self._temp_dir}/vae_f16.mlir" ,
141
- output_file = "{self._temp_dir}/vae_f16.vmfb" ,
141
+ f "{ self ._temp_dir } /vae_f16.mlir" ,
142
+ output_file = f "{ self ._temp_dir } /vae_f16.vmfb" ,
142
143
extra_args = extra_args ,
143
144
)
144
145
iree .compiler .compile_file (
145
- "{self._temp_dir}/vae_f32.mlir" ,
146
- output_file = "{self._temp_dir}/vae_f32.vmfb" ,
146
+ f "{ self ._temp_dir } /vae_f32.mlir" ,
147
+ output_file = f "{ self ._temp_dir } /vae_f32.vmfb" ,
147
148
extra_args = extra_args ,
148
149
)
149
150
150
151
iree_devices = get_iree_devices (driver = "hip" , device_count = 1 )
151
152
152
153
iree_module , iree_vm_context , iree_vm_instance = load_iree_module (
153
- module_path = "{self._temp_dir}/vae_f16.vmfb" ,
154
+ module_path = f "{ self ._temp_dir } /vae_f16.vmfb" ,
154
155
devices = iree_devices ,
155
- parameters_path = "{self._temp_dir}/vae_f16.irpa" ,
156
+ parameters_path = f "{ self ._temp_dir } /vae_f16.irpa" ,
156
157
)
157
158
158
159
input_args = OrderedDict ([("inputs" , inputs .to (torch .float16 ))])
@@ -178,9 +179,9 @@ def testVaeIreeVsHuggingFace(self):
178
179
)
179
180
180
181
iree_module , iree_vm_context , iree_vm_instance = load_iree_module (
181
- module_path = "{self._temp_dir}/vae_f32.vmfb" ,
182
+ module_path = f "{ self ._temp_dir } /vae_f32.vmfb" ,
182
183
devices = iree_devices ,
183
- parameters_path = "{self._temp_dir}/vae_f32.irpa" ,
184
+ parameters_path = f "{ self ._temp_dir } /vae_f32.irpa" ,
184
185
)
185
186
186
187
input_args = OrderedDict ([("inputs" , inputs )])
@@ -209,30 +210,32 @@ def setUp(self):
209
210
hf_model_id = "black-forest-labs/FLUX.1-dev"
210
211
hf_hub_download (
211
212
repo_id = hf_model_id ,
212
- local_dir = "{self._temp_dir}/flux_vae/" ,
213
+ local_dir = f "{ self ._temp_dir } /flux_vae/" ,
213
214
local_dir_use_symlinks = False ,
214
215
revision = "main" ,
215
216
filename = "vae/config.json" ,
216
217
)
217
218
hf_hub_download (
218
219
repo_id = hf_model_id ,
219
- local_dir = "{self._temp_dir}/flux_vae/" ,
220
+ local_dir = f "{ self ._temp_dir } /flux_vae/" ,
220
221
local_dir_use_symlinks = False ,
221
222
revision = "main" ,
222
223
filename = "vae/diffusion_pytorch_model.safetensors" ,
223
224
)
224
225
torch .manual_seed (12345 )
225
226
dataset = import_hf_dataset (
226
- "{self._temp_dir}/flux_vae/vae/config.json" ,
227
- ["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors" ],
227
+ f "{ self ._temp_dir } /flux_vae/vae/config.json" ,
228
+ [f "{ self ._temp_dir } /flux_vae/vae/diffusion_pytorch_model.safetensors" ],
228
229
)
229
- dataset .save ("{self._temp_dir}/flux_vae_bf16.irpa" , io_report_callback = print )
230
+ dataset .save (f "{ self ._temp_dir } /flux_vae_bf16.irpa" , io_report_callback = print )
230
231
dataset_f32 = import_hf_dataset (
231
- "{self._temp_dir}/flux_vae/vae/config.json" ,
232
- ["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors" ],
232
+ f "{ self ._temp_dir } /flux_vae/vae/config.json" ,
233
+ [f "{ self ._temp_dir } /flux_vae/vae/diffusion_pytorch_model.safetensors" ],
233
234
target_dtype = torch .float32 ,
234
235
)
235
- dataset_f32 .save ("{self._temp_dir}/flux_vae_f32.irpa" , io_report_callback = print )
236
+ dataset_f32 .save (
237
+ f"{ self ._temp_dir } /flux_vae_f32.irpa" , io_report_callback = print
238
+ )
236
239
237
240
def testCompareBF16EagerVsHuggingface (self ):
238
241
dtype = torch .bfloat16
@@ -241,7 +244,7 @@ def testCompareBF16EagerVsHuggingface(self):
241
244
"black-forest-labs/FLUX.1-dev" , inputs , 1024 , 1024 , True , dtype
242
245
)
243
246
244
- ds = Dataset .load ("{self._temp_dir}/flux_vae_bf16.irpa" , file_type = "irpa" )
247
+ ds = Dataset .load (f "{ self ._temp_dir } /flux_vae_bf16.irpa" , file_type = "irpa" )
245
248
model = VaeDecoderModel .from_dataset (ds ).to (device = "cpu" )
246
249
247
250
results = model .forward (inputs )
@@ -255,7 +258,7 @@ def testCompareF32EagerVsHuggingface(self):
255
258
"black-forest-labs/FLUX.1-dev" , inputs , 1024 , 1024 , True , dtype
256
259
)
257
260
258
- ds = Dataset .load ("{self._temp_dir}/flux_vae_f32.irpa" , file_type = "irpa" )
261
+ ds = Dataset .load (f "{ self ._temp_dir } /flux_vae_f32.irpa" , file_type = "irpa" )
259
262
model = VaeDecoderModel .from_dataset (ds ).to (device = "cpu" , dtype = dtype )
260
263
261
264
results = model .forward (inputs )
@@ -270,8 +273,8 @@ def testVaeIreeVsHuggingFace(self):
270
273
"black-forest-labs/FLUX.1-dev" , inputs , 1024 , 1024 , True , torch .float32
271
274
)
272
275
273
- ds = Dataset .load ("{self._temp_dir}/flux_vae_bf16.irpa" , file_type = "irpa" )
274
- ds_f32 = Dataset .load ("{self._temp_dir}/flux_vae_f32.irpa" , file_type = "irpa" )
276
+ ds = Dataset .load (f "{ self ._temp_dir } /flux_vae_bf16.irpa" , file_type = "irpa" )
277
+ ds_f32 = Dataset .load (f "{ self ._temp_dir } /flux_vae_f32.irpa" , file_type = "irpa" )
275
278
276
279
model = VaeDecoderModel .from_dataset (ds ).to (device = "cpu" )
277
280
model_f32 = VaeDecoderModel .from_dataset (ds_f32 ).to (device = "cpu" )
@@ -280,8 +283,8 @@ def testVaeIreeVsHuggingFace(self):
280
283
module = export_vae (model , inputs .to (dtype = dtype ), True )
281
284
module_f32 = export_vae (model_f32 , inputs , True )
282
285
283
- module .save_mlir ("{self._temp_dir}/flux_vae_bf16.mlir" )
284
- module_f32 .save_mlir ("{self._temp_dir}/flux_vae_f32.mlir" )
286
+ module .save_mlir (f "{ self ._temp_dir } /flux_vae_bf16.mlir" )
287
+ module_f32 .save_mlir (f "{ self ._temp_dir } /flux_vae_f32.mlir" )
285
288
286
289
extra_args = [
287
290
"--iree-hal-target-backends=rocm" ,
@@ -299,22 +302,22 @@ def testVaeIreeVsHuggingFace(self):
299
302
]
300
303
301
304
iree .compiler .compile_file (
302
- "{self._temp_dir}/flux_vae_bf16.mlir" ,
303
- output_file = "{self._temp_dir}/flux_vae_bf16.vmfb" ,
305
+ f "{ self ._temp_dir } /flux_vae_bf16.mlir" ,
306
+ output_file = f "{ self ._temp_dir } /flux_vae_bf16.vmfb" ,
304
307
extra_args = extra_args ,
305
308
)
306
309
iree .compiler .compile_file (
307
- "{self._temp_dir}/flux_vae_f32.mlir" ,
308
- output_file = "{self._temp_dir}/flux_vae_f32.vmfb" ,
310
+ f "{ self ._temp_dir } /flux_vae_f32.mlir" ,
311
+ output_file = f "{ self ._temp_dir } /flux_vae_f32.vmfb" ,
309
312
extra_args = extra_args ,
310
313
)
311
314
312
315
iree_devices = get_iree_devices (driver = "hip" , device_count = 1 )
313
316
314
317
iree_module , iree_vm_context , iree_vm_instance = load_iree_module (
315
- module_path = "{self._temp_dir}/flux_vae_bf16.vmfb" ,
318
+ module_path = f "{ self ._temp_dir } /flux_vae_bf16.vmfb" ,
316
319
devices = iree_devices ,
317
- parameters_path = "{self._temp_dir}/flux_vae_bf16.irpa" ,
320
+ parameters_path = f "{ self ._temp_dir } /flux_vae_bf16.irpa" ,
318
321
)
319
322
320
323
input_args = OrderedDict ([("inputs" , inputs .to (dtype = dtype ))])
@@ -339,9 +342,9 @@ def testVaeIreeVsHuggingFace(self):
339
342
)
340
343
341
344
iree_module , iree_vm_context , iree_vm_instance = load_iree_module (
342
- module_path = "{self._temp_dir}/flux_vae_f32.vmfb" ,
345
+ module_path = f "{ self ._temp_dir } /flux_vae_f32.vmfb" ,
343
346
devices = iree_devices ,
344
- parameters_path = "{self._temp_dir}/flux_vae_f32.irpa" ,
347
+ parameters_path = f "{ self ._temp_dir } /flux_vae_f32.irpa" ,
345
348
)
346
349
347
350
input_args = OrderedDict ([("inputs" , inputs )])
0 commit comments