-
Notifications
You must be signed in to change notification settings - Fork 108
Remove ChunkedHybridCache from benchmark_inference.py #2733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,17 +22,15 @@ | |||||||||
| import warnings | ||||||||||
| from typing import Any | ||||||||||
| from collections.abc import Callable | ||||||||||
| from looseversion import LooseVersion | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
| import torch.distributed as dist | ||||||||||
| import torch.nn as nn | ||||||||||
| from torch.distributed.device_mesh import init_device_mesh | ||||||||||
| from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel, ColwiseParallel | ||||||||||
| from tqdm import tqdm | ||||||||||
| import transformers | ||||||||||
| from transformers import AutoConfig, AutoModelForCausalLM | ||||||||||
| from transformers.cache_utils import HybridChunkedCache, StaticCache | ||||||||||
| from transformers.cache_utils import StaticCache | ||||||||||
| from transformers.models.llama4.modeling_llama4 import Llama4TextMoe | ||||||||||
| from torch.distributed.tensor.placement_types import Shard | ||||||||||
| from torch.distributed.tensor import DTensor | ||||||||||
|
|
@@ -335,36 +333,20 @@ def _load_model(self) -> torch.nn.Module: | |||||||||
|
|
||||||||||
| return model | ||||||||||
|
|
||||||||||
| def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]: | ||||||||||
| def generate_batch(self) -> tuple[torch.Tensor, StaticCache]: | ||||||||||
| """Generate a batch of input tokens""" | ||||||||||
| batch_size = self.config.batch_size | ||||||||||
| input_length = self.config.input_length | ||||||||||
|
|
||||||||||
| input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE) | ||||||||||
| if LooseVersion(transformers.__version__) >= LooseVersion("4.55"): | ||||||||||
| # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x | ||||||||||
| past_key_values = StaticCache( | ||||||||||
| config=self.hf_config, | ||||||||||
| max_batch_size=input_ids.shape[0], | ||||||||||
| max_cache_len=input_ids.shape[1] + self.config.output_length, | ||||||||||
| device=DEVICE, | ||||||||||
| dtype=torch.bfloat16, | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, from transformers.cache_utils import StaticCache
from transformers import AutoConfig, AutoModelForCausalLM
import torch
model_id = "meta-llama/Llama-4-Maverick-17B-128E"
config = AutoConfig.from_pretrained(model_id)
if hasattr(config, "text_config"):
config = config.text_config
config.num_hidden_layers = 2
past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256)
print(past_key_values.layers[0].keys.dtype) # torch.float32
print(past_key_values.layers[0].keys.device) # cpu
past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256, dtype=torch.bfloat16, device="cuda")
print(past_key_values.layers[0].keys.dtype) # torch.bfloat16
print(past_key_values.layers[0].keys.device) # cuda:0 |
||||||||||
| ) | ||||||||||
| else: | ||||||||||
| past_key_values = HybridChunkedCache( | ||||||||||
| self.hf_config, input_ids.shape[0], input_ids.shape[1] + self.config.output_length | ||||||||||
| ) | ||||||||||
| for layer_idx in range(self.hf_config.num_hidden_layers): | ||||||||||
| # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored | ||||||||||
| # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 | ||||||||||
| dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to preserve The patch can be something like the following diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index 212f5f8e..13af8175 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -339,9 +339,15 @@ class InferenceBenchmark:
input_length = self.config.input_length
input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
+ import copy
+ hf_config = copy.copy(self.hf_config)
+ hf_config.num_key_value_heads //= WORLD_SIZE
past_key_values = StaticCache(
- config=self.hf_config,
+ config=hf_config,
max_cache_len=input_ids.shape[1] + self.config.output_length,
+ max_batch_size=batch_size,
+ dtype=torch.bfloat16,
+ device=DEVICE,
)
return input_ids, past_key_values |
||||||||||
| past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) | ||||||||||
| past_key_values = StaticCache( | ||||||||||
| config=self.hf_config, | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| max_cache_len=input_ids.shape[1] + self.config.output_length, | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also device and dtype seem to be required: RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_bmm)
Suggested change
|
||||||||||
| ) | ||||||||||
|
|
||||||||||
| return input_ids, past_key_values | ||||||||||
|
|
||||||||||
| def get_next_token( | ||||||||||
| self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache | StaticCache | ||||||||||
| ) -> torch.Tensor: | ||||||||||
| def get_next_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: | ||||||||||
| start_pos = past_key_values.get_seq_length() | ||||||||||
| cache_position = start_pos + torch.arange(0, input_ids.shape[1], device=start_pos.device, dtype=start_pos.dtype) | ||||||||||
| with torch.no_grad(): | ||||||||||
|
|
@@ -376,7 +358,7 @@ def get_next_token( | |||||||||
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | ||||||||||
| return next_token | ||||||||||
|
|
||||||||||
| def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: | ||||||||||
| def prefill(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: | ||||||||||
| """ | ||||||||||
| Prefill phase: Process the entire input prompt at once. | ||||||||||
| Returns the next token. | ||||||||||
|
|
@@ -385,7 +367,7 @@ def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) | |||||||||
| """ | ||||||||||
| return self.get_next_token(input_ids, past_key_values) | ||||||||||
|
|
||||||||||
| def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: | ||||||||||
| def decode_one_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: | ||||||||||
| """ | ||||||||||
| Decode phase: Generate a single token given the current sequence. | ||||||||||
| Returns the next token. | ||||||||||
|
|
@@ -401,9 +383,7 @@ def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunk | |||||||||
| # [rank1]: ~^^^^^ | ||||||||||
| # [rank1]: RuntimeError: Cannot set version_counter for inference tensor | ||||||||||
| # @torch.inference_mode() | ||||||||||
| def generate( | ||||||||||
| self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: HybridChunkedCache | ||||||||||
| ) -> dict[str, Any]: | ||||||||||
| def generate(self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: StaticCache) -> dict[str, Any]: | ||||||||||
| """ | ||||||||||
| Generate tokens using separate prefill and decode phases. | ||||||||||
| Returns detailed metrics for both phases. | ||||||||||
|
|
@@ -431,7 +411,7 @@ def generate( | |||||||||
| } | ||||||||||
|
|
||||||||||
| def measure_inference_step( | ||||||||||
| self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache, max_new_tokens: int | ||||||||||
| self, input_ids: torch.Tensor, past_key_values: StaticCache, max_new_tokens: int | ||||||||||
| ) -> dict[str, float]: | ||||||||||
| """Measure a single inference step with detailed timing using separate prefill/decode""" | ||||||||||
| # Generate tokens with separate prefill/decode tracking | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the error here, I think
max_batch_sizeis required.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for running it with transformers version 4.55.4! I was running with the latest release. Need to update the requirements pin first before merging this change.