-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
40 lines (32 loc) · 1.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import time
import torch
import functools
from transformers import AutoModelForCausalLM, AutoTokenizer
def time_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
exec_time = end_time - start_time
return (result, exec_time)
return wrapper
def memory_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
result, exec_time = func(*args, **kwargs)
peak_mem = torch.cuda.max_memory_allocated()
peak_mem_consumption = peak_mem / 1e9
return peak_mem_consumption, exec_time, result
return wrapper
@memory_decorator
@time_decorator
def generate_output(
prompt: str, model: AutoModelForCausalLM, tokenizer: AutoTokenizer
) -> torch.Tensor:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")
outputs = model.generate(input_ids, max_length=500)
return outputs