Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more logging to sharktank data tests #1017

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-llama-large-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
- name: Run llama tests
run: |
source ${VENV_DIR}/bin/activate
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-nightly-llama-tests --iree-hip-target=gfx942 --iree-device=hip://0 --html=out/llm/llama/benchmark/index.html
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v --log-cli-level=info -s --run-nightly-llama-tests --iree-hip-target=gfx942 --iree-device=hip://0 --html=out/llm/llama/benchmark/index.html

- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-llama-quick-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
- name: Run llama 8b f16 decomposed test
run: |
source ${VENV_DIR}/bin/activate
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --iree-device=hip://0 --run-quick-llama-test
pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v --log-cli-level=info -s --iree-hip-target=gfx942 --iree-device=hip://0 --run-quick-llama-test

- name: Upload llama executable files
uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
Expand Down
52 changes: 52 additions & 0 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
IreeCompileException,
)

logger = logging.getLogger(__name__)

is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
skipif_run_quick_llama_test = pytest.mark.skipif(
'config.getoption("run-quick-llama-test") and not config.getoption("run-nightly-llama-tests")',
Expand Down Expand Up @@ -106,6 +108,8 @@ def save_benchmarks(

@is_mi300x
class BenchmarkLlama3_1_8B(BaseBenchmarkTest):
logger.info("Testing BenchmarkLlama3_1_8B...")
Comment on lines 110 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove these class-scoped log lines and instead rely on the -v option already included on the pytest commands in the workflow files.


I don't think logs at the class scope are doing what you want / are useful: https://github.com/nod-ai/shark-ai/actions/runs/13667985379/job/38212871488?pr=1017#step:7:22

snippet:

collecting ... 
----------------------------- live log collection ------------------------------
INFO:tests.models.llama.benchmark_amdgpu_test BenchmarkLlama3_1_8B...
INFO:tests.models.llama.benchmark_amdgpu_test Testing Benchmark8B_f16_Non_Decomposed_Input_Len_128...
INFO:tests.models.llama.benchmark_amdgpu_test Testing Benchmark405B_fp8_TP8_Non_Decomposed...
collected 11 items

Pytest has a few phases while running tests. This is the "collection" phase, which is where pytest discovers test cases from the directory tree based on any conftest.py files, setting/configuration files, and command line options. These log lines are running during that collection phase, when you really want them when the tests are actually starting. Only after collection completes are individual tests actually ran, then there are a few other phases.

I can't find docs on the phases, but https://docs.pytest.org/en/stable/reference/reference.html#hooks is close enough to get the point across:

  1. bootstrapping
  2. initialization
  3. collection
  4. test running
  5. reporting

However, pytest can already log when a test starts. See https://docs.pytest.org/en/stable/how-to/output.html and the -v option in particular which changes from

=========================== test session starts ============================
collected 4 items

test_verbosity_example.py .FFF                                       [100%]

================================= FAILURES =================================

to

=========================== test session starts ============================
collecting ... collected 4 items

test_verbosity_example.py::test_ok PASSED                            [ 25%]
test_verbosity_example.py::test_words_fail FAILED                    [ 50%]
test_verbosity_example.py::test_numbers_fail FAILED                  [ 75%]
test_verbosity_example.py::test_long_text_fail FAILED                [100%]

================================= FAILURES =================================


def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
Expand Down Expand Up @@ -191,6 +195,7 @@ def setUp(self):
]

def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_128(self):
logger.info("Testing Benchmark8B_f16_TP1_Non_Decomposed_Input_Len_128...")
output_file_name = self.dir_path_8b / "f16_torch_128_tp1"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -208,6 +213,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_128(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
Comment on lines +216 to 217
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper already does some logging, but the extra logs don't hurt. What we could do in these places is provide context for what or why we're compiling/running. The "IREE Benchmark Prefill..." logs you added are a good example of that. Just seeing iree-run-module ... in the logs doesn't immediately provide such context, but "benchmark prefill" does.

INFO:tests.models.llama.benchmark_amdgpu_test Compiling MLIR file...

INFO:eval      Launching compile command:
cd /home/runner/_work/shark-ai/shark-ai && iree-compile /home/runner/_work/shark-ai/shark-ai/2025-03-05/llama-8b/f16_torch_128.mlir --iree-hip-target=gfx942 -o=/home/runner/_work/shark-ai/shark-ai/2025-03-05/llama-8b/f16_torch_128.vmfb --iree-hal-target-device=hip --iree-hal-dump-executable-files-to=/home/runner/_work/shark-ai/shark-ai/2025-03-05/llama-8b/f16_torch_128/files --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-stream-resource-memory-model=discrete --iree-hal-indirect-command-buffers=true --iree-hal-memoization=true --iree-opt-strip-assertions

INFO:eval      compile_to_vmfb: 16.77 secs

mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -216,6 +222,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_128(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -225,6 +232,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_128(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -236,6 +244,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_128(self):

@skipif_run_quick_llama_test
def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
logger.info("Testing Benchmark8B_f16_TP1_Non_Decomposed_Input_Len_2048...")
output_file_name = self.dir_path_8b / "f16_torch_2048_tp1"
output_mlir = self.llama8b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -253,6 +262,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama8b_f16_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -261,6 +271,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -270,6 +281,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama8b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -281,6 +293,7 @@ def testBenchmark8B_f16_TP1_Non_Decomposed_Input_Len_2048(self):

@skipif_run_quick_llama_test
def testBenchmark8B_fp8_TP1_Non_Decomposed(self):
logger.info("Testing Benchmark8B_fp8_TP1_Non_Decomposed...")
output_file_name = self.dir_path_8b / "fp8_torch_tp1"
output_mlir = self.llama8b_fp8_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -298,6 +311,7 @@ def testBenchmark8B_fp8_TP1_Non_Decomposed(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama8b_fp8_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -306,6 +320,7 @@ def testBenchmark8B_fp8_TP1_Non_Decomposed(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -315,6 +330,7 @@ def testBenchmark8B_fp8_TP1_Non_Decomposed(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama8b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -328,6 +344,9 @@ def testBenchmark8B_fp8_TP1_Non_Decomposed(self):
@is_mi300x
@skipif_run_quick_llama_test
class BenchmarkLlama3_1_70B(BaseBenchmarkTest):

logger.info("Testing BenchmarkLlama3_1_70B...")

def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
Expand Down Expand Up @@ -450,6 +469,7 @@ def setUp(self):
]

def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_128(self):
logger.info("Testing Benchmark70B_f16_TP1_Non_Decomposed_Input_Len_128...")
output_file_name = self.dir_path_70b / "f16_torch_128_tp1"
output_mlir = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -467,6 +487,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_128(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama70b_f16_torch_sdpa_artifacts_tp1.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -475,6 +496,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_128(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -484,6 +506,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_128(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -494,6 +517,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_128(self):
)

def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
logger.info("Testing Benchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048...")
output_file_name = self.dir_path_70b / "f16_torch_2048_tp1"
output_mlir = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -511,6 +535,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama70b_f16_torch_sdpa_artifacts_tp1.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -519,6 +544,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -528,6 +554,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -541,6 +568,7 @@ def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
)
def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_128(self):
logger.info("Testing Benchmark70B_f16_TP8_Non_Decomposed_Input_Len_128...")
output_file_name = self.dir_path_70b / "f16_torch_128_tp8"
output_mlir = self.llama70b_f16_torch_sdpa_artifacts_tp8.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -563,6 +591,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_128(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama70b_f16_torch_sdpa_artifacts_tp8.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -571,6 +600,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_128(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
Comment on lines 602 to +603
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: when comments and logs are saying the same thing, you can remove the comments

self.llama70b_f16_torch_sdpa_artifacts_tp8.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -579,6 +609,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_128(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama70b_f16_torch_sdpa_artifacts_tp8.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -591,6 +622,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_128(self):
reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
)
def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
logger.info("Testing Benchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048...")
output_file_name = self.dir_path_70b / "f16_torch_2048_tp8"
output_mlir = self.llama70b_f16_torch_sdpa_artifacts_tp8.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -613,6 +645,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama70b_f16_torch_sdpa_artifacts_tp8.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -621,6 +654,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama70b_f16_torch_sdpa_artifacts_tp8.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -629,6 +663,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama70b_f16_torch_sdpa_artifacts_tp8.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -641,6 +676,7 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
reason="70b fp8 irpa does not exist", strict=True, raises=ExportMlirException
)
def testBenchmark70B_fp8_TP1_Non_Decomposed(self):
logger.info("Testing Benchmark70B_fp8_TP1_Non_Decomposed...")
output_file_name = self.dir_path_70b / "fp8_torch_tp1"
output_mlir = self.llama70b_fp8_torch_sdpa_artifacts_tp1.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -655,6 +691,7 @@ def testBenchmark70B_fp8_TP1_Non_Decomposed(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama70b_fp8_torch_sdpa_artifacts_tp1.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -663,6 +700,7 @@ def testBenchmark70B_fp8_TP1_Non_Decomposed(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama70b_fp8_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -671,6 +709,7 @@ def testBenchmark70B_fp8_TP1_Non_Decomposed(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama70b_fp8_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -683,6 +722,9 @@ def testBenchmark70B_fp8_TP1_Non_Decomposed(self):
@is_mi300x
@skipif_run_quick_llama_test
class BenchmarkLlama3_1_405B(BaseBenchmarkTest):

logger.info("Testing BenchmarkLlama3_1_405B...")

def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
Expand Down Expand Up @@ -773,6 +815,7 @@ def setUp(self):
reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
)
def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_128(self):
logger.info("Testing Benchmark405B_f16_TP8_Non_Decomposed_Input_Len_128...")
output_file_name = self.dir_path_405b / "f16_torch_128"
output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -793,6 +836,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_128(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama405b_f16_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -801,6 +845,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_128(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -814,6 +859,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_128(self):
reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
)
def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
logger.info("Testing Benchmark405B_f16_TP8_Non_Decomposed_Input_Len_2048...")
output_file_name = self.dir_path_405b / "f16_torch_2048"
output_mlir = self.llama405b_f16_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -834,6 +880,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama405b_f16_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -842,6 +889,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama405b_f16_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -855,6 +903,7 @@ def testBenchmark405B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
reason="KeyError in theta.py", strict=True, raises=ExportMlirException
)
def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
logger.info("Testing Benchmark405B_fp8_TP8_Non_Decomposed...")
output_file_name = self.dir_path_405b / "fp8_torch"
output_mlir = self.llama405b_fp8_torch_sdpa_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand All @@ -875,6 +924,7 @@ def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
mlir_path=output_mlir,
json_path=output_json,
)
logger.info("Compiling MLIR file...")
self.llama405b_fp8_torch_sdpa_artifacts.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
Expand All @@ -883,6 +933,7 @@ def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
args=self.compile_args,
)
# benchmark prefill
logger.info("IREE Benchmark Prefill...")
self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand All @@ -891,6 +942,7 @@ def testBenchmark405B_fp8_TP8_Non_Decomposed(self):
cwd=self.repo_root,
)
# benchmark decode
logger.info("IREE Benchmark Decode...")
self.llama405b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
Expand Down
Loading