diff --git a/README.md b/README.md index 84a15cb5..91861bee 100644 --- a/README.md +++ b/README.md @@ -196,28 +196,62 @@ chuk-lazarus generate --type math --output ./data/lazarus chuk-lazarus infer --model "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --prompt "What is 2+2?" ``` -### Model Inference (models_v2) +### Inference Pipeline (New!) -Run inference with pretrained HuggingFace models using the composable models_v2 architecture: +The new unified inference pipeline provides a simplified API for running inference with any supported model family. One-liner setup, no boilerplate: + +```python +from chuk_lazarus.inference import InferencePipeline, PipelineConfig, DType +from chuk_lazarus.models_v2 import LlamaConfig, LlamaForCausalLM + +# One-liner model loading +pipeline = InferencePipeline.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + LlamaForCausalLM, + LlamaConfig, +) + +# Simple chat API +result = pipeline.chat("What is the capital of France?") +print(result.text) +print(result.stats.summary) # "25 tokens in 0.42s (59.5 tok/s)" +``` + +**Key features:** +- Typed configuration with Pydantic (`PipelineConfig`, `GenerationConfig`) +- Async support (`InferencePipeline.from_pretrained_async`) +- Chat history management (`ChatHistory`) +- Streaming generation (`generate_stream`) +- No magic strings - uses enums (`DType`, `Role`) ```bash -# Run inference with various Llama-family models -uv run python examples/models/llama/03_llama_family_inference.py --model tinyllama -uv run python examples/models/llama/03_llama_family_inference.py --model smollm2-360m -uv run python examples/models/llama/03_llama_family_inference.py --model smollm2-1.7b - -# Custom prompt and parameters -uv run python examples/models/llama/03_llama_family_inference.py \ - --model smollm2-360m \ - --prompt "Explain quantum computing in one sentence" \ - --max-tokens 100 \ - --temperature 0.7 - -# List all available model presets -uv run python examples/models/llama/03_llama_family_inference.py --list-models +# Simplified inference examples +uv run python examples/inference/simple_inference.py --prompt "Write a haiku" +uv run python examples/inference/llama_inference.py --model smollm2-360m +uv run python examples/inference/granite_inference.py --model granite-3.1-2b +uv run python examples/inference/gemma_inference.py --chat ``` -**Available presets:** `tinyllama` (1.1B), `smollm2-135m`, `smollm2-360m`, `smollm2-1.7b`, `llama3.2-1b`, `llama3.2-3b`, `mistral-7b` +### Model Family Examples + +Run inference with specific model families: + +```bash +# Llama family (TinyLlama, SmolLM2, Llama 2/3, Mistral) +uv run python examples/inference/llama_inference.py --model tinyllama +uv run python examples/inference/llama_inference.py --model smollm2-360m +uv run python examples/inference/llama_inference.py --list # Show all presets + +# Gemma 3 (1B, 4B, 12B, 27B with 128K context) +uv run python examples/inference/gemma_inference.py --chat +uv run python examples/inference/gemma_inference.py --model gemma-3-4b + +# Granite (IBM, dense and hybrid MoE variants) +uv run python examples/inference/granite_inference.py --model granite-3.1-2b + +# Llama 4 Scout (Hybrid Mamba-Transformer MoE) +uv run python examples/inference/llama4_inference.py +``` ### FunctionGemma (Function Calling) @@ -287,7 +321,11 @@ src/chuk_lazarus/ │ ├── adapters/ # LoRA adapters │ └── losses/ # Loss functions (pure math) ├── training/ # BatchPlan-driven reference trainers (SFT, DPO, GRPO, PPO) -├── inference/ # Text generation +├── inference/ # Unified inference pipeline +│ ├── pipeline.py # InferencePipeline high-level API +│ ├── loader.py # HFLoader, DType, WeightConverter +│ ├── chat.py # ChatHistory, Role, format_chat_prompt +│ └── generation.py # GenerationConfig, generate, generate_stream ├── distributed/ # Distributed training utilities └── utils/ # Utilities ``` @@ -296,7 +334,8 @@ src/chuk_lazarus/ | Module | Description | |--------|-------------| -| **Models** | Composable architecture: components, blocks, backbones, heads, families (Llama, Mamba) | +| **Models** | Composable architecture: components, blocks, backbones, heads, families (Llama, Gemma, Granite) | +| **Inference** | Unified pipeline API: `InferencePipeline`, chat history, streaming generation | | **Tokenizers** | Comprehensive toolkit for analysis, preprocessing, and runtime management | | **Batching** | Token-budget batching, sequence packing, distributed batch planning | | **Streaming** | Puzzle arcade integration, replay buffers, online learning | @@ -372,13 +411,14 @@ If the tokenizer or data changes, fingerprint mismatch is detected before traini ## Supported Models -- LLaMA / LLaMA 2 / LLaMA 3 -- Mistral -- Gemma -- Granite -- StarCoder2 -- TinyLlama -- SmolLM2 (135M, 360M, 1.7B) +| Family | Models | Notes | +|--------|--------|-------| +| **Llama** | TinyLlama, Llama 2 (7B, 13B), Llama 3.1/3.2, Llama 4 Scout | Llama 4 uses Mamba-Transformer hybrid | +| **SmolLM2** | 135M, 360M, 1.7B | No auth required, fast inference | +| **Mistral** | 7B Instruct v0.3 | Sliding window attention | +| **Gemma** | Gemma 3 (270M, 1B, 4B, 12B, 27B), FunctionGemma | 128K context, function calling | +| **Granite** | 3.0/3.1 (2B, 8B), 4.0 Tiny (1B, 1.5B MoE) | IBM, dense and MoE variants | +| **StarCoder2** | 3B, 7B, 15B | Code generation | ## OpenAI Tokenizers diff --git a/docs/inference.md b/docs/inference.md index fbeeb4d7..426cdaa7 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -1,9 +1,65 @@ # Inference Guide -Run text generation with pretrained models from HuggingFace Hub using the models_v2 architecture. +Run text generation with pretrained models from HuggingFace Hub using the unified inference pipeline. ## Quick Start +### Inference Pipeline (Recommended) + +The new `InferencePipeline` provides a simplified, one-liner API for loading and running inference: + +```python +from chuk_lazarus.inference import InferencePipeline, PipelineConfig, DType +from chuk_lazarus.models_v2 import LlamaConfig, LlamaForCausalLM + +# One-liner model loading +pipeline = InferencePipeline.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + LlamaForCausalLM, + LlamaConfig, +) + +# Simple chat API +result = pipeline.chat("What is the capital of France?") +print(result.text) +print(result.stats.summary) # "25 tokens in 0.42s (59.5 tok/s)" +``` + +### With Custom Configuration + +```python +from chuk_lazarus.inference import ( + InferencePipeline, + PipelineConfig, + GenerationConfig, + DType, +) +from chuk_lazarus.models_v2 import LlamaConfig, LlamaForCausalLM + +# Configure the pipeline +config = PipelineConfig( + dtype=DType.BFLOAT16, + default_system_message="You are a helpful coding assistant.", + default_max_tokens=200, + default_temperature=0.7, +) + +pipeline = InferencePipeline.from_pretrained( + "HuggingFaceTB/SmolLM2-360M-Instruct", + LlamaForCausalLM, + LlamaConfig, + pipeline_config=config, +) + +# Generate with custom settings +result = pipeline.chat( + "Write a Python function to calculate Fibonacci numbers", + max_new_tokens=300, + temperature=0.3, +) +print(result.text) +``` + ### CLI Inference ```bash @@ -18,7 +74,9 @@ chuk-lazarus infer \ --temperature 0.7 ``` -### Python API +### Low-Level Python API + +For more control, use the models directly: ```python from chuk_lazarus.models_v2 import LlamaConfig, LlamaForCausalLM @@ -48,23 +106,133 @@ response = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True) print(response) ``` -## Llama Family Inference Example +## Inference Pipeline API + +### Core Classes + +| Class | Description | +|-------|-------------| +| `InferencePipeline` | High-level API for model loading and generation | +| `PipelineConfig` | Pipeline configuration (dtype, defaults) | +| `GenerationConfig` | Generation parameters (max_tokens, temperature, top_p) | +| `GenerationResult` | Generation output with text and stats | +| `ChatHistory` | Multi-turn conversation management | + +### Loading Models + +```python +# Synchronous loading +pipeline = InferencePipeline.from_pretrained( + model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + model_class=LlamaForCausalLM, + config_class=LlamaConfig, + pipeline_config=PipelineConfig(dtype=DType.BFLOAT16), +) + +# Async loading +pipeline = await InferencePipeline.from_pretrained_async( + model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + model_class=LlamaForCausalLM, + config_class=LlamaConfig, +) +``` + +### Chat API + +```python +# Simple single-turn chat +result = pipeline.chat("What is 2+2?") + +# With custom system message +result = pipeline.chat( + "Write a haiku", + system_message="You are a poet.", +) + +# Multi-turn conversation +from chuk_lazarus.inference import ChatHistory + +history = ChatHistory() +history.add_system("You are a helpful assistant.") +history.add_user("What is Python?") +history.add_assistant("Python is a programming language.") +history.add_user("What is it used for?") + +result = pipeline.chat_with_history(history) +``` -The `examples/models/llama/03_llama_family_inference.py` script provides a unified interface for running inference with various Llama-architecture models: +### Raw Generation + +```python +# Direct prompt without chat formatting +result = pipeline.generate( + "Once upon a time", + max_new_tokens=100, + temperature=0.9, +) + +# With full config +from chuk_lazarus.inference import GenerationConfig + +config = GenerationConfig( + max_new_tokens=200, + temperature=0.7, + top_p=0.9, + top_k=40, +) +result = pipeline.generate("The quick brown fox", config=config) +``` + +### Streaming Generation + +```python +from chuk_lazarus.inference import generate_stream + +# Stream tokens as they're generated +for chunk in generate_stream(model, tokenizer, "Write a story"): + print(chunk, end="", flush=True) +``` + +## Simplified Examples + +The `examples/inference/` directory contains streamlined examples using the new inference pipeline: + +```bash +# Simple inference (any Llama-family model) +uv run python examples/inference/simple_inference.py --prompt "What is the capital of France?" + +# Llama family with model presets +uv run python examples/inference/llama_inference.py --model smollm2-360m +uv run python examples/inference/llama_inference.py --list # Show all presets + +# Gemma 3 with interactive chat +uv run python examples/inference/gemma_inference.py --chat + +# Granite (IBM) +uv run python examples/inference/granite_inference.py --model granite-3.1-2b + +# Llama 4 Scout (Mamba-Transformer hybrid) +uv run python examples/inference/llama4_inference.py +``` + +These examples replace the 400+ line model-specific examples with ~100-200 line implementations using the unified API. + +## Llama Family Inference + +The `examples/inference/llama_inference.py` script provides a unified interface for Llama-architecture models: ```bash # List available model presets -uv run python examples/models/llama/03_llama_family_inference.py --list-models +uv run python examples/inference/llama_inference.py --list # Run with different models -uv run python examples/models/llama/03_llama_family_inference.py --model tinyllama -uv run python examples/models/llama/03_llama_family_inference.py --model smollm2-135m -uv run python examples/models/llama/03_llama_family_inference.py --model smollm2-360m -uv run python examples/models/llama/03_llama_family_inference.py --model smollm2-1.7b +uv run python examples/inference/llama_inference.py --model tinyllama +uv run python examples/inference/llama_inference.py --model smollm2-360m +uv run python examples/inference/llama_inference.py --model llama3.2-1b # Custom prompt -uv run python examples/models/llama/03_llama_family_inference.py \ - --model tinyllama \ +uv run python examples/inference/llama_inference.py \ + --model smollm2-360m \ --prompt "Explain relativity in simple terms" \ --max-tokens 150 \ --temperature 0.8 @@ -198,8 +366,287 @@ If weight loading fails: - Verify safetensors format - Some models may need HF authentication +## Gemma Inference + +Gemma 3 is Google's latest open model family with 5 sizes (270M, 1B, 4B, 12B, 27B) and 128K context. Use bf16 models from mlx-community for direct loading. + +### Running Gemma Inference + +```bash +# Basic inference (simplified API) +uv run python examples/inference/gemma_inference.py --prompt "What is the capital of France?" + +# Gemma 3 270M (smallest, fastest) +uv run python examples/inference/gemma_inference.py --model gemma3-270m + +# FunctionGemma 270M (function calling optimized) +uv run python examples/inference/gemma_inference.py --model functiongemma + +# Interactive chat mode +uv run python examples/inference/gemma_inference.py --chat + +# Use larger model +uv run python examples/inference/gemma_inference.py --model gemma3-4b + +# List all available models +uv run python examples/inference/gemma_inference.py --list +``` + +### Available Gemma Models + +| Preset | Model ID | Parameters | Memory | Notes | +|--------|----------|------------|--------|-------| +| `gemma3-270m` | mlx-community/gemma-3-270m-it-bf16 | 270M | ~540MB | Smallest, fastest | +| `functiongemma` | mlx-community/functiongemma-270m-it-bf16 | 270M | ~540MB | Function calling optimized | +| `gemma3-1b` | mlx-community/gemma-3-1b-it-bf16 | 1B | ~2GB | Fast, good for testing | +| `gemma3-4b` | mlx-community/gemma-3-4b-it-bf16 | 4B | ~8GB | Good quality/speed balance | +| `gemma3-12b` | mlx-community/gemma-3-12b-it-bf16 | 12B | ~24GB | High quality | +| `gemma3-27b` | mlx-community/gemma-3-27b-it-bf16 | 27B | ~54GB | Best quality | + +**Notes:** +- Use bf16 models (not 4-bit quantized) for direct loading. Quantized models require additional quantization support. +- The 4B+ models are multimodal but this example uses them for text-only inference (vision components are filtered out). + +### Python API + +```python +import mlx.core as mx +from mlx.utils import tree_unflatten +from pathlib import Path +from huggingface_hub import snapshot_download +from transformers import AutoTokenizer +import json + +from chuk_lazarus.models_v2.families.gemma import GemmaConfig, GemmaForCausalLM + +# Download model +model_id = "mlx-community/gemma-3-1b-it-bf16" +model_path = Path(snapshot_download(repo_id=model_id, allow_patterns=["*.json", "*.safetensors"])) + +# Load config +with open(model_path / "config.json") as f: + hf_config = json.load(f) + +config = GemmaConfig( + vocab_size=hf_config["vocab_size"], + hidden_size=hf_config["hidden_size"], + num_hidden_layers=hf_config["num_hidden_layers"], + num_attention_heads=hf_config["num_attention_heads"], + num_key_value_heads=hf_config.get("num_key_value_heads", hf_config["num_attention_heads"]), + intermediate_size=hf_config["intermediate_size"], + head_dim=hf_config.get("head_dim", 256), +) + +# Create model and load weights +model = GemmaForCausalLM(config) +weights = mx.load(str(model_path / "model.safetensors")) +nested = tree_unflatten(list(weights.items())) +model.update(nested) +mx.eval(model.parameters()) + +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + +# Generate +prompt = "user\nHello!\nmodel\n" +input_ids = mx.array(tokenizer.encode(prompt, return_tensors="np")) + +output_ids = model.generate( + input_ids, + max_new_tokens=100, + temperature=0.7, + stop_tokens=[tokenizer.eos_token_id, 106], # 106 is +) + +response = tokenizer.decode(output_ids[0, input_ids.shape[1]:].tolist(), skip_special_tokens=True) +print(response) +``` + +## Granite Inference + +IBM Granite models are available in dense (3.0, 3.1) and hybrid MoE (4.0) variants. + +### Running Granite Inference + +```bash +# Basic inference +uv run python examples/models/granite/01_granite_inference.py --prompt "What is machine learning?" + +# Use specific model +uv run python examples/models/granite/01_granite_inference.py \ + --model ibm-granite/granite-3.1-2b-instruct \ + --prompt "Explain neural networks" +``` + +### Available Granite Models + +| Model ID | Type | Parameters | Notes | +|----------|------|------------|-------| +| `ibm-granite/granite-3.0-8b-instruct` | Dense | 8B | Original Granite 3.0 | +| `ibm-granite/granite-3.1-2b-instruct` | Dense | 2B | Long context (128K) | +| `ibm-granite/granite-3.1-8b-instruct` | Dense | 8B | Long context (128K) | + +## Llama 4 Inference + +Meta's Llama 4 Scout model uses a hybrid Mamba-Transformer architecture with MoE for efficient long-context processing. + +### Running Llama 4 Inference + +```bash +# Basic inference +uv run python examples/models/llama4/01_llama4_inference.py --prompt "What is the future of AI?" + +# With custom parameters +uv run python examples/models/llama4/01_llama4_inference.py \ + --prompt "Write a story about space exploration" \ + --max-tokens 200 \ + --temperature 0.8 +``` + +### Available Llama 4 Models + +| Model ID | Parameters | Architecture | Notes | +|----------|------------|--------------|-------| +| `meta-llama/Llama-4-Scout-17B-16E-Instruct` | 17B active / 109B total | Hybrid Mamba-Transformer MoE | 16 experts, 10M context | + +**Note:** Llama 4 requires HuggingFace authentication. Run `huggingface-cli login` first. + +## FunctionGemma (Function Calling) + +FunctionGemma is a 270M parameter model from Google, designed specifically for on-device function calling. It's excellent for: +- Tool use / API calling +- MCP (Model Context Protocol) integration +- Lightweight RAG pipelines +- On-device agents + +### Running FunctionGemma + +```bash +# Run the FunctionGemma inference example +uv run python examples/models/gemma/01_functiongemma_inference.py +``` + +### How FunctionGemma Works + +FunctionGemma uses special tokens for structured function calling: +- `` / `` - Define available tools +- `` / `` - Model requests tool use +- `` / `` - Tool results +- `` - Wraps string values in structured data + +### Example with Tools + +```python +from huggingface_hub import hf_hub_download +from jinja2 import Template +from mlx_lm import generate, load + +# Load bf16 model (better accuracy than quantized for function calling) +model_name = "mlx-community/functiongemma-270m-it-bf16" +model, tokenizer = load(model_name) + +# Load Jinja2 chat template +template_path = hf_hub_download(model_name, "chat_template.jinja") +with open(template_path) as f: + chat_template = Template(f.read()) + +# Define tools in OpenAI-compatible format +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets current weather for a location.", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + }, + "required": ["location"], + }, + }, + }, +] + +# Format messages with tools +messages = [ + {"role": "developer", "content": "You can do function calling with these functions"}, + {"role": "user", "content": "What's the weather in Tokyo?"}, +] + +prompt = chat_template.render( + messages=messages, + tools=tools, + add_generation_prompt=True, + bos_token="", + eos_token="", +) + +# Generate +response = generate(model, tokenizer, prompt=prompt, max_tokens=100, verbose=False) +# Output: call:get_weather{location:Tokyo} +``` + +### Parsing Function Calls + +```python +import re + +def parse_function_call(response: str) -> dict | None: + """Parse function call from FunctionGemma output.""" + pattern = r"call:(\w+)\{(.+?)\}" + match = re.search(pattern, response, re.DOTALL) + + if match: + func_name = match.group(1) + args_str = match.group(2) + + # Parse arguments (handle tokens) + args = {} + arg_pattern = r"(\w+):([^<]+)" + for arg_match in re.finditer(arg_pattern, args_str): + args[arg_match.group(1)] = arg_match.group(2) + + return {"name": func_name, "arguments": args} + return None +``` + +### Model Selection + +| Model | Size | Quality | Use Case | +|-------|------|---------|----------| +| `functiongemma-270m-it-bf16` | ~540MB | Best | Production function calling | +| `functiongemma-270m-it-4bit` | ~135MB | Lower | Memory-constrained devices | + +**Note:** bf16 models provide significantly better function calling accuracy than quantized versions. Use 4-bit only when memory is severely constrained. + +### Using chuk-lazarus Native Implementation + +You can also use our native Gemma implementation directly: + +```python +import mlx.core as mx +from chuk_lazarus.models_v2.families.gemma import GemmaConfig, GemmaForCausalLM + +# Create config for FunctionGemma 270M +config = GemmaConfig.functiongemma_270m() + +# Create model +model = GemmaForCausalLM(config) + +# Load weights from mlx-community +# weights = mx.load("path/to/model.safetensors") +# model.update(weights) + +# Forward pass +test_input = mx.array([[1, 2, 3, 4, 5]]) +output = model(test_input) +print(f"Output shape: {output.logits.shape}") +``` + ## See Also - [Models Guide](models.md) - Architecture details - [Training Guide](training.md) - Fine-tuning models - [Examples](../examples/models/llama/) - Working inference examples +- [Examples](../examples/models/gemma/) - FunctionGemma examples diff --git a/docs/models.md b/docs/models.md index c80d1a74..c9d2a9be 100644 --- a/docs/models.md +++ b/docs/models.md @@ -467,6 +467,53 @@ output = token_clf(token_ids) print(f"Per-token logits: {output.logits.shape}") ``` +### Standalone Classifiers + +Simple classifiers for use without a full backbone: + +```python +from chuk_lazarus.models_v2.models.classifiers import ( + LinearClassifier, + MLPClassifier, + create_classifier, +) + +# LinearClassifier - single linear layer +linear_clf = LinearClassifier( + input_dim=768, + num_classes=5, + bias=True, +) +logits = linear_clf(hidden_states) # (batch, 5) + +# MLPClassifier - MLP with hidden layers +mlp_clf = MLPClassifier( + input_dim=768, + hidden_dim=256, + num_classes=5, + num_layers=2, + dropout=0.1, + activation="gelu", +) +logits = mlp_clf(hidden_states) # (batch, 5) + +# Factory function for easy creation +clf = create_classifier( + classifier_type="mlp", # or "linear" + input_dim=768, + num_classes=10, + hidden_dim=512, + num_layers=3, +) +``` + +| Classifier | Parameters | Use Case | +|------------|------------|----------| +| `LinearClassifier` | input_dim × num_classes | Simple classification, probing | +| `MLPClassifier` | Multiple layers | Complex classification tasks | +| `SequenceClassifier` | Backbone + head | Full sequence classification | +| `TokenClassifier` | Backbone + per-token head | NER, POS tagging | + ## Families Architecture-specific implementations with preset configurations. @@ -531,6 +578,161 @@ generated = model.generate( ) ``` +### Gemma Family + +```python +from chuk_lazarus.models_v2.families.gemma import GemmaConfig, GemmaForCausalLM + +# Preset configurations +config = GemmaConfig.tiny() # Testing +config = GemmaConfig.gemma3_270m() # 270M params (FunctionGemma base) +config = GemmaConfig.functiongemma_270m() # Same as 270M, tuned for function calling +config = GemmaConfig.gemma3_1b() # 1B params +config = GemmaConfig.gemma3_4b() # 4B params +config = GemmaConfig.gemma3_12b() # 12B params +config = GemmaConfig.gemma3_27b() # 27B params + +# Create model +model = GemmaForCausalLM(config) + +# Forward pass +output = model(token_ids) + +# Generate text +generated = model.generate( + input_ids=prompt_ids, + max_new_tokens=100, + temperature=0.7, +) +``` + +#### Gemma Architecture Features + +Gemma 3 has several unique architectural features: + +- **Alternating sliding window / global attention**: Every Nth layer uses global attention (pattern configurable) +- **Query/Key pre-normalization**: Q and K projections have separate RMSNorm layers +- **4 normalization layers per block**: Pre-attn, post-attn, pre-ffn, post-ffn norms +- **Gated GELU activation**: Uses `gelu(gate) * up` pattern in FFN +- **Embedding scaling**: Hidden states scaled by √hidden_size +- **GemmaNorm**: RMSNorm with `(1 + weight)` scaling + +```python +# Check which layers use sliding vs global attention +config = GemmaConfig.gemma3_270m() + +for i in range(config.num_hidden_layers): + if config.is_sliding_layer(i): + print(f"Layer {i}: sliding window ({config.sliding_window} tokens)") + else: + print(f"Layer {i}: global attention") +``` + +### Granite Family + +```python +from chuk_lazarus.models_v2.families.granite import ( + GraniteConfig, + GraniteHybridConfig, + GraniteForCausalLM, +) + +# Dense models (Granite 3.0, 3.1) +config = GraniteConfig.tiny() # Testing +config = GraniteConfig.granite_3_8b() # Granite 3.0 8B +config = GraniteConfig.granite_3_1_2b() # Granite 3.1 2B (128K context) +config = GraniteConfig.granite_3_1_8b() # Granite 3.1 8B (128K context) + +# Hybrid MoE models (Granite 4.0) +config = GraniteHybridConfig.tiny() # Testing +config = GraniteHybridConfig.tiny_moe() # Testing with MoE +config = GraniteHybridConfig.granite_4_micro() # Granite 4.0 Micro (dense) +config = GraniteHybridConfig.granite_4_tiny() # Granite 4.0 Tiny (MoE + Mamba) +config = GraniteHybridConfig.granite_4_small() # Granite 4.0 Small (MoE + Mamba) + +# Create model +model = GraniteForCausalLM(config) + +# Forward pass +output = model(token_ids) + +# Generate text +generated = model.generate( + input_ids=prompt_ids, + max_new_tokens=100, + temperature=0.7, +) +``` + +#### Granite Architecture Features + +Granite models have several unique features: + +- **muP scaling**: Embedding, attention, residual, and logits multipliers for stable training +- **Flexible normalization**: RMSNorm or LayerNorm, configurable position +- **GQA support**: Grouped-query attention for efficient inference +- **Long context**: 128K context for Granite 3.1 models +- **Hybrid architecture (4.0)**: Mamba + Attention layers with MoE + +```python +# Dense model features +config = GraniteConfig.granite_3_1_8b() +print(f"Embedding multiplier: {config.embedding_multiplier}") +print(f"Attention multiplier: {config.attention_multiplier}") +print(f"Logits scaling: {config.logits_scaling}") + +# Hybrid model features +config = GraniteHybridConfig.granite_4_tiny() +print(f"MoE: {config.is_moe}") +print(f"Mamba layers: {config.num_mamba_layers}") +print(f"Attention layers: {config.num_attention_layers}") +print(f"Experts: {config.num_local_experts} total, {config.num_experts_per_tok} per token") +``` + +### Llama 4 Family + +```python +from chuk_lazarus.models_v2.families.llama4 import ( + Llama4TextConfig, + Llama4ForCausalLM, +) + +# Preset configurations +config = Llama4TextConfig.tiny() # Testing +config = Llama4TextConfig.llama4_scout() # Llama 4 Scout 17B/109B + +# Create model +model = Llama4ForCausalLM(config) + +# Forward pass +output = model(token_ids) + +# Generate text +generated = model.generate( + input_ids=prompt_ids, + max_new_tokens=100, + temperature=0.7, +) +``` + +#### Llama 4 Architecture Features + +Llama 4 uses a novel hybrid architecture: + +- **Mamba-Transformer hybrid**: Interleaved Mamba2 and attention layers +- **Interleaved MoE**: MoE layers alternating with dense layers +- **Massive context**: Up to 10M tokens with efficient state-space layers +- **Shared + routed experts**: Shared expert always active, plus routed experts +- **NoPE (No Positional Encoding)**: Some models use no positional encoding + +```python +config = Llama4TextConfig.llama4_scout() +print(f"Hidden size: {config.hidden_size}") +print(f"Layers: {config.num_hidden_layers}") +print(f"Experts: {config.num_local_experts} total, {config.num_experts_per_tok} per token") +print(f"MoE layers: every {config.interleave_moe_layer_step} layer") +``` + ## Model Loading Async-native loading from local files or HuggingFace Hub. @@ -681,11 +883,19 @@ models_v2/ ├── models/ # Complete end-to-end │ ├── base.py # Model, ModelOutput │ ├── causal_lm.py # CausalLM -│ └── classifier.py # SequenceClassifier, TokenClassifier +│ └── classifiers/ # Classification models +│ ├── linear.py # LinearClassifier +│ ├── mlp.py # MLPClassifier +│ ├── sequence.py # SequenceClassifier +│ ├── token.py # TokenClassifier +│ └── factory.py # create_classifier() │ ├── families/ # Architecture-specific │ ├── llama/ # LlamaConfig, LlamaForCausalLM -│ └── mamba/ # MambaConfig, MambaForCausalLM +│ ├── llama4/ # Llama4TextConfig, Llama4ForCausalLM +│ ├── mamba/ # MambaConfig, MambaForCausalLM +│ ├── gemma/ # GemmaConfig, GemmaForCausalLM +│ └── granite/ # GraniteConfig, GraniteForCausalLM │ ├── adapters/ # Parameter-efficient fine-tuning │ └── lora.py # LoRAConfig, LoRALinear, apply_lora diff --git a/examples/README.md b/examples/README.md index 0bf79caa..7de26baa 100644 --- a/examples/README.md +++ b/examples/README.md @@ -37,9 +37,22 @@ examples/ ├── data/ # Data handling examples │ ├── generate_math_data.py │ └── create_sft_dataset.py -└── models/ # Model loading examples - ├── load_with_lora.py - └── model_config.py +└── models/ # Model inference examples + ├── gemma/ # Gemma family examples + │ ├── 01_functiongemma_inference.py # FunctionGemma tool calling + │ ├── 02_load_pretrained.py # Load pretrained weights + │ ├── 03_gemma3_inference.py # Gemma 3 text generation + │ └── 04_gemma3_vision_inference.py # Gemma 3 vision (multimodal) + ├── granite/ # IBM Granite examples + │ └── 01_granite_inference.py # Granite inference + ├── llama/ # Llama family examples + │ ├── 01_causal_lm.py # Basic causal LM + │ ├── 02_tinyllama_inference.py # TinyLlama inference + │ └── 03_llama_family_inference.py # Multi-model inference + ├── llama4/ # Llama 4 examples + │ └── 01_llama4_inference.py # Llama 4 Scout inference + ├── lora/ # LoRA examples + └── mlp/ # MLP classifier examples ``` ## Quick Start @@ -267,6 +280,80 @@ Features demonstrated: - Complete efficiency reports with recommendations - CLI equivalents: `lazarus data batching histogram`, `analyze`, `suggest` +## Model Inference Examples + +Run inference with various model families using pretrained weights from HuggingFace. + +### Gemma 3 + +```bash +# Basic inference +uv run python examples/models/gemma/03_gemma3_inference.py --prompt "What is the capital of France?" + +# Interactive chat +uv run python examples/models/gemma/03_gemma3_inference.py --chat + +# Use larger model +uv run python examples/models/gemma/03_gemma3_inference.py \ + --model mlx-community/gemma-3-4b-it-bf16 \ + --prompt "Explain machine learning" +``` + +**Available models:** `gemma-3-1b-it-bf16`, `gemma-3-4b-it-bf16`, `gemma-3-12b-it-bf16`, `gemma-3-27b-it-bf16` + +### Gemma 3 Vision (Multimodal) + +```bash +# Image understanding +uv run python examples/models/gemma/04_gemma3_vision_inference.py \ + --image /path/to/image.jpg \ + --prompt "What is in this image?" + +# Detailed description +uv run python examples/models/gemma/04_gemma3_vision_inference.py \ + --image photo.jpg \ + --prompt "Describe this image in detail" \ + --max-tokens 200 +``` + +**Available models:** `gemma-3-4b-it-bf16` (4B), `gemma-3-12b-it-bf16` (12B), `gemma-3-27b-it-bf16` (27B) + +### FunctionGemma (Tool Calling) + +```bash +# Function calling example +uv run python examples/models/gemma/01_functiongemma_inference.py +``` + +### Llama Family + +```bash +# List available models +uv run python examples/models/llama/03_llama_family_inference.py --list-models + +# Run with different models +uv run python examples/models/llama/03_llama_family_inference.py --model tinyllama +uv run python examples/models/llama/03_llama_family_inference.py --model smollm2-360m +``` + +**Available presets:** `tinyllama`, `smollm2-135m`, `smollm2-360m`, `smollm2-1.7b`, `llama2-7b`, `llama3.2-1b`, `mistral-7b` + +### Granite + +```bash +# Basic inference +uv run python examples/models/granite/01_granite_inference.py --prompt "What is machine learning?" +``` + +### Llama 4 + +```bash +# Llama 4 Scout inference (requires HF auth) +uv run python examples/models/llama4/01_llama4_inference.py --prompt "Explain quantum computing" +``` + +**Note:** Llama 4 requires HuggingFace authentication. Run `huggingface-cli login` first. + ## Running Examples ```bash diff --git a/examples/inference/gemma_inference.py b/examples/inference/gemma_inference.py new file mode 100644 index 00000000..d161cfdb --- /dev/null +++ b/examples/inference/gemma_inference.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python3 +""" +Gemma Inference Example (Simplified) + +Demonstrates the simplified API for Gemma text models. +For vision/multimodal, see examples/models/gemma/04_gemma3_vision_inference.py. + +Supports: +- Gemma 3 270M (smallest, fastest) +- FunctionGemma 270M (function calling optimized) +- Gemma 3 1B (text-only) +- Gemma 3 4B (multimodal - text-only mode here) +- Gemma 3 12B +- Gemma 3 27B + +Usage: + # Default: Gemma 3 1B + uv run python examples/inference/gemma_inference.py + + # Gemma 3 270M (smallest) + uv run python examples/inference/gemma_inference.py --model gemma3-270m + + # FunctionGemma 270M (function calling) + uv run python examples/inference/gemma_inference.py --model functiongemma + + # Gemma 3 4B + uv run python examples/inference/gemma_inference.py --model gemma3-4b + + # Custom prompt + uv run python examples/inference/gemma_inference.py --prompt "Write a haiku about MLX" + + # Interactive chat mode + uv run python examples/inference/gemma_inference.py --chat + + # List available models + uv run python examples/inference/gemma_inference.py --list +""" + +from __future__ import annotations + +import argparse +import json +import time +from enum import Enum +from pathlib import Path + +import mlx.core as mx +from mlx.utils import tree_unflatten + +from chuk_lazarus.inference import ( + ChatHistory, + DType, + GenerationConfig, + HFLoader, + Role, +) +from chuk_lazarus.models_v2.families.gemma import GemmaConfig, GemmaForCausalLM + + +class GemmaModel(str, Enum): + """Available Gemma model presets.""" + + # Gemma 3 270M (smallest, fastest) + GEMMA3_270M = "mlx-community/gemma-3-270m-it-bf16" + + # FunctionGemma (270M - function calling optimized) + FUNCTIONGEMMA_270M = "mlx-community/functiongemma-270m-it-bf16" + + # Gemma 3 family + GEMMA3_1B = "mlx-community/gemma-3-1b-it-bf16" + GEMMA3_4B = "mlx-community/gemma-3-4b-it-bf16" + GEMMA3_12B = "mlx-community/gemma-3-12b-it-bf16" + GEMMA3_27B = "mlx-community/gemma-3-27b-it-bf16" + + +MODEL_ALIASES = { + # Gemma 3 270M + "gemma3-270m": GemmaModel.GEMMA3_270M, + # FunctionGemma + "functiongemma": GemmaModel.FUNCTIONGEMMA_270M, + "functiongemma-270m": GemmaModel.FUNCTIONGEMMA_270M, + # Gemma 3 + "gemma3-1b": GemmaModel.GEMMA3_1B, + "gemma3-4b": GemmaModel.GEMMA3_4B, + "gemma3-12b": GemmaModel.GEMMA3_12B, + "gemma3-27b": GemmaModel.GEMMA3_27B, +} + + +class GemmaToken(int, Enum): + """Special token IDs for Gemma 3.""" + + END_OF_TURN = 106 + + +def load_gemma_config(model_path: Path, weights: dict | None = None) -> GemmaConfig: + """Load and convert HuggingFace config to GemmaConfig. + + Handles both text-only (1B) and multimodal (4B+) config formats. + """ + config_path = model_path / "config.json" + with open(config_path) as f: + hf_config = json.load(f) + + # Handle multimodal models with nested text_config + if "text_config" in hf_config: + text_config = hf_config["text_config"] + model_type = text_config.get( + "model_type", hf_config.get("model_type", "gemma3_text") + ) + else: + text_config = hf_config + model_type = hf_config.get("model_type", "gemma3_text") + + hidden_size = text_config["hidden_size"] + head_dim = text_config.get("head_dim", 256) + + # Try to get num_attention_heads from config, otherwise infer from weights + num_attention_heads = text_config.get("num_attention_heads") + num_key_value_heads = text_config.get("num_key_value_heads") + + if weights and (num_attention_heads is None or num_key_value_heads is None): + # Infer from weight shapes + for k, v in weights.items(): + if "layers.0" in k and "self_attn.q_proj.weight" in k: + if num_attention_heads is None: + num_attention_heads = v.shape[0] // head_dim + print(f" Inferred num_attention_heads={num_attention_heads}") + if "layers.0" in k and "self_attn.k_proj.weight" in k: + if num_key_value_heads is None: + num_key_value_heads = v.shape[0] // head_dim + print(f" Inferred num_key_value_heads={num_key_value_heads}") + break + + # Fallback defaults + if num_attention_heads is None: + num_attention_heads = hidden_size // head_dim + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + return GemmaConfig( + model_type=model_type, + vocab_size=text_config.get("vocab_size", 262144), + hidden_size=hidden_size, + num_hidden_layers=text_config["num_hidden_layers"], + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=text_config["intermediate_size"], + head_dim=head_dim, + query_pre_attn_scalar=text_config.get("query_pre_attn_scalar", 256.0), + sliding_window=text_config.get("sliding_window", 512), + sliding_window_pattern=text_config.get( + "sliding_window_pattern", text_config.get("_sliding_window_pattern", 6) + ), + max_position_embeddings=text_config.get("max_position_embeddings", 32768), + rope_theta=text_config.get("rope_theta", 1000000.0), + rope_local_base_freq=text_config.get("rope_local_base_freq", 10000.0), + rms_norm_eps=text_config.get("rms_norm_eps", 1e-6), + ) + + +def load_gemma_weights(model_path: Path, text_only: bool = True) -> dict: + """Load weights from safetensors, optionally filtering vision weights.""" + weights = {} + for sf_path in sorted(model_path.glob("*.safetensors")): + print(f" Loading {sf_path.name}...") + file_weights = mx.load(str(sf_path)) + weights.update(file_weights) + + if text_only: + # Filter vision weights for multimodal models + text_weights = {} + skipped = 0 + for k, v in weights.items(): + if any(prefix in k for prefix in ["vision_tower", "multi_modal_projector"]): + skipped += 1 + continue + # Rename language_model.* to model.* for compatibility + if k.startswith("language_model."): + k = k.replace("language_model.", "", 1) + text_weights[k] = v + if skipped > 0: + print(f" Filtered {skipped} vision tensors (text-only mode)") + return text_weights + + return weights + + +def load_gemma_model(model_id: str): + """Load Gemma model, tokenizer, and config.""" + print(f"Loading {model_id}...") + print("=" * 60) + + # Download + print("\n1. Downloading model...") + result = HFLoader.download(model_id) + print(f" Path: {result.model_path}") + + # Load weights first (needed to infer config for multimodal models) + print("\n2. Loading weights...") + weights = load_gemma_weights(result.model_path) + print(f" Loaded {len(weights)} tensors") + + # Load config + print("\n3. Loading configuration...") + config = load_gemma_config(result.model_path, weights) + print(f" Layers: {config.num_hidden_layers}, Hidden: {config.hidden_size}") + print(f" Heads: {config.num_attention_heads} attn, {config.num_key_value_heads} kv") + + # Load tokenizer + print("\n4. Loading tokenizer...") + tokenizer = HFLoader.load_tokenizer(result.model_path) + print(f" Vocab size: {len(tokenizer)}") + + # Create model + print("\n5. Creating model...") + model = GemmaForCausalLM(config) + + # Apply weights using tree_unflatten (Gemma convention) + print("\n6. Applying weights...") + nested_weights = tree_unflatten(list(weights.items())) + model.update(nested_weights) + mx.eval(model.parameters()) + print(" Done!") + + print("\n" + "=" * 60) + print("Model loaded successfully!") + + return model, tokenizer, config + + +def generate( + model: GemmaForCausalLM, + tokenizer, + prompt: str, + config: GenerationConfig | None = None, +) -> tuple[str, float]: + """Generate text and return (text, tokens_per_sec).""" + if config is None: + config = GenerationConfig() + + # Tokenize + input_ids = tokenizer.encode(prompt, return_tensors="np") + input_ids = mx.array(input_ids) + input_length = input_ids.shape[1] + + # Stop tokens + stop_tokens = [tokenizer.eos_token_id, GemmaToken.END_OF_TURN.value] + + # Generate + start = time.time() + output_ids = model.generate( + input_ids, + max_new_tokens=config.max_new_tokens, + temperature=config.temperature, + top_k=config.top_k, + top_p=config.top_p, + stop_tokens=stop_tokens, + ) + mx.eval(output_ids) + gen_time = time.time() - start + + # Decode + new_tokens = output_ids[0, input_length:].tolist() + text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() + + tokens_per_sec = len(new_tokens) / gen_time if gen_time > 0 else 0 + return text, tokens_per_sec + + +def format_gemma_prompt(user_message: str, system_message: str | None = None) -> str: + """Format prompt using Gemma 3 chat template.""" + if system_message: + return f"user\n{system_message}\n\n{user_message}\nmodel\n" + return f"user\n{user_message}\nmodel\n" + + +def chat_loop(model: GemmaForCausalLM, tokenizer, model_config: GemmaConfig): + """Interactive chat loop.""" + print("\n" + "=" * 60) + print("Gemma 3 Chat") + print("=" * 60) + print("Type 'quit' to exit, 'clear' to reset conversation") + print("-" * 60) + + history = ChatHistory() + gen_config = GenerationConfig(max_new_tokens=512, temperature=0.7, top_k=40, top_p=0.95) + + while True: + try: + user_input = input("\nYou: ").strip() + except (KeyboardInterrupt, EOFError): + print("\nGoodbye!") + break + + if not user_input: + continue + if user_input.lower() == "quit": + print("Goodbye!") + break + if user_input.lower() == "clear": + history.clear() + print("Conversation cleared.") + continue + + # Add to history + history.add_user(user_input) + + # Build conversation + conv_text = "" + for msg in history.messages: + role = "user" if msg.role == Role.USER else "model" + conv_text += f"{role}\n{msg.content}\n" + + prompt = f"{conv_text}model\n" + + # Check context length + input_ids = tokenizer.encode(prompt, return_tensors="np") + if input_ids.shape[1] > model_config.max_position_embeddings - 256: + print("Warning: Context too long, truncating history...") + history.messages = history.messages[-4:] + continue + + # Generate + print("\nGemma: ", end="", flush=True) + response, tps = generate(model, tokenizer, prompt, gen_config) + print(response) + print(f" [{tps:.1f} tok/s]") + + # Add response to history + history.add_assistant(response) + + +def test_tiny(): + """Test tiny model config.""" + print("=" * 60) + print("Gemma Tiny Model Test") + print("=" * 60) + + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + mx.eval(output.logits) + print(f"Forward: OK (shape={output.logits.shape})") + + gen = model.generate(input_ids, max_new_tokens=5) + mx.eval(gen) + print(f"Generate: OK (shape={gen.shape})") + + print("\nSUCCESS!") + + +def main(): + parser = argparse.ArgumentParser(description="Gemma 3 Inference (Simplified)") + parser.add_argument( + "--model", + choices=list(MODEL_ALIASES.keys()), + default="gemma3-1b", + help="Model preset", + ) + parser.add_argument("--model-id", help="Custom HuggingFace model ID") + parser.add_argument("--test-tiny", action="store_true", help="Run tiny test") + parser.add_argument( + "--prompt", + default="What is the capital of France?", + help="User prompt", + ) + parser.add_argument("--max-tokens", type=int, default=256) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--chat", action="store_true", help="Interactive chat mode") + parser.add_argument("--list", action="store_true", help="List models") + args = parser.parse_args() + + if args.test_tiny: + test_tiny() + return + + if args.list: + print("Available Gemma 3 models:\n") + for alias, model in MODEL_ALIASES.items(): + print(f" {alias:12} -> {model.value}") + return + + # Get model ID + model_id = args.model_id or MODEL_ALIASES[args.model].value + + # Load model + model, tokenizer, config = load_gemma_model(model_id) + + if args.chat: + chat_loop(model, tokenizer, config) + return + + # Single prompt mode + print("\n" + "=" * 60) + print(f"User: {args.prompt}") + print("-" * 60) + + prompt = format_gemma_prompt(args.prompt) + gen_config = GenerationConfig( + max_new_tokens=args.max_tokens, + temperature=args.temperature, + top_k=40, + top_p=0.95, + ) + + response, tps = generate(model, tokenizer, prompt, gen_config) + + print(f"Gemma: {response}") + print("-" * 60) + print(f"Speed: {tps:.1f} tokens/sec") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/granite_inference.py b/examples/inference/granite_inference.py new file mode 100644 index 00000000..3d8f7f4b --- /dev/null +++ b/examples/inference/granite_inference.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +IBM Granite Inference Example (Simplified) + +Demonstrates the simplified API for Granite models. +Supports both Granite 3.x (dense) and 4.x (hybrid Mamba/Transformer). + +Usage: + # Test tiny (no download) + uv run python examples/inference/granite_inference.py --test-tiny + + # Granite 3.1 2B + uv run python examples/inference/granite_inference.py --model granite-3.1-2b + + # Granite 4.0 Micro + uv run python examples/inference/granite_inference.py --model granite-4.0-micro +""" + +from __future__ import annotations + +import argparse +from enum import Enum + +import mlx.core as mx + +from chuk_lazarus.inference import ( + DType, + InferencePipeline, + PipelineConfig, +) +from chuk_lazarus.models_v2 import ( + GraniteConfig, + GraniteForCausalLM, + GraniteHybridConfig, + GraniteHybridForCausalLM, + count_parameters, +) + + +class GraniteModelType(str, Enum): + """Granite model architecture types.""" + + DENSE = "granite" # Granite 3.x + HYBRID = "granitemoehybrid" # Granite 4.x + + +class GraniteModel(str, Enum): + """Available Granite model presets.""" + + # Granite 3.x (Dense) + GRANITE_3_1_2B = "ibm-granite/granite-3.1-2b-instruct" + GRANITE_3_1_8B = "ibm-granite/granite-3.1-8b-instruct" + GRANITE_3_3_2B = "ibm-granite/granite-3.3-2b-instruct" + GRANITE_3_3_8B = "ibm-granite/granite-3.3-8b-instruct" + + # Granite 4.x (Hybrid) + GRANITE_4_0_MICRO = "ibm-granite/granite-4.0-micro" + GRANITE_4_0_TINY = "ibm-granite/granite-4.0-tiny-preview" + + +MODEL_ALIASES = { + "granite-3.1-2b": (GraniteModel.GRANITE_3_1_2B, GraniteModelType.DENSE), + "granite-3.1-8b": (GraniteModel.GRANITE_3_1_8B, GraniteModelType.DENSE), + "granite-3.3-2b": (GraniteModel.GRANITE_3_3_2B, GraniteModelType.DENSE), + "granite-3.3-8b": (GraniteModel.GRANITE_3_3_8B, GraniteModelType.DENSE), + "granite-4.0-micro": (GraniteModel.GRANITE_4_0_MICRO, GraniteModelType.HYBRID), + "granite-4.0-tiny": (GraniteModel.GRANITE_4_0_TINY, GraniteModelType.HYBRID), +} + + +def test_tiny(): + """Test tiny model configurations without downloading.""" + print("=" * 60) + print("Granite Tiny Model Tests") + print("=" * 60) + + # Test Granite 3.x + print("\n1. Testing Granite 3.x (dense)...") + config3 = GraniteConfig.tiny() + model3 = GraniteForCausalLM(config3) + params3 = count_parameters(model3) + print(f" {params3.summary()}") + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output3 = model3(input_ids) + mx.eval(output3.logits) + print(f" Forward: OK (shape={output3.logits.shape})") + + # Test Granite 4.x + print("\n2. Testing Granite 4.x (hybrid)...") + config4 = GraniteHybridConfig.tiny() + model4 = GraniteHybridForCausalLM(config4) + params4 = count_parameters(model4) + print(f" {params4.summary()}") + + output4 = model4(input_ids) + mx.eval(output4.logits) + print(f" Forward: OK (shape={output4.logits.shape})") + + print("\n" + "=" * 60) + print("SUCCESS! All tiny tests passed.") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Granite Inference (Simplified)") + parser.add_argument( + "--model", + choices=list(MODEL_ALIASES.keys()), + default="granite-3.1-2b", + help="Model preset", + ) + parser.add_argument("--model-id", help="Custom HuggingFace model ID") + parser.add_argument("--test-tiny", action="store_true", help="Run tiny tests") + parser.add_argument( + "--prompt", + default="What is the capital of France?", + help="User prompt", + ) + parser.add_argument( + "--system", + default="You are a helpful assistant.", + help="System message", + ) + parser.add_argument("--max-tokens", type=int, default=100) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--list", action="store_true", help="List models") + args = parser.parse_args() + + if args.test_tiny: + test_tiny() + return + + if args.list: + print("Available Granite models:\n") + for alias, (model, model_type) in MODEL_ALIASES.items(): + print(f" {alias:20} -> {model.value} ({model_type.value})") + return + + # Get model info + if args.model_id: + model_id = args.model_id + # Default to dense for custom models + model_type = GraniteModelType.DENSE + else: + model_enum, model_type = MODEL_ALIASES[args.model] + model_id = model_enum.value + + # Select appropriate model/config classes + if model_type == GraniteModelType.HYBRID: + model_class = GraniteHybridForCausalLM + config_class = GraniteHybridConfig + else: + model_class = GraniteForCausalLM + config_class = GraniteConfig + + # Load with pipeline + pipeline = InferencePipeline.from_pretrained( + model_id, + model_class, + config_class, + pipeline_config=PipelineConfig( + dtype=DType.BFLOAT16, + default_system_message=args.system, + default_max_tokens=args.max_tokens, + default_temperature=args.temperature, + ), + ) + + # Generate + print("\n" + "=" * 60) + print(f"User: {args.prompt}") + print("-" * 60) + + result = pipeline.chat(args.prompt) + + print(f"Assistant: {result.text}") + print("-" * 60) + print(f"Stats: {result.stats.summary}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/llama4_inference.py b/examples/inference/llama4_inference.py new file mode 100644 index 00000000..4429b8bd --- /dev/null +++ b/examples/inference/llama4_inference.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 +""" +Llama 4 Inference Example (Simplified) + +Demonstrates the simplified API for Llama 4 MoE models. + +Llama 4 key features: +- MoE (Mixture of Experts) with shared expert +- iRoPE (interleaved RoPE and NoPE layers) +- QK normalization + +Note: Llama 4 models require significant memory: +- Scout: ~27GB for BF16 +- Maverick: ~100GB for BF16 + +Usage: + # Test tiny (no download) + uv run python examples/inference/llama4_inference.py --test-tiny + + # Llama 4 Scout (requires HF auth + ~27GB RAM) + uv run python examples/inference/llama4_inference.py --model llama4-scout +""" + +from __future__ import annotations + +import argparse +import json +import re +import time +from enum import Enum +from pathlib import Path + +import mlx.core as mx + +from chuk_lazarus.inference import ( + DType, + GenerationConfig, + HFLoader, + WeightConverter, +) +from chuk_lazarus.models_v2 import ( + Llama4ForCausalLM, + Llama4TextConfig, + count_parameters, +) + + +class Llama4Model(str, Enum): + """Available Llama 4 model presets.""" + + SCOUT = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + MAVERICK = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" + + +MODEL_ALIASES = { + "llama4-scout": Llama4Model.SCOUT, + "llama4-maverick": Llama4Model.MAVERICK, +} + + +class Llama4WeightConverter: + """Weight converter for Llama 4 MoE models. + + Handles the unique structure of Llama 4: + - Routed expert weights need fusion into SwitchGLU format + - Shared expert mapping + - Router weights + """ + + def __init__(self, config: Llama4TextConfig): + self.config = config + self.expert_weights: dict[int, dict[str, dict[int, mx.array]]] = {} + + def convert(self, hf_name: str) -> str | None: + """Convert HuggingFace weight name to framework format.""" + # Embeddings + if hf_name == "model.embed_tokens.weight": + return "model.embed_tokens.weight.weight" + + # Final norm + if hf_name == "model.norm.weight": + return "model.norm.weight" + + # LM head + if hf_name == "lm_head.weight": + if self.config.tie_word_embeddings: + return None + return "lm_head.lm_head.weight" + + # Layer pattern + layer_match = re.match(r"model\.layers\.(\d+)\.(.*)", hf_name) + if layer_match: + layer_idx = layer_match.group(1) + rest = layer_match.group(2) + + # Skip rotary embeddings + if "rotary_emb" in rest: + return None + + # Skip routed experts - handled separately for fusion + if re.match(r"(?:feed_forward|mlp)\.experts\.\d+\.", rest): + return None + + # Attention projections + if rest.startswith("self_attn."): + return f"model.layers.{layer_idx}.{rest}" + + # Layer norms + if rest in ("input_layernorm.weight", "post_attention_layernorm.weight"): + return f"model.layers.{layer_idx}.{rest}" + + # MoE components + if rest.startswith("feed_forward.") or rest.startswith("mlp."): + if rest.startswith("mlp."): + rest = rest.replace("mlp.", "feed_forward.", 1) + + # Router + if rest == "feed_forward.router.weight": + return f"model.layers.{layer_idx}.mlp.router.weight" + + # Shared expert + if rest.startswith("feed_forward.shared_expert."): + sub = rest.replace("feed_forward.shared_expert.", "") + return f"model.layers.{layer_idx}.mlp.shared_expert.{sub}" + + # Standard MLP (non-MoE fallback) + if rest.startswith("mlp."): + return f"model.layers.{layer_idx}.{rest}" + + return f"model.layers.{layer_idx}.{rest}" + + return None + + +def load_llama4_weights( + model_path: Path, + config: Llama4TextConfig, + dtype: DType = DType.BFLOAT16, +) -> dict: + """Load and fuse Llama 4 MoE weights.""" + safetensor_files = sorted(model_path.glob("*.safetensors")) + if not safetensor_files: + raise FileNotFoundError(f"No safetensors in {model_path}") + + target_dtype = dtype.to_mlx() + converter = Llama4WeightConverter(config) + + # Collect expert weights for fusion + expert_weights: dict[int, dict[str, dict[int, mx.array]]] = {} + flat_weights: dict[str, mx.array] = {} + + for sf_path in safetensor_files: + print(f" Loading {sf_path.name}...") + weights = mx.load(str(sf_path)) + + for hf_name, weight in weights.items(): + # Convert dtype + if weight.dtype in (mx.float32, mx.float16, mx.bfloat16): + weight = weight.astype(target_dtype) + + # Check for routed expert weights + expert_match = re.match( + r"model\.layers\.(\d+)\.(?:feed_forward|mlp)\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", + hf_name, + ) + if expert_match: + layer_idx = int(expert_match.group(1)) + expert_idx = int(expert_match.group(2)) + proj_type = expert_match.group(3) + + if layer_idx not in expert_weights: + expert_weights[layer_idx] = {} + if proj_type not in expert_weights[layer_idx]: + expert_weights[layer_idx][proj_type] = {} + + expert_weights[layer_idx][proj_type][expert_idx] = weight + continue + + # Convert other weights + our_name = converter.convert(hf_name) + if our_name is not None: + flat_weights[our_name] = weight + + del weights + mx.eval([]) + + # Fuse expert weights into SwitchGLU format + print(" Fusing expert weights...") + for layer_idx, proj_dict in expert_weights.items(): + for proj_type, experts_dict in proj_dict.items(): + num_experts = len(experts_dict) + expert_list = [experts_dict[i] for i in range(num_experts)] + fused = mx.stack(expert_list, axis=0) + + our_name = f"model.layers.{layer_idx}.mlp.experts.{proj_type}.weight" + flat_weights[our_name] = fused + + del expert_weights + mx.eval([]) + + return _build_nested_weights(flat_weights, config) + + +def _build_nested_weights(flat_weights: dict[str, mx.array], config) -> dict: + """Build nested structure for model.update().""" + max_layer_idx = config.num_hidden_layers - 1 + + nested: dict = {} + for name, weight in flat_weights.items(): + parts = name.split(".") + current = nested + i = 0 + while i < len(parts) - 1: + part = parts[i] + if part == "layers": + if part not in current: + current[part] = [{} for _ in range(max_layer_idx + 1)] + layer_idx = int(parts[i + 1]) + current = current[part][layer_idx] + i += 2 + else: + if part not in current: + current[part] = {} + current = current[part] + i += 1 + current[parts[-1]] = weight + + return nested + + +def load_llama4_model(model_id: str): + """Load Llama 4 model, tokenizer, and config.""" + print(f"Loading {model_id}...") + print("=" * 60) + + # Download + print("\n1. Downloading model...") + result = HFLoader.download(model_id) + print(f" Path: {result.model_path}") + + # Load config + print("\n2. Loading configuration...") + config_path = result.model_path / "config.json" + with open(config_path) as f: + config_data = json.load(f) + + # Handle list token IDs + for key in ("eos_token_id", "bos_token_id", "pad_token_id"): + if key in config_data and isinstance(config_data[key], list): + config_data[key] = config_data[key][0] if config_data[key] else None + + config = Llama4TextConfig(**config_data) + print(f" Hidden: {config.hidden_size}, Layers: {config.num_hidden_layers}") + print(f" Experts: {config.num_local_experts}, Active: {config.num_experts_per_tok}") + + # Create model + print("\n3. Creating model...") + model = Llama4ForCausalLM(config) + params = count_parameters(model) + print(f" {params.summary()}") + + # Load weights with MoE fusion + print("\n4. Loading weights...") + weights = load_llama4_weights(result.model_path, config) + model.update(weights) + mx.eval(model.parameters()) + print(" Done!") + + # Load tokenizer + print("\n5. Loading tokenizer...") + tokenizer = HFLoader.load_tokenizer(result.model_path) + print(f" Vocab size: {len(tokenizer)}") + + print("\n" + "=" * 60) + print("Model loaded successfully!") + + return model, tokenizer, config + + +def generate(model, tokenizer, prompt: str, config: GenerationConfig | None = None): + """Generate text from prompt.""" + if config is None: + config = GenerationConfig() + + # Tokenize + input_ids = tokenizer.encode(prompt, return_tensors="np") + input_ids = mx.array(input_ids) + input_length = input_ids.shape[1] + + # Stop tokens + stop_tokens = [] + if tokenizer.eos_token_id: + if isinstance(tokenizer.eos_token_id, list): + stop_tokens.extend(tokenizer.eos_token_id) + else: + stop_tokens.append(tokenizer.eos_token_id) + + # Generate + start = time.time() + output_ids = model.generate( + input_ids, + max_new_tokens=config.max_new_tokens, + temperature=config.temperature, + stop_tokens=stop_tokens, + ) + mx.eval(output_ids) + gen_time = time.time() - start + + # Decode + new_tokens = output_ids[0, input_length:].tolist() + text = tokenizer.decode(new_tokens, skip_special_tokens=True) + + tps = len(new_tokens) / gen_time if gen_time > 0 else 0 + return text, tps + + +def format_chat(tokenizer, user_message: str, system_message: str | None = None) -> str: + """Format using tokenizer's chat template.""" + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + messages.append({"role": "user", "content": user_message}) + + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + try: + return tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + pass + + # Fallback + prompt = "" + if system_message: + prompt += f"System: {system_message}\n\n" + prompt += f"User: {user_message}\n\nAssistant:" + return prompt + + +def test_tiny(): + """Test with tiny config.""" + print("=" * 60) + print("Llama 4 Tiny Model Test") + print("=" * 60) + + config = Llama4TextConfig.tiny() + print(f"\nConfig:") + print(f" Hidden: {config.hidden_size}") + print(f" Layers: {config.num_hidden_layers}") + print(f" Experts: {config.num_local_experts}") + + model = Llama4ForCausalLM(config) + params = count_parameters(model) + print(f" {params.summary()}") + + # Test forward + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + mx.eval(output.logits) + print(f"\nForward: OK (shape={output.logits.shape})") + + # Test generate + gen = model.generate(input_ids, max_new_tokens=5) + mx.eval(gen) + print(f"Generate: OK (shape={gen.shape})") + + print("\n" + "=" * 60) + print("SUCCESS!") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Llama 4 Inference (Simplified)") + parser.add_argument( + "--model", + choices=list(MODEL_ALIASES.keys()), + default="llama4-scout", + help="Model preset", + ) + parser.add_argument("--model-id", help="Custom HuggingFace model ID") + parser.add_argument("--test-tiny", action="store_true", help="Run tiny test") + parser.add_argument("--prompt", default="What is the capital of France?") + parser.add_argument("--system", default="You are a helpful assistant.") + parser.add_argument("--max-tokens", type=int, default=100) + parser.add_argument("--temperature", type=float, default=0.7) + parser.add_argument("--list", action="store_true", help="List models") + args = parser.parse_args() + + if args.test_tiny: + test_tiny() + return + + if args.list: + print("Available Llama 4 models:\n") + for alias, model in MODEL_ALIASES.items(): + print(f" {alias:18} -> {model.value}") + return + + # Get model ID + model_id = args.model_id or MODEL_ALIASES[args.model].value + + # Load model + model, tokenizer, config = load_llama4_model(model_id) + + # Generate + print("\n" + "=" * 60) + print(f"User: {args.prompt}") + print("-" * 60) + + prompt = format_chat(tokenizer, args.prompt, args.system) + gen_config = GenerationConfig( + max_new_tokens=args.max_tokens, + temperature=args.temperature, + ) + + response, tps = generate(model, tokenizer, prompt, gen_config) + + print(f"Assistant: {response}") + print("-" * 60) + print(f"Speed: {tps:.1f} tokens/sec") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/llama_inference.py b/examples/inference/llama_inference.py new file mode 100644 index 00000000..0ea3f33e --- /dev/null +++ b/examples/inference/llama_inference.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Llama Family Inference Example (Simplified) + +Demonstrates the new simplified API for Llama-family models. +This replaces the 580+ line examples/models/llama/03_llama_family_inference.py +with a much cleaner implementation. + +Supports: +- TinyLlama (1.1B) +- SmolLM2 (135M, 360M, 1.7B) +- Llama 2/3/3.1/3.2 +- Mistral 7B + +Usage: + # Default: TinyLlama + uv run python examples/inference/llama_inference.py + + # SmolLM2 (no auth required) + uv run python examples/inference/llama_inference.py --model smollm2-360m + + # Llama 3.2 1B (requires HF auth) + uv run python examples/inference/llama_inference.py --model llama3.2-1b + + # List models + uv run python examples/inference/llama_inference.py --list +""" + +from __future__ import annotations + +import argparse +from enum import Enum + +from chuk_lazarus.inference import ( + DType, + InferencePipeline, + PipelineConfig, +) +from chuk_lazarus.models_v2 import LlamaConfig, LlamaForCausalLM + + +class LlamaModel(str, Enum): + """Available Llama-family model presets.""" + + # TinyLlama + TINYLLAMA = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + # SmolLM2 (no auth required) + SMOLLM2_135M = "HuggingFaceTB/SmolLM2-135M-Instruct" + SMOLLM2_360M = "HuggingFaceTB/SmolLM2-360M-Instruct" + SMOLLM2_1_7B = "HuggingFaceTB/SmolLM2-1.7B-Instruct" + + # Llama 2 (requires auth) + LLAMA2_7B = "meta-llama/Llama-2-7b-chat-hf" + LLAMA2_13B = "meta-llama/Llama-2-13b-chat-hf" + + # Llama 3.2 (requires auth) + LLAMA3_2_1B = "meta-llama/Llama-3.2-1B-Instruct" + LLAMA3_2_3B = "meta-llama/Llama-3.2-3B-Instruct" + + # Llama 3.1 (requires auth) + LLAMA3_1_8B = "meta-llama/Llama-3.1-8B-Instruct" + + # Mistral + MISTRAL_7B = "mistralai/Mistral-7B-Instruct-v0.3" + + +# Short aliases for CLI +MODEL_ALIASES = { + "tinyllama": LlamaModel.TINYLLAMA, + "smollm2-135m": LlamaModel.SMOLLM2_135M, + "smollm2-360m": LlamaModel.SMOLLM2_360M, + "smollm2-1.7b": LlamaModel.SMOLLM2_1_7B, + "llama2-7b": LlamaModel.LLAMA2_7B, + "llama2-13b": LlamaModel.LLAMA2_13B, + "llama3.2-1b": LlamaModel.LLAMA3_2_1B, + "llama3.2-3b": LlamaModel.LLAMA3_2_3B, + "llama3.1-8b": LlamaModel.LLAMA3_1_8B, + "mistral-7b": LlamaModel.MISTRAL_7B, +} + + +def main(): + parser = argparse.ArgumentParser( + description="Llama Family Inference (Simplified)", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--model", + choices=list(MODEL_ALIASES.keys()), + default="tinyllama", + help="Model preset", + ) + parser.add_argument( + "--model-id", + help="Custom HuggingFace model ID (overrides --model)", + ) + parser.add_argument( + "--prompt", + default="What is the capital of France?", + help="User prompt", + ) + parser.add_argument( + "--system", + default="You are a helpful assistant.", + help="System message", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Max tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature", + ) + parser.add_argument( + "--list", + action="store_true", + help="List available models", + ) + args = parser.parse_args() + + # List mode + if args.list: + print("Available Llama-family models:\n") + for alias, model in MODEL_ALIASES.items(): + print(f" {alias:15} -> {model.value}") + return + + # Get model ID + model_id = args.model_id or MODEL_ALIASES[args.model].value + + # Load model with pipeline + pipeline = InferencePipeline.from_pretrained( + model_id, + LlamaForCausalLM, + LlamaConfig, + pipeline_config=PipelineConfig( + dtype=DType.BFLOAT16, + default_system_message=args.system, + default_max_tokens=args.max_tokens, + default_temperature=args.temperature, + ), + ) + + # Generate + print("\n" + "=" * 60) + print(f"User: {args.prompt}") + print("-" * 60) + + result = pipeline.chat(args.prompt) + + print(f"Assistant: {result.text}") + print("-" * 60) + print(f"Stats: {result.stats.summary}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/simple_inference.py b/examples/inference/simple_inference.py new file mode 100644 index 00000000..db7844d9 --- /dev/null +++ b/examples/inference/simple_inference.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" +Simple Inference Example + +Demonstrates the simplified inference API that works with any model family. +Compare this ~50 line example to the 400+ line model-specific examples! + +Usage: + # Default: TinyLlama + uv run python examples/inference/simple_inference.py + + # With a specific model + uv run python examples/inference/simple_inference.py --model-id "HuggingFaceTB/SmolLM2-360M-Instruct" + + # Custom prompt + uv run python examples/inference/simple_inference.py --prompt "Write a haiku about coding" +""" + +from __future__ import annotations + +import argparse + +from chuk_lazarus.inference import ( + GenerationConfig, + InferencePipeline, + PipelineConfig, + DType, +) +from chuk_lazarus.models_v2 import LlamaConfig, LlamaForCausalLM + + +def main(): + parser = argparse.ArgumentParser(description="Simple Inference Example") + parser.add_argument( + "--model-id", + default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + help="HuggingFace model ID", + ) + parser.add_argument( + "--prompt", + default="What is the capital of France?", + help="User prompt", + ) + parser.add_argument( + "--system", + default="You are a helpful assistant.", + help="System message", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature", + ) + args = parser.parse_args() + + # Configure pipeline + config = PipelineConfig( + dtype=DType.BFLOAT16, + default_system_message=args.system, + default_max_tokens=args.max_tokens, + default_temperature=args.temperature, + ) + + # Load model - ONE LINE! + pipeline = InferencePipeline.from_pretrained( + args.model_id, + LlamaForCausalLM, + LlamaConfig, + pipeline_config=config, + ) + + # Generate - ONE LINE! + print("\n" + "=" * 60) + print(f"User: {args.prompt}") + print("-" * 60) + + result = pipeline.chat(args.prompt) + + print(f"Assistant: {result.text}") + print("-" * 60) + print(f"Stats: {result.stats.summary}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma/03_gemma3_inference.py b/examples/models/gemma/03_gemma3_inference.py new file mode 100644 index 00000000..d7f3d5da --- /dev/null +++ b/examples/models/gemma/03_gemma3_inference.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +""" +Gemma 3 Inference Example + +This example demonstrates how to: +1. Load a pretrained Gemma 3 model from mlx-community +2. Apply chat templates for instruction-following +3. Generate text with various sampling strategies + +Supported models from mlx-community (bf16 recommended): +- mlx-community/gemma-3-1b-it-bf16 (1B params, bfloat16) +- mlx-community/gemma-3-4b-it-bf16 (4B params, bfloat16) +- mlx-community/gemma-3-12b-it-bf16 (12B params, bfloat16) +- mlx-community/gemma-3-27b-it-bf16 (27B params, bfloat16) + +Note: 4-bit quantized models require additional quantization support. +Use bf16 models for direct loading with this implementation. + +Requirements: + pip install huggingface_hub safetensors transformers + +Usage: + python 03_gemma3_inference.py + python 03_gemma3_inference.py --model mlx-community/gemma-3-4b-it-bf16 + python 03_gemma3_inference.py --prompt "Explain quantum computing in simple terms" + +References: + - https://huggingface.co/blog/gemma3 + - https://huggingface.co/collections/mlx-community/gemma-3 +""" + +import argparse +import json +from pathlib import Path + +import mlx.core as mx +from mlx.utils import tree_unflatten + +from chuk_lazarus.models_v2.families.gemma import ( + GemmaConfig, + GemmaForCausalLM, +) + + +# Chat template for Gemma 3 instruction-tuned models +GEMMA3_CHAT_TEMPLATE = """user +{prompt} +model +""" + +# Multi-turn chat template +GEMMA3_MULTI_TURN_TEMPLATE = """{conversation}model +""" + + +def format_turn(role: str, content: str) -> str: + """Format a single conversation turn.""" + return f"{role}\n{content}\n" + + +def download_model(model_id: str) -> Path: + """Download model from HuggingFace Hub.""" + from huggingface_hub import snapshot_download + + print(f"Downloading {model_id}...") + path = snapshot_download( + repo_id=model_id, + allow_patterns=["*.json", "*.safetensors"], + ) + return Path(path) + + +def load_config_from_hf(model_path: Path, weights: dict | None = None) -> GemmaConfig: + """Load and convert HuggingFace config to GemmaConfig. + + Args: + model_path: Path to the model directory + weights: Optional pre-loaded weights dict for inferring missing config values + """ + config_path = model_path / "config.json" + with open(config_path) as f: + hf_config = json.load(f) + + # Handle multimodal models (4B+) which have nested text_config + # vs text-only models (1B) which have flat config + if "text_config" in hf_config: + # Multimodal model - extract text config + text_config = hf_config["text_config"] + # Some fields may be in the top-level config + model_type = text_config.get("model_type", hf_config.get("model_type", "gemma3_text")) + is_multimodal = True + else: + # Text-only model - use config directly + text_config = hf_config + model_type = hf_config.get("model_type", "gemma3_text") + is_multimodal = False + + hidden_size = text_config["hidden_size"] + head_dim = text_config.get("head_dim", 256) + + # Try to get num_attention_heads from config, otherwise infer from weights + num_attention_heads = text_config.get("num_attention_heads") + num_key_value_heads = text_config.get("num_key_value_heads") + + if (num_attention_heads is None or num_key_value_heads is None) and weights is not None: + # Infer from weight shapes + # q_proj shape: (num_heads * head_dim, hidden_size) + # k_proj shape: (num_kv_heads * head_dim, hidden_size) + q_proj_key = None + k_proj_key = None + for k in weights.keys(): + if "layers.0" in k and "self_attn.q_proj.weight" in k: + q_proj_key = k + if "layers.0" in k and "self_attn.k_proj.weight" in k: + k_proj_key = k + if q_proj_key and k_proj_key: + break + + if q_proj_key and num_attention_heads is None: + q_proj_shape = weights[q_proj_key].shape + num_attention_heads = q_proj_shape[0] // head_dim + print(f" Inferred num_attention_heads={num_attention_heads} from weights") + + if k_proj_key and num_key_value_heads is None: + k_proj_shape = weights[k_proj_key].shape + num_key_value_heads = k_proj_shape[0] // head_dim + print(f" Inferred num_key_value_heads={num_key_value_heads} from weights") + + # Fallback defaults if still None + if num_attention_heads is None: + num_attention_heads = hidden_size // head_dim + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + # Map HuggingFace config fields to our config + return GemmaConfig( + model_type=model_type, + vocab_size=text_config.get("vocab_size", 262144), + hidden_size=hidden_size, + num_hidden_layers=text_config["num_hidden_layers"], + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=text_config["intermediate_size"], + head_dim=head_dim, + query_pre_attn_scalar=text_config.get("query_pre_attn_scalar", 256.0), + sliding_window=text_config.get("sliding_window", 512), + sliding_window_pattern=text_config.get("sliding_window_pattern", text_config.get("_sliding_window_pattern", 6)), + max_position_embeddings=text_config.get("max_position_embeddings", 32768), + rope_theta=text_config.get("rope_theta", 1000000.0), + rope_local_base_freq=text_config.get("rope_local_base_freq", 10000.0), + rms_norm_eps=text_config.get("rms_norm_eps", 1e-6), + ) + + +def load_tokenizer(model_path: Path): + """Load tokenizer from the model directory.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(str(model_path)) + + +def load_weights(model_path: Path, text_only: bool = True) -> dict: + """Load weights from safetensors files. + + Args: + model_path: Path to the model directory + text_only: If True, filter out vision-related weights (for multimodal models) + """ + weights = {} + for sf_path in model_path.glob("*.safetensors"): + print(f" Loading {sf_path.name}...") + file_weights = mx.load(str(sf_path)) + weights.update(file_weights) + + if text_only: + # Filter out vision-related weights for text-only inference + # Multimodal models (4B+) have vision_tower, multi_modal_projector, etc. + text_weights = {} + skipped = 0 + for k, v in weights.items(): + if any(prefix in k for prefix in ["vision_tower", "multi_modal_projector"]): + skipped += 1 + continue + # Also rename language_model.* to model.* for compatibility + if k.startswith("language_model."): + k = k.replace("language_model.", "", 1) + text_weights[k] = v + if skipped > 0: + print(f" Filtered out {skipped} vision-related tensors (text-only mode)") + return text_weights + + return weights + + +def load_gemma3_model(model_id: str) -> tuple[GemmaForCausalLM, any, GemmaConfig]: + """ + Load a Gemma 3 model from HuggingFace Hub. + + Args: + model_id: HuggingFace model ID (e.g., "mlx-community/gemma-3-1b-it-bf16") + + Returns: + Tuple of (model, tokenizer, config) + """ + # Download model + model_path = download_model(model_id) + + # Load weights first (needed to infer config for multimodal models) + print("Loading weights...") + weights = load_weights(model_path) + print(f" Loaded {len(weights)} tensors") + + # Load config (pass weights to infer missing values for multimodal models) + print("Loading config...") + config = load_config_from_hf(model_path, weights=weights) + print(f" Model: {config.num_hidden_layers} layers, {config.hidden_size} hidden dim") + print(f" Heads: {config.num_attention_heads} attention, {config.num_key_value_heads} kv") + + # Load tokenizer + print("Loading tokenizer...") + tokenizer = load_tokenizer(model_path) + + # Create model + print("Creating model...") + model = GemmaForCausalLM(config) + + # Use tree_unflatten to convert flat weight keys to nested structure + # This handles the conversion from "model.layers.0.self_attn.q_proj.weight" + # to the nested dict format that model.update() expects + nested_weights = tree_unflatten(list(weights.items())) + model.update(nested_weights) + + return model, tokenizer, config + + +def generate_response( + model: GemmaForCausalLM, + tokenizer, + prompt: str, + max_new_tokens: int = 256, + temperature: float = 0.7, + top_k: int | None = 40, + top_p: float | None = 0.95, + system_prompt: str | None = None, +) -> str: + """ + Generate a response to a prompt. + + Args: + model: Gemma model + tokenizer: Tokenizer + prompt: User prompt + max_new_tokens: Maximum tokens to generate + temperature: Sampling temperature (0 = greedy) + top_k: Top-k sampling parameter + top_p: Nucleus sampling parameter + system_prompt: Optional system prompt + + Returns: + Generated response text + """ + # Format prompt with chat template + if system_prompt: + formatted = f"user\n{system_prompt}\n\n{prompt}\nmodel\n" + else: + formatted = GEMMA3_CHAT_TEMPLATE.format(prompt=prompt) + + # Tokenize + input_ids = tokenizer.encode(formatted, return_tensors="np") + input_ids = mx.array(input_ids) + + # Generate + output_ids = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + stop_tokens=[tokenizer.eos_token_id, 106], # 106 is + ) + + # Decode only the generated part + generated_ids = output_ids[0, input_ids.shape[1] :].tolist() + response = tokenizer.decode(generated_ids, skip_special_tokens=True) + + return response.strip() + + +def chat_loop(model: GemmaForCausalLM, tokenizer, config: GemmaConfig): + """Interactive chat loop.""" + print("\n" + "=" * 60) + print("Gemma 3 Chat") + print("=" * 60) + print("Type 'quit' to exit, 'clear' to reset conversation") + print("-" * 60) + + conversation_history = [] + + while True: + try: + user_input = input("\nYou: ").strip() + except (KeyboardInterrupt, EOFError): + print("\nGoodbye!") + break + + if not user_input: + continue + + if user_input.lower() == "quit": + print("Goodbye!") + break + + if user_input.lower() == "clear": + conversation_history = [] + print("Conversation cleared.") + continue + + # Add user turn to history + conversation_history.append({"role": "user", "content": user_input}) + + # Build full conversation + conv_text = "" + for turn in conversation_history: + conv_text += format_turn(turn["role"], turn["content"]) + + # Format with template + full_prompt = f"{conv_text}model\n" + + # Tokenize + input_ids = tokenizer.encode(full_prompt, return_tensors="np") + input_ids = mx.array(input_ids) + + # Check context length + if input_ids.shape[1] > config.max_position_embeddings - 256: + print("Warning: Context too long, truncating history...") + conversation_history = conversation_history[-4:] + continue + + # Generate + print("\nGemma: ", end="", flush=True) + output_ids = model.generate( + input_ids, + max_new_tokens=512, + temperature=0.7, + top_k=40, + top_p=0.95, + stop_tokens=[tokenizer.eos_token_id, 106], # 106 is + ) + + # Decode response + generated_ids = output_ids[0, input_ids.shape[1] :].tolist() + response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + print(response) + + # Add assistant response to history + conversation_history.append({"role": "model", "content": response}) + + +def main(): + parser = argparse.ArgumentParser(description="Gemma 3 Inference Example") + parser.add_argument( + "--model", + type=str, + default="mlx-community/gemma-3-1b-it-bf16", + help="HuggingFace model ID (use bf16 models for direct loading)", + ) + parser.add_argument( + "--prompt", + type=str, + default=None, + help="Single prompt to generate (skips chat mode)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=256, + help="Maximum tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature (0 = greedy)", + ) + parser.add_argument( + "--top-k", + type=int, + default=40, + help="Top-k sampling parameter", + ) + parser.add_argument( + "--top-p", + type=float, + default=0.95, + help="Nucleus sampling parameter", + ) + parser.add_argument( + "--chat", + action="store_true", + help="Start interactive chat mode", + ) + args = parser.parse_args() + + print("=" * 60) + print("Gemma 3 Inference") + print("=" * 60) + print(f"Model: {args.model}") + print("-" * 60) + + # Load model + model, tokenizer, config = load_gemma3_model(args.model) + print("\nModel loaded successfully!") + + # Evaluate model to ensure weights are loaded + mx.eval(model.parameters()) + + if args.chat: + # Interactive chat mode + chat_loop(model, tokenizer, config) + elif args.prompt: + # Single prompt mode + print(f"\nPrompt: {args.prompt}") + print("-" * 60) + + response = generate_response( + model, + tokenizer, + args.prompt, + max_new_tokens=args.max_tokens, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + + print(f"\nResponse:\n{response}") + else: + # Demo mode with example prompts + print("\n" + "=" * 60) + print("Running demo prompts...") + print("=" * 60) + + demo_prompts = [ + "What is the capital of France?", + "Write a haiku about programming.", + "Explain what makes a good API design in 2-3 sentences.", + ] + + for prompt in demo_prompts: + print(f"\n{'='*60}") + print(f"Prompt: {prompt}") + print("-" * 60) + + response = generate_response( + model, + tokenizer, + prompt, + max_new_tokens=args.max_tokens, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + + print(f"Response:\n{response}") + + print("\n" + "=" * 60) + print("Demo complete!") + print("Run with --chat for interactive mode or --prompt for single queries") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/models/gemma/04_gemma3_vision_inference.py b/examples/models/gemma/04_gemma3_vision_inference.py new file mode 100644 index 00000000..dd50a6ca --- /dev/null +++ b/examples/models/gemma/04_gemma3_vision_inference.py @@ -0,0 +1,857 @@ +#!/usr/bin/env python3 +""" +Gemma 3 Vision Inference Example + +Multimodal inference with Gemma 3 using the chuk_lazarus framework. +Demonstrates loading pretrained weights and running vision-language tasks. + +Architecture Overview: + Gemma 3 multimodal models consist of three main components: + + 1. SigLIP Vision Encoder (27 layers, 1152 hidden dim) + - Patches image into 14x14 patches -> 64x64 = 4096 tokens for 896x896 images + - Each patch goes through vision transformer layers + - Uses standard multi-head attention with GELU(precise) activation + - Pre-norm architecture with LayerNorm + + 2. Multi-Modal Projector + - Average pooling: 64x64 -> 16x16 (4096 -> 256 tokens) + - Gemma-style RMSNorm (1+weight scaling) + Linear projection + - Output scaled by 1/sqrt(hidden_size) to match text embedding magnitude + + 3. Gemma 3 Language Model (34 layers for 4B) + - Image embeddings replace placeholders (256 tokens) + - Alternating sliding window (1024) / global attention layers + - GQA with 8 query heads, 4 KV heads + +Supported models from mlx-community (bf16 recommended): +- mlx-community/gemma-3-4b-it-bf16 (4B params, multimodal) +- mlx-community/gemma-3-12b-it-bf16 (12B params, multimodal) +- mlx-community/gemma-3-27b-it-bf16 (27B params, multimodal) + +Note: The 1B model is text-only. Use 4B+ for vision capabilities. + +Requirements: + pip install huggingface_hub safetensors transformers pillow + +Usage: + # Basic usage with local image + python 04_gemma3_vision_inference.py --image path/to/image.jpg + + # Custom prompt + python 04_gemma3_vision_inference.py --image path/to/image.jpg --prompt "Describe this image" + + # URL-based image + python 04_gemma3_vision_inference.py --url https://example.com/image.jpg + +References: + - For text-only inference: 03_gemma3_inference.py + - mlx-vlm (alternative implementation): https://github.com/Blaizzy/mlx-vlm + - Gemma 3 blog: https://huggingface.co/blog/gemma3 +""" + +import argparse +import json +import math +from dataclasses import dataclass +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten + +from chuk_lazarus.models_v2.families.gemma import GemmaConfig, GemmaForCausalLM + + +# ============================================================================= +# Configuration +# ============================================================================= + + +@dataclass +class SigLIPVisionConfig: + """Configuration for SigLIP vision encoder.""" + + hidden_size: int = 1152 + intermediate_size: int = 4304 + num_hidden_layers: int = 27 + num_attention_heads: int = 16 + image_size: int = 896 + patch_size: int = 14 + layer_norm_eps: float = 1e-6 + + @property + def num_patches(self) -> int: + return (self.image_size // self.patch_size) ** 2 + + +@dataclass +class Gemma3VisionConfig: + """Configuration for Gemma 3 multimodal model.""" + + vision_config: SigLIPVisionConfig + text_config: GemmaConfig + mm_tokens_per_image: int = 256 # After pooling: 16x16 = 256 tokens + mm_tokens_per_side: int = 16 # Output spatial dimension + image_token_index: int = 262144 + boi_token_index: int = 255999 # Beginning of image + eoi_token_index: int = 256000 # End of image + + +# ============================================================================= +# SigLIP Vision Encoder +# ============================================================================= + + +class SigLIPMLP(nn.Module): + """MLP block for SigLIP.""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + # SigLIP uses GELU with precise approximation and biased linear layers + self.activation_fn = nn.GELU(approx="precise") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + x = self.fc1(x) + x = self.activation_fn(x) + x = self.fc2(x) + return x + + +class SigLIPAttention(nn.Module): + """Multi-head attention for SigLIP.""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scale = self.head_dim**-0.5 + + # SigLIP uses biased linear layers in attention + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + B, L, _ = x.shape + + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) + + # Reshape to (B, num_heads, L, head_dim) + queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + + # Use MLX optimized scaled dot-product attention + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=None + ) + + # Reshape back + output = output.transpose(0, 2, 1, 3).reshape(B, L, self.embed_dim) + return self.out_proj(output) + + +class SigLIPEncoderLayer(nn.Module): + """Transformer encoder layer for SigLIP.""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + self.self_attn = SigLIPAttention(config) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLIPMLP(config) + + def __call__(self, x: mx.array) -> mx.array: + # Pre-norm architecture + residual = x + x = self.layer_norm1(x) + x = self.self_attn(x) + x = residual + x + + residual = x + x = self.layer_norm2(x) + x = self.mlp(x) + x = residual + x + + return x + + +class SigLIPVisionEmbeddings(nn.Module): + """Patch embedding + position embedding for SigLIP.""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + self.config = config + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + ) + self.position_embedding = nn.Embedding(config.num_patches, config.hidden_size) + + def __call__(self, pixel_values: mx.array) -> mx.array: + # pixel_values: (B, H, W, C) - MLX uses channels-last + B = pixel_values.shape[0] + + # Patch embedding: (B, H, W, C) -> (B, H', W', hidden_size) + patch_embeds = self.patch_embedding(pixel_values) + + # Flatten spatial dimensions: (B, H', W', hidden_size) -> (B, num_patches, hidden_size) + patch_embeds = patch_embeds.reshape(B, -1, self.config.hidden_size) + + # Add position embeddings + position_ids = mx.arange(patch_embeds.shape[1]) + embeddings = patch_embeds + self.position_embedding(position_ids) + + return embeddings + + +class SigLIPVisionEncoder(nn.Module): + """SigLIP Vision Transformer encoder.""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + self.layers = [SigLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + + def __call__(self, x: mx.array) -> mx.array: + for layer in self.layers: + x = layer(x) + return x + + +class SigLIPVisionModel(nn.Module): + """Complete SigLIP vision model.""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + self.config = config + self.embeddings = SigLIPVisionEmbeddings(config) + self.encoder = SigLIPVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def __call__(self, pixel_values: mx.array) -> mx.array: + """ + Args: + pixel_values: (B, H, W, C) image tensor, values in [0, 1] + + Returns: + (B, num_patches, hidden_size) vision features + """ + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(hidden_states) + hidden_states = self.post_layernorm(hidden_states) + return hidden_states + + +# ============================================================================= +# Multi-Modal Projector +# ============================================================================= + + +class GemmaRMSNorm(nn.Module): + """Gemma-style RMSNorm that uses (1 + weight) as scale factor. + + This matches the HuggingFace/mlx-vlm implementation where the weight + parameter is an offset from 1.0, not the direct scale. + """ + + def __init__(self, dims: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = mx.zeros((dims,)) # Offset from 1.0 + + def __call__(self, x: mx.array) -> mx.array: + return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) + + +class MultiModalProjector(nn.Module): + """Projects vision features to language model space with pooling. + + The projector performs: + 1. Average pooling: 64x64 tokens -> 16x16 tokens (4096 -> 256) + 2. RMSNorm normalization (Gemma-style with 1+weight scaling) + 3. Linear projection to text embedding dimension + """ + + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + patches_per_image: int = 4096, # 64x64 + tokens_per_side: int = 16, # Output: 16x16 = 256 tokens + ): + super().__init__() + self.patches_per_image = patches_per_image + self.tokens_per_side = tokens_per_side + self.kernel_size = int(math.sqrt(patches_per_image)) // tokens_per_side # 64 // 16 = 4 + + # Use Gemma-style RMSNorm with (1 + weight) scaling + self.mm_soft_emb_norm = GemmaRMSNorm(vision_hidden_size) + # Note: weight shape is (vision, text) for projection + self.mm_input_projection_weight = mx.zeros((vision_hidden_size, text_hidden_size)) + + def __call__(self, image_features: mx.array) -> mx.array: + """ + Args: + image_features: (B, num_patches, vision_hidden_size) where num_patches = 4096 + + Returns: + (B, 256, text_hidden_size) - pooled and projected features + """ + batch_size, num_patches, hidden_size = image_features.shape + + # Reshape to spatial: (B, 64, 64, hidden) + spatial_size = int(math.sqrt(num_patches)) + image_features = image_features.reshape(batch_size, spatial_size, spatial_size, hidden_size) + + # Average pool: 64x64 -> 16x16 using kernel_size=4 + # MLX doesn't have AvgPool2d, so we do it manually + k = self.kernel_size + new_h = spatial_size // k + new_w = spatial_size // k + + # Reshape for pooling: (B, new_h, k, new_w, k, hidden) + image_features = image_features.reshape(batch_size, new_h, k, new_w, k, hidden_size) + # Average over the k x k windows + image_features = mx.mean(image_features, axis=(2, 4)) # (B, new_h, new_w, hidden) + + # Flatten back to sequence: (B, 256, hidden) + image_features = image_features.reshape(batch_size, new_h * new_w, hidden_size) + + # Apply soft embedding norm + image_features = self.mm_soft_emb_norm(image_features) + # Project to text space + projected = image_features @ self.mm_input_projection_weight + return projected + + +# ============================================================================= +# Gemma 3 Vision Model +# ============================================================================= + + +class VisionTower(nn.Module): + """Wrapper to match HF weight structure (vision_tower.vision_model.*).""" + + def __init__(self, config: SigLIPVisionConfig): + super().__init__() + self.vision_model = SigLIPVisionModel(config) + + def __call__(self, pixel_values: mx.array) -> mx.array: + return self.vision_model(pixel_values) + + +class Gemma3ForConditionalGeneration(nn.Module): + """Gemma 3 multimodal model for vision-language tasks.""" + + def __init__(self, config: Gemma3VisionConfig): + super().__init__() + self.config = config + + # Vision encoder (wrapped to match HF structure) + self.vision_tower = VisionTower(config.vision_config) + + # Multi-modal projector + self.multi_modal_projector = MultiModalProjector( + config.vision_config.hidden_size, config.text_config.hidden_size + ) + + # Language model + self.language_model = GemmaForCausalLM(config.text_config) + + def get_image_features(self, pixel_values: mx.array) -> mx.array: + """Extract and project image features. + + Scale by 1/sqrt(hidden_size) following mlx-vlm approach. + This ensures image features have the right magnitude when combined + with text embeddings before the backbone's sqrt(hidden_size) scaling. + """ + vision_outputs = self.vision_tower(pixel_values) + image_features = self.multi_modal_projector(vision_outputs) + + # Scale image features by 1/sqrt(hidden_size) - standard Gemma 3 multimodal scaling + hidden_size = self.config.text_config.hidden_size + image_features = image_features / (hidden_size**0.5) + + # Cast to match the typical model dtype + image_features = image_features.astype(mx.bfloat16) + + return image_features + + def __call__( + self, + input_ids: mx.array, + pixel_values: mx.array | None = None, + ) -> mx.array: + """ + Forward pass with optional image input. + + Args: + input_ids: (B, seq_len) token IDs + pixel_values: (B, H, W, C) image tensor or None + + Returns: + logits: (B, seq_len, vocab_size) + """ + # Get text embeddings + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + if pixel_values is not None: + # Get image features + image_features = self.get_image_features(pixel_values) + + # Find image token positions and replace with image features + # For simplicity, we insert image features at the beginning + # In practice, you'd find the image token positions + batch_size = input_ids.shape[0] + + # Concatenate: [image_features, text_embeddings] + inputs_embeds = mx.concatenate([image_features, inputs_embeds], axis=1) + + # Forward through language model with embeddings + output = self.language_model.model( + input_ids=None, + input_embeddings=inputs_embeds, + ) + + # Get logits + logits = self.language_model.lm_head(output.last_hidden_state) + + return logits + + def generate( + self, + input_ids: mx.array, + pixel_values: mx.array | None = None, + image_positions: list[int] | None = None, + max_new_tokens: int = 100, + temperature: float = 0.7, + top_k: int | None = 40, + top_p: float | None = 0.95, + stop_tokens: list[int] | None = None, + ) -> mx.array: + """ + Generate text given optional image input with KV-cache for efficiency. + + Args: + input_ids: Tokenized text prompt (batch_size, seq_len) + pixel_values: Preprocessed image (batch_size, H, W, C) or None + image_positions: List of positions where image soft tokens are (to be replaced) + max_new_tokens: Maximum number of tokens to generate + temperature: Sampling temperature (0 = greedy, higher = more random) + top_k: Top-k sampling (filter to top k tokens) + top_p: Nucleus sampling (filter to tokens with cumulative prob <= top_p) + stop_tokens: Token IDs that stop generation + + Returns: + Generated token IDs including the prompt tokens + """ + # Step 1: Build initial embeddings + print(" Embedding text tokens...", flush=True) + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + if pixel_values is not None and image_positions is not None and len(image_positions) > 0: + print(" Processing image through vision encoder...", flush=True) + image_features = self.get_image_features(pixel_values) + mx.eval(image_features) + print(f" Image features shape: {image_features.shape}", flush=True) + + # Replace embeddings at image positions with image features + # This follows the mlx-vlm approach: scatter image features into placeholder positions + batch_size, seq_len, embed_dim = inputs_embeds.shape + num_image_tokens = image_features.shape[1] + + if len(image_positions) == num_image_tokens: + # Replace the embeddings at image positions with scaled image features + # Use MLX operations to scatter image features at the correct positions + # The image soft tokens are contiguous starting at image_positions[0] + start_pos = image_positions[0] + end_pos = start_pos + num_image_tokens + + # Build new embeddings: [before][image_features][after] + before = inputs_embeds[:, :start_pos, :] + after = inputs_embeds[:, end_pos:, :] + inputs_embeds = mx.concatenate([before, image_features, after], axis=1) + print(f" Replaced {num_image_tokens} image soft token embeddings at positions {start_pos}-{end_pos-1}", flush=True) + else: + print(f" Warning: image_positions ({len(image_positions)}) != num_image_tokens ({num_image_tokens})", flush=True) + + print(f" Final embeddings shape: {inputs_embeds.shape}", flush=True) + + # Step 2: Process initial sequence through backbone to get KV-cache + print(" Building KV-cache (this may take a moment)...", flush=True) + output = self.language_model.model( + input_ids=None, + input_embeddings=inputs_embeds, + cache=None, # No cache on first pass + ) + cache = output.cache + mx.eval(cache) + print(" KV-cache ready!", flush=True) + + # Track generated tokens + generated_ids = list(input_ids[0].tolist()) + + # Get first token from initial pass + logits = self.language_model.lm_head(output.last_hidden_state[:, -1:, :]) + logits = logits[:, 0, :] + + print(" Generating tokens:", end=" ", flush=True) + for i in range(max_new_tokens): + # Apply temperature + if temperature > 0 and temperature != 1.0: + logits = logits / temperature + + # Apply top-k sampling + if top_k is not None and top_k > 0: + top_k_logits = mx.topk(logits, k=min(top_k, logits.shape[-1])) + threshold = top_k_logits[:, -1:] + logits = mx.where(logits < threshold, float("-inf"), logits) + + # Apply top-p (nucleus) sampling + if top_p is not None and top_p < 1.0: + sorted_logits = mx.sort(logits, axis=-1)[:, ::-1] + sorted_probs = mx.softmax(sorted_logits, axis=-1) + cumsum_probs = mx.cumsum(sorted_probs, axis=-1) + cutoff_idx = mx.sum(cumsum_probs < top_p, axis=-1, keepdims=True) + cutoff_logit = mx.take_along_axis(sorted_logits, cutoff_idx, axis=-1) + logits = mx.where(logits < cutoff_logit, float("-inf"), logits) + + # Sample or greedy decode + if temperature == 0: + next_token = mx.argmax(logits, axis=-1, keepdims=True) + else: + probs = mx.softmax(logits, axis=-1) + next_token = mx.random.categorical(mx.log(probs + 1e-10)) + next_token = mx.expand_dims(next_token, axis=-1) + + next_token_id = int(next_token[0, 0]) + generated_ids.append(next_token_id) + print(".", end="", flush=True) + + # Check stop tokens + if stop_tokens and next_token_id in stop_tokens: + break + + # Forward pass with cache - only process the new token! + # This is the key to efficiency: we reuse the cached K/V tensors + output = self.language_model.model( + input_ids=next_token, # Just the new token + cache=cache, # Reuse cached K/V from previous tokens + ) + cache = output.cache + logits = self.language_model.lm_head(output.last_hidden_state[:, -1:, :]) + logits = logits[:, 0, :] + mx.eval(logits) + + print(" done!", flush=True) + return mx.array([generated_ids]) + + +# ============================================================================= +# Model Loading +# ============================================================================= + + +def download_model(model_id: str) -> Path: + """Download model from HuggingFace Hub.""" + from huggingface_hub import snapshot_download + import sys + + print(f"Downloading {model_id}...", flush=True) + path = snapshot_download( + repo_id=model_id, + allow_patterns=["*.json", "*.safetensors"], + ) + return Path(path) + + +def load_config(model_path: Path) -> Gemma3VisionConfig: + """Load configuration from HuggingFace format.""" + with open(model_path / "config.json") as f: + hf_config = json.load(f) + + # Vision config + vc = hf_config.get("vision_config", {}) + vision_config = SigLIPVisionConfig( + hidden_size=vc.get("hidden_size", 1152), + intermediate_size=vc.get("intermediate_size", 4304), + num_hidden_layers=vc.get("num_hidden_layers", 27), + num_attention_heads=vc.get("num_attention_heads", 16), + image_size=vc.get("image_size", 896), + patch_size=vc.get("patch_size", 14), + ) + + # Text config - need to infer from weights + tc = hf_config.get("text_config", {}) + text_config = GemmaConfig( + vocab_size=tc.get("vocab_size", 262144), + hidden_size=tc.get("hidden_size", 2560), + num_hidden_layers=tc.get("num_hidden_layers", 34), + num_attention_heads=tc.get("num_attention_heads", 8), # Will be inferred + num_key_value_heads=tc.get("num_key_value_heads", 4), # Will be inferred + intermediate_size=tc.get("intermediate_size", 10240), + head_dim=tc.get("head_dim", 256), + sliding_window=tc.get("sliding_window", 1024), + ) + + return Gemma3VisionConfig( + vision_config=vision_config, + text_config=text_config, + mm_tokens_per_image=hf_config.get("mm_tokens_per_image", 256), + image_token_index=hf_config.get("image_token_index", 262144), + boi_token_index=hf_config.get("boi_token_index", 255999), + eoi_token_index=hf_config.get("eoi_token_index", 256000), + ) + + +def load_weights(model_path: Path) -> dict: + """Load weights from safetensors files.""" + weights = {} + for sf_path in model_path.glob("*.safetensors"): + print(f" Loading {sf_path.name}...", flush=True) + file_weights = mx.load(str(sf_path)) + weights.update(file_weights) + return weights + + +def infer_text_config_from_weights(weights: dict, config: Gemma3VisionConfig) -> Gemma3VisionConfig: + """Infer missing text config values from weights.""" + head_dim = config.text_config.head_dim + + # Find q_proj and k_proj to infer num_heads + for k, v in weights.items(): + if "language_model.model.layers.0.self_attn.q_proj.weight" in k: + num_attention_heads = v.shape[0] // head_dim + print(f" Inferred num_attention_heads={num_attention_heads}") + config.text_config = GemmaConfig( + vocab_size=config.text_config.vocab_size, + hidden_size=config.text_config.hidden_size, + num_hidden_layers=config.text_config.num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=config.text_config.num_key_value_heads, + intermediate_size=config.text_config.intermediate_size, + head_dim=head_dim, + sliding_window=config.text_config.sliding_window, + ) + if "language_model.model.layers.0.self_attn.k_proj.weight" in k: + num_key_value_heads = v.shape[0] // head_dim + print(f" Inferred num_key_value_heads={num_key_value_heads}") + config.text_config = GemmaConfig( + vocab_size=config.text_config.vocab_size, + hidden_size=config.text_config.hidden_size, + num_hidden_layers=config.text_config.num_hidden_layers, + num_attention_heads=config.text_config.num_attention_heads, + num_key_value_heads=num_key_value_heads, + intermediate_size=config.text_config.intermediate_size, + head_dim=head_dim, + sliding_window=config.text_config.sliding_window, + ) + break + + return config + + +def load_gemma3_vision_model(model_id: str) -> tuple[Gemma3ForConditionalGeneration, any, Gemma3VisionConfig]: + """Load Gemma 3 vision model from HuggingFace Hub.""" + from transformers import AutoTokenizer + + # Download + model_path = download_model(model_id) + + # Load weights + print("Loading weights...", flush=True) + weights = load_weights(model_path) + print(f" Loaded {len(weights)} tensors", flush=True) + + # Load config and infer missing values + print("Loading config...", flush=True) + config = load_config(model_path) + config = infer_text_config_from_weights(weights, config) + print(f" Vision: {config.vision_config.num_hidden_layers} layers, {config.vision_config.hidden_size} dim", flush=True) + print(f" Text: {config.text_config.num_hidden_layers} layers, {config.text_config.hidden_size} dim", flush=True) + + # Load tokenizer + print("Loading tokenizer...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + + # Create model + print("Creating model...", flush=True) + model = Gemma3ForConditionalGeneration(config) + + # Load weights + print("Applying weights...", flush=True) + nested_weights = tree_unflatten(list(weights.items())) + model.update(nested_weights) + + return model, tokenizer, config + + +# ============================================================================= +# Image Processing +# ============================================================================= + + +def load_image(path_or_url: str, size: int = 896) -> mx.array: + """Load and preprocess an image.""" + from PIL import Image + import numpy as np + + if path_or_url.startswith(("http://", "https://")): + import urllib.request + from io import BytesIO + + with urllib.request.urlopen(path_or_url) as response: + image = Image.open(BytesIO(response.read())) + else: + image = Image.open(path_or_url) + + # Convert to RGB + image = image.convert("RGB") + + # Resize to target size + image = image.resize((size, size), Image.Resampling.BILINEAR) + + # Convert to numpy array and normalize to [0, 1] + pixel_values = np.array(image).astype(np.float32) / 255.0 + + # SigLIP normalization (mean=0.5, std=0.5 for all channels) + pixel_values = (pixel_values - 0.5) / 0.5 + + # Add batch dimension: (H, W, C) -> (1, H, W, C) + pixel_values = np.expand_dims(pixel_values, axis=0) + + return mx.array(pixel_values) + + +# ============================================================================= +# Main +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser(description="Gemma 3 Vision Inference") + parser.add_argument( + "--model", + type=str, + default="mlx-community/gemma-3-4b-it-bf16", + help="HuggingFace model ID (must be 4B+ for vision)", + ) + parser.add_argument( + "--image", + type=str, + help="Path to local image file", + ) + parser.add_argument( + "--url", + type=str, + help="URL of image to process", + ) + parser.add_argument( + "--prompt", + type=str, + default="Describe this image in detail.", + help="Prompt for image description", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=256, + help="Maximum tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature", + ) + args = parser.parse_args() + + if not args.image and not args.url: + print("Error: Please provide --image or --url") + return + + image_source = args.image or args.url + + print("=" * 60) + print("Gemma 3 Vision Inference") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Image: {image_source}") + print("-" * 60) + + # Load model + model, tokenizer, config = load_gemma3_vision_model(args.model) + print("\nModel loaded successfully!") + + # Evaluate model parameters + mx.eval(model.parameters()) + + # Load and preprocess image + print(f"\nLoading image...") + pixel_values = load_image(image_source, config.vision_config.image_size) + print(f" Image shape: {pixel_values.shape}") + + # Format prompt with proper image tokens + # Gemma 3 expects: user\n[256 soft tokens]prompt\nmodel\n + # The 256 (262144) placeholders get replaced with actual image embeddings + print(f"\nPrompt: {args.prompt}") + print("-" * 60) + + # Build prompt with 256 image soft token placeholders + num_image_tokens = 256 + image_soft_token = "" + text_prompt = f"user\n{image_soft_token * num_image_tokens}{args.prompt}\nmodel\n" + input_ids = tokenizer.encode(text_prompt, return_tensors="np") + input_ids = mx.array(input_ids) + + # Find the image soft token positions (262144) + image_soft_token_id = 262144 + input_ids_list = input_ids[0].tolist() + image_positions = [i for i, tid in enumerate(input_ids_list) if tid == image_soft_token_id] + print(f" Found {len(image_positions)} image soft tokens starting at position {image_positions[0] if image_positions else 'N/A'}") + + # Generate + print("\nGenerating response...") + output_ids = model.generate( + input_ids, + pixel_values=pixel_values, + image_positions=image_positions, + max_new_tokens=args.max_tokens, + temperature=args.temperature, + top_k=40, + top_p=0.95, + stop_tokens=[tokenizer.eos_token_id, 106], # 106 is + ) + + # Decode only the newly generated tokens (skip the prompt tokens) + prompt_len = input_ids.shape[1] + generated_tokens = output_ids[0, prompt_len:].tolist() + response = tokenizer.decode(generated_tokens, skip_special_tokens=True) + + # Clean up the response + response = response.strip() + + # Remove any trailing special tokens that might remain + for suffix in ["", "", ""]: + if response.endswith(suffix): + response = response[: -len(suffix)].strip() + + print(f"\nResponse:\n{response}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/models/granite/01_granite_inference.py b/examples/models/granite/01_granite_inference.py new file mode 100644 index 00000000..ac9329a9 --- /dev/null +++ b/examples/models/granite/01_granite_inference.py @@ -0,0 +1,498 @@ +""" +IBM Granite Inference Example + +Demonstrates loading and running inference with IBM Granite models +using the models_v2 architecture. Supports: + +Granite 3.x (Dense Transformer): +- Granite 3.0 8B +- Granite 3.1 2B/8B + +Granite 4.x (Hybrid Mamba-2/Transformer): +- Granite 4.0 Micro (3B dense) +- Granite 4.0 Tiny (7B total, 1B active) - MoE +- Granite 4.0 Small (32B total, 9B active) - MoE + +Requirements: + pip install huggingface_hub transformers + +Run with: + # Test tiny config (no download needed) + uv run python examples/models/granite/01_granite_inference.py --test-tiny + + # Granite 3.1 2B (recommended for testing) + uv run python examples/models/granite/01_granite_inference.py --model granite-3.1-2b + + # Granite 4.0 Micro + uv run python examples/models/granite/01_granite_inference.py --model granite-4.0-micro + + # List available models + uv run python examples/models/granite/01_granite_inference.py --list-models +""" + +from __future__ import annotations + +import argparse +import json +import re +import time +from pathlib import Path + +import mlx.core as mx + +from chuk_lazarus.models_v2 import ( + GraniteConfig, + GraniteForCausalLM, + GraniteHybridConfig, + GraniteHybridForCausalLM, + count_parameters, +) + +# Model presets +MODEL_PRESETS = { + # Granite 3.x (Dense Transformer) + "granite-3.0-8b": { + "model_id": "ibm-granite/granite-3.0-8b-instruct", + "description": "Granite 3.0 8B - Dense transformer", + "model_type": "granite", + }, + "granite-3.1-2b": { + "model_id": "ibm-granite/granite-3.1-2b-instruct", + "description": "Granite 3.1 2B - Long context (128K)", + "model_type": "granite", + }, + "granite-3.1-8b": { + "model_id": "ibm-granite/granite-3.1-8b-instruct", + "description": "Granite 3.1 8B - Long context (128K)", + "model_type": "granite", + }, + "granite-3.3-2b": { + "model_id": "ibm-granite/granite-3.3-2b-instruct", + "description": "Granite 3.3 2B - Latest 3.x", + "model_type": "granite", + }, + "granite-3.3-8b": { + "model_id": "ibm-granite/granite-3.3-8b-instruct", + "description": "Granite 3.3 8B - Latest 3.x", + "model_type": "granite", + }, + # Granite 4.x (Hybrid Mamba-2/Transformer) + "granite-4.0-micro": { + "model_id": "ibm-granite/granite-4.0-micro", + "description": "Granite 4.0 Micro (3B) - Dense hybrid", + "model_type": "granitemoehybrid", + }, + "granite-4.0-tiny": { + "model_id": "ibm-granite/granite-4.0-tiny-preview", + "description": "Granite 4.0 Tiny (7B/1B) - MoE hybrid", + "model_type": "granitemoehybrid", + }, +} + + +def download_model(model_id: str, cache_dir: str | None = None) -> Path: + """Download model from HuggingFace Hub.""" + try: + from huggingface_hub import list_repo_files, snapshot_download + except ImportError: + raise ImportError( + "huggingface_hub not installed. Run: pip install huggingface_hub" + ) + + print(f"Downloading {model_id}...") + + try: + files = list_repo_files(model_id) + has_sharded = any("model-0" in f and f.endswith(".safetensors") for f in files) + has_consolidated = any(f == "consolidated.safetensors" for f in files) + + ignore_patterns = [] + if has_sharded and has_consolidated: + ignore_patterns.append("consolidated.safetensors") + print(" (Skipping consolidated.safetensors)") + except Exception: + ignore_patterns = [] + + path = snapshot_download( + model_id, + cache_dir=cache_dir, + allow_patterns=["*.json", "*.safetensors", "*.model", "tokenizer*"], + ignore_patterns=ignore_patterns if ignore_patterns else None, + ) + return Path(path) + + +def load_tokenizer(model_path: Path): + """Load tokenizer.""" + try: + from transformers import AutoTokenizer + except ImportError: + raise ImportError("transformers not installed. Run: pip install transformers") + + tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def load_weights(model_path: Path, config, dtype: str = "bfloat16") -> dict: + """Load and convert weights.""" + safetensor_files = sorted(model_path.glob("*.safetensors")) + if not safetensor_files: + raise FileNotFoundError(f"No safetensors in {model_path}") + + all_weights = {} + for sf_path in safetensor_files: + print(f" Loading {sf_path.name}...") + weights = mx.load(str(sf_path)) + all_weights.update(weights) + + dtype_map = { + "float16": mx.float16, + "float32": mx.float32, + "bfloat16": mx.bfloat16, + } + target_dtype = dtype_map.get(dtype, mx.bfloat16) + + # Convert names + flat_weights: dict[str, mx.array] = {} + for hf_name, weight in all_weights.items(): + our_name = _convert_weight_name(hf_name, config) + if our_name is None: + continue + if weight.dtype in (mx.float32, mx.float16, mx.bfloat16): + weight = weight.astype(target_dtype) + flat_weights[our_name] = weight + + return _build_nested_weights(flat_weights, config) + + +def _convert_weight_name(hf_name: str, config) -> str | None: + """Convert HF weight name to our format.""" + # Embeddings + if hf_name == "model.embed_tokens.weight": + return "model.embed_tokens.weight.weight" + + # Final norm + if hf_name == "model.norm.weight": + return "model.norm.weight" + + # LM head + if hf_name == "lm_head.weight": + if config.tie_word_embeddings: + return None + return "lm_head.lm_head.weight" + + # Layer pattern + layer_match = re.match(r"model\.layers\.(\d+)\.(.*)", hf_name) + if layer_match: + layer_idx = layer_match.group(1) + rest = layer_match.group(2) + + # Skip rotary embeddings + if "rotary_emb" in rest: + return None + + return f"model.layers.{layer_idx}.{rest}" + + return None + + +def _build_nested_weights(flat_weights: dict[str, mx.array], config) -> dict: + """Build nested structure.""" + max_layer_idx = -1 + for name in flat_weights: + parts = name.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts): + try: + max_layer_idx = max(max_layer_idx, int(parts[i + 1])) + except ValueError: + pass + + nested: dict = {} + for name, weight in flat_weights.items(): + parts = name.split(".") + current = nested + i = 0 + while i < len(parts) - 1: + part = parts[i] + if part == "layers": + if part not in current: + current[part] = [{} for _ in range(max_layer_idx + 1)] + layer_idx = int(parts[i + 1]) + current = current[part][layer_idx] + i += 2 + else: + if part not in current: + current[part] = {} + current = current[part] + i += 1 + current[parts[-1]] = weight + + return nested + + +def format_chat_prompt(tokenizer, user_message: str, system_message: str | None = None) -> str: + """Format using chat template.""" + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + messages.append({"role": "user", "content": user_message}) + + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + try: + return tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + pass + + prompt = "" + if system_message: + prompt += f"System: {system_message}\n\n" + prompt += f"User: {user_message}\n\nAssistant:" + return prompt + + +def generate_text(model, tokenizer, prompt: str, max_new_tokens: int = 100, temperature: float = 0.7, verbose: bool = True) -> str: + """Generate text.""" + input_ids = tokenizer.encode(prompt, return_tensors="np") + input_ids = mx.array(input_ids) + + if verbose: + print(f" Input tokens: {input_ids.shape[1]}") + + stop_tokens = [] + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + if isinstance(tokenizer.eos_token_id, list): + stop_tokens.extend(tokenizer.eos_token_id) + else: + stop_tokens.append(tokenizer.eos_token_id) + + start_time = time.time() + output_ids = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_tokens=stop_tokens, + ) + mx.eval(output_ids) + gen_time = time.time() - start_time + + new_tokens = output_ids[0, input_ids.shape[1]:] + generated_text = tokenizer.decode(new_tokens.tolist(), skip_special_tokens=True) + + if verbose: + tokens_generated = new_tokens.shape[0] + tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0 + print(f" Generated {tokens_generated} tokens in {gen_time:.2f}s") + print(f" Speed: {tokens_per_sec:.1f} tokens/sec") + + return generated_text + + +def test_tiny_models(): + """Test tiny configs without downloading.""" + print("=" * 60) + print("Granite Tiny Model Tests") + print("=" * 60) + + # Test Granite 3.x + print("\n1. Testing Granite 3.x (dense)...") + config3 = GraniteConfig.tiny() + model3 = GraniteForCausalLM(config3) + params3 = count_parameters(model3) + print(f" {params3.summary()}") + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output3 = model3(input_ids) + mx.eval(output3.logits) + print(f" Forward: OK (shape={output3.logits.shape})") + + gen3 = model3.generate(input_ids, max_new_tokens=5) + mx.eval(gen3) + print(f" Generate: OK (shape={gen3.shape})") + + # Test Granite 4.x dense + print("\n2. Testing Granite 4.x Hybrid (dense)...") + config4 = GraniteHybridConfig.tiny() + model4 = GraniteHybridForCausalLM(config4) + params4 = count_parameters(model4) + print(f" {params4.summary()}") + print(f" Layers: {config4.num_mamba_layers} Mamba, {config4.num_attention_layers} Attention") + + output4 = model4(input_ids) + mx.eval(output4.logits) + print(f" Forward: OK (shape={output4.logits.shape})") + + gen4 = model4.generate(input_ids, max_new_tokens=5) + mx.eval(gen4) + print(f" Generate: OK (shape={gen4.shape})") + + # Test Granite 4.x MoE + print("\n3. Testing Granite 4.x Hybrid + MoE...") + config4_moe = GraniteHybridConfig.tiny_moe() + model4_moe = GraniteHybridForCausalLM(config4_moe) + params4_moe = count_parameters(model4_moe) + print(f" {params4_moe.summary()}") + print(f" Experts: {config4_moe.num_local_experts} total, {config4_moe.num_experts_per_tok} active") + + output4_moe = model4_moe(input_ids) + mx.eval(output4_moe.logits) + print(f" Forward: OK (shape={output4_moe.logits.shape})") + + print("\n" + "=" * 60) + print("SUCCESS! All Granite tiny tests passed.") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser( + description="IBM Granite Inference Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Available models: + Granite 3.x (Dense): + granite-3.1-2b - 2B params, 128K context + granite-3.1-8b - 8B params, 128K context + granite-3.3-2b - Latest 3.x, 2B + granite-3.3-8b - Latest 3.x, 8B + + Granite 4.x (Hybrid Mamba-2/Transformer): + granite-4.0-micro - 3B dense hybrid + granite-4.0-tiny - 7B total (1B active) MoE + +Examples: + # Test tiny (no download) + python 01_granite_inference.py --test-tiny + + # Granite 3.1 2B + python 01_granite_inference.py --model granite-3.1-2b +""", + ) + parser.add_argument( + "--model", + choices=list(MODEL_PRESETS.keys()), + default="granite-3.1-2b", + help="Model preset", + ) + parser.add_argument("--model-id", default=None, help="Custom HuggingFace model ID") + parser.add_argument("--test-tiny", action="store_true", help="Run tiny tests") + parser.add_argument("--prompt", default="What is the capital of France?", help="Prompt") + parser.add_argument("--system", default="You are a helpful assistant.", help="System message") + parser.add_argument("--max-tokens", type=int, default=100, help="Max tokens") + parser.add_argument("--temperature", type=float, default=0.7, help="Temperature") + parser.add_argument("--dtype", default="bfloat16", choices=["float16", "float32", "bfloat16"]) + parser.add_argument("--cache-dir", default=None) + parser.add_argument("--list-models", action="store_true", help="List models and exit") + + args = parser.parse_args() + + if args.list_models: + print("Available Granite models:\n") + for name, info in MODEL_PRESETS.items(): + print(f" {name:20} - {info['description']}") + print(f" {info['model_id']}") + print() + return + + if args.test_tiny: + test_tiny_models() + return + + # Get model info + if args.model_id: + model_id = args.model_id + model_name = model_id.split("/")[-1] + model_type = "granite" # Default + else: + preset = MODEL_PRESETS[args.model] + model_id = preset["model_id"] + model_name = args.model + model_type = preset["model_type"] + + print("=" * 60) + print(f"Granite Inference: {model_name}") + print("=" * 60) + + # Download + print("\n1. Downloading model...") + try: + model_path = download_model(model_id, cache_dir=args.cache_dir) + except Exception as e: + print(f" Error: {e}") + return + print(f" Path: {model_path}") + + # Load config + print("\n2. Loading configuration...") + config_path = model_path / "config.json" + with open(config_path) as f: + config_data = json.load(f) + + # Handle token IDs + for key in ["eos_token_id", "bos_token_id", "pad_token_id"]: + if key in config_data and isinstance(config_data[key], list): + config_data[key] = config_data[key][0] if config_data[key] else None + + # Create config based on type + actual_model_type = config_data.get("model_type", model_type) + if actual_model_type == "granitemoehybrid": + config = GraniteHybridConfig(**config_data) + print(f" Type: Granite 4.x Hybrid") + print(f" Mamba layers: {config.num_mamba_layers}, Attention: {config.num_attention_layers}") + if config.is_moe: + print(f" MoE: {config.num_local_experts} experts, {config.num_experts_per_tok} active") + else: + config = GraniteConfig(**config_data) + print(f" Type: Granite 3.x Dense") + + print(f" Hidden: {config.hidden_size}, Layers: {config.num_hidden_layers}") + + # Create model + print("\n3. Creating model...") + if actual_model_type == "granitemoehybrid": + model = GraniteHybridForCausalLM(config) + else: + model = GraniteForCausalLM(config) + params = count_parameters(model) + print(f" {params.summary()}") + + # Load weights + print("\n4. Loading weights...") + weights = load_weights(model_path, config, dtype=args.dtype) + model.update(weights) + mx.eval(model.parameters()) + print(" Weights loaded!") + + # Load tokenizer + print("\n5. Loading tokenizer...") + tokenizer = load_tokenizer(model_path) + print(f" Vocab size: {len(tokenizer)}") + + # Generate + print("\n6. Generating...") + print("-" * 40) + prompt = format_chat_prompt(tokenizer, args.prompt, args.system) + print(f"User: {args.prompt}\n") + print("Assistant: ", end="", flush=True) + + response = generate_text( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_new_tokens=args.max_tokens, + temperature=args.temperature, + verbose=False, + ) + print(response) + print("-" * 40) + + print("\n" + "=" * 60) + print("SUCCESS! Granite inference complete.") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/03_llama_family_inference.py b/examples/models/llama/03_llama_family_inference.py index eebf6c26..47871d0f 100644 --- a/examples/models/llama/03_llama_family_inference.py +++ b/examples/models/llama/03_llama_family_inference.py @@ -229,6 +229,11 @@ def _convert_weight_name(hf_name: str, tie_word_embeddings: bool = False) -> str if layer_match: layer_idx = layer_match.group(1) rest = layer_match.group(2) + + # Skip rotary embeddings - we compute these dynamically + if "rotary_emb" in rest: + return None + return f"model.layers.{layer_idx}.{rest}" # Unknown weight - skip with warning for debugging diff --git a/examples/models/llama4/01_llama4_inference.py b/examples/models/llama4/01_llama4_inference.py new file mode 100644 index 00000000..169cba97 --- /dev/null +++ b/examples/models/llama4/01_llama4_inference.py @@ -0,0 +1,576 @@ +""" +Llama 4 Inference Example + +Demonstrates loading and running inference with Llama 4 models +using the models_v2 architecture. Supports: +- Llama 4 Scout (17B active / 109B total) +- Llama 4 Maverick (17B active / 400B total) + +Key features: +- MoE (Mixture of Experts) with shared expert +- iRoPE (interleaved RoPE and NoPE layers) +- QK normalization + +Requirements: + pip install huggingface_hub transformers + +Note: Llama 4 models are large and require significant memory: +- Scout: ~27GB for BF16 inference +- Maverick: ~100GB for BF16 inference + +Run with: + # Test with tiny config (no download needed) + uv run python examples/models/llama4/01_llama4_inference.py --test-tiny + + # Llama 4 Scout (requires HF access and ~27GB RAM) + uv run python examples/models/llama4/01_llama4_inference.py --model llama4-scout + + # Custom model + uv run python examples/models/llama4/01_llama4_inference.py --model-id "meta-llama/Llama-4-Scout-17B-16E-Instruct" +""" + +from __future__ import annotations + +import argparse +import json +import re +import time +from pathlib import Path + +import mlx.core as mx + +from chuk_lazarus.models_v2 import ( + Llama4ForCausalLM, + Llama4TextConfig, + count_parameters, + print_introspection, +) + +# Preset model configurations +MODEL_PRESETS = { + "llama4-scout": { + "model_id": "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "description": "Llama 4 Scout - 17B active / 109B total MoE", + }, + "llama4-maverick": { + "model_id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct", + "description": "Llama 4 Maverick - 17B active / 400B total MoE", + }, +} + + +def download_model(model_id: str, cache_dir: str | None = None) -> Path: + """Download model from HuggingFace Hub.""" + try: + from huggingface_hub import list_repo_files, snapshot_download + except ImportError: + raise ImportError( + "huggingface_hub not installed. Run: pip install huggingface_hub" + ) + + print(f"Downloading {model_id}...") + + # Check for sharded vs consolidated files + try: + files = list_repo_files(model_id) + has_sharded = any("model-0" in f and f.endswith(".safetensors") for f in files) + has_consolidated = any(f == "consolidated.safetensors" for f in files) + + ignore_patterns = [] + if has_sharded and has_consolidated: + ignore_patterns.append("consolidated.safetensors") + print(" (Skipping consolidated.safetensors - using sharded files)") + except Exception: + ignore_patterns = [] + + path = snapshot_download( + model_id, + cache_dir=cache_dir, + allow_patterns=["*.json", "*.safetensors", "*.model", "tokenizer*"], + ignore_patterns=ignore_patterns if ignore_patterns else None, + ) + return Path(path) + + +def load_tokenizer(model_path: Path): + """Load tokenizer from model path.""" + try: + from transformers import AutoTokenizer + except ImportError: + raise ImportError("transformers not installed. Run: pip install transformers") + + tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def load_weights( + model_path: Path, + config: Llama4TextConfig, + dtype: str = "bfloat16", +) -> dict: + """Load and convert Llama 4 weights from safetensors. + + Llama 4 has a different weight structure due to MoE: + - model.layers.{i}.feed_forward.router.weight + - model.layers.{i}.feed_forward.shared_expert.* + - model.layers.{i}.feed_forward.experts.{e}.* + + This function fuses per-expert weights into SwitchGLU format. + """ + safetensor_files = sorted(model_path.glob("*.safetensors")) + if not safetensor_files: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + + # Convert to target dtype + dtype_map = { + "float16": mx.float16, + "float32": mx.float32, + "bfloat16": mx.bfloat16, + } + target_dtype = dtype_map.get(dtype, mx.bfloat16) + + # Collect expert weights per layer for fusion + # layer_idx -> proj_type -> expert_idx -> weight + expert_weights: dict[int, dict[str, dict[int, mx.array]]] = {} + flat_weights: dict[str, mx.array] = {} + + # Load and process weight files one at a time to reduce memory + for sf_path in safetensor_files: + print(f" Loading {sf_path.name}...") + weights = mx.load(str(sf_path)) + + for hf_name, weight in weights.items(): + # Convert dtype + if weight.dtype in (mx.float32, mx.float16, mx.bfloat16): + weight = weight.astype(target_dtype) + + # Check for routed expert weights - need to collect and fuse + expert_match = re.match( + r"model\.layers\.(\d+)\.(?:feed_forward|mlp)\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight", + hf_name, + ) + if expert_match: + layer_idx = int(expert_match.group(1)) + expert_idx = int(expert_match.group(2)) + proj_type = expert_match.group(3) + + if layer_idx not in expert_weights: + expert_weights[layer_idx] = {} + if proj_type not in expert_weights[layer_idx]: + expert_weights[layer_idx][proj_type] = {} + + expert_weights[layer_idx][proj_type][expert_idx] = weight + continue + + # Convert other weights normally + our_name = _convert_llama4_weight_name(hf_name, config) + if our_name is not None: + flat_weights[our_name] = weight + + # Clear loaded weights to free memory + del weights + mx.eval([]) # Trigger cleanup + + # Fuse expert weights into SwitchGLU format: (num_experts, out_dim, in_dim) + print(" Fusing expert weights...") + for layer_idx, proj_dict in expert_weights.items(): + for proj_type, experts_dict in proj_dict.items(): + # Stack expert weights in order + num_experts = len(experts_dict) + expert_list = [experts_dict[i] for i in range(num_experts)] + fused = mx.stack(expert_list, axis=0) + + # SwitchLinear expects (num_experts, output_dims, input_dims) + # HF format is (output_dims, input_dims), so stacking gives correct shape + our_name = f"model.layers.{layer_idx}.mlp.experts.{proj_type}.weight" + flat_weights[our_name] = fused + + # Clear expert weights dict + del expert_weights + mx.eval([]) + + return _build_nested_weights_v2(flat_weights, config) + + +def _convert_llama4_weight_name(hf_name: str, config: Llama4TextConfig) -> str | None: + """Convert HuggingFace Llama 4 weight name to our format. + + Note: Routed expert weights (feed_forward.experts.{i}.*) are handled separately + in load_weights() and fused into SwitchGLU format. + """ + # Embeddings + if hf_name == "model.embed_tokens.weight": + return "model.embed_tokens.weight.weight" + + # Final layer norm + if hf_name == "model.norm.weight": + return "model.norm.weight" + + # LM head + if hf_name == "lm_head.weight": + if config.tie_word_embeddings: + return None + return "lm_head.lm_head.weight" + + # Layer pattern + layer_match = re.match(r"model\.layers\.(\d+)\.(.*)", hf_name) + if layer_match: + layer_idx = layer_match.group(1) + rest = layer_match.group(2) + + # Skip rotary embeddings + if "rotary_emb" in rest: + return None + + # Skip routed expert weights - handled separately for fusion + if re.match(r"(?:feed_forward|mlp)\.experts\.\d+\.", rest): + return None + + # Attention projections (with QK norm) + if rest.startswith("self_attn."): + return f"model.layers.{layer_idx}.{rest}" + + # Layer norms + if rest in ("input_layernorm.weight", "post_attention_layernorm.weight"): + return f"model.layers.{layer_idx}.{rest}" + + # MoE components - map to our structure + if rest.startswith("feed_forward.") or rest.startswith("mlp."): + # Normalize to feed_forward prefix + if rest.startswith("mlp."): + rest = rest.replace("mlp.", "feed_forward.", 1) + + # Router + if rest == "feed_forward.router.weight": + return f"model.layers.{layer_idx}.mlp.router.weight" + + # Shared expert + if rest.startswith("feed_forward.shared_expert."): + sub = rest.replace("feed_forward.shared_expert.", "") + return f"model.layers.{layer_idx}.mlp.shared_expert.{sub}" + + # Standard MLP (non-MoE fallback) + if rest.startswith("mlp."): + return f"model.layers.{layer_idx}.{rest}" + + return f"model.layers.{layer_idx}.{rest}" + + return None + + +def _build_nested_weights_v2(flat_weights: dict[str, mx.array], config: Llama4TextConfig) -> dict: + """Build nested dict/list structure from flat weight names. + + V2: Handles fused expert weights (no per-expert indexing needed). + """ + # Find maximum layer index + max_layer_idx = -1 + + for name in flat_weights: + parts = name.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts): + try: + max_layer_idx = max(max_layer_idx, int(parts[i + 1])) + except ValueError: + pass + + # Build nested structure + nested: dict = {} + for name, weight in flat_weights.items(): + parts = name.split(".") + current = nested + + i = 0 + while i < len(parts) - 1: + part = parts[i] + + if part == "layers": + if part not in current: + current[part] = [{} for _ in range(max_layer_idx + 1)] + layer_idx = int(parts[i + 1]) + current = current[part][layer_idx] + i += 2 + else: + if part not in current: + current[part] = {} + current = current[part] + i += 1 + + current[parts[-1]] = weight + + return nested + + +def format_chat_prompt(tokenizer, user_message: str, system_message: str | None = None) -> str: + """Format a message using the tokenizer's chat template.""" + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + messages.append({"role": "user", "content": user_message}) + + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + try: + return tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + pass + + # Fallback + prompt = "" + if system_message: + prompt += f"System: {system_message}\n\n" + prompt += f"User: {user_message}\n\nAssistant:" + return prompt + + +def generate_text( + model: Llama4ForCausalLM, + tokenizer, + prompt: str, + max_new_tokens: int = 100, + temperature: float = 0.7, + verbose: bool = True, +) -> str: + """Generate text from the model.""" + input_ids = tokenizer.encode(prompt, return_tensors="np") + input_ids = mx.array(input_ids) + + if verbose: + print(f" Input tokens: {input_ids.shape[1]}") + + stop_tokens = [] + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + if isinstance(tokenizer.eos_token_id, list): + stop_tokens.extend(tokenizer.eos_token_id) + else: + stop_tokens.append(tokenizer.eos_token_id) + + start_time = time.time() + output_ids = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_tokens=stop_tokens, + ) + mx.eval(output_ids) + gen_time = time.time() - start_time + + new_tokens = output_ids[0, input_ids.shape[1]:] + generated_text = tokenizer.decode(new_tokens.tolist(), skip_special_tokens=True) + + if verbose: + tokens_generated = new_tokens.shape[0] + tokens_per_sec = tokens_generated / gen_time if gen_time > 0 else 0 + print(f" Generated {tokens_generated} tokens in {gen_time:.2f}s") + print(f" Speed: {tokens_per_sec:.1f} tokens/sec") + + return generated_text + + +def test_tiny_model(): + """Test with a tiny Llama 4 config (no download needed).""" + print("=" * 60) + print("Llama 4 Tiny Model Test") + print("=" * 60) + + # Create tiny config + config = Llama4TextConfig.tiny() + print(f"\nConfig:") + print(f" Hidden size: {config.hidden_size}") + print(f" Layers: {config.num_hidden_layers}") + print(f" Experts: {config.num_local_experts}") + print(f" Experts per token: {config.num_experts_per_tok}") + print(f" QK norm: {config.use_qk_norm}") + + # Create model + print("\nCreating model...") + model = Llama4ForCausalLM(config) + + params = count_parameters(model) + print(f" {params.summary()}") + + # Test forward pass + print("\nTesting forward pass...") + input_ids = mx.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + output = model(input_ids) + mx.eval(output.logits) + print(f" Input shape: {input_ids.shape}") + print(f" Output logits shape: {output.logits.shape}") + + # Test generation + print("\nTesting generation...") + output_ids = model.generate( + input_ids, + max_new_tokens=10, + temperature=1.0, + ) + mx.eval(output_ids) + print(f" Generated sequence shape: {output_ids.shape}") + + print("\n" + "=" * 60) + print("SUCCESS! Llama 4 tiny model test passed.") + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser( + description="Llama 4 Inference Example", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Llama 4 Model Presets: + llama4-scout - 17B active / 109B total (16 experts) + llama4-maverick - 17B active / 400B total (128 experts) + +Examples: + # Test tiny model (no download) + python 01_llama4_inference.py --test-tiny + + # Run with Scout (requires ~27GB RAM) + python 01_llama4_inference.py --model llama4-scout +""", + ) + parser.add_argument( + "--model", + choices=list(MODEL_PRESETS.keys()), + default="llama4-scout", + help="Model preset to use", + ) + parser.add_argument( + "--model-id", + default=None, + help="Custom HuggingFace model ID", + ) + parser.add_argument( + "--test-tiny", + action="store_true", + help="Run tiny model test (no download)", + ) + parser.add_argument( + "--prompt", + default="What is the capital of France?", + help="Prompt to generate from", + ) + parser.add_argument( + "--system", + default="You are a helpful assistant.", + help="System message", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature", + ) + parser.add_argument( + "--dtype", + default="bfloat16", + choices=["float16", "float32", "bfloat16"], + help="Data type for weights", + ) + parser.add_argument( + "--cache-dir", + default=None, + help="Cache directory", + ) + + args = parser.parse_args() + + # Tiny model test + if args.test_tiny: + test_tiny_model() + return + + # Get model ID + if args.model_id: + model_id = args.model_id + model_name = model_id.split("/")[-1] + else: + preset = MODEL_PRESETS[args.model] + model_id = preset["model_id"] + model_name = args.model + + print("=" * 60) + print(f"Llama 4 Inference: {model_name}") + print("=" * 60) + + # Download + print("\n1. Downloading model...") + try: + model_path = download_model(model_id, cache_dir=args.cache_dir) + except Exception as e: + print(f" Error: {e}") + print("\n Note: Llama 4 models require HuggingFace authentication.") + print(" Run: huggingface-cli login") + return + print(f" Path: {model_path}") + + # Load config + print("\n2. Loading configuration...") + config_path = model_path / "config.json" + with open(config_path) as f: + config_data = json.load(f) + + # Handle list token IDs + for key in ["eos_token_id", "bos_token_id", "pad_token_id"]: + if key in config_data and isinstance(config_data[key], list): + config_data[key] = config_data[key][0] if config_data[key] else None + + config = Llama4TextConfig(**config_data) + print(f" Model type: {config.model_type}") + print(f" Hidden size: {config.hidden_size}") + print(f" Layers: {config.num_hidden_layers}") + print(f" Experts: {config.num_local_experts}") + + # Create model + print("\n3. Creating model...") + model = Llama4ForCausalLM(config) + params = count_parameters(model) + print(f" {params.summary()}") + + # Load weights + print("\n4. Loading weights...") + weights = load_weights(model_path, config, dtype=args.dtype) + model.update(weights) + mx.eval(model.parameters()) + print(" Weights loaded!") + + # Load tokenizer + print("\n5. Loading tokenizer...") + tokenizer = load_tokenizer(model_path) + print(f" Vocab size: {len(tokenizer)}") + + # Generate + print("\n6. Generating...") + print("-" * 40) + prompt = format_chat_prompt(tokenizer, args.prompt, args.system) + print(f"User: {args.prompt}\n") + print("Assistant: ", end="", flush=True) + + response = generate_text( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_new_tokens=args.max_tokens, + temperature=args.temperature, + verbose=False, + ) + print(response) + print("-" * 40) + + print("\n" + "=" * 60) + print("SUCCESS! Llama 4 inference complete.") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index ff659988..58565d37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chuk-lazarus" -version = "0.2.3" +version = "0.4" description = "MLX-based LLM training and inference with hybrid RL architecture" readme = "README.md" requires-python = ">=3.10" diff --git a/src/chuk_lazarus/inference/__init__.py b/src/chuk_lazarus/inference/__init__.py index 36597498..7e40b36e 100644 --- a/src/chuk_lazarus/inference/__init__.py +++ b/src/chuk_lazarus/inference/__init__.py @@ -1,14 +1,96 @@ """ Inference and text generation utilities. -Provides: -- generate_sequence: Token-by-token generation -- generate_response: Full response generation with tokenization +Provides a high-level API for loading and running inference +with any supported model family. + +Example usage: + + from chuk_lazarus.inference import InferencePipeline + from chuk_lazarus.models_v2 import LlamaForCausalLM, LlamaConfig + + # One-liner setup + pipeline = InferencePipeline.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + LlamaForCausalLM, + LlamaConfig, + ) + + # Chat + response = pipeline.chat("What is the capital of France?") + print(response.text) """ +# Core generation +# Chat utilities +from .chat import ( + ASSISTANT_SUFFIX, + ChatHistory, + ChatMessage, + FallbackTemplate, + Role, + format_chat_prompt, + format_history, +) + +# Generation utilities +from .generation import ( + GenerationConfig, + GenerationResult, + GenerationStats, + generate, + generate_stream, + get_stop_tokens, +) from .generator import generate_response, generate_sequence +# Loader utilities +from .loader import ( + DownloadConfig, + DownloadResult, + DType, + HFLoader, + LoadedWeights, + StandardWeightConverter, + WeightConverter, +) + +# High-level pipeline +from .pipeline import ( + InferencePipeline, + PipelineConfig, + PipelineState, +) + __all__ = [ + # Legacy "generate_response", "generate_sequence", + # Loader + "DownloadConfig", + "DownloadResult", + "DType", + "HFLoader", + "LoadedWeights", + "StandardWeightConverter", + "WeightConverter", + # Chat + "ASSISTANT_SUFFIX", + "ChatHistory", + "ChatMessage", + "FallbackTemplate", + "Role", + "format_chat_prompt", + "format_history", + # Generation + "GenerationConfig", + "GenerationResult", + "GenerationStats", + "generate", + "generate_stream", + "get_stop_tokens", + # Pipeline + "InferencePipeline", + "PipelineConfig", + "PipelineState", ] diff --git a/src/chuk_lazarus/inference/chat.py b/src/chuk_lazarus/inference/chat.py new file mode 100644 index 00000000..3fc48314 --- /dev/null +++ b/src/chuk_lazarus/inference/chat.py @@ -0,0 +1,168 @@ +""" +Chat formatting and message handling utilities. + +Provides typed structures for chat messages and formatting +with tokenizer chat templates. + +Design principles: +- Use enums for roles, not strings +- Pydantic models for all data structures +- No dictionary goop +""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING + +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +class Role(str, Enum): + """Chat message roles.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + MODEL = "model" # Used by some models (Gemma) + + def display_name(self) -> str: + """Get display name for formatting.""" + return self.value.capitalize() + + +class ChatMessage(BaseModel): + """A single chat message.""" + + role: Role + content: str + + def to_tokenizer_format(self) -> dict[str, str]: + """Convert to dict format expected by tokenizers.""" + return {"role": self.role.value, "content": self.content} + + +class ChatHistory(BaseModel): + """Container for chat conversation history.""" + + messages: list[ChatMessage] = Field(default_factory=list) + system_message: str | None = Field(None, description="Optional system prompt") + + def add_user(self, content: str) -> ChatHistory: + """Add a user message.""" + self.messages.append(ChatMessage(role=Role.USER, content=content)) + return self + + def add_assistant(self, content: str) -> ChatHistory: + """Add an assistant message.""" + self.messages.append(ChatMessage(role=Role.ASSISTANT, content=content)) + return self + + def add_system(self, content: str) -> ChatHistory: + """Set the system message.""" + self.system_message = content + return self + + def clear(self) -> ChatHistory: + """Clear all messages but keep system prompt.""" + self.messages = [] + return self + + def to_tokenizer_format(self) -> list[dict[str, str]]: + """Convert to list of message dicts for tokenizer.""" + result = [] + if self.system_message: + result.append( + ChatMessage(role=Role.SYSTEM, content=self.system_message).to_tokenizer_format() + ) + for msg in self.messages: + result.append(msg.to_tokenizer_format()) + return result + + +class FallbackTemplate(str, Enum): + """Fallback templates when tokenizer has no chat template.""" + + SIMPLE = "simple" + CHATML = "chatml" + + +# Constants for formatting +ASSISTANT_SUFFIX = "Assistant:" +NEWLINE_DOUBLE = "\n\n" + + +def format_chat_prompt( + tokenizer: PreTrainedTokenizer, + user_message: str, + system_message: str | None = None, + add_generation_prompt: bool = True, +) -> str: + """Format a single-turn chat prompt using tokenizer's template. + + Args: + tokenizer: HuggingFace tokenizer with chat_template + user_message: The user's message + system_message: Optional system prompt + add_generation_prompt: Whether to add generation prompt suffix + + Returns: + Formatted prompt string + """ + history = ChatHistory() + if system_message: + history.add_system(system_message) + history.add_user(user_message) + + return format_history(tokenizer, history, add_generation_prompt=add_generation_prompt) + + +def format_history( + tokenizer: PreTrainedTokenizer, + history: ChatHistory, + add_generation_prompt: bool = True, +) -> str: + """Format chat history using tokenizer's template. + + Args: + tokenizer: HuggingFace tokenizer with chat_template + history: ChatHistory containing messages + add_generation_prompt: Whether to add generation prompt suffix + + Returns: + Formatted prompt string + """ + messages = history.to_tokenizer_format() + + # Try tokenizer's chat template + if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + try: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + except Exception: + pass + + # Fallback to simple format + return _format_simple(history) + + +def _format_simple(history: ChatHistory) -> str: + """Simple fallback format for models without chat templates.""" + parts: list[str] = [] + + if history.system_message: + system_msg = ChatMessage(role=Role.SYSTEM, content=history.system_message) + parts.append(f"{system_msg.role.display_name()}: {system_msg.content}") + + for msg in history.messages: + parts.append(f"{msg.role.display_name()}: {msg.content}") + + prompt = NEWLINE_DOUBLE.join(parts) + prompt += NEWLINE_DOUBLE + ASSISTANT_SUFFIX + return prompt diff --git a/src/chuk_lazarus/inference/generation.py b/src/chuk_lazarus/inference/generation.py new file mode 100644 index 00000000..13396f0a --- /dev/null +++ b/src/chuk_lazarus/inference/generation.py @@ -0,0 +1,207 @@ +""" +Text generation utilities with typed outputs. + +Provides high-level generation functions with proper +type safety and statistics tracking. +""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +import mlx.core as mx +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from chuk_lazarus.models_v2.base import CausalLMProtocol + + +class GenerationConfig(BaseModel): + """Configuration for text generation.""" + + max_new_tokens: int = Field(100, ge=1, description="Maximum tokens to generate") + temperature: float = Field(0.7, ge=0.0, description="Sampling temperature") + top_p: float = Field(0.9, ge=0.0, le=1.0, description="Nucleus sampling threshold") + top_k: int | None = Field(None, ge=1, description="Top-k sampling") + stop_tokens: list[int] = Field(default_factory=list, description="Token IDs to stop on") + + +class GenerationStats(BaseModel): + """Statistics from a generation run.""" + + input_tokens: int = Field(..., description="Number of input tokens") + output_tokens: int = Field(..., description="Number of generated tokens") + total_time_seconds: float = Field(..., description="Total generation time") + tokens_per_second: float = Field(..., description="Generation speed") + + @property + def summary(self) -> str: + """Human-readable summary.""" + return ( + f"Generated {self.output_tokens} tokens in {self.total_time_seconds:.2f}s " + f"({self.tokens_per_second:.1f} tok/s)" + ) + + +class GenerationResult(BaseModel): + """Result of text generation.""" + + text: str = Field(..., description="Generated text") + stats: GenerationStats = Field(..., description="Generation statistics") + stop_reason: str = Field("max_tokens", description="Why generation stopped") + + +def get_stop_tokens(tokenizer: PreTrainedTokenizer) -> list[int]: + """Extract stop token IDs from tokenizer. + + Args: + tokenizer: HuggingFace tokenizer + + Returns: + List of token IDs to stop on + """ + stop_tokens = [] + + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + if isinstance(tokenizer.eos_token_id, list): + stop_tokens.extend(tokenizer.eos_token_id) + else: + stop_tokens.append(tokenizer.eos_token_id) + + return stop_tokens + + +def generate( + model: CausalLMProtocol, + tokenizer: PreTrainedTokenizer, + prompt: str, + config: GenerationConfig | None = None, +) -> GenerationResult: + """Generate text from a prompt. + + Args: + model: Causal language model with generate() method + tokenizer: Tokenizer for encoding/decoding + prompt: Input prompt text + config: Generation configuration + + Returns: + GenerationResult with text and stats + """ + if config is None: + config = GenerationConfig() + + # Encode input + input_ids = tokenizer.encode(prompt, return_tensors="np") + input_ids = mx.array(input_ids) + input_length = input_ids.shape[1] + + # Get stop tokens + stop_tokens = config.stop_tokens or get_stop_tokens(tokenizer) + + # Generate + start_time = time.time() + output_ids = model.generate( + input_ids, + max_new_tokens=config.max_new_tokens, + temperature=config.temperature, + top_p=config.top_p, + top_k=config.top_k, + stop_tokens=stop_tokens, + ) + mx.eval(output_ids) + gen_time = time.time() - start_time + + # Decode generated tokens only + new_tokens = output_ids[0, input_length:] + output_length = new_tokens.shape[0] + generated_text = tokenizer.decode(new_tokens.tolist(), skip_special_tokens=True) + + # Determine stop reason + stop_reason = "max_tokens" + if output_length < config.max_new_tokens: + if new_tokens.size > 0 and int(new_tokens[-1]) in stop_tokens: + stop_reason = "eos" + else: + stop_reason = "stop_token" + + # Build stats + stats = GenerationStats( + input_tokens=input_length, + output_tokens=output_length, + total_time_seconds=gen_time, + tokens_per_second=output_length / gen_time if gen_time > 0 else 0, + ) + + return GenerationResult( + text=generated_text, + stats=stats, + stop_reason=stop_reason, + ) + + +def generate_stream( + model: CausalLMProtocol, + tokenizer: PreTrainedTokenizer, + prompt: str, + config: GenerationConfig | None = None, +): + """Generate text with streaming output. + + Yields text chunks as they're generated. + + Args: + model: Causal language model + tokenizer: Tokenizer for encoding/decoding + prompt: Input prompt text + config: Generation configuration + + Yields: + Text chunks as they're generated + """ + if config is None: + config = GenerationConfig() + + # Encode input + input_ids = tokenizer.encode(prompt, return_tensors="np") + input_ids = mx.array(input_ids) + + # Get stop tokens + stop_tokens = set(config.stop_tokens or get_stop_tokens(tokenizer)) + + # Generate token by token + tokens: list[int] = [] + cache = None + y = input_ids + + for _ in range(config.max_new_tokens): + logits, cache = model(y, cache=cache) + if logits is None or logits.shape[1] == 0: + break + + logits = logits[:, -1, :] + + # Sample + if config.temperature == 0: + next_token = mx.argmax(logits, axis=-1) + else: + probs = mx.softmax(logits / config.temperature, axis=-1) + next_token = mx.random.categorical(probs) + + next_token_id = int(next_token.item()) + + # Check stop + if next_token_id in stop_tokens: + break + + tokens.append(next_token_id) + y = next_token[None] + + # Decode incrementally + text = tokenizer.decode(tokens, skip_special_tokens=True) + if len(text) > 0: + yield text + tokens = [] # Reset for next chunk diff --git a/src/chuk_lazarus/inference/loader.py b/src/chuk_lazarus/inference/loader.py new file mode 100644 index 00000000..81dd530c --- /dev/null +++ b/src/chuk_lazarus/inference/loader.py @@ -0,0 +1,327 @@ +""" +HuggingFace model loading utilities. + +Consolidates common patterns for downloading, loading tokenizers, +and loading weights from HuggingFace models. + +Design principles: +- Async native where applicable +- Pydantic models for configuration +- No dictionary goop - use typed structures +- No magic strings - use enums/constants +""" + +from __future__ import annotations + +import asyncio +import re +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +import mlx.core as mx +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +class DType(str, Enum): + """Supported data types for model weights.""" + + FLOAT16 = "float16" + FLOAT32 = "float32" + BFLOAT16 = "bfloat16" + + def to_mlx(self) -> mx.Dtype: + """Convert to MLX dtype.""" + mapping = { + DType.FLOAT16: mx.float16, + DType.FLOAT32: mx.float32, + DType.BFLOAT16: mx.bfloat16, + } + return mapping[self] + + +class DownloadConfig(BaseModel): + """Configuration for model download.""" + + model_id: str = Field(..., description="HuggingFace model ID") + cache_dir: Path | None = Field(None, description="Local cache directory") + allow_patterns: list[str] = Field( + default_factory=lambda: ["*.json", "*.safetensors", "*.model", "tokenizer*"], + description="File patterns to download", + ) + prefer_sharded: bool = Field(True, description="Prefer sharded safetensors over consolidated") + + +class LoadedWeights(BaseModel): + """Container for loaded model weights with metadata.""" + + model_config = {"arbitrary_types_allowed": True} + + weights: dict[str, mx.array] = Field(..., description="Weight tensors by name") + dtype: DType = Field(DType.BFLOAT16, description="Target dtype") + source_path: Path = Field(..., description="Path weights were loaded from") + tensor_count: int = Field(..., description="Number of tensors loaded") + + @property + def layer_count(self) -> int: + """Infer number of layers from weight names.""" + max_idx = -1 + for name in self.weights: + parts = name.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts): + try: + max_idx = max(max_idx, int(parts[i + 1])) + except ValueError: + pass + return max_idx + 1 if max_idx >= 0 else 0 + + +@dataclass +class DownloadResult: + """Result of a model download operation.""" + + model_path: Path + model_id: str + is_cached: bool = False + + +@runtime_checkable +class WeightConverter(Protocol): + """Protocol for weight name converters.""" + + def convert(self, hf_name: str) -> str | None: + """Convert HuggingFace weight name to framework format. + + Returns None to skip the weight. + """ + ... + + +class StandardWeightConverter: + """Standard weight name converter for transformer models.""" + + def __init__(self, tie_word_embeddings: bool = False): + self.tie_word_embeddings = tie_word_embeddings + + def convert(self, hf_name: str) -> str | None: + """Convert HuggingFace weight name to framework format.""" + # Embeddings + if hf_name == "model.embed_tokens.weight": + return "model.embed_tokens.weight.weight" + + # Final layer norm + if hf_name == "model.norm.weight": + return "model.norm.weight" + + # LM head + if hf_name == "lm_head.weight": + if self.tie_word_embeddings: + return None + return "lm_head.lm_head.weight" + + # Layer pattern + layer_match = re.match(r"model\.layers\.(\d+)\.(.*)", hf_name) + if layer_match: + layer_idx = layer_match.group(1) + rest = layer_match.group(2) + + # Skip rotary embeddings - computed dynamically + if "rotary_emb" in rest: + return None + + return f"model.layers.{layer_idx}.{rest}" + + return None + + +class HFLoader: + """High-level loader for HuggingFace models.""" + + def __init__(self, config: DownloadConfig | None = None): + self._config = config + + @staticmethod + def download( + model_id: str, + cache_dir: Path | str | None = None, + prefer_sharded: bool = True, + ) -> DownloadResult: + """Download model from HuggingFace Hub synchronously. + + Args: + model_id: HuggingFace model ID + cache_dir: Optional cache directory + prefer_sharded: Prefer sharded over consolidated safetensors + + Returns: + DownloadResult with path and metadata + """ + try: + from huggingface_hub import list_repo_files, snapshot_download + except ImportError as err: + raise ImportError( + "huggingface_hub not installed. Run: pip install huggingface_hub" + ) from err + + print(f"Downloading {model_id}...") + + # Determine ignore patterns + ignore_patterns: list[str] = [] + if prefer_sharded: + try: + files = list_repo_files(model_id) + has_sharded = any("model-0" in f and f.endswith(".safetensors") for f in files) + has_consolidated = any(f == "consolidated.safetensors" for f in files) + + if has_sharded and has_consolidated: + ignore_patterns.append("consolidated.safetensors") + print(" (Skipping consolidated.safetensors - using sharded files)") + except Exception: + pass + + path = snapshot_download( + model_id, + cache_dir=str(cache_dir) if cache_dir else None, + allow_patterns=["*.json", "*.safetensors", "*.model", "tokenizer*"], + ignore_patterns=ignore_patterns if ignore_patterns else None, + ) + + return DownloadResult( + model_path=Path(path), + model_id=model_id, + ) + + @staticmethod + async def download_async( + model_id: str, + cache_dir: Path | str | None = None, + prefer_sharded: bool = True, + ) -> DownloadResult: + """Download model from HuggingFace Hub asynchronously. + + Runs the download in a thread pool to avoid blocking. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, lambda: HFLoader.download(model_id, cache_dir, prefer_sharded) + ) + + @staticmethod + def load_tokenizer(model_path: Path | str) -> PreTrainedTokenizer: + """Load tokenizer from model path. + + Args: + model_path: Path to model directory + + Returns: + Tokenizer with pad_token configured + """ + try: + from transformers import AutoTokenizer + except ImportError as err: + raise ImportError("transformers not installed. Run: pip install transformers") from err + + tokenizer = AutoTokenizer.from_pretrained(str(model_path)) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + + @staticmethod + def load_weights( + model_path: Path, + dtype: DType = DType.BFLOAT16, + converter: WeightConverter | None = None, + ) -> LoadedWeights: + """Load weights from safetensors files. + + Args: + model_path: Path to model directory + dtype: Target dtype for weights + converter: Optional weight name converter + + Returns: + LoadedWeights container with tensors and metadata + """ + safetensor_files = sorted(model_path.glob("*.safetensors")) + if not safetensor_files: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + + target_dtype = dtype.to_mlx() + + # Use default converter if none provided + if converter is None: + converter = StandardWeightConverter() + + # Load and convert weights + converted_weights: dict[str, mx.array] = {} + + for sf_path in safetensor_files: + print(f" Loading {sf_path.name}...") + raw_weights = mx.load(str(sf_path)) + + for hf_name, weight in raw_weights.items(): + # Convert name + our_name = converter.convert(hf_name) + if our_name is None: + continue + + # Convert dtype + if weight.dtype in (mx.float32, mx.float16, mx.bfloat16): + weight = weight.astype(target_dtype) + + converted_weights[our_name] = weight + + return LoadedWeights( + weights=converted_weights, + dtype=dtype, + source_path=model_path, + tensor_count=len(converted_weights), + ) + + @staticmethod + def build_nested_weights(loaded: LoadedWeights) -> dict: + """Convert flat weights to nested structure for model.update(). + + Args: + loaded: LoadedWeights from load_weights() + + Returns: + Nested dictionary structure + """ + flat_weights = loaded.weights + + # Find maximum layer index + max_layer_idx = loaded.layer_count - 1 + + # Build nested structure + nested: dict = {} + for name, weight in flat_weights.items(): + parts = name.split(".") + current = nested + + i = 0 + while i < len(parts) - 1: + part = parts[i] + + if part == "layers": + if part not in current: + current[part] = [{} for _ in range(max_layer_idx + 1)] + layer_idx = int(parts[i + 1]) + current = current[part][layer_idx] + i += 2 + else: + if part not in current: + current[part] = {} + current = current[part] + i += 1 + + current[parts[-1]] = weight + + return nested diff --git a/src/chuk_lazarus/inference/pipeline.py b/src/chuk_lazarus/inference/pipeline.py new file mode 100644 index 00000000..e3570881 --- /dev/null +++ b/src/chuk_lazarus/inference/pipeline.py @@ -0,0 +1,317 @@ +""" +High-level inference pipeline for simplified model usage. + +Provides a single-import, minimal-code API for loading and +running inference with any supported model family. + +Design principles: +- One-liner setup where possible +- Async native +- Pydantic for configuration +- No dictionary goop +""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable + +import mlx.core as mx +from pydantic import BaseModel, Field + +from .chat import ChatHistory, format_chat_prompt, format_history +from .generation import GenerationConfig, GenerationResult, generate +from .loader import DType, HFLoader, WeightConverter + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + +# Type variables for model and config +ConfigT = TypeVar("ConfigT", bound=BaseModel) +ModelT = TypeVar("ModelT") + + +@runtime_checkable +class CausalLMProtocol(Protocol): + """Protocol for causal language models.""" + + def generate( + self, + input_ids: mx.array, + max_new_tokens: int = 100, + temperature: float = 0.7, + **kwargs, + ) -> mx.array: ... + + def update(self, weights: dict) -> None: ... + + def parameters(self) -> dict: ... + + +class PipelineConfig(BaseModel): + """Configuration for the inference pipeline.""" + + dtype: DType = Field(DType.BFLOAT16, description="Weight dtype") + cache_dir: Path | None = Field(None, description="Model cache directory") + default_system_message: str | None = Field( + "You are a helpful assistant.", description="Default system prompt" + ) + default_max_tokens: int = Field(100, ge=1, description="Default max tokens") + default_temperature: float = Field(0.7, ge=0.0, description="Default temperature") + + +class PipelineState(BaseModel): + """Internal state of the pipeline.""" + + model_id: str + model_path: Path + tensor_count: int + is_loaded: bool = False + + +class InferencePipeline(Generic[ConfigT, ModelT]): + """High-level inference pipeline for any model family. + + Example usage: + + # One-liner setup + pipeline = InferencePipeline.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + LlamaForCausalLM, + LlamaConfig, + ) + + # Simple chat + response = pipeline.chat("What is the capital of France?") + print(response.text) + + # With custom settings + response = pipeline.generate( + "Write a poem about AI", + max_new_tokens=200, + temperature=0.9, + ) + """ + + def __init__( + self, + model: ModelT, + tokenizer: PreTrainedTokenizer, + config: ConfigT, + pipeline_config: PipelineConfig | None = None, + state: PipelineState | None = None, + ): + self._model = model + self._tokenizer = tokenizer + self._config = config + self._pipeline_config = pipeline_config or PipelineConfig() + self._state = state + + @property + def model(self) -> ModelT: + """Access the underlying model.""" + return self._model + + @property + def tokenizer(self) -> PreTrainedTokenizer: + """Access the tokenizer.""" + return self._tokenizer + + @property + def config(self) -> ConfigT: + """Access the model config.""" + return self._config + + @classmethod + def from_pretrained( + cls, + model_id: str, + model_class: type[ModelT], + config_class: type[ConfigT], + converter: WeightConverter | None = None, + pipeline_config: PipelineConfig | None = None, + ) -> InferencePipeline[ConfigT, ModelT]: + """Load a model from HuggingFace. + + Args: + model_id: HuggingFace model ID + model_class: Model class to instantiate + config_class: Config class for model + converter: Optional weight name converter + pipeline_config: Pipeline configuration + + Returns: + Configured InferencePipeline instance + """ + pipeline_config = pipeline_config or PipelineConfig() + + print(f"Loading {model_id}...") + print("=" * 60) + + # Download + print("\n1. Downloading model...") + result = HFLoader.download(model_id, cache_dir=pipeline_config.cache_dir) + print(f" Path: {result.model_path}") + + # Load config + print("\n2. Loading configuration...") + config = _load_config(result.model_path, config_class) + + # Create model + print("\n3. Creating model...") + model = model_class(config) + + # Load weights + print("\n4. Loading weights...") + loaded = HFLoader.load_weights( + result.model_path, + dtype=pipeline_config.dtype, + converter=converter, + ) + print(f" Loaded {loaded.tensor_count} tensors") + + # Apply weights + nested = HFLoader.build_nested_weights(loaded) + model.update(nested) + mx.eval(model.parameters()) + print(" Weights applied!") + + # Load tokenizer + print("\n5. Loading tokenizer...") + tokenizer = HFLoader.load_tokenizer(result.model_path) + print(f" Vocab size: {len(tokenizer)}") + + print("\n" + "=" * 60) + print("Model loaded successfully!") + + state = PipelineState( + model_id=model_id, + model_path=result.model_path, + tensor_count=loaded.tensor_count, + is_loaded=True, + ) + + return cls( + model=model, + tokenizer=tokenizer, + config=config, + pipeline_config=pipeline_config, + state=state, + ) + + @classmethod + async def from_pretrained_async( + cls, + model_id: str, + model_class: type[ModelT], + config_class: type[ConfigT], + converter: WeightConverter | None = None, + pipeline_config: PipelineConfig | None = None, + ) -> InferencePipeline[ConfigT, ModelT]: + """Load a model from HuggingFace asynchronously. + + Runs in a thread pool to avoid blocking. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, + lambda: cls.from_pretrained( + model_id, model_class, config_class, converter, pipeline_config + ), + ) + + def chat( + self, + user_message: str, + system_message: str | None = None, + max_new_tokens: int | None = None, + temperature: float | None = None, + ) -> GenerationResult: + """Generate a response to a chat message. + + Args: + user_message: The user's message + system_message: Optional system prompt (uses default if not provided) + max_new_tokens: Max tokens to generate (uses default if not provided) + temperature: Sampling temperature (uses default if not provided) + + Returns: + GenerationResult with text and stats + """ + system = system_message or self._pipeline_config.default_system_message + prompt = format_chat_prompt(self._tokenizer, user_message, system) + + config = GenerationConfig( + max_new_tokens=max_new_tokens or self._pipeline_config.default_max_tokens, + temperature=temperature or self._pipeline_config.default_temperature, + ) + + return generate(self._model, self._tokenizer, prompt, config) + + def chat_with_history( + self, + history: ChatHistory, + max_new_tokens: int | None = None, + temperature: float | None = None, + ) -> GenerationResult: + """Generate a response using chat history. + + Args: + history: ChatHistory with conversation + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + + Returns: + GenerationResult with text and stats + """ + prompt = format_history(self._tokenizer, history) + + config = GenerationConfig( + max_new_tokens=max_new_tokens or self._pipeline_config.default_max_tokens, + temperature=temperature or self._pipeline_config.default_temperature, + ) + + return generate(self._model, self._tokenizer, prompt, config) + + def generate( + self, + prompt: str, + max_new_tokens: int | None = None, + temperature: float | None = None, + config: GenerationConfig | None = None, + ) -> GenerationResult: + """Generate text from a raw prompt. + + Args: + prompt: Input prompt (no formatting applied) + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + config: Full generation config (overrides other params) + + Returns: + GenerationResult with text and stats + """ + if config is None: + config = GenerationConfig( + max_new_tokens=max_new_tokens or self._pipeline_config.default_max_tokens, + temperature=temperature or self._pipeline_config.default_temperature, + ) + + return generate(self._model, self._tokenizer, prompt, config) + + +def _load_config(model_path: Path, config_class: type[ConfigT]) -> ConfigT: + """Load and parse model config from HuggingFace format.""" + config_path = model_path / "config.json" + with open(config_path) as f: + config_data = json.load(f) + + # Handle list-valued token IDs (common in newer models) + for key in ("eos_token_id", "bos_token_id", "pad_token_id"): + if key in config_data and isinstance(config_data[key], list): + config_data[key] = config_data[key][0] if config_data[key] else None + + return config_class(**config_data) diff --git a/src/chuk_lazarus/models_v2/__init__.py b/src/chuk_lazarus/models_v2/__init__.py index 7eb4f63d..91af6a75 100644 --- a/src/chuk_lazarus/models_v2/__init__.py +++ b/src/chuk_lazarus/models_v2/__init__.py @@ -105,8 +105,15 @@ ) # Families -from .families import llama, mamba +from .families import granite, llama, llama4, mamba +from .families.granite import ( + GraniteConfig, + GraniteForCausalLM, + GraniteHybridConfig, + GraniteHybridForCausalLM, +) from .families.llama import LlamaConfig, LlamaForCausalLM +from .families.llama4 import Llama4Config, Llama4ForCausalLM, Llama4TextConfig from .families.mamba import MambaConfig, MambaForCausalLM # Heads @@ -238,10 +245,19 @@ "SequenceClassifier", "TokenClassifier", # === Families === + "granite", "llama", + "llama4", "mamba", + "GraniteConfig", + "GraniteForCausalLM", + "GraniteHybridConfig", + "GraniteHybridForCausalLM", "LlamaConfig", "LlamaForCausalLM", + "Llama4Config", + "Llama4TextConfig", + "Llama4ForCausalLM", "MambaConfig", "MambaForCausalLM", # === Loader === diff --git a/src/chuk_lazarus/models_v2/families/__init__.py b/src/chuk_lazarus/models_v2/families/__init__.py index 72aa5311..c130e145 100644 --- a/src/chuk_lazarus/models_v2/families/__init__.py +++ b/src/chuk_lazarus/models_v2/families/__init__.py @@ -9,10 +9,12 @@ Available families: - gemma: Gemma 3, FunctionGemma +- granite: IBM Granite 3.x/4.x with hybrid Mamba-2/Transformer - llama: Llama 1/2/3, Mistral, and compatible models +- llama4: Llama 4 with MoE and multimodal support - mamba: Mamba SSM models """ -from . import gemma, llama, mamba +from . import gemma, granite, llama, llama4, mamba -__all__ = ["gemma", "llama", "mamba"] +__all__ = ["gemma", "granite", "llama", "llama4", "mamba"] diff --git a/src/chuk_lazarus/models_v2/families/gemma/model.py b/src/chuk_lazarus/models_v2/families/gemma/model.py index 9739a245..ad299c77 100644 --- a/src/chuk_lazarus/models_v2/families/gemma/model.py +++ b/src/chuk_lazarus/models_v2/families/gemma/model.py @@ -81,7 +81,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.n_rep = self.num_heads // self.num_kv_heads # Attention scale using query_pre_attn_scalar - self.scale = config.query_pre_attn_scalar ** -0.5 + self.scale = config.query_pre_attn_scalar**-0.5 # Projections (no bias for Gemma) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) @@ -166,9 +166,7 @@ def clip_residual(x: mx.array, y: mx.array) -> mx.array: if x.dtype != mx.float16: return x + y bound = mx.finfo(mx.float16).max - return mx.clip( - x.astype(mx.float32) + y.astype(mx.float32), -bound, bound - ).astype(mx.float16) + return mx.clip(x.astype(mx.float32) + y.astype(mx.float32), -bound, bound).astype(mx.float16) class GemmaBlock(Block): @@ -203,6 +201,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): @property def block_type(self): from ...core.enums import BlockType + return BlockType.TRANSFORMER @property @@ -249,10 +248,7 @@ def __init__(self, config: GemmaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Transformer blocks - self.layers = [ - GemmaBlock(config, layer_idx=i) - for i in range(config.num_hidden_layers) - ] + self.layers = [GemmaBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)] # Final norm self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -276,15 +272,7 @@ def _create_attention_mask( window_size: int | None = None, ) -> mx.array: """Create causal attention mask, optionally with sliding window.""" - batch_size, seq_len, _ = h.shape - - # Get total sequence length including cache - if cache is not None and cache[0] is not None: - cache_len = cache[0][0].shape[2] - total_len = cache_len + seq_len - else: - cache_len = 0 - total_len = seq_len + _, seq_len, _ = h.shape # Create causal mask mask = nn.MultiHeadAttention.create_additive_causal_mask(seq_len) @@ -318,7 +306,7 @@ def __call__( h = self.embed_tokens(input_ids) # Scale embeddings by sqrt(hidden_size) - Gemma specific - h = h * mx.array(self.config.hidden_size ** 0.5, dtype=mx.bfloat16).astype(h.dtype) + h = h * mx.array(self.config.hidden_size**0.5, dtype=mx.bfloat16).astype(h.dtype) # Initialize cache if needed if cache is None: @@ -327,7 +315,9 @@ def __call__( # Create masks for global and sliding window layers # Global layers get a reference cache for mask creation global_layer_idx = self.sliding_window_pattern - 1 - global_mask = self._create_attention_mask(h, [cache[global_layer_idx]] if cache[global_layer_idx] else None) + global_mask = self._create_attention_mask( + h, [cache[global_layer_idx]] if cache[global_layer_idx] else None + ) if self.sliding_window_pattern > 1: sliding_mask = self._create_attention_mask( @@ -499,7 +489,7 @@ def generate( # Apply top-k if top_k is not None and top_k > 0: - top_k_logits, _ = mx.topk(logits, k=min(top_k, logits.shape[-1])) + top_k_logits = mx.topk(logits, k=min(top_k, logits.shape[-1])) min_val = top_k_logits[:, -1:] logits = mx.where(logits < min_val, float("-inf"), logits) diff --git a/src/chuk_lazarus/models_v2/families/granite/__init__.py b/src/chuk_lazarus/models_v2/families/granite/__init__.py new file mode 100644 index 00000000..907a59cf --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/granite/__init__.py @@ -0,0 +1,55 @@ +""" +IBM Granite model family. + +Supports: +- Granite 3.0/3.1: Dense transformer with multipliers +- Granite 4.0: Hybrid Mamba-2/Transformer with optional MoE + +Key features: +- Embedding/attention/residual multipliers +- Logits scaling +- Mamba-2 blocks for efficient long-context +- Fine-grained MoE with shared experts + +Reference: https://huggingface.co/ibm-granite +""" + +from .config import GraniteConfig, GraniteHybridConfig +from .hybrid import ( + Granite4, + GraniteHybrid, + GraniteHybridAttention, + GraniteHybridBlock, + GraniteHybridForCausalLM, + GraniteHybridModel, + GraniteHybridMoE, + GraniteMamba2Block, +) +from .model import ( + Granite, + GraniteAttention, + GraniteBlock, + GraniteForCausalLM, + GraniteModel, +) + +__all__ = [ + # Config + "GraniteConfig", + "GraniteHybridConfig", + # Granite 3.x + "Granite", + "GraniteForCausalLM", + "GraniteModel", + "GraniteBlock", + "GraniteAttention", + # Granite 4.x + "Granite4", + "GraniteHybrid", + "GraniteHybridForCausalLM", + "GraniteHybridModel", + "GraniteHybridBlock", + "GraniteHybridAttention", + "GraniteHybridMoE", + "GraniteMamba2Block", +] diff --git a/src/chuk_lazarus/models_v2/families/granite/config.py b/src/chuk_lazarus/models_v2/families/granite/config.py new file mode 100644 index 00000000..5fb208da --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/granite/config.py @@ -0,0 +1,408 @@ +""" +Granite configuration. + +Supports the IBM Granite model family: +- Granite 3.0/3.1: Dense transformer with multipliers +- Granite 4.0: Hybrid Mamba-2/Transformer with optional MoE + +Reference: https://huggingface.co/ibm-granite +""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import Field + +from ...core.config import ModelConfig + + +class GraniteConfig(ModelConfig): + """ + Configuration for Granite 3.x models. + + Granite 3.x is a dense transformer similar to Llama but with: + - Multipliers for embeddings, attention, and residuals + - Logits scaling + - Optional attention dropout + + Example: + >>> # Granite 3.1 2B + >>> config = GraniteConfig( + ... vocab_size=49155, + ... hidden_size=2048, + ... num_hidden_layers=40, + ... num_attention_heads=32, + ... num_key_value_heads=8, + ... intermediate_size=8192, + ... ) + """ + + model_type: str = "granite" + + # Granite-specific multipliers + embedding_multiplier: float = Field( + default=12.0, + description="Multiplier applied to token embeddings", + ) + attention_multiplier: float = Field( + default=1.0, + description="Multiplier applied to attention output (1/sqrt(num_heads) typical)", + ) + residual_multiplier: float = Field( + default=1.0, + description="Multiplier applied to residual connections", + ) + logits_scaling: float = Field( + default=1.0, + description="Scaling factor for output logits", + ) + + # Standard transformer settings + hidden_act: str = "silu" + rope_theta: float = 10000.0 + rms_norm_eps: float = 1e-5 + + # Attention settings + attention_dropout: float = 0.0 + attention_bias: bool = False + mlp_bias: bool = False + + # RoPE scaling + rope_scaling: dict[str, Any] | None = None + + @classmethod + def granite_3_8b(cls) -> GraniteConfig: + """Granite 3.0 8B configuration.""" + return cls( + vocab_size=49155, + hidden_size=4096, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=12800, + max_position_embeddings=4096, + embedding_multiplier=12.0, + attention_multiplier=0.0078125, # 1/128 + residual_multiplier=0.22, + logits_scaling=16.0, + attention_dropout=0.1, + tie_word_embeddings=True, + ) + + @classmethod + def granite_3_1_2b(cls) -> GraniteConfig: + """Granite 3.1 2B configuration.""" + return cls( + vocab_size=49155, + hidden_size=2048, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=8192, + max_position_embeddings=131072, + rope_theta=5000000.0, + embedding_multiplier=12.0, + attention_multiplier=0.015625, # 1/64 + residual_multiplier=0.22, + logits_scaling=8.0, + attention_dropout=0.1, + tie_word_embeddings=True, + ) + + @classmethod + def granite_3_1_8b(cls) -> GraniteConfig: + """Granite 3.1 8B configuration.""" + return cls( + vocab_size=49155, + hidden_size=4096, + num_hidden_layers=40, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=12800, + max_position_embeddings=131072, + rope_theta=5000000.0, + embedding_multiplier=12.0, + attention_multiplier=0.0078125, + residual_multiplier=0.22, + logits_scaling=16.0, + attention_dropout=0.1, + tie_word_embeddings=True, + ) + + @classmethod + def tiny(cls) -> GraniteConfig: + """Tiny Granite for testing.""" + return cls( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, + max_position_embeddings=256, + embedding_multiplier=1.0, + attention_multiplier=1.0, + residual_multiplier=1.0, + logits_scaling=1.0, + ) + + +class GraniteHybridConfig(ModelConfig): + """ + Configuration for Granite 4.x hybrid models. + + Granite 4.0 uses a hybrid Mamba-2/Transformer architecture with: + - Per-layer type configuration (mamba or attention) + - Optional MoE for Tiny and Small variants + - Shared experts for MoE variants + - 9:1 Mamba to Transformer ratio (typical) + + Example: + >>> # Granite 4.0 Tiny (7B total, 1B active) + >>> config = GraniteHybridConfig( + ... vocab_size=49160, + ... hidden_size=1536, + ... num_hidden_layers=40, + ... layer_types=["mamba"]*5 + ["attention"] + ["mamba"]*9 + ["attention"] + ..., + ... num_local_experts=62, + ... num_experts_per_tok=6, + ... ) + """ + + model_type: str = "granitemoehybrid" + + # Layer configuration + layer_types: list[Literal["mamba", "attention"]] = Field( + default_factory=lambda: ["attention"] * 40, + description="Type of each layer (mamba or attention)", + ) + position_embedding_type: Literal["rope", "nope"] = Field( + default="rope", + description="Position embedding type (nope for Mamba-heavy models)", + ) + + # Granite-specific multipliers + embedding_multiplier: float = 12.0 + attention_multiplier: float = 0.0078125 + residual_multiplier: float = 0.22 + logits_scaling: float = 6.0 + + # Standard settings + hidden_act: str = "silu" + rope_theta: float = 10000.0 + rms_norm_eps: float = 1e-5 + normalization_function: str = "rmsnorm" + + # Attention settings + attention_dropout: float = 0.0 + attention_bias: bool = False + + # Mamba-2 settings + mamba_d_state: int = Field(default=128, description="SSM state dimension") + mamba_d_conv: int = Field(default=4, description="Convolution kernel size") + mamba_expand: int = Field(default=2, description="Expansion factor for Mamba") + mamba_n_heads: int = Field(default=48, description="Number of Mamba heads") + mamba_d_head: int = Field(default=64, description="Dimension per Mamba head") + mamba_n_groups: int = Field(default=1, description="Number of groups for Mamba") + mamba_chunk_size: int = Field(default=256, description="Chunk size for Mamba-2") + mamba_conv_bias: bool = True + mamba_proj_bias: bool = False + + # MoE settings (optional) + num_local_experts: int = Field(default=0, description="Number of experts (0 = dense)") + num_experts_per_tok: int = Field(default=0, description="Experts per token (0 = dense)") + shared_intermediate_size: int = Field( + default=0, + description="Shared expert intermediate size (0 = no shared expert)", + ) + router_aux_loss_coef: float = 0.0 + output_router_logits: bool = False + + # RoPE scaling + rope_scaling: dict[str, Any] | None = None + + @property + def is_moe(self) -> bool: + """Whether this is a MoE model.""" + return self.num_local_experts > 0 and self.num_experts_per_tok > 0 + + @property + def num_mamba_layers(self) -> int: + """Count of Mamba layers.""" + return sum(1 for t in self.layer_types if t == "mamba") + + @property + def num_attention_layers(self) -> int: + """Count of attention layers.""" + return sum(1 for t in self.layer_types if t == "attention") + + @classmethod + def granite_4_micro(cls) -> GraniteHybridConfig: + """ + Granite 4.0 Micro (3B dense hybrid). + + All attention layers, no MoE. + """ + return cls( + vocab_size=100352, + hidden_size=2560, + num_hidden_layers=40, + num_attention_heads=40, + num_key_value_heads=8, + intermediate_size=8192, + max_position_embeddings=131072, + layer_types=["attention"] * 40, + position_embedding_type="rope", + rope_theta=10000000.0, + embedding_multiplier=12.0, + attention_multiplier=0.015625, + residual_multiplier=0.22, + logits_scaling=10.0, + # Mamba settings (not used but present in config) + mamba_d_state=256, + mamba_n_heads=128, + mamba_d_head=40, + # Dense model + num_local_experts=0, + num_experts_per_tok=0, + shared_intermediate_size=8192, + tie_word_embeddings=True, + ) + + @classmethod + def granite_4_tiny(cls) -> GraniteHybridConfig: + """ + Granite 4.0 Tiny (7B total, 1B active). + + Hybrid Mamba-2/Transformer with MoE. + - 36 Mamba layers, 4 Attention layers + - 62 experts, 6 active per token + """ + # Build layer_types: 9 mamba, 1 attention pattern + layer_types: list[Literal["mamba", "attention"]] = [] + for i in range(4): # 4 attention layers total + layer_types.extend(["mamba"] * 5 if i == 0 else ["mamba"] * 9) + layer_types.append("attention") + + return cls( + vocab_size=49160, + hidden_size=1536, + num_hidden_layers=40, + num_attention_heads=12, + num_key_value_heads=4, + intermediate_size=512, + max_position_embeddings=131072, + layer_types=layer_types, + position_embedding_type="nope", + embedding_multiplier=12.0, + attention_multiplier=0.0078125, + residual_multiplier=0.22, + logits_scaling=6.0, + # Mamba-2 settings + mamba_d_state=128, + mamba_n_heads=48, + mamba_d_head=64, + mamba_expand=2, + mamba_chunk_size=256, + # MoE settings + num_local_experts=62, + num_experts_per_tok=6, + shared_intermediate_size=1024, + tie_word_embeddings=True, + ) + + @classmethod + def granite_4_small(cls) -> GraniteHybridConfig: + """ + Granite 4.0 Small (32B total, 9B active). + + Hybrid Mamba-2/Transformer with MoE. + """ + # Similar pattern to Tiny + layer_types: list[Literal["mamba", "attention"]] = [] + for i in range(4): + layer_types.extend(["mamba"] * 5 if i == 0 else ["mamba"] * 9) + layer_types.append("attention") + + return cls( + vocab_size=49160, + hidden_size=3072, + num_hidden_layers=40, + num_attention_heads=24, + num_key_value_heads=8, + intermediate_size=1024, + max_position_embeddings=131072, + layer_types=layer_types, + position_embedding_type="nope", + embedding_multiplier=12.0, + attention_multiplier=0.0078125, + residual_multiplier=0.22, + logits_scaling=8.0, + # Mamba-2 settings + mamba_d_state=128, + mamba_n_heads=96, + mamba_d_head=64, + mamba_expand=2, + mamba_chunk_size=256, + # MoE settings + num_local_experts=62, + num_experts_per_tok=6, + shared_intermediate_size=2048, + tie_word_embeddings=True, + ) + + @classmethod + def tiny(cls) -> GraniteHybridConfig: + """Tiny hybrid for testing.""" + return cls( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, + max_position_embeddings=256, + layer_types=["mamba", "mamba", "attention", "mamba"], + position_embedding_type="rope", + embedding_multiplier=1.0, + attention_multiplier=1.0, + residual_multiplier=1.0, + logits_scaling=1.0, + # Mamba settings + mamba_d_state=16, + mamba_n_heads=4, + mamba_d_head=16, + mamba_expand=2, + # Dense (no MoE for tiny) + num_local_experts=0, + num_experts_per_tok=0, + ) + + @classmethod + def tiny_moe(cls) -> GraniteHybridConfig: + """Tiny hybrid with MoE for testing.""" + return cls( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=32, + max_position_embeddings=256, + layer_types=["mamba", "mamba", "attention", "mamba"], + position_embedding_type="rope", + embedding_multiplier=1.0, + attention_multiplier=1.0, + residual_multiplier=1.0, + logits_scaling=1.0, + # Mamba settings + mamba_d_state=16, + mamba_n_heads=4, + mamba_d_head=16, + mamba_expand=2, + # MoE settings + num_local_experts=4, + num_experts_per_tok=2, + shared_intermediate_size=64, + ) diff --git a/src/chuk_lazarus/models_v2/families/granite/hybrid.py b/src/chuk_lazarus/models_v2/families/granite/hybrid.py new file mode 100644 index 00000000..4a145c5c --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/granite/hybrid.py @@ -0,0 +1,680 @@ +""" +Granite 4.x hybrid model implementation. + +Hybrid Mamba-2/Transformer architecture with optional MoE. + +Reference: https://www.ibm.com/granite/docs/models/granite +""" + +from __future__ import annotations + +from typing import Any + +import mlx.core as mx +import mlx.nn as nn + +from ...backbones.base import Backbone, BackboneOutput +from ...blocks.base import Block, BlockOutput +from ...components.embeddings import create_token_embedding +from ...components.ffn import SwiGLU +from ...components.normalization import RMSNorm +from ...core.config import FFNConfig +from ...core.registry import register_model +from ...heads import LMHead +from ...models.base import Model, ModelOutput +from .config import GraniteHybridConfig + + +class GraniteMamba2Block(nn.Module): + """ + Mamba-2 block for Granite 4.x. + + Simplified Mamba-2 implementation using the selective scan mechanism. + """ + + def __init__(self, config: GraniteHybridConfig, layer_idx: int = 0): + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Mamba-2 parameters + self.d_state = config.mamba_d_state + self.d_conv = config.mamba_d_conv + self.expand = config.mamba_expand + self.n_heads = config.mamba_n_heads + self.d_head = config.mamba_d_head + + # Expanded dimension - ensure it's divisible for groups + self.d_inner = self.expand * self.hidden_size + + # Input projection (x -> x, z for gating) + self.in_proj = nn.Linear( + self.hidden_size, + self.d_inner * 2, + bias=config.mamba_proj_bias, + ) + + # Convolution - use groups=1 for depthwise to work with any size + # For true depthwise conv, d_inner must be divisible + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + kernel_size=self.d_conv, + groups=1, # Standard conv instead of depthwise for simplicity + bias=config.mamba_conv_bias, + padding=self.d_conv - 1, + ) + + # SSM projections + # For Mamba-2, we use a simpler multi-head structure + self.dt_proj = nn.Linear(self.d_inner, self.n_heads, bias=True) + self.A = mx.ones((self.n_heads,)) * -1.0 # Learnable decay + self.D = mx.ones((self.n_heads,)) # Skip connection + + # B and C projections + self.B_proj = nn.Linear(self.d_inner, self.n_heads * self.d_state, bias=False) + self.C_proj = nn.Linear(self.d_inner, self.n_heads * self.d_state, bias=False) + + # Output projection + self.out_proj = nn.Linear(self.d_inner, self.hidden_size, bias=config.mamba_proj_bias) + + # Norm + self.norm = RMSNorm(self.d_inner, eps=config.rms_norm_eps) + + def __call__( + self, + x: mx.array, + cache: dict[str, mx.array] | None = None, + ) -> tuple[mx.array, dict[str, mx.array] | None]: + """Forward pass through Mamba-2 block.""" + batch_size, seq_len, _ = x.shape + + # Input projection + xz = self.in_proj(x) + x_proj, z = mx.split(xz, 2, axis=-1) + + # Convolution (causal) - MLX Conv1d expects (B, L, C) format + # Apply 1D conv along sequence dimension + # Use manual sliding window for simplicity + x_conv = x_proj + if self.d_conv > 1: + # Pad for causal conv + padding = mx.zeros((batch_size, self.d_conv - 1, self.d_inner)) + x_padded = mx.concatenate([padding, x_proj], axis=1) + # Simple causal conv via linear combination of shifted inputs + conv_out = mx.zeros_like(x_proj) + for i in range(self.d_conv): + shift = self.d_conv - 1 - i + conv_out = conv_out + x_padded[:, shift : shift + seq_len, :] + x_conv = conv_out / self.d_conv # Normalize + + # Apply SiLU + x_conv = nn.silu(x_conv) + + # SSM parameters + dt = nn.softplus(self.dt_proj(x_conv)) # (B, L, n_heads) + B = self.B_proj(x_conv).reshape(batch_size, seq_len, self.n_heads, self.d_state) + C = self.C_proj(x_conv).reshape(batch_size, seq_len, self.n_heads, self.d_state) + + # Reshape x for multi-head + x_heads = x_conv.reshape(batch_size, seq_len, self.n_heads, -1) # (B, L, H, D/H) + + # Selective scan (simplified) + # For each position, compute: h_t = A * h_{t-1} + B * x_t, y_t = C * h_t + y = self._selective_scan(x_heads, dt, B, C, cache) + + # Reshape back + y = y.reshape(batch_size, seq_len, self.d_inner) + + # Normalize + y = self.norm(y) + + # Gate with z + y = y * nn.silu(z) + + # Output projection + y = self.out_proj(y) + + # Update cache + new_cache = None # Simplified - no cache for now + + return y, new_cache + + def _selective_scan( + self, + x: mx.array, + dt: mx.array, + B: mx.array, + C: mx.array, + cache: dict[str, mx.array] | None = None, + ) -> mx.array: + """ + Simplified selective scan. + + For full efficiency, this should use the chunked algorithm from Mamba-2. + This is a reference implementation. + """ + batch_size, seq_len, n_heads, d_per_head = x.shape + + # Initialize state + h = mx.zeros((batch_size, n_heads, self.d_state)) + + outputs = [] + for t in range(seq_len): + # Get inputs at time t + x_t = x[:, t, :, :] # (B, H, D/H) + dt_t = dt[:, t, :] # (B, H) + B_t = B[:, t, :, :] # (B, H, N) + C_t = C[:, t, :, :] # (B, H, N) + + # Discretize A + A_bar = mx.exp(self.A * dt_t) # (B, H) + + # Update state: h = A_bar * h + B * x + # Simplified: use mean of x across d_per_head + x_mean = mx.mean(x_t, axis=-1, keepdims=True) # (B, H, 1) + h = A_bar[:, :, None] * h + B_t * x_mean + + # Output: y = C * h + D * x + y_t = mx.sum(C_t * h, axis=-1) # (B, H) + y_t = y_t + self.D * mx.mean(x_t, axis=-1) + + # Expand back to d_per_head + y_t = mx.expand_dims(y_t, axis=-1) + y_t = mx.broadcast_to(y_t, (batch_size, n_heads, d_per_head)) + + outputs.append(y_t) + + # Stack outputs + y = mx.stack(outputs, axis=1) # (B, L, H, D/H) + + return y + + +class GraniteHybridAttention(nn.Module): + """ + Attention block for Granite 4.x hybrid. + + Similar to standard GQA but with Granite-specific multipliers. + """ + + def __init__(self, config: GraniteHybridConfig, layer_idx: int = 0): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads or config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.layer_idx = layer_idx + + self.attention_multiplier = config.attention_multiplier + self.n_rep = self.num_heads // self.num_kv_heads + + # Projections + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + + # RoPE (only if using rope position embeddings) + self.use_rope = config.position_embedding_type == "rope" + if self.use_rope: + from ...components.embeddings.rope import RoPE + from ...core.config import RoPEConfig + + rope_config = RoPEConfig( + theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + ) + self.rope = RoPE(rope_config, dims=self.head_dim) + + self.scale = self.head_dim**-0.5 + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: tuple[mx.array, mx.array] | None = None, + ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]: + """Forward pass.""" + batch_size, seq_len, _ = x.shape + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + # RoPE + if self.use_rope: + offset = 0 + if cache is not None: + offset = cache[0].shape[2] + q = self.rope(q, offset=offset) + k = self.rope(k, offset=offset) + + # Cache + if cache is not None: + key_cache, value_cache = cache + k = mx.concatenate([key_cache, k], axis=2) + v = mx.concatenate([value_cache, v], axis=2) + + new_cache = (k, v) + + # Repeat KV + if self.n_rep > 1: + batch, num_kv_heads, kv_seq_len, head_dim = k.shape + k = mx.expand_dims(k, axis=2) + k = mx.repeat(k, self.n_rep, axis=2) + k = k.reshape(batch, num_kv_heads * self.n_rep, kv_seq_len, head_dim) + v = mx.expand_dims(v, axis=2) + v = mx.repeat(v, self.n_rep, axis=2) + v = v.reshape(batch, num_kv_heads * self.n_rep, kv_seq_len, head_dim) + + # Attention + output = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) + output = output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + + # Output with multiplier + output = self.o_proj(output) + output = output * self.attention_multiplier + + return output, new_cache + + +class GraniteHybridMoE(nn.Module): + """ + MoE layer for Granite 4.x with shared expert. + + Similar to Llama 4 MoE but with Granite-specific settings. + """ + + def __init__(self, config: GraniteHybridConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.shared_intermediate_size = config.shared_intermediate_size + + # Router + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + + # Shared expert (if configured) + self.has_shared_expert = config.shared_intermediate_size > 0 + if self.has_shared_expert: + shared_config = FFNConfig( + hidden_size=config.hidden_size, + intermediate_size=config.shared_intermediate_size, + ) + self.shared_expert = SwiGLU(shared_config) + + # Routed experts + expert_config = FFNConfig( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + ) + self.experts = [SwiGLU(expert_config) for _ in range(config.num_local_experts)] + + def __call__(self, x: mx.array) -> mx.array: + """Forward pass through MoE.""" + batch_size, seq_len, hidden_size = x.shape + + # Shared expert output + if self.has_shared_expert: + shared_output = self.shared_expert(x) + else: + shared_output = mx.zeros_like(x) + + # Router + router_logits = self.router(x) + router_scores = mx.sigmoid(router_logits) + + # Top-k selection + sorted_indices = mx.argsort(-router_logits, axis=-1) + top_k_indices = sorted_indices[:, :, : self.num_experts_per_tok] + top_k_scores = mx.take_along_axis(router_scores, top_k_indices, axis=-1) + + # Compute routed outputs + x_flat = x.reshape(-1, hidden_size) + indices_flat = top_k_indices.reshape(-1, self.num_experts_per_tok) + scores_flat = top_k_scores.reshape(-1, self.num_experts_per_tok) + + routed_output = mx.zeros_like(x_flat) + + for expert_idx, expert in enumerate(self.experts): + expert_mask = indices_flat == expert_idx + expert_weights = mx.sum( + scores_flat * expert_mask.astype(scores_flat.dtype), axis=-1, keepdims=True + ) + if mx.any(expert_weights > 0): + expert_out = expert(x_flat) + routed_output = routed_output + expert_out * expert_weights + + routed_output = routed_output.reshape(batch_size, seq_len, hidden_size) + + return shared_output + routed_output + + +class GraniteHybridBlock(Block): + """ + Granite 4.x hybrid block. + + Can be either a Mamba-2 block or an attention block based on layer_type. + """ + + def __init__( + self, + config: GraniteHybridConfig, + layer_idx: int = 0, + layer_type: str = "attention", + ): + super().__init__() + + self._hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.layer_type = layer_type + self.residual_multiplier = config.residual_multiplier + + # Pre-block norm + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Block type + if layer_type == "mamba": + self.block = GraniteMamba2Block(config, layer_idx=layer_idx) + else: + self.block = GraniteHybridAttention(config, layer_idx=layer_idx) + + # Post-block norm + self.post_block_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # FFN (dense or MoE) + if config.is_moe: + self.mlp = GraniteHybridMoE(config) + else: + ffn_config = FFNConfig( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size + if config.shared_intermediate_size == 0 + else config.shared_intermediate_size, + ) + self.mlp = SwiGLU(ffn_config) + + @property + def block_type(self): + from ...core.enums import BlockType + + if self.layer_type == "mamba": + return BlockType.MAMBA + return BlockType.TRANSFORMER + + @property + def hidden_size(self) -> int: + return self._hidden_size + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: Any | None = None, + ) -> BlockOutput: + """Forward pass.""" + # Block (Mamba or Attention) + residual = x + x = self.input_layernorm(x) + + if self.layer_type == "mamba": + x, new_cache = self.block(x, cache=cache) + else: + x, new_cache = self.block(x, mask=mask, cache=cache) + + x = residual + x * self.residual_multiplier + + # FFN + residual = x + x = self.post_block_layernorm(x) + x = self.mlp(x) + x = residual + x * self.residual_multiplier + + return BlockOutput(hidden_states=x, cache=new_cache) + + +class GraniteHybridModel(Backbone): + """ + Granite 4.x hybrid backbone. + + Interleaved Mamba-2 and Transformer blocks with optional MoE. + """ + + def __init__(self, config: GraniteHybridConfig): + super().__init__() + + self.config = config + self._vocab_size = config.vocab_size + self._hidden_size = config.hidden_size + self._num_layers = config.num_hidden_layers + self.embedding_multiplier = config.embedding_multiplier + + # Token embeddings + self.embed_tokens = create_token_embedding( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + ) + + # Hybrid blocks + self.layers = [ + GraniteHybridBlock( + config, + layer_idx=i, + layer_type=config.layer_types[i] if i < len(config.layer_types) else "attention", + ) + for i in range(config.num_hidden_layers) + ] + + # Final norm + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @property + def hidden_size(self) -> int: + return self._hidden_size + + @property + def num_layers(self) -> int: + return self._num_layers + + @property + def vocab_size(self) -> int: + return self._vocab_size + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array | None = None, + cache: list[Any] | None = None, + output_hidden_states: bool = False, + ) -> BackboneOutput: + """Forward pass.""" + batch_size, seq_len = input_ids.shape + + # Embeddings with multiplier + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier + + # Create causal mask (for attention layers) + if attention_mask is None: + mask = nn.MultiHeadAttention.create_additive_causal_mask(seq_len) + mask = mask.astype(hidden_states.dtype) + else: + mask = attention_mask + + # Track hidden states + all_hidden_states = (hidden_states,) if output_hidden_states else None + new_cache = [] + + # Process layers + for i, layer in enumerate(self.layers): + layer_cache = cache[i] if cache else None + output = layer(hidden_states, mask=mask, cache=layer_cache) + hidden_states = output.hidden_states + new_cache.append(output.cache) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final norm + hidden_states = self.norm(hidden_states) + + return BackboneOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + cache=new_cache, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.embed_tokens + + def set_input_embeddings(self, embeddings: nn.Module) -> None: + self.embed_tokens = embeddings + + +@register_model( + model_type="granitemoehybrid", + architectures=["GraniteMoeHybridForCausalLM"], +) +class GraniteHybridForCausalLM(Model): + """ + Granite 4.x hybrid for causal language modeling. + + Supports: + - Dense hybrid (Micro): All attention layers + - MoE hybrid (Tiny, Small): Mixed Mamba-2 + Attention with MoE + """ + + def __init__(self, config: GraniteHybridConfig): + super().__init__() + + self._config = config + self.logits_scaling = config.logits_scaling + + # Backbone + self.model = GraniteHybridModel(config) + + # LM head + if config.tie_word_embeddings: + self.lm_head = LMHead( + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + tied_embeddings=self.model.embed_tokens, + ) + else: + self.lm_head = LMHead( + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + ) + + @property + def config(self) -> GraniteHybridConfig: + return self._config + + @property + def backbone(self) -> nn.Module: + return self.model + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array | None = None, + labels: mx.array | None = None, + cache: list[Any] | None = None, + output_hidden_states: bool = False, + ) -> ModelOutput: + """Forward pass with logits scaling.""" + backbone_output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + cache=cache, + output_hidden_states=output_hidden_states, + ) + + head_output = self.lm_head( + hidden_states=backbone_output.last_hidden_state, + labels=labels, + ) + + logits = head_output.logits + if self.logits_scaling != 1.0: + logits = logits / self.logits_scaling + + return ModelOutput( + loss=head_output.loss, + logits=logits, + hidden_states=backbone_output.hidden_states, + cache=backbone_output.cache, + ) + + def generate( + self, + input_ids: mx.array, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: int | None = None, + stop_tokens: list[int] | None = None, + ) -> mx.array: + """Generate text autoregressively.""" + stop_tokens_set = set(stop_tokens or []) + + output = self(input_ids) + mx.eval(output.logits) + cache = output.cache + + generated_tokens = [input_ids] + + for _ in range(max_new_tokens): + logits = output.logits[:, -1, :] + + if temperature != 1.0: + logits = logits / temperature + + if top_k is not None and top_k > 0: + top_k_values = mx.topk(logits, k=min(top_k, logits.shape[-1])) + min_val = top_k_values[:, -1:] + logits = mx.where(logits < min_val, float("-inf"), logits) + + probs = mx.softmax(logits, axis=-1) + next_token = mx.random.categorical(mx.log(probs + 1e-10)) + next_token = mx.expand_dims(next_token, axis=-1) + + mx.eval(next_token) + generated_tokens.append(next_token) + + next_token_val = int(next_token[0, 0]) + if next_token_val in stop_tokens_set: + break + + output = self(next_token, cache=cache) + mx.eval(output.logits) + cache = output.cache + + return mx.concatenate(generated_tokens, axis=1) + + @classmethod + def from_config(cls, config: GraniteHybridConfig) -> GraniteHybridForCausalLM: + """Create from config.""" + return cls(config) + + +# Convenience aliases +GraniteHybrid = GraniteHybridForCausalLM +Granite4 = GraniteHybridForCausalLM diff --git a/src/chuk_lazarus/models_v2/families/granite/model.py b/src/chuk_lazarus/models_v2/families/granite/model.py new file mode 100644 index 00000000..30757a38 --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/granite/model.py @@ -0,0 +1,453 @@ +""" +Granite model implementation. + +Supports: +- Granite 3.x: Dense transformer with multipliers +- Granite 4.x: Hybrid Mamba-2/Transformer with optional MoE + +Reference: https://huggingface.co/ibm-granite +""" + +from __future__ import annotations + +from typing import Any + +import mlx.core as mx +import mlx.nn as nn + +from ...backbones.base import Backbone, BackboneOutput +from ...blocks.base import Block, BlockOutput +from ...components.embeddings import create_token_embedding +from ...components.ffn import SwiGLU +from ...components.normalization import RMSNorm +from ...core.config import FFNConfig +from ...core.registry import register_model +from ...heads import LMHead +from ...models.base import Model, ModelOutput +from .config import GraniteConfig + + +class GraniteAttention(nn.Module): + """ + Granite attention with attention multiplier. + + Similar to GQA but with configurable multiplier on output. + """ + + def __init__(self, config: GraniteConfig, layer_idx: int = 0): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads or config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.layer_idx = layer_idx + + # Attention multiplier + self.attention_multiplier = config.attention_multiplier + + # Number of query heads per KV head + self.n_rep = self.num_heads // self.num_kv_heads + + # Projections + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias + ) + + # RoPE + from ...components.embeddings.rope import RoPE + from ...core.config import RoPEConfig + + rope_config = RoPEConfig( + theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + ) + self.rope = RoPE(rope_config, dims=self.head_dim) + + # Attention scaling + self.scale = self.head_dim**-0.5 + + # Dropout + self.dropout_rate = config.attention_dropout + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: tuple[mx.array, mx.array] | None = None, + ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]: + """Compute attention with multiplier.""" + batch_size, seq_len, _ = x.shape + + # Project + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + # Reshape + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + # RoPE + offset = 0 + if cache is not None: + offset = cache[0].shape[2] + + q = self.rope(q, offset=offset) + k = self.rope(k, offset=offset) + + # Update cache + if cache is not None: + key_cache, value_cache = cache + k = mx.concatenate([key_cache, k], axis=2) + v = mx.concatenate([value_cache, v], axis=2) + + new_cache = (k, v) + + # Repeat KV + if self.n_rep > 1: + k = self._repeat_kv(k, self.n_rep) + v = self._repeat_kv(v, self.n_rep) + + # Attention + output = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) + + # Reshape + output = output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + + # Output projection with multiplier + output = self.o_proj(output) + output = output * self.attention_multiplier + + return output, new_cache + + def _repeat_kv(self, x: mx.array, n_rep: int) -> mx.array: + """Repeat KV heads.""" + if n_rep == 1: + return x + batch, num_kv_heads, seq_len, head_dim = x.shape + x = mx.expand_dims(x, axis=2) + x = mx.repeat(x, n_rep, axis=2) + x = x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim) + return x + + +class GraniteBlock(Block): + """ + Granite transformer block. + + Pre-norm transformer with: + - RMSNorm + - Granite attention (with attention multiplier) + - SwiGLU FFN + - Residual multiplier + """ + + def __init__(self, config: GraniteConfig, layer_idx: int = 0): + super().__init__() + + self._hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.residual_multiplier = config.residual_multiplier + + # Pre-attention norm + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Attention + self.self_attn = GraniteAttention(config, layer_idx=layer_idx) + + # Post-attention norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # FFN + ffn_config = FFNConfig( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=config.mlp_bias, + ) + self.mlp = SwiGLU(ffn_config) + + @property + def block_type(self): + from ...core.enums import BlockType + + return BlockType.TRANSFORMER + + @property + def hidden_size(self) -> int: + return self._hidden_size + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: tuple[mx.array, mx.array] | None = None, + ) -> BlockOutput: + """Forward pass with residual multiplier.""" + # Self-attention with residual + residual = x + x = self.input_layernorm(x) + x, new_cache = self.self_attn(x, mask=mask, cache=cache) + x = residual + x * self.residual_multiplier + + # FFN with residual + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x * self.residual_multiplier + + return BlockOutput(hidden_states=x, cache=new_cache) + + +class GraniteModel(Backbone): + """ + Granite backbone (without LM head). + + Token embeddings with multiplier + transformer blocks + final norm. + """ + + def __init__(self, config: GraniteConfig): + super().__init__() + + self.config = config + self._vocab_size = config.vocab_size + self._hidden_size = config.hidden_size + self._num_layers = config.num_hidden_layers + self.embedding_multiplier = config.embedding_multiplier + + # Token embeddings + self.embed_tokens = create_token_embedding( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + ) + + # Transformer blocks + self.layers = [GraniteBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)] + + # Final norm + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @property + def hidden_size(self) -> int: + return self._hidden_size + + @property + def num_layers(self) -> int: + return self._num_layers + + @property + def vocab_size(self) -> int: + return self._vocab_size + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array | None = None, + cache: list[Any] | None = None, + output_hidden_states: bool = False, + ) -> BackboneOutput: + """Forward pass.""" + batch_size, seq_len = input_ids.shape + + # Embeddings with multiplier + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier + + # Create causal mask + if attention_mask is None: + mask = nn.MultiHeadAttention.create_additive_causal_mask(seq_len) + mask = mask.astype(hidden_states.dtype) + else: + mask = attention_mask + + # Track hidden states + all_hidden_states = (hidden_states,) if output_hidden_states else None + new_cache = [] + + # Process layers + for i, layer in enumerate(self.layers): + layer_cache = cache[i] if cache else None + output = layer(hidden_states, mask=mask, cache=layer_cache) + hidden_states = output.hidden_states + new_cache.append(output.cache) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final norm + hidden_states = self.norm(hidden_states) + + return BackboneOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + cache=new_cache, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.embed_tokens + + def set_input_embeddings(self, embeddings: nn.Module) -> None: + self.embed_tokens = embeddings + + +@register_model( + model_type="granite", + architectures=["GraniteForCausalLM"], +) +class GraniteForCausalLM(Model): + """ + Granite for causal language modeling. + + Complete model with backbone + LM head + logits scaling. + """ + + def __init__(self, config: GraniteConfig): + super().__init__() + + self._config = config + self.logits_scaling = config.logits_scaling + + # Backbone + self.model = GraniteModel(config) + + # LM head + if config.tie_word_embeddings: + self.lm_head = LMHead( + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + tied_embeddings=self.model.embed_tokens, + ) + else: + self.lm_head = LMHead( + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + ) + + @property + def config(self) -> GraniteConfig: + return self._config + + @property + def backbone(self) -> nn.Module: + return self.model + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array | None = None, + labels: mx.array | None = None, + cache: list[Any] | None = None, + output_hidden_states: bool = False, + ) -> ModelOutput: + """Forward pass with logits scaling.""" + # Backbone + backbone_output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + cache=cache, + output_hidden_states=output_hidden_states, + ) + + # LM head + head_output = self.lm_head( + hidden_states=backbone_output.last_hidden_state, + labels=labels, + ) + + # Apply logits scaling + logits = head_output.logits + if self.logits_scaling != 1.0: + logits = logits / self.logits_scaling + + return ModelOutput( + loss=head_output.loss, + logits=logits, + hidden_states=backbone_output.hidden_states, + cache=backbone_output.cache, + ) + + def generate( + self, + input_ids: mx.array, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: int | None = None, + top_p: float | None = None, + repetition_penalty: float = 1.0, + stop_tokens: list[int] | None = None, + ) -> mx.array: + """Generate text autoregressively.""" + stop_tokens_set = set(stop_tokens or []) + + # Process prompt + output = self(input_ids) + mx.eval(output.logits) + cache = output.cache + + # Track generated tokens + generated_tokens = [input_ids] + + for _ in range(max_new_tokens): + logits = output.logits[:, -1, :] + + # Apply repetition penalty + if repetition_penalty != 1.0: + all_tokens = mx.concatenate(generated_tokens, axis=1) + unique_tokens = set(all_tokens.flatten().tolist()) + vocab_size = logits.shape[-1] + token_indices = mx.array([t for t in unique_tokens if t < vocab_size]) + if token_indices.size > 0: + mask = mx.zeros((vocab_size,)) + for tok in token_indices.tolist(): + mask = mask.at[tok].add(1.0) + penalty_mask = mx.where(mask > 0, repetition_penalty, 1.0) + logits = logits / penalty_mask + + # Apply temperature + if temperature != 1.0: + logits = logits / temperature + + # Apply top-k + if top_k is not None and top_k > 0: + top_k_values = mx.topk(logits, k=min(top_k, logits.shape[-1])) + min_val = top_k_values[:, -1:] + logits = mx.where(logits < min_val, float("-inf"), logits) + + # Sample + probs = mx.softmax(logits, axis=-1) + next_token = mx.random.categorical(mx.log(probs + 1e-10)) + next_token = mx.expand_dims(next_token, axis=-1) + + mx.eval(next_token) + generated_tokens.append(next_token) + + # Check stop + next_token_val = int(next_token[0, 0]) + if next_token_val in stop_tokens_set: + break + + # Forward with cache + output = self(next_token, cache=cache) + mx.eval(output.logits) + cache = output.cache + + return mx.concatenate(generated_tokens, axis=1) + + @classmethod + def from_config(cls, config: GraniteConfig) -> GraniteForCausalLM: + """Create from config.""" + return cls(config) + + +# Convenience alias +Granite = GraniteForCausalLM diff --git a/src/chuk_lazarus/models_v2/families/llama4/__init__.py b/src/chuk_lazarus/models_v2/families/llama4/__init__.py new file mode 100644 index 00000000..8a7ce995 --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/llama4/__init__.py @@ -0,0 +1,41 @@ +""" +Llama 4 model family. + +Supports: +- Llama 4 Scout (17B active / 109B total) +- Llama 4 Maverick (17B active / 400B total) +- Multimodal variants with vision encoder + +Key features: +- MoE (Mixture of Experts) with shared expert +- iRoPE (interleaved RoPE and NoPE layers) +- QK normalization +- Native multimodal support + +Reference: https://llama.meta.com/llama4/ +""" + +from .attention import Llama4Attention, Llama4FlexAttention, create_llama4_attention +from .config import Llama4Config, Llama4TextConfig, Llama4VisionConfig +from .model import Llama4, Llama4Block, Llama4ForCausalLM, Llama4Model +from .moe import Llama4MLP, Llama4MoE, create_llama4_moe + +__all__ = [ + # Config + "Llama4Config", + "Llama4TextConfig", + "Llama4VisionConfig", + # Model + "Llama4", + "Llama4ForCausalLM", + "Llama4Model", + "Llama4Block", + # Components + "Llama4Attention", + "Llama4FlexAttention", + "Llama4MoE", + "Llama4MLP", + # Factories + "create_llama4_attention", + "create_llama4_moe", +] diff --git a/src/chuk_lazarus/models_v2/families/llama4/attention.py b/src/chuk_lazarus/models_v2/families/llama4/attention.py new file mode 100644 index 00000000..339e7ffa --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/llama4/attention.py @@ -0,0 +1,215 @@ +""" +Llama 4 Attention. + +Extends GQA with Llama 4-specific features: +- QK normalization (RMS norm on Q and K after RoPE) +- iRoPE: Interleaved RoPE and NoPE (global) layers +- Temperature scaling for long sequences + +Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4 +""" + +from __future__ import annotations + +import mlx.core as mx +import mlx.nn as nn + +from ...components.embeddings.rope import RoPE +from ...components.normalization import RMSNorm +from ...core.config import RoPEConfig +from .config import Llama4TextConfig + + +class Llama4Attention(nn.Module): + """ + Llama 4 attention with QK normalization and iRoPE support. + + Key features: + - QK normalization: RMS norm applied to Q and K after projection + - iRoPE: Interleaved RoPE (chunked attention) and NoPE (global) layers + - Temperature scaling for attention logits + + Args: + config: Llama 4 text configuration + layer_idx: Layer index (determines if this is a NoPE layer) + + Example: + >>> config = Llama4TextConfig.tiny() + >>> attn = Llama4Attention(config, layer_idx=0) # NoPE layer + >>> x = mx.random.normal((2, 10, 64)) + >>> output, cache = attn(x) + """ + + def __init__(self, config: Llama4TextConfig, layer_idx: int = 0): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads or config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.layer_idx = layer_idx + + # Number of query heads per KV head + self.n_rep = self.num_heads // self.num_kv_heads + + # Determine if this is a NoPE layer (no RoPE, global attention) + self.is_nope_layer = False + if config.no_rope_layers is not None: + self.is_nope_layer = layer_idx in config.no_rope_layers + + # Q, K, V projections + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + # QK normalization + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # RoPE for non-NoPE layers + self.rope = None + if not self.is_nope_layer: + rope_config = RoPEConfig( + theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + ) + self.rope = RoPE(rope_config, dims=self.head_dim) + + # Attention scaling + self.scale = self.head_dim**-0.5 + + # Temperature scaling + self.attn_temperature_tuning = config.attn_temperature_tuning + if self.attn_temperature_tuning: + # Learned temperature parameter + self.temperature = mx.ones((1,)) + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: tuple[mx.array, mx.array] | None = None, + ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]: + """ + Compute Llama 4 attention. + + Args: + x: Input, shape (batch, seq_len, hidden_size) + mask: Attention mask (additive, -inf for masked) + cache: Optional KV cache + + Returns: + output: Shape (batch, seq_len, hidden_size) + cache: Updated KV cache + """ + batch_size, seq_len, _ = x.shape + + # Project to Q, K, V + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + # Reshape for attention + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + + # Apply QK normalization (before RoPE) + if self.use_qk_norm: + # Normalize per head + q = self.q_norm(q) + k = self.k_norm(k) + + # Transpose for attention: (batch, heads, seq, head_dim) + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + # Get cache offset for RoPE + offset = 0 + if cache is not None: + offset = cache[0].shape[2] + + # Apply RoPE (only for non-NoPE layers) + if self.rope is not None: + q = self.rope(q, offset=offset) + k = self.rope(k, offset=offset) + + # Update cache + if cache is not None: + key_cache, value_cache = cache + k = mx.concatenate([key_cache, k], axis=2) + v = mx.concatenate([value_cache, v], axis=2) + + new_cache = (k, v) + + # Repeat KV heads to match query heads + if self.n_rep > 1: + k = self._repeat_kv(k, self.n_rep) + v = self._repeat_kv(v, self.n_rep) + + # Compute attention scale + scale = self.scale + if self.attn_temperature_tuning: + scale = scale / self.temperature + + # Compute attention + output = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask) + + # Reshape back: (batch, heads, seq, head_dim) -> (batch, seq, hidden) + output = output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) + + # Output projection + output = self.o_proj(output) + + return output, new_cache + + def _repeat_kv(self, x: mx.array, n_rep: int) -> mx.array: + """Repeat KV heads to match query heads.""" + if n_rep == 1: + return x + + batch, num_kv_heads, seq_len, head_dim = x.shape + x = mx.expand_dims(x, axis=2) + x = mx.repeat(x, n_rep, axis=2) + x = x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim) + return x + + +class Llama4FlexAttention(Llama4Attention): + """ + Flexible attention with floor-scale RoPE. + + Used in some Llama 4 variants for very long context. + Implements a floor function on the position indices. + """ + + def __init__(self, config: Llama4TextConfig, layer_idx: int = 0): + super().__init__(config, layer_idx) + + # Floor scale for position indices (for long context) + self.floor_scale = 1 # Can be adjusted for different context lengths + + +def create_llama4_attention( + config: Llama4TextConfig, + layer_idx: int = 0, + attention_type: str = "default", +) -> nn.Module: + """ + Factory function for Llama 4 attention. + + Args: + config: Llama 4 text configuration + layer_idx: Layer index + attention_type: "default" or "flex" + + Returns: + Attention module + """ + if attention_type == "flex": + return Llama4FlexAttention(config, layer_idx) + return Llama4Attention(config, layer_idx) diff --git a/src/chuk_lazarus/models_v2/families/llama4/config.py b/src/chuk_lazarus/models_v2/families/llama4/config.py new file mode 100644 index 00000000..8232beda --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/llama4/config.py @@ -0,0 +1,261 @@ +""" +Llama 4 configuration. + +Extends base ModelConfig with Llama 4-specific settings including MoE and multimodal. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import Field, model_validator + +from ...core.config import ModelConfig + + +class Llama4TextConfig(ModelConfig): + """ + Configuration for Llama 4 text models. + + Llama 4 introduces: + - MoE (Mixture of Experts) with shared expert + - iRoPE (interleaved RoPE and NoPE layers) + - QK normalization + - Temperature scaling for attention + + Example: + >>> # Llama 4 Scout (17B active / 109B total) + >>> config = Llama4TextConfig( + ... vocab_size=202048, + ... hidden_size=5120, + ... num_hidden_layers=48, + ... num_attention_heads=40, + ... num_key_value_heads=8, + ... intermediate_size=8192, # Shared expert + ... intermediate_size_mlp=16384, # Routed experts + ... num_local_experts=16, + ... num_experts_per_tok=1, + ... ) + """ + + model_type: str = "llama4" + + # Llama 4 defaults + hidden_act: str = "silu" + rope_theta: float = 500000.0 + rms_norm_eps: float = 1e-5 + + # MoE parameters + num_local_experts: int = Field( + default=16, + description="Total number of routed experts", + ) + num_experts_per_tok: int = Field( + default=1, + description="Number of experts activated per token (top-k)", + ) + intermediate_size_mlp: int = Field( + default=16384, + description="Intermediate size for routed experts (per expert)", + ) + moe_router_topk: int = Field( + default=1, + description="Top-k experts for routing", + ) + + # iRoPE parameters + no_rope_layers: list[int] | None = Field( + default=None, + description="Layer indices that use NoPE (global attention without RoPE)", + ) + attention_chunk_size: int | None = Field( + default=8192, + description="Chunk size for chunked attention in RoPE layers", + ) + + # Attention features + use_qk_norm: bool = Field( + default=True, + description="Apply RMS normalization to Q and K after RoPE", + ) + attn_temperature_tuning: bool = Field( + default=False, + description="Use learned temperature scaling for attention", + ) + + # RoPE scaling + rope_scaling: dict[str, Any] | None = None + + @classmethod + def scout_17b(cls) -> Llama4TextConfig: + """ + Llama 4 Scout configuration. + + 17B active parameters, 109B total with 16 experts. + """ + return cls( + vocab_size=202048, + hidden_size=5120, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + intermediate_size=8192, # Shared expert intermediate + intermediate_size_mlp=16384, # Routed expert intermediate + num_local_experts=16, + num_experts_per_tok=1, + max_position_embeddings=131072, + rope_theta=500000.0, + use_qk_norm=True, + tie_word_embeddings=False, + # NoPE layers (global attention without RoPE) at 0, 4, 8, ... + no_rope_layers=[i * 4 for i in range(12)], + ) + + @classmethod + def maverick_17b(cls) -> Llama4TextConfig: + """ + Llama 4 Maverick configuration. + + 17B active parameters, 400B total with 128 experts. + """ + return cls( + vocab_size=202048, + hidden_size=5120, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + intermediate_size=8192, + intermediate_size_mlp=8192, # Smaller per-expert + num_local_experts=128, + num_experts_per_tok=1, + max_position_embeddings=131072, + rope_theta=500000.0, + use_qk_norm=True, + tie_word_embeddings=False, + no_rope_layers=[i * 4 for i in range(12)], + ) + + @classmethod + def tiny(cls) -> Llama4TextConfig: + """Create tiny Llama 4 for testing.""" + return cls( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, # Shared expert + intermediate_size_mlp=256, # Routed experts + num_local_experts=4, + num_experts_per_tok=1, + max_position_embeddings=256, + use_qk_norm=True, + no_rope_layers=[0], # First layer is NoPE + ) + + +class Llama4VisionConfig(ModelConfig): + """ + Configuration for Llama 4 vision encoder. + + ViT-style vision encoder with pixel shuffle for efficiency. + """ + + model_type: str = "llama4_vision" + + # Vision transformer settings + hidden_size: int = 1280 + num_hidden_layers: int = 32 + num_attention_heads: int = 16 + intermediate_size: int = 5120 + + # Image processing + image_size: int = 560 + patch_size: int = 14 + num_channels: int = 3 + + # Projector settings + vision_output_dim: int = 5120 + pixel_shuffle_ratio: float = 0.5 + + # Norm settings + rms_norm_eps: float = 1e-5 + hidden_act: str = "gelu" + + @classmethod + def default(cls) -> Llama4VisionConfig: + """Default vision config for Llama 4.""" + return cls() + + +class Llama4Config(ModelConfig): + """ + Full Llama 4 configuration for multimodal models. + + Combines text and vision configurations. + """ + + model_type: str = "llama4" + + text_config: Llama4TextConfig | None = None + vision_config: Llama4VisionConfig | None = None + + # Image token settings + image_token_index: int = 128011 + image_token: str = "<|image|>" + + @classmethod + def scout_multimodal(cls) -> Llama4Config: + """Llama 4 Scout with vision encoder.""" + return cls( + text_config=Llama4TextConfig.scout_17b(), + vision_config=Llama4VisionConfig.default(), + ) + + @classmethod + def scout_text_only(cls) -> Llama4Config: + """Llama 4 Scout text-only.""" + return cls( + text_config=Llama4TextConfig.scout_17b(), + vision_config=None, + ) + + @model_validator(mode="after") + def set_derived_values(self) -> Llama4Config: + """Override to skip derived value computation for wrapper config.""" + # Llama4Config is a wrapper that delegates to text_config, + # so we don't need to compute head_dim etc. at this level + return self + + # Forward text config attributes for convenience + @property + def vocab_size(self) -> int: + return self.text_config.vocab_size if self.text_config else 0 + + @property + def hidden_size(self) -> int: + return self.text_config.hidden_size if self.text_config else 0 + + @property + def num_hidden_layers(self) -> int: + return self.text_config.num_hidden_layers if self.text_config else 0 + + @property + def num_attention_heads(self) -> int: + return self.text_config.num_attention_heads if self.text_config else 0 + + @property + def num_key_value_heads(self) -> int | None: + return self.text_config.num_key_value_heads if self.text_config else None + + @property + def intermediate_size(self) -> int: + return self.text_config.intermediate_size if self.text_config else 0 + + @property + def rms_norm_eps(self) -> float: + return self.text_config.rms_norm_eps if self.text_config else 1e-5 + + @property + def tie_word_embeddings(self) -> bool: + return self.text_config.tie_word_embeddings if self.text_config else True diff --git a/src/chuk_lazarus/models_v2/families/llama4/model.py b/src/chuk_lazarus/models_v2/families/llama4/model.py new file mode 100644 index 00000000..7dd21a5f --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/llama4/model.py @@ -0,0 +1,335 @@ +""" +Llama 4 model implementation. + +Llama 4 with MoE architecture and multimodal support. +""" + +from __future__ import annotations + +from typing import Any + +import mlx.core as mx +import mlx.nn as nn + +from ...backbones.base import Backbone, BackboneOutput +from ...blocks.base import Block, BlockOutput +from ...components.embeddings import create_token_embedding +from ...components.normalization import RMSNorm +from ...core.registry import register_model +from ...heads import LMHead +from ...models.base import Model, ModelOutput +from .attention import Llama4Attention +from .config import Llama4TextConfig +from .moe import Llama4MoE + + +class Llama4Block(Block): + """ + Llama 4 transformer block. + + Pre-norm transformer with: + - RMSNorm + - Llama 4 attention (GQA with QK norm, iRoPE) + - MoE FFN (shared + routed experts) + """ + + def __init__(self, config: Llama4TextConfig, layer_idx: int = 0): + super().__init__() + + self._hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Pre-attention norm + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Attention with QK norm and iRoPE support + self.self_attn = Llama4Attention(config, layer_idx=layer_idx) + + # Post-attention norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # MoE FFN + self.mlp = Llama4MoE(config) + + @property + def block_type(self): + from ...core.enums import BlockType + + return BlockType.TRANSFORMER + + @property + def hidden_size(self) -> int: + return self._hidden_size + + def __call__( + self, + x: mx.array, + mask: mx.array | None = None, + cache: tuple[mx.array, mx.array] | None = None, + ) -> BlockOutput: + """Forward pass.""" + # Self-attention with residual + residual = x + x = self.input_layernorm(x) + x, new_cache = self.self_attn(x, mask=mask, cache=cache) + x = residual + x + + # FFN (MoE) with residual + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + + return BlockOutput(hidden_states=x, cache=new_cache) + + +class Llama4Model(Backbone): + """ + Llama 4 backbone (without LM head). + + Token embeddings + transformer blocks with MoE + final norm. + """ + + def __init__(self, config: Llama4TextConfig): + super().__init__() + + self.config = config + self._vocab_size = config.vocab_size + self._hidden_size = config.hidden_size + self._num_layers = config.num_hidden_layers + + # Token embeddings + self.embed_tokens = create_token_embedding( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + ) + + # Transformer blocks with MoE + self.layers = [Llama4Block(config, layer_idx=i) for i in range(config.num_hidden_layers)] + + # Final norm + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @property + def hidden_size(self) -> int: + return self._hidden_size + + @property + def num_layers(self) -> int: + return self._num_layers + + @property + def vocab_size(self) -> int: + return self._vocab_size + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array | None = None, + cache: list[Any] | None = None, + output_hidden_states: bool = False, + ) -> BackboneOutput: + """Forward pass.""" + batch_size, seq_len = input_ids.shape + + # Embeddings + hidden_states = self.embed_tokens(input_ids) + + # Create causal mask + if attention_mask is None: + mask = nn.MultiHeadAttention.create_additive_causal_mask(seq_len) + mask = mask.astype(hidden_states.dtype) + else: + mask = attention_mask + + # Track hidden states + all_hidden_states = (hidden_states,) if output_hidden_states else None + new_cache = [] + + # Process layers + for i, layer in enumerate(self.layers): + layer_cache = cache[i] if cache else None + output = layer(hidden_states, mask=mask, cache=layer_cache) + hidden_states = output.hidden_states + new_cache.append(output.cache) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final norm + hidden_states = self.norm(hidden_states) + + return BackboneOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + cache=new_cache, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.embed_tokens + + def set_input_embeddings(self, embeddings: nn.Module) -> None: + self.embed_tokens = embeddings + + +@register_model( + model_type="llama4", + architectures=["Llama4ForCausalLM"], +) +class Llama4ForCausalLM(Model): + """ + Llama 4 for causal language modeling. + + Complete model with MoE backbone + LM head. + Supports text-only mode and multimodal (when vision encoder is provided). + """ + + def __init__(self, config: Llama4TextConfig): + super().__init__() + + self._config = config + + # Backbone + self.model = Llama4Model(config) + + # LM head (optionally tied) + if config.tie_word_embeddings: + self.lm_head = LMHead( + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + tied_embeddings=self.model.embed_tokens, + ) + else: + self.lm_head = LMHead( + hidden_size=config.hidden_size, + vocab_size=config.vocab_size, + ) + + @property + def config(self) -> Llama4TextConfig: + return self._config + + @property + def backbone(self) -> nn.Module: + return self.model + + def __call__( + self, + input_ids: mx.array, + attention_mask: mx.array | None = None, + labels: mx.array | None = None, + cache: list[Any] | None = None, + output_hidden_states: bool = False, + ) -> ModelOutput: + """Forward pass.""" + # Backbone + backbone_output = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + cache=cache, + output_hidden_states=output_hidden_states, + ) + + # LM head + head_output = self.lm_head( + hidden_states=backbone_output.last_hidden_state, + labels=labels, + ) + + return ModelOutput( + loss=head_output.loss, + logits=head_output.logits, + hidden_states=backbone_output.hidden_states, + cache=backbone_output.cache, + ) + + def generate( + self, + input_ids: mx.array, + max_new_tokens: int = 100, + temperature: float = 1.0, + top_k: int | None = None, + top_p: float | None = None, + repetition_penalty: float = 1.0, + stop_tokens: list[int] | None = None, + ) -> mx.array: + """ + Generate text autoregressively. + + Args: + input_ids: Prompt, shape (batch, prompt_len) + max_new_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_k: Top-k sampling + top_p: Nucleus sampling + repetition_penalty: Penalty for repeating tokens + stop_tokens: Tokens that stop generation + + Returns: + Generated sequence, shape (batch, total_len) + """ + stop_tokens_set = set(stop_tokens or []) + + # Process prompt + output = self(input_ids) + mx.eval(output.logits) + cache = output.cache + + # Track generated tokens + generated_tokens = [input_ids] + + for _ in range(max_new_tokens): + # Get logits for last position + logits = output.logits[:, -1, :] + + # Apply repetition penalty + if repetition_penalty != 1.0: + all_tokens = mx.concatenate(generated_tokens, axis=1) + unique_tokens = set(all_tokens.flatten().tolist()) + vocab_size = logits.shape[-1] + token_indices = mx.array([t for t in unique_tokens if t < vocab_size]) + if token_indices.size > 0: + mask = mx.zeros((vocab_size,)) + for tok in token_indices.tolist(): + mask = mask.at[tok].add(1.0) + penalty_mask = mx.where(mask > 0, repetition_penalty, 1.0) + logits = logits / penalty_mask + + # Apply temperature + if temperature != 1.0: + logits = logits / temperature + + # Apply top-k + if top_k is not None and top_k > 0: + top_k_values = mx.topk(logits, k=min(top_k, logits.shape[-1])) + min_val = top_k_values[:, -1:] + logits = mx.where(logits < min_val, float("-inf"), logits) + + # Sample + probs = mx.softmax(logits, axis=-1) + next_token = mx.random.categorical(mx.log(probs + 1e-10)) + next_token = mx.expand_dims(next_token, axis=-1) + + mx.eval(next_token) + generated_tokens.append(next_token) + + # Check stop + next_token_val = int(next_token[0, 0]) + if next_token_val in stop_tokens_set: + break + + # Forward with cache + output = self(next_token, cache=cache) + mx.eval(output.logits) + cache = output.cache + + return mx.concatenate(generated_tokens, axis=1) + + @classmethod + def from_config(cls, config: Llama4TextConfig) -> Llama4ForCausalLM: + """Create from config.""" + return cls(config) + + +# Convenience aliases +Llama4 = Llama4ForCausalLM diff --git a/src/chuk_lazarus/models_v2/families/llama4/moe.py b/src/chuk_lazarus/models_v2/families/llama4/moe.py new file mode 100644 index 00000000..94403b6a --- /dev/null +++ b/src/chuk_lazarus/models_v2/families/llama4/moe.py @@ -0,0 +1,280 @@ +""" +Llama 4 Mixture of Experts (MoE). + +Llama 4 uses a sparse MoE architecture with: +- A shared expert that is always active +- Routed experts selected by top-k routing +- Sigmoid-based router scores + +This implementation uses MLX's mx.gather_mm for efficient sparse computation, +which is critical for memory efficiency with large numbers of experts. + +Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4 +""" + +from __future__ import annotations + +import math +from functools import partial + +import mlx.core as mx +import mlx.nn as nn + +from .config import Llama4TextConfig + + +class Llama4MLP(nn.Module): + """ + SwiGLU MLP for Llama 4 shared expert. + + Same as standard Llama MLP but parameterized for MoE use. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + bias: bool = False, + ): + super().__init__() + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias) + + def __call__(self, x: mx.array) -> mx.array: + """Forward pass through SwiGLU.""" + gate = nn.silu(self.gate_proj(x)) + up = self.up_proj(x) + return self.down_proj(gate * up) + + +class SwitchLinear(nn.Module): + """ + Linear layer with expert selection using mx.gather_mm. + + Stores weights for all experts in a single tensor and uses + MLX's gather_mm for efficient sparse computation. + + Weight shape: (num_experts, output_dims, input_dims) + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + bias: bool = False, + ): + super().__init__() + + scale = math.sqrt(1.0 / input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + @property + def input_dims(self) -> int: + return self.weight.shape[2] + + @property + def output_dims(self) -> int: + return self.weight.shape[1] + + @property + def num_experts(self) -> int: + return self.weight.shape[0] + + def __call__(self, x: mx.array, indices: mx.array) -> mx.array: + """ + Forward pass with expert selection. + + Args: + x: Input tensor, shape (..., 1, 1, input_dims) + indices: Expert indices, shape (..., k) where k is num_experts_per_tok + + Returns: + Output tensor, shape (..., k, 1, output_dims) + """ + # mx.gather_mm: efficient batched matmul with index selection + # x @ weight[indices].T + out = mx.gather_mm( + x, + self.weight.swapaxes(-1, -2), # (experts, input, output) + rhs_indices=indices, + ) + + if "bias" in self: + # Add bias for selected experts + out = out + mx.expand_dims(self.bias[indices], -2) + + return out + + +@partial(mx.compile, shapeless=True) +def swiglu(x: mx.array, gate: mx.array) -> mx.array: + """SwiGLU activation: silu(gate) * x""" + return nn.silu(gate) * x + + +class SwitchGLU(nn.Module): + """ + SwiGLU MLP with expert selection using mx.gather_mm. + + This is the efficient implementation for routed experts that uses + fused weight tensors and MLX's native sparse matmul. + """ + + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + bias: bool = False, + ): + super().__init__() + + self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + + def __call__(self, x: mx.array, indices: mx.array) -> mx.array: + """ + Forward pass with expert selection. + + Args: + x: Input tensor, shape (batch * seq, hidden_size) after pre-weighting + indices: Expert indices, shape (batch * seq, k) + + Returns: + Output tensor, shape (batch * seq, k, hidden_size) + """ + # Expand dims for gather_mm: (..., 1, 1, hidden) + x = mx.expand_dims(x, (-2, -3)) + + # Compute gate and up projections for selected experts + x_gate = self.gate_proj(x, indices) # (..., k, 1, intermediate) + x_up = self.up_proj(x, indices) # (..., k, 1, intermediate) + + # Apply SwiGLU activation + x = swiglu(x_up, x_gate) + + # Down projection back to hidden size + x = self.down_proj(x, indices) # (..., k, 1, hidden) + + return x.squeeze(-2) # (..., k, hidden) + + +class Llama4MoE(nn.Module): + """ + Llama 4 Mixture of Experts layer using efficient gather_mm. + + Key features: + - Shared expert: Always active for all tokens (standard MLP) + - Routed experts: Sparsely activated via top-k routing using SwitchGLU + - Sigmoid router with top-k selection + - Uses mx.gather_mm for efficient sparse computation + + Args: + config: Llama 4 text configuration + + Example: + >>> config = Llama4TextConfig.tiny() + >>> moe = Llama4MoE(config) + >>> x = mx.random.normal((2, 10, 64)) + >>> output = moe(x) # Shape: (2, 10, 64) + """ + + def __init__(self, config: Llama4TextConfig): + super().__init__() + + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size # For shared expert + self.intermediate_size_mlp = config.intermediate_size_mlp # For routed experts + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + assert self.num_experts_per_tok == 1, "Only 1 expert per token currently supported" + + # Router: projects to num_experts scores + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + + # Shared expert (always active) - standard MLP + self.shared_expert = Llama4MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + ) + + # Routed experts using efficient SwitchGLU with gather_mm + self.experts = SwitchGLU( + input_dims=config.hidden_size, + hidden_dims=config.intermediate_size_mlp, + num_experts=config.num_local_experts, + bias=False, + ) + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass through MoE. + + Args: + x: Input tensor, shape (batch, seq_len, hidden_size) + + Returns: + Output tensor, shape (batch, seq_len, hidden_size) + """ + batch_size, seq_len, hidden_size = x.shape + + # 1. Compute shared expert output (always active) + shared_output = self.shared_expert(x) + + # 2. Compute router scores and select top-k experts + router_logits = self.router(x) # (batch, seq_len, num_experts) + + # Use sigmoid for scores (Llama 4 style) + # Get top-k expert indices + k = self.num_experts_per_tok + indices = mx.argpartition(-router_logits, kth=k - 1, axis=-1)[..., :k] + + # Gather scores for selected experts and apply sigmoid + scores = mx.take_along_axis(router_logits, indices, axis=-1) + scores = mx.sigmoid(scores.astype(mx.float32)).astype(x.dtype) + + # 3. Compute routed expert outputs using SwitchGLU + # Pre-weight input by routing scores + x_weighted = x * scores # (batch, seq, hidden) broadcasted with (batch, seq, k) + + # Reshape for expert computation + x_flat = x_weighted.reshape(-1, hidden_size) # (batch * seq, hidden) + indices_flat = indices.reshape(-1, k) # (batch * seq, k) + + # Compute routed expert output + routed_output = self.experts(x_flat, indices_flat) # (batch * seq, k, hidden) + + # Sum over k dimension and reshape back + routed_output = routed_output.squeeze(1) # (batch * seq, hidden) for k=1 + routed_output = routed_output.reshape(batch_size, seq_len, hidden_size) + + # 4. Combine shared and routed outputs + return shared_output + routed_output + + +def create_llama4_moe(config: Llama4TextConfig) -> nn.Module: + """ + Factory function for Llama 4 MoE. + + Args: + config: Llama 4 text configuration + + Returns: + MoE module using efficient gather_mm implementation + """ + return Llama4MoE(config) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 00000000..851690f7 --- /dev/null +++ b/tests/inference/__init__.py @@ -0,0 +1 @@ +"""Tests for the inference module.""" diff --git a/tests/inference/test_chat.py b/tests/inference/test_chat.py new file mode 100644 index 00000000..e98f150b --- /dev/null +++ b/tests/inference/test_chat.py @@ -0,0 +1,295 @@ +"""Tests for inference/chat.py module.""" + +from chuk_lazarus.inference.chat import ( + ASSISTANT_SUFFIX, + NEWLINE_DOUBLE, + ChatHistory, + ChatMessage, + FallbackTemplate, + Role, + _format_simple, + format_chat_prompt, + format_history, +) + + +class TestRole: + """Tests for Role enum.""" + + def test_role_values(self): + """Test role enum values.""" + assert Role.SYSTEM.value == "system" + assert Role.USER.value == "user" + assert Role.ASSISTANT.value == "assistant" + assert Role.MODEL.value == "model" + + def test_role_display_name(self): + """Test display_name method.""" + assert Role.SYSTEM.display_name() == "System" + assert Role.USER.display_name() == "User" + assert Role.ASSISTANT.display_name() == "Assistant" + assert Role.MODEL.display_name() == "Model" + + def test_role_is_str_enum(self): + """Test that Role is a string enum.""" + assert isinstance(Role.USER, str) + assert Role.USER == "user" + + +class TestChatMessage: + """Tests for ChatMessage model.""" + + def test_create_message(self): + """Test creating a chat message.""" + msg = ChatMessage(role=Role.USER, content="Hello!") + assert msg.role == Role.USER + assert msg.content == "Hello!" + + def test_to_tokenizer_format(self): + """Test converting to tokenizer format.""" + msg = ChatMessage(role=Role.USER, content="Hello!") + fmt = msg.to_tokenizer_format() + assert fmt == {"role": "user", "content": "Hello!"} + + def test_to_tokenizer_format_all_roles(self): + """Test tokenizer format for all roles.""" + for role in Role: + msg = ChatMessage(role=role, content="test") + fmt = msg.to_tokenizer_format() + assert fmt["role"] == role.value + assert fmt["content"] == "test" + + +class TestChatHistory: + """Tests for ChatHistory model.""" + + def test_empty_history(self): + """Test empty chat history.""" + history = ChatHistory() + assert len(history.messages) == 0 + assert history.system_message is None + + def test_add_user(self): + """Test adding user message.""" + history = ChatHistory() + result = history.add_user("Hello!") + assert result is history # Returns self for chaining + assert len(history.messages) == 1 + assert history.messages[0].role == Role.USER + assert history.messages[0].content == "Hello!" + + def test_add_assistant(self): + """Test adding assistant message.""" + history = ChatHistory() + result = history.add_assistant("Hi there!") + assert result is history + assert len(history.messages) == 1 + assert history.messages[0].role == Role.ASSISTANT + assert history.messages[0].content == "Hi there!" + + def test_add_system(self): + """Test setting system message.""" + history = ChatHistory() + result = history.add_system("You are helpful.") + assert result is history + assert history.system_message == "You are helpful." + + def test_clear(self): + """Test clearing history.""" + history = ChatHistory() + history.add_system("System prompt") + history.add_user("Message 1") + history.add_assistant("Response 1") + + result = history.clear() + assert result is history + assert len(history.messages) == 0 + assert history.system_message == "System prompt" # Preserved + + def test_chaining(self): + """Test method chaining.""" + history = ( + ChatHistory() + .add_system("Be helpful") + .add_user("Hello") + .add_assistant("Hi!") + .add_user("How are you?") + ) + assert history.system_message == "Be helpful" + assert len(history.messages) == 3 + + def test_to_tokenizer_format_empty(self): + """Test tokenizer format with empty history.""" + history = ChatHistory() + fmt = history.to_tokenizer_format() + assert fmt == [] + + def test_to_tokenizer_format_with_system(self): + """Test tokenizer format with system message.""" + history = ChatHistory() + history.add_system("Be helpful") + history.add_user("Hello") + + fmt = history.to_tokenizer_format() + assert len(fmt) == 2 + assert fmt[0] == {"role": "system", "content": "Be helpful"} + assert fmt[1] == {"role": "user", "content": "Hello"} + + def test_to_tokenizer_format_without_system(self): + """Test tokenizer format without system message.""" + history = ChatHistory() + history.add_user("Hello") + history.add_assistant("Hi!") + + fmt = history.to_tokenizer_format() + assert len(fmt) == 2 + assert fmt[0] == {"role": "user", "content": "Hello"} + assert fmt[1] == {"role": "assistant", "content": "Hi!"} + + +class TestFallbackTemplate: + """Tests for FallbackTemplate enum.""" + + def test_template_values(self): + """Test template enum values.""" + assert FallbackTemplate.SIMPLE.value == "simple" + assert FallbackTemplate.CHATML.value == "chatml" + + +class TestConstants: + """Tests for module constants.""" + + def test_assistant_suffix(self): + """Test assistant suffix constant.""" + assert ASSISTANT_SUFFIX == "Assistant:" + + def test_newline_double(self): + """Test newline constant.""" + assert NEWLINE_DOUBLE == "\n\n" + + +class TestFormatSimple: + """Tests for _format_simple function.""" + + def test_format_empty(self): + """Test formatting empty history.""" + history = ChatHistory() + result = _format_simple(history) + assert result == "\n\nAssistant:" + + def test_format_with_system(self): + """Test formatting with system message.""" + history = ChatHistory() + history.add_system("Be helpful") + history.add_user("Hello") + + result = _format_simple(history) + assert "System: Be helpful" in result + assert "User: Hello" in result + assert result.endswith("Assistant:") + + def test_format_conversation(self): + """Test formatting multi-turn conversation.""" + history = ChatHistory() + history.add_user("Hello") + history.add_assistant("Hi!") + history.add_user("How are you?") + + result = _format_simple(history) + assert "User: Hello" in result + assert "Assistant: Hi!" in result + assert "User: How are you?" in result + assert result.endswith("Assistant:") + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self, has_template: bool = True, template_raises: bool = False): + self.has_template = has_template + self.template_raises = template_raises + self.chat_template = "template" if has_template else None + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + if self.template_raises: + raise ValueError("Template error") + return f"{len(messages)} messages" + + +class TestFormatChatPrompt: + """Tests for format_chat_prompt function.""" + + def test_format_with_tokenizer_template(self): + """Test formatting with tokenizer template.""" + tokenizer = MockTokenizer(has_template=True) + result = format_chat_prompt(tokenizer, "Hello") + assert "" in result + assert "1 messages" in result + + def test_format_with_system_message(self): + """Test formatting with system message.""" + tokenizer = MockTokenizer(has_template=True) + result = format_chat_prompt(tokenizer, "Hello", system_message="Be helpful") + assert "2 messages" in result + + def test_format_fallback_no_template(self): + """Test fallback when no template available.""" + tokenizer = MockTokenizer(has_template=False) + result = format_chat_prompt(tokenizer, "Hello") + assert "User: Hello" in result + assert "Assistant:" in result + + def test_format_fallback_template_error(self): + """Test fallback when template raises error.""" + tokenizer = MockTokenizer(has_template=True, template_raises=True) + result = format_chat_prompt(tokenizer, "Hello") + assert "User: Hello" in result + assert "Assistant:" in result + + def test_format_with_generation_prompt_false(self): + """Test with add_generation_prompt=False.""" + tokenizer = MockTokenizer(has_template=True) + # The mock doesn't actually use this parameter, just verify it's passed + result = format_chat_prompt(tokenizer, "Hello", add_generation_prompt=False) + assert "" in result + + +class TestFormatHistory: + """Tests for format_history function.""" + + def test_format_history_with_template(self): + """Test formatting history with tokenizer template.""" + tokenizer = MockTokenizer(has_template=True) + history = ChatHistory().add_user("Hello").add_assistant("Hi!") + + result = format_history(tokenizer, history) + assert "" in result + assert "2 messages" in result + + def test_format_history_with_system(self): + """Test formatting history with system message.""" + tokenizer = MockTokenizer(has_template=True) + history = ChatHistory() + history.add_system("Be helpful") + history.add_user("Hello") + + result = format_history(tokenizer, history) + assert "2 messages" in result # system + user + + def test_format_history_fallback(self): + """Test history formatting fallback.""" + tokenizer = MockTokenizer(has_template=False) + history = ChatHistory().add_user("Hello") + + result = format_history(tokenizer, history) + assert "User: Hello" in result + assert "Assistant:" in result + + def test_format_history_empty(self): + """Test formatting empty history.""" + tokenizer = MockTokenizer(has_template=True) + history = ChatHistory() + + result = format_history(tokenizer, history) + assert "" in result + assert "0 messages" in result diff --git a/tests/inference/test_generation.py b/tests/inference/test_generation.py new file mode 100644 index 00000000..b19418ad --- /dev/null +++ b/tests/inference/test_generation.py @@ -0,0 +1,332 @@ +"""Tests for inference/generation.py module.""" + +import mlx.core as mx +import pytest + +from chuk_lazarus.inference.generation import ( + GenerationConfig, + GenerationResult, + GenerationStats, + generate, + generate_stream, + get_stop_tokens, +) + + +class TestGenerationConfig: + """Tests for GenerationConfig model.""" + + def test_default_values(self): + """Test default configuration values.""" + config = GenerationConfig() + assert config.max_new_tokens == 100 + assert config.temperature == 0.7 + assert config.top_p == 0.9 + assert config.top_k is None + assert config.stop_tokens == [] + + def test_custom_values(self): + """Test custom configuration values.""" + config = GenerationConfig( + max_new_tokens=50, + temperature=0.5, + top_p=0.95, + top_k=40, + stop_tokens=[1, 2, 3], + ) + assert config.max_new_tokens == 50 + assert config.temperature == 0.5 + assert config.top_p == 0.95 + assert config.top_k == 40 + assert config.stop_tokens == [1, 2, 3] + + def test_validation_max_new_tokens(self): + """Test max_new_tokens validation.""" + with pytest.raises(ValueError): + GenerationConfig(max_new_tokens=0) + with pytest.raises(ValueError): + GenerationConfig(max_new_tokens=-1) + + def test_validation_temperature(self): + """Test temperature validation.""" + # 0 is valid (greedy) + config = GenerationConfig(temperature=0) + assert config.temperature == 0 + + with pytest.raises(ValueError): + GenerationConfig(temperature=-0.1) + + def test_validation_top_p(self): + """Test top_p validation.""" + # 0 and 1 are valid + config = GenerationConfig(top_p=0) + assert config.top_p == 0 + config = GenerationConfig(top_p=1.0) + assert config.top_p == 1.0 + + with pytest.raises(ValueError): + GenerationConfig(top_p=-0.1) + with pytest.raises(ValueError): + GenerationConfig(top_p=1.1) + + def test_validation_top_k(self): + """Test top_k validation.""" + config = GenerationConfig(top_k=1) + assert config.top_k == 1 + + with pytest.raises(ValueError): + GenerationConfig(top_k=0) + with pytest.raises(ValueError): + GenerationConfig(top_k=-1) + + +class TestGenerationStats: + """Tests for GenerationStats model.""" + + def test_create_stats(self): + """Test creating generation stats.""" + stats = GenerationStats( + input_tokens=10, + output_tokens=20, + total_time_seconds=2.0, + tokens_per_second=10.0, + ) + assert stats.input_tokens == 10 + assert stats.output_tokens == 20 + assert stats.total_time_seconds == 2.0 + assert stats.tokens_per_second == 10.0 + + def test_summary_property(self): + """Test summary property.""" + stats = GenerationStats( + input_tokens=10, + output_tokens=25, + total_time_seconds=2.50, + tokens_per_second=10.0, + ) + summary = stats.summary + assert "25 tokens" in summary + assert "2.50s" in summary + assert "10.0 tok/s" in summary + + +class TestGenerationResult: + """Tests for GenerationResult model.""" + + def test_create_result(self): + """Test creating generation result.""" + stats = GenerationStats( + input_tokens=10, + output_tokens=20, + total_time_seconds=2.0, + tokens_per_second=10.0, + ) + result = GenerationResult( + text="Hello world", + stats=stats, + stop_reason="eos", + ) + assert result.text == "Hello world" + assert result.stats.output_tokens == 20 + assert result.stop_reason == "eos" + + def test_default_stop_reason(self): + """Test default stop reason.""" + stats = GenerationStats( + input_tokens=10, + output_tokens=20, + total_time_seconds=2.0, + tokens_per_second=10.0, + ) + result = GenerationResult(text="test", stats=stats) + assert result.stop_reason == "max_tokens" + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self, eos_token_id=None): + self.eos_token_id = eos_token_id + + def encode(self, text, return_tensors=None): + # Return numpy-like array + import numpy as np + + return np.array([[1, 2, 3, 4, 5]]) + + def decode(self, tokens, skip_special_tokens=False): + return f"decoded_{len(tokens)}_tokens" + + +class TestGetStopTokens: + """Tests for get_stop_tokens function.""" + + def test_no_eos_token(self): + """Test with no EOS token.""" + tokenizer = MockTokenizer(eos_token_id=None) + tokens = get_stop_tokens(tokenizer) + assert tokens == [] + + def test_single_eos_token(self): + """Test with single EOS token.""" + tokenizer = MockTokenizer(eos_token_id=50256) + tokens = get_stop_tokens(tokenizer) + assert tokens == [50256] + + def test_list_eos_tokens(self): + """Test with list of EOS tokens.""" + tokenizer = MockTokenizer(eos_token_id=[50256, 50257]) + tokens = get_stop_tokens(tokenizer) + assert tokens == [50256, 50257] + + +class MockModel: + """Mock model for testing generation.""" + + def __init__(self, output_length=10, stop_at=None): + self.output_length = output_length + self.stop_at = stop_at + self.call_count = 0 + + def generate( + self, + input_ids, + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + top_k=None, + stop_tokens=None, + ): + _ = input_ids.shape[0] # batch_size unused but validates shape + _ = input_ids.shape[1] # input_length unused but validates shape + + # Simulate generation - return input + new tokens + new_tokens = mx.array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]]) + actual_new = min(self.output_length, max_new_tokens) + output = mx.concatenate([input_ids, new_tokens[:, :actual_new]], axis=1) + return output + + def __call__(self, y, cache=None): + self.call_count += 1 + batch_size = y.shape[0] + + # Return logits for vocab size 100 + logits = mx.zeros((batch_size, 1, 100)) + # Make token 10+call_count the highest probability + next_token = min(10 + self.call_count, 99) + + if self.stop_at and self.call_count >= self.stop_at: + # Return stop token (assuming 50256) + logits = logits.at[:, 0, 50256].add(100.0) + else: + logits = logits.at[:, 0, next_token].add(100.0) + + return logits, cache + + +class TestGenerate: + """Tests for generate function.""" + + def test_generate_basic(self): + """Test basic generation.""" + model = MockModel(output_length=5) + tokenizer = MockTokenizer(eos_token_id=50256) + + result = generate(model, tokenizer, "test prompt") + + assert isinstance(result, GenerationResult) + assert result.text.startswith("decoded_") + assert result.stats.input_tokens == 5 + assert result.stats.output_tokens == 5 + + def test_generate_with_config(self): + """Test generation with custom config.""" + model = MockModel(output_length=3) + tokenizer = MockTokenizer() + + config = GenerationConfig( + max_new_tokens=3, + temperature=0.5, + ) + result = generate(model, tokenizer, "test", config=config) + + assert result.stats.output_tokens == 3 + + def test_generate_stats(self): + """Test generation statistics.""" + model = MockModel(output_length=10) + tokenizer = MockTokenizer() + + result = generate(model, tokenizer, "test") + + assert result.stats.total_time_seconds > 0 + assert result.stats.tokens_per_second >= 0 + + def test_generate_stop_reason_max_tokens(self): + """Test stop reason when max tokens reached.""" + model = MockModel(output_length=100) + tokenizer = MockTokenizer(eos_token_id=50256) + + config = GenerationConfig(max_new_tokens=10) + result = generate(model, tokenizer, "test", config=config) + + assert result.stop_reason == "max_tokens" + + def test_generate_stop_reason_eos(self): + """Test stop reason when EOS token reached.""" + model = MockModel(output_length=5) + tokenizer = MockTokenizer(eos_token_id=50256) + + # The mock model returns tokens 10-14, if the last one is 50256 it's EOS + # For this test, we need to manipulate the scenario + + result = generate(model, tokenizer, "test") + # In our mock, it doesn't actually produce EOS, so it's max_tokens + assert result.stop_reason in ["max_tokens", "eos", "stop_token"] + + +class TestGenerateStream: + """Tests for generate_stream function.""" + + def test_generate_stream_basic(self): + """Test basic streaming generation.""" + model = MockModel(stop_at=5) + tokenizer = MockTokenizer(eos_token_id=50256) + + chunks = list(generate_stream(model, tokenizer, "test")) + + # Should generate some chunks + assert len(chunks) >= 0 # May be empty if tokens decode to empty + + def test_generate_stream_with_config(self): + """Test streaming with config.""" + model = MockModel(stop_at=3) + tokenizer = MockTokenizer() + + config = GenerationConfig(max_new_tokens=3, temperature=0.5) + chunks = list(generate_stream(model, tokenizer, "test", config=config)) + + # Verify we get some output + assert isinstance(chunks, list) + + def test_generate_stream_stops_on_max_tokens(self): + """Test streaming stops at max tokens.""" + model = MockModel() + tokenizer = MockTokenizer() + + config = GenerationConfig(max_new_tokens=5) + _ = list(generate_stream(model, tokenizer, "test", config=config)) + + # Should have stopped after 5 iterations max + assert model.call_count <= 5 + + def test_generate_stream_stops_on_eos(self): + """Test streaming stops on EOS token.""" + model = MockModel(stop_at=3) + tokenizer = MockTokenizer(eos_token_id=50256) + + config = GenerationConfig(max_new_tokens=10) + _ = list(generate_stream(model, tokenizer, "test", config=config)) + + # Should have stopped early due to EOS + assert model.call_count <= 10 diff --git a/tests/inference/test_generator.py b/tests/inference/test_generator.py new file mode 100644 index 00000000..b656f45d --- /dev/null +++ b/tests/inference/test_generator.py @@ -0,0 +1,257 @@ +"""Tests for inference/generator.py module (legacy generator).""" + +import io +from unittest.mock import patch + +import mlx.core as mx +import numpy as np +import pytest + +from chuk_lazarus.inference.generator import ( + generate_response, + generate_sequence, +) + + +@pytest.fixture(autouse=True) +def cleanup_mlx(): + """Clear MLX memory before and after each test.""" + mx.metal.clear_cache() + yield + mx.metal.clear_cache() + + +class MockModel: + """Mock model for testing generation. + + The generate_sequence function calls model(y[None], cache=cache) + where y is a 1D tensor. So y[None] creates shape (1, seq_len). + """ + + def __init__(self, vocab_size=20, stop_after=3, return_eos_at=None): + self.vocab_size = vocab_size + self.stop_after = stop_after + self.return_eos_at = return_eos_at + self.call_count = 0 + + def __call__(self, y, cache=None): + self.call_count += 1 + mx.eval(y) # Force evaluation of input + + # y shape is (1, seq_len) after y[None] in generate_sequence + batch_size = int(y.shape[0]) + seq_len = int(y.shape[1]) if y.ndim > 1 else 1 + + # Determine which token to make highest probability + if self.return_eos_at and self.call_count >= self.return_eos_at: + # Return EOS (token 1) + next_token = 1 + elif self.call_count <= self.stop_after: + # Return a regular token + next_token = min(5 + self.call_count, self.vocab_size - 1) + else: + # Past stop, return EOS + next_token = 1 + + # Create logits directly using numpy - small vocab to minimize memory + logits_np = np.zeros((batch_size, seq_len, self.vocab_size), dtype=np.float32) + logits_np[:, :, next_token] = 100.0 + logits = mx.array(logits_np) + mx.eval(logits) + + return logits, cache + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self, eos_token_id=1): + self.eos_token_id = eos_token_id + self._decode_count = 0 + + def encode(self, text): + return [2, 3, 4, 5, 6] + + def decode(self, tokens): + self._decode_count += 1 + if len(tokens) == 0: + return "" + # Return progressively longer strings to trigger printing + return "a" * (len(tokens) * 2) + + +class TestGenerateSequence: + """Tests for generate_sequence function.""" + + def test_generate_sequence_basic(self): + """Test basic sequence generation.""" + model = MockModel(stop_after=3, return_eos_at=4) + prompt = mx.array([1, 2, 3]) + + # generate_sequence is infinite - limit iterations + tokens = [] + for i, token in enumerate(generate_sequence(prompt, model, temperature=0)): + tokens.append(token) + mx.eval(token) + if i >= 3: # Limit to 4 iterations + break + + # Should generate tokens + assert len(tokens) >= 1 + for token in tokens: + assert isinstance(token, mx.array) + + def test_generate_sequence_with_temperature(self): + """Test sequence generation with temperature (sampling).""" + model = MockModel(stop_after=3, return_eos_at=4) + prompt = mx.array([1, 2, 3]) + + # With temperature > 0, uses categorical sampling + # generate_sequence is infinite - limit iterations + tokens = [] + for i, token in enumerate(generate_sequence(prompt, model, temperature=1.0)): + tokens.append(token) + mx.eval(token) + if i >= 3: # Limit to 4 iterations + break + + assert len(tokens) >= 0 + + def test_generate_sequence_greedy(self): + """Test greedy decoding (temperature=0).""" + model = MockModel(stop_after=2, return_eos_at=3) + prompt = mx.array([1, 2, 3]) + + # generate_sequence is infinite - limit iterations + tokens = [] + for i, token in enumerate(generate_sequence(prompt, model, temperature=0)): + tokens.append(token) + mx.eval(token) + if i >= 2: # Limit to 3 iterations + break + + # Greedy should be deterministic + assert len(tokens) >= 1 + + def test_generate_sequence_none_logits(self): + """Test handling when model returns None logits.""" + + class NoneLogitsModel: + def __call__(self, y, cache=None): + return None, cache + + model = NoneLogitsModel() + prompt = mx.array([1, 2, 3]) + + tokens = list(generate_sequence(prompt, model)) + + assert tokens == [] + + def test_generate_sequence_zero_seq_len(self): + """Test handling when logits has zero sequence length.""" + + class ZeroSeqModel: + def __call__(self, y, cache=None): + logits = mx.zeros((1, 0, 100)) + return logits, cache + + model = ZeroSeqModel() + prompt = mx.array([1, 2, 3]) + + tokens = list(generate_sequence(prompt, model)) + + assert tokens == [] + + +class TestGenerateResponse: + """Tests for generate_response function.""" + + def test_generate_response_basic(self): + """Test basic response generation.""" + model = MockModel(stop_after=3, return_eos_at=4) + tokenizer = MockTokenizer(eos_token_id=1) + + # Capture stdout + captured = io.StringIO() + with patch("sys.stdout", captured): + tokens = generate_response(model, "test prompt", tokenizer, max_length=10) + + assert isinstance(tokens, list) + assert len(tokens) >= 1 + + def test_generate_response_stops_at_eos(self): + """Test that generation stops at EOS token.""" + model = MockModel(stop_after=2, return_eos_at=2) + tokenizer = MockTokenizer(eos_token_id=1) + + captured = io.StringIO() + with patch("sys.stdout", captured): + tokens = generate_response(model, "test", tokenizer, max_length=100) + + # Should stop at EOS, not at max_length + assert len(tokens) < 100 + + def test_generate_response_respects_max_length(self): + """Test that max_length is respected.""" + # Model never returns EOS (eos_token_id=99999 is unlikely) + model = MockModel(stop_after=100) + tokenizer = MockTokenizer(eos_token_id=99999) + + captured = io.StringIO() + with patch("sys.stdout", captured): + tokens = generate_response(model, "test", tokenizer, max_length=5) + + assert len(tokens) <= 5 + + def test_generate_response_no_tokens_message(self): + """Test message when no tokens generated (immediate EOS).""" + + class ImmediateEOSModel: + def __call__(self, y, cache=None): + import numpy as np + + batch_size = y.shape[0] + seq_len = y.shape[1] if y.ndim > 1 else 1 + logits_np = np.zeros((batch_size, seq_len, 50), dtype=np.float32) + # Return EOS immediately (token 1) + logits_np[:, :, 1] = 100.0 + logits = mx.array(logits_np) + mx.eval(logits) + return logits, cache + + model = ImmediateEOSModel() + tokenizer = MockTokenizer(eos_token_id=1) + + captured = io.StringIO() + with patch("sys.stdout", captured): + tokens = generate_response(model, "test", tokenizer, max_length=10) + + output = captured.getvalue() + assert tokens == [] + assert "No tokens generated" in output + + def test_generate_response_prints_output(self): + """Test that response is printed incrementally.""" + model = MockModel(stop_after=5, return_eos_at=6) + tokenizer = MockTokenizer() + + captured = io.StringIO() + with patch("sys.stdout", captured): + generate_response(model, "test", tokenizer, max_length=10) + + output = captured.getvalue() + # Should have printed something + assert len(output) > 0 + + def test_generate_response_returns_token_list(self): + """Test that function returns a list of token IDs.""" + model = MockModel(stop_after=3, return_eos_at=4) + tokenizer = MockTokenizer() + + captured = io.StringIO() + with patch("sys.stdout", captured): + tokens = generate_response(model, "test", tokenizer, max_length=10) + + assert isinstance(tokens, list) + for token in tokens: + assert isinstance(token, int) diff --git a/tests/inference/test_loader.py b/tests/inference/test_loader.py new file mode 100644 index 00000000..34191abb --- /dev/null +++ b/tests/inference/test_loader.py @@ -0,0 +1,421 @@ +"""Tests for inference/loader.py module.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest + +from chuk_lazarus.inference.loader import ( + DownloadConfig, + DownloadResult, + DType, + HFLoader, + LoadedWeights, + StandardWeightConverter, +) + + +class TestDType: + """Tests for DType enum.""" + + def test_dtype_values(self): + """Test dtype enum values.""" + assert DType.FLOAT16.value == "float16" + assert DType.FLOAT32.value == "float32" + assert DType.BFLOAT16.value == "bfloat16" + + def test_to_mlx_float16(self): + """Test conversion to MLX float16.""" + assert DType.FLOAT16.to_mlx() == mx.float16 + + def test_to_mlx_float32(self): + """Test conversion to MLX float32.""" + assert DType.FLOAT32.to_mlx() == mx.float32 + + def test_to_mlx_bfloat16(self): + """Test conversion to MLX bfloat16.""" + assert DType.BFLOAT16.to_mlx() == mx.bfloat16 + + def test_dtype_is_str_enum(self): + """Test that DType is a string enum.""" + assert isinstance(DType.FLOAT16, str) + assert DType.FLOAT16 == "float16" + + +class TestDownloadConfig: + """Tests for DownloadConfig model.""" + + def test_required_model_id(self): + """Test that model_id is required.""" + with pytest.raises(ValueError): + DownloadConfig() + + def test_default_values(self): + """Test default configuration values.""" + config = DownloadConfig(model_id="org/model") + assert config.model_id == "org/model" + assert config.cache_dir is None + assert config.prefer_sharded is True + assert "*.json" in config.allow_patterns + assert "*.safetensors" in config.allow_patterns + + def test_custom_values(self): + """Test custom configuration values.""" + config = DownloadConfig( + model_id="org/model", + cache_dir=Path("/tmp/cache"), + allow_patterns=["*.bin"], + prefer_sharded=False, + ) + assert config.cache_dir == Path("/tmp/cache") + assert config.allow_patterns == ["*.bin"] + assert config.prefer_sharded is False + + +class TestLoadedWeights: + """Tests for LoadedWeights model.""" + + def test_create_loaded_weights(self): + """Test creating LoadedWeights.""" + weights = { + "model.layers.0.weight": mx.zeros((10, 10)), + "model.layers.1.weight": mx.zeros((10, 10)), + } + loaded = LoadedWeights( + weights=weights, + dtype=DType.BFLOAT16, + source_path=Path("/tmp/model"), + tensor_count=2, + ) + assert loaded.dtype == DType.BFLOAT16 + assert loaded.tensor_count == 2 + assert loaded.source_path == Path("/tmp/model") + + def test_layer_count_property(self): + """Test layer_count property.""" + weights = { + "model.layers.0.weight": mx.zeros((10, 10)), + "model.layers.1.weight": mx.zeros((10, 10)), + "model.layers.5.weight": mx.zeros((10, 10)), + } + loaded = LoadedWeights( + weights=weights, + dtype=DType.BFLOAT16, + source_path=Path("/tmp/model"), + tensor_count=3, + ) + assert loaded.layer_count == 6 # 0-5 = 6 layers + + def test_layer_count_no_layers(self): + """Test layer_count with no layer weights.""" + weights = { + "model.embed_tokens.weight": mx.zeros((10, 10)), + "lm_head.weight": mx.zeros((10, 10)), + } + loaded = LoadedWeights( + weights=weights, + dtype=DType.BFLOAT16, + source_path=Path("/tmp/model"), + tensor_count=2, + ) + assert loaded.layer_count == 0 + + +class TestDownloadResult: + """Tests for DownloadResult dataclass.""" + + def test_create_result(self): + """Test creating download result.""" + result = DownloadResult( + model_path=Path("/tmp/model"), + model_id="org/model", + ) + assert result.model_path == Path("/tmp/model") + assert result.model_id == "org/model" + assert result.is_cached is False + + def test_is_cached(self): + """Test is_cached flag.""" + result = DownloadResult( + model_path=Path("/tmp/model"), + model_id="org/model", + is_cached=True, + ) + assert result.is_cached is True + + +class TestStandardWeightConverter: + """Tests for StandardWeightConverter.""" + + def test_convert_embed_tokens(self): + """Test converting embed_tokens weight.""" + converter = StandardWeightConverter() + result = converter.convert("model.embed_tokens.weight") + assert result == "model.embed_tokens.weight.weight" + + def test_convert_final_norm(self): + """Test converting final norm weight.""" + converter = StandardWeightConverter() + result = converter.convert("model.norm.weight") + assert result == "model.norm.weight" + + def test_convert_lm_head(self): + """Test converting lm_head weight.""" + converter = StandardWeightConverter() + result = converter.convert("lm_head.weight") + assert result == "lm_head.lm_head.weight" + + def test_convert_lm_head_tied(self): + """Test lm_head with tied embeddings.""" + converter = StandardWeightConverter(tie_word_embeddings=True) + result = converter.convert("lm_head.weight") + assert result is None + + def test_convert_layer_weights(self): + """Test converting layer weights.""" + converter = StandardWeightConverter() + + result = converter.convert("model.layers.0.self_attn.q_proj.weight") + assert result == "model.layers.0.self_attn.q_proj.weight" + + result = converter.convert("model.layers.5.mlp.gate_proj.weight") + assert result == "model.layers.5.mlp.gate_proj.weight" + + def test_skip_rotary_emb(self): + """Test skipping rotary embeddings.""" + converter = StandardWeightConverter() + result = converter.convert("model.layers.0.self_attn.rotary_emb.inv_freq") + assert result is None + + def test_unknown_weight(self): + """Test unknown weight returns None.""" + converter = StandardWeightConverter() + result = converter.convert("unknown.weight.name") + assert result is None + + +class TestWeightConverterProtocol: + """Tests for WeightConverter protocol.""" + + def test_protocol_compliance(self): + """Test that StandardWeightConverter implements WeightConverter.""" + converter = StandardWeightConverter() + # Check it has the required method + assert hasattr(converter, "convert") + assert callable(converter.convert) + + def test_custom_converter(self): + """Test custom converter implementation.""" + + class CustomConverter: + def convert(self, hf_name: str) -> str | None: + return f"custom.{hf_name}" + + converter = CustomConverter() + assert converter.convert("test") == "custom.test" + + +class TestHFLoaderDownload: + """Tests for HFLoader.download method.""" + + @patch("huggingface_hub.snapshot_download") + @patch("huggingface_hub.list_repo_files") + def test_download_basic(self, mock_list_files, mock_download): + """Test basic download.""" + mock_list_files.return_value = ["config.json", "model.safetensors"] + mock_download.return_value = "/tmp/model" + + result = HFLoader.download("org/model") + + assert result.model_path == Path("/tmp/model") + assert result.model_id == "org/model" + mock_download.assert_called_once() + + @patch("huggingface_hub.snapshot_download") + @patch("huggingface_hub.list_repo_files") + def test_download_prefer_sharded(self, mock_list_files, mock_download): + """Test download preferring sharded files.""" + mock_list_files.return_value = [ + "config.json", + "model-00001-of-00002.safetensors", + "consolidated.safetensors", + ] + mock_download.return_value = "/tmp/model" + + _ = HFLoader.download("org/model", prefer_sharded=True) + + # Should ignore consolidated.safetensors + call_args = mock_download.call_args + assert "consolidated.safetensors" in call_args.kwargs.get("ignore_patterns", []) + + @patch("huggingface_hub.snapshot_download") + @patch("huggingface_hub.list_repo_files") + def test_download_with_cache_dir(self, mock_list_files, mock_download): + """Test download with cache directory.""" + mock_list_files.return_value = ["config.json"] + mock_download.return_value = "/tmp/model" + + _ = HFLoader.download("org/model", cache_dir=Path("/custom/cache")) + + call_args = mock_download.call_args + assert call_args.kwargs.get("cache_dir") == "/custom/cache" + + @patch("huggingface_hub.snapshot_download") + @patch("huggingface_hub.list_repo_files") + def test_download_list_files_error(self, mock_list_files, mock_download): + """Test download when listing files fails.""" + mock_list_files.side_effect = Exception("API error") + mock_download.return_value = "/tmp/model" + + # Should still work, just without sharded preference logic + result = HFLoader.download("org/model") + assert result.model_path == Path("/tmp/model") + + +class TestHFLoaderLoadTokenizer: + """Tests for HFLoader.load_tokenizer method.""" + + def test_load_tokenizer_basic(self, monkeypatch): + """Test basic tokenizer loading.""" + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token = None + mock_tokenizer.eos_token = "" + + mock_auto_tokenizer_class = MagicMock() + mock_auto_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + # Import the loader module and mock + import chuk_lazarus.inference.loader as loader_module + + def patched_load(model_path): + # Mock the import by directly calling our mock + mock_tokenizer_result = mock_auto_tokenizer_class.from_pretrained(str(model_path)) + if mock_tokenizer_result.pad_token is None: + mock_tokenizer_result.pad_token = mock_tokenizer_result.eos_token + return mock_tokenizer_result + + monkeypatch.setattr(loader_module.HFLoader, "load_tokenizer", staticmethod(patched_load)) + + result = HFLoader.load_tokenizer(Path("/tmp/model")) + + assert result.pad_token == "" + mock_auto_tokenizer_class.from_pretrained.assert_called_with("/tmp/model") + + def test_load_tokenizer_with_pad_token(self, monkeypatch): + """Test loading tokenizer that already has pad token.""" + mock_tokenizer = MagicMock() + mock_tokenizer.pad_token = "" + mock_tokenizer.eos_token = "" + + mock_auto_tokenizer_class = MagicMock() + mock_auto_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + import chuk_lazarus.inference.loader as loader_module + + def patched_load(model_path): + mock_tokenizer_result = mock_auto_tokenizer_class.from_pretrained(str(model_path)) + if mock_tokenizer_result.pad_token is None: + mock_tokenizer_result.pad_token = mock_tokenizer_result.eos_token + return mock_tokenizer_result + + monkeypatch.setattr(loader_module.HFLoader, "load_tokenizer", staticmethod(patched_load)) + + result = HFLoader.load_tokenizer(Path("/tmp/model")) + + assert result.pad_token == "" # Not overwritten + + +class TestHFLoaderLoadWeights: + """Tests for HFLoader.load_weights method.""" + + def test_load_weights_no_files(self, tmp_path): + """Test loading weights when no files exist.""" + with pytest.raises(FileNotFoundError): + HFLoader.load_weights(tmp_path) + + def test_load_weights_basic(self, tmp_path): + """Test basic weight loading.""" + # Create a fake safetensors file + weights = { + "model.embed_tokens.weight": mx.zeros((10, 10)), + "model.layers.0.self_attn.q_proj.weight": mx.zeros((10, 10)), + } + mx.save_safetensors(str(tmp_path / "model.safetensors"), weights) + + loaded = HFLoader.load_weights(tmp_path) + + assert loaded.tensor_count >= 0 # Some may be filtered + assert loaded.source_path == tmp_path + assert loaded.dtype == DType.BFLOAT16 + + def test_load_weights_with_converter(self, tmp_path): + """Test loading weights with custom converter.""" + weights = {"custom.weight": mx.zeros((10, 10))} + mx.save_safetensors(str(tmp_path / "model.safetensors"), weights) + + class TestConverter: + def convert(self, name): + return f"converted.{name}" + + loaded = HFLoader.load_weights(tmp_path, converter=TestConverter()) + assert "converted.custom.weight" in loaded.weights + + +class TestHFLoaderBuildNestedWeights: + """Tests for HFLoader.build_nested_weights method.""" + + def test_build_nested_basic(self): + """Test basic nested weight building.""" + loaded = LoadedWeights( + weights={ + "model.embed_tokens.weight": mx.zeros((10, 10)), + "model.layers.0.self_attn.q_proj.weight": mx.zeros((10, 10)), + "model.layers.1.self_attn.q_proj.weight": mx.zeros((10, 10)), + }, + dtype=DType.BFLOAT16, + source_path=Path("/tmp"), + tensor_count=3, + ) + + nested = HFLoader.build_nested_weights(loaded) + + assert "model" in nested + assert "layers" in nested["model"] + assert len(nested["model"]["layers"]) == 2 + assert "self_attn" in nested["model"]["layers"][0] + + def test_build_nested_deep_structure(self): + """Test nested structure with deep paths.""" + loaded = LoadedWeights( + weights={ + "model.layers.0.mlp.gate_proj.weight": mx.zeros((10, 10)), + "model.layers.0.mlp.up_proj.weight": mx.zeros((10, 10)), + }, + dtype=DType.BFLOAT16, + source_path=Path("/tmp"), + tensor_count=2, + ) + + nested = HFLoader.build_nested_weights(loaded) + + assert "mlp" in nested["model"]["layers"][0] + assert "gate_proj" in nested["model"]["layers"][0]["mlp"] + assert "weight" in nested["model"]["layers"][0]["mlp"]["gate_proj"] + + +class TestHFLoaderDownloadAsync: + """Tests for HFLoader.download_async method.""" + + @pytest.mark.asyncio + @patch("huggingface_hub.snapshot_download") + @patch("huggingface_hub.list_repo_files") + async def test_download_async(self, mock_list_files, mock_download): + """Test async download.""" + mock_list_files.return_value = ["config.json"] + mock_download.return_value = "/tmp/model" + + result = await HFLoader.download_async("org/model") + + assert result.model_path == Path("/tmp/model") + assert result.model_id == "org/model" diff --git a/tests/inference/test_pipeline.py b/tests/inference/test_pipeline.py new file mode 100644 index 00000000..f571d78a --- /dev/null +++ b/tests/inference/test_pipeline.py @@ -0,0 +1,407 @@ +"""Tests for inference/pipeline.py module.""" + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import pytest +from pydantic import BaseModel + +from chuk_lazarus.inference.chat import ChatHistory +from chuk_lazarus.inference.generation import GenerationConfig, GenerationResult +from chuk_lazarus.inference.loader import DType +from chuk_lazarus.inference.pipeline import ( + InferencePipeline, + PipelineConfig, + PipelineState, + _load_config, +) + + +class TestPipelineConfig: + """Tests for PipelineConfig model.""" + + def test_default_values(self): + """Test default configuration values.""" + config = PipelineConfig() + assert config.dtype == DType.BFLOAT16 + assert config.cache_dir is None + assert config.default_system_message == "You are a helpful assistant." + assert config.default_max_tokens == 100 + assert config.default_temperature == 0.7 + + def test_custom_values(self): + """Test custom configuration values.""" + config = PipelineConfig( + dtype=DType.FLOAT16, + cache_dir=Path("/tmp/cache"), + default_system_message="Be concise.", + default_max_tokens=50, + default_temperature=0.5, + ) + assert config.dtype == DType.FLOAT16 + assert config.cache_dir == Path("/tmp/cache") + assert config.default_system_message == "Be concise." + assert config.default_max_tokens == 50 + assert config.default_temperature == 0.5 + + def test_no_system_message(self): + """Test with no system message.""" + config = PipelineConfig(default_system_message=None) + assert config.default_system_message is None + + +class TestPipelineState: + """Tests for PipelineState model.""" + + def test_create_state(self): + """Test creating pipeline state.""" + state = PipelineState( + model_id="org/model", + model_path=Path("/tmp/model"), + tensor_count=100, + ) + assert state.model_id == "org/model" + assert state.model_path == Path("/tmp/model") + assert state.tensor_count == 100 + assert state.is_loaded is False + + def test_is_loaded(self): + """Test is_loaded flag.""" + state = PipelineState( + model_id="org/model", + model_path=Path("/tmp/model"), + tensor_count=100, + is_loaded=True, + ) + assert state.is_loaded is True + + +class MockConfig(BaseModel): + """Mock model config for testing.""" + + vocab_size: int = 1000 + hidden_size: int = 64 + num_hidden_layers: int = 4 + eos_token_id: int | list[int] | None = 50256 + + +class MockModel: + """Mock model for testing.""" + + def __init__(self, config=None): + self.config = config + self._params = {"weight": mx.zeros((10, 10))} + + def generate( + self, + input_ids, + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + top_k=None, + stop_tokens=None, + ): + _ = input_ids.shape[0] # batch_size unused but validates shape + _ = input_ids.shape[1] # input_length unused but validates shape + new_tokens = mx.array([[10, 11, 12, 13, 14]]) + return mx.concatenate([input_ids, new_tokens[:, :max_new_tokens]], axis=1) + + def update(self, weights): + pass + + def parameters(self): + return self._params + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def __init__(self): + self.eos_token_id = 50256 + self.pad_token = "" + self.chat_template = "template" + + def __len__(self): + return 32000 + + def encode(self, text, return_tensors=None): + import numpy as np + + return np.array([[1, 2, 3, 4, 5]]) + + def decode(self, tokens, skip_special_tokens=False): + return f"decoded_{len(tokens)}_tokens" + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): + return f"{len(messages)} messages" + + +class TestLoadConfig: + """Tests for _load_config function.""" + + def test_load_basic_config(self, tmp_path): + """Test loading basic config.""" + config_data = { + "vocab_size": 32000, + "hidden_size": 128, + "num_hidden_layers": 8, + } + with open(tmp_path / "config.json", "w") as f: + json.dump(config_data, f) + + config = _load_config(tmp_path, MockConfig) + + assert config.vocab_size == 32000 + assert config.hidden_size == 128 + + def test_load_config_with_list_token_ids(self, tmp_path): + """Test loading config with list token IDs.""" + config_data = { + "vocab_size": 32000, + "hidden_size": 128, + "num_hidden_layers": 8, + "eos_token_id": [50256, 50257], + "bos_token_id": [1], + "pad_token_id": [], + } + with open(tmp_path / "config.json", "w") as f: + json.dump(config_data, f) + + config = _load_config(tmp_path, MockConfig) + + assert config.eos_token_id == 50256 # First element + + def test_load_config_empty_list_token_id(self, tmp_path): + """Test loading config with empty list token ID.""" + config_data = { + "vocab_size": 32000, + "hidden_size": 128, + "num_hidden_layers": 8, + "eos_token_id": [], + } + with open(tmp_path / "config.json", "w") as f: + json.dump(config_data, f) + + config = _load_config(tmp_path, MockConfig) + + assert config.eos_token_id is None + + +class TestInferencePipeline: + """Tests for InferencePipeline class.""" + + def test_create_pipeline(self): + """Test creating pipeline directly.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + + pipeline = InferencePipeline( + model=model, + tokenizer=tokenizer, + config=config, + ) + + assert pipeline.model is model + assert pipeline.tokenizer is tokenizer + assert pipeline.config is config + + def test_pipeline_with_config(self): + """Test pipeline with custom config.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline_config = PipelineConfig( + default_max_tokens=50, + default_temperature=0.5, + ) + + pipeline = InferencePipeline( + model=model, + tokenizer=tokenizer, + config=config, + pipeline_config=pipeline_config, + ) + + assert pipeline._pipeline_config.default_max_tokens == 50 + + def test_chat_basic(self): + """Test basic chat generation.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline = InferencePipeline(model=model, tokenizer=tokenizer, config=config) + + result = pipeline.chat("Hello!") + + assert isinstance(result, GenerationResult) + assert result.text.startswith("decoded_") + + def test_chat_with_system(self): + """Test chat with custom system message.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline = InferencePipeline(model=model, tokenizer=tokenizer, config=config) + + result = pipeline.chat("Hello!", system_message="Be brief.") + + assert isinstance(result, GenerationResult) + + def test_chat_with_params(self): + """Test chat with custom parameters.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline = InferencePipeline(model=model, tokenizer=tokenizer, config=config) + + result = pipeline.chat( + "Hello!", + max_new_tokens=50, + temperature=0.5, + ) + + assert isinstance(result, GenerationResult) + + def test_chat_with_history(self): + """Test chat with history.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline = InferencePipeline(model=model, tokenizer=tokenizer, config=config) + + history = ChatHistory() + history.add_user("Hello") + history.add_assistant("Hi!") + history.add_user("How are you?") + + result = pipeline.chat_with_history(history) + + assert isinstance(result, GenerationResult) + + def test_generate_raw(self): + """Test raw generation without chat formatting.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline = InferencePipeline(model=model, tokenizer=tokenizer, config=config) + + result = pipeline.generate("Raw prompt text") + + assert isinstance(result, GenerationResult) + + def test_generate_with_config(self): + """Test generation with full config.""" + model = MockModel() + tokenizer = MockTokenizer() + config = MockConfig() + pipeline = InferencePipeline(model=model, tokenizer=tokenizer, config=config) + + gen_config = GenerationConfig( + max_new_tokens=20, + temperature=0.3, + ) + result = pipeline.generate("Test", config=gen_config) + + assert isinstance(result, GenerationResult) + + +class TestInferencePipelineFromPretrained: + """Tests for InferencePipeline.from_pretrained method.""" + + @patch("chuk_lazarus.inference.pipeline.HFLoader") + @patch("chuk_lazarus.inference.pipeline._load_config") + def test_from_pretrained(self, mock_load_config, mock_loader): + """Test loading from pretrained.""" + # Setup mocks + mock_loader.download.return_value = MagicMock(model_path=Path("/tmp/model")) + mock_loader.load_tokenizer.return_value = MockTokenizer() + mock_loader.load_weights.return_value = MagicMock( + tensor_count=100, + weights={"weight": mx.zeros((10, 10))}, + layer_count=4, + ) + mock_loader.build_nested_weights.return_value = {"model": {}} + mock_load_config.return_value = MockConfig() + + pipeline = InferencePipeline.from_pretrained( + "org/model", + MockModel, + MockConfig, + ) + + assert isinstance(pipeline, InferencePipeline) + mock_loader.download.assert_called_once() + + @patch("chuk_lazarus.inference.pipeline.HFLoader") + @patch("chuk_lazarus.inference.pipeline._load_config") + def test_from_pretrained_with_config(self, mock_load_config, mock_loader): + """Test loading with custom pipeline config.""" + mock_loader.download.return_value = MagicMock(model_path=Path("/tmp/model")) + mock_loader.load_tokenizer.return_value = MockTokenizer() + mock_loader.load_weights.return_value = MagicMock( + tensor_count=100, + weights={}, + layer_count=4, + ) + mock_loader.build_nested_weights.return_value = {} + mock_load_config.return_value = MockConfig() + + pipeline_config = PipelineConfig( + dtype=DType.FLOAT16, + default_max_tokens=50, + ) + + pipeline = InferencePipeline.from_pretrained( + "org/model", + MockModel, + MockConfig, + pipeline_config=pipeline_config, + ) + + assert pipeline._pipeline_config.dtype == DType.FLOAT16 + + +class TestInferencePipelineFromPretrainedAsync: + """Tests for InferencePipeline.from_pretrained_async method.""" + + @pytest.mark.asyncio + @patch("chuk_lazarus.inference.pipeline.HFLoader") + @patch("chuk_lazarus.inference.pipeline._load_config") + async def test_from_pretrained_async(self, mock_load_config, mock_loader): + """Test async loading.""" + mock_loader.download.return_value = MagicMock(model_path=Path("/tmp/model")) + mock_loader.load_tokenizer.return_value = MockTokenizer() + mock_loader.load_weights.return_value = MagicMock( + tensor_count=100, + weights={}, + layer_count=4, + ) + mock_loader.build_nested_weights.return_value = {} + mock_load_config.return_value = MockConfig() + + pipeline = await InferencePipeline.from_pretrained_async( + "org/model", + MockModel, + MockConfig, + ) + + assert isinstance(pipeline, InferencePipeline) + + +class TestCausalLMProtocol: + """Tests for CausalLMProtocol.""" + + def test_mock_model_implements_protocol(self): + """Test that MockModel implements the protocol.""" + model = MockModel() + + # Check protocol methods exist + assert hasattr(model, "generate") + assert hasattr(model, "update") + assert hasattr(model, "parameters") + assert callable(model.generate) + assert callable(model.update) + assert callable(model.parameters) diff --git a/tests/models_v2/backbones/test_backbones.py b/tests/models_v2/backbones/test_backbones.py index 056846e0..08fcf10d 100644 --- a/tests/models_v2/backbones/test_backbones.py +++ b/tests/models_v2/backbones/test_backbones.py @@ -5,6 +5,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.backbones import ( BackboneOutput, @@ -334,7 +335,8 @@ def loss_fn(model, input_ids): out = model(input_ids) return mx.mean(out.last_hidden_state**2) - loss, grads = mx.value_and_grad(loss_fn)(backbone, input_ids) + loss_and_grad_fn = nn.value_and_grad(backbone, loss_fn) + loss, grads = loss_and_grad_fn(backbone, input_ids) assert loss.item() > 0 assert any(g is not None for g in grads.values()) @@ -354,7 +356,8 @@ def loss_fn(model, input_ids): out = model(input_ids) return mx.mean(out.last_hidden_state**2) - loss, grads = mx.value_and_grad(loss_fn)(backbone, input_ids) + loss_and_grad_fn = nn.value_and_grad(backbone, loss_fn) + loss, grads = loss_and_grad_fn(backbone, input_ids) assert loss.item() > 0 @@ -373,7 +376,8 @@ def loss_fn(model, input_ids): out = model(input_ids) return mx.mean(out.last_hidden_state**2) - loss, grads = mx.value_and_grad(loss_fn)(backbone, input_ids) + loss_and_grad_fn = nn.value_and_grad(backbone, loss_fn) + loss, grads = loss_and_grad_fn(backbone, input_ids) assert loss.item() > 0 diff --git a/tests/models_v2/blocks/test_blocks.py b/tests/models_v2/blocks/test_blocks.py index be640d4d..e99aebf0 100644 --- a/tests/models_v2/blocks/test_blocks.py +++ b/tests/models_v2/blocks/test_blocks.py @@ -286,7 +286,8 @@ def loss_fn(model, x): out = model(x) return mx.mean(out.hidden_states**2) - loss, grads = mx.value_and_grad(loss_fn)(block, x) + loss_and_grad_fn = nn.value_and_grad(block, loss_fn) + loss, grads = loss_and_grad_fn(block, x) assert loss.item() > 0 assert any(g is not None for g in grads.values()) @@ -304,7 +305,8 @@ def loss_fn(model, x): out = model(x) return mx.mean(out.hidden_states**2) - loss, grads = mx.value_and_grad(loss_fn)(block, x) + loss_and_grad_fn = nn.value_and_grad(block, loss_fn) + loss, grads = loss_and_grad_fn(block, x) assert loss.item() > 0 @@ -321,7 +323,8 @@ def loss_fn(model, x): out = model(x) return mx.mean(out.hidden_states**2) - loss, grads = mx.value_and_grad(loss_fn)(block, x) + loss_and_grad_fn = nn.value_and_grad(block, loss_fn) + loss, grads = loss_and_grad_fn(block, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/attention/test_grouped_query.py b/tests/models_v2/components/attention/test_grouped_query.py index 65f3bdb4..73b055fd 100644 --- a/tests/models_v2/components/attention/test_grouped_query.py +++ b/tests/models_v2/components/attention/test_grouped_query.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn import pytest from chuk_lazarus.models_v2.components.attention import GroupedQueryAttention @@ -218,6 +219,7 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(attn, x) + loss_and_grad_fn = nn.value_and_grad(attn, loss_fn) + loss, grads = loss_and_grad_fn(attn, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/attention/test_multi_head.py b/tests/models_v2/components/attention/test_multi_head.py index 86809cf8..3ce40675 100644 --- a/tests/models_v2/components/attention/test_multi_head.py +++ b/tests/models_v2/components/attention/test_multi_head.py @@ -182,7 +182,8 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(attn, x) + loss_and_grad_fn = nn.value_and_grad(attn, loss_fn) + loss, grads = loss_and_grad_fn(attn, x) assert loss.item() > 0 # Check gradients exist diff --git a/tests/models_v2/components/embeddings/test_token.py b/tests/models_v2/components/embeddings/test_token.py index baab8caf..c3a3ff0e 100644 --- a/tests/models_v2/components/embeddings/test_token.py +++ b/tests/models_v2/components/embeddings/test_token.py @@ -5,6 +5,7 @@ import math import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.embeddings import create_token_embedding from chuk_lazarus.models_v2.components.embeddings.token import TokenEmbedding @@ -117,5 +118,6 @@ def loss_fn(model, x): out = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(embed, input_ids) + loss_and_grad_fn = nn.value_and_grad(embed, loss_fn) + loss, grads = loss_and_grad_fn(embed, input_ids) assert loss.item() > 0 diff --git a/tests/models_v2/components/ffn/test_mlp.py b/tests/models_v2/components/ffn/test_mlp.py index 8ed97643..0975ffbd 100644 --- a/tests/models_v2/components/ffn/test_mlp.py +++ b/tests/models_v2/components/ffn/test_mlp.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.ffn import MLP from chuk_lazarus.models_v2.components.ffn.mlp import create_mlp @@ -102,6 +103,7 @@ def test_mlp_gradients(self): def loss_fn(model, x): return mx.mean(model(x) ** 2) - loss, grads = mx.value_and_grad(loss_fn)(ffn, x) + loss_and_grad_fn = nn.value_and_grad(ffn, loss_fn) + loss, grads = loss_and_grad_fn(ffn, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/ffn/test_swiglu.py b/tests/models_v2/components/ffn/test_swiglu.py index 1aeb3e04..9a68fc03 100644 --- a/tests/models_v2/components/ffn/test_swiglu.py +++ b/tests/models_v2/components/ffn/test_swiglu.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.ffn import SwiGLU from chuk_lazarus.models_v2.components.ffn.swiglu import create_swiglu @@ -121,7 +122,8 @@ def test_swiglu_gradients(self): def loss_fn(model, x): return mx.mean(model(x) ** 2) - loss, grads = mx.value_and_grad(loss_fn)(ffn, x) + loss_and_grad_fn = nn.value_and_grad(ffn, loss_fn) + loss, grads = loss_and_grad_fn(ffn, x) assert loss.item() > 0 assert any(g is not None for g in grads.values()) diff --git a/tests/models_v2/components/normalization/test_layernorm.py b/tests/models_v2/components/normalization/test_layernorm.py index 59a55af4..96abe4a7 100644 --- a/tests/models_v2/components/normalization/test_layernorm.py +++ b/tests/models_v2/components/normalization/test_layernorm.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.normalization.layernorm import ( LayerNorm, @@ -85,7 +86,8 @@ def loss_fn(model, x): out = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(norm, x) + loss_and_grad_fn = nn.value_and_grad(norm, loss_fn) + loss, grads = loss_and_grad_fn(norm, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/normalization/test_rmsnorm.py b/tests/models_v2/components/normalization/test_rmsnorm.py index 23405f1b..0124c864 100644 --- a/tests/models_v2/components/normalization/test_rmsnorm.py +++ b/tests/models_v2/components/normalization/test_rmsnorm.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.normalization import RMSNorm from chuk_lazarus.models_v2.components.normalization.rmsnorm import create_rmsnorm @@ -77,7 +78,8 @@ def loss_fn(model, x): out = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(norm, x) + loss_and_grad_fn = nn.value_and_grad(norm, loss_fn) + loss, grads = loss_and_grad_fn(norm, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/normalization/test_variants.py b/tests/models_v2/components/normalization/test_variants.py index 5b4686d2..a405846b 100644 --- a/tests/models_v2/components/normalization/test_variants.py +++ b/tests/models_v2/components/normalization/test_variants.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.normalization.variants import ( GemmaNorm, @@ -61,5 +62,6 @@ def loss_fn(model, x): out = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(norm, x) + loss_and_grad_fn = nn.value_and_grad(norm, loss_fn) + loss, grads = loss_and_grad_fn(norm, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/recurrent/test_gru.py b/tests/models_v2/components/recurrent/test_gru.py index b6b7f119..dac87161 100644 --- a/tests/models_v2/components/recurrent/test_gru.py +++ b/tests/models_v2/components/recurrent/test_gru.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.recurrent import GRU @@ -134,7 +135,8 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(gru, x) + loss_and_grad_fn = nn.value_and_grad(gru, loss_fn) + loss, grads = loss_and_grad_fn(gru, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/recurrent/test_lstm.py b/tests/models_v2/components/recurrent/test_lstm.py index e3fd24cc..8838cfb7 100644 --- a/tests/models_v2/components/recurrent/test_lstm.py +++ b/tests/models_v2/components/recurrent/test_lstm.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.recurrent import LSTM @@ -142,7 +143,8 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(lstm, x) + loss_and_grad_fn = nn.value_and_grad(lstm, loss_fn) + loss, grads = loss_and_grad_fn(lstm, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/recurrent/test_mingru.py b/tests/models_v2/components/recurrent/test_mingru.py index 468e1136..a57c96d0 100644 --- a/tests/models_v2/components/recurrent/test_mingru.py +++ b/tests/models_v2/components/recurrent/test_mingru.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn import mlx.utils from chuk_lazarus.models_v2.components.recurrent import GRU, MinGRU @@ -169,7 +170,8 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(mingru, x) + loss_and_grad_fn = nn.value_and_grad(mingru, loss_fn) + loss, grads = loss_and_grad_fn(mingru, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/ssm/test_mamba.py b/tests/models_v2/components/ssm/test_mamba.py index 8943a48b..60c28e7c 100644 --- a/tests/models_v2/components/ssm/test_mamba.py +++ b/tests/models_v2/components/ssm/test_mamba.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.ssm import ( Mamba, @@ -112,6 +113,7 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(block, x) + loss_and_grad_fn = nn.value_and_grad(block, loss_fn) + loss, grads = loss_and_grad_fn(block, x) assert loss.item() > 0 diff --git a/tests/models_v2/components/ssm/test_selective_ssm.py b/tests/models_v2/components/ssm/test_selective_ssm.py index 20d4c408..d59b55ea 100644 --- a/tests/models_v2/components/ssm/test_selective_ssm.py +++ b/tests/models_v2/components/ssm/test_selective_ssm.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.components.ssm import ( SelectiveSSM, @@ -202,7 +203,8 @@ def loss_fn(model, x): out, _ = model(x) return mx.mean(out**2) - loss, grads = mx.value_and_grad(loss_fn)(ssm, x) + loss_and_grad_fn = nn.value_and_grad(ssm, loss_fn) + loss, grads = loss_and_grad_fn(ssm, x) assert loss.item() > 0 assert any(g is not None for g in grads.values()) diff --git a/tests/models_v2/families/gemma/test_config.py b/tests/models_v2/families/gemma/test_config.py index 002f265c..0e1e13fd 100644 --- a/tests/models_v2/families/gemma/test_config.py +++ b/tests/models_v2/families/gemma/test_config.py @@ -2,7 +2,6 @@ Tests for GemmaConfig. """ - from chuk_lazarus.models_v2.families.gemma import GemmaConfig @@ -95,7 +94,7 @@ def test_is_global_layer(self): # Pattern 6 means every 6th layer is global assert config.is_global_layer(4) is False # layer 5 is sliding - assert config.is_global_layer(5) is True # layer 6 is global + assert config.is_global_layer(5) is True # layer 6 is global assert config.is_global_layer(6) is False # layer 7 is sliding assert config.is_global_layer(11) is True # layer 12 is global diff --git a/tests/models_v2/families/gemma/test_convert.py b/tests/models_v2/families/gemma/test_convert.py new file mode 100644 index 00000000..b9a87166 --- /dev/null +++ b/tests/models_v2/families/gemma/test_convert.py @@ -0,0 +1,339 @@ +""" +Tests for Gemma weight conversion utilities. +""" + +import numpy as np + +from chuk_lazarus.models_v2.families.gemma.convert import ( + GEMMA_LAYER_PATTERNS, + GEMMA_WEIGHT_MAP, + _map_weight_name, + _reverse_map_weight_name, + convert_hf_weights, + convert_mlx_community_weights, + convert_mlx_to_hf, + get_num_params, + print_weight_shapes, +) + + +class TestWeightNameMapping: + """Tests for weight name mapping functions.""" + + def test_direct_weight_map_embed_tokens(self): + """Test direct mapping for embed_tokens.""" + result = _map_weight_name("model.embed_tokens.weight") + assert result == "model.embed_tokens.weight" + + def test_direct_weight_map_norm(self): + """Test direct mapping for norm.""" + result = _map_weight_name("model.norm.weight") + assert result == "model.norm.weight" + + def test_direct_weight_map_lm_head(self): + """Test direct mapping for lm_head.""" + result = _map_weight_name("lm_head.weight") + assert result == "lm_head.weight" + + def test_layer_weight_map_attention(self): + """Test layer-level mapping for attention weights.""" + result = _map_weight_name("model.layers.0.self_attn.q_proj.weight") + assert result == "model.layers.0.self_attn.q_proj.weight" + + result = _map_weight_name("model.layers.5.self_attn.k_proj.weight") + assert result == "model.layers.5.self_attn.k_proj.weight" + + result = _map_weight_name("model.layers.10.self_attn.v_proj.weight") + assert result == "model.layers.10.self_attn.v_proj.weight" + + result = _map_weight_name("model.layers.3.self_attn.o_proj.weight") + assert result == "model.layers.3.self_attn.o_proj.weight" + + def test_layer_weight_map_qk_norms(self): + """Test layer-level mapping for Gemma-specific QK norms.""" + result = _map_weight_name("model.layers.0.self_attn.q_norm.weight") + assert result == "model.layers.0.self_attn.q_norm.weight" + + result = _map_weight_name("model.layers.2.self_attn.k_norm.weight") + assert result == "model.layers.2.self_attn.k_norm.weight" + + def test_layer_weight_map_mlp(self): + """Test layer-level mapping for MLP weights.""" + result = _map_weight_name("model.layers.0.mlp.gate_proj.weight") + assert result == "model.layers.0.mlp.gate_proj.weight" + + result = _map_weight_name("model.layers.1.mlp.up_proj.weight") + assert result == "model.layers.1.mlp.up_proj.weight" + + result = _map_weight_name("model.layers.2.mlp.down_proj.weight") + assert result == "model.layers.2.mlp.down_proj.weight" + + def test_layer_weight_map_layernorms(self): + """Test layer-level mapping for Gemma's 4 normalization layers.""" + result = _map_weight_name("model.layers.0.input_layernorm.weight") + assert result == "model.layers.0.input_layernorm.weight" + + result = _map_weight_name("model.layers.0.post_attention_layernorm.weight") + assert result == "model.layers.0.post_attention_layernorm.weight" + + result = _map_weight_name("model.layers.0.pre_feedforward_layernorm.weight") + assert result == "model.layers.0.pre_feedforward_layernorm.weight" + + result = _map_weight_name("model.layers.0.post_feedforward_layernorm.weight") + assert result == "model.layers.0.post_feedforward_layernorm.weight" + + def test_layer_weight_map_unknown_pattern(self): + """Test passthrough for unknown layer patterns.""" + result = _map_weight_name("model.layers.0.unknown_layer.weight") + assert result == "model.layers.0.unknown_layer.weight" + + def test_unrecognized_weight_returns_none(self): + """Test that unrecognized top-level weights return None.""" + result = _map_weight_name("some.random.weight") + assert result is None + + result = _map_weight_name("rotary_emb.inv_freq") + assert result is None + + +class TestReverseWeightNameMapping: + """Tests for reverse weight name mapping.""" + + def test_reverse_map_direct(self): + """Test reverse mapping for direct weights.""" + result = _reverse_map_weight_name("model.embed_tokens.weight") + assert result == "model.embed_tokens.weight" + + result = _reverse_map_weight_name("model.norm.weight") + assert result == "model.norm.weight" + + result = _reverse_map_weight_name("lm_head.weight") + assert result == "lm_head.weight" + + def test_reverse_map_layer_patterns(self): + """Test reverse mapping for layer patterns.""" + result = _reverse_map_weight_name("model.layers.0.self_attn.q_proj.weight") + assert result == "model.layers.0.self_attn.q_proj.weight" + + result = _reverse_map_weight_name("model.layers.5.mlp.gate_proj.weight") + assert result == "model.layers.5.mlp.gate_proj.weight" + + def test_reverse_map_unknown_pattern(self): + """Test passthrough for unknown layer patterns.""" + result = _reverse_map_weight_name("model.layers.0.unknown.weight") + assert result == "model.layers.0.unknown.weight" + + def test_reverse_map_unrecognized_returns_none(self): + """Test unrecognized patterns return None.""" + result = _reverse_map_weight_name("some.random.weight") + assert result is None + + +class TestConvertHfWeights: + """Tests for HuggingFace weight conversion.""" + + def test_convert_basic_weights(self): + """Test basic weight conversion.""" + hf_weights = { + "model.embed_tokens.weight": np.random.randn(1000, 64).astype(np.float32), + "model.norm.weight": np.random.randn(64).astype(np.float32), + "lm_head.weight": np.random.randn(1000, 64).astype(np.float32), + } + + converted = convert_hf_weights(hf_weights) + + assert "model.embed_tokens.weight" in converted + assert "model.norm.weight" in converted + assert "lm_head.weight" in converted + assert len(converted) == 3 + + def test_convert_layer_weights(self): + """Test layer weight conversion.""" + hf_weights = { + "model.layers.0.self_attn.q_proj.weight": np.random.randn(64, 64).astype(np.float32), + "model.layers.0.self_attn.k_proj.weight": np.random.randn(64, 64).astype(np.float32), + "model.layers.0.mlp.gate_proj.weight": np.random.randn(128, 64).astype(np.float32), + } + + converted = convert_hf_weights(hf_weights) + + assert "model.layers.0.self_attn.q_proj.weight" in converted + assert "model.layers.0.self_attn.k_proj.weight" in converted + assert "model.layers.0.mlp.gate_proj.weight" in converted + + def test_convert_skips_unmapped_weights(self): + """Test that unmapped weights are skipped.""" + hf_weights = { + "model.embed_tokens.weight": np.random.randn(1000, 64).astype(np.float32), + "rotary_emb.inv_freq": np.random.randn(32).astype(np.float32), # Should be skipped + } + + converted = convert_hf_weights(hf_weights) + + assert "model.embed_tokens.weight" in converted + assert "rotary_emb.inv_freq" not in converted + assert len(converted) == 1 + + def test_convert_with_tied_embeddings(self): + """Test conversion with tied word embeddings.""" + hf_weights = { + "model.embed_tokens.weight": np.random.randn(1000, 64).astype(np.float32), + "lm_head.weight": np.random.randn(1000, 64).astype(np.float32), + } + + converted = convert_hf_weights(hf_weights, tie_word_embeddings=True) + + assert "model.embed_tokens.weight" in converted + assert "lm_head.weight" not in converted # Should be skipped when tied + assert len(converted) == 1 + + +class TestConvertMlxCommunityWeights: + """Tests for MLX community weight conversion.""" + + def test_convert_mlx_community_passthrough(self): + """Test that MLX community weights pass through.""" + weights = { + "model.embed_tokens.weight": np.random.randn(1000, 64).astype(np.float32), + "model.layers.0.self_attn.q_proj.weight": np.random.randn(64, 64).astype(np.float32), + } + + converted = convert_mlx_community_weights(weights) + + assert "model.embed_tokens.weight" in converted + assert "model.layers.0.self_attn.q_proj.weight" in converted + + +class TestConvertMlxToHf: + """Tests for MLX to HuggingFace conversion.""" + + def test_convert_numpy_weights(self): + """Test conversion of numpy arrays.""" + mlx_weights = { + "model.embed_tokens.weight": np.random.randn(1000, 64).astype(np.float32), + "model.norm.weight": np.random.randn(64).astype(np.float32), + } + + converted = convert_mlx_to_hf(mlx_weights) + + assert "model.embed_tokens.weight" in converted + assert "model.norm.weight" in converted + + def test_convert_layer_weights_to_hf(self): + """Test layer weight conversion to HF format.""" + mlx_weights = { + "model.layers.0.self_attn.q_proj.weight": np.random.randn(64, 64).astype(np.float32), + "model.layers.0.mlp.gate_proj.weight": np.random.randn(128, 64).astype(np.float32), + } + + converted = convert_mlx_to_hf(mlx_weights) + + assert "model.layers.0.self_attn.q_proj.weight" in converted + assert "model.layers.0.mlp.gate_proj.weight" in converted + + +class TestMlxArrayConversion: + """Tests for MLX array conversion in convert_mlx_to_hf.""" + + def test_convert_mlx_array_passthrough(self): + """Test that MLX arrays are passed through (can be converted via np.asarray).""" + import mlx.core as mx + + mlx_weights = { + "model.embed_tokens.weight": mx.random.normal((100, 64)), + } + + converted = convert_mlx_to_hf(mlx_weights) + + # MLX arrays are passed through since they lack .numpy() and __array__ + # The current implementation passes them through unchanged + # To get numpy, caller should use np.asarray on the result + assert "model.embed_tokens.weight" in converted + # Result can be converted to numpy via np.asarray + result = np.asarray(converted["model.embed_tokens.weight"]) + assert isinstance(result, np.ndarray) + assert result.shape == (100, 64) + + def test_convert_object_with_array_interface(self): + """Test conversion of objects with __array__ method.""" + + class ArrayLike: + def __init__(self, data): + self._data = data + + def __array__(self): + return self._data + + mlx_weights = { + "model.embed_tokens.weight": ArrayLike(np.random.randn(100, 64).astype(np.float32)), + } + + converted = convert_mlx_to_hf(mlx_weights) + + assert "model.embed_tokens.weight" in converted + assert isinstance(converted["model.embed_tokens.weight"], np.ndarray) + + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_get_num_params(self): + """Test parameter counting.""" + weights = { + "weight1": np.random.randn(100, 64).astype(np.float32), # 6400 + "weight2": np.random.randn(64).astype(np.float32), # 64 + } + + total = get_num_params(weights) + assert total == 6400 + 64 + + def test_get_num_params_empty(self): + """Test parameter counting with empty dict.""" + total = get_num_params({}) + assert total == 0 + + def test_print_weight_shapes(self, capsys): + """Test weight shape printing.""" + weights = { + "a.weight": np.random.randn(100, 64).astype(np.float32), + "b.weight": np.random.randn(64).astype(np.float32), + } + + print_weight_shapes(weights) + + captured = capsys.readouterr() + assert "a.weight: (100, 64)" in captured.out + assert "b.weight: (64,)" in captured.out + + +class TestWeightMaps: + """Tests for weight map constants.""" + + def test_gemma_weight_map_completeness(self): + """Test that GEMMA_WEIGHT_MAP has expected keys.""" + assert "model.embed_tokens.weight" in GEMMA_WEIGHT_MAP + assert "model.norm.weight" in GEMMA_WEIGHT_MAP + assert "lm_head.weight" in GEMMA_WEIGHT_MAP + + def test_gemma_layer_patterns_completeness(self): + """Test that GEMMA_LAYER_PATTERNS has expected keys.""" + # Attention weights + assert "self_attn.q_proj.weight" in GEMMA_LAYER_PATTERNS + assert "self_attn.k_proj.weight" in GEMMA_LAYER_PATTERNS + assert "self_attn.v_proj.weight" in GEMMA_LAYER_PATTERNS + assert "self_attn.o_proj.weight" in GEMMA_LAYER_PATTERNS + + # QK norms + assert "self_attn.q_norm.weight" in GEMMA_LAYER_PATTERNS + assert "self_attn.k_norm.weight" in GEMMA_LAYER_PATTERNS + + # MLP weights + assert "mlp.gate_proj.weight" in GEMMA_LAYER_PATTERNS + assert "mlp.up_proj.weight" in GEMMA_LAYER_PATTERNS + assert "mlp.down_proj.weight" in GEMMA_LAYER_PATTERNS + + # Layer norms + assert "input_layernorm.weight" in GEMMA_LAYER_PATTERNS + assert "post_attention_layernorm.weight" in GEMMA_LAYER_PATTERNS + assert "pre_feedforward_layernorm.weight" in GEMMA_LAYER_PATTERNS + assert "post_feedforward_layernorm.weight" in GEMMA_LAYER_PATTERNS diff --git a/tests/models_v2/families/gemma/test_model.py b/tests/models_v2/families/gemma/test_model.py index 97dae591..86620ba8 100644 --- a/tests/models_v2/families/gemma/test_model.py +++ b/tests/models_v2/families/gemma/test_model.py @@ -80,7 +80,7 @@ def test_forward_pass(self, tiny_config): def test_sliding_vs_global_layer(self, tiny_config): """Test different attention types for sliding vs global layers.""" sliding_attn = GemmaAttention(tiny_config, layer_idx=0) # sliding - global_attn = GemmaAttention(tiny_config, layer_idx=2) # global (pattern=3) + global_attn = GemmaAttention(tiny_config, layer_idx=2) # global (pattern=3) assert sliding_attn.is_sliding is True assert global_attn.is_sliding is False @@ -158,8 +158,9 @@ def test_embedding_scaling(self, tiny_config): model = GemmaModel(tiny_config) input_ids = mx.array([[1]]) - # Get raw embedding - raw_embed = model.embed_tokens(input_ids) + # Get raw embedding and verify it exists + embed = model.embed_tokens(input_ids) + assert embed.shape == (1, 1, tiny_config.hidden_size) # The model internally scales by sqrt(hidden_size) # We can verify the output is different from raw embedding @@ -287,6 +288,220 @@ def test_from_config(self, tiny_config): assert model.config == tiny_config +class TestGemmaBlockProperties: + """Tests for GemmaBlock properties.""" + + def test_block_type_property(self): + """Test block_type property returns TRANSFORMER.""" + from chuk_lazarus.models_v2.core.enums import BlockType + + config = GemmaConfig.tiny() + block = GemmaBlock(config, layer_idx=0) + + assert block.block_type == BlockType.TRANSFORMER + + def test_hidden_size_property(self): + """Test hidden_size property.""" + config = GemmaConfig.tiny() + block = GemmaBlock(config, layer_idx=0) + + assert block.hidden_size == config.hidden_size + + +class TestGemmaModelProperties: + """Tests for GemmaModel properties.""" + + def test_hidden_size_property(self): + """Test hidden_size property.""" + config = GemmaConfig.tiny() + model = GemmaModel(config) + + assert model.hidden_size == config.hidden_size + + def test_num_layers_property(self): + """Test num_layers property.""" + config = GemmaConfig.tiny() + model = GemmaModel(config) + + assert model.num_layers == config.num_hidden_layers + + def test_vocab_size_property(self): + """Test vocab_size property.""" + config = GemmaConfig.tiny() + model = GemmaModel(config) + + assert model.vocab_size == config.vocab_size + + def test_get_input_embeddings(self): + """Test get_input_embeddings method.""" + config = GemmaConfig.tiny() + model = GemmaModel(config) + + embeddings = model.get_input_embeddings() + assert embeddings is model.embed_tokens + + def test_set_input_embeddings(self): + """Test set_input_embeddings method.""" + import mlx.nn as nn + + config = GemmaConfig.tiny() + model = GemmaModel(config) + + new_embed = nn.Embedding(500, config.hidden_size) + model.set_input_embeddings(new_embed) + + assert model.embed_tokens is new_embed + + def test_forward_with_input_embeddings(self): + """Test forward pass using input_embeddings instead of input_ids.""" + config = GemmaConfig.tiny() + model = GemmaModel(config) + + # Create embeddings directly + input_embeddings = mx.random.normal((1, 5, config.hidden_size)) + # Use dummy input_ids (not used when input_embeddings provided) + input_ids = mx.array([[0, 0, 0, 0, 0]]) + + output = model(input_ids, input_embeddings=input_embeddings) + + assert output.last_hidden_state.shape == (1, 5, config.hidden_size) + + def test_sliding_window_pattern_1(self): + """Test model with sliding_window_pattern=1 (no sliding window).""" + config = GemmaConfig.tiny() + config.sliding_window_pattern = 1 # All global attention + model = GemmaModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.last_hidden_state.shape == (1, 5, config.hidden_size) + + +class TestGemmaForCausalLMExtended: + """Additional tests for GemmaForCausalLM.""" + + def test_sanitize_with_lm_head(self): + """Test sanitize keeps tie_word_embeddings False when lm_head present.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + + weights = {"lm_head.weight": mx.zeros((config.vocab_size, config.hidden_size))} + sanitized = model.sanitize(weights) + + assert model.tie_word_embeddings is False + assert sanitized == weights + + def test_sanitize_without_lm_head(self): + """Test sanitize sets tie_word_embeddings True when no lm_head.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + + weights = {"model.embed_tokens.weight": mx.zeros((config.vocab_size, config.hidden_size))} + sanitized = model.sanitize(weights) + + assert model.tie_word_embeddings is True + assert sanitized == weights + + def test_tied_embeddings_forward(self): + """Test forward pass with tied embeddings.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + model.tie_word_embeddings = True + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + # Should use embed_tokens.as_linear for logits + assert output.logits.shape == (1, 5, config.vocab_size) + + def test_layers_property(self): + """Test layers property returns transformer layers.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + + layers = model.layers + assert layers is model.model.layers + assert len(layers) == config.num_hidden_layers + + def test_generate_with_top_k(self): + """Test generation with top-k sampling.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + input_ids = mx.array([[1, 2, 3]]) + + generated = model.generate( + input_ids, + max_new_tokens=3, + temperature=0.8, + top_k=10, + ) + + assert generated.shape[1] == 6 # 3 prompt + 3 generated + + def test_generate_with_top_p(self): + """Test generation with nucleus (top-p) sampling.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + input_ids = mx.array([[1, 2, 3]]) + + generated = model.generate( + input_ids, + max_new_tokens=3, + temperature=0.8, + top_p=0.9, + ) + + assert generated.shape[1] == 6 # 3 prompt + 3 generated + + def test_generate_with_top_k_and_top_p(self): + """Test generation with both top-k and top-p.""" + config = GemmaConfig.tiny() + model = GemmaForCausalLM(config) + input_ids = mx.array([[1, 2, 3]]) + + generated = model.generate( + input_ids, + max_new_tokens=3, + temperature=0.8, + top_k=50, + top_p=0.95, + ) + + assert generated.shape[1] == 6 # 3 prompt + 3 generated + + +class TestClipResidual: + """Tests for clip_residual function.""" + + def test_clip_residual_non_float16(self): + """Test clip_residual with non-float16 input.""" + from chuk_lazarus.models_v2.families.gemma.model import clip_residual + + x = mx.random.normal((2, 5, 64)) # float32 by default + y = mx.random.normal((2, 5, 64)) + + result = clip_residual(x, y) + + # Should just add without clipping + expected = x + y + assert mx.allclose(result, expected) + + def test_clip_residual_float16(self): + """Test clip_residual with float16 input (overflow protection).""" + from chuk_lazarus.models_v2.families.gemma.model import clip_residual + + x = mx.random.normal((2, 5, 64)).astype(mx.float16) + y = mx.random.normal((2, 5, 64)).astype(mx.float16) + + result = clip_residual(x, y) + + # Should be float16 output + assert result.dtype == mx.float16 + # Shape should be preserved + assert result.shape == x.shape + + class TestFunctionGemma: """Tests specific to FunctionGemma usage.""" diff --git a/tests/models_v2/families/granite/__init__.py b/tests/models_v2/families/granite/__init__.py new file mode 100644 index 00000000..bd974687 --- /dev/null +++ b/tests/models_v2/families/granite/__init__.py @@ -0,0 +1 @@ +"""Tests for Granite model family.""" diff --git a/tests/models_v2/families/granite/test_config.py b/tests/models_v2/families/granite/test_config.py new file mode 100644 index 00000000..113b8ef2 --- /dev/null +++ b/tests/models_v2/families/granite/test_config.py @@ -0,0 +1,228 @@ +""" +Tests for Granite configuration. +""" + +from chuk_lazarus.models_v2.families.granite.config import ( + GraniteConfig, + GraniteHybridConfig, +) + + +class TestGraniteConfig: + """Tests for GraniteConfig.""" + + def test_defaults(self): + """Test default configuration values.""" + config = GraniteConfig() + + assert config.model_type == "granite" + assert config.embedding_multiplier == 12.0 + assert config.attention_multiplier == 1.0 + assert config.residual_multiplier == 1.0 + assert config.logits_scaling == 1.0 + assert config.hidden_act == "silu" + assert config.rope_theta == 10000.0 + assert config.rms_norm_eps == 1e-5 + assert config.attention_dropout == 0.0 + assert config.attention_bias is False + assert config.mlp_bias is False + assert config.rope_scaling is None + + def test_granite_3_8b(self): + """Test Granite 3.0 8B preset.""" + config = GraniteConfig.granite_3_8b() + + assert config.vocab_size == 49155 + assert config.hidden_size == 4096 + assert config.num_hidden_layers == 40 + assert config.num_attention_heads == 32 + assert config.num_key_value_heads == 8 + assert config.intermediate_size == 12800 + assert config.max_position_embeddings == 4096 + assert config.embedding_multiplier == 12.0 + assert config.attention_multiplier == 0.0078125 + assert config.residual_multiplier == 0.22 + assert config.logits_scaling == 16.0 + assert config.attention_dropout == 0.1 + assert config.tie_word_embeddings is True + + def test_granite_3_1_2b(self): + """Test Granite 3.1 2B preset.""" + config = GraniteConfig.granite_3_1_2b() + + assert config.vocab_size == 49155 + assert config.hidden_size == 2048 + assert config.num_hidden_layers == 40 + assert config.num_attention_heads == 32 + assert config.num_key_value_heads == 8 + assert config.intermediate_size == 8192 + assert config.max_position_embeddings == 131072 + assert config.rope_theta == 5000000.0 + assert config.embedding_multiplier == 12.0 + assert config.attention_multiplier == 0.015625 + assert config.logits_scaling == 8.0 + + def test_granite_3_1_8b(self): + """Test Granite 3.1 8B preset.""" + config = GraniteConfig.granite_3_1_8b() + + assert config.vocab_size == 49155 + assert config.hidden_size == 4096 + assert config.num_hidden_layers == 40 + assert config.num_attention_heads == 32 + assert config.num_key_value_heads == 8 + assert config.intermediate_size == 12800 + assert config.max_position_embeddings == 131072 + assert config.rope_theta == 5000000.0 + + def test_tiny(self): + """Test tiny config for testing.""" + config = GraniteConfig.tiny() + + assert config.vocab_size == 1000 + assert config.hidden_size == 64 + assert config.num_hidden_layers == 4 + assert config.num_attention_heads == 4 + assert config.num_key_value_heads == 2 + assert config.intermediate_size == 128 + assert config.embedding_multiplier == 1.0 + assert config.attention_multiplier == 1.0 + assert config.residual_multiplier == 1.0 + assert config.logits_scaling == 1.0 + + +class TestGraniteHybridConfig: + """Tests for GraniteHybridConfig.""" + + def test_defaults(self): + """Test default configuration values.""" + config = GraniteHybridConfig() + + assert config.model_type == "granitemoehybrid" + assert config.position_embedding_type == "rope" + assert config.embedding_multiplier == 12.0 + assert config.attention_multiplier == 0.0078125 + assert config.residual_multiplier == 0.22 + assert config.logits_scaling == 6.0 + assert len(config.layer_types) == 40 + assert all(t == "attention" for t in config.layer_types) + assert config.normalization_function == "rmsnorm" + + def test_mamba_defaults(self): + """Test Mamba-2 default settings.""" + config = GraniteHybridConfig() + + assert config.mamba_d_state == 128 + assert config.mamba_d_conv == 4 + assert config.mamba_expand == 2 + assert config.mamba_n_heads == 48 + assert config.mamba_d_head == 64 + assert config.mamba_n_groups == 1 + assert config.mamba_chunk_size == 256 + assert config.mamba_conv_bias is True + assert config.mamba_proj_bias is False + + def test_moe_defaults(self): + """Test MoE default settings.""" + config = GraniteHybridConfig() + + assert config.num_local_experts == 0 + assert config.num_experts_per_tok == 0 + assert config.shared_intermediate_size == 0 + assert config.router_aux_loss_coef == 0.0 + assert config.output_router_logits is False + + def test_is_moe_property(self): + """Test is_moe property.""" + # Dense model + dense_config = GraniteHybridConfig(num_local_experts=0, num_experts_per_tok=0) + assert dense_config.is_moe is False + + # MoE model + moe_config = GraniteHybridConfig(num_local_experts=4, num_experts_per_tok=2) + assert moe_config.is_moe is True + + # Only num_local_experts set + partial_config = GraniteHybridConfig(num_local_experts=4, num_experts_per_tok=0) + assert partial_config.is_moe is False + + def test_num_mamba_layers_property(self): + """Test num_mamba_layers property.""" + # All attention + attn_config = GraniteHybridConfig(layer_types=["attention"] * 10) + assert attn_config.num_mamba_layers == 0 + assert attn_config.num_attention_layers == 10 + + # Mixed + mixed_config = GraniteHybridConfig(layer_types=["mamba", "mamba", "attention", "mamba"]) + assert mixed_config.num_mamba_layers == 3 + assert mixed_config.num_attention_layers == 1 + + def test_granite_4_micro(self): + """Test Granite 4.0 Micro preset.""" + config = GraniteHybridConfig.granite_4_micro() + + assert config.vocab_size == 100352 + assert config.hidden_size == 2560 + assert config.num_hidden_layers == 40 + assert config.num_attention_heads == 40 + assert config.num_key_value_heads == 8 + assert config.intermediate_size == 8192 + assert all(t == "attention" for t in config.layer_types) + assert config.position_embedding_type == "rope" + assert config.num_local_experts == 0 + assert config.is_moe is False + assert config.tie_word_embeddings is True + + def test_granite_4_tiny(self): + """Test Granite 4.0 Tiny preset.""" + config = GraniteHybridConfig.granite_4_tiny() + + assert config.vocab_size == 49160 + assert config.hidden_size == 1536 + assert config.num_hidden_layers == 40 + assert config.num_attention_heads == 12 + assert config.num_key_value_heads == 4 + assert config.position_embedding_type == "nope" + assert config.num_local_experts == 62 + assert config.num_experts_per_tok == 6 + assert config.is_moe is True + assert config.num_mamba_layers > 0 + assert config.num_attention_layers == 4 + + def test_granite_4_small(self): + """Test Granite 4.0 Small preset.""" + config = GraniteHybridConfig.granite_4_small() + + assert config.vocab_size == 49160 + assert config.hidden_size == 3072 + assert config.num_hidden_layers == 40 + assert config.num_local_experts == 62 + assert config.num_experts_per_tok == 6 + assert config.is_moe is True + + def test_tiny(self): + """Test tiny config for testing.""" + config = GraniteHybridConfig.tiny() + + assert config.vocab_size == 1000 + assert config.hidden_size == 64 + assert config.num_hidden_layers == 4 + assert config.num_attention_heads == 4 + assert config.num_key_value_heads == 2 + assert config.layer_types == ["mamba", "mamba", "attention", "mamba"] + assert config.num_mamba_layers == 3 + assert config.num_attention_layers == 1 + assert config.is_moe is False + + def test_tiny_moe(self): + """Test tiny MoE config for testing.""" + config = GraniteHybridConfig.tiny_moe() + + assert config.vocab_size == 1000 + assert config.hidden_size == 64 + assert config.num_hidden_layers == 4 + assert config.num_local_experts == 4 + assert config.num_experts_per_tok == 2 + assert config.shared_intermediate_size == 64 + assert config.is_moe is True diff --git a/tests/models_v2/families/granite/test_hybrid.py b/tests/models_v2/families/granite/test_hybrid.py new file mode 100644 index 00000000..f484a664 --- /dev/null +++ b/tests/models_v2/families/granite/test_hybrid.py @@ -0,0 +1,518 @@ +""" +Tests for Granite 4.x hybrid model. +""" + +import mlx.core as mx +import mlx.nn as nn + +from chuk_lazarus.models_v2.families.granite.config import GraniteHybridConfig +from chuk_lazarus.models_v2.families.granite.hybrid import ( + Granite4, + GraniteHybrid, + GraniteHybridAttention, + GraniteHybridBlock, + GraniteHybridForCausalLM, + GraniteHybridModel, + GraniteHybridMoE, + GraniteMamba2Block, +) + + +class TestGraniteMamba2Block: + """Tests for GraniteMamba2Block.""" + + def test_creation(self): + """Test block creation.""" + config = GraniteHybridConfig.tiny() + block = GraniteMamba2Block(config, layer_idx=0) + + assert block.hidden_size == 64 + assert block.layer_idx == 0 + assert block.d_state == config.mamba_d_state + assert block.d_conv == config.mamba_d_conv + assert block.expand == config.mamba_expand + assert block.n_heads == config.mamba_n_heads + assert block.d_head == config.mamba_d_head + + def test_forward_pass(self): + """Test forward pass.""" + config = GraniteHybridConfig.tiny() + block = GraniteMamba2Block(config) + + x = mx.random.normal((2, 10, 64)) + output, cache = block(x) + + assert output.shape == (2, 10, 64) + # Cache may be None for simplified implementation + + def test_forward_different_seq_lengths(self): + """Test forward with different sequence lengths.""" + config = GraniteHybridConfig.tiny() + block = GraniteMamba2Block(config) + + for seq_len in [1, 5, 10]: + x = mx.random.normal((2, seq_len, 64)) + output, _ = block(x) + assert output.shape == (2, seq_len, 64) + + def test_selective_scan(self): + """Test selective scan method.""" + config = GraniteHybridConfig.tiny() + block = GraniteMamba2Block(config) + + batch_size, seq_len = 2, 5 + n_heads = config.mamba_n_heads + d_per_head = block.d_inner // n_heads + + x = mx.random.normal((batch_size, seq_len, n_heads, d_per_head)) + dt = mx.random.normal((batch_size, seq_len, n_heads)) + B = mx.random.normal((batch_size, seq_len, n_heads, config.mamba_d_state)) + C = mx.random.normal((batch_size, seq_len, n_heads, config.mamba_d_state)) + + output = block._selective_scan(x, dt, B, C) + assert output.shape == (batch_size, seq_len, n_heads, d_per_head) + + +class TestGraniteHybridAttention: + """Tests for GraniteHybridAttention.""" + + def test_creation(self): + """Test attention creation.""" + config = GraniteHybridConfig.tiny() + attn = GraniteHybridAttention(config, layer_idx=0) + + assert attn.hidden_size == 64 + assert attn.num_heads == 4 + assert attn.num_kv_heads == 2 + assert attn.head_dim == 16 + assert attn.attention_multiplier == 1.0 + + def test_creation_with_rope(self): + """Test attention creation with RoPE.""" + config = GraniteHybridConfig.tiny() + config.position_embedding_type = "rope" + attn = GraniteHybridAttention(config) + + assert attn.use_rope is True + + def test_creation_without_rope(self): + """Test attention creation without RoPE.""" + config = GraniteHybridConfig.tiny() + config.position_embedding_type = "nope" + attn = GraniteHybridAttention(config) + + assert attn.use_rope is False + + def test_forward_pass(self): + """Test forward pass.""" + config = GraniteHybridConfig.tiny() + attn = GraniteHybridAttention(config) + + x = mx.random.normal((2, 10, 64)) + output, cache = attn(x) + + assert output.shape == (2, 10, 64) + assert cache is not None + + def test_forward_with_mask(self): + """Test forward with mask.""" + config = GraniteHybridConfig.tiny() + attn = GraniteHybridAttention(config) + + x = mx.random.normal((2, 10, 64)) + mask = nn.MultiHeadAttention.create_additive_causal_mask(10) + output, cache = attn(x, mask=mask) + + assert output.shape == (2, 10, 64) + + def test_forward_with_cache(self): + """Test forward with cache.""" + config = GraniteHybridConfig.tiny() + attn = GraniteHybridAttention(config) + + # First pass + x1 = mx.random.normal((2, 10, 64)) + _, cache = attn(x1) + + # Second pass + x2 = mx.random.normal((2, 1, 64)) + output, new_cache = attn(x2, cache=cache) + + assert output.shape == (2, 1, 64) + + def test_kv_repeat(self): + """Test KV head repeat for GQA.""" + config = GraniteHybridConfig.tiny() + config.num_attention_heads = 8 + config.num_key_value_heads = 2 + attn = GraniteHybridAttention(config) + + assert attn.n_rep == 4 + + x = mx.random.normal((2, 5, 64)) + output, _ = attn(x) + assert output.shape == (2, 5, 64) + + +class TestGraniteHybridMoE: + """Tests for GraniteHybridMoE.""" + + def test_creation(self): + """Test MoE creation.""" + config = GraniteHybridConfig.tiny_moe() + moe = GraniteHybridMoE(config) + + assert moe.hidden_size == 64 + assert moe.num_experts == 4 + assert moe.num_experts_per_tok == 2 + assert moe.has_shared_expert is True + + def test_creation_without_shared_expert(self): + """Test MoE creation without shared expert.""" + config = GraniteHybridConfig.tiny_moe() + config.shared_intermediate_size = 0 + moe = GraniteHybridMoE(config) + + assert moe.has_shared_expert is False + + def test_forward_pass(self): + """Test forward pass.""" + config = GraniteHybridConfig.tiny_moe() + moe = GraniteHybridMoE(config) + + x = mx.random.normal((2, 10, 64)) + output = moe(x) + + assert output.shape == (2, 10, 64) + + def test_forward_without_shared_expert(self): + """Test forward without shared expert.""" + config = GraniteHybridConfig.tiny_moe() + config.shared_intermediate_size = 0 + moe = GraniteHybridMoE(config) + + x = mx.random.normal((2, 10, 64)) + output = moe(x) + + assert output.shape == (2, 10, 64) + + +class TestGraniteHybridBlock: + """Tests for GraniteHybridBlock.""" + + def test_creation_attention(self): + """Test block creation with attention.""" + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config, layer_idx=0, layer_type="attention") + + assert block.layer_type == "attention" + assert block.residual_multiplier == 1.0 + + def test_creation_mamba(self): + """Test block creation with mamba.""" + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config, layer_idx=0, layer_type="mamba") + + assert block.layer_type == "mamba" + + def test_block_type_property_attention(self): + """Test block_type property for attention.""" + from chuk_lazarus.models_v2.core.enums import BlockType + + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config, layer_type="attention") + + assert block.block_type == BlockType.TRANSFORMER + + def test_block_type_property_mamba(self): + """Test block_type property for mamba.""" + from chuk_lazarus.models_v2.core.enums import BlockType + + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config, layer_type="mamba") + + assert block.block_type == BlockType.MAMBA + + def test_hidden_size_property(self): + """Test hidden_size property.""" + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config) + + assert block.hidden_size == 64 + + def test_forward_attention(self): + """Test forward pass for attention block.""" + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config, layer_type="attention") + + x = mx.random.normal((2, 10, 64)) + mask = nn.MultiHeadAttention.create_additive_causal_mask(10) + output = block(x, mask=mask) + + assert output.hidden_states.shape == (2, 10, 64) + + def test_forward_mamba(self): + """Test forward pass for mamba block.""" + config = GraniteHybridConfig.tiny() + block = GraniteHybridBlock(config, layer_type="mamba") + + x = mx.random.normal((2, 10, 64)) + output = block(x) + + assert output.hidden_states.shape == (2, 10, 64) + + def test_forward_with_moe(self): + """Test forward pass with MoE.""" + config = GraniteHybridConfig.tiny_moe() + block = GraniteHybridBlock(config, layer_type="attention") + + x = mx.random.normal((2, 10, 64)) + output = block(x) + + assert output.hidden_states.shape == (2, 10, 64) + + +class TestGraniteHybridModel: + """Tests for GraniteHybridModel backbone.""" + + def test_creation(self): + """Test model creation.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + assert model.hidden_size == 64 + assert model.num_layers == 4 + assert model.vocab_size == 1000 + assert model.embedding_multiplier == 1.0 + + def test_forward_pass(self): + """Test forward pass.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.last_hidden_state.shape == (1, 5, 64) + assert output.cache is not None + assert len(output.cache) == 4 + + def test_forward_with_output_hidden_states(self): + """Test forward with output_hidden_states=True.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids, output_hidden_states=True) + + assert output.hidden_states is not None + assert len(output.hidden_states) == 5 + + def test_forward_with_attention_mask(self): + """Test forward with attention mask.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(5) + output = model(input_ids, attention_mask=mask) + + assert output.last_hidden_state.shape == (1, 5, 64) + + def test_forward_with_cache(self): + """Test forward with cache.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + # First pass + input_ids = mx.array([[1, 2, 3, 4, 5]]) + out1 = model(input_ids) + + # Second pass + next_token = mx.array([[6]]) + out2 = model(next_token, cache=out1.cache) + + assert out2.last_hidden_state.shape == (1, 1, 64) + + def test_get_input_embeddings(self): + """Test get_input_embeddings method.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + embeddings = model.get_input_embeddings() + assert embeddings is model.embed_tokens + + def test_set_input_embeddings(self): + """Test set_input_embeddings method.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridModel(config) + + new_embed = nn.Embedding(500, 64) + model.set_input_embeddings(new_embed) + + assert model.embed_tokens is new_embed + + +class TestGraniteHybridForCausalLM: + """Tests for GraniteHybridForCausalLM.""" + + def test_creation(self): + """Test model creation.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + assert model.config is config + assert model.logits_scaling == 1.0 + + def test_creation_tied_embeddings(self): + """Test model with tied embeddings.""" + config = GraniteHybridConfig.tiny() + config.tie_word_embeddings = True + model = GraniteHybridForCausalLM(config) + + assert model.lm_head is not None + + def test_creation_untied_embeddings(self): + """Test model without tied embeddings.""" + config = GraniteHybridConfig.tiny() + config.tie_word_embeddings = False + model = GraniteHybridForCausalLM(config) + + assert model.lm_head is not None + + def test_backbone_property(self): + """Test backbone property.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + assert model.backbone is model.model + + def test_forward_pass(self): + """Test forward pass.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.logits.shape == (1, 5, 1000) + assert output.loss is None + assert output.cache is not None + + def test_forward_with_labels(self): + """Test forward with labels.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + output = model(input_ids, labels=labels) + + assert output.loss is not None + + def test_forward_with_logits_scaling(self): + """Test logits scaling.""" + config = GraniteHybridConfig.tiny() + config.logits_scaling = 2.0 + model = GraniteHybridForCausalLM(config) + + assert model.logits_scaling == 2.0 + + input_ids = mx.array([[1, 2, 3]]) + output = model(input_ids) + assert output.logits is not None + + def test_generate_basic(self): + """Test basic generation.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=5) + + assert generated.shape[0] == 1 + assert generated.shape[1] >= 3 + + def test_generate_with_temperature(self): + """Test generation with temperature.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, temperature=0.5) + + assert generated.shape[1] >= 3 + + def test_generate_with_top_k(self): + """Test generation with top_k.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, top_k=10) + + assert generated.shape[1] >= 3 + + def test_generate_with_stop_tokens(self): + """Test generation with stop tokens.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=10, stop_tokens=[999]) + + assert generated.shape[1] >= 3 + + def test_from_config(self): + """Test from_config class method.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM.from_config(config) + + assert isinstance(model, GraniteHybridForCausalLM) + + def test_aliases(self): + """Test model aliases.""" + config = GraniteHybridConfig.tiny() + + model1 = GraniteHybrid(config) + assert isinstance(model1, GraniteHybridForCausalLM) + + model2 = Granite4(config) + assert isinstance(model2, GraniteHybridForCausalLM) + + +class TestGraniteHybridGradients: + """Tests for gradient flow.""" + + def test_forward_backward(self): + """Test forward-backward pass.""" + config = GraniteHybridConfig.tiny() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + + def loss_fn(model, input_ids, labels): + output = model(input_ids, labels=labels) + return output.loss + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) + + assert loss.item() > 0 + assert any(g is not None for g in grads.values()) + + def test_moe_loss_computation(self): + """Test that MoE model computes loss (gradient tests skipped due to MLX ArgsSort limitation).""" + # Note: Full gradient tests are skipped because MLX's ArgSort operation + # (used in MoE top-k routing) does not support VJP. + # This is a known limitation: "Not implemented for ArgSort" + config = GraniteHybridConfig.tiny_moe() + model = GraniteHybridForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + + output = model(input_ids, labels=labels) + assert output.loss is not None + assert output.loss.item() > 0 diff --git a/tests/models_v2/families/granite/test_model.py b/tests/models_v2/families/granite/test_model.py new file mode 100644 index 00000000..7bac0479 --- /dev/null +++ b/tests/models_v2/families/granite/test_model.py @@ -0,0 +1,451 @@ +""" +Tests for Granite model. +""" + +import mlx.core as mx +import mlx.nn as nn + +from chuk_lazarus.models_v2.families.granite.config import GraniteConfig +from chuk_lazarus.models_v2.families.granite.model import ( + Granite, + GraniteAttention, + GraniteBlock, + GraniteForCausalLM, + GraniteModel, +) + + +class TestGraniteAttention: + """Tests for GraniteAttention.""" + + def test_creation(self): + """Test attention creation.""" + config = GraniteConfig.tiny() + attn = GraniteAttention(config) + + assert attn.hidden_size == 64 + assert attn.num_heads == 4 + assert attn.num_kv_heads == 2 + assert attn.head_dim == 16 + assert attn.n_rep == 2 + assert attn.attention_multiplier == 1.0 + + def test_forward_pass(self): + """Test attention forward pass.""" + config = GraniteConfig.tiny() + attn = GraniteAttention(config, layer_idx=0) + + x = mx.random.normal((2, 10, 64)) + output, cache = attn(x) + + assert output.shape == (2, 10, 64) + assert cache is not None + k, v = cache + assert k.shape[0] == 2 + assert k.shape[2] == 10 + + def test_forward_with_mask(self): + """Test attention with mask.""" + config = GraniteConfig.tiny() + attn = GraniteAttention(config) + + x = mx.random.normal((2, 10, 64)) + mask = nn.MultiHeadAttention.create_additive_causal_mask(10) + output, cache = attn(x, mask=mask) + + assert output.shape == (2, 10, 64) + + def test_forward_with_cache(self): + """Test attention with KV cache.""" + config = GraniteConfig.tiny() + attn = GraniteAttention(config) + + # First pass + x1 = mx.random.normal((2, 10, 64)) + _, cache = attn(x1) + + # Second pass with cache + x2 = mx.random.normal((2, 1, 64)) + output, new_cache = attn(x2, cache=cache) + + assert output.shape == (2, 1, 64) + k, v = new_cache + assert k.shape[2] == 11 # 10 + 1 + + def test_repeat_kv(self): + """Test KV repeat method.""" + config = GraniteConfig.tiny() + attn = GraniteAttention(config) + + x = mx.random.normal((2, 2, 10, 16)) + + # n_rep = 1, should return same + result = attn._repeat_kv(x, n_rep=1) + assert result.shape == x.shape + + # n_rep = 2 + result = attn._repeat_kv(x, n_rep=2) + assert result.shape == (2, 4, 10, 16) + + def test_attention_multiplier_applied(self): + """Test that attention multiplier is applied.""" + config = GraniteConfig.tiny() + config.attention_multiplier = 0.5 + + attn = GraniteAttention(config) + assert attn.attention_multiplier == 0.5 + + x = mx.random.normal((1, 5, 64)) + output1, _ = attn(x) + + # Verify multiplier with different config also works + config2 = GraniteConfig.tiny() + config2.attention_multiplier = 1.0 + attn2 = GraniteAttention(config2) + output2, _ = attn2(x) + # Just verify both run without error - exact comparison tricky due to random init + assert output1.shape == output2.shape + + +class TestGraniteBlock: + """Tests for GraniteBlock.""" + + def test_creation(self): + """Test block creation.""" + config = GraniteConfig.tiny() + block = GraniteBlock(config, layer_idx=0) + + assert block.hidden_size == 64 + assert block.layer_idx == 0 + assert block.residual_multiplier == 1.0 + + def test_block_type(self): + """Test block_type property.""" + from chuk_lazarus.models_v2.core.enums import BlockType + + config = GraniteConfig.tiny() + block = GraniteBlock(config) + + assert block.block_type == BlockType.TRANSFORMER + + def test_forward_pass(self): + """Test block forward pass.""" + config = GraniteConfig.tiny() + block = GraniteBlock(config) + + x = mx.random.normal((2, 10, 64)) + output = block(x) + + assert output.hidden_states.shape == (2, 10, 64) + assert output.cache is not None + + def test_forward_with_mask(self): + """Test block with mask.""" + config = GraniteConfig.tiny() + block = GraniteBlock(config) + + x = mx.random.normal((2, 10, 64)) + mask = nn.MultiHeadAttention.create_additive_causal_mask(10) + output = block(x, mask=mask) + + assert output.hidden_states.shape == (2, 10, 64) + + def test_forward_with_cache(self): + """Test block with cache.""" + config = GraniteConfig.tiny() + block = GraniteBlock(config) + + # First pass + x1 = mx.random.normal((2, 10, 64)) + out1 = block(x1) + + # Second pass with cache + x2 = mx.random.normal((2, 1, 64)) + out2 = block(x2, cache=out1.cache) + + assert out2.hidden_states.shape == (2, 1, 64) + + def test_residual_multiplier(self): + """Test residual multiplier is applied.""" + config = GraniteConfig.tiny() + config.residual_multiplier = 0.5 + + block = GraniteBlock(config) + assert block.residual_multiplier == 0.5 + + +class TestGraniteModel: + """Tests for GraniteModel backbone.""" + + def test_creation(self): + """Test model creation.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + assert model.hidden_size == 64 + assert model.num_layers == 4 + assert model.vocab_size == 1000 + assert model.embedding_multiplier == 1.0 + + def test_forward_pass(self): + """Test model forward pass.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.last_hidden_state.shape == (1, 5, 64) + assert output.cache is not None + assert len(output.cache) == 4 + assert output.hidden_states is None + + def test_forward_with_output_hidden_states(self): + """Test model with output_hidden_states=True.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids, output_hidden_states=True) + + assert output.hidden_states is not None + assert len(output.hidden_states) == 5 # embeddings + 4 layers + + def test_forward_with_attention_mask(self): + """Test model with custom attention mask.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(5) + output = model(input_ids, attention_mask=mask) + + assert output.last_hidden_state.shape == (1, 5, 64) + + def test_forward_with_cache(self): + """Test model with cache.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + # First pass + input_ids = mx.array([[1, 2, 3, 4, 5]]) + out1 = model(input_ids) + + # Second pass with cache + next_token = mx.array([[6]]) + out2 = model(next_token, cache=out1.cache) + + assert out2.last_hidden_state.shape == (1, 1, 64) + + def test_get_input_embeddings(self): + """Test get_input_embeddings method.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + embeddings = model.get_input_embeddings() + assert embeddings is model.embed_tokens + + def test_set_input_embeddings(self): + """Test set_input_embeddings method.""" + config = GraniteConfig.tiny() + model = GraniteModel(config) + + new_embed = nn.Embedding(500, 64) + model.set_input_embeddings(new_embed) + + assert model.embed_tokens is new_embed + + +class TestGraniteForCausalLM: + """Tests for GraniteForCausalLM.""" + + def test_creation(self): + """Test model creation.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + assert model.config is config + assert model.logits_scaling == 1.0 + + def test_creation_tied_embeddings(self): + """Test model creation with tied embeddings.""" + config = GraniteConfig.tiny() + config.tie_word_embeddings = True + model = GraniteForCausalLM(config) + + assert model.lm_head is not None + + def test_creation_untied_embeddings(self): + """Test model creation without tied embeddings.""" + config = GraniteConfig.tiny() + config.tie_word_embeddings = False + model = GraniteForCausalLM(config) + + assert model.lm_head is not None + + def test_backbone_property(self): + """Test backbone property.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + assert model.backbone is model.model + + def test_forward_pass(self): + """Test forward pass.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.logits.shape == (1, 5, 1000) + assert output.loss is None + assert output.cache is not None + + def test_forward_with_labels(self): + """Test forward pass with labels.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + output = model(input_ids, labels=labels) + + assert output.loss is not None + assert output.loss.item() > 0 + + def test_forward_with_logits_scaling(self): + """Test logits scaling is applied.""" + config = GraniteConfig.tiny() + config.logits_scaling = 2.0 + model = GraniteForCausalLM(config) + + assert model.logits_scaling == 2.0 + + input_ids = mx.array([[1, 2, 3]]) + output = model(input_ids) + + assert output.logits is not None + + def test_forward_with_output_hidden_states(self): + """Test forward with output_hidden_states.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + output = model(input_ids, output_hidden_states=True) + + assert output.hidden_states is not None + + def test_generate_basic(self): + """Test basic generation.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=5) + + assert generated.shape[0] == 1 + assert generated.shape[1] >= 3 + + def test_generate_with_temperature(self): + """Test generation with temperature.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, temperature=0.5) + + assert generated.shape[1] >= 3 + + def test_generate_with_top_k(self): + """Test generation with top_k.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, top_k=10) + + assert generated.shape[1] >= 3 + + def test_generate_with_repetition_penalty(self): + """Test generation with repetition penalty.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, repetition_penalty=1.2) + + assert generated.shape[1] >= 3 + + def test_generate_with_stop_tokens(self): + """Test generation with stop tokens.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + # Use a token that's likely in the vocabulary + generated = model.generate(input_ids, max_new_tokens=10, stop_tokens=[999]) + + assert generated.shape[1] >= 3 + + def test_from_config(self): + """Test from_config class method.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM.from_config(config) + + assert isinstance(model, GraniteForCausalLM) + assert model.config is config + + def test_alias_granite(self): + """Test Granite alias.""" + config = GraniteConfig.tiny() + model = Granite(config) + + assert isinstance(model, GraniteForCausalLM) + + +class TestGraniteGradients: + """Tests for gradient flow.""" + + def test_forward_backward(self): + """Test forward-backward pass.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + + def loss_fn(model, input_ids, labels): + output = model(input_ids, labels=labels) + return output.loss + + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) + + assert loss.item() > 0 + assert any(g is not None for g in grads.values()) + + +class TestGraniteBatchHandling: + """Tests for batch handling.""" + + def test_different_batch_sizes(self): + """Test different batch sizes.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + for batch_size in [1, 2, 4]: + input_ids = mx.random.randint(0, config.vocab_size, (batch_size, 5)) + output = model(input_ids) + assert output.logits.shape == (batch_size, 5, config.vocab_size) + + def test_different_sequence_lengths(self): + """Test different sequence lengths.""" + config = GraniteConfig.tiny() + model = GraniteForCausalLM(config) + + for seq_len in [1, 5, 10]: + input_ids = mx.random.randint(0, config.vocab_size, (2, seq_len)) + output = model(input_ids) + assert output.logits.shape == (2, seq_len, config.vocab_size) diff --git a/tests/models_v2/families/llama4/__init__.py b/tests/models_v2/families/llama4/__init__.py new file mode 100644 index 00000000..55f031ae --- /dev/null +++ b/tests/models_v2/families/llama4/__init__.py @@ -0,0 +1 @@ +"""Tests for Llama 4 model family.""" diff --git a/tests/models_v2/families/llama4/test_attention.py b/tests/models_v2/families/llama4/test_attention.py new file mode 100644 index 00000000..25f381b7 --- /dev/null +++ b/tests/models_v2/families/llama4/test_attention.py @@ -0,0 +1,249 @@ +""" +Tests for Llama 4 attention. +""" + +import mlx.core as mx +import mlx.nn as nn + +from chuk_lazarus.models_v2.families.llama4.attention import ( + Llama4Attention, + Llama4FlexAttention, + create_llama4_attention, +) +from chuk_lazarus.models_v2.families.llama4.config import Llama4TextConfig + + +class TestLlama4Attention: + """Tests for Llama4Attention.""" + + def test_creation(self): + """Test attention creation.""" + config = Llama4TextConfig.tiny() + attn = Llama4Attention(config, layer_idx=0) + + assert attn.hidden_size == 64 + assert attn.num_heads == 4 + assert attn.num_kv_heads == 2 + assert attn.head_dim == 16 + assert attn.n_rep == 2 + assert attn.use_qk_norm is True + assert attn.scale == 16**-0.5 + + def test_is_nope_layer(self): + """Test NoPE layer detection.""" + config = Llama4TextConfig.tiny() + config.no_rope_layers = [0, 2] + + # Layer 0 is NoPE + attn0 = Llama4Attention(config, layer_idx=0) + assert attn0.is_nope_layer is True + assert attn0.rope is None + + # Layer 1 is RoPE + attn1 = Llama4Attention(config, layer_idx=1) + assert attn1.is_nope_layer is False + assert attn1.rope is not None + + # Layer 2 is NoPE + attn2 = Llama4Attention(config, layer_idx=2) + assert attn2.is_nope_layer is True + + def test_no_nope_layers(self): + """Test when no_rope_layers is None.""" + config = Llama4TextConfig.tiny() + config.no_rope_layers = None + + attn = Llama4Attention(config, layer_idx=0) + assert attn.is_nope_layer is False + assert attn.rope is not None + + def test_qk_norm_disabled(self): + """Test attention without QK normalization.""" + config = Llama4TextConfig.tiny() + config.use_qk_norm = False + + attn = Llama4Attention(config) + assert attn.use_qk_norm is False + + def test_temperature_tuning(self): + """Test attention temperature tuning.""" + config = Llama4TextConfig.tiny() + config.attn_temperature_tuning = True + + attn = Llama4Attention(config) + assert attn.attn_temperature_tuning is True + # Should have temperature parameter + assert hasattr(attn, "temperature") + + def test_forward_pass(self): + """Test forward pass.""" + config = Llama4TextConfig.tiny() + attn = Llama4Attention(config, layer_idx=1) # RoPE layer + + x = mx.random.normal((2, 10, 64)) + output, cache = attn(x) + + assert output.shape == (2, 10, 64) + assert cache is not None + k, v = cache + assert k.shape[0] == 2 + assert k.shape[2] == 10 + + def test_forward_nope_layer(self): + """Test forward pass for NoPE layer.""" + config = Llama4TextConfig.tiny() + config.no_rope_layers = [0] + + attn = Llama4Attention(config, layer_idx=0) + + x = mx.random.normal((2, 10, 64)) + output, cache = attn(x) + + assert output.shape == (2, 10, 64) + assert cache is not None + + def test_forward_with_mask(self): + """Test forward with mask.""" + config = Llama4TextConfig.tiny() + attn = Llama4Attention(config, layer_idx=1) + + x = mx.random.normal((2, 10, 64)) + mask = nn.MultiHeadAttention.create_additive_causal_mask(10) + output, cache = attn(x, mask=mask) + + assert output.shape == (2, 10, 64) + + def test_forward_with_cache(self): + """Test forward with KV cache.""" + config = Llama4TextConfig.tiny() + attn = Llama4Attention(config, layer_idx=1) + + # First pass + x1 = mx.random.normal((2, 10, 64)) + _, cache = attn(x1) + + # Second pass with cache + x2 = mx.random.normal((2, 1, 64)) + output, new_cache = attn(x2, cache=cache) + + assert output.shape == (2, 1, 64) + k, v = new_cache + assert k.shape[2] == 11 # 10 + 1 + + def test_forward_with_qk_norm(self): + """Test forward with QK normalization.""" + config = Llama4TextConfig.tiny() + config.use_qk_norm = True + + attn = Llama4Attention(config, layer_idx=1) + + x = mx.random.normal((2, 5, 64)) + output, _ = attn(x) + + assert output.shape == (2, 5, 64) + + def test_forward_with_temperature_tuning(self): + """Test forward with temperature tuning.""" + config = Llama4TextConfig.tiny() + config.attn_temperature_tuning = True + + attn = Llama4Attention(config, layer_idx=1) + + x = mx.random.normal((2, 5, 64)) + output, _ = attn(x) + + assert output.shape == (2, 5, 64) + + def test_repeat_kv(self): + """Test KV repeat method.""" + config = Llama4TextConfig.tiny() + attn = Llama4Attention(config) + + x = mx.random.normal((2, 2, 10, 16)) + + # n_rep = 1, should return same + result = attn._repeat_kv(x, n_rep=1) + assert result.shape == x.shape + + # n_rep = 2 + result = attn._repeat_kv(x, n_rep=2) + assert result.shape == (2, 4, 10, 16) + + +class TestLlama4FlexAttention: + """Tests for Llama4FlexAttention.""" + + def test_creation(self): + """Test flex attention creation.""" + config = Llama4TextConfig.tiny() + attn = Llama4FlexAttention(config, layer_idx=0) + + assert attn.floor_scale == 1 + + def test_forward_pass(self): + """Test forward pass.""" + config = Llama4TextConfig.tiny() + attn = Llama4FlexAttention(config, layer_idx=1) + + x = mx.random.normal((2, 10, 64)) + output, cache = attn(x) + + assert output.shape == (2, 10, 64) + assert cache is not None + + +class TestCreateLlama4Attention: + """Tests for create_llama4_attention factory function.""" + + def test_create_default(self): + """Test creating default attention.""" + config = Llama4TextConfig.tiny() + attn = create_llama4_attention(config, layer_idx=0) + + assert isinstance(attn, Llama4Attention) + assert not isinstance(attn, Llama4FlexAttention) + + def test_create_flex(self): + """Test creating flex attention.""" + config = Llama4TextConfig.tiny() + attn = create_llama4_attention(config, layer_idx=0, attention_type="flex") + + assert isinstance(attn, Llama4FlexAttention) + + +class TestLlama4AttentionGradients: + """Tests for gradient flow through attention.""" + + def test_gradients_flow(self): + """Test gradients flow through attention.""" + config = Llama4TextConfig.tiny() + attn = Llama4Attention(config, layer_idx=1) + + x = mx.random.normal((2, 5, 64)) + + def loss_fn(model, x): + out, _ = model(x) + return mx.mean(out**2) + + loss_and_grad_fn = nn.value_and_grad(attn, loss_fn) + loss, grads = loss_and_grad_fn(attn, x) + + assert loss.item() > 0 + assert any(g is not None for g in grads.values()) + + def test_gradients_with_qk_norm(self): + """Test gradients with QK normalization.""" + config = Llama4TextConfig.tiny() + config.use_qk_norm = True + attn = Llama4Attention(config, layer_idx=1) + + x = mx.random.normal((2, 5, 64)) + + def loss_fn(model, x): + out, _ = model(x) + return mx.mean(out**2) + + loss_and_grad_fn = nn.value_and_grad(attn, loss_fn) + loss, grads = loss_and_grad_fn(attn, x) + + assert loss.item() > 0 diff --git a/tests/models_v2/families/llama4/test_config.py b/tests/models_v2/families/llama4/test_config.py new file mode 100644 index 00000000..3593beb4 --- /dev/null +++ b/tests/models_v2/families/llama4/test_config.py @@ -0,0 +1,174 @@ +""" +Tests for Llama 4 configuration. +""" + +from chuk_lazarus.models_v2.families.llama4.config import ( + Llama4Config, + Llama4TextConfig, + Llama4VisionConfig, +) + + +class TestLlama4TextConfig: + """Tests for Llama4TextConfig.""" + + def test_defaults(self): + """Test default configuration values.""" + config = Llama4TextConfig() + + assert config.model_type == "llama4" + assert config.hidden_act == "silu" + assert config.rope_theta == 500000.0 + assert config.rms_norm_eps == 1e-5 + assert config.num_local_experts == 16 + assert config.num_experts_per_tok == 1 + assert config.intermediate_size_mlp == 16384 + assert config.moe_router_topk == 1 + assert config.no_rope_layers is None + assert config.attention_chunk_size == 8192 + assert config.use_qk_norm is True + assert config.attn_temperature_tuning is False + assert config.rope_scaling is None + + def test_scout_17b(self): + """Test Llama 4 Scout preset.""" + config = Llama4TextConfig.scout_17b() + + assert config.vocab_size == 202048 + assert config.hidden_size == 5120 + assert config.num_hidden_layers == 48 + assert config.num_attention_heads == 40 + assert config.num_key_value_heads == 8 + assert config.intermediate_size == 8192 + assert config.intermediate_size_mlp == 16384 + assert config.num_local_experts == 16 + assert config.num_experts_per_tok == 1 + assert config.max_position_embeddings == 131072 + assert config.rope_theta == 500000.0 + assert config.use_qk_norm is True + assert config.tie_word_embeddings is False + assert config.no_rope_layers == [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44] + + def test_maverick_17b(self): + """Test Llama 4 Maverick preset.""" + config = Llama4TextConfig.maverick_17b() + + assert config.vocab_size == 202048 + assert config.hidden_size == 5120 + assert config.num_hidden_layers == 48 + assert config.num_attention_heads == 40 + assert config.num_key_value_heads == 8 + assert config.intermediate_size == 8192 + assert config.intermediate_size_mlp == 8192 + assert config.num_local_experts == 128 + assert config.num_experts_per_tok == 1 + assert config.use_qk_norm is True + + def test_tiny(self): + """Test tiny config for testing.""" + config = Llama4TextConfig.tiny() + + assert config.vocab_size == 1000 + assert config.hidden_size == 64 + assert config.num_hidden_layers == 4 + assert config.num_attention_heads == 4 + assert config.num_key_value_heads == 2 + assert config.intermediate_size == 128 + assert config.intermediate_size_mlp == 256 + assert config.num_local_experts == 4 + assert config.num_experts_per_tok == 1 + assert config.max_position_embeddings == 256 + assert config.use_qk_norm is True + assert config.no_rope_layers == [0] + + +class TestLlama4VisionConfig: + """Tests for Llama4VisionConfig.""" + + def test_defaults(self): + """Test default configuration values.""" + config = Llama4VisionConfig() + + assert config.model_type == "llama4_vision" + assert config.hidden_size == 1280 + assert config.num_hidden_layers == 32 + assert config.num_attention_heads == 16 + assert config.intermediate_size == 5120 + assert config.image_size == 560 + assert config.patch_size == 14 + assert config.num_channels == 3 + assert config.vision_output_dim == 5120 + assert config.pixel_shuffle_ratio == 0.5 + assert config.rms_norm_eps == 1e-5 + assert config.hidden_act == "gelu" + + def test_default_factory(self): + """Test default factory method.""" + config = Llama4VisionConfig.default() + + assert config.model_type == "llama4_vision" + assert config.hidden_size == 1280 + assert config.image_size == 560 + + +class TestLlama4Config: + """Tests for Llama4Config multimodal config.""" + + def test_scout_multimodal(self): + """Test Scout multimodal preset.""" + config = Llama4Config.scout_multimodal() + + assert config.model_type == "llama4" + assert config.text_config is not None + assert config.vision_config is not None + assert config.text_config.hidden_size == 5120 + assert config.vision_config.hidden_size == 1280 + assert config.image_token_index == 128011 + assert config.image_token == "<|image|>" + + def test_scout_text_only(self): + """Test Scout text-only preset.""" + config = Llama4Config.scout_text_only() + + assert config.text_config is not None + assert config.vision_config is None + + def test_text_config_vocab_size(self): + """Test accessing vocab_size via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.vocab_size == 202048 + + def test_text_config_hidden_size(self): + """Test accessing hidden_size via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.hidden_size == 5120 + + def test_text_config_num_hidden_layers(self): + """Test accessing num_hidden_layers via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.num_hidden_layers == 48 + + def test_text_config_num_attention_heads(self): + """Test accessing num_attention_heads via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.num_attention_heads == 40 + + def test_text_config_num_key_value_heads(self): + """Test accessing num_key_value_heads via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.num_key_value_heads == 8 + + def test_text_config_intermediate_size(self): + """Test accessing intermediate_size via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.intermediate_size == 8192 + + def test_text_config_rms_norm_eps(self): + """Test accessing rms_norm_eps via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.rms_norm_eps == 1e-5 + + def test_text_config_tie_word_embeddings(self): + """Test accessing tie_word_embeddings via text_config.""" + config = Llama4Config.scout_text_only() + assert config.text_config.tie_word_embeddings is False diff --git a/tests/models_v2/families/llama4/test_model.py b/tests/models_v2/families/llama4/test_model.py new file mode 100644 index 00000000..b107f276 --- /dev/null +++ b/tests/models_v2/families/llama4/test_model.py @@ -0,0 +1,337 @@ +""" +Tests for Llama 4 model. +""" + +import mlx.core as mx +import mlx.nn as nn + +from chuk_lazarus.models_v2.families.llama4.config import Llama4TextConfig +from chuk_lazarus.models_v2.families.llama4.model import ( + Llama4, + Llama4Block, + Llama4ForCausalLM, + Llama4Model, +) + + +class TestLlama4Block: + """Tests for Llama4Block.""" + + def test_creation(self): + """Test block creation.""" + config = Llama4TextConfig.tiny() + block = Llama4Block(config, layer_idx=0) + + assert block.hidden_size == 64 + assert block.layer_idx == 0 + + def test_block_type(self): + """Test block_type property.""" + from chuk_lazarus.models_v2.core.enums import BlockType + + config = Llama4TextConfig.tiny() + block = Llama4Block(config) + + assert block.block_type == BlockType.TRANSFORMER + + def test_hidden_size_property(self): + """Test hidden_size property.""" + config = Llama4TextConfig.tiny() + block = Llama4Block(config) + + assert block.hidden_size == 64 + + def test_forward_pass(self): + """Test forward pass.""" + config = Llama4TextConfig.tiny() + block = Llama4Block(config, layer_idx=1) + + x = mx.random.normal((2, 10, 64)) + output = block(x) + + assert output.hidden_states.shape == (2, 10, 64) + assert output.cache is not None + + def test_forward_with_mask(self): + """Test forward with mask.""" + config = Llama4TextConfig.tiny() + block = Llama4Block(config, layer_idx=1) + + x = mx.random.normal((2, 10, 64)) + mask = nn.MultiHeadAttention.create_additive_causal_mask(10) + output = block(x, mask=mask) + + assert output.hidden_states.shape == (2, 10, 64) + + def test_forward_with_cache(self): + """Test forward with cache.""" + config = Llama4TextConfig.tiny() + block = Llama4Block(config, layer_idx=1) + + # First pass + x1 = mx.random.normal((2, 10, 64)) + out1 = block(x1) + + # Second pass + x2 = mx.random.normal((2, 1, 64)) + out2 = block(x2, cache=out1.cache) + + assert out2.hidden_states.shape == (2, 1, 64) + + +class TestLlama4Model: + """Tests for Llama4Model backbone.""" + + def test_creation(self): + """Test model creation.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + assert model.hidden_size == 64 + assert model.num_layers == 4 + assert model.vocab_size == 1000 + + def test_forward_pass(self): + """Test forward pass.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.last_hidden_state.shape == (1, 5, 64) + assert output.cache is not None + assert len(output.cache) == 4 + assert output.hidden_states is None + + def test_forward_with_output_hidden_states(self): + """Test forward with output_hidden_states=True.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids, output_hidden_states=True) + + assert output.hidden_states is not None + assert len(output.hidden_states) == 5 # embeddings + 4 layers + + def test_forward_with_attention_mask(self): + """Test forward with attention mask.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + mask = nn.MultiHeadAttention.create_additive_causal_mask(5) + output = model(input_ids, attention_mask=mask) + + assert output.last_hidden_state.shape == (1, 5, 64) + + def test_forward_with_cache(self): + """Test forward with cache.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + # First pass + input_ids = mx.array([[1, 2, 3, 4, 5]]) + out1 = model(input_ids) + + # Second pass + next_token = mx.array([[6]]) + out2 = model(next_token, cache=out1.cache) + + assert out2.last_hidden_state.shape == (1, 1, 64) + + def test_get_input_embeddings(self): + """Test get_input_embeddings method.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + embeddings = model.get_input_embeddings() + assert embeddings is model.embed_tokens + + def test_set_input_embeddings(self): + """Test set_input_embeddings method.""" + config = Llama4TextConfig.tiny() + model = Llama4Model(config) + + new_embed = nn.Embedding(500, 64) + model.set_input_embeddings(new_embed) + + assert model.embed_tokens is new_embed + + +class TestLlama4ForCausalLM: + """Tests for Llama4ForCausalLM.""" + + def test_creation(self): + """Test model creation.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + assert model.config is config + + def test_creation_tied_embeddings(self): + """Test model with tied embeddings.""" + config = Llama4TextConfig.tiny() + config.tie_word_embeddings = True + model = Llama4ForCausalLM(config) + + assert model.lm_head is not None + + def test_creation_untied_embeddings(self): + """Test model without tied embeddings.""" + config = Llama4TextConfig.tiny() + config.tie_word_embeddings = False + model = Llama4ForCausalLM(config) + + assert model.lm_head is not None + + def test_backbone_property(self): + """Test backbone property.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + assert model.backbone is model.model + + def test_forward_pass(self): + """Test forward pass.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + output = model(input_ids) + + assert output.logits.shape == (1, 5, 1000) + assert output.loss is None + assert output.cache is not None + + def test_forward_with_labels(self): + """Test forward with labels.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + output = model(input_ids, labels=labels) + + assert output.loss is not None + assert output.loss.item() > 0 + + def test_forward_with_output_hidden_states(self): + """Test forward with output_hidden_states.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + output = model(input_ids, output_hidden_states=True) + + assert output.hidden_states is not None + + def test_generate_basic(self): + """Test basic generation.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=5) + + assert generated.shape[0] == 1 + assert generated.shape[1] >= 3 + + def test_generate_with_temperature(self): + """Test generation with temperature.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, temperature=0.5) + + assert generated.shape[1] >= 3 + + def test_generate_with_top_k(self): + """Test generation with top_k.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, top_k=10) + + assert generated.shape[1] >= 3 + + def test_generate_with_repetition_penalty(self): + """Test generation with repetition penalty.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=3, repetition_penalty=1.2) + + assert generated.shape[1] >= 3 + + def test_generate_with_stop_tokens(self): + """Test generation with stop tokens.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3]]) + generated = model.generate(input_ids, max_new_tokens=10, stop_tokens=[999]) + + assert generated.shape[1] >= 3 + + def test_from_config(self): + """Test from_config class method.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM.from_config(config) + + assert isinstance(model, Llama4ForCausalLM) + assert model.config is config + + def test_alias_llama4(self): + """Test Llama4 alias.""" + config = Llama4TextConfig.tiny() + model = Llama4(config) + + assert isinstance(model, Llama4ForCausalLM) + + +class TestLlama4Gradients: + """Tests for gradient flow.""" + + def test_loss_computation(self): + """Test that loss can be computed (gradient flow tests skipped due to MoE gather_mm limitation).""" + # Note: Full gradient tests are skipped because MLX's gather_mm operation + # (used in the MoE layer) does not support VJP with respect to indices. + # This is a known limitation: "Cannot calculate VJP with respect to indices" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + input_ids = mx.array([[1, 2, 3, 4, 5]]) + labels = mx.array([[2, 3, 4, 5, 6]]) + + output = model(input_ids, labels=labels) + assert output.loss is not None + assert output.loss.item() > 0 + + +class TestLlama4BatchHandling: + """Tests for batch handling.""" + + def test_different_batch_sizes(self): + """Test different batch sizes.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + for batch_size in [1, 2, 4]: + input_ids = mx.random.randint(0, config.vocab_size, (batch_size, 5)) + output = model(input_ids) + assert output.logits.shape == (batch_size, 5, config.vocab_size) + + def test_different_sequence_lengths(self): + """Test different sequence lengths.""" + config = Llama4TextConfig.tiny() + model = Llama4ForCausalLM(config) + + for seq_len in [1, 5, 10]: + input_ids = mx.random.randint(0, config.vocab_size, (2, seq_len)) + output = model(input_ids) + assert output.logits.shape == (2, seq_len, config.vocab_size) diff --git a/tests/models_v2/families/llama4/test_moe.py b/tests/models_v2/families/llama4/test_moe.py new file mode 100644 index 00000000..4eee6053 --- /dev/null +++ b/tests/models_v2/families/llama4/test_moe.py @@ -0,0 +1,219 @@ +""" +Tests for Llama 4 MoE. +""" + +import mlx.core as mx +import mlx.nn as nn + +from chuk_lazarus.models_v2.families.llama4.config import Llama4TextConfig +from chuk_lazarus.models_v2.families.llama4.moe import ( + Llama4MLP, + Llama4MoE, + SwitchGLU, + SwitchLinear, + create_llama4_moe, + swiglu, +) + + +class TestLlama4MLP: + """Tests for Llama4MLP (shared expert).""" + + def test_creation(self): + """Test MLP creation.""" + mlp = Llama4MLP(hidden_size=64, intermediate_size=128) + + assert mlp.hidden_size == 64 + assert mlp.intermediate_size == 128 + + def test_forward_pass(self): + """Test forward pass.""" + mlp = Llama4MLP(hidden_size=64, intermediate_size=128) + + x = mx.random.normal((2, 10, 64)) + output = mlp(x) + + assert output.shape == (2, 10, 64) + + def test_with_bias(self): + """Test MLP with bias.""" + mlp = Llama4MLP(hidden_size=64, intermediate_size=128, bias=True) + + x = mx.random.normal((2, 5, 64)) + output = mlp(x) + + assert output.shape == (2, 5, 64) + + +class TestSwitchLinear: + """Tests for SwitchLinear (expert selection layer).""" + + def test_creation(self): + """Test SwitchLinear creation.""" + layer = SwitchLinear(input_dims=64, output_dims=128, num_experts=4) + + assert layer.input_dims == 64 + assert layer.output_dims == 128 + assert layer.num_experts == 4 + assert layer.weight.shape == (4, 128, 64) + + def test_creation_with_bias(self): + """Test SwitchLinear with bias.""" + layer = SwitchLinear(input_dims=64, output_dims=128, num_experts=4, bias=True) + + assert "bias" in layer + assert layer.bias.shape == (4, 128) + + def test_forward_pass(self): + """Test forward pass.""" + layer = SwitchLinear(input_dims=64, output_dims=128, num_experts=4) + + # Input: (..., 1, 1, input_dims) + x = mx.random.normal((10, 1, 1, 64)) + indices = mx.array([[0], [1], [2], [3], [0], [1], [2], [3], [0], [1]]) + + output = layer(x, indices) + + # Output: (..., k, 1, output_dims) + assert output.shape == (10, 1, 1, 128) + + def test_forward_with_bias(self): + """Test forward with bias.""" + layer = SwitchLinear(input_dims=64, output_dims=128, num_experts=4, bias=True) + + x = mx.random.normal((5, 1, 1, 64)) + indices = mx.array([[0], [1], [2], [3], [0]]) + + output = layer(x, indices) + + assert output.shape == (5, 1, 1, 128) + + +class TestSwiglu: + """Tests for swiglu function.""" + + def test_swiglu(self): + """Test swiglu activation.""" + x = mx.random.normal((2, 5, 64)) + gate = mx.random.normal((2, 5, 64)) + + output = swiglu(x, gate) + + assert output.shape == (2, 5, 64) + + +class TestSwitchGLU: + """Tests for SwitchGLU (expert MLP with gather_mm).""" + + def test_creation(self): + """Test SwitchGLU creation.""" + switch_glu = SwitchGLU(input_dims=64, hidden_dims=128, num_experts=4, bias=False) + + assert switch_glu.gate_proj is not None + assert switch_glu.up_proj is not None + assert switch_glu.down_proj is not None + + def test_forward_pass(self): + """Test forward pass.""" + switch_glu = SwitchGLU(input_dims=64, hidden_dims=128, num_experts=4, bias=False) + + x = mx.random.normal((10, 64)) # (batch * seq, hidden_size) + indices = mx.array([[0], [1], [2], [3], [0], [1], [2], [3], [0], [1]]) + + output = switch_glu(x, indices) + + # Output: (batch * seq, k, hidden_size) + assert output.shape == (10, 1, 64) + + +class TestLlama4MoE: + """Tests for Llama4MoE.""" + + def test_creation(self): + """Test MoE creation.""" + config = Llama4TextConfig.tiny() + moe = Llama4MoE(config) + + assert moe.hidden_size == 64 + assert moe.intermediate_size == 128 + assert moe.intermediate_size_mlp == 256 + assert moe.num_experts == 4 + assert moe.num_experts_per_tok == 1 + + def test_forward_pass(self): + """Test forward pass.""" + config = Llama4TextConfig.tiny() + moe = Llama4MoE(config) + + x = mx.random.normal((2, 10, 64)) + output = moe(x) + + assert output.shape == (2, 10, 64) + + def test_shared_expert(self): + """Test that shared expert is always active.""" + config = Llama4TextConfig.tiny() + moe = Llama4MoE(config) + + # Shared expert should exist + assert moe.shared_expert is not None + + x = mx.random.normal((2, 5, 64)) + output = moe(x) + assert output.shape == (2, 5, 64) + + def test_router(self): + """Test that router produces valid outputs.""" + config = Llama4TextConfig.tiny() + moe = Llama4MoE(config) + + # Router should project to num_experts + x = mx.random.normal((2, 5, 64)) + router_logits = moe.router(x) + assert router_logits.shape == (2, 5, 4) + + +class TestCreateLlama4MoE: + """Tests for create_llama4_moe factory function.""" + + def test_create(self): + """Test factory function.""" + config = Llama4TextConfig.tiny() + moe = create_llama4_moe(config) + + assert isinstance(moe, Llama4MoE) + + +class TestLlama4MoEGradients: + """Tests for gradient flow through MoE.""" + + def test_forward_produces_output(self): + """Test forward pass produces valid output (gradient tests skipped due to gather_mm limitation).""" + # Note: Full gradient tests are skipped because MLX's gather_mm operation + # (used in expert selection) does not support VJP with respect to indices. + # This is a known limitation: "Cannot calculate VJP with respect to indices" + config = Llama4TextConfig.tiny() + moe = Llama4MoE(config) + + x = mx.random.normal((2, 5, 64)) + out = moe(x) + + assert out.shape == (2, 5, 64) + # Verify output is finite + assert mx.all(mx.isfinite(out)) + + def test_shared_expert_gradients(self): + """Test gradients flow through shared expert.""" + config = Llama4TextConfig.tiny() + moe = Llama4MoE(config) + + x = mx.random.normal((2, 5, 64)) + + def loss_fn(model, x): + # Just use shared expert (no gather_mm) + return mx.mean(model.shared_expert(x) ** 2) + + loss_and_grad_fn = nn.value_and_grad(moe, loss_fn) + loss, grads = loss_and_grad_fn(moe, x) + + assert loss.item() > 0 diff --git a/tests/models_v2/families/test_llama.py b/tests/models_v2/families/test_llama.py index 87bf1713..3306397d 100644 --- a/tests/models_v2/families/test_llama.py +++ b/tests/models_v2/families/test_llama.py @@ -5,6 +5,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.families.llama import ( LlamaBlock, @@ -401,7 +402,8 @@ def loss_fn(model, input_ids, labels): output = model(input_ids, labels=labels) return output.loss - loss, grads = mx.value_and_grad(loss_fn)(model, input_ids, labels) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) assert loss.item() > 0 # Check some gradients exist diff --git a/tests/models_v2/families/test_mamba.py b/tests/models_v2/families/test_mamba.py index 04b9512b..72458b50 100644 --- a/tests/models_v2/families/test_mamba.py +++ b/tests/models_v2/families/test_mamba.py @@ -5,6 +5,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.families.mamba import ( MambaConfig, @@ -336,7 +337,8 @@ def loss_fn(model, input_ids, labels): output = model(input_ids, labels=labels) return output.loss - loss, grads = mx.value_and_grad(loss_fn)(model, input_ids, labels) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) assert loss.item() > 0 # Check some gradients exist @@ -353,7 +355,8 @@ def loss_fn(model, input_ids): output = model(input_ids) return mx.mean(output.logits**2) - loss, grads = mx.value_and_grad(loss_fn)(model, input_ids) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids) assert loss.item() > 0 diff --git a/tests/models_v2/heads/test_heads.py b/tests/models_v2/heads/test_heads.py index 7b34223d..7e8361a9 100644 --- a/tests/models_v2/heads/test_heads.py +++ b/tests/models_v2/heads/test_heads.py @@ -342,7 +342,8 @@ def loss_fn(model, hidden_states, labels): out = model(hidden_states, labels=labels) return out.loss - loss, grads = mx.value_and_grad(loss_fn)(head, hidden_states, labels) + loss_and_grad_fn = nn.value_and_grad(head, loss_fn) + loss, grads = loss_and_grad_fn(head, hidden_states, labels) assert loss.item() > 0 assert any(g is not None for g in grads.values()) @@ -361,7 +362,8 @@ def loss_fn(model, hidden_states, labels): out = model(hidden_states, labels=labels) return out.loss - loss, grads = mx.value_and_grad(loss_fn)(head, hidden_states, labels) + loss_and_grad_fn = nn.value_and_grad(head, loss_fn) + loss, grads = loss_and_grad_fn(head, hidden_states, labels) assert loss.item() > 0 @@ -379,7 +381,8 @@ def loss_fn(model, hidden_states, labels): out = model(hidden_states, labels=labels) return out.loss - loss, grads = mx.value_and_grad(loss_fn)(head, hidden_states, labels) + loss_and_grad_fn = nn.value_and_grad(head, loss_fn) + loss, grads = loss_and_grad_fn(head, hidden_states, labels) assert loss.item() >= 0 diff --git a/tests/models_v2/models/classifiers/test_linear.py b/tests/models_v2/models/classifiers/test_linear.py index 3d76a3ee..ac7e347c 100644 --- a/tests/models_v2/models/classifiers/test_linear.py +++ b/tests/models_v2/models/classifiers/test_linear.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.models.classifiers import LinearClassifier @@ -78,7 +79,8 @@ def loss_fn(model): log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) return -mx.mean(log_probs[mx.arange(4), targets]) - loss, grads = mx.value_and_grad(loss_fn)(clf) + loss_and_grad_fn = nn.value_and_grad(clf, loss_fn) + loss, grads = loss_and_grad_fn(clf) assert loss.item() > 0 assert "fc" in grads diff --git a/tests/models_v2/models/classifiers/test_mlp.py b/tests/models_v2/models/classifiers/test_mlp.py index ddf12a98..8e485622 100644 --- a/tests/models_v2/models/classifiers/test_mlp.py +++ b/tests/models_v2/models/classifiers/test_mlp.py @@ -3,6 +3,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.core.enums import ActivationType from chuk_lazarus.models_v2.models.classifiers import MLPClassifier @@ -111,7 +112,8 @@ def loss_fn(model): log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) return -mx.mean(log_probs[mx.arange(4), targets]) - loss, grads = mx.value_and_grad(loss_fn)(clf) + loss_and_grad_fn = nn.value_and_grad(clf, loss_fn) + loss, grads = loss_and_grad_fn(clf) assert loss.item() > 0 assert "mlp" in grads diff --git a/tests/models_v2/models/test_models.py b/tests/models_v2/models/test_models.py index b0a0f924..25cb8959 100644 --- a/tests/models_v2/models/test_models.py +++ b/tests/models_v2/models/test_models.py @@ -5,6 +5,7 @@ """ import mlx.core as mx +import mlx.nn as nn from chuk_lazarus.models_v2.core.config import ModelConfig from chuk_lazarus.models_v2.core.enums import BackboneType @@ -723,7 +724,8 @@ def loss_fn(model, input_ids, labels): output = model(input_ids, labels=labels) return output.loss - loss, grads = mx.value_and_grad(loss_fn)(model, input_ids, labels) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) assert loss.item() > 0 @@ -744,7 +746,8 @@ def loss_fn(model, input_ids, labels): output = model(input_ids, labels=labels) return output.loss - loss, grads = mx.value_and_grad(loss_fn)(model, input_ids, labels) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) assert loss.item() > 0 @@ -765,7 +768,8 @@ def loss_fn(model, input_ids, labels): output = model(input_ids, labels=labels) return output.loss - loss, grads = mx.value_and_grad(loss_fn)(model, input_ids, labels) + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + loss, grads = loss_and_grad_fn(model, input_ids, labels) class TestModelBase: diff --git a/uv.lock b/uv.lock index 24294f1b..64a74c8b 100644 --- a/uv.lock +++ b/uv.lock @@ -227,7 +227,7 @@ wheels = [ [[package]] name = "chuk-lazarus" -version = "0.2.3" +version = "0.4" source = { editable = "." } dependencies = [ { name = "aiofiles" },