@@ -154,6 +154,7 @@ class InferenceBenchmarkConfig:
154154 enable_nv_linear : bool
155155 mode : str
156156 disable_moe_replacement : bool
157+ attn_implementation : str | None
157158 profile : bool
158159
159160
@@ -319,7 +320,9 @@ def _load_model(self) -> torch.nn.Module:
319320 self .hf_config = config
320321
321322 with torch .device ("meta" ):
322- model = AutoModelForCausalLM .from_config (config , torch_dtype = torch .bfloat16 )
323+ model = AutoModelForCausalLM .from_config (
324+ config , torch_dtype = torch .bfloat16 , attn_implementation = self .config .attn_implementation
325+ )
323326
324327 return model
325328
@@ -676,6 +679,7 @@ def parse_args() -> argparse.Namespace:
676679
677680 parser .add_argument ("--save-results" , action = "store_true" , help = "Save results to JSON file" )
678681 parser .add_argument ("--output-dir" , type = str , default = "./results" , help = "Directory to save results" )
682+ parser .add_argument ("--attn-implementation" , type = str , default = None , help = "Attention implementation" )
679683
680684 args = parser .parse_args ()
681685 return args
@@ -707,6 +711,7 @@ def main():
707711 fx_report_folder = args .fx_report_folder ,
708712 enable_nv_linear = args .enable_nv_linear ,
709713 disable_moe_replacement = args .disable_moe_replacement ,
714+ attn_implementation = args .attn_implementation ,
710715 profile = args .profile ,
711716 )
712717 benchmark = InferenceBenchmark (config )
0 commit comments