Skip to content

Commit fec3869

Browse files
Merge branch 'main' into dev/wdziurdz/test-matmul-6
2 parents 6267a5c + ca6fb8a commit fec3869

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,10 @@ def perf_report(benchmarks):
263263
class MarkArgs:
264264
reports: str = ""
265265
n_runs: int = 1
266+
brief: bool = False
266267

267-
@classmethod
268-
def _get_parser(cls) -> argparse.ArgumentParser:
268+
@staticmethod
269+
def load_cli_args() -> MarkArgs:
269270
"""Parses arguments via CLI, allows save_path overloading to `reports`."""
270271
parser = argparse.ArgumentParser()
271272
parser.add_argument(
@@ -280,12 +281,14 @@ def _get_parser(cls) -> argparse.ArgumentParser:
280281
default=1,
281282
help="Number of runs for this benchmark. The default is one.",
282283
)
283-
return parser
284-
285-
@classmethod
286-
def from_args(cls) -> "MarkArgs":
287-
args = cls._get_parser().parse_args()
288-
return MarkArgs(args.reports, args.n_runs)
284+
parser.add_argument(
285+
"--brief",
286+
"-b",
287+
action="store_true",
288+
help="Print only mean values without min, max, CV.",
289+
)
290+
args = parser.parse_args()
291+
return MarkArgs(args.reports, args.n_runs, args.brief)
289292

290293

291294
class Mark:
@@ -295,8 +298,8 @@ def __init__(self, fn, benchmarks):
295298
self.benchmarks = benchmarks
296299

297300
# pylint: disable=too-many-branches
298-
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, run_counter=0,
299-
save_precision=6, **kwargs):
301+
def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, mark_args, diff_col=False,
302+
run_counter=0, save_precision=6, **kwargs):
300303
y_vals = []
301304
for label in bench.ylabel:
302305
y_mean = [f"{x}-{label}" for x in bench.line_names]
@@ -373,13 +376,17 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b
373376

374377
if print_data:
375378
print(bench.plot_name + ":")
376-
print(df.to_string())
379+
if mark_args.brief:
380+
print(df[[c for c in df.columns if not any(map(c.endswith, ("min", "max", "CV")))]].to_string())
381+
else:
382+
print(df.to_string())
383+
377384
if save_path:
378385
df.to_csv(os.path.join(save_path, f"{filename}.csv"), float_format=f"%.{save_precision}f", index=False)
379386
return df
380387

381388
def run(self, show_plots=False, print_data=False, return_df=False, save_precision=6, mark_args=None, **kwargs):
382-
args = MarkArgs().from_args() if mark_args is None else mark_args
389+
args = mark_args or MarkArgs.load_cli_args()
383390

384391
has_single_bench = isinstance(self.benchmarks, Benchmark)
385392
benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
@@ -392,7 +399,8 @@ def run(self, show_plots=False, print_data=False, return_df=False, save_precisio
392399
for bench in benchmarks:
393400
benchmark_dfs = []
394401
for run_counter in range(args.n_runs):
395-
df = self._run(bench, args.reports, show_plots, print_data, run_counter=run_counter, **kwargs)
402+
df = self._run(bench, args.reports, show_plots, print_data, mark_args=args, run_counter=run_counter,
403+
**kwargs)
396404
df["datetime"] = datetime.datetime.now()
397405
df["run_counter"] = run_counter + 1
398406
benchmark_dfs.append(df)
-34.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)