Skip to content

Commit f8648aa

Browse files
benchmark_inference: Allow passing attn-implementation (#2672)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b3a94d1 commit f8648aa

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

thunder/benchmarks/benchmark_inference.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class InferenceBenchmarkConfig:
154154
enable_nv_linear: bool
155155
mode: str
156156
disable_moe_replacement: bool
157+
attn_implementation: str | None
157158
profile: bool
158159

159160

@@ -319,7 +320,9 @@ def _load_model(self) -> torch.nn.Module:
319320
self.hf_config = config
320321

321322
with torch.device("meta"):
322-
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
323+
model = AutoModelForCausalLM.from_config(
324+
config, torch_dtype=torch.bfloat16, attn_implementation=self.config.attn_implementation
325+
)
323326

324327
return model
325328

@@ -676,6 +679,7 @@ def parse_args() -> argparse.Namespace:
676679

677680
parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
678681
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
682+
parser.add_argument("--attn-implementation", type=str, default=None, help="Attention implementation")
679683

680684
args = parser.parse_args()
681685
return args
@@ -707,6 +711,7 @@ def main():
707711
fx_report_folder=args.fx_report_folder,
708712
enable_nv_linear=args.enable_nv_linear,
709713
disable_moe_replacement=args.disable_moe_replacement,
714+
attn_implementation=args.attn_implementation,
710715
profile=args.profile,
711716
)
712717
benchmark = InferenceBenchmark(config)

0 commit comments

Comments
 (0)