diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d43533b5a70..c9eecd4cb4c 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -29,6 +29,12 @@ class ForwardOptions(TypedDict, total=False): # When provided, the attention layer skips its own K/V projection # and reuses the donor's K/V instead. shared_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] + # Per-call KV cache override. Used by `MultimodalTransformer` when + # `transformer_block_repeat_config` repeats a TransformerBlock so that each + # *occurrence* of the layer in the schedule writes to its own KV cache + # rather than sharing the layer's `self.kv_cache`. When None or absent the + # attention falls back to `self.kv_cache`. + kv_cache_override: Optional["KVCache"] class Attention(nn.Module, ABC): @@ -276,7 +282,7 @@ def __init__( [0, 1, 2, 3, 4, NA, NA, NA] After cache update we would have [8, 1, 2, 3, 4, 5, 6, 7]. We kicked out token at pos = 0. However, the current step still has access to [pos - sliding_window_size, pos] tokens. - + To make sure we dont over attend, i.e. we dont have pos = 5 to attend to pos = 1, mask calculaton has to account for the sliding window size. @@ -573,8 +579,12 @@ def forward( q, k, v = self._prepare_qkv(q, x, bsz, seqlen, freqs_cos, freqs_sin) if self.use_kv_cache: + # Per-call KV cache override (used when a TransformerBlock is invoked + # multiple times via `transformer_block_repeat_config` so each + # occurrence has its own KV cache). Falls back to `self.kv_cache`. + active_kv_cache = kwargs.get("kv_cache_override") or self.kv_cache assert input_pos is not None - is_ring_buffer = getattr(self.kv_cache, "is_ring_buffer", False) + is_ring_buffer = getattr(active_kv_cache, "is_ring_buffer", False) if is_ring_buffer: # Ring buffer models compute their own mask after KV cache @@ -594,14 +604,14 @@ def forward( # Only update KV cache for non-shared layers if shared_kv is None: - assert self.kv_cache is not None, ( + assert active_kv_cache is not None, ( "kv_cache is required when shared_kv is not provided. " "This layer may be a YOCO shared layer that requires shared_kv from a donor." ) - k, v = self.kv_cache.update(input_pos, k, v) + k, v = active_kv_cache.update(input_pos, k, v) if is_ring_buffer: - attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + attn_mask = active_kv_cache.create_causal_mask_for_ring_buffer( input_pos[0].item(), seqlen ) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index ed661c75517..67e07134cca 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch.nn.functional as F @@ -182,6 +182,12 @@ class ModelArgs: use_ffn_learnable_scales: bool = False output_soft_cap_temp: Optional[float] = None + # Block repetition: repeat contiguous ranges of transformer layers. + # List of {"start": int, "end": int, "count": int} dicts where start/end + # are layer indices (both inclusive) and count is total number of passes + # (1 = normal, 2 = run the block twice, etc.). Blocks must not overlap. + transformer_block_repeat_config: Optional[list] = None + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads @@ -224,3 +230,50 @@ def find_multiple(n: int, k: int) -> int: # Convert string act_fn to enum if needed if isinstance(self.act_fn, str): self.act_fn = ActFn.from_string(self.act_fn) + + self.validate_block_repeat_config() + + def validate_block_repeat_config(self) -> None: + """Validate transformer_block_repeat_config field. + + Called from __post_init__ and should also be called after setting + transformer_block_repeat_config post-construction. + """ + if self.transformer_block_repeat_config is None: + return + for i, block in enumerate(self.transformer_block_repeat_config): + assert ( + "start" in block and "end" in block and "count" in block + ), f"transformer_block_repeat_config[{i}] must have 'start', 'end', and 'count' keys" + assert 0 <= block["start"] <= block["end"] < self.n_layers, ( + f"transformer_block_repeat_config[{i}]: invalid range [{block['start']}, {block['end']}] " + f"for {self.n_layers} layers" + ) + assert ( + block["count"] >= 1 + ), f"transformer_block_repeat_config[{i}]: count must be >= 1" + # Check for overlapping blocks (end is inclusive, so next start must be > prev end) + sorted_blocks = sorted( + self.transformer_block_repeat_config, key=lambda b: b["start"] + ) + for i in range(1, len(sorted_blocks)): + assert sorted_blocks[i]["start"] > sorted_blocks[i - 1]["end"], ( + f"transformer_block_repeat_config: blocks {sorted_blocks[i-1]} and " + f"{sorted_blocks[i]} overlap" + ) + + @staticmethod + def normalize_block_repeat_config( + config: Optional[List[Dict[str, int]]], + ) -> Optional[List[Dict[str, int]]]: + """Drop entries with `count == 1`; return None if nothing remains. + + A block-repeat entry with count=1 visits its layers exactly once -- + the same as if the entry were omitted. Stripping these no-ops at + assignment time lets every downstream consumer assume each entry is + a genuine repeat (count > 1). Pure function: does not mutate input. + """ + if not config: + return None + normalized = [b for b in config if b.get("count", 1) > 1] + return normalized if normalized else None