Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@
import warnings
from typing import Any
from collections.abc import Callable
from looseversion import LooseVersion

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel, ColwiseParallel
from tqdm import tqdm
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.cache_utils import HybridChunkedCache, StaticCache
from transformers.cache_utils import StaticCache
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
from torch.distributed.tensor.placement_types import Shard
from torch.distributed.tensor import DTensor
Expand Down Expand Up @@ -335,36 +333,20 @@ def _load_model(self) -> torch.nn.Module:

return model

def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]:
def generate_batch(self) -> tuple[torch.Tensor, StaticCache]:
"""Generate a batch of input tokens"""
batch_size = self.config.batch_size
input_length = self.config.input_length

input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
if LooseVersion(transformers.__version__) >= LooseVersion("4.55"):
# Transformers deprecated HybridChunkedCache in favour of static in 4.55.x
past_key_values = StaticCache(
config=self.hf_config,
max_batch_size=input_ids.shape[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the error here, I think max_batch_size is required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for running it with transformers version 4.55.4! I was running with the latest release. Need to update the requirements pin first before merging this change.

max_cache_len=input_ids.shape[1] + self.config.output_length,
device=DEVICE,
dtype=torch.bfloat16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, device and dtype seem necessary -

from transformers.cache_utils import StaticCache
from transformers import AutoConfig, AutoModelForCausalLM
import torch

model_id = "meta-llama/Llama-4-Maverick-17B-128E"
config = AutoConfig.from_pretrained(model_id)

if hasattr(config, "text_config"):
    config = config.text_config

config.num_hidden_layers = 2

past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256)

print(past_key_values.layers[0].keys.dtype)  # torch.float32
print(past_key_values.layers[0].keys.device)  # cpu

past_key_values = StaticCache(config=config, max_batch_size=1, max_cache_len=256, dtype=torch.bfloat16, device="cuda")

print(past_key_values.layers[0].keys.dtype)  # torch.bfloat16
print(past_key_values.layers[0].keys.device)  # cuda:0

)
else:
past_key_values = HybridChunkedCache(
self.hf_config, input_ids.shape[0], input_ids.shape[1] + self.config.output_length
)
for layer_idx in range(self.hf_config.num_hidden_layers):
# key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored
# https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822
dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to preserve hf_config.num_key_value_heads // WORLD_SIZE for distributed setting.

The patch can be something like the following

diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index 212f5f8e..13af8175 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -339,9 +339,15 @@ class InferenceBenchmark:
         input_length = self.config.input_length
 
         input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE)
+        import copy
+        hf_config = copy.copy(self.hf_config)
+        hf_config.num_key_value_heads //= WORLD_SIZE
         past_key_values = StaticCache(
-            config=self.hf_config,
+            config=hf_config,
             max_cache_len=input_ids.shape[1] + self.config.output_length,
+            max_batch_size=batch_size,
+            dtype=torch.bfloat16,
+            device=DEVICE,
         )
 
         return input_ids, past_key_values

past_key_values.initialise_cache_layer(layer_idx, dummy_key_states)
past_key_values = StaticCache(
config=self.hf_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config=self.hf_config,
config=self.hf_config,
max_batch_size=input_ids.shape[0],

max_cache_len=input_ids.shape[1] + self.config.output_length,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also device and dtype seem to be required:

RuntimeError: Expected all tensors to be on the same device, but got mat2 is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_bmm)
Suggested change
max_cache_len=input_ids.shape[1] + self.config.output_length,
max_cache_len=input_ids.shape[1] + self.config.output_length,
device=DEVICE,
dtype=torch.bfloat16,

)

return input_ids, past_key_values

def get_next_token(
self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache | StaticCache
) -> torch.Tensor:
def get_next_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor:
start_pos = past_key_values.get_seq_length()
cache_position = start_pos + torch.arange(0, input_ids.shape[1], device=start_pos.device, dtype=start_pos.dtype)
with torch.no_grad():
Expand All @@ -376,7 +358,7 @@ def get_next_token(
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
return next_token

def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor:
def prefill(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor:
"""
Prefill phase: Process the entire input prompt at once.
Returns the next token.
Expand All @@ -385,7 +367,7 @@ def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache)
"""
return self.get_next_token(input_ids, past_key_values)

def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor:
def decode_one_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor:
"""
Decode phase: Generate a single token given the current sequence.
Returns the next token.
Expand All @@ -401,9 +383,7 @@ def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunk
# [rank1]: ~^^^^^
# [rank1]: RuntimeError: Cannot set version_counter for inference tensor
# @torch.inference_mode()
def generate(
self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: HybridChunkedCache
) -> dict[str, Any]:
def generate(self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: StaticCache) -> dict[str, Any]:
"""
Generate tokens using separate prefill and decode phases.
Returns detailed metrics for both phases.
Expand Down Expand Up @@ -431,7 +411,7 @@ def generate(
}

def measure_inference_step(
self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache, max_new_tokens: int
self, input_ids: torch.Tensor, past_key_values: StaticCache, max_new_tokens: int
) -> dict[str, float]:
"""Measure a single inference step with detailed timing using separate prefill/decode"""
# Generate tokens with separate prefill/decode tracking
Expand Down
Loading