From c2bf2dbb025073eb95d445587a8a6530ea03ec94 Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Sat, 15 Nov 2025 21:56:24 -0800 Subject: [PATCH] Internal. PiperOrigin-RevId: 832867038 --- .../nn/tests/preprocess_input_benchmarks.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py index 88d60e44..b7ac4e53 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py @@ -26,17 +26,14 @@ The --benchmark_filter flag uses a regex to select benchmarks. For parameterized benchmarks, the name is typically formatted as: -`[benchmark_name]/[param1]:[value1]/[param2]:[value2]`. Boolean parameters are often represented as 0 for False and 1 for True. +`[benchmark_name]/[param1]:[value1]`. For example, to run only the `sparse_coo` benchmarks: `--benchmark_filter=preprocess_input_benchmark_sparse_coo` -To run only the `sparse_coo` benchmark where `has_leading_dimension` is `False`: -`--benchmark_filter='preprocess_input_benchmark_sparse_coo/has_leading_dimension:0'` - -To run all benchmarks across all suites where `has_leading_dimension` is `False`: -`--benchmark_filter='/has_leading_dimension:0'` +To run only the ragged benchmark with ragged=True: +`--benchmark_filter='preprocess_input_benchmark/ragged:1'` To upload the profile to pprof: pprof -flame /tmp/preprocess.prof @@ -220,12 +217,12 @@ def apply_fdo_stats( @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) -@google_benchmark.option.arg_names(["ragged", "has_leading_dimension"]) -@google_benchmark.option.args_product([[False, True], [False, True]]) +@google_benchmark.option.arg_names(["ragged"]) +@google_benchmark.option.args_product([[False, True]]) @google_benchmark.option.iterations(100) def preprocess_input_benchmark(state: google_benchmark.State): """Benchmark for preprocessing input for sparse-dense matmul.""" - ragged, has_leading_dimension = state.range(0), state.range(1) + ragged = state.range(0) if ragged: features, feature_weights = _GLOBAL_RAGGED_FEATURES, _GLOBAL_RAGGED_WEIGHTS else: @@ -241,7 +238,6 @@ def preprocess_input_benchmark(state: google_benchmark.State): local_device_count=4, global_device_count=16, num_sc_per_device=4, - has_leading_dimension=has_leading_dimension, batch_number=batch_num, allow_id_dropping=batch_num == 0, ) @@ -253,12 +249,9 @@ def preprocess_input_benchmark(state: google_benchmark.State): @google_benchmark.register @google_benchmark.option.unit(google_benchmark.kMillisecond) -@google_benchmark.option.arg_name("has_leading_dimension") -@google_benchmark.option.args_product([[False, True]]) @google_benchmark.option.iterations(100) def preprocess_input_benchmark_sparse_coo(state: google_benchmark.State): """Benchmark for preprocessing input for sparse-dense matmul.""" - has_leading_dimension = state.range(0) batch_num = 0 while state: if batch_num == 0: @@ -272,7 +265,6 @@ def preprocess_input_benchmark_sparse_coo(state: google_benchmark.State): local_device_count=4, global_device_count=16, num_sc_per_device=4, - has_leading_dimension=has_leading_dimension, batch_number=batch_num, allow_id_dropping=batch_num == 0, ) @@ -312,7 +304,6 @@ def worker(host_id: int, batch_number: int): local_device_count=4, global_device_count=16, num_sc_per_device=4, - has_leading_dimension=False, enable_minibatching=True, all_reduce_interface=all_reduce_interfaces[host_id], batch_number=batch_number,