18
18
from typing import Optional
19
19
import os
20
20
from collections import OrderedDict
21
+ import logging
21
22
import pytest
22
23
import torch
23
24
from torch .utils ._pytree import tree_map , tree_unflatten , tree_flatten_with_path
61
62
62
63
with_t5_data = pytest .mark .skipif ("not config.getoption('with_t5_data')" )
63
64
65
+ logger = logging .getLogger (__name__ )
66
+
64
67
65
68
@pytest .mark .usefixtures ("get_model_artifacts" )
66
69
class T5EncoderEagerTest (TestCase ):
@@ -184,15 +187,18 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
184
187
pad_to_multiple_of = config .context_length_padding_block_size ,
185
188
).input_ids
186
189
190
+ logger .info ("Invoking Torch eager model..." )
187
191
model = T5Encoder (theta = dataset .root_theta , config = config )
188
192
model .eval ()
189
193
194
+ logger .info ("Invoking reference HuggingFace model..." )
190
195
expected_outputs = reference_model (input_ids = input_ids )
191
196
actual_outputs = model (input_ids = input_ids )
192
197
actual_outputs = tree_map (
193
198
lambda t : ops .to (t , dtype = reference_dtype ), actual_outputs
194
199
)
195
200
201
+ logger .info ("Comparing outputs..." )
196
202
torch .testing .assert_close (
197
203
actual_outputs , expected_outputs , atol = atol , rtol = rtol
198
204
)
@@ -340,18 +346,21 @@ def runTestV1_1CompareIreeAgainstTorchEager(
340
346
341
347
mlir_path = f"{ target_model_path_prefix } .mlir"
342
348
if not self .caching or not os .path .exists (mlir_path ):
349
+ logger .info ("Exporting T5 encoder model to MLIR..." )
343
350
export_encoder_mlir (
344
351
parameters_path , batch_sizes = [batch_size ], mlir_output_path = mlir_path
345
352
)
346
353
iree_module_path = f"{ target_model_path_prefix } .vmfb"
347
354
if not self .caching or not os .path .exists (iree_module_path ):
355
+ logger .info ("Compiling MLIR file..." )
348
356
iree .compiler .compile_file (
349
357
mlir_path ,
350
358
output_file = iree_module_path ,
351
359
extra_args = ["--iree-hal-target-device=hip" , "--iree-hip-target=gfx942" ],
352
360
)
353
361
354
362
iree_devices = get_iree_devices (driver = "hip" , device_count = 1 )
363
+ logger .info ("Loading IREE module..." )
355
364
iree_module , iree_vm_context , iree_vm_instance = load_iree_module (
356
365
module_path = iree_module_path ,
357
366
devices = iree_devices ,
@@ -360,6 +369,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
360
369
iree_args = prepare_iree_module_function_args (
361
370
args = flatten_for_iree_signature (input_args ), devices = iree_devices
362
371
)
372
+ logger .info ("Invoking IREE function..." )
363
373
iree_result = iree_to_torch (
364
374
* run_iree_module_function (
365
375
module = iree_module ,
@@ -375,6 +385,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
375
385
for i in range (len (reference_result ))
376
386
]
377
387
388
+ logger .info ("Comparing outputs..." )
378
389
torch .testing .assert_close (reference_result , iree_result , atol = atol , rtol = rtol )
379
390
380
391
@with_t5_data
0 commit comments