Skip to content

Commit 57478ea

Browse files
committed
Unify llama generate utils
1 parent 17867e6 commit 57478ea

File tree

1 file changed

+17
-38
lines changed

1 file changed

+17
-38
lines changed

torchao/_models/llama/generate.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -40,35 +40,21 @@ def elapsed_time(self, other_event):
4040
return abs(other_event.event_time - self.event_time) * 1000
4141

4242

43-
def device_timer(device):
44-
if "cuda" in device:
45-
return torch.cuda.Event(enable_timing=True)
46-
elif "xpu" in device:
47-
return torch.xpu.Event(enable_timing=True)
43+
def device_timer(device: str):
44+
if device in ["cuda", "xpu"]:
45+
return torch.Event(enable_timing=True)
4846
elif ("cpu" in device) or ("mps" in device):
4947
return HostEvent()
5048
else:
5149
print(f"device={device} is not yet suppported")
5250

5351

54-
def device_sync(device):
55-
if "cuda" in device:
56-
torch.cuda.synchronize(device)
57-
elif "xpu" in device:
58-
torch.xpu.synchronize(device)
59-
elif ("cpu" in device) or ("mps" in device):
60-
pass
61-
else:
62-
print(f"device={device} is not yet suppported")
52+
def device_sync(device: str):
53+
if torch.accelerator.is_available():
54+
torch.accelerator.synchronize(device)
6355

6456

65-
default_device = (
66-
"cuda"
67-
if torch.cuda.is_available()
68-
else "xpu"
69-
if torch.xpu.is_available()
70-
else "cpu"
71-
)
57+
default_device = acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
7258

7359
# support running without installing as a package
7460
wd = Path(__file__).parent.parent.resolve()
@@ -160,10 +146,10 @@ def generate(
160146
kv_cache_quantization: bool = False,
161147
cache_size: Optional[int] = None,
162148
linear_causal_mask: bool = False,
163-
prefill_start_event: Optional[torch.cuda.Event] = None,
164-
prefill_end_event: Optional[torch.cuda.Event] = None,
165-
decode_start_event: Optional[torch.cuda.Event] = None,
166-
decode_end_event: Optional[torch.cuda.Event] = None,
149+
prefill_start_event: Optional[torch.Event] = None,
150+
prefill_end_event: Optional[torch.Event] = None,
151+
decode_start_event: Optional[torch.Event] = None,
152+
decode_end_event: Optional[torch.Event] = None,
167153
**sampling_kwargs,
168154
) -> torch.Tensor:
169155
"""
@@ -281,8 +267,8 @@ def main(
281267
compile_prefill: bool = False,
282268
profile: Optional[Path] = None,
283269
memory_profile: Optional[Path] = None,
284-
device=default_device,
285-
precision=torch.bfloat16,
270+
device: str = default_device,
271+
precision = torch.bfloat16,
286272
write_result: Optional[Path] = None,
287273
output_json_path: Optional[Path] = None,
288274
output_json_local: bool = False,
@@ -606,7 +592,7 @@ def ffn_or_attn_only(mod, fqn):
606592
prepare_inputs_for_model,
607593
False, # pad_calibration_inputs
608594
model.config.vocab_size,
609-
device="cuda",
595+
device=device,
610596
)
611597
.record_inputs(
612598
["wikitext"],
@@ -616,7 +602,7 @@ def ffn_or_attn_only(mod, fqn):
616602
.values[0]
617603
)
618604
inputs = prepare_inputs_for_model(inputs)
619-
with torch.device("cuda"):
605+
with torch.device(device):
620606
model.setup_caches(
621607
max_batch_size=1, max_seq_length=calibration_seq_length
622608
)
@@ -883,10 +869,7 @@ def ffn_or_attn_only(mod, fqn):
883869

884870
for i in range(start, num_samples):
885871
if i == 0:
886-
if device == "cuda":
887-
torch.cuda.reset_peak_memory_stats() # MKG
888-
elif device == "xpu":
889-
torch.xpu.reset_peak_memory_stats() # MKG
872+
torch.accelerator.reset_peak_memory_stats() # MKG
890873
device_sync(device=device) # MKG
891874
if i >= 0 and interactive:
892875
prompt = input("What is your prompt? ")
@@ -1016,14 +999,10 @@ def callback(x):
1016999
torch.tensor(aggregate_metrics["decode_tokens_per_sec"])
10171000
).item()
10181001
bandwidth = model_size * tokpersec
1019-
mem = torch.cuda.max_memory_reserved() / 1e9
10201002
print(f"Average overall tokens/sec: {tokpersec:.2f}")
10211003
print(f"Average decode tokens/sec: {decode_tokpersec:.04f} s")
10221004
print(f"Average TTFT: {ttft:.04f} s")
1023-
if device == "cuda":
1024-
mem = torch.cuda.max_memory_reserved() / 1e9
1025-
elif device == "xpu":
1026-
mem = torch.xpu.max_memory_reserved() / 1e9
1005+
mem = torch.accelerator.max_memory_reserved() / 1e9
10271006
print(f"Average tokens/sec: {tokpersec:.2f}")
10281007
if batch_size > 1:
10291008
print(f"Average tokens/sec including batches {batch_size * tokpersec:.2f}")

0 commit comments

Comments
 (0)