|
| 1 | +# Copyright 2024 Advanced Micro Devices, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +from pathlib import Path |
| 8 | +import iree.compiler |
| 9 | +import iree.runtime |
| 10 | +import os |
| 11 | +from iree.turbine.support.tools import iree_tool_prepare_input_args |
| 12 | + |
| 13 | +from .export import ( |
| 14 | + export_flux_transformer_from_hugging_face, |
| 15 | + flux_transformer_default_batch_sizes, |
| 16 | + iree_compile_flags, |
| 17 | +) |
| 18 | +from ...types import Dataset |
| 19 | +from .flux import FluxModelV1, FluxParams |
| 20 | +from ...utils.export_artifacts import ExportArtifacts |
| 21 | +from ...utils.iree import flatten_for_iree_signature |
| 22 | +from ...utils.benchmark import iree_benchmark_module |
| 23 | + |
| 24 | + |
| 25 | +def iree_benchmark_flux_dev_transformer( |
| 26 | + artifacts_dir: Path, |
| 27 | + iree_device: str, |
| 28 | + json_result_output_path: Path, |
| 29 | + caching: bool = False, |
| 30 | +) -> str: |
| 31 | + mlir_path = artifacts_dir / "model.mlir" |
| 32 | + parameters_path = artifacts_dir / "parameters.irpa" |
| 33 | + if ( |
| 34 | + not caching |
| 35 | + or not os.path.exists(mlir_path) |
| 36 | + or not os.path.exists(parameters_path) |
| 37 | + ): |
| 38 | + export_flux_transformer_from_hugging_face( |
| 39 | + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", |
| 40 | + mlir_output_path=mlir_path, |
| 41 | + parameters_output_path=parameters_path, |
| 42 | + ) |
| 43 | + return iree_benchmark_flux_transformer( |
| 44 | + mlir_path=mlir_path, |
| 45 | + parameters_path=parameters_path, |
| 46 | + artifacts_dir=artifacts_dir, |
| 47 | + iree_device=iree_device, |
| 48 | + json_result_output_path=json_result_output_path, |
| 49 | + caching=caching, |
| 50 | + ) |
| 51 | + |
| 52 | + |
| 53 | +def iree_benchmark_flux_transformer( |
| 54 | + artifacts_dir: Path, |
| 55 | + mlir_path: Path, |
| 56 | + parameters_path: Path, |
| 57 | + iree_device: str, |
| 58 | + json_result_output_path: Path, |
| 59 | + caching: bool = False, |
| 60 | +) -> str: |
| 61 | + dataset = Dataset.load(parameters_path) |
| 62 | + model = FluxModelV1( |
| 63 | + theta=dataset.root_theta, |
| 64 | + params=FluxParams.from_hugging_face_properties(dataset.properties), |
| 65 | + ) |
| 66 | + input_args = flatten_for_iree_signature( |
| 67 | + model.sample_inputs(batch_size=flux_transformer_default_batch_sizes[0]) |
| 68 | + ) |
| 69 | + cli_input_args = iree_tool_prepare_input_args( |
| 70 | + input_args, file_path_prefix=f"{artifacts_dir / 'arg'}" |
| 71 | + ) |
| 72 | + cli_input_args = [f"--input={v}" for v in cli_input_args] |
| 73 | + |
| 74 | + iree_module_path = artifacts_dir / "model.vmfb" |
| 75 | + if not caching or not os.path.exists(iree_module_path): |
| 76 | + iree.compiler.compile_file( |
| 77 | + mlir_path, |
| 78 | + output_file=iree_module_path, |
| 79 | + extra_args=iree_compile_flags, |
| 80 | + ) |
| 81 | + |
| 82 | + iree_benchmark_args = [ |
| 83 | + f"--device={iree_device}", |
| 84 | + f"--module={iree_module_path}", |
| 85 | + f"--parameters=model={parameters_path}", |
| 86 | + f"--function=forward_bs{flux_transformer_default_batch_sizes[0]}", |
| 87 | + "--benchmark_repetitions=30", |
| 88 | + "--benchmark_min_warmup_time=1.0", |
| 89 | + "--benchmark_out_format=json", |
| 90 | + f"--benchmark_out={json_result_output_path}", |
| 91 | + ] + cli_input_args |
| 92 | + return iree_benchmark_module(iree_benchmark_args) |
0 commit comments