Skip to content

Commit a40a486

Browse files
authored
Add more logging to sharktank data tests. (#917)
Logging more information as tests run will help with the triage process for issues like #888. Logs before: https://github.com/nod-ai/shark-ai/actions/runs/13150995160/job/36698277502#step:6:27 Logs after: https://github.com/nod-ai/shark-ai/actions/runs/13168068091/job/36752792808?pr=917#step:6:29
1 parent 3cccc20 commit a40a486

File tree

5 files changed

+32
-3
lines changed

5 files changed

+32
-3
lines changed

.github/workflows/ci-sharktank.yml

+2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ jobs:
137137
run: |
138138
source ${VENV_DIR}/bin/activate
139139
pytest \
140+
-v \
141+
--log-cli-level=info \
140142
--with-clip-data \
141143
--with-flux-data \
142144
--with-t5-data \

sharktank/sharktank/tools/import_hf_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def import_hf_dataset(
6767
if output_irpa_file is None:
6868
return dataset
6969

70-
dataset.save(output_irpa_file, io_report_callback=logger.info)
70+
dataset.save(output_irpa_file, io_report_callback=logger.debug)
7171

7272

7373
def main(argv: list[str]):

sharktank/tests/models/clip/clip_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pathlib import Path
1212
from parameterized import parameterized
1313
from copy import copy
14+
import logging
1415
import pytest
1516
import torch
1617
from torch.utils._pytree import tree_map
@@ -72,6 +73,8 @@
7273

7374
with_clip_data = pytest.mark.skipif("not config.getoption('with_clip_data')")
7475

76+
logger = logging.getLogger(__name__)
77+
7578

7679
@pytest.mark.usefixtures("path_prefix")
7780
class ClipTextIreeTest(TempDirTestBase):
@@ -163,6 +166,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
163166
batch_size = input_ids.shape[0]
164167
mlir_path = f"{target_model_path_prefix}.mlir"
165168

169+
logger.info("Exporting clip text model to MLIR...")
166170
export_clip_text_model_iree_test_data(
167171
reference_model=reference_model,
168172
target_dtype=target_dtype,
@@ -172,12 +176,14 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
172176
)
173177

174178
iree_module_path = f"{target_model_path_prefix}.vmfb"
179+
logger.info("Compiling MLIR file...")
175180
iree.compiler.compile_file(
176181
mlir_path,
177182
output_file=iree_module_path,
178183
extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
179184
)
180185

186+
logger.info("Invoking reference torch function...")
181187
reference_result_dict = call_torch_module_function(
182188
module=reference_model,
183189
function_name="forward",
@@ -187,6 +193,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
187193
expected_outputs = flatten_for_iree_signature(reference_result_dict)
188194

189195
iree_devices = get_iree_devices(driver="hip", device_count=1)
196+
logger.info("Loading IREE module...")
190197
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
191198
module_path=iree_module_path,
192199
devices=iree_devices,
@@ -195,6 +202,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
195202
iree_args = prepare_iree_module_function_args(
196203
args=flatten_for_iree_signature(input_args), devices=iree_devices
197204
)
205+
logger.info("Invoking IREE function...")
198206
iree_result = iree_to_torch(
199207
*run_iree_module_function(
200208
module=iree_module,
@@ -213,6 +221,7 @@ def runTestCompareIreeAgainstTorchEagerWithInputTokens(
213221
actual_last_hidden_state = actual_outputs[0]
214222
expected_last_hidden_state = expected_outputs[0]
215223

224+
logger.info("Comparing outputs...")
216225
assert_text_encoder_state_close(
217226
actual_last_hidden_state, expected_last_hidden_state, atol
218227
)

sharktank/tests/models/flux/flux_test.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from sharktank.types import Dataset, Theta
4141

4242
logging.basicConfig(level=logging.DEBUG)
43+
logger = logging.getLogger(__name__)
4344
with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')")
4445

4546
iree_compile_flags = [
@@ -102,6 +103,7 @@ def runCompareIreeAgainstTorchEager(
102103
parameters_path = self._temp_dir / "parameters.irpa"
103104
batch_size = 1
104105
batch_sizes = [batch_size]
106+
logger.info("Exporting flux transformer to MLIR...")
105107
export_flux_transformer(
106108
target_torch_model,
107109
mlir_output_path=mlir_path,
@@ -110,9 +112,10 @@ def runCompareIreeAgainstTorchEager(
110112
)
111113

112114
iree_module_path = self._temp_dir / "model.vmfb"
115+
logger.info("Compiling MLIR file...")
113116
iree.compiler.compile_file(
114-
mlir_path,
115-
output_file=iree_module_path,
117+
str(mlir_path),
118+
output_file=str(iree_module_path),
116119
extra_args=iree_compile_flags,
117120
)
118121

@@ -136,6 +139,7 @@ def runCompareIreeAgainstTorchEager(
136139
for k, t in target_input_kwargs.items()
137140
)
138141

142+
logger.info("Invoking reference torch function...")
139143
reference_result_dict = call_torch_module_function(
140144
module=reference_model,
141145
function_name="forward",
@@ -145,6 +149,7 @@ def runCompareIreeAgainstTorchEager(
145149
expected_outputs = flatten_for_iree_signature(reference_result_dict)
146150

147151
iree_devices = get_iree_devices(driver="hip", device_count=1)
152+
logger.info("Loading IREE module...")
148153
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
149154
module_path=iree_module_path,
150155
devices=iree_devices,
@@ -155,6 +160,7 @@ def runCompareIreeAgainstTorchEager(
155160
devices=iree_devices,
156161
)
157162

163+
logger.info("Invoking IREE function...")
158164
iree_result = iree_to_torch(
159165
*run_iree_module_function(
160166
module=iree_module,
@@ -168,6 +174,7 @@ def runCompareIreeAgainstTorchEager(
168174
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
169175
for i in range(len(expected_outputs))
170176
]
177+
logger.info("Comparing outputs...")
171178
torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0)
172179

173180
def runTestCompareDevIreeAgainstHuggingFace(

sharktank/tests/models/t5/t5_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Optional
1919
import os
2020
from collections import OrderedDict
21+
import logging
2122
import pytest
2223
import torch
2324
from torch.utils._pytree import tree_map, tree_unflatten, tree_flatten_with_path
@@ -61,6 +62,8 @@
6162

6263
with_t5_data = pytest.mark.skipif("not config.getoption('with_t5_data')")
6364

65+
logger = logging.getLogger(__name__)
66+
6467

6568
@pytest.mark.usefixtures("get_model_artifacts")
6669
class T5EncoderEagerTest(TestCase):
@@ -184,15 +187,18 @@ def runTestV1_1CompareTorchEagerAgainstHuggingFace(
184187
pad_to_multiple_of=config.context_length_padding_block_size,
185188
).input_ids
186189

190+
logger.info("Invoking Torch eager model...")
187191
model = T5Encoder(theta=dataset.root_theta, config=config)
188192
model.eval()
189193

194+
logger.info("Invoking reference HuggingFace model...")
190195
expected_outputs = reference_model(input_ids=input_ids)
191196
actual_outputs = model(input_ids=input_ids)
192197
actual_outputs = tree_map(
193198
lambda t: ops.to(t, dtype=reference_dtype), actual_outputs
194199
)
195200

201+
logger.info("Comparing outputs...")
196202
torch.testing.assert_close(
197203
actual_outputs, expected_outputs, atol=atol, rtol=rtol
198204
)
@@ -340,18 +346,21 @@ def runTestV1_1CompareIreeAgainstTorchEager(
340346

341347
mlir_path = f"{target_model_path_prefix}.mlir"
342348
if not self.caching or not os.path.exists(mlir_path):
349+
logger.info("Exporting T5 encoder model to MLIR...")
343350
export_encoder_mlir(
344351
parameters_path, batch_sizes=[batch_size], mlir_output_path=mlir_path
345352
)
346353
iree_module_path = f"{target_model_path_prefix}.vmfb"
347354
if not self.caching or not os.path.exists(iree_module_path):
355+
logger.info("Compiling MLIR file...")
348356
iree.compiler.compile_file(
349357
mlir_path,
350358
output_file=iree_module_path,
351359
extra_args=["--iree-hal-target-device=hip", "--iree-hip-target=gfx942"],
352360
)
353361

354362
iree_devices = get_iree_devices(driver="hip", device_count=1)
363+
logger.info("Loading IREE module...")
355364
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
356365
module_path=iree_module_path,
357366
devices=iree_devices,
@@ -360,6 +369,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
360369
iree_args = prepare_iree_module_function_args(
361370
args=flatten_for_iree_signature(input_args), devices=iree_devices
362371
)
372+
logger.info("Invoking IREE function...")
363373
iree_result = iree_to_torch(
364374
*run_iree_module_function(
365375
module=iree_module,
@@ -375,6 +385,7 @@ def runTestV1_1CompareIreeAgainstTorchEager(
375385
for i in range(len(reference_result))
376386
]
377387

388+
logger.info("Comparing outputs...")
378389
torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol)
379390

380391
@with_t5_data

0 commit comments

Comments
 (0)