21
21
from sharktank .models .flux .testing import (
22
22
convert_flux_transformer_input_for_hugging_face_model ,
23
23
export_dev_random_single_layer ,
24
- make_dev_single_layer_config ,
24
+ make_toy_config ,
25
25
make_random_theta ,
26
26
)
27
27
from sharktank .models .flux .flux import FluxModelV1 , FluxParams
28
- from sharktank .utils .testing import TempDirTestBase
28
+ from sharktank .utils .testing import TempDirTestBase , skip , is_mi300x
29
29
from sharktank .utils .iree import (
30
- get_iree_devices ,
31
30
load_iree_module ,
32
31
run_iree_module_function ,
33
32
prepare_iree_module_function_args ,
34
33
call_torch_module_function ,
35
34
flatten_for_iree_signature ,
36
35
iree_to_torch ,
37
36
)
37
+ from sharktank .utils .logging import format_tensor_statistics
38
38
from sharktank import ops
39
39
from sharktank .transforms .dataset import set_float_dtype
40
40
from sharktank .types import Dataset , Theta
44
44
with_flux_data = pytest .mark .skipif ("not config.getoption('with_flux_data')" )
45
45
46
46
iree_compile_flags = [
47
- "--iree-hal-target-device=hip" ,
48
- "--iree-hip-target=gfx942" ,
49
47
"--iree-opt-const-eval=false" ,
50
48
"--iree-opt-strip-assertions=true" ,
51
49
"--iree-global-opt-propagate-transposes=true" ,
@@ -74,6 +72,15 @@ def convert_dtype_if_dtype(
74
72
return t
75
73
76
74
75
+ def convert_input_dtype (input : dict [str , torch .Tensor ], dtype : torch .dtype ):
76
+ always_float32_input_arg_names = set (["img_ids" , "txt_ids" ])
77
+ return OrderedDict (
78
+ (k , t if k in always_float32_input_arg_names else t .to (dtype = dtype ))
79
+ for k , t in input .items ()
80
+ )
81
+
82
+
83
+ @pytest .mark .usefixtures ("path_prefix" , "get_iree_flags" )
77
84
class FluxTest (TempDirTestBase ):
78
85
def setUp (self ):
79
86
super ().setUp ()
@@ -96,6 +103,7 @@ def runCompareIreeAgainstTorchEager(
96
103
target_theta = reference_model .theta .transform (
97
104
functools .partial (set_float_dtype , dtype = target_dtype )
98
105
)
106
+
99
107
target_torch_model = FluxModelV1 (
100
108
theta = target_theta ,
101
109
params = reference_model .params ,
@@ -115,30 +123,22 @@ def runCompareIreeAgainstTorchEager(
115
123
116
124
iree_module_path = self ._temp_dir / "model.vmfb"
117
125
logger .info ("Compiling MLIR file..." )
126
+ compile_flags = iree_compile_flags + [
127
+ f"--iree-hal-target-device={ self .iree_hal_target_device } " ,
128
+ f"--iree-hip-target={ self .iree_hip_target } " ,
129
+ ]
118
130
iree .compiler .compile_file (
119
131
str (mlir_path ),
120
132
output_file = str (iree_module_path ),
121
- extra_args = iree_compile_flags ,
133
+ extra_args = compile_flags ,
122
134
)
123
135
124
- target_input_args , target_input_kwargs = target_torch_model .sample_inputs (
136
+ reference_input_args , reference_input_kwargs = reference_model .sample_inputs (
125
137
batch_size
126
138
)
127
-
128
- reference_input_args = [
129
- convert_dtype_if_dtype (
130
- t , source_dtype = target_dtype , target_dtype = reference_model .dtype
131
- )
132
- for t in target_input_args
133
- ]
134
- reference_input_kwargs = OrderedDict (
135
- (
136
- k ,
137
- convert_dtype_if_dtype (
138
- t , source_dtype = target_dtype , target_dtype = reference_model .dtype
139
- ),
140
- )
141
- for k , t in target_input_kwargs .items ()
139
+ assert len (reference_input_args ) == 0
140
+ target_input_kwargs = convert_input_dtype (
141
+ reference_input_kwargs , dtype = target_dtype
142
142
)
143
143
144
144
logger .info ("Invoking reference torch function..." )
@@ -150,15 +150,15 @@ def runCompareIreeAgainstTorchEager(
150
150
)
151
151
expected_outputs = flatten_for_iree_signature (reference_result_dict )
152
152
153
- iree_devices = get_iree_devices ( driver = "hip" , device_count = 1 )
153
+ iree_devices = [ iree . runtime . get_device ( self . iree_device )]
154
154
logger .info ("Loading IREE module..." )
155
155
iree_module , iree_vm_context , iree_vm_instance = load_iree_module (
156
156
module_path = iree_module_path ,
157
157
devices = iree_devices ,
158
158
parameters_path = parameters_path ,
159
159
)
160
160
iree_args = prepare_iree_module_function_args (
161
- args = flatten_for_iree_signature ([ target_input_args , target_input_kwargs ] ),
161
+ args = flatten_for_iree_signature (target_input_kwargs ),
162
162
devices = iree_devices ,
163
163
)
164
164
@@ -177,9 +177,14 @@ def runCompareIreeAgainstTorchEager(
177
177
for i in range (len (expected_outputs ))
178
178
]
179
179
logger .info ("Comparing outputs..." )
180
+ logger .info (f"Expected output { format_tensor_statistics (expected_outputs [0 ])} " )
181
+ abs_diff = (actual_outputs [0 ] - expected_outputs [0 ]).abs ()
182
+ logger .info (
183
+ f"Actual vs expected abs diff { format_tensor_statistics (abs_diff [0 ])} "
184
+ )
180
185
torch .testing .assert_close (actual_outputs , expected_outputs , atol = atol , rtol = 0 )
181
186
182
- def runTestCompareDevIreeAgainstHuggingFace (
187
+ def runTestCompareDevIreeAgainstEager (
183
188
self , reference_dtype : torch .dtype , target_dtype : torch .dtype , atol : float
184
189
):
185
190
parameters_output_path = self ._temp_dir / "parameters.irpa"
@@ -211,21 +216,12 @@ def runTestCompareTorchEagerAgainstHuggingFace(
211
216
):
212
217
target_input_args , target_input_kwargs = target_model .sample_inputs ()
213
218
214
- reference_input_args = [
215
- convert_dtype_if_dtype (
216
- t , source_dtype = target_model .dtype , target_dtype = reference_dtype
217
- )
218
- for t in target_input_args
219
- ]
220
- reference_input_kwargs = OrderedDict (
221
- (
222
- k ,
223
- convert_dtype_if_dtype (
224
- t , source_dtype = target_model .dtype , target_dtype = reference_dtype
225
- ),
226
- )
227
- for k , t in target_input_kwargs .items ()
219
+ assert len (target_input_args ) == 0
220
+ reference_input_args = []
221
+ reference_input_kwargs = convert_input_dtype (
222
+ target_input_kwargs , dtype = reference_dtype
228
223
)
224
+
229
225
reference_input_kwargs = convert_flux_transformer_input_for_hugging_face_model (
230
226
* reference_input_args , ** reference_input_kwargs
231
227
)
@@ -238,18 +234,55 @@ def runTestCompareTorchEagerAgainstHuggingFace(
238
234
239
235
torch .testing .assert_close (target_output , reference_output , atol = atol , rtol = 0 )
240
236
237
+ def runTestCompareToyIreeAgainstEager (
238
+ self , reference_dtype : torch .dtype , target_dtype : torch .dtype , atol : float
239
+ ):
240
+ config = make_toy_config ()
241
+ reference_theta = make_random_theta (config , dtype = reference_dtype )
242
+ reference_model = FluxModelV1 (theta = reference_theta , params = config )
243
+ self .runCompareIreeAgainstTorchEager (
244
+ reference_model = reference_model , target_dtype = target_dtype , atol = atol
245
+ )
246
+
247
+ @is_mi300x
248
+ def testCompareToyIreeF32AgainstEagerF64 (self ):
249
+ """atol is apparently high because the expected output range is large.
250
+ Its absolute maximum is 3915. Observed atol is 0.036."""
251
+ self .runTestCompareToyIreeAgainstEager (
252
+ reference_dtype = torch .float64 , target_dtype = torch .float32 , atol = 1e-1
253
+ )
254
+
255
+ @skip (
256
+ reason = (
257
+ "Sporadic segmentation fault during buffer destruction."
258
+ " See https://github.com/nod-ai/shark-ai/issues/1050"
259
+ )
260
+ )
261
+ @is_mi300x
262
+ def testCompareToyIreeBf16AgainstEagerF64 (self ):
263
+ """atol is apparently high because the expected output range is large.
264
+ Its absolute maximum is 3915. Observed atol is 260.6.
265
+ This is consistent with the expectation that bf16 atol should be worse by ~10^4
266
+ compared to f32. f32 can represent ~7 digits and bf16 can represent ~3."""
267
+ self .runTestCompareToyIreeAgainstEager (
268
+ reference_dtype = torch .float64 , target_dtype = torch .bfloat16 , atol = 5e2
269
+ )
270
+
241
271
@with_flux_data
242
- def testCompareDevIreeF32AgainstHuggingFaceF32 (self ):
243
- self .runTestCompareDevIreeAgainstHuggingFace (
272
+ def testCompareDevIreeF32AgainstEagerF32 (self ):
273
+ self .runTestCompareDevIreeAgainstEager (
244
274
reference_dtype = torch .float32 , target_dtype = torch .float32 , atol = 1e-2
245
275
)
246
276
247
- @pytest .mark .skip (
248
- reason = "Segmentation fault during output comparison. See https://github.com/nod-ai/shark-ai/issues/1050"
277
+ @skip (
278
+ reason = (
279
+ "Sporadic segmentation fault during buffer destruction."
280
+ " See https://github.com/nod-ai/shark-ai/issues/1050"
281
+ )
249
282
)
250
283
@with_flux_data
251
- def testCompareDevIreeBf16AgainstHuggingFaceF32 (self ):
252
- self .runTestCompareDevIreeAgainstHuggingFace (
284
+ def testCompareDevIreeBf16AgainstEagerF32 (self ):
285
+ self .runTestCompareDevIreeAgainstEager (
253
286
reference_dtype = torch .float32 , target_dtype = torch .bfloat16 , atol = 1
254
287
)
255
288
0 commit comments