Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 913e9c0

Browse files
committedJan 27, 2025·
Add Flux transformer benchmarking
Add also some more general functionality to check benchmark results against baseline results. This requires the Google Benchmark compare.py tool that is not a part of the pip package. That is why I added the repo as a git submodule. This tool does a statistical comparison between benchmarks with proper p-value calculation. I don't think we should roll out our own. Adds a new nightly CI job that should contain nightly tests and benchmarks that are not in their own category like Llama is now.
1 parent 3cecd77 commit 913e9c0

16 files changed

+1032
-24
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
name: Sharktank Nightly Tests
8+
9+
on:
10+
workflow_dispatch:
11+
schedule:
12+
# Weekdays at 10:00 AM UTC = 02:00 AM PST / 03:00 AM PDT
13+
- cron: "0 10 * * 1-5"
14+
15+
concurrency:
16+
# A PR number if a pull request and otherwise the commit hash. This cancels
17+
# queued and in-progress runs for the same PR (presubmit) or commit
18+
# (postsubmit). The workflow name is prepended to avoid conflicts between
19+
# different workflows.
20+
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
21+
cancel-in-progress: true
22+
23+
jobs:
24+
nightly-mi300x:
25+
if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }}
26+
name: "Nightly tests and benchmarks"
27+
strategy:
28+
matrix:
29+
version: [3.11]
30+
fail-fast: false
31+
runs-on: llama-mi300x-3
32+
defaults:
33+
run:
34+
shell: bash
35+
env:
36+
VENV_DIR: ${{ github.workspace }}/.venv
37+
HF_HOME: "/data/huggingface"
38+
steps:
39+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
40+
41+
- name: Get Current Date
42+
id: date
43+
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
44+
45+
- name: "Setting up Python"
46+
id: setup_python
47+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
48+
with:
49+
python-version: ${{matrix.version}}
50+
- name: Create Python venv
51+
run: python -m venv ${VENV_DIR}
52+
53+
- name: Install pip deps
54+
run: |
55+
source ${VENV_DIR}/bin/activate
56+
python -m pip install --no-compile --upgrade pip
57+
58+
# Note: We install in three steps in order to satisfy requirements
59+
# from non default locations first.
60+
pip install --no-compile -r pytorch-cpu-requirements.txt
61+
pip install -r requirements-iree-unpinned.txt
62+
pip install --no-compile \
63+
-r sharktank/requirements-tests.txt \
64+
-e sharktank/
65+
66+
pip freeze
67+
68+
- name: Run benchmarks
69+
run: |
70+
source ${VENV_DIR}/bin/activate
71+
pytest \
72+
--verbose \
73+
--capture=no \
74+
--iree-hip-target=gfx942 \
75+
--iree-device=hip://6 \
76+
--with-flux-data \
77+
-m="benchmark and expensive" \
78+
--html=out/benchmark/index.html \
79+
sharktank/tests
80+
81+
- name: Deploy to GitHub Pages
82+
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
83+
with:
84+
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
85+
publish_dir: ./out/benchmark
86+
destination_dir: ./benchmark
87+
keep_files: true

‎.github/workflows/ci-sharktank.yml

+5-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ jobs:
8585
- name: Run sharktank tests
8686
if: ${{ !cancelled() }}
8787
run: |
88-
pytest -n 4 sharktank/ --durations=10
88+
pytest \
89+
-n 4 \
90+
--durations=10 \
91+
-m "not expensive" \
92+
sharktank/
8993
9094
9195
test_with_data:

‎.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/benchmark"]
2+
path = third_party/benchmark
3+
url = https://github.com/google/benchmark

‎sharktank/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ addopts = [
4444
"-m=unit",
4545
]
4646
markers = [
47+
"benchmark: model benchmarks",
4748
"expensive: tests that are very expensive",
4849
"export: tests that require export from torch",
4950
"golden: tests that compare to some golden values",

‎sharktank/requirements-tests.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ accelerate
66

77
datasets==3.0.0
88
diffusers
9+
google-benchmark
910
parameterized
1011
protobuf
1112
pytest==8.0.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from pathlib import Path
8+
import iree.compiler
9+
import iree.runtime
10+
import os
11+
from iree.turbine.support.tools import iree_tool_prepare_input_args
12+
13+
from .export import (
14+
export_flux_transformer_from_hugging_face,
15+
flux_transformer_default_batch_sizes,
16+
iree_compile_flags,
17+
)
18+
from ...types import Dataset
19+
from .flux import FluxModelV1, FluxParams
20+
from ...utils.export_artifacts import ExportArtifacts
21+
from ...utils.iree import flatten_for_iree_signature
22+
from ...utils.benchmark import iree_benchmark_module
23+
24+
25+
def iree_benchmark_flux_dev_transformer(
26+
artifacts_dir: Path,
27+
iree_device: str,
28+
json_result_output_path: Path,
29+
caching: bool = False,
30+
) -> str:
31+
mlir_path = artifacts_dir / "model.mlir"
32+
parameters_path = artifacts_dir / "parameters.irpa"
33+
if (
34+
not caching
35+
or not os.path.exists(mlir_path)
36+
or not os.path.exists(parameters_path)
37+
):
38+
export_flux_transformer_from_hugging_face(
39+
"black-forest-labs/FLUX.1-dev/black-forest-labs-transformer",
40+
mlir_output_path=mlir_path,
41+
parameters_output_path=parameters_path,
42+
)
43+
return iree_benchmark_flux_transformer(
44+
mlir_path=mlir_path,
45+
parameters_path=parameters_path,
46+
artifacts_dir=artifacts_dir,
47+
iree_device=iree_device,
48+
json_result_output_path=json_result_output_path,
49+
caching=caching,
50+
)
51+
52+
53+
def iree_benchmark_flux_transformer(
54+
artifacts_dir: Path,
55+
mlir_path: Path,
56+
parameters_path: Path,
57+
iree_device: str,
58+
json_result_output_path: Path,
59+
caching: bool = False,
60+
) -> str:
61+
dataset = Dataset.load(parameters_path)
62+
model = FluxModelV1(
63+
theta=dataset.root_theta,
64+
params=FluxParams.from_hugging_face_properties(dataset.properties),
65+
)
66+
input_args = flatten_for_iree_signature(
67+
model.sample_inputs(batch_size=flux_transformer_default_batch_sizes[0])
68+
)
69+
cli_input_args = iree_tool_prepare_input_args(
70+
input_args, file_path_prefix=f"{artifacts_dir / 'arg'}"
71+
)
72+
cli_input_args = [f"--input={v}" for v in cli_input_args]
73+
74+
iree_module_path = artifacts_dir / "model.vmfb"
75+
if not caching or not os.path.exists(iree_module_path):
76+
iree.compiler.compile_file(
77+
mlir_path,
78+
output_file=iree_module_path,
79+
extra_args=iree_compile_flags,
80+
)
81+
82+
iree_benchmark_args = [
83+
f"--device={iree_device}",
84+
f"--module={iree_module_path}",
85+
f"--parameters=model={parameters_path}",
86+
f"--function=forward_bs{flux_transformer_default_batch_sizes[0]}",
87+
"--benchmark_repetitions=30",
88+
"--benchmark_min_warmup_time=1.0",
89+
"--benchmark_out_format=json",
90+
f"--benchmark_out={json_result_output_path}",
91+
] + cli_input_args
92+
return iree_benchmark_module(iree_benchmark_args)

‎sharktank/sharktank/models/flux/export.py

+20
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,26 @@
1717

1818
flux_transformer_default_batch_sizes = [1]
1919

20+
iree_compile_flags = [
21+
"--iree-hal-target-device=hip",
22+
"--iree-hip-target=gfx942",
23+
"--iree-opt-const-eval=false",
24+
"--iree-opt-strip-assertions=true",
25+
"--iree-global-opt-propagate-transposes=true",
26+
"--iree-dispatch-creation-enable-fuse-horizontal-contractions=true",
27+
"--iree-dispatch-creation-enable-aggressive-fusion=true",
28+
"--iree-opt-aggressively-propagate-transposes=true",
29+
"--iree-opt-outer-dim-concat=true",
30+
"--iree-vm-target-truncate-unsupported-floats",
31+
"--iree-llvmgpu-enable-prefetch=true",
32+
"--iree-opt-data-tiling=false",
33+
"--iree-codegen-gpu-native-math-precision=true",
34+
"--iree-codegen-llvmgpu-use-vector-distribution",
35+
"--iree-hip-waves-per-eu=2",
36+
"--iree-execution-model=async-external",
37+
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)",
38+
]
39+
2040

2141
def export_flux_transformer_model_mlir(
2242
model: FluxModelV1,

‎sharktank/sharktank/models/flux/testing.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from os import PathLike
99
from collections import OrderedDict
10+
import pytest
1011

1112
from .flux import FluxParams, FluxModelV1
1213
from .export import export_flux_transformer, flux_transformer_default_batch_sizes
@@ -17,6 +18,8 @@
1718
make_mmdit_single_block_random_theta,
1819
)
1920

21+
with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')")
22+
2023

2124
def convert_flux_transformer_input_for_hugging_face_model(
2225
img: torch.Tensor,

0 commit comments

Comments
 (0)
Please sign in to comment.