diff --git a/src/csrc/kernels/attention.metal b/src/csrc/kernels/attention.metal new file mode 100644 index 00000000..e69de29b diff --git a/src/csrc/kv_cache.cpp b/src/csrc/kv_cache.cpp new file mode 100644 index 00000000..56e4a3fe --- /dev/null +++ b/src/csrc/kv_cache.cpp @@ -0,0 +1,427 @@ + + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace py = pybind11; + +// Paged and Sparse KV Cache + +constexpr size_t BLOCK_SIZE = 128; // tokens +constexpr size_t MAX_BLOCKS = 2 << 14; + +// Lock-free Treiber list implementation +// From : https://people.csail.mit.edu/shanir/publications/Lock_Free.pdf +struct Node { + int block_id = -1; + std::atomic next{nullptr}; +}; + +struct KVBlock { + int physical_idx = -1; +}; + +// Use large single buffers for performance +class PageAllocator { + private: + MTL::Device* m_device; + MTL::Buffer* k_cache; + MTL::Buffer* v_cache; + std::vector node_pool; + std::atomic free_node_idx{-1}; + std::atomic HEAD{nullptr}; + size_t bytesize = 0; + size_t nr_kv_heads = 0; + size_t block_size = 0; + size_t head_dim = 0; + + public: + PageAllocator(size_t nr_heads, size_t kv_bytes, size_t head_dim) + : nr_kv_heads(nr_heads), head_dim(head_dim), + bytesize(kv_bytes) + { + m_device = MTL::CreateSystemDefaultDevice(); + if(!m_device) { + throw std::runtime_error("Failed to create Metal device."); + } + k_cache = m_device->newBuffer(MAX_BLOCKS*bytesize, MTL::ResourceStorageModeShared); + v_cache = m_device->newBuffer(MAX_BLOCKS*bytesize, MTL::ResourceStorageModeShared); + + // node pool + node_pool.reserve(nr_heads * MAX_BLOCKS); + for(int i=0; i< nr_heads*MAX_BLOCKS; ++i) { + node_pool.emplace_back(new Node()); + } + free_node_idx.store(nr_heads*MAX_BLOCKS - 1); + + HEAD.store(nullptr, std::memory_order_relaxed); + for(int i=0; i< nr_heads*MAX_BLOCKS; ++i) { + push(i); + } + } + + ~PageAllocator() { + k_cache->release(); + v_cache->release(); + for(auto* node: node_pool) { + delete node; + } + } + + int allocate_block() { + Node* old_head; + do { + old_head = HEAD.load(std::memory_order_acquire); + if(!old_head) { + throw std::runtime_error("No free blocks"); + } + } while(!HEAD.compare_exchange_weak(old_head, + old_head->next.load(std::memory_order_relaxed), + std::memory_order_release)); + int block = old_head->block_id; + recycle_node(old_head); + return block; + } + + int free_block(int block) { + Node* node = get_node_from_pool(); + if(!node) { + throw std::runtime_error("Node pool exhausted"); + } + node->block_id = block; + Node* old_head; + do { + old_head = HEAD.load(std::memory_order_acquire); + node->next.store(old_head, std::memory_order_relaxed); + } while(!HEAD.compare_exchange_weak(old_head, node, std::memory_order_release)); + } + + MTL::Device* device() { return m_device; } + MTL::Buffer* get_k_cache() { return k_cache; } + MTL::Buffer* get_v_cache() { return v_cache; } + + // Debug helpers + std::vector get_block(bool key, size_t b, size_t h, size_t bid) { + //std::lock_guard lock(mux); + MTL::Buffer* cache = key ? k_cache : v_cache; + size_t offset = (b*nr_kv_heads*MAX_BLOCKS + h*MAX_BLOCKS + bid); + simd_float16* data = reinterpret_cast(cache->contents()) + offset; + return std::vector(data, data+block_size); + } + + private: + Node* get_node_from_pool() { + int idx = free_node_idx.fetch_sub(1, std::memory_order_relaxed); + if(idx<0) { + free_node_idx.fetch_add(1, std::memory_order_relaxed); + return nullptr; + } + return node_pool[idx]; + } + + void recycle_node(Node* node) { + int idx = free_node_idx.fetch_add(1, std::memory_order_relaxed) + 1; + if(idx > (int)MAX_BLOCKS) { + std::cerr << "Node pool overflow." << std::endl; + return; + } + node_pool[idx] = node; + node->next.store(nullptr, std::memory_order_relaxed); + } + + // Grab new block and move HEAD + void push(int block) { + Node* n_node = get_node_from_pool(); + n_node->block_id = block; + Node* old_head; + do { + old_head = HEAD.load(std::memory_order_acquire); + n_node->next.store(old_head, std::memory_order_relaxed); + } while(!HEAD.compare_exchange_weak(old_head, n_node, + std::memory_order_release, + std::memory_order_relaxed)); + } +}; + +class SparseKVCache { + private: + std::mutex mux; + PageAllocator* alloc; + MTL::CommandQueue* queue; + std::vector>> page_table; // [B, head, block_id] + MTL::Buffer* page_table_buffer = nullptr; + size_t bytesize = 0; // item byte size (default f16 - 4) + std::vector seq_len; // Sequence lengths of each batch + std::vector sq_offs; // Sequence offsets + size_t batch_size = 0; + size_t block_size = 0; + size_t num_kv_heads = 0; + size_t head_dim = 0; + size_t max_blocks = MAX_BLOCKS; + + public: + SparseKVCache(PageAllocator* allocator, size_t bytes, size_t block_size, + size_t batch_size, size_t head_size, size_t num_heads) + : alloc(allocator), batch_size(batch_size), block_size(block_size), + head_dim(head_size), num_kv_heads(num_heads), + seq_len(batch_size, 0), bytesize(bytes), + page_table(batch_size, std::vector>(num_heads)) + { + queue = alloc->device()->newCommandQueue(); + seq_len.reserve(batch_size); + sq_offs.reserve(batch_size); + + for(int b=0; ballocate_block(); + if(n_blk.physical_idx == -1) { + throw std::runtime_error("Failed to allocate new physical block."); + } + pt.push_back(n_blk); + } + } + + // Non-blocking append function + void append(MTL::Buffer* new_k, + MTL::Buffer* new_v, + std::vector num_new_tokens) + { + std::lock_guard lock(mux); + + if(num_new_tokens.size() != batch_size) { + throw std::runtime_error("New token counts don't match known batch size."); + } + + MTL::CommandBuffer* cmd_buf = queue->commandBuffer(); + MTL::BlitCommandEncoder* blit = cmd_buf->blitCommandEncoder(); + + // new_k, new_v of shape [B, 1, Hkv, D] + for (size_t b=0; bcopyFromBuffer(new_k, src_off, alloc->get_k_cache(), dst_off, cpy_bytes); + blit->copyFromBuffer(new_v, src_off, alloc->get_v_cache(), dst_off, cpy_bytes); + + // Set block metadata + size_t current_tok = seq_len[b]; + size_t remaining = num_new_tokens[b]; + size_t pos = 0; + while(remaining > 0) { + size_t num_blocks = current_tok / BLOCK_SIZE; + size_t block_offset = current_tok % BLOCK_SIZE; + size_t to_copy = std::min(remaining, BLOCK_SIZE - block_offset); + for(size_t h=0; hendEncoding(); + cmd_buf->commit(); + } + + // Pack working physical block indices according to sparsity mask + void pack_table_buffer(MTL::Buffer* page_table_buff, MTL::Buffer* mask) { + std::lock_guard lock(mux); + + size_t max_blocks = 0; + for(size_t b=0; blength() < batch_size*num_kv_heads*max_blocks*sizeof(unsigned char)) { + throw std::runtime_error("Sparsity mask is too small for KV Cache."); + } + + unsigned char* mask_data = reinterpret_cast(mask->contents()); + + // Temp buffer that stores all indices + offset and count metadata + // of blocks that are active in the sparse_mask + + // [total_blocks * sizeof(int)] indices of active physical blocks + // [batch_size * num_kv_heads * sizeof(unsigned int)] offsets (starting point of indices array) + // [batch_size * num_kv_heads * sizeof(unsigned int)] counts (number of active blocks for b and h) + size_t max_buffer_size = batch_size*num_kv_heads*max_blocks*sizeof(int) + 2*batch_size*num_kv_heads*sizeof(unsigned int); + MTL::Buffer* temp = alloc->device()->newBuffer(max_buffer_size, MTL::ResourceStorageModeShared); + int* temp_idx = reinterpret_cast(temp->contents()); + unsigned int* temp_off = reinterpret_cast(temp_idx + batch_size * num_kv_heads * max_blocks); + unsigned int* temp_counts = temp_off + batch_size*num_kv_heads; + + size_t total_blocks = 0; + std::vector block_counts(batch_size*num_kv_heads, 0); + for(size_t b=0; blength() != real_buffer_size) { + page_table_buff->release(); + page_table_buff = nullptr; + } + if(!page_table_buff) { + page_table_buff = alloc->device()->newBuffer(real_buffer_size, MTL::ResourceStorageModeShared); + } + + // Copy to final buffer + MTL::CommandBuffer* cmd_buf = queue->commandBuffer(); + MTL::BlitCommandEncoder* blit = cmd_buf->blitCommandEncoder(); + blit->copyFromBuffer(temp, 0, page_table_buff, 0, real_buffer_size); + blit->endEncoding(); + cmd_buf->commit(); + cmd_buf->waitUntilCompleted(); + temp->release(); + + page_table_buffer = page_table_buff; + } + + std::pair + read(MTL::Buffer* sparse_mask) { + std::lock_guard lock(mux); + + MTL::Buffer* sparse_block_idxs; + pack_table_buffer(sparse_block_idxs, sparse_mask); + int* temp_indices = reinterpret_cast(sparse_block_idxs->contents()); + unsigned int* temp_off = reinterpret_cast(temp_indices + batch_size*num_kv_heads*max_blocks); + unsigned int* temp_counts = temp_off + batch_size*num_kv_heads; + + // Get output sizes + size_t max_tokens = 0; + std::vector b_sizes(batch_size, 0); + for(size_t b=0; bdevice()->newBuffer(max_size, MTL::ResourceStorageModeShared); + MTL::Buffer* v_out = alloc->device()->newBuffer(max_size, MTL::ResourceStorageModeShared); + + MTL::CommandBuffer* cmd_buf = queue->commandBuffer(); + MTL::BlitCommandEncoder* blit = cmd_buf->blitCommandEncoder(); + + // Gather values from indices + for(size_t b=0; bcopyFromBuffer(alloc->get_k_cache(), src_offset, k_out, dst_offset, copy_size); + blit->copyFromBuffer(alloc->get_v_cache(), src_offset, v_out, dst_offset, copy_size); + out_pos += tokens_in_block; + } + } + } + blit->endEncoding(); + cmd_buf->commit(); + cmd_buf->waitUntilCompleted(); + sparse_block_idxs->release(); + return {k_out, v_out}; + } + + // LRU block eviction from sparse mask history + void evict_block(size_t b, size_t h, size_t block) {} + + // Debug helpers + std::vector* get_seq_offsets() { return &sq_offs; } + std::vector* get_seq_len() { return &seq_len; } + + std::vector>> get_page_table() { + std::vector>> ret(batch_size, std::vector>(num_kv_heads)); + for(size_t b=0; b get_data(bool key, size_t b, size_t h, size_t block) { + if(b >= batch_size || h >= num_kv_heads || block >= MAX_BLOCKS) { + throw std::runtime_error("Invalid request for 'get_block': {b}, {h}, {block}"); + } + return alloc->get_block(key, b, h, block); + } +}; + + +// Bind to python +PYBIND11_MODULE(kv_cache, m) { + py::class_>(m, "PageAllocator") + .def(py::init()); + + py::class_>(m, "SparseKVCache") + .def(py::init()) + .def("append", &SparseKVCache::append) + .def("pack_table_buffer", &SparseKVCache::pack_table_buffer) + .def("read", &SparseKVCache::read) + .def("get_seq_len", &SparseKVCache::get_seq_len) + .def("get_seq_offset", &SparseKVCache::get_seq_offsets) + .def("get_page_table", &SparseKVCache::get_page_table) + .def("data", &SparseKVCache::get_data); +} diff --git a/src/dnet/ring/model/qwen3_sparse.py b/src/dnet/ring/model/qwen3_sparse.py new file mode 100644 index 00000000..f4655248 --- /dev/null +++ b/src/dnet/ring/model/qwen3_sparse.py @@ -0,0 +1,442 @@ + + +from typing import Any, Dict, List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.qwen3 import ModelArgs, MLP +from src.runtime.sparse_attention import SparseAttention, FlexPrefillSparseAttention + +# Sparse attention transformer block +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.hidden_size = args.hidden_size + self.num_attention_heads = args.num_attention_heads + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.self_attn = SparseAttention(args, FlexPrefillSparseAttention()) + + def __call__( + self, + x:mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(x)) + return h + r + + +from .base import BaseRingModel +import logging + +logger = logging.getLogger(__name__) + + +class Qwen3RingModel(BaseRingModel): + model_type = "qwen3" + + def __init__( + self, + model_config: Any, + assigned_layers: Optional[List[int]] = None, + is_api_layer: bool = False, + ): + super().__init__() + + if is_api_layer and assigned_layers: + raise RuntimeError("API layer doesn't handle layers") + + self.model_config = model_config + self.is_api_layer = is_api_layer + self.config = config = ModelArgs.from_dict(model_config) + + logger.info( + f"Initializing Qwen3RingModel: is_api_layer={is_api_layer}, assigned_layers={assigned_layers}" + ) + logger.info( + f"Config: hidden_size={config.hidden_size}, num_heads={config.num_attention_heads}, num_kv_heads={config.num_key_value_heads}" + ) + + # The API layer handles embedding and normalization + if is_api_layer: + # Start with regular Embedding, will be converted if needed when loading quantized weights + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # Qwen3 can tie embeddings; add head only if not tied + if not config.tie_word_embeddings: + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) + + # Other layers: local zero-based list and abs->local map + self.layers: List[nn.Module] = [] + self.abs_to_local: Dict[int, int] = {} + + # Check if model is quantized from config + self.is_quantized = "quantization" in model_config + if self.is_quantized: + self.quantization_config = model_config["quantization"] + logger.info(f"Model is quantized with config: {self.quantization_config}") + + # For now, create regular TransformerBlocks + # They will be converted to quantized on first weight load if needed + for i, layer in enumerate(sorted(assigned_layers or [])): + self.layers.append(TransformerBlock(config)) + self.abs_to_local[layer] = i + + # For shard layers (non-API), enable quantization upfront if configured. + # This ensures QuantizedLinear modules are in place before loading + # per-layer int weights/scales from the weight cache. + if not is_api_layer and self.is_quantized: + logger.info("Applying quantization for shard layers") + self.apply_quantization() + elif not is_api_layer: + logger.info("Not applying quantization - model is not quantized") + + logger.info(f"Created {len(self.layers)} sparse TransformerBlocks layers") + #logger.info(f"Created {len(self.layers)} TransformerBlock layers") + logger.info(f"abs_to_local mapping: {self.abs_to_local}") + + # Flag to track if layers have been converted to quantized + self._converted_to_quantized = False + + @staticmethod + def class_predicate(p, m): + return hasattr(m, "to_quantized") + + def apply_quantization(self): + """Apply quantization after weights are loaded""" + # Skip if this is the API layer + if self.is_api_layer: + return + + # Check if model is already quantized by checking if any Linear layers are QuantizedLinear + from mlx.nn.layers.quantized import QuantizedLinear + + for layer in self.layers: + for module in layer.modules(): + if isinstance(module, QuantizedLinear): + # Model is already quantized, skip re-quantization + return + + # Only quantize if not already quantized and quantization config exists + if "quantization" in self.model_config: + quant_config = self.model_config["quantization"].copy() + nn.quantize( + self, + **quant_config, + class_predicate=Qwen3RingModel.class_predicate, + ) + + def embed(self, x: mx.array) -> mx.array: + return self.embed_tokens(x) if self.is_api_layer else x + + def normalize(self, x: mx.array) -> mx.array: + return self.norm(x) if self.is_api_layer else x + + def lm_project(self, x: mx.array) -> mx.array: + if self.is_api_layer: + if self.config.tie_word_embeddings: + # For tied embeddings, use embed_tokens as linear projection + return self.embed_tokens.as_linear(x) + return self.lm_head(x) + return x + + def forward(self, x: mx.array, cache: Optional[List[Any]] = None) -> mx.array: + # Create attention mask + mask = create_attention_mask(x, cache) + + if cache is None: + cache = [None] * len(self.layers) + + # Apply in local order 0..len-1 + for i, layer in enumerate(self.layers): + x = layer(x, mask, cache[i] if i < len(cache) else None) + + return x + + def apply_single_layer( + self, layer_idx: int, x: mx.array, cache: Optional[List[Any]] = None + ) -> mx.array: + if layer_idx not in self.abs_to_local: + raise RuntimeError(f"Layer {layer_idx} not hosted on this model instance") + + mask = create_attention_mask(x, cache) + local_idx = self.abs_to_local[layer_idx] + + logger.info( + f"apply_single_layer: layer_idx={layer_idx}, local_idx={local_idx}, input shape={x.shape}" + ) + + # Log the layer's weight shapes + layer = self.layers[local_idx] + if hasattr(layer, "self_attn"): + if hasattr(layer.self_attn, "q_proj"): + if hasattr(layer.self_attn.q_proj, "weight"): + logger.info( + f"Layer {layer_idx} q_proj weight shape: {layer.self_attn.q_proj.weight.shape}" + ) + else: + logger.info(f"Layer {layer_idx} q_proj has no weight attribute") + if hasattr(layer.self_attn, "k_proj"): + if hasattr(layer.self_attn.k_proj, "weight"): + logger.info( + f"Layer {layer_idx} k_proj weight shape: {layer.self_attn.k_proj.weight.shape}" + ) + else: + logger.info(f"Layer {layer_idx} k_proj has no weight attribute") + + c = None + if cache is not None and local_idx < len(cache): + c = cache[local_idx] + + result = self.layers[local_idx](x, mask, c) + logger.info(f"Layer {layer_idx} output shape: {result.shape}") + return result + + def _convert_layers_to_quantized(self, weights): + """Convert Linear layers to QuantizedLinear based on weight structure""" + if self._converted_to_quantized: + return + + # Check if weights are quantized by looking for .scales and .biases + weight_keys = [k for k, _ in weights] + has_scales = any(".scales" in k for k in weight_keys) + has_biases = any(".biases" in k for k in weight_keys) + + if not (has_scales and has_biases): + logger.info("Weights are not quantized, keeping Linear layers") + return + + logger.info( + "Detected quantized weights, converting layers to quantized versions" + ) + + from mlx.nn.layers.quantized import QuantizedLinear, QuantizedEmbedding + + # Infer quantization parameters from weight shapes + # Default to common values if not in config + group_size = 64 + bits = 8 + if hasattr(self, "quantization_config"): + group_size = self.quantization_config.get("group_size", 64) + bits = self.quantization_config.get("bits", 8) + + logger.info(f"Using quantization: bits={bits}, group_size={group_size}") + + # Convert embedding layer for API layer + if self.is_api_layer and hasattr(self, "embed_tokens"): + if has_scales and "embed_tokens.scales" in weight_keys: + # Get the actual dimensions from the quantized weights + embed_weight = next( + (v for k, v in weights if k == "embed_tokens.weight"), None + ) + if embed_weight is not None: + vocab_size = embed_weight.shape[0] + # The quantized weight shape is (vocab_size, compressed_dim) + # We need to infer the original hidden_size + hidden_size = self.config.hidden_size + + # Use the same group_size as the quantized weights + # Infer from scales shape: scales.shape[1] * group_size = hidden_size + embed_scales = next( + (v for k, v in weights if k == "embed_tokens.scales"), None + ) + if embed_scales is not None: + num_groups = embed_scales.shape[1] + embed_group_size = hidden_size // num_groups + else: + embed_group_size = group_size + + # Replace with QuantizedEmbedding + self.embed_tokens = QuantizedEmbedding( + vocab_size, hidden_size, group_size=embed_group_size, bits=bits + ) + logger.info( + f"Converted embed_tokens to QuantizedEmbedding: {vocab_size}x{hidden_size}, group_size={embed_group_size}" + ) + + # Convert each layer's Linear modules to QuantizedLinear + for layer_idx, layer in enumerate(self.layers): + if hasattr(layer, "self_attn"): + # Convert attention layers + for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]: + if hasattr(layer.self_attn, proj_name): + linear = getattr(layer.self_attn, proj_name) + if isinstance(linear, nn.Linear) and not isinstance( + linear, QuantizedLinear + ): + # Create QuantizedLinear with same dimensions + in_features = linear.weight.shape[1] + out_features = linear.weight.shape[0] + ql = QuantizedLinear( + in_features, + out_features, + bias=False, + group_size=group_size, + bits=bits, + ) + setattr(layer.self_attn, proj_name, ql) + logger.debug( + f"Converted layer {layer_idx} self_attn.{proj_name} to QuantizedLinear" + ) + + if hasattr(layer, "mlp"): + # Convert MLP layers + for proj_name in ["gate_proj", "up_proj", "down_proj"]: + if hasattr(layer.mlp, proj_name): + linear = getattr(layer.mlp, proj_name) + if isinstance(linear, nn.Linear) and not isinstance( + linear, QuantizedLinear + ): + in_features = linear.weight.shape[1] + out_features = linear.weight.shape[0] + ql = QuantizedLinear( + in_features, + out_features, + bias=False, + group_size=group_size, + bits=bits, + ) + setattr(layer.mlp, proj_name, ql) + logger.debug( + f"Converted layer {layer_idx} mlp.{proj_name} to QuantizedLinear" + ) + + self._converted_to_quantized = True + logger.info("Successfully converted all Linear layers to QuantizedLinear") + + def load_weights(self, weights, strict=False): + """Load weights into the model""" + logger.info(f"load_weights called with {len(weights)} weights") + logger.info(f"First few weight keys: {[k for k, _ in weights[:5]]}") + logger.info(f"abs_to_local mapping: {self.abs_to_local}") + + # Convert layers to quantized if loading quantized weights + self._convert_layers_to_quantized(weights) + + # Filter weights to only include what this shard needs + shard_weights = {} + + for key, value in weights: + # Accept both bare 'layers.*' and 'model.layers.*' and remap abs->local + if key.startswith("model.layers.") or key.startswith("layers."): + parts = key.split(".") + idx_pos = 2 if parts[0] == "model" else 1 + try: + abs_idx = int(parts[idx_pos]) + except Exception: + continue + if abs_idx not in self.abs_to_local: + # Skip layers not assigned to this shard + continue + local_idx = self.abs_to_local[abs_idx] + # Keep the "layers" prefix and remap index + parts[idx_pos] = str(local_idx) + # Remove "model." prefix if present + if parts[0] == "model": + parts = parts[1:] + new_key = ".".join(parts) + logger.debug(f"Mapping weight {key} (shape {value.shape}) -> {new_key}") + shard_weights[new_key] = value + elif self.is_api_layer: + # API layer needs embed_tokens, norm, and lm_head + # Weights come as "embed_tokens.weight", "norm.weight", etc. + if key.startswith("embed_tokens"): + shard_weights[key] = value + logger.info(f"API layer: loading {key}, shape={value.shape}") + elif key.startswith("norm"): + shard_weights[key] = value + logger.info(f"API layer: loading {key}, shape={value.shape}") + elif key.startswith("lm_head") and not self.config.tie_word_embeddings: + shard_weights[key] = value + logger.info(f"API layer: loading {key}, shape={value.shape}") + + logger.info(f"Loading {len(shard_weights)} weights into model") + + if shard_weights: + # Log the first weight being loaded to check dimensions + first_key = list(shard_weights.keys())[0] + logger.info( + f"First weight to load: {first_key} with shape {shard_weights[first_key].shape}" + ) + + # Check what layer this is for + if "layers." in first_key: + layer_idx = first_key.split(".")[1] + logger.info(f"Loading into local layer {layer_idx}") + logger.info(f"Number of layers in model: {len(self.layers)}") + if int(layer_idx) < len(self.layers): + layer = self.layers[int(layer_idx)] + # Log the current layer structure + if hasattr(layer, "self_attn"): + if hasattr(layer.self_attn, "q_proj"): + logger.info( + f"Current q_proj type: {type(layer.self_attn.q_proj)}" + ) + if hasattr(layer.self_attn.q_proj, "weight"): + logger.info( + f"Current q_proj weight shape: {layer.self_attn.q_proj.weight.shape}" + ) + + # Load the filtered weights using parent class method + try: + super().load_weights(list(shard_weights.items()), strict=strict) + logger.info("Successfully loaded weights") + + # Verify weights were actually loaded (not just shape but values) + if shard_weights and "layers." in first_key: + layer_idx = first_key.split(".")[1] + if int(layer_idx) < len(self.layers): + layer = self.layers[int(layer_idx)] + if hasattr(layer, "self_attn") and hasattr( + layer.self_attn, "q_proj" + ): + if hasattr(layer.self_attn.q_proj, "weight"): + weight = layer.self_attn.q_proj.weight + logger.info( + f"After loading - q_proj weight stats: shape={weight.shape}, mean={mx.mean(weight).item():.6f}, std={mx.std(weight).item():.6f}" + ) + # Check if weights are reasonable (not all zeros or random) + if ( + mx.abs(mx.mean(weight)).item() < 1e-6 + and mx.std(weight).item() < 1e-6 + ): + logger.warning( + "WARNING: q_proj weights appear to be all zeros!" + ) + elif mx.std(weight).item() > 1.0: + logger.warning( + "WARNING: q_proj weights have very high std dev, might be uninitialized!" + ) + except Exception as e: + logger.error(f"Failed to load weights: {e}") + logger.error(f"Weight keys: {list(shard_weights.keys())}") + raise + + # Don't apply quantization for pre-quantized models + # Pre-quantized models already have QuantizedLinear layers from weight loading + + @property + def decoding_layers(self): + return self.layers + + @property + def head_dim(self) -> Tuple[int, int]: + # Qwen3 uses the same head_dim for both Q and V + return (self.config.head_dim, self.config.head_dim) + + @property + def n_kv_heads(self) -> int: + return self.config.num_key_value_heads + + @property + def num_layers(self) -> int: + # Number of local decoding layers hosted on this shard model + return len(self.layers) diff --git a/src/dnet/runtime/sparse_attention.py b/src/dnet/runtime/sparse_attention.py new file mode 100644 index 00000000..fb43d026 --- /dev/null +++ b/src/dnet/runtime/sparse_attention.py @@ -0,0 +1,719 @@ + +import time +import math +import inspect +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union +#from ..util import logger + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.qwen3 import ModelArgs +from mlx_lm.models.rope_utils import initialize_rope +from mlx_lm.models.base import scaled_dot_product_attention + +logger = logging.getLogger(__name__) + +BLOCK_SIZE = 128 +#REDUCTION_KERNEL_SIZE = 2 + +# NOTE Look into AdaFlash for adaptive fine-grained + +# NOTE: Rely on mx.matmul for dense tile sgemm for now, until we have metal kernels +# Compute sparsity of packed matrices per block +def sparse_dot_product_attention_blocked( + Q: mx.array, # (B, Hq, Lq, D) + K_sel: mx.array, # (B, Hkv, Lk_sel_max, D) (packed) + V_sel: mx.array, # (B, Hkv, Lk_sel_max, D) (packed) + scale: float, + mask: Optional[mx.array] = None, # (B, Hq, Lq, Lk_sel_max) + padding_mask: Optional[mx.array] = None, # (B, Hq, 1, 1) + selected_counts: Optional[mx.array] = None, + group_size: int = 64, + bits: int = 8 +) -> mx.array: + + K = K_sel + V = V_sel + B, n_q_heads, L, D = Q.shape + n_kv_heads = K.shape[-3] + n_repeats = n_q_heads // n_kv_heads + + #print(Q.shape, K.shape, V.shape) + Q *= scale + + if n_repeats > 1: + Q = mx.reshape(Q, (B, n_kv_heads, n_repeats, L, D)) + else: + Q = mx.expand_dims(K, axis=-3) + K = mx.expand_dims(K, axis=-3) + V = mx.expand_dims(V, axis=-3) + + scores = mx.matmul(Q, K.transpose(0, 1, 2, 4, 3)) + #print(scores) + if mask is not None: + """ # Use for running non-sparse attention + if padding_mask is not None: + padding_mask = padding_mask.reshape(1, -1) + padding_mask = mx.repeat(padding_mask, L, axis=0) + scores = mx.where(padding_mask, scores, mx.finfo(scores.dtype).min) + """ + + Lq, Lkv = scores.shape[-2:] + if Lq > 1: + q_idx = mx.arange(Lq) + k_idx = mx.arange(Lkv) + mask = q_idx.reshape(-1, 1) >= k_idx.reshape(1, -1) + #if mask.ndim > 2: + # mask = mask.reshape(B, n_kv_heads, 1, mask.shape[2], mask.shape[3]) + scores = mx.where(mask, scores, mx.finfo(scores.dtype).min) + + + #print(mask) + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.matmul(scores, V) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out.transpose(0, 2, 1, 3).reshape(B, L, -1) + + +@dataclass +class StrategyInput: + Q: mx.array + K: mx.array + mask: Optional[mx.array] + block_size: int # pooling block size + gqa_interleave: bool = False # mapping style + last_q: Optional[mx.array] = None # (B, block_size, Hq, D) + gamma: float = 0.9 # top-p value [0, 1] + min_budget: int = 1 # min blocks nr + max_budget: int = 2147483647 # max blocks nr + tau: float = 0.0 # JSD threashold + update_frequency: int = 10 # # cycles before we refresh JSD metrics + is_token_level: bool = False # Token-level sparsity indices + + +@dataclass +class Strategy: + blocks: mx.array + tokens: mx.array + block_size: int + selected_counts: Optional[mx.array] + min_budget: int + max_budget: int + update_frequency: int = 10 + # TODO: Debug strategy metadata + # Potentially create different budgets? + + +class AbstractSparseStrategy: + """ Strategies decide the block indices of attention we compute """ + + def __init__(self, strc: Dict[str, Any] = None) -> None: + pass + + def reset(self) -> None: + pass + + def __call__(self, input: StrategyInput) -> Strategy: + pass + + def update(self, input: StrategyInput, prev_blocks: mx.array, prev_Lkv: int) -> Strategy: + """ Update the indices every decode loop for new tokens """ + pass + + +# https://arxiv.org/abs/2502.20766 +class FlexPrefillSparseAttention(AbstractSparseStrategy): + def __init__(self, strc: Dict[str, Any] = None): + super().__init__() + self.blocks = None + self.score_cache = None # for incremental update + self.pattern_type = None # Query-Aware (0) or Vertical-Slash (1) + + def reset(self): + self.blocks = None + self.score_cache = None + self.pattern_type = None + + def kullback_leibler(self, d0, d1): + return mx.sum(d0 * mx.log2(d0 / d1)) + + def jensen_shannon_metric(self, d0, d1): + median = 0.5*(d0 + d1) + div = 0.5*self.kullback_leibler(d0, median) + 0.5*self.kullback_leibler(d1, median) + return div**0.5 + + def top_p(self, x, p:float): + cum_probs = mx.cumsum(x, axis=1) + top = mx.where(cum_probs > 1-p, x, mx.zeros_like(x)) + return top + + # Q: (B, Lq, Hkv, D) + # K/V: (B, Lkv, Hkv, D) + # Computes a low-resolution map of sparse indices depending on the input query + def query_aware_search(self, Q, K, gamma, head_dim, min_budget, max_budget, prev_scores=None, prev_Lkv=None): + assert min_budget >= 1 and max_budget >= 1, "budgets must be at least 1" + assert min_budget <= max_budget, "min_budget must be smaller then max_budget" + pool = nn.AvgPool1d(kernel_size=BLOCK_SIZE, stride=BLOCK_SIZE) + + D = head_dim + B, Lq, Hq = Q.shape[:3] + Lk = K.shape[1] + Hkv = K.shape[2] + n_repeats = Hq // Hkv + assert Hq % Hkv == 0, "query heads must be a multiple of KV heads for GQA" + num_blocks = math.ceil(Lk / BLOCK_SIZE) + + # Group Q for GQA + if Lq < BLOCK_SIZE: + padding = BLOCK_SIZE - Lq + Q = mx.pad(Q, [(0,0),(0,padding), (0,0), (0,0)], constant_values=0) + Q_gqa = Q[:, -BLOCK_SIZE:, :, :] + Q_gqa = Q_gqa.reshape(B, BLOCK_SIZE, n_repeats, Hkv, head_dim) + Q_gqa = mx.mean(Q_gqa, axis=2) # (B, BLOCK_SIZE, Hkv, head_dim) + + # TODO: Handle the step % repeat_count condition for recomputation too + if prev_scores is not None and prev_Lkv is not None: # Only compute new blocks + new_blocks = num_blocks - math.ceil(prev_Lkv / BLOCK_SIZE) + if new_blocks > 0: + K_new = K[:, prev_Lkv:, :] + L_new = Lk - prev_Lkv + assert L_new > 0 + K_hat_new = pool(K_new.reshape(B*Hkv, L_new, D)).reshape(B, Hkv, -1, D) + A_hat_new = nn.softmax(mx.matmul(Q_gqa.reshape(B*Hkv, Lq, D), K_hat_new.transpose(0, 1, 3, 2)) / mx.sqrt(D), axis=-1) + A_hat_new = (A_hat_new / mx.sum(A_hat_new, axis=-1, keepdims=True)).reshape(B, Hkv, -1) #(B, Hkv, new_blocks) + A_hat = mx.concatenate([prev_scores, A_hat_new], axis=-1) + assert A_hat.shape[-1] == num_blocks, "Invalid number of blocks selected in A_hat." + + else: # No new blocks + A_hat = prev_scores + + else: # Full compute + Q_hat = pool(Q_gqa.reshape(B*Hkv, BLOCK_SIZE, D)).reshape(B, Hkv, -1, D) + K_hat = pool(K.reshape(B*Hkv, Lk, D)).reshape(B, Hkv, -1, D) + A_hat = nn.softmax(mx.matmul(Q_hat, K_hat.transpose(0, 1, 3, 2)) / mx.sqrt(D), axis=-1) + A_hat = A_hat / mx.sum(A_hat, axis=-1, keepdims=True) + A_hat = mx.mean(A_hat, axis=-2) + assert A_hat.shape[-1] == num_blocks, "Invalid number of blocks selected in A_hat." + + # Pluck selected values + # Use argmax to get the index of the first True value, where the threshold was met + # then handle edge cases where the threshold is never met by selecting all blocks + self.score_cache = A_hat + I_a = mx.argsort(A_hat, axis=-1)[..., ::-1] # (B, Hkv, num_blocks) + active_blocks = mx.take_along_axis(A_hat, I_a, axis=-1) + csum = mx.cumsum(active_blocks, axis=-1) # (B, Hkv, num_blocks) + + target = gamma * mx.sum(A_hat, axis=-1, keepdims=True) if 0 <= gamma <= 1.0 else gamma + selected_counts = mx.argmax(csum >= target, axis=-1) + 1 # (B, Hkv), reduce over num_blocks + full_sum = mx.sum(A_hat, axis=-1) # (B, Hkv) + selected_counts = mx.where(full_sum < target.squeeze(-1), num_blocks, selected_counts) + #TODO: FIX target degrading due to softmax + + # Apply min/max budget per head + selected_counts = mx.maximum(mx.minimum(selected_counts, max_budget), min_budget) + selected_counts = mx.minimum(selected_counts, num_blocks) + + # Pad and format to top selected_counts indices per head (sorted) + #max_sel = max_budget + max_sel = int(selected_counts.max()) + blocks = mx.full([B, Hkv, max_sel], -1, dtype=mx.int32) + for b in range(B): + for h in range(Hkv): + count = int(selected_counts[b, h]) + selected_idxs = I_a[b, h, :count] + selected_idxs = mx.sort(selected_idxs) + actual_count = min(selected_idxs.shape[0], max_sel) + blocks[b, h, :actual_count] = selected_idxs[:actual_count] + + return blocks, selected_counts + + + # VS Index search + # Token-level sparsity + def vertical_slash_search(self, Q, K, gamma, D, min_budget, max_budget, prev_scores=None, prev_Lkv=None): + assert min_budget >= 1 and max_budget >= 1, "budgets must be at least 1" + assert min_budget <= max_budget, "min_budget must be smaller then max_budget" + B, Lq, Hq = Q.shape[:3] + Hkv = K.shape[2] + n_repeats = Hq // Hkv + Lkv = K.shape[1] + + # Group Q for GQA + Q_gqa = Q[:, -BLOCK_SIZE:, :, :] + Q_gqa = Q.reshape(B, Lq, n_repeats, Hkv, D) + Q_gqa = mx.mean(Q_gqa, axis=2) # (B, Lq, Hkv, D) + #print(Q_gqa.shape) + + + # Compute a_v vertical sum and a_s diagonal sum + # NOTE: Maintain batch and head independence for each sum + # NOTE: Normalization happens after the raw values are partially or fully computed + if prev_scores is not None and prev_Lkv is not None: # Only compute the update to sequence length + new_cols = Lkv - prev_Lkv + if new_cols > 0: + prev_a_v, prev_a_s = prev_scores + new_K = K[:, prev_Lkv:, :] + # TODO: Broadcasting on A_hat_new is most likely broken. Check full compute. + A_hat_new = nn.softmax(mx.matmul(Q_gqa, new_K.transpose(0, 1, 3, 2)) / mx.sqrt(D), axis=-1) + new_a_v = mx.sum(A_hat_new, axis=1) + new_a_s = mx.zeros_like(prev_a_s) + prev_a_s + + offset_start = -(Lq-1) + num_diags = Lq + Lkv -1 + assert new_a_s.shape[-1] == Lq + prev_Lkv - 1 + if num_diags > new_a_s.sape[-1]: + new_a_s = mx.pad(new_a_s, [(0,0), (0,0), (0,num_diags - new_a_s.shape[-1])], constant_values=0) + + for b in range(B): + for h in range(Hkv): + for off in range(prev_Lkv, Lkv): + if off >= prev_Lkv - Lq + 1: + diag_val = mx.trace(A_hat_new[b, :, h, :], offset=off-prev_Lkv) + new_a_s[b, h, off - offset_start] = diag_val + + a_v = mx.concatenate([prev_a_v, new_a_v], axis=-1) + a_s = mx.concatenate([prev_a_s, new_a_s], axis=-1) + else: + a_v, a_s = prev_scores + + else: # Full compute + # Collapse BxHkv because MLX broadcasting sucks + Q_flat = Q_gqa.reshape(B*Hkv, Lq, D) + K_flat = K.reshape(B*Hkv, Lkv, D) + A_hat = nn.softmax(mx.matmul(Q_flat, K_flat.transpose(0, 2, 1)) / mx.sqrt(D), axis=-1) + A_hat = A_hat.reshape(B, Hkv, Lq, Lkv).transpose(0, 2, 1, 3) + + a_v = mx.sum(A_hat, axis=1) + a_s = mx.zeros([B, Hkv, Lq + Lkv -1]) # Diagonal + offset_start = -(A_hat.shape[1]-1) + for b in range(B): + for h in range(Hkv): + for off in range(offset_start, A_hat.shape[-1]): + diag_val = mx.trace(A_hat[b, :, h, :], offset=off) + a_s[b, h, off - offset_start] = diag_val + + # Store cache before normalization + self.score_cache = (a_v, a_s) + + # Normalize + total = mx.sum(a_v, axis=-1, keepdims=True) # Already normalized over 1(Lq) + a_v = a_v / mx.maximum(total, 1e-6) # Avoid div by 0 + a_s = a_s / mx.maximum(mx.sum(a_s, axis=-1, keepdims=True), 1e-6) + + # Per head selection + I_v = mx.argsort(a_v, axis=-1)[::-1] + I_s = mx.argsort(a_s, axis=-1)[::-1] + #print(I_v, I_s) + plucked_v = mx.cumsum(mx.take_along_axis(a_v, I_v, axis=-1), axis=-1) + plucked_s = mx.cumsum(mx.take_along_axis(a_s, I_s, axis=-1), axis=-1) + + target_v = gamma * mx.sum(a_v, axis=-1, keepdims=True) if 0 <= gamma <= 1.0 else gamma + target_s = gamma * mx.sum(a_s, axis=-1, keepdims=True) if 0 <= gamma <= 1.0 else gamma + K_v = mx.argmax(plucked_v >= target_v, axis=-1) + 1 + K_s = mx.argmax(plucked_s >= target_s, axis=-1) + 1 + full_sum_v = mx.sum(a_v, axis=-1) # (B, Hkv) + full_sum_s = mx.sum(a_s, axis=-1) # (B, Hkv) + K_v = mx.where(full_sum_v < target_v.squeeze(-1), a_v.shape[-1], K_v) + K_s = mx.where(full_sum_s < target_s.squeeze(-1), a_s.shape[-1], K_s) + + b# Concat, sort, unique per head + max_sel = min( int((K_v + K_s).max()), max_budget) # Worst case number + blocks = mx.full([B, Hkv, max_budget], -1, dtype=mx.int32) + for b in range(B): + for h in range(Hkv): + sel_v = I_v[b, h, :int(K_v[b, h].item())] + sel_s = I_s[b, h, :int(K_s[b, h].item())] + selected = mx.sort(mx.concatenate([sel_v, sel_s])) + + # WARNING: VERY SLOW + # No unique op in mlx, because we are merging both sel_v and sel_s we need it + # TODO: When MLX adds support for boolean indices change this. + if selected.size > 0: + diffs = mx.concatenate([mx.array([True]), selected[1:] != selected[:-1]]) # Filter duplicates + unique_idxs = [0] + for i in range(1, selected.size): + if selected[i] != selected[unique_idxs[-1]]: + unique_idxs.append(i) + selected = selected[mx.array(unique_idxs, dtype=mx.int32)] + num_selected = max(min(selected.size, max_budget), min_budget) + num_selected = min(num_selected, selected.size) + blocks[b, h, :num_selected] = selected[:num_selected] + + return blocks + + + # 0 - query specific, 1 - vertical slash + # NOTE: Causal mask is not applied when comparing distributions + def sparse_pattern_search(self, Q, K, tau, BLOCK_SIZE): + assert K.shape[1] % BLOCK_SIZE == 0 + pool = nn.AvgPool1d(kernel_size=BLOCK_SIZE, stride=BLOCK_SIZE) + B, Lq, Hq, D = Q.shape + Lkv = K.shape[1] + Hkv = K.shape[2] + n_repeats = Hq // Hkv # GQA repeats + + # Pad repersentative subset + if Lq < BLOCK_SIZE: + padding = BLOCK_SIZE - Lq + rep = Q[:, -Lq:, :, :] + rep = mx.pad(rep, [(0,0), (0, padding), (0,0), (0,0)]) + Lq = BLOCK_SIZE + else: + rep = Q[:, -BLOCK_SIZE:, :, :] + + if n_repeats > 1: # mean over n_repeats (rep size becomes Hkv) + rep = rep.reshape(B, BLOCK_SIZE, n_repeats, Hkv, D) + rep_group = mx.mean(rep, axis=2).transpose(0, 2, 1, 3) # (B, Hkv, BLOCK_SIZE, D) + else: + rep_group = rep + + Q_est = pool(rep_group.reshape(B*Hkv, BLOCK_SIZE, D)).reshape(B, Hkv, -1, D) + K_est = pool(K.reshape(B*Hkv, K.shape[1], D)).reshape(B, Hkv, -1, D) + a_est = nn.softmax(mx.matmul(Q_est, K_est.transpose(0, 1, 3, 2)) / mx.sqrt(D), axis=-1) + + if n_repeats > 1: # (B, Hkv, n_repeat, BLOCK_SIZE, D) @ (B, Hkv, D, Lq) + K = K.reshape(B, Lkv, Hkv, 1, D) + a_true = mx.matmul(rep.transpose(0, 3, 2, 1, 4), K.transpose(0, 2, 3, 4, 1)) / mx.sqrt(D) + a_true = mx.mean(a_true, axis=2) + a_true = nn.softmax(a_true, axis=-1) + a_true = pool(a_true.reshape(B*Hkv, BLOCK_SIZE, -1)).reshape(B, Hkv, Lkv, -1 ) + else: + a_true = nn.softmax(mx.matmul(rep.transpose(0, 2, 1, 3), K.transpose(0, 2, 3, 1)) / mx.sqrt(D), axis=-1) + a_true = pool(a_true.reshape(B*Hkv, BLOCK_SIZE, -1)).reshape(B, Hkv, -1, ) + + djs = mx.zeros([B, Hkv]) + for b in range(B): + for h in range(Hkv): + djs_bh = self.jensen_shannon_metric(a_est[b, h], a_true[b, h]) + djs[b, h] = mx.mean(djs_bh) # If multi-dim + return mx.where(djs < tau, 1, 0) # 1 for vertical-slash, 0 for query-aware + + # Analyze heads and create sparse strategy + def __call__(self, i: StrategyInput) -> Strategy: + Q, K, gamma, tau, block_size = i.Q, i.K, i.gamma, i.tau, i.block_size + min_budget, max_budget = i.min_budget, i.max_budget + B, Lq, Hq, D = Q.shape + _, Lkv, Hkv, _ = K.shape + self.pattern_type = self.sparse_pattern_search(Q, K, tau, block_size) + B, Hkv = self.pattern_type.shape + + """ + for b in range(B): + for h in range(Hkv): + #if self.pattern_type[b, h] == 0: + if True: # Hardcode query_aware for now + sel, selected_counts = self.query_aware_search(Q, K, gamma, D, i.min_budget, i.max_budget) + blocks[b, h, :sel.shape[-1]] = sel[b, h] + else: + sel = self.vertical_shash_search(Q, K, gamma, D, i.min_budget, i.max_budget) + blocks[b, h, :sel.shape[0]] = sel + """ + + selected_counts = mx.zeros([B, Hkv]) + #pttn = self.pattern_type[b][h].item() + pttn = 0 + if pttn == 0: + sel, selected_counts = self.query_aware_search(Q, K, gamma, D, min_budget, max_budget) + else: + sel = self.vertical_shash_search(Q, K, gamma, head_dim, min_budget, max_budget) + + max_sel = int(mx.max(selected_counts).item()) + blocks = mx.full([B, Hkv, max_sel], -1, dtype=mx.int32) + for b in range(B): + for h in range(Hkv): + count = int(selected_counts[b][h].item()) + blocks[b, h, :count] = sel[b][h][:count] + + return Strategy(blocks=blocks, + tokens=None, + block_size=BLOCK_SIZE, + selected_counts=selected_counts, + max_budget=i.max_budget, + min_budget=i.min_budget) + + + def update(self, i: StrategyInput, prev_S: Strategy, prev_Lkv:int, cycles) -> Strategy: + Q, K, gamma, tau, block_size = i.Q, i.K, i.gamma, i.tau, i.block_size + min_budget, max_budget = prev_S.min_budget, prev_S.max_budget + B, Lq, Hkv, D = Q.shape + Lk = K.shape[1] + + if cycles % prev_S.update_frequency != 0: + return prev_S + + # Recompute patterns only on new blocks (maybe add a set repeat pattern) + new_blocks = math.ceil(Lk / block_size) - math.ceil(prev_Lkv / block_size) + if new_blocks > 0 or self.pattern_type is None: + self.pattern_type = self.sparse_pattern_search(Q, K, tau, block_size) + + # Update per head (recompute all for now) + selected_counts = mx.zeros([B, Hkv]) + #pttn = self.pattern_type[b][h].item() + pttn = 0 + if pttn == 0: + sel, selected_counts = self.query_aware_search(Q, K, gamma, D, min_budget, max_budget) + else: + sel = self.vertical_shash_search(Q, K, gamma, head_dim, min_budget, max_budget) + + max_sel = int(mx.max(selected_counts).item()) + blocks = mx.full([B, Hkv, max_sel], -1, dtype=mx.int32) + for b in range(B): + for h in range(Hkv): + count = int(selected_counts[b][h].item()) + blocks[b, h, :count] = sel[b][h][:count] + + return Strategy( + blocks=blocks, + tokens=None, + block_size=BLOCK_SIZE, + selected_counts=selected_counts, + max_budget=max_budget, + min_budget=min_budget) + +class SparseAttention(nn.Module): + def __init__(self, args: ModelArgs, sparse_strategy: AbstractSparseStrategy, prefill=False): + super().__init__() + + h = args.hidden_size + self.n_heads = args.num_attention_heads + self.n_kv_heads = args.num_key_value_heads + self.head_dim = args.head_dim + self.scale = args.head_dim**-0.5 + self.S = None # Keep a cached S and only update K/V L_kv dimension on decode + self.compute_strat = sparse_strategy + self.kv_cache = None + self.sparse_kv_cache = None + self.cycles = 0 + + self.q_proj = nn.Linear(h, self.n_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(h, self.n_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(h, self.n_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.n_heads*self.head_dim, h, bias=False) + + self.q_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.head_dim, eps=args.rms_norm_eps) + self.rope = initialize_rope( + self.head_dim, + base=args.rope_theta, + traditional=False, + scaling_config=None, + max_position_embeddings=40960, + ) + + # Returns a flat, sorted list of absolute token indices for the blocks. + # Expects KV cache in (B, Hkv, Nblk, blk, D) format. + def blocks_to_positions(self, block_ids, blk, Lk): + pos = [] + for id in block_ids: + start = id*blk + end = min(start + blk, Lk) + if start < end: + pos.extend(range(start, end)) + return pos + + # K_bh: (Nblk, blk, D) + # Return: (Lk_sel, D) + def gather_blocks(self,K_bh, blocks, blk): + out = [] + for id in blocks: + slab = K_bh[id] + out.append(slab) + return mx.concatenate(out, axis=0) + + # K, V: (B, Hkv, Nblk, blk, D) + # S: (B, Hkv) (sorted) + # NOTE: Also computes the causal mask + def pack_kv_blocks(self, K, V, S, Lkv, Lq, uniform=True, padding_mask=None): + _, _, sel_max = S.blocks.shape + B, Hkv, Lkv_total, D = K.shape + #print(Lkv_total, Lkv) + assert Lkv_total == Lkv, "Cache length mismatch" + num_blocks = math.ceil(Lkv / S.block_size) + K_sel = mx.full([B, Hkv, sel_max, S.block_size, D], 0.0) + V_sel = mx.full([B, Hkv, sel_max, S.block_size, D], 0.0) + + if padding_mask is None: + padding_mask = mx.full([Lkv], 0.0, dtype=K.dtype) + + K = K.reshape(B, Hkv, num_blocks, S.block_size, D) + V = V.reshape(B, Hkv, num_blocks, S.block_size, D) + padding_mask = padding_mask.reshape(num_blocks, S.block_size, 1) + mask = mx.zeros([B, Hkv, Lq, sel_max*S.block_size], dtype=mx.bool_) + + for b in range(B): + for h in range(Hkv): + K_sel_b = mx.full([sel_max, S.block_size, D], 0.0) + V_sel_b = mx.full([sel_max, S.block_size, D], 0.0) + mask_b = mx.zeros([Lq, sel_max * S.block_size], dtype=mx.bool_) + blocks = S.blocks[b,h] # (Lk_selected) + for i, idx in enumerate(blocks.tolist()): + if idx == -1: continue + K_sel_b[i] = K[b, h, idx, :, :] + V_sel_b[i] = V[b, h, idx, :, :] + + #K_sel_b[i] = mx.where(padding_mask[idx, :], K_sel_b[1], mx.finfo(K.dtype).min) + #V_sel_b[i] = mx.where(padding_mask[idx, :], V_sel_b[i], mx.finfo(K.dtype).min) + + start = idx*S.block_size + stop = min(start + S.block_size, Lkv) + k_pos = mx.arange(start, stop) + q_pos = mx.arange(Lq).reshape(-1, 1) + pk_start = i*(stop - start) + pk_end = pk_start+(stop-start) + mask_b[:, pk_start:pk_end] = q_pos >= k_pos.reshape(1, -1) + K_sel[b, h] = K_sel_b + V_sel[b, h] = V_sel_b + mask[b, h] = mask_b + + """ + for b in range(B): + for h in range(Hkv): + for blk in range(sel_max): + for dim in range(S.block_size): + print(f"\nBatch {b}, Head {h}, block {blk}, dim {dim}: ", end="") + print(K_sel[b, h, blk, dim]) + """ + #print(K_sel, V_sel) + #print(mask) + K_sel = K_sel.reshape(B, Hkv, sel_max*S.block_size, D) + V_sel = V_sel.reshape(B, Hkv, sel_max*S.block_size, D) + return K_sel, V_sel, mask + + + def __call__(self, x: mx.array, mask: mx.array, cache: Optional[tuple[mx.array, mx.array]]): + start_g = time.perf_counter() + self.cycles += 1 + B, Lq, _ = x.shape + + #start_t = time.perf_counter() + Q = self.q_proj(x).reshape(B, Lq, self.n_heads, self.head_dim) + Q = self.q_norm(Q).transpose(0, 2, 1, 3) + + if (self.cycles != 1 + and self.cycles % self.S.update_frequency != 0 + and self.sparse_kv_cache[0].shape[2] >= self.kv_cache[0].shape[1] + self.cycles % self.S.update_frequency): # Quick decode using cached packed values + + # No strategy, packing or padding the full tensors + x = x[:, -1, :] + K, V, padding_mask = self.sparse_kv_cache + Lkv = self.kv_cache[0].shape[1] + 1 + + new_V = self.v_proj(x).reshape(self.n_kv_heads, self.head_dim) + V[-1, :, Lkv] = new_V + self.kv_cache[1] = mx.concatenate([self.kv_cache[1], new_V.reshape(1, 1, self.n_kv_heads, self.head_dim)], axis=1) + + new_K = self.k_proj(x).reshape(B, -1, self.n_kv_heads, self.head_dim) + new_K = self.k_norm(new_K).transpose(0, 2, 1, 3) + new_K = self.rope(new_K, offset=Lkv) + new_K = new_K.reshape(self.n_kv_heads, self.head_dim) + K[-1, :, Lkv, :] = new_K + self.kv_cache[0] = mx.concatenate([self.kv_cache[0], new_K.reshape(1, 1, self.n_kv_heads, self.head_dim)], axis=1) + Q = self.rope(Q, offset=Lkv) # offset rope + + # TODO: Add to normal cache too + padding_mask[-1, Lkv, :, :] = False + self.sparse_kv_cache = (K, V, padding_mask) + + #logger.info(f"[PROFILE] Quick decode: {(time.perf_counter() - start_t)*1000:0.5f}ms") + + else: + if self.cycles == 1: # Prefill + #start_t = time.perf_counter() + Lkv = Lq + V = self.v_proj(x).reshape(B, -1, self.n_kv_heads, self.head_dim) + K = self.k_proj(x).reshape(B, -1, self.n_kv_heads, self.head_dim) + K = self.k_norm(K).transpose(0, 2, 1, 3) + K = self.rope(K) + Q = self.rope(Q) + #print(Q.dtype, K.dtype, V.dtype) + + # Transpose back for S calculation + K = K.transpose(0, 2, 1, 3) + self.kv_cache = [K, V] + #logger.info(f"[PROFILE] Prefill Projections: {(time.perf_counter() - start_t)*1000:0.5f}ms") + + else: # Full Decode: Recompute sparsity indices and masks + #start_t = time.perf_counter() + K_cache, V_cache = self.kv_cache + Q = self.rope(Q, offset=K_cache.shape[1]) # offset rope + assert self.kv_cache is not None, "KV Cache is needed to compute decode phase." + + new_V = self.v_proj(x[:, -1, :]).reshape(B, -1, self.n_kv_heads, self.head_dim) + V = mx.concatenate([V_cache, new_V], axis=1) + + new_K = self.k_proj(x[:, -1, :]).reshape(B, -1, self.n_kv_heads, self.head_dim) + new_K = self.k_norm(new_K).transpose(0, 2, 1, 3) + new_K = self.rope(new_K, offset=K_cache.shape[1]) + new_K = new_K.transpose(0, 2, 1, 3) + K = mx.concatenate([K_cache, new_K], axis=1) + #logger.info(f"[PROFILE] Long Decode Projections: {(time.perf_counter() - start_t)*1000:0.5f}ms") + + self.kv_cache = [K, V] + + Q = Q.transpose(0, 2, 1, 3) + + # Pad to block size + #start_t = time.perf_counter() + _, Lkv, Hkv, _ = K.shape + final_lkv = math.ceil(Lkv / BLOCK_SIZE) * BLOCK_SIZE + padding = final_lkv - Lkv + if padding > 0: + K = mx.pad(K, [(0,0), (0,padding), (0,0), (0,0)], constant_values=0.0) + V = mx.pad(V, [(0,0), (0,padding), (0,0), (0,0)], constant_values=0.0) + padding_indices = mx.arange(final_lkv)[None, :, None, None] + padding_mask = mx.where(padding_indices < Lkv, True, False) + else: + padding_mask = None + #print(f"Strategy: {time.perf_counter() - start_t}s") + #logger.info(f"[PROFILE] Padding {(time.perf_counter() - start_t)*1000:0.5f}ms") + + V = V.transpose(0, 2, 1, 3) + + #start_t = time.perf_counter() + if not self.S: + strat_in = StrategyInput( + Q, K, mask, BLOCK_SIZE, gqa_interleave=False, + last_q=None, gamma=0.1, min_budget=1, max_budget=128, + tau=0.1, is_token_level=False) + self.S = self.compute_strat(strat_in) + elif self.cycles % self.S.update_frequency == 0: + strat_in = StrategyInput( + Q, K, mask, BLOCK_SIZE, gqa_interleave=False, + last_q=None, gamma=0.1, min_budget=1, max_budget=128, + tau=0.1, is_token_level=False) + self.S = self.compute_strat.update(strat_in, prev_S=self.S, prev_Lkv=Lkv-1, cycles=self.cycles) + #logger.info(f"[PROFILE] Strategy compute: {(time.perf_counter() - start_t)*1000:0.5f}ms") + + + Q = Q.transpose(0, 2, 1, 3) + K = K.transpose(0, 2, 1, 3) + + self.sparse_kv_cache = (K, V, padding_mask) + + # Packing the selected sparse blocks into dense buffers + #start_t = time.perf_counter() + if self.S.blocks.shape[2] != 1: + K, V, sparse_mask = self.pack_kv_blocks(K, V, self.S, final_lkv, Lq, padding_mask=padding_mask) + """ + if not isinstance(mask, str): + mask += sparse_mask + else: + mask = sparse_mask + """ + mask = sparse_mask + #logger.info(f"[PROFILE] Packing and masking: {(time.perf_counter()-start_t)*1000:0.5f}ms") + + #print(Q.shape, K.shape, V.shape) + #start_t = time.perf_counter() + out = sparse_dot_product_attention_blocked(Q, K, V, self.scale, mask, padding_mask=padding_mask, selected_counts=None) + #logger.info(f"[PROFILE] sdpa {(time.perf_counter() - start_t)*1000:0.5f}ms") + out = self.o_proj(out) + logger.info(f"[PROFILE] Global attention runtime: {(time.perf_counter() - start_g)*1000:0.5f}ms") + + return out diff --git a/test/test_kv_cache.py b/test/test_kv_cache.py new file mode 100644 index 00000000..00a2f2da --- /dev/null +++ b/test/test_kv_cache.py @@ -0,0 +1,82 @@ + + +import os +import sys +import ctypes +import pathlib +import mlx.nn as nn +import mlx.core as mx + +import pathlib +import traceback +import importlib, importlib.machinery, importlib.util +#so = pathlib.Path("../lib/kv_cache/").resolve().glob("kv_cache.dylib") +so = pathlib.Path(__file__).resolve().parents[1] / "lib" / "kv_cache" / "kv_cache.dylib" +ldr = importlib.machinery.ExtensionFileLoader("kv_cache", str(so)) +spec = importlib.util.spec_from_loader("kv_cache", ldr) +kv_cache = importlib.util.module_from_spec(spec) + +try: + ldr.exec_module(kv_cache) +except Exception: + traceback.print_exec() + +import pytest + +@pytest.fixture +def kv_cache_setup(): + num_kv_heads = 8 + block_size = 128 + max_blocks = 256 + batch_size = 1 + head_dim = 128 + bytesize = 2 + + allocator = kv_cache.PageAllocator(num_kv_heads, block_size, head_dim) + cache = kv_cache.SparseKVCache(allocator, bytesize, batch_size, block_size, head_dim, num_kv_heads) + + yield cache, num_kv_heads, block_size, head_dim, batch_size, max_blocks + +# Append data to cache +@pytest.fixture(params=[1, 284, 8491]) +def appended_cache(kv_cache_setup, request): + cache, num_kv_heads, block_size, head_dim, batch_size, _ = kv_cache_setup + num_new_tokens = request.param + k_data = mx.random.normal([1, num_new_tokens, num_kv_heads, head_dim]) + v_data = mx.random.normal([1, num_new_tokens, num_kv_heads, head_dim]) + cache.append(k_data, v_data, 1) + return cache, num_new_tokens + +# Test sequence size and offsets after appending +def test_append(kv_cache_setup, appended_cache): + _, num_kv_heads, _, _, batch_size, _ = kv_cache_setup + cache, num_new_tokens = appended_cache + seq_len = cache.get_seq_len() + seq_offset = cache.get_seq_offset() + for b in range(batch_size): + assert seq_len[b] == num_new_tokens, f"Batch {b} seq_len: {cache.seq_len[b]} != {num_new_tokens}" + assert seq_offset[b] == num_new_tokens * num_kv_heads, \ + f"Batch {b} seq_offset: {seq_offset[b]} != {num_new_tokens*num_kv_heads}" + +# Ensure page table is sane +def test_page_table(kv_cache_setup, appended_cache): + _, num_kv_heads, block_size, _, batch_size, _ = kv_cache_setup + cache, num_new_tokens = appended_cache + page_table = cache.get_page_table() + for b in range(batch_size): + num_blocks = (num_new_tokens + block_size - 1) // block_size + for h in range(num_kv_heads): + assert len(page_table[b][h] == num_blocks), \ + f"Batch {b} Head {h}: page table len mispatch: {len(page_table[b][h])} : {num_blocks}" + for i in range(num_blocks): + assert page_table[b][h][i] >=0, f"Batch {b}, Head {h}, Block{i} is invalid" + +# Test packing from sparsity mask +def test_pack_table_buffer(): + pass + +def test_read(kv_cache_setup, appended_cache): + pass + + + diff --git a/test/test_sparse_attention.py b/test/test_sparse_attention.py new file mode 100644 index 00000000..f004e66d --- /dev/null +++ b/test/test_sparse_attention.py @@ -0,0 +1,121 @@ + + +import mlx.core as mx +import mlx.nn as nn +import pytest +import math + +from src.runtime.sparse_attention import FlexPrefillSparseAttention, SparseAttention, StrategyInput, BLOCK_SIZE + +#@pytest.fixture +def fixed_inputs(): + B = 1 + Lq = BLOCK_SIZE*2 + Lkv = BLOCK_SIZE*4 + Hq = 32 + Hkv = 16 # GQA n_repeats=2 + D = 64 + return B, Lq, Lkv, Hq, Hkv, D + +# Test deterministic arange values +#@pytest.mark.parametrize("gamma", [0.3, 0.5, 0.8]) +#@pytest.mark.parametrize("Lk", [BLOCK_SIZE*8, BLOCK_SIZE*16]) +def test_query_aware_search(fixed_inputs, gamma, Lk): + B, Lq, Lkv, Hq, Hkv, D = fixed_inputs + Lkv = Lk + Q = mx.arange(B*Lq*Hq*D, dtype=mx.float32).reshape(B, Lq, Hq, D) + K = mx.arange(B*Lkv*Hkv*D, dtype=mx.float32).reshape(B, Lkv, Hkv, D) + strat = FlexPrefillSparseAttention() + input = StrategyInput( + Q=Q, + K=K, + block_size=BLOCK_SIZE, + gamma=gamma, + min_budget=1, + max_budget=128, + tau=0.5, + mask=None, + ) + blocks = strat.query_aware_search(Q, K, input.gamma, D, input.min_budget, input.max_budget) + num_valid = mx.sum(blocks >= 0, axis=-1) + assert blocks.shape == (B, Hkv, input.max_budget), "Shape mismatch" + assert mx.all(num_valid >= input.min_budget) & mx.all(num_valid < input.max_budget), "Invalid number of blocks selected" + + # Per head checks + num_blocks = math.ceil(K.shape[1] / BLOCK_SIZE) + for b in range(B): + for h in range(Hkv): + #print(num_valid[b, h]) + valid = blocks[b, h, :int(num_valid[b, h].item())] + assert valid.size == num_valid[b, h] + assert mx.all(valid >= 0) and mx.all(num_blocks > valid) + #assert mx.all( ~(blocks[b,h, 1:] != blocks[b, h, :1])), "Repeting indices" + #print( ~(blocks[b,h, 1:] != blocks[b, h, :1])) + +def test_query_aware_search_variable_gamma(fixed_inputs): + B, Lq, Lkv, Hq, Hkv, D = fixed_inputs + Q = mx.random.normal([B, Lq, Hq, D], scale=2.0) + K = mx.random.normal([B, Lkv, Hkv, D], scale=2.0) + strat = FlexPrefillSparseAttention() + input = StrategyInput( + Q=Q, + K=K, + block_size=BLOCK_SIZE, + gamma=0.3, + min_budget=1, + max_budget=128, + tau=0.5, + mask=None, + ) + blocks = strat.query_aware_search(Q, K, input.gamma, D, input.min_budget, input.max_budget) + num_valid = mx.sum(blocks >= 0, axis=-1) + num_blocks = math.ceil(K.shape[1] / BLOCK_SIZE) + for b in range(B): + for h in range(Hkv): + valid = blocks[b, h, :int(num_valid[b, h].item())] + assert valid.size == num_valid[b, h] + assert mx.all(valid >= 0) and mx.all(num_blocks > valid) + #assert mx.all( ~(blocks[b,h, 1:] != blocks[b, h, :1])), "Repeting indices" + +def test_vertical_slash_search(fixed_inputs): + B, Lq, Lkv, Hq, Hkv, D = fixed_inputs + Q = mx.arange(B*Lq*Hq*D, dtype=mx.float32).reshape(B, Lq, Hq, D) + K = mx.arange(B*Lkv*Hkv*D, dtype=mx.float32).reshape(B, Lkv, Hkv, D) + strat = FlexPrefillSparseAttention() + input = StrategyInput( + Q=Q, + K=K, + block_size=BLOCK_SIZE, + gamma=0.5, + min_budget=1, + max_budget=64, + tau=0.5, + mask=None, + ) + blocks = strat.vertical_slash_search(Q, K, input.gamma, D, input.min_budget, input.max_budget) + +def test_vertical_slash_search_normal_dist(fixed_inputs): + B, Lq, Lkv, Hq, Hkv, D = fixed_inputs + Q = mx.random.normal([B, Lq, Hq, D], scale=2.0) + K = mx.random.normal([B, Lkv, Hkv, D], scale=2.0) + strat = FlexPrefillSparseAttention() + input = StrategyInput( + Q=Q, + K=K, + block_size=BLOCK_SIZE, + gamma=0.5, + min_budget=1, + max_budget=64, + tau=0.5, + mask=None, + ) + blocks = strat.vertical_slash_search(Q, K, input.gamma, D, input.min_budget, input.max_budget) + print(blocks) + + +if __name__ == "__main__": + input = fixed_inputs() + test_query_aware_search(input, 0.5, BLOCK_SIZE*8) + test_query_aware_search_variable_gamma(input) + test_vertical_slash_search(input) + test_vertical_slash_search_normal_dist(input)