Skip to content

Commit 7a8f360

Browse files
authored
Make the T5 model tests use cosine similarity (#895)
There were several xfail tests with bad metric. Cosine similarity is a better metric for language embeddings. The comparison between bf16 and f32 exhibits a small fraction of outliers that have a higher per-token numerical error than the majority of tokens. To account for that the testing metric is expanded to test for inlier and outlier absolute tolerance.
1 parent f12ed07 commit 7a8f360

File tree

2 files changed

+126
-63
lines changed

2 files changed

+126
-63
lines changed

sharktank/sharktank/utils/testing.py

+55-5
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,56 @@ def assert_iterables_equal(
239239
), f"Iterables not equal at index {i} for elements {v1} and {v2}"
240240

241241

242+
def assert_tensor_close(
243+
actual: torch.Tensor,
244+
expected: torch.Tensor,
245+
atol: float,
246+
max_outliers_fraction: Optional[float] = None,
247+
inlier_atol: Optional[float] = None,
248+
):
249+
if (max_outliers_fraction is None and inlier_atol is not None) or (
250+
max_outliers_fraction is not None and inlier_atol is None
251+
):
252+
raise ValueError(
253+
"max_outliers_fraction and inlier_atol must be provided or not together."
254+
)
255+
256+
try:
257+
torch.testing.assert_close(
258+
actual,
259+
expected,
260+
atol=atol,
261+
rtol=0,
262+
)
263+
264+
if inlier_atol is not None:
265+
outliers = (actual - expected).abs() > inlier_atol
266+
outliers_fraction = outliers.count_nonzero() / outliers.numel()
267+
if outliers_fraction > max_outliers_fraction:
268+
raise AssertionError(
269+
f"The fraction of outliers {outliers_fraction:%} is above the allowed "
270+
f"{max_outliers_fraction:%}. Inlier atol={inlier_atol}."
271+
)
272+
except AssertionError as ex:
273+
diff = actual - expected
274+
std, mean = torch.std_mean(diff)
275+
msg = (
276+
"Difference (actual - expected):\n"
277+
f"mean = {mean}\n"
278+
f"median = {diff.median()}\n"
279+
f"std dev = {std}\n"
280+
f"min = {diff.min()}\n"
281+
f"max = {diff.max()}\n"
282+
)
283+
raise AssertionError(msg) from ex
284+
285+
242286
def assert_text_encoder_state_close(
243-
actual: torch.Tensor, expected: torch.Tensor, atol: float
287+
actual: torch.Tensor,
288+
expected: torch.Tensor,
289+
atol: float,
290+
max_outliers_fraction: Optional[float] = None,
291+
inlier_atol: Optional[float] = None,
244292
):
245293
"""The cosine similarity has been suggested to compare encoder states.
246294
@@ -261,11 +309,13 @@ def assert_text_encoder_state_close(
261309
expected,
262310
dim=-1,
263311
)
264-
torch.testing.assert_close(
265-
cosine_similarity_per_token,
266-
torch.ones_like(cosine_similarity_per_token),
312+
313+
assert_tensor_close(
314+
actual=cosine_similarity_per_token,
315+
expected=torch.ones_like(cosine_similarity_per_token),
267316
atol=atol,
268-
rtol=0,
317+
max_outliers_fraction=max_outliers_fraction,
318+
inlier_atol=inlier_atol,
269319
)
270320

271321

sharktank/tests/models/t5/t5_test.py

+71-58
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
export_encoder_iree_parameters,
4242
)
4343
from sharktank.utils.testing import (
44+
assert_text_encoder_state_close,
4445
make_rand_torch,
4546
make_random_mask,
4647
TempDirTestBase,
@@ -107,17 +108,16 @@ def testXxlBf16AgainstFluxGolden(self):
107108
) as f:
108109
reference_last_hidden_state = torch.load(f)
109110

110-
torch.testing.assert_close(
111-
outputs["last_hidden_state"], reference_last_hidden_state
111+
assert_text_encoder_state_close(
112+
outputs["last_hidden_state"], reference_last_hidden_state, atol=1e-1
112113
)
113114

114115
def runTestV1_1CompareTorchEagerHuggingFace(
115116
self,
116117
huggingface_repo_id: str,
117118
reference_dtype: torch.dtype,
118119
target_dtype: torch.dtype,
119-
atol: Optional[float] = None,
120-
rtol: Optional[float] = None,
120+
atol: float,
121121
):
122122
get_dataset(
123123
huggingface_repo_id,
@@ -146,17 +146,18 @@ def runTestV1_1CompareTorchEagerHuggingFace(
146146
lambda t: ops.to(t, dtype=reference_dtype), actual_outputs
147147
)
148148

149-
torch.testing.assert_close(
150-
actual_outputs, expected_outputs, atol=atol, rtol=rtol
149+
assert_text_encoder_state_close(
150+
actual_outputs["last_hidden_state"],
151+
expected_outputs["last_hidden_state"],
152+
atol,
151153
)
152154

153155
def runTestV1_1CompareTorchEagerAgainstHuggingFace(
154156
self,
155157
huggingface_repo_id: str,
156158
reference_dtype: torch.dtype,
157159
target_dtype: torch.dtype,
158-
atol: Optional[float] = None,
159-
rtol: Optional[float] = None,
160+
atol: float,
160161
):
161162
get_dataset(
162163
huggingface_repo_id,
@@ -199,8 +200,10 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
199200
)
200201

201202
logger.info("Comparing outputs...")
202-
torch.testing.assert_close(
203-
actual_outputs, expected_outputs, atol=atol, rtol=rtol
203+
assert_text_encoder_state_close(
204+
actual_outputs["last_hidden_state"],
205+
expected_outputs["last_hidden_state"],
206+
atol,
204207
)
205208

206209
@pytest.mark.xfail(
@@ -213,12 +216,31 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
213216
)
214217
@with_t5_data
215218
def testV1_1SmallCompareTorchEagerHuggingFaceBf16AgainstF32(self):
219+
"""Hugging Face model tests to estimate numerical error baseline for reference.
220+
We don't want to run this test regularly, but we would like to keep it around
221+
as a reference. It provides some baseline of what numerical error to expect.
222+
"""
216223
self.runTestV1_1CompareTorchEagerHuggingFace(
217224
"google/t5-v1_1-small",
218225
reference_dtype=torch.float32,
219226
target_dtype=torch.bfloat16,
220-
atol=1e-2,
221-
rtol=1.6e-2,
227+
# The observed error is 0.05.
228+
atol=1e-1,
229+
)
230+
231+
@pytest.mark.skip
232+
@with_t5_data
233+
def testV1_1XxlCompareTorchEagerHuggingFaceBf16AgainstF32(self):
234+
"""Hugging Face model tests to estimate numerical error baseline for reference.
235+
We don't want to run this test regularly, but we would like to keep it around
236+
as a reference. It provides some baseline of what numerical error to expect.
237+
"""
238+
self.runTestV1_1CompareTorchEagerHuggingFace(
239+
"google/t5-v1_1-xxl",
240+
reference_dtype=torch.float32,
241+
target_dtype=torch.bfloat16,
242+
# The observed error is 0.026.
243+
atol=1e-1,
222244
)
223245

224246
@with_t5_data
@@ -227,24 +249,16 @@ def testV1_1SmallF32CompareTorchEagerAgainstHuggingFace(self):
227249
"google/t5-v1_1-small",
228250
reference_dtype=torch.float32,
229251
target_dtype=torch.float32,
252+
atol=1e-5,
230253
)
231254

232-
@pytest.mark.xfail(
233-
raises=AssertionError,
234-
reason=(
235-
"The accuracy is bad, "
236-
"but for XXL we get the same result as the Flux pipeline. "
237-
"This need further investigation how Flux works at all like that."
238-
),
239-
)
240255
@with_t5_data
241256
def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFaceF32(self):
242257
self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
243258
"google/t5-v1_1-small",
244259
reference_dtype=torch.float32,
245260
target_dtype=torch.bfloat16,
246-
atol=1e-2,
247-
rtol=1.6e-2,
261+
atol=1e-1,
248262
)
249263

250264
@with_t5_data
@@ -253,6 +267,7 @@ def testV1_1SmallBf16CompareTorchEagerAgainstHuggingFace(self):
253267
"google/t5-v1_1-small",
254268
reference_dtype=torch.bfloat16,
255269
target_dtype=torch.bfloat16,
270+
atol=1e-1,
256271
)
257272

258273
@with_t5_data
@@ -261,23 +276,16 @@ def testV1_1XxlF32CompareTorchEagerAgainstHuggingFace(self):
261276
"google/t5-v1_1-xxl",
262277
reference_dtype=torch.float32,
263278
target_dtype=torch.float32,
279+
atol=1e-5,
264280
)
265281

266-
@pytest.mark.xfail(
267-
raises=AssertionError,
268-
reason=(
269-
"The accuracy is bad, but we get the same result as the Flux pipeline. "
270-
"This need further investigation how Flux works at all like that."
271-
),
272-
)
273282
@with_t5_data
274283
def testV1_1XxlBf16CompareTorchEagerAgainstHuggingFaceF32(self):
275284
self.runTestV1_1CompareTorchEagerAgainstHuggingFace(
276285
"google/t5-v1_1-xxl",
277286
reference_dtype=torch.float32,
278287
target_dtype=torch.bfloat16,
279-
atol=1e-2,
280-
rtol=1.6e-2,
288+
atol=5e-2,
281289
)
282290

283291

@@ -293,8 +301,9 @@ def runTestV1_1CompareIreeAgainstTorchEager(
293301
huggingface_repo_id: str,
294302
reference_dtype: torch.dtype,
295303
target_dtype: torch.dtype,
296-
atol: Optional[float] = None,
297-
rtol: Optional[float] = None,
304+
atol: float,
305+
max_outliers_fraction: Optional[float] = None,
306+
inlier_atol: Optional[float] = None,
298307
):
299308
get_dataset(
300309
huggingface_repo_id,
@@ -386,34 +395,35 @@ def runTestV1_1CompareIreeAgainstTorchEager(
386395
]
387396

388397
logger.info("Comparing outputs...")
389-
torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol)
398+
reference_result_last_hidden_state = reference_result[0]
399+
iree_result_last_hidden_state = iree_result[0]
400+
assert_text_encoder_state_close(
401+
iree_result_last_hidden_state,
402+
reference_result_last_hidden_state,
403+
atol=atol,
404+
max_outliers_fraction=max_outliers_fraction,
405+
inlier_atol=inlier_atol,
406+
)
390407

391408
@with_t5_data
392409
def testV1_1CompareSmallIreeF32AgainstTorchEagerF32(self):
393410
self.runTestV1_1CompareIreeAgainstTorchEager(
394411
"google/t5-v1_1-small",
395412
reference_dtype=torch.float32,
396413
target_dtype=torch.float32,
397-
atol=1e-4,
398-
rtol=2.0e-3,
414+
atol=1e-5,
399415
)
400416

401-
@pytest.mark.xfail(
402-
raises=AssertionError,
403-
reason=(
404-
"The accuracy is bad, "
405-
"but but it is no worse than the accuracy for of eager bfloat16. "
406-
"This need further investigation how Flux works at all like that."
407-
),
408-
)
409417
@with_t5_data
410418
def testV1_1CompareSmallIreeBf16AgainstTorchEagerF32(self):
411419
self.runTestV1_1CompareIreeAgainstTorchEager(
412420
"google/t5-v1_1-small",
413421
reference_dtype=torch.float32,
414422
target_dtype=torch.bfloat16,
415-
atol=1e-2,
416-
rtol=1.6e-2,
423+
# The observed error is 0.12.
424+
atol=0.2,
425+
max_outliers_fraction=0.03,
426+
inlier_atol=0.01,
417427
)
418428

419429
@with_t5_data
@@ -422,26 +432,29 @@ def testV1_1CompareXxlIreeF32AgainstTorchEagerF32(self):
422432
"google/t5-v1_1-xxl",
423433
reference_dtype=torch.float32,
424434
target_dtype=torch.float32,
425-
atol=1e-4,
426-
rtol=2.0e-3,
435+
atol=1e-5,
427436
)
428437

429-
@pytest.mark.xfail(
430-
raises=AssertionError,
431-
reason=(
432-
"The accuracy is bad, "
433-
"but but it is no worse than the accuracy for of eager bfloat16. "
434-
"This need further investigation how Flux works at all like that."
435-
),
436-
)
437438
@with_t5_data
438439
def testV1_1CompareXxlIreeBf16AgainstTorchEagerF32(self):
440+
"""The observed absolute numerical error is 0.21.
441+
Per token cosine similarity metrics are
442+
mean = 0.997
443+
std dev = 0.018
444+
min = 0.789
445+
446+
The error seems high as it corresponds to 38° angular difference.
447+
For comparison the bf16 Hugging Face small model exhibits a worst token error
448+
of 0.05. Although, here the error worse it may be reasonable as it comes from a
449+
single token outlier. The majority of tokens have an error less than 0.01.
450+
"""
439451
self.runTestV1_1CompareIreeAgainstTorchEager(
440452
"google/t5-v1_1-xxl",
441453
reference_dtype=torch.float32,
442454
target_dtype=torch.bfloat16,
443-
atol=1e-2,
444-
rtol=1.6e-2,
455+
atol=2.5e-1,
456+
max_outliers_fraction=0.03,
457+
inlier_atol=0.01,
445458
)
446459

447460

0 commit comments

Comments
 (0)