Skip to content

Commit 43ffb5e

Browse files
authored
Initial migration of iree regression tests (#77)
Migrating [threshold tests](https://github.com/iree-org/iree/tree/main/experimental/regression_suite/shark-test-suite-models) and [benchmark tests](https://github.com/iree-org/iree/tree/main/experimental/benchmarks/sdxl) to this repository and also making improvements Improvements: - Made a generalized python script for both threshold tests and benchmark tests, so for current and future models/submodels, you just need to add/update a JSON file instead of duplicating code - Resolving [issue 19984](iree-org/iree#19984) by specifying file path [here](https://github.com/geomin12/iree-test-suites/blob/da44d58a316018c2a87da182609ca4819e2cd50a/sharktank_models/test_suite/regression_tests/sdxl/punet_int8_fp8.json#L69) and correctly getting the right path [here](https://github.com/geomin12/iree-test-suites/blob/da44d58a316018c2a87da182609ca4819e2cd50a/sharktank_models/test_suite/regression_tests/test_model_threshold.py#L93) - Improving / creating local development script to run these tests Things left to do in future PR(s): - Creating github action script that runs everything (in a midnight run or PR checks) and runs specific selected models through an action workflow_dispatch - Integrating Hugging Face model retrieval, adding Google's benchmark script and onboarding flux, according to this [PR](nod-ai/shark-ai#870 (comment)) --------- Signed-off-by: Min <[email protected]>
1 parent f8f2771 commit 43ffb5e

33 files changed

+2792
-20
lines changed

.github/workflows/test_sharktank_models.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ jobs:
6565
--durations=0 \
6666
--log-cli-level=info \
6767
--html=${HTML_REPORT_PATH} \
68-
--self-contained-html
68+
--self-contained-html \
69+
--ignore=sharktank_models/test_suite
6970
7071
- name: Upload HTML report
7172
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0

sharktank_models/clip/test_clip.py

+32-19
Original file line numberDiff line numberDiff line change
@@ -17,47 +17,51 @@
1717

1818
THIS_DIR = pathlib.Path(__file__).parent
1919

20+
2021
def load_tensor_from_irpa(path: PathLike) -> np.ndarray:
2122
index = iree.runtime.ParameterIndex()
2223
index.load(str(path))
2324
index_entry: iree.runtime.ParameterIndexEntry = index.items()[0][1]
2425
return iree.runtime.parameter_index_entry_as_numpy_ndarray(index_entry)
2526

27+
2628
@pytest.fixture(
27-
params=[
28-
pytest.param("local-task", marks=pytest.mark.target_cpu),
29-
pytest.param("hip", marks=pytest.mark.target_hip),
30-
]
29+
params=[
30+
pytest.param("local-task", marks=pytest.mark.target_cpu),
31+
pytest.param("hip", marks=pytest.mark.target_hip),
32+
]
3133
)
3234
def device_id(request: pytest.FixtureRequest) -> str:
3335
return request.param
3436

3537

36-
@pytest.fixture(
37-
params=["bf16", "f32"]
38-
)
38+
@pytest.fixture(params=["bf16", "f32"])
3939
def model_variant(request: pytest.FixtureRequest) -> str:
4040
return request.param
4141

4242

4343
mlir_path = {
4444
"bf16": THIS_DIR / "assets/text_model/toy/bf16.mlir",
45-
"f32": THIS_DIR / "assets/text_model/toy/f32.mlir"
45+
"f32": THIS_DIR / "assets/text_model/toy/f32.mlir",
4646
}
4747

4848
parameters_path = {
4949
"bf16": THIS_DIR / "assets/text_model/toy/bf16_parameters.irpa",
50-
"f32": THIS_DIR / "assets/text_model/toy/f32_parameters.irpa"
50+
"f32": THIS_DIR / "assets/text_model/toy/f32_parameters.irpa",
5151
}
5252

5353
function_arg0_path = THIS_DIR / "assets/text_model/toy/forward_bs4_arg0_input_ids.irpa"
54-
function_expected_result0 = THIS_DIR / "assets/text_model/toy/forward_bs4_expected_result0_last_hidden_state_f32.irpa"
54+
function_expected_result0 = (
55+
THIS_DIR
56+
/ "assets/text_model/toy/forward_bs4_expected_result0_last_hidden_state_f32.irpa"
57+
)
5558

5659
absolute_tolerance = {
5760
"bf16": 1e-3,
58-
"f32" : 1e-5,
61+
"f32": 1e-5,
5962
}
6063

64+
6165
def compiler_args(device_id: str) -> list[str]:
6266
if device_id == "local-task":
6367
return ["--iree-hal-target-device=llvm-cpu", "--iree-llvmcpu-target-cpu=host"]
@@ -70,16 +74,21 @@ def compiler_args(device_id: str) -> list[str]:
7074

7175
raise KeyError(f"Compiler args for {device_id} not found")
7276

73-
def compile_and_run(mlir_path: str, compiler_args: list[str], function: str, args: list[np.ndarray]) -> list[np.ndarray]:
77+
78+
def compile_and_run(
79+
mlir_path: str, compiler_args: list[str], function: str, args: list[np.ndarray]
80+
) -> list[np.ndarray]:
7481
iree.compiler.compile_file(
7582
mlir_path,
7683
extra_args=compiler_args,
7784
)
7885

86+
7987
@pytest.fixture(scope="session")
8088
def iree_module(model_variant, device_id) -> iree.runtime.VmModule:
8189
compiler_arguments = compiler_args(device_id)
8290

91+
8392
def device_array_to_host(device_array: iree.runtime.DeviceArray) -> np.ndarray:
8493
def reinterpret_hal_buffer_view_element_type(
8594
buffer_view: iree.runtime.HalBufferView,
@@ -157,11 +166,12 @@ def assert_text_encoder_state_close(
157166
rtol=0,
158167
)
159168

169+
160170
def test_results_close(model_variant, device_id):
161171
module_buffer = iree.compiler.compile_file(
162-
str(mlir_path[model_variant]),
163-
extra_args=compiler_args(device_id),
164-
)
172+
str(mlir_path[model_variant]),
173+
extra_args=compiler_args(device_id),
174+
)
165175

166176
vm_instance = iree.runtime.VmInstance()
167177
paramIndex = iree.runtime.ParameterIndex()
@@ -173,13 +183,16 @@ def test_results_close(model_variant, device_id):
173183
device = iree.runtime.get_device(device_id)
174184
hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=[device])
175185
vm_module = iree.runtime.VmModule.from_buffer(vm_instance, module_buffer)
176-
config=iree.runtime.Config(device=device)
177-
bound_modules = iree.runtime.load_vm_modules(hal_module, parameters_module, vm_module,
178-
config=config)
186+
config = iree.runtime.Config(device=device)
187+
bound_modules = iree.runtime.load_vm_modules(
188+
hal_module, parameters_module, vm_module, config=config
189+
)
179190
module = bound_modules[-1]
180191
result = module.forward_bs4(load_tensor_from_irpa(function_arg0_path))[0]
181192

182193
expected_result = load_tensor_from_irpa(function_expected_result0)
183194
result = device_array_to_host(result).astype(dtype=expected_result.dtype)
184195

185-
assert_text_encoder_state_close(result, expected_result, absolute_tolerance[model_variant])
196+
assert_text_encoder_state_close(
197+
result, expected_result, absolute_tolerance[model_variant]
198+
)

sharktank_models/requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# Baseline requirements for running the test suite.
22
# * See requirements-iree.txt for using IREE packages.
33

4+
azure-storage-blob
45
ml_dtypes
56
numpy
67
pytest
8+
pytest-check
9+
pytest-dependency
710
pytest-html
811
pytest-reportlog
12+
pytest-retry
913
pytest-timeout
1014
pytest-xdist
15+
tabulate

sharktank_models/test_suite/README.md

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Regression Test Suite
2+
3+
details to come!
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
## Benchmark tests
2+
3+
### Adding your own model
4+
5+
- To add your own model, create a directory under `benchmarks` and add JSON files that correspond to the submodels and chip. Please follow the [JSON file schema in this README file](#required-and-optional-fields-for-the-json-model-file)
6+
7+
### How to run
8+
9+
```
10+
python sharktank_models/test_suite/benchmarks/run_benchmarks.py --model=sdxl --filename=*
11+
12+
python sharktank_models/test_suite/benchmarks/run_benchmarks.py --model=sdxl --filename=clip_rocm
13+
```
14+
15+
Argument options for the script
16+
17+
| Argument Name | Default value | Description |
18+
| ------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
19+
| --model | sdxl | Runs benchmark tests for a specific model |
20+
| --filename | \* | If specified, the benchmark tests will run for a specific filename (ex: `--filename clip`). If not specified, it will run tests on all filenames |
21+
| --sku | mi300 | The benchmark tests will run on this sku and retrieve golden values from the specified sku |
22+
| --rocm-chip | gfx942 | The benchmark tests will run on this ROCM chip |
23+
24+
### Required and optional fields for the JSON model file
25+
26+
| Field Name | Required | Type | Description |
27+
| -------------------------------- | -------- | ------- | ---------------------------------------------------------------------------------------------------------------------------- |
28+
| inputs | required | array | An array of input strings for the benchmark module (ex: `["1xi64, 1xf16]`) |
29+
| compilation_required | optional | boolean | If true, this will let the benchmark test know that it needs to compile a file |
30+
| compiled_file_name | optional | string | When the compilation occurs, this will be the file name |
31+
| compile_flags | optional | array | An array of compiler flag options |
32+
| mlir_file_path | optional | string | Path to where the mlir file to compile is |
33+
| modules | optional | array | Specific to e2e, add modules here to include in the benchmarking test |
34+
| function_run | required | string | The function that the `iree-benchmark-module` will run adnd benchmark |
35+
| benchmark_repetitions | required | float | The number of times the benchmark tests will repeat |
36+
| benchmark_min_warmup_time | required | float | The minimum warm up time for the benchmark test |
37+
| device | required | string | The device that the benchmark tests are running |
38+
| golden_time_tolerance_multiplier | optional | object | An object of tolerance multipliers, where the key is the sku and the value is the multiplier, (ex: `{"mi250": 1.3}`) |
39+
| golden_time_ms | optional | object | An object of golden times, where the key is the sku and the value is the golden time in ms, (ex: `{"mi250": 100}`) |
40+
| golden_dispatch | optional | object | An object of golden dispatches, where the key is the sku and the value is the golden dispatch count, (ex: `{"mi250": 1602}`) |
41+
| golden_size | optional | object | An object of golden sizes, where the key is the sku and the value is the golden size in bytes, (ex: `{"mi250": 2000000}`) |
42+
| specific_chip_to_ignore | optional | array | An array of chip values, where the benchmark tests will ignore the chips specified |
43+
| real_weights_file_name | optional | string | If real weights is a different file name, specify it here in order to get the correct real weights file |
44+
45+
Please feel free to look at any JSON examples under a model directory (ex: sdxl)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module @sdxl_compiled_pipeline {
2+
func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor<i64>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"}
3+
func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf16>, %arg1: tensor<2x64x2048xf16>, %arg2: tensor<2x1280xf16>, %arg3: tensor<2x6xf16>, %arg4: tensor<1xf16>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"}
4+
func.func private @compiled_clip.encode_prompts(%arg0: tensor<1x64xi64>, %arg1: tensor<1x64xi64>, %arg2: tensor<1x64xi64>, %arg3: tensor<1x64xi64>) -> (tensor<2x64x2048xf16>, tensor<2x1280xf16>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"}
5+
func.func private @compiled_vae.main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"}
6+
7+
func.func @tokens_to_image(%sample: tensor<1x4x128x128xf16>, %guidance_scale: tensor<1xf16>, %t_ids_1: tensor<1x64xi64>, %t_ids_2: tensor<1x64xi64>, %u_ids_1: tensor<1x64xi64>, %u_ids_2: tensor<1x64xi64>) -> tensor<1x3x1024x1024xf16> {
8+
%p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>) -> (tensor<2x64x2048xf16>, tensor<2x1280xf16>)
9+
%noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor<i64>)
10+
%c0 = arith.constant 0 : index
11+
%c1 = arith.constant 1 : index
12+
%steps_int = tensor.extract %steps[] : tensor<i64>
13+
%n_steps = arith.index_cast %steps_int: i64 to index
14+
%res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf16>) {
15+
%step_64 = arith.index_cast %arg0 : index to i64
16+
%this_step = tensor.from_elements %step_64 : tensor<1xi64>
17+
%inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf16>, tensor<2x64x2048xf16>, tensor<2x1280xf16>, tensor<2x6xf16>, tensor<1xf16>, tensor<1xi64>) -> tensor<1x4x128x128xf16>
18+
scf.yield %inner : tensor<1x4x128x128xf16>
19+
}
20+
%image = func.call @compiled_vae.main(%res): (tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16>
21+
return %image : tensor<1x3x1024x1024xf16>
22+
}
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import subprocess
8+
import os
9+
from pathlib import Path
10+
import argparse
11+
import sys
12+
13+
def main():
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument("--model", type=str, default="sdxl")
16+
parser.add_argument("--filename", type=str, default="*")
17+
parser.add_argument("--sku", type=str, default="mi300")
18+
parser.add_argument("--rocm-chip", type=str, default="gfx942")
19+
args = parser.parse_args()
20+
model = args.model
21+
filename = args.filename
22+
sku = args.sku
23+
rocm_chip = args.rocm_chip
24+
25+
os.environ["BENCHMARK_MODEL"] = model
26+
os.environ["BENCHMARK_FILE_NAME"] = filename
27+
os.environ["SKU"] = sku
28+
os.environ["ROCM_CHIP"] = rocm_chip
29+
30+
THIS_DIR = Path(__file__).parent
31+
32+
command = [
33+
"pytest",
34+
THIS_DIR / "test_model_benchmark.py",
35+
"--log-cli-level=info",
36+
"--timeout=600",
37+
"--retries=7",
38+
]
39+
subprocess.run(command)
40+
return 0
41+
42+
43+
if __name__ == "__main__":
44+
sys.exit(main())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"inputs": [
3+
"1x64xi64",
4+
"1x64xi64",
5+
"1x64xi64",
6+
"1x64xi64"
7+
],
8+
"function_run": "encode_prompts",
9+
"benchmark_flags": [
10+
"--benchmark_repetitions=10",
11+
"--benchmark_min_warmup_time=3.0",
12+
"--device_allocator=caching"
13+
],
14+
"device": "hip",
15+
"golden_time_tolerance_multiplier": {
16+
"mi250": 1.3,
17+
"mi300": 1.1,
18+
"mi308": 1.1
19+
},
20+
"golden_time_ms": {
21+
"mi250": 14.5,
22+
"mi300": 15.0,
23+
"mi308": 15.0
24+
},
25+
"golden_dispatch": {
26+
"mi250": 1139,
27+
"mi300": 1139,
28+
"mi308": 1139
29+
},
30+
"golden_size": {
31+
"mi250": 860000,
32+
"mi300": 860000,
33+
"mi308": 860000
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
{
2+
"inputs": [
3+
"1x4x128x128xf16",
4+
"1xf16",
5+
"1x64xi64",
6+
"1x64xi64",
7+
"1x64xi64",
8+
"1x64xi64"
9+
],
10+
"compilation_required": true,
11+
"compiled_file_name": "sdxl_full_pipeline_fp16_rocm",
12+
"compile_flags": [
13+
"--iree-global-opt-propagate-transposes=true",
14+
"--iree-codegen-llvmgpu-use-vector-distribution",
15+
"--iree-codegen-gpu-native-math-precision=true",
16+
"--iree-hip-waves-per-eu=2",
17+
"--iree-opt-outer-dim-concat=true",
18+
"--iree-llvmgpu-enable-prefetch",
19+
"--iree-hal-target-backends=rocm"
20+
],
21+
"mlir_file_path": "external_test_files/sdxl_pipeline_bench_f16.mlir",
22+
"modules": [
23+
"sdxl_clip",
24+
"sdxl_unet_fp16",
25+
"sdxl_vae"
26+
],
27+
"function_run": "tokens_to_image",
28+
"benchmark_flags": [
29+
"--benchmark_repetitions=10",
30+
"--benchmark_min_warmup_time=3.0",
31+
"--device_allocator=caching"
32+
],
33+
"device": "hip",
34+
"golden_time_tolerance_multiplier": {
35+
"mi250": 1.3,
36+
"mi300": 1.1,
37+
"mi308": 1.1
38+
},
39+
"golden_time_ms": {
40+
"mi250": 1100,
41+
"mi300": 325,
42+
"mi308": 800
43+
}
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"inputs": [
3+
"1x4x128x128xf16",
4+
"1xf16",
5+
"2x64x2048xf16",
6+
"2x1280xf16",
7+
"2x6xf16",
8+
"1xf16"
9+
],
10+
"function_run": "main",
11+
"benchmark_flags": [
12+
"--benchmark_repetitions=10",
13+
"--benchmark_min_warmup_time=3.0",
14+
"--device_allocator=caching"
15+
],
16+
"device": "hip",
17+
"golden_time_tolerance_multiplier": {
18+
"mi300": 1.1,
19+
"mi308": 1.1
20+
},
21+
"golden_time_ms": {
22+
"mi300": 50,
23+
"mi308": 140
24+
},
25+
"golden_dispatch": {
26+
"mi300": 1424,
27+
"mi308": 1424
28+
},
29+
"golden_size": {
30+
"mi300": 2560000,
31+
"mi308": 2560000
32+
},
33+
"specific_chip_to_ignore": ["gfx90a"],
34+
"real_weights_file_name": "punet_weights.irpa"
35+
}

0 commit comments

Comments
 (0)