@@ -40,35 +40,21 @@ def elapsed_time(self, other_event):
4040 return abs (other_event .event_time - self .event_time ) * 1000
4141
4242
43- def device_timer (device ):
44- if "cuda" in device :
45- return torch .cuda .Event (enable_timing = True )
46- elif "xpu" in device :
47- return torch .xpu .Event (enable_timing = True )
43+ def device_timer (device : str ):
44+ if device in ["cuda" , "xpu" ]:
45+ return torch .Event (enable_timing = True )
4846 elif ("cpu" in device ) or ("mps" in device ):
4947 return HostEvent ()
5048 else :
5149 print (f"device={ device } is not yet suppported" )
5250
5351
54- def device_sync (device ):
55- if "cuda" in device :
56- torch .cuda .synchronize (device )
57- elif "xpu" in device :
58- torch .xpu .synchronize (device )
59- elif ("cpu" in device ) or ("mps" in device ):
60- pass
61- else :
62- print (f"device={ device } is not yet suppported" )
52+ def device_sync (device : str ):
53+ if torch .accelerator .is_available ():
54+ torch .accelerator .synchronize (device )
6355
6456
65- default_device = (
66- "cuda"
67- if torch .cuda .is_available ()
68- else "xpu"
69- if torch .xpu .is_available ()
70- else "cpu"
71- )
57+ default_device = acc .type if (acc := torch .accelerator .current_accelerator (True )) else "cpu"
7258
7359# support running without installing as a package
7460wd = Path (__file__ ).parent .parent .resolve ()
@@ -160,10 +146,10 @@ def generate(
160146 kv_cache_quantization : bool = False ,
161147 cache_size : Optional [int ] = None ,
162148 linear_causal_mask : bool = False ,
163- prefill_start_event : Optional [torch .cuda . Event ] = None ,
164- prefill_end_event : Optional [torch .cuda . Event ] = None ,
165- decode_start_event : Optional [torch .cuda . Event ] = None ,
166- decode_end_event : Optional [torch .cuda . Event ] = None ,
149+ prefill_start_event : Optional [torch .Event ] = None ,
150+ prefill_end_event : Optional [torch .Event ] = None ,
151+ decode_start_event : Optional [torch .Event ] = None ,
152+ decode_end_event : Optional [torch .Event ] = None ,
167153 ** sampling_kwargs ,
168154) -> torch .Tensor :
169155 """
@@ -281,8 +267,8 @@ def main(
281267 compile_prefill : bool = False ,
282268 profile : Optional [Path ] = None ,
283269 memory_profile : Optional [Path ] = None ,
284- device = default_device ,
285- precision = torch .bfloat16 ,
270+ device : str = default_device ,
271+ precision = torch .bfloat16 ,
286272 write_result : Optional [Path ] = None ,
287273 output_json_path : Optional [Path ] = None ,
288274 output_json_local : bool = False ,
@@ -606,7 +592,7 @@ def ffn_or_attn_only(mod, fqn):
606592 prepare_inputs_for_model ,
607593 False , # pad_calibration_inputs
608594 model .config .vocab_size ,
609- device = "cuda" ,
595+ device = device ,
610596 )
611597 .record_inputs (
612598 ["wikitext" ],
@@ -616,7 +602,7 @@ def ffn_or_attn_only(mod, fqn):
616602 .values [0 ]
617603 )
618604 inputs = prepare_inputs_for_model (inputs )
619- with torch .device ("cuda" ):
605+ with torch .device (device ):
620606 model .setup_caches (
621607 max_batch_size = 1 , max_seq_length = calibration_seq_length
622608 )
@@ -883,10 +869,7 @@ def ffn_or_attn_only(mod, fqn):
883869
884870 for i in range (start , num_samples ):
885871 if i == 0 :
886- if device == "cuda" :
887- torch .cuda .reset_peak_memory_stats () # MKG
888- elif device == "xpu" :
889- torch .xpu .reset_peak_memory_stats () # MKG
872+ torch .accelerator .reset_peak_memory_stats () # MKG
890873 device_sync (device = device ) # MKG
891874 if i >= 0 and interactive :
892875 prompt = input ("What is your prompt? " )
@@ -1016,14 +999,10 @@ def callback(x):
1016999 torch .tensor (aggregate_metrics ["decode_tokens_per_sec" ])
10171000 ).item ()
10181001 bandwidth = model_size * tokpersec
1019- mem = torch .cuda .max_memory_reserved () / 1e9
10201002 print (f"Average overall tokens/sec: { tokpersec :.2f} " )
10211003 print (f"Average decode tokens/sec: { decode_tokpersec :.04f} s" )
10221004 print (f"Average TTFT: { ttft :.04f} s" )
1023- if device == "cuda" :
1024- mem = torch .cuda .max_memory_reserved () / 1e9
1025- elif device == "xpu" :
1026- mem = torch .xpu .max_memory_reserved () / 1e9
1005+ mem = torch .accelerator .max_memory_reserved () / 1e9
10271006 print (f"Average tokens/sec: { tokpersec :.2f} " )
10281007 if batch_size > 1 :
10291008 print (f"Average tokens/sec including batches { batch_size * tokpersec :.2f} " )
0 commit comments