41
41
export_encoder_iree_parameters ,
42
42
)
43
43
from sharktank .utils .testing import (
44
+ assert_text_encoder_state_close ,
44
45
make_rand_torch ,
45
46
make_random_mask ,
46
47
TempDirTestBase ,
@@ -107,17 +108,16 @@ def testXxlBf16AgainstFluxGolden(self):
107
108
) as f :
108
109
reference_last_hidden_state = torch .load (f )
109
110
110
- torch . testing . assert_close (
111
- outputs ["last_hidden_state" ], reference_last_hidden_state
111
+ assert_text_encoder_state_close (
112
+ outputs ["last_hidden_state" ], reference_last_hidden_state , atol = 1e-1
112
113
)
113
114
114
115
def runTestV1_1CompareTorchEagerHuggingFace (
115
116
self ,
116
117
huggingface_repo_id : str ,
117
118
reference_dtype : torch .dtype ,
118
119
target_dtype : torch .dtype ,
119
- atol : Optional [float ] = None ,
120
- rtol : Optional [float ] = None ,
120
+ atol : float ,
121
121
):
122
122
get_dataset (
123
123
huggingface_repo_id ,
@@ -146,17 +146,18 @@ def runTestV1_1CompareTorchEagerHuggingFace(
146
146
lambda t : ops .to (t , dtype = reference_dtype ), actual_outputs
147
147
)
148
148
149
- torch .testing .assert_close (
150
- actual_outputs , expected_outputs , atol = atol , rtol = rtol
149
+ assert_text_encoder_state_close (
150
+ actual_outputs ["last_hidden_state" ],
151
+ expected_outputs ["last_hidden_state" ],
152
+ atol ,
151
153
)
152
154
153
155
def runTestV1_1CompareTorchEagerAgainstHuggingFace (
154
156
self ,
155
157
huggingface_repo_id : str ,
156
158
reference_dtype : torch .dtype ,
157
159
target_dtype : torch .dtype ,
158
- atol : Optional [float ] = None ,
159
- rtol : Optional [float ] = None ,
160
+ atol : float ,
160
161
):
161
162
get_dataset (
162
163
huggingface_repo_id ,
@@ -199,8 +200,10 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
199
200
)
200
201
201
202
logger .info ("Comparing outputs..." )
202
- torch .testing .assert_close (
203
- actual_outputs , expected_outputs , atol = atol , rtol = rtol
203
+ assert_text_encoder_state_close (
204
+ actual_outputs ["last_hidden_state" ],
205
+ expected_outputs ["last_hidden_state" ],
206
+ atol ,
204
207
)
205
208
206
209
@pytest .mark .xfail (
@@ -213,12 +216,31 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
213
216
)
214
217
@with_t5_data
215
218
def testV1_1SmallCompareTorchEagerHuggingFaceBf16AgainstF32 (self ):
219
+ """Hugging Face model tests to estimate numerical error baseline for reference.
220
+ We don't want to run this test regularly, but we would like to keep it around
221
+ as a reference. It provides some baseline of what numerical error to expect.
222
+ """
216
223
self .runTestV1_1CompareTorchEagerHuggingFace (
217
224
"google/t5-v1_1-small" ,
218
225
reference_dtype = torch .float32 ,
219
226
target_dtype = torch .bfloat16 ,
220
- atol = 1e-2 ,
221
- rtol = 1.6e-2 ,
227
+ # The observed error is 0.05.
228
+ atol = 1e-1 ,
229
+ )
230
+
231
+ @pytest .mark .skip
232
+ @with_t5_data
233
+ def testV1_1XxlCompareTorchEagerHuggingFaceBf16AgainstF32 (self ):
234
+ """Hugging Face model tests to estimate numerical error baseline for reference.
235
+ We don't want to run this test regularly, but we would like to keep it around
236
+ as a reference. It provides some baseline of what numerical error to expect.
237
+ """
238
+ self .runTestV1_1CompareTorchEagerHuggingFace (
239
+ "google/t5-v1_1-xxl" ,
240
+ reference_dtype = torch .float32 ,
241
+ target_dtype = torch .bfloat16 ,
242
+ # The observed error is 0.026.
243
+ atol = 1e-1 ,
222
244
)
223
245
224
246
@with_t5_data
@@ -227,24 +249,16 @@ def testV1_1SmallF32CompareTorchEagerAgainstHuggingFace(self):
227
249
"google/t5-v1_1-small" ,
228
250
reference_dtype = torch .float32 ,
229
251
target_dtype = torch .float32 ,
252
+ atol = 1e-5 ,
230
253
)
231
254
232
- @pytest .mark .xfail (
233
- raises = AssertionError ,
234
- reason = (
235
- "The accuracy is bad, "
236
- "but for XXL we get the same result as the Flux pipeline. "
237
- "This need further investigation how Flux works at all like that."
238
- ),
239
- )
240
255
@with_t5_data
241
256
def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32 (self ):
242
257
self .runTestV1_1CompareTorchEagerAgainstHuggingFace (
243
258
"google/t5-v1_1-small" ,
244
259
reference_dtype = torch .float32 ,
245
260
target_dtype = torch .bfloat16 ,
246
- atol = 1e-2 ,
247
- rtol = 1.6e-2 ,
261
+ atol = 1e-1 ,
248
262
)
249
263
250
264
@with_t5_data
@@ -253,6 +267,7 @@ def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFace(self):
253
267
"google/t5-v1_1-small" ,
254
268
reference_dtype = torch .bfloat16 ,
255
269
target_dtype = torch .bfloat16 ,
270
+ atol = 1e-1 ,
256
271
)
257
272
258
273
@with_t5_data
@@ -261,23 +276,16 @@ def testV1_1XxlF32CompareTorchEagerAgainstHuggingFace(self):
261
276
"google/t5-v1_1-xxl" ,
262
277
reference_dtype = torch .float32 ,
263
278
target_dtype = torch .float32 ,
279
+ atol = 1e-5 ,
264
280
)
265
281
266
- @pytest .mark .xfail (
267
- raises = AssertionError ,
268
- reason = (
269
- "The accuracy is bad, but we get the same result as the Flux pipeline. "
270
- "This need further investigation how Flux works at all like that."
271
- ),
272
- )
273
282
@with_t5_data
274
283
def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32 (self ):
275
284
self .runTestV1_1CompareTorchEagerAgainstHuggingFace (
276
285
"google/t5-v1_1-xxl" ,
277
286
reference_dtype = torch .float32 ,
278
287
target_dtype = torch .bfloat16 ,
279
- atol = 1e-2 ,
280
- rtol = 1.6e-2 ,
288
+ atol = 5e-2 ,
281
289
)
282
290
283
291
@@ -293,8 +301,9 @@ def runTestV1_1CompareIreeAgainstTorchEager(
293
301
huggingface_repo_id : str ,
294
302
reference_dtype : torch .dtype ,
295
303
target_dtype : torch .dtype ,
296
- atol : Optional [float ] = None ,
297
- rtol : Optional [float ] = None ,
304
+ atol : float ,
305
+ max_outliers_fraction : Optional [float ] = None ,
306
+ inlier_atol : Optional [float ] = None ,
298
307
):
299
308
get_dataset (
300
309
huggingface_repo_id ,
@@ -386,34 +395,35 @@ def runTestV1_1CompareIreeAgainstTorchEager(
386
395
]
387
396
388
397
logger .info ("Comparing outputs..." )
389
- torch .testing .assert_close (reference_result , iree_result , atol = atol , rtol = rtol )
398
+ reference_result_last_hidden_state = reference_result [0 ]
399
+ iree_result_last_hidden_state = iree_result [0 ]
400
+ assert_text_encoder_state_close (
401
+ iree_result_last_hidden_state ,
402
+ reference_result_last_hidden_state ,
403
+ atol = atol ,
404
+ max_outliers_fraction = max_outliers_fraction ,
405
+ inlier_atol = inlier_atol ,
406
+ )
390
407
391
408
@with_t5_data
392
409
def testV1_1CompareSmallIreeF32AgainstTorchEagerF32 (self ):
393
410
self .runTestV1_1CompareIreeAgainstTorchEager (
394
411
"google/t5-v1_1-small" ,
395
412
reference_dtype = torch .float32 ,
396
413
target_dtype = torch .float32 ,
397
- atol = 1e-4 ,
398
- rtol = 2.0e-3 ,
414
+ atol = 1e-5 ,
399
415
)
400
416
401
- @pytest .mark .xfail (
402
- raises = AssertionError ,
403
- reason = (
404
- "The accuracy is bad, "
405
- "but but it is no worse than the accuracy for of eager bfloat16. "
406
- "This need further investigation how Flux works at all like that."
407
- ),
408
- )
409
417
@with_t5_data
410
418
def testV1_1CompareSmallIreeBf16AgainstTorchEagerF32 (self ):
411
419
self .runTestV1_1CompareIreeAgainstTorchEager (
412
420
"google/t5-v1_1-small" ,
413
421
reference_dtype = torch .float32 ,
414
422
target_dtype = torch .bfloat16 ,
415
- atol = 1e-2 ,
416
- rtol = 1.6e-2 ,
423
+ # The observed error is 0.12.
424
+ atol = 0.2 ,
425
+ max_outliers_fraction = 0.03 ,
426
+ inlier_atol = 0.01 ,
417
427
)
418
428
419
429
@with_t5_data
@@ -422,26 +432,29 @@ def testV1_1CompareXxlIreeF32AgainstTorchEagerF32(self):
422
432
"google/t5-v1_1-xxl" ,
423
433
reference_dtype = torch .float32 ,
424
434
target_dtype = torch .float32 ,
425
- atol = 1e-4 ,
426
- rtol = 2.0e-3 ,
435
+ atol = 1e-5 ,
427
436
)
428
437
429
- @pytest .mark .xfail (
430
- raises = AssertionError ,
431
- reason = (
432
- "The accuracy is bad, "
433
- "but but it is no worse than the accuracy for of eager bfloat16. "
434
- "This need further investigation how Flux works at all like that."
435
- ),
436
- )
437
438
@with_t5_data
438
439
def testV1_1CompareXxlIreeBf16AgainstTorchEagerF32 (self ):
440
+ """The observed absolute numerical error is 0.21.
441
+ Per token cosine similarity metrics are
442
+ mean = 0.997
443
+ std dev = 0.018
444
+ min = 0.789
445
+
446
+ The error seems high as it corresponds to 38° angular difference.
447
+ For comparison the bf16 Hugging Face small model exhibits a worst token error
448
+ of 0.05. Although, here the error worse it may be reasonable as it comes from a
449
+ single token outlier. The majority of tokens have an error less than 0.01.
450
+ """
439
451
self .runTestV1_1CompareIreeAgainstTorchEager (
440
452
"google/t5-v1_1-xxl" ,
441
453
reference_dtype = torch .float32 ,
442
454
target_dtype = torch .bfloat16 ,
443
- atol = 1e-2 ,
444
- rtol = 1.6e-2 ,
455
+ atol = 2.5e-1 ,
456
+ max_outliers_fraction = 0.03 ,
457
+ inlier_atol = 0.01 ,
445
458
)
446
459
447
460
0 commit comments