diff --git a/.env.example b/.env.example index 822b7a76..e2b5a409 100644 --- a/.env.example +++ b/.env.example @@ -82,6 +82,18 @@ DNET_KV_GROUP_SIZE=64 # KV cache TTL in seconds DNET_KV_TTL_S=30.0 +# === Context Parallelism === +# Enable context parallelism mode +DNET_CP_ENABLED=false +# Ring attention algorithm (auto, pass_kv, pass_q, ring_reduce) +DNET_CP_ALGORITHM=auto +# Minimum context length to enable CP (below this, single-device) +DNET_CP_MIN_CONTEXT_FOR_CP=32768 +# Minimum new tokens to prefer pass_kv over pass_q +DNET_CP_MIN_TOKENS_FOR_PASS_KV=256 +# Overlap between chunks for sliding window attention +DNET_CP_CHUNK_OVERLAP=0 + # === gRPC === # Max gRPC message length DNET_GRPC_MAX_MESSAGE_LENGTH=67108864 diff --git a/.github/workflows/cp-integration-tests.yml b/.github/workflows/cp-integration-tests.yml new file mode 100644 index 00000000..de569d80 --- /dev/null +++ b/.github/workflows/cp-integration-tests.yml @@ -0,0 +1,109 @@ +name: CP Integration Tests + +on: + workflow_dispatch: + inputs: + model_filter: + description: 'Model filter for tests (e.g. "qwen")' + required: false + default: '' + pull_request: + paths: + - 'src/dnet/core/cp/**' + - 'src/dnet/shard/adapters/context_parallel.py' + - 'src/dnet/api/strategies/context_parallel.py' + - 'tests/integration/test_cp_*.py' + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + cp-integration-tests: + runs-on: mac2.metal + timeout-minutes: 60 + env: + PROJECT_ROOT: ${{ github.workspace }} + PYTHONPATH: src + DNET_CP_ENABLED: 'true' + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Setup Environment + uses: ./.github/actions/setup-env + with: + python_version: '3.12' + + - name: Enable CP in .env + run: | + # Force DNET_CP_ENABLED=true in .env file (overrides default) + # Note: macOS sed requires -i '' for in-place edit + if grep -q "^DNET_CP_ENABLED=" .env 2>/dev/null; then + sed -i '' 's/^DNET_CP_ENABLED=.*/DNET_CP_ENABLED=true/' .env + else + echo "DNET_CP_ENABLED=true" >> .env + fi + echo "Updated .env:" + grep DNET_CP_ .env || echo "No DNET_CP_ settings found" + + - name: Ensure compatible gRPC/protobuf versions + run: | + uv pip install --upgrade "grpcio>=1.75.1" "protobuf>=6.31.1" + + - name: Run CP unit tests + run: | + uv run pytest tests/subsystems/test_cp_*.py -v --tb=short + + - name: Kill processes on required ports + run: | + for port in 8080 8081 58080 58081; do + lsof -ti:$port | xargs kill -9 2>/dev/null || true + done + sleep 2 + + - name: Verify CP environment + run: | + echo "DNET_CP_ENABLED=${DNET_CP_ENABLED}" + if [ "$DNET_CP_ENABLED" != "true" ]; then + echo "::error::DNET_CP_ENABLED is not set to true" + exit 1 + fi + + - name: Start shard server + uses: ./.github/actions/start-shard + with: + http_port: '8081' + grpc_port: '58081' + + - name: Start API server + uses: ./.github/actions/start-api + with: + http_port: '8080' + grpc_port: '58080' + + - name: Run integration tests + run: | + sleep 10 # Wait for servers to initialize + echo "Running tests with DNET_CP_ENABLED=${DNET_CP_ENABLED}" + if [ -n "${{ github.event.inputs.model_filter }}" ]; then + uv run pytest tests/integration/test_model_catalog.py -v -x -k "${{ github.event.inputs.model_filter }}" --tb=short + else + uv run pytest tests/integration/test_model_catalog.py -v -x --tb=short + fi + + - name: Cleanup servers + if: always() + uses: ./.github/actions/cleanup-servers + + - name: Show logs on failure + if: failure() + run: | + echo "=== Shard logs ===" + cat shard.log 2>/dev/null || echo "(no shard log)" + echo "" + echo "=== API logs ===" + cat api.log 2>/dev/null || echo "(no API log)" diff --git a/.gitignore b/.gitignore index ecc24c30..3ff8831c 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ repacked_models/* # Env files *.env* !.env*.example +dnet-tui/ diff --git a/docs/design/context-parallelism.md b/docs/design/context-parallelism.md new file mode 100644 index 00000000..b0a51031 --- /dev/null +++ b/docs/design/context-parallelism.md @@ -0,0 +1,941 @@ +# Context Parallelism for Long-Context Inference + +## 1. Executive Summary + +This document describes the design for adding **Context Parallelism (CP)** to dnet, enabling long-context inference (128K+ tokens) by distributing sequence dimensions across multiple Apple Silicon devices. CP complements the existing **RingStrategy** (layer/pipeline parallelism) with a new axis of parallelization. + +### Goals + +- **Primary**: Enable 128K+ context inference across heterogeneous device clusters +- **Secondary**: Achieve near-linear latency scaling with device count +- **Constraint**: Zero approximations to attention computation (exact attention) + +### Non-Goals (v1) + +- Mixed CP + pipeline parallelism (future work) +- Training support (inference-only) +- CUDA/AMD backends (Apple Silicon only) + +--- + +## 2. Background + +### 2.1 Current Architecture + +```mermaid +graph LR + subgraph "Pipeline Parallelism" + A[API] --> S1[Shard 1
Layers 0-10] + S1 --> S2[Shard 2
Layers 11-20] + S2 --> S3[Shard 3
Layers 21-31] + S3 -->|token| A + end +``` + +The current dnet uses **pipeline parallelism**: each shard owns a subset of layers, and activations flow through the ring. This works well for large models but does **not** reduce per-device context memory. + +### 2.2 Problem Statement + +| Context Length | KV Cache (FP16, 7B model) | Fits in 24GB RAM? | +|----------------|---------------------------|-------------------| +| 8K | ~1 GB | Yes | +| 32K | ~4 GB | Yes | +| 128K | ~16 GB | Tight | +| 512K | ~64 GB | No | +| 1M | ~128 GB | No | + +Pipeline parallelism does **not** shard KV cache across devices. Context Parallelism solves this. + +### 2.3 Ring Attention + +Ring Attention (Liu et al., 2023) distributes the **sequence dimension** across devices: + +```mermaid +graph LR + subgraph "Context Parallelism" + D1[Device 1
Tokens 0-32K] --> D2[Device 2
Tokens 32K-64K] + D2 --> D3[Device 3
Tokens 64K-96K] + D3 --> D4[Device 4
Tokens 96K-128K] + D4 -->|KV blocks| D1 + end +``` + +Key insight: Blockwise attention is **permutation invariant** over KV blocks, so we can compute partial attention in any order and merge results. + +--- + +## 3. Design Overview + +### 3.1 High-Level Architecture + +```mermaid +flowchart TB + subgraph API["API Node"] + direction TB + CM["ClusterManager"] + MM["ModelManager"] + IM["InferenceManager"] + CPS["ContextParallelStrategy"] + CPTS["CPTopologySolver"] + CPAA["CPApiAdapter"] + CPS -->|solver| CPTS + CPS -->|adapter| CPAA + IM --> CPAA + end + + subgraph Shards["Shard Nodes (CP Ring)"] + direction LR + subgraph S1["Shard 1"] + CPA1["Adapter 1"] + SR1["Runtime 1 (Full Model)"] + CPA1 --> SR1 + end + subgraph S2["Shard 2"] + CPA2["Adapter 2"] + SR2["Runtime 2 (Full Model)"] + CPA2 --> SR2 + end + subgraph S3["Shard 3"] + CPA3["Adapter 3"] + SR3["Runtime 3 (Full Model)"] + CPA3 --> SR3 + end + subgraph S4["Shard 4"] + CPA4["Adapter 4"] + SR4["Runtime 4 (Full Model)"] + CPA4 --> SR4 + end + end + + CPAA --> CPA1 + CPA1 <-.->|"KV/Q blocks"| CPA2 + CPA2 <-.->|"KV/Q blocks"| CPA3 + CPA3 <-.->|"KV/Q blocks"| CPA4 + CPA4 <-.->|"KV/Q blocks"| CPA1 +``` + +**Data Flow**: + +1. API receives request → `InferenceManager` → `CPApiAdapter` +2. `CPApiAdapter` sends sharded tokens to Shard 1 (head of ring) +3. Each shard computes partial attention, rotates KV/Q blocks around ring +4. Final merged output returns to API via `CPApiAdapter` + +### 3.2 Key Differences from RingStrategy + +| Aspect | RingStrategy (Pipeline) | ContextParallelStrategy | +|---------------------|----------------------------|--------------------------------| +| Sharding axis | Layers | Sequence (tokens) | +| Model per device | Partial (subset of layers) | Full (all layers) | +| KV cache per device | Full context | 1/N of context | +| Communication | Activations between layers | KV or Q blocks between devices | +| Memory scaling | With model size | With context length | + +--- + +## 4. Detailed Design + +### 4.1 New Components + +#### 4.1.1 Load-Balanced Sharding + +Causal attention has asymmetric compute: later tokens attend to more predecessors. Naive even partitioning causes load imbalance. + +**Solution**: Partition sequence into `2N` chunks, assign complementary pairs: + +```text +Sequence: [C0, C1, C2, C3, C4, C5, C6, C7] (8 chunks for 4 devices) + +Device 0: [C0, C7] # first + last +Device 1: [C1, C6] +Device 2: [C2, C5] +Device 3: [C3, C4] +``` + +Each device gets roughly equal compute load. + +```python +# src/dnet/core/cp/sharding.py +def load_balanced_shard( + tokens: mx.array, # [seq_len, ...] + num_ranks: int, + rank_id: int, +) -> tuple[mx.array, list[int]]: + """ + Shard tokens with load balancing for causal attention. + + Returns: + sharded_tokens: tokens for this rank + chunk_indices: original positions (for unsharding) + """ + seq_len = tokens.shape[0] + chunk_size = seq_len // (2 * num_ranks) + + # Assign chunks (i, 2N-i-1) to rank i + chunk_a = rank_id + chunk_b = 2 * num_ranks - rank_id - 1 + + start_a = chunk_a * chunk_size + end_a = start_a + chunk_size + start_b = chunk_b * chunk_size + end_b = start_b + chunk_size if chunk_b < 2 * num_ranks - 1 else seq_len + + sharded = mx.concatenate([tokens[start_a:end_a], tokens[start_b:end_b]]) + chunk_indices = list(range(start_a, end_a)) + list(range(start_b, end_b)) + + return sharded, chunk_indices +``` + +#### 4.1.2 Merge Attention Operator + +When computing blockwise attention across distributed KV, each device produces partial outputs with local softmax denominators. These must be merged correctly. + +**Math**: For blocks with outputs `O_i`, max scores `m_i`, and log-sum-exp `l_i`: + +```text +m_global = max(m_1, m_2, ..., m_N) +l_global = sum(exp(m_i - m_global) * l_i) +O_merged = sum(exp(m_i - m_global) * l_i * O_i) / l_global +``` + +```python +# src/dnet/core/cp/merge_attention.py +@dataclass +class PartialAttentionOutput: + output: mx.array # [batch, seq, heads, dim] + max_score: mx.array # [batch, seq, heads] + log_sum_exp: mx.array # [batch, seq, heads] + +def merge_partial_attention( + partials: list[PartialAttentionOutput], +) -> mx.array: + """Merge partial attention outputs with numerically stable rescaling.""" + # Find global max for stability + m_global = partials[0].max_score + for p in partials[1:]: + m_global = mx.maximum(m_global, p.max_score) + + # Rescale and accumulate + numerator = mx.zeros_like(partials[0].output) + denominator = mx.zeros_like(partials[0].log_sum_exp) + + for p in partials: + scale = mx.exp(p.max_score - m_global) + numerator += scale[..., None] * p.log_sum_exp[..., None] * p.output + denominator += scale * p.log_sum_exp + + return numerator / denominator[..., None] +``` + +#### 4.1.3 Ring Communication + +gRPC-based ring for passing KV or Q blocks between CP ranks. + +```python +# src/dnet/core/cp/ring_comm.py +class CPRingCommunicator: + """Manages ring communication for context parallelism.""" + + def __init__( + self, + rank_id: int, + num_ranks: int, + discovery: AsyncDnetP2P, + ): + self.rank_id = rank_id + self.num_ranks = num_ranks + self._prev_rank = (rank_id - 1) % num_ranks + self._next_rank = (rank_id + 1) % num_ranks + self._discovery = discovery + + # gRPC channels + self._prev_channel: Optional[aio_grpc.Channel] = None + self._next_channel: Optional[aio_grpc.Channel] = None + + async def send_recv( + self, + send_data: bytes, + tag: str, + ) -> bytes: + """ + Simultaneously send to next rank and receive from previous rank. + Overlaps communication with computation when used correctly. + """ + send_task = asyncio.create_task(self._send_to_next(send_data, tag)) + recv_task = asyncio.create_task(self._recv_from_prev(tag)) + + await send_task + return await recv_task +``` + +### 4.2 Ring Attention Variants + +#### 4.2.1 Pass-KV (Full Prefill) + +Best for full prefill where KV is smaller than Q (GQA models: 8 KV heads vs 128 Q heads). + +```python +# src/dnet/shard/adapters/context_parallel.py +async def ring_pass_kv_attention( + self, + query: mx.array, # Local Q chunk + key: mx.array, # Local K chunk (will be rotated) + value: mx.array, # Local V chunk (will be rotated) +) -> mx.array: + """ + Ring attention with KV rotation. + + Algorithm: + 1. Compute local attention: Attn(Q_local, KV_local) + 2. For i in 1..N-1: + a. SendRecv: send KV to next, receive from prev + b. Compute partial attention with received KV + c. Accumulate partial outputs + 3. Merge all partial outputs + """ + partials: list[PartialAttentionOutput] = [] + + # Local attention first + local_out = self._compute_partial_attention(query, key, value) + partials.append(local_out) + + current_k, current_v = key, value + + for step in range(1, self.num_ranks): + # Overlap: send current KV while computing with previous + kv_bytes = self._serialize_kv(current_k, current_v) + recv_bytes = await self.ring_comm.send_recv(kv_bytes, f"kv_{step}") + current_k, current_v = self._deserialize_kv(recv_bytes) + + # Compute attention with received KV + partial = self._compute_partial_attention(query, current_k, current_v) + partials.append(partial) + + return merge_partial_attention(partials) +``` + +#### 4.2.2 Pass-Q (Decode / High Cache Hit) + +Best for decode (single token Q) or partial prefill with high cache hit rate. + +```python +async def ring_pass_q_attention( + self, + query: mx.array, # Local Q chunk (will be rotated) + key: mx.array, # Full local K (stationary) + value: mx.array, # Full local V (stationary) +) -> mx.array: + """ + Ring attention with Q rotation. + + Key difference: After ring loop, partial outputs are scattered + across ranks. Requires All2All to redistribute. + """ + # Compute attention for local Q against local KV + local_outputs: dict[int, PartialAttentionOutput] = {} + + current_q = query + source_rank = self.rank_id + + for step in range(self.num_ranks): + # Compute attention: Q from source_rank, KV from local + partial = self._compute_partial_attention(current_q, key, value) + local_outputs[source_rank] = partial + + if step < self.num_ranks - 1: + q_bytes = self._serialize_q(current_q) + recv_bytes = await self.ring_comm.send_recv(q_bytes, f"q_{step}") + current_q = self._deserialize_q(recv_bytes) + source_rank = (source_rank - 1) % self.num_ranks + + # All2All: redistribute partial outputs to source ranks + my_partials = await self._all2all_outputs(local_outputs) + + return merge_partial_attention(my_partials) +``` + +#### 4.2.3 Adaptive Heuristic + +```python +# src/dnet/core/cp/heuristics.py +def select_ring_algorithm( + new_tokens: int, # T + cached_tokens: int, # P + num_kv_heads: int, # NKV + num_q_heads: int, # NH + num_ranks: int, # N + flops_per_device: float, # C + inter_device_bandwidth: float # BW +) -> Literal["pass_kv", "pass_q"]: + """ + Select optimal ring algorithm based on cache miss rate and arithmetic intensity. + + Heuristic (from Meta's paper): + - pass-KV if T/(T+P) >= 2*NKV/NH (cache miss rate threshold) + - pass-KV if T >= N * (C * NKV * e) / (2 * NH * BW) (sufficient compute) + - pass-Q otherwise + """ + total_tokens = new_tokens + cached_tokens + miss_rate = new_tokens / total_tokens if total_tokens > 0 else 1.0 + + # Threshold from GQA ratio + gqa_threshold = 2 * num_kv_heads / num_q_heads # e.g., 2*8/128 = 0.125 + + if miss_rate >= gqa_threshold: + return "pass_kv" + + # Check if sufficient compute to overlap pass-KV communication + element_size = 2 # bfloat16 + min_tokens_for_overlap = num_ranks * (flops_per_device * num_kv_heads * element_size) / (2 * num_q_heads * inter_device_bandwidth) + + if new_tokens >= min_tokens_for_overlap: + return "pass_kv" + + return "pass_q" +``` + +### 4.3 Strategy Integration + +#### 4.3.1 ContextParallelStrategy + +```python +# src/dnet/api/strategies/context_parallel.py +class CPTopologySolver(TopologySolver): + """Topology solver for context parallelism.""" + + async def solve( + self, + profiles: Dict[str, DeviceProfile], + model_profile: Any, + model_name: str, + num_layers: int, + kv_bits: Literal["4bit", "8bit", "fp16"], + shards: Dict[str, DnetDeviceProperties], + thunderbolts: Dict[str, Dict[str, ThunderboltConnection]], + ) -> CPTopologyInfo: + """ + For CP, all devices get the full model. + Optimize ordering for ring bandwidth. + """ + # Order devices by Thunderbolt connectivity for minimal latency + ordered = self._optimize_ring_order(shards, thunderbolts) + + return CPTopologyInfo( + model=model_name, + kv_bits=kv_bits, + num_layers=num_layers, + devices=ordered, + # Each device gets ALL layers (full model) + assignments={name: list(range(num_layers)) for name in ordered}, + num_cp_ranks=len(ordered), + ) + + +class ContextParallelStrategy(Strategy): + """Execution strategy using context parallelism.""" + + def __init__(self): + self._solver = CPTopologySolver() + self._adapter = CPApiAdapter() + + @property + def solver(self) -> TopologySolver: + return self._solver + + @property + def adapter(self) -> ApiAdapterBase: + return self._adapter +``` + +#### 4.3.2 Shard-Side CPAdapter + +```python +# src/dnet/shard/adapters/context_parallel.py +class CPAdapter(ShardAdapterBase): + """Context parallel adapter for shards.""" + + def __init__( + self, + runtime: ShardRuntime, + discovery: AsyncDnetP2P, + rank_id: int, + num_ranks: int, + ): + super().__init__(runtime, discovery) + self.rank_id = rank_id + self.num_ranks = num_ranks + self.ring_comm = CPRingCommunicator(rank_id, num_ranks, discovery) + self._algorithm: Literal["pass_kv", "pass_q"] = "pass_kv" + + async def configure_topology(self, req: ShardLoadModelRequest) -> None: + """Configure CP topology from load request.""" + self.rank_id = req.cp_rank_id + self.num_ranks = req.cp_num_ranks + await self.ring_comm.connect_neighbors() + + async def process_activation(self, msg: ActivationMessage) -> ActivationMessage: + """Process with context-parallel attention.""" + # 1. Load-balanced unshard to get local tokens + local_tokens, indices = load_balanced_shard( + msg.tokens, self.num_ranks, self.rank_id + ) + + # 2. Compute embeddings and projections locally + hidden = self.runtime.compute_embeddings(local_tokens) + q, k, v = self.runtime.compute_qkv(hidden) + + # 3. Ring attention (select algorithm dynamically) + if self._algorithm == "pass_kv": + attn_out = await self.ring_pass_kv_attention(q, k, v) + else: + attn_out = await self.ring_pass_q_attention(q, k, v) + + # 4. FFN + output projection (local compute) + output = self.runtime.compute_ffn(attn_out) + + return msg.with_output(output, indices) +``` + +### 4.4 Configuration + +Following the existing pattern in `config.py`, we use `Literal` types for constrained choices (which Pydantic validates) and integrate with the `.env.example` auto-generation via `scripts/generate_env_example.py`. + +```python +# src/dnet/config.py (additions) +from enum import StrEnum + +class CPAlgorithm(StrEnum): + """Ring attention algorithm selection.""" + AUTO = "auto" # Dynamic selection based on heuristics + PASS_KV = "pass_kv" # Rotate KV blocks (best for prefill) + PASS_Q = "pass_q" # Rotate Q blocks (best for decode) + + +class ContextParallelSettings(BaseSettings): + """Context parallelism configuration.""" + + model_config = SettingsConfigDict(env_prefix="DNET_CP_") + + enabled: bool = Field( + default=False, + description="Enable context parallelism mode", + ) + algorithm: CPAlgorithm = Field( + default=CPAlgorithm.AUTO, + description="Ring attention algorithm (auto, pass_kv, pass_q)", + ) + min_context_for_cp: int = Field( + default=32768, + description="Minimum context length to enable CP (below this, single-device)", + ) + chunk_overlap: int = Field( + default=0, + description="Overlap between chunks for sliding window attention", + ) +``` + +**`.env.example` Integration**: + +1. Add `ContextParallelSettings` to `generate_env_example.py`: + +```python +# scripts/generate_env_example.py +from dnet.config import ContextParallelSettings + +settings_sections = [ + # ... existing ... + ("Context Parallelism", ContextParallelSettings), +] +``` + +1. Run `make env-example` to regenerate `.env.example` with CP settings: + +```bash +# Generated output: +# === Context Parallelism === +# Enable context parallelism mode +DNET_CP_ENABLED=false +# Ring attention algorithm (auto, pass_kv, pass_q) +DNET_CP_ALGORITHM=auto +# Minimum context length to enable CP (below this, single-device) +DNET_CP_MIN_CONTEXT_FOR_CP=32768 +# Overlap between chunks for sliding window attention +DNET_CP_CHUNK_OVERLAP=0 +``` + +### 4.5 Protocol Changes + +#### Decision: Separate proto file vs. additions to existing + +| Approach | Pros | Cons | +|------------------------------|---------------------------------------------------------------|------------------------------------------------------------| +| **Separate `dnet_cp.proto`** | Clean separation; easier to deprecate; independent versioning | More generated files; cross-import needed for shared types | +| **Add to `dnet_ring.proto`** | Reuses existing types (`ActivationRequest`); fewer imports | Couples CP to ring; larger proto file | + +**Recommendation**: Create `dnet_cp.proto` as a **separate file** because: + +1. CP and pipeline ring are independent strategies—they shouldn't be coupled +2. `KVBlockTransfer`/`QBlockTransfer` are CP-specific and don't belong in ring transport +3. Easier to iterate on CP without risk of breaking existing ring protocol + +```protobuf +// src/dnet/protos/dnet_cp.proto (NEW FILE) +syntax = "proto3"; +package dnetcp; + +// Context Parallelism ring communication service +service CPRingService { + // Bidirectional stream for KV/Q block transfer + rpc StreamBlocks(stream CPBlockFrame) returns (stream CPBlockAck); +} + +// Configuration for CP distributed attention +message CPConfig { + int32 rank_id = 1; + int32 num_ranks = 2; + repeated string rank_addresses = 3; // Ordered ring addresses + string algorithm = 4; // "pass_kv" or "pass_q" +} + +// Frame for streaming KV or Q blocks +message CPBlockFrame { + string nonce = 1; + int32 source_rank = 2; + int32 layer_id = 3; + oneof payload { + KVBlock kv_block = 4; + QBlock q_block = 5; + } + uint64 seq = 6; +} + +message KVBlock { + bytes key_data = 1; + bytes value_data = 2; + bytes max_scores = 3; // For merge attention + bytes log_sum_exp = 4; +} + +message QBlock { + bytes query_data = 1; + repeated int32 token_indices = 2; // For unsharding +} + +message CPBlockAck { + string nonce = 1; + uint64 seq = 2; + bool accepted = 3; +} +``` + +**Minor addition to `dnet_ring.proto`** (for CP-enabled requests): + +```protobuf +// src/dnet/protos/dnet_ring.proto - add to ActivationRequest +message ActivationRequest { + // ... existing fields 1-13 ... + optional CPConfig cp_config = 14; // CP metadata (if CP mode) +} +``` + +--- + +## 5. Proposed Changes + +### 5.1 New Files + +| File | Purpose | +|-----------------------------------------------|--------------------------------------------| +| `src/dnet/core/cp/__init__.py` | CP subpackage | +| `src/dnet/core/cp/sharding.py` | Load-balanced sharding utilities | +| `src/dnet/core/cp/merge_attention.py` | Merge attention operator | +| `src/dnet/core/cp/ring_comm.py` | Ring communication primitives | +| `src/dnet/core/cp/heuristics.py` | Algorithm selection heuristics | +| `src/dnet/api/strategies/context_parallel.py` | CPTopologySolver + ContextParallelStrategy | +| `src/dnet/shard/adapters/context_parallel.py` | CPAdapter | +| `tests/subsystems/test_cp_sharding.py` | Sharding unit tests | +| `tests/subsystems/test_cp_merge.py` | Merge attention tests | +| `tests/subsystems/test_cp_heuristics.py` | Heuristic tests | + +### 5.2 Modified Files + +#### [MODIFY] [config.py](file:///home/jaiswal0/Desktop/dria/repo/dnet/src/dnet/config.py) + +- Add `ContextParallelSettings` class +- Add `context_parallel: ContextParallelSettings` to `DnetSettings` + +#### [MODIFY] [dnet_ring.proto](file:///home/jaiswal0/Desktop/dria/repo/dnet/src/dnet/protos/dnet_ring.proto) + +- Add `CPConfig`, `KVBlockTransfer`, `QBlockTransfer` messages +- Add `cp_config` field to `ActivationRequest` + +#### [MODIFY] [api.py](file:///home/jaiswal0/Desktop/dria/repo/dnet/src/cli/api.py) + +- Add strategy selection based on config (RingStrategy vs ContextParallelStrategy) + +#### [MODIFY] [shard.py](file:///home/jaiswal0/Desktop/dria/repo/dnet/src/cli/shard.py) + +- Add adapter selection based on topology info + +#### [MODIFY] [models.py](file:///home/jaiswal0/Desktop/dria/repo/dnet/src/dnet/shard/models.py) + +- Add `cp_rank_id`, `cp_num_ranks` to `ShardLoadModelRequest` + +--- + +## 6. Implementation Phases + +### Phase 1: Core Infrastructure (2-3 days) + +1. Create `src/dnet/core/cp/` package +2. Implement `sharding.py` with load-balanced partitioning +3. Implement `merge_attention.py` with numerically stable merging +4. Add unit tests for sharding and merging + +### Phase 2: Ring Communication (2-3 days) + +1. Implement `ring_comm.py` with gRPC send/recv +2. Add protobuf messages for KV/Q block transfers +3. Test ring formation with fake discovery + +### Phase 3: Ring Attention Variants (3-4 days) + +1. Implement pass-KV algorithm in `CPAdapter` +2. Implement pass-Q algorithm with All2All +3. Implement adaptive heuristic +4. Integration tests with 2+ simulated ranks + +### Phase 4: Strategy Integration (2-3 days) + +1. Implement `ContextParallelStrategy` class +2. Modify CLI entry points for strategy selection +3. Add configuration options +4. End-to-end test with real multi-device setup + +### Phase 5: Verification & Optimization (2-3 days) + +1. Benchmark against RingStrategy baseline +2. Memory profiling for 128K+ contexts +3. Documentation updates + +--- + +## 7. Verification Plan + +### 7.1 Unit Tests + +**Sharding Tests** (`tests/subsystems/test_cp_sharding.py`): + +```bash +uv run pytest tests/subsystems/test_cp_sharding.py -v +``` + +- Test load-balanced partitioning produces equal-sized chunks +- Test round-trip shard → unshard preserves data +- Test chunk indices are correct for causal masking + +**Merge Attention Tests** (`tests/subsystems/test_cp_merge.py`): + +```bash +uv run pytest tests/subsystems/test_cp_merge.py -v +``` + +- Test merging 2 partial outputs matches full attention +- Test numerical stability with extreme max scores +- Test empty partials handling + +**Heuristic Tests** (`tests/subsystems/test_cp_heuristics.py`): + +```bash +uv run pytest tests/subsystems/test_cp_heuristics.py -v +``` + +- Test pass-KV selected for full prefill +- Test pass-Q selected for decode +- Test boundary conditions at GQA threshold + +### 7.2 Integration Tests + +**Ring Communication** (`tests/integration/test_cp_ring.py`): + +```bash +uv run pytest tests/integration/test_cp_ring.py -v +``` + +- Test 4-rank ring with mock discovery +- Test simultaneous send/recv completes +- Test graceful handling of rank failure + +### 7.3 CI Workflow for Coordinated Multi-Runner E2E Tests + +Since dnet has 2 self-hosted macOS runners (`mac2.metal`), we can design a workflow that **coordinates both runners** for CP e2e tests: + +**Approach**: Use a **hostfile + static discovery** pattern (similar to `test-static-discovery.yml`) where: + +1. Both runners register their IPs to a shared artifact +2. One runner acts as API + Shard 1, the other as Shard 2 +3. Static hostfile enables cross-runner communication + +```yaml +# .github/workflows/test-context-parallel.yml +name: Test Context Parallelism E2E + +on: + workflow_dispatch: # Manual trigger for expensive e2e tests + schedule: + - cron: '0 6 * * 1' # Weekly on Monday 6AM UTC + +jobs: + # Job 1: Coordination - creates hostfile and waits for both runners + coordinate: + runs-on: ubuntu-latest + outputs: + hostfile: ${{ steps.gen.outputs.hostfile }} + steps: + - id: gen + run: echo "hostfile=will be generated dynamically" >> $GITHUB_OUTPUT + + # Job 2: Runner A - API node + Shard 1 (CP rank 0) + runner-a: + runs-on: mac2.metal # First self-hosted runner + needs: coordinate + env: + RUNNER_ROLE: shard1_and_api + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Setup Environment + uses: ./.github/actions/setup-env + + - name: Get Runner IP + id: ip + run: echo "ip=$(ipconfig getifaddr en0 || echo 127.0.0.1)" >> $GITHUB_OUTPUT + + - name: Upload IP for coordination + uses: actions/upload-artifact@v4 + with: + name: runner-a-ip + path: ${{ steps.ip.outputs.ip }} + + - name: Wait for Runner B IP + uses: actions/download-artifact@v4 + with: + name: runner-b-ip + path: ./runner-b-ip + continue-on-error: true + timeout-minutes: 5 + + - name: Start Shard 1 + run: | + uv run dnet-shard --http-port 8081 --grpc-port 58081 --shard-name cp-shard-0 & + sleep 5 + + - name: Create hostfile + run: | + echo "cp-shard-0 ${{ steps.ip.outputs.ip }} 8081 58081" > hostfile + cat ./runner-b-ip >> hostfile 2>/dev/null || echo "# Runner B not ready" + + - name: Start API with CP enabled + run: | + DNET_CP_ENABLED=true uv run dnet-api --http-port 8080 --grpc-port 58080 --hostfile hostfile & + sleep 10 + + - name: Run CP E2E test + run: | + uv run python scripts/test_cp_e2e.py --context-length 32768 + + # Job 3: Runner B - Shard 2 (CP rank 1) + runner-b: + runs-on: mac2.metal # Second self-hosted runner (if labeled differently) + needs: coordinate + env: + RUNNER_ROLE: shard2 + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Setup Environment + uses: ./.github/actions/setup-env + + - name: Get Runner IP + id: ip + run: echo "ip=$(ipconfig getifaddr en0)" >> $GITHUB_OUTPUT + + - name: Upload IP + run: echo "cp-shard-1 ${{ steps.ip.outputs.ip }} 8082 58082" > runner-b-ip.txt + - uses: actions/upload-artifact@v4 + with: + name: runner-b-ip + path: runner-b-ip.txt + + - name: Start Shard 2 and wait + run: | + uv run dnet-shard --http-port 8082 --grpc-port 58082 --shard-name cp-shard-1 +``` + +> [!WARNING] +> **Challenge**: GitHub Actions artifact uploads/downloads add latency. For reliable coordination, consider: +> +> 1. Use a shared storage (S3/GCS) for IP exchange +> 2. Add retry logic for artifact downloads +> 3. Increase timeouts for cross-runner synchronization + +### 7.4 Manual Verification (Local Development) + +**Single-machine test** (2 shards on localhost): + +```bash +# Terminal 1: Shard 1 +uv run dnet-shard --http-port 8081 --grpc-port 58081 --shard-name cp-shard-0 + +# Terminal 2: Shard 2 +uv run dnet-shard --http-port 8082 --grpc-port 58082 --shard-name cp-shard-1 + +# Terminal 3: Create hostfile and start API +echo "cp-shard-0 127.0.0.1 8081 58081" > hostfile +echo "cp-shard-1 127.0.0.1 8082 58082" >> hostfile +DNET_CP_ENABLED=true uv run dnet-api --http-port 8080 --grpc-port 58080 --hostfile hostfile + +# Terminal 4: Test +curl -X POST http://localhost:8080/v1/prepare_topology \ + -H "Content-Type: application/json" \ + -d '{"model": "Qwen/Qwen3-4B-MLX-4bit", "strategy": "context_parallel"}' +``` + +**Cross-machine test** (2 Apple Silicon devices on same network): + +1. Note IPs of both machines (e.g., `192.168.1.10`, `192.168.1.11`) +2. Start shards on each machine with their respective IPs +3. Create hostfile on API machine with both shard entries +4. Verify response coherence and memory distribution + +--- + +## 8. Risks and Mitigations + +| Risk | Mitigation | +|---------------------------------------|-------------------------------------------------------------------------| +| Thunderbolt bandwidth insufficient | Profile actual bandwidth; fall back to pipeline if CP overhead too high | +| Merge attention numerical instability | Use log-space accumulation; add extensive numerical tests | +| All2All latency for pass-Q | Implement async All2All; consider hierarchical reduction | +| Model too large for full replication | CP requires full model per device; document minimum memory requirements | + +--- + +## 9. Future Work + +1. **Hybrid CP + PP**: Combine context and pipeline parallelism for very large models with long contexts +2. **Speculative Decoding**: Leverage CP for parallel draft generation +3. **Persistent KV Cache**: Optimize multi-turn conversations with sharded persistent cache +4. **Training Support**: Extend CP to gradient computation + +--- + +## 10. References + +1. Liu et al., "Ring Attention with Blockwise Transformers for Near-Infinite Context" (arXiv:2310.01889) +2. Yang et al., "Context Parallelism for Scalable Million-Token Inference" (arXiv:2411.01783) +3. [dnet Repository](https://github.com/firstbatchxyz/dnet) diff --git a/scripts/cp_utils.py b/scripts/cp_utils.py new file mode 100644 index 00000000..b8f817fc --- /dev/null +++ b/scripts/cp_utils.py @@ -0,0 +1,120 @@ +""" +Shared utilities for Context Parallelism scripts. + +Common functionality for prepare_cp_model.py and stress_test_cp.py. +""" + +from functools import lru_cache +from typing import Literal + +import requests +from dnet_p2p import DnetDeviceProperties + +from dnet.api.models import ManualDevice +from dnet.config import DnetSettings + + +@lru_cache(maxsize=1) +def _fetch_settings(api_url: str) -> DnetSettings | None: + """Fetch and cache settings from API as typed DnetSettings.""" + try: + response = requests.get(f"{api_url}/v1/settings", timeout=5) + if response.status_code == 200: + return DnetSettings.model_validate(response.json()) + except (requests.RequestException, Exception): + pass + return None + + +def get_kv_bits_from_server(api_url: str) -> Literal["4bit", "8bit", "fp16"]: + """Get kv_bits from server settings via API.""" + settings = _fetch_settings(api_url) + if settings: + mode = settings.kv_cache.mode + if mode in ("4bit", "8bit", "fp16"): + return mode # type: ignore + return "8bit" + + +def get_devices(api_url: str) -> dict[str, DnetDeviceProperties]: + """Fetch available devices from API. Returns {instance: DnetDeviceProperties}.""" + response = requests.get(f"{api_url}/v1/devices") + response.raise_for_status() + data = response.json() + devices_raw = data.get("devices", {}) + return { + instance: DnetDeviceProperties(**props) + for instance, props in devices_raw.items() + } + + +def get_shards(api_url: str) -> list[ManualDevice]: + """Get shard devices (non-managers) as ManualDevice list.""" + devices = get_devices(api_url) + shards = [] + for instance, props in devices.items(): + if props.is_manager: + continue + shards.append( + ManualDevice( + instance=instance, + local_ip=props.local_ip, + server_port=props.server_port, + shard_port=props.shard_port, + ) + ) + return shards + + +def get_topology(api_url: str) -> dict | None: + """Fetch current topology from API. Returns None if not set.""" + try: + response = requests.get(f"{api_url}/v1/topology") + if response.status_code == 200: + return response.json() + except requests.RequestException: + pass + return None + + +def get_api_settings(api_url: str) -> DnetSettings | None: + """Fetch settings from API as typed DnetSettings. + + Note: Uses cached _fetch_settings internally. + """ + return _fetch_settings(api_url) + + +def is_cp_enabled(api_url: str) -> bool: + """Check if context parallelism is enabled on the API server.""" + settings = _fetch_settings(api_url) + if settings: + return settings.context_parallel.enabled + return False + + +def get_recommended_test_sizes(num_shards: int) -> list[int]: + """Get recommended context sizes for CP testing based on shard count. + + Based on design doc memory table: + - Single device (24GB): ~32K comfortable, 128K tight + - 2 devices: can handle 128K+ distributed + - 4 devices: can handle 256K+ distributed + + Returns context lengths that should stress-test CP properly. + """ + if num_shards <= 1: + # Single device - test up to comfortable limit + return [1000, 4000, 8000, 16000, 32000] + elif num_shards == 2: + # 2 shards - test beyond single-device capacity + return [8000, 16000, 32000, 48000, 64000, 96000] + else: + # 3+ shards - test long contexts + return [16000, 32000, 64000, 96000, 128000] + + +# Context length thresholds (from design doc) +SINGLE_DEVICE_COMFORTABLE = 32000 # ~4GB KV cache +SINGLE_DEVICE_TIGHT = 128000 # ~16GB KV cache +CP_MIN_BENEFIT_THRESHOLD = 32000 # Below this, CP overhead may not be worth it diff --git a/scripts/generate_env_example.py b/scripts/generate_env_example.py index ea801506..fe32e44f 100644 --- a/scripts/generate_env_example.py +++ b/scripts/generate_env_example.py @@ -18,6 +18,7 @@ def main() -> int: from dnet.config import ( ApiSettings, ComputeSettings, + ContextParallelSettings, GrpcSettings, KVCacheSettings, LoggingSettings, @@ -46,6 +47,7 @@ def main() -> int: ("Transport", TransportSettings), ("Compute", ComputeSettings), ("KV Cache", KVCacheSettings), + ("Context Parallelism", ContextParallelSettings), ("gRPC", GrpcSettings), ("Storage", StorageSettings), ] diff --git a/scripts/generate_protos.py b/scripts/generate_protos.py index 42c617a8..8b80bd8e 100755 --- a/scripts/generate_protos.py +++ b/scripts/generate_protos.py @@ -3,6 +3,7 @@ import glob import os +import re from pathlib import Path from grpc_tools import protoc @@ -37,6 +38,7 @@ def generate_protos() -> None: if ret != 0: raise RuntimeError(f"protoc failed for {proto_file}") + # Fix imports in grpc file pb2 = get_pb2_module_name(proto_file) grpc_file = f"{OUT_DIR}/{pb2}_grpc.py" @@ -49,6 +51,22 @@ def generate_protos() -> None: print(f"Fixed imports in {grpc_file}") + # Fix cross-proto imports in all pb2 files + # (e.g., import dnet_cp_pb2 -> from . import dnet_cp_pb2) + for pb2_file in glob.glob(os.path.join(OUT_DIR, "*_pb2.py")): + with open(pb2_file, "r+") as f: + content = f.read() + # Match bare imports like "import foo_pb2 as foo__pb2" + # and convert to relative imports + pattern = r"^import (\w+_pb2) as (\w+)$" + replacement = r"from . import \1 as \2" + new_content = re.sub(pattern, replacement, content, flags=re.MULTILINE) + if new_content != content: + f.seek(0) + f.write(new_content) + f.truncate() + print(f"Fixed cross-proto imports in {pb2_file}") + if __name__ == "__main__": generate_protos() diff --git a/scripts/needle_in_haystack.py b/scripts/needle_in_haystack.py new file mode 100644 index 00000000..ee1dc3e5 --- /dev/null +++ b/scripts/needle_in_haystack.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" +Needle in a Haystack test for Context Parallelism validation. + +This test verifies that the model can attend to ALL positions in a long context, +which is essential for validating that CP is working correctly. + +If CP is broken (ranks only see their chunk), the model will fail to find the needle. +This test works the best with non-thinking models such as mlx-community/Llama-3.2-3B-Instruct-4bit + +Usage: + uv run python scripts/needle_in_haystack.py --api http://localhost:8080 --context-size 4096 +""" + +import argparse +import random +import time +import httpx + +# The "needle" - a specific fact we hide in the haystack +NEEDLE_TEMPLATE = "The secret password is: {password}" + +# Filler text for the haystack (Paul Graham essays style) +HAYSTACK_CHUNKS = [ + "The most important thing in a startup is to launch quickly. " + "You can always iterate and improve later, but you need to get " + "something out there to learn from real users. ", + "Good ideas look like bad ideas at first. If they looked obviously " + "good, someone else would already be doing them. The trick is to " + "recognize the good ideas that look bad. ", + "Startups are about growth. A startup is a company designed to grow " + "fast. Being newly founded does not in itself make a company a " + "startup. Nor is it necessary for a startup to work on technology. ", + "The way to get startup ideas is not to try to think of startup " + "ideas. It's to look for problems, preferably problems you have " + "yourself. The very best startup ideas tend to have three things " + "in common: they're something the founders themselves want. ", + "Work on hard problems. If you're working on something that seems " + "really hard, you're probably working on something that matters. " + "Easy problems have already been solved. ", +] + + +def generate_password() -> str: + """Generate a random memorable password.""" + words = [ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "gamma", + "hotel", + "india", + "juliet", + "kilo", + "lima", + ] + return f"{random.choice(words)}-{random.randint(100, 999)}-{random.choice(words)}" + + +def generate_haystack(target_tokens: int, needle: str, needle_position: float) -> str: + """ + Generate a haystack of approximately target_tokens with needle at specified position. + + Args: + target_tokens: Approximate number of tokens for the haystack + needle: The needle text to hide + needle_position: Where to place needle (0.0 = start, 0.5 = middle, 1.0 = end) + + Returns: + Full haystack text with needle inserted + """ + # Rough estimate: 4 chars per token + target_chars = target_tokens * 4 + + # Build haystack chunks + haystack_parts = [] + current_chars = 0 + + while current_chars < target_chars: + chunk = random.choice(HAYSTACK_CHUNKS) + haystack_parts.append(chunk) + current_chars += len(chunk) + + # Determine needle insertion point + needle_idx = int(len(haystack_parts) * needle_position) + needle_idx = max(1, min(needle_idx, len(haystack_parts) - 1)) # Avoid edges + + # Insert needle + haystack_parts.insert(needle_idx, f"\n\n{needle}\n\n") + + return "".join(haystack_parts) + + +def run_needle_test( + api_url: str, + context_size: int, + needle_position: float, + timeout: float = 120.0, + model: str = "default", +) -> dict: + """ + Run a single needle in haystack test. + + Returns: + dict with test results including success, response, latency + """ + # Generate test case + password = generate_password() + needle = NEEDLE_TEMPLATE.format(password=password) + haystack = generate_haystack(context_size, needle, needle_position) + + # Build prompt + prompt = f"""Read the following document carefully. At some point, there is a secret password mentioned. + + +{haystack} + + +What is the secret password mentioned in the document above? Reply with ONLY the password, nothing else.""" + + # Estimate actual token count + approx_tokens = len(prompt) // 4 + + print(f"\n{'=' * 60}") + print("Needle in Haystack Test") + print(f"{'=' * 60}") + print(f"Target context: ~{context_size} tokens") + print(f"Actual prompt: ~{approx_tokens} tokens") + print(f"Needle position: {needle_position:.0%}") + print(f"Expected password: {password}") + print(f"{'=' * 60}") + + # Make API request + start_time = time.time() + + try: + with httpx.Client(timeout=timeout) as client: + response = client.post( + f"{api_url}/v1/chat/completions", + json={ + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 256, # Qwen3 uses thinking mode, needs more tokens + "temperature": 0.0, # Deterministic + }, + ) + response.raise_for_status() + result = response.json() + except Exception as e: + return { + "success": False, + "error": str(e), + "latency_s": time.time() - start_time, + "expected": password, + "actual": None, + } + + latency = time.time() - start_time + + # Extract response + try: + actual_response = result["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError): + actual_response = str(result) + + # Check if password is in response + success = password.lower() in actual_response.lower() + + print(f"Response: {actual_response}") + print(f"Latency: {latency:.2f}s") + print(f"Result: {'✓ PASS' if success else '✗ FAIL'}") + + return { + "success": success, + "expected": password, + "actual": actual_response, + "latency_s": latency, + "context_tokens": approx_tokens, + "needle_position": needle_position, + } + + +def run_full_test_suite( + api_url: str, + context_sizes: list[int], + timeout: float, + model: str = "default", +) -> None: + """Run full test suite across context sizes and needle positions.""" + positions = [0.75] # Test needle at different depths + + results = [] + + for ctx_size in context_sizes: + for pos in positions: + result = run_needle_test(api_url, ctx_size, pos, timeout, model=model) + result["target_context"] = ctx_size + results.append(result) + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + + passed = sum(1 for r in results if r["success"]) + total = len(results) + + print(f"Passed: {passed}/{total}") + + # Group by context size + by_size: dict[int, list[dict]] = {} + for r in results: + size = r.get("target_context", 0) + if size not in by_size: + by_size[size] = [] + by_size[size].append(r) + + for size in sorted(by_size.keys()): + size_results = by_size[size] + size_passed = sum(1 for r in size_results if r["success"]) + avg_latency = sum(r["latency_s"] for r in size_results) / len(size_results) + print( + f" {size:>6} tokens: {size_passed}/{len(size_results)} passed, avg {avg_latency:.1f}s" + ) + + # Overall verdict + if passed == total: + print("\n✓ ALL TESTS PASSED - CP is working correctly!") + elif passed > total // 2: + print("\n⚠ PARTIAL PASS - Some positions may have issues") + else: + print("\n✗ TESTS FAILED - CP may not be attending to full context") + + +def main(): + parser = argparse.ArgumentParser( + description="Needle in a Haystack test for CP validation" + ) + parser.add_argument("--api", default="http://localhost:8080", help="API server URL") + parser.add_argument( + "--context-size", + type=int, + default=None, + help="Single context size to test (default: run full suite)", + ) + parser.add_argument( + "--sizes", + default="512,1024,2048,4096,8192,16384,32768", + help="Comma-separated context sizes for full suite", + ) + parser.add_argument( + "--position", + type=float, + default=0.5, + help="Needle position (0.0-1.0) for single test", + ) + parser.add_argument( + "--timeout", type=float, default=300.0, help="Request timeout in seconds" + ) + + parser.add_argument( + "--model", + default="default", + help="Model name to use for requests", + ) + + args = parser.parse_args() + + if args.context_size: + # Single test + run_needle_test( + args.api, args.context_size, args.position, args.timeout, model=args.model + ) + else: + # Full suite + sizes = [int(s.strip()) for s in args.sizes.split(",")] + run_full_test_suite(args.api, sizes, args.timeout, model=args.model) + + +if __name__ == "__main__": + main() diff --git a/scripts/prepare_cp_model.py b/scripts/prepare_cp_model.py new file mode 100644 index 00000000..864a9874 --- /dev/null +++ b/scripts/prepare_cp_model.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +""" +Prepare and load model for Context Parallelism (CP). + +Unlike ring/pipeline parallelism where each shard gets non-overlapping layers, +CP loads ALL layers on ALL shards. Each shard processes a portion of the +context window (sequence dimension) while maintaining the full model. + +Usage: + uv run scripts/prepare_cp_model.py Qwen/Qwen3-4B-MLX-4bit + uv run scripts/prepare_cp_model.py Qwen/Qwen3-4B-MLX-4bit --shards m4s1,m4s2 + +The ModelManager will automatically assign CP ranks based on device order: + - rank 0: first device in list + - rank 1: second device in list + - etc. + +For two-device CP, each device handles half the context window. +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Literal + +# Add project root to sys.path to allow imports from scripts package +sys.path.append(str(Path(__file__).parent.parent)) + +import requests + +from dnet.api.models import ManualDevice, PrepareTopologyManualRequest +from dnet.core.types.topology import LayerAssignment + +from scripts.cp_utils import get_kv_bits_from_server, get_devices + + +def get_model_config(model: str) -> dict: + """Fetch model config from HuggingFace to get num_layers.""" + try: + from huggingface_hub import hf_hub_download + + local_path = hf_hub_download( + repo_id=model, + filename="config.json", + ) + with open(local_path) as f: + return json.load(f) + except Exception as e: + print(f"Warning: Could not fetch model config from HuggingFace: {e}") + return {} + + +def prepare_cp_topology( + api_url: str, + model: str, + devices: list[ManualDevice], + num_layers: int, + seq_len: int, + kv_bits: Literal["4bit", "8bit", "fp16"], +) -> dict: + """Prepare manual topology for CP mode (all shards get all layers).""" + all_layers = list(range(num_layers)) + + # For CP, each device gets ALL layers (full model replication) + assignments: list[LayerAssignment] = [] + for i, device in enumerate(devices): + next_idx = (i + 1) % len(devices) + next_instance = devices[next_idx].instance + + assignments.append( + LayerAssignment( + instance=device.instance, + layers=[all_layers], + window_size=num_layers, + residency_size=num_layers, + next_instance=next_instance, + ) + ) + + request = PrepareTopologyManualRequest( + model=model, + devices=devices, + assignments=assignments, + num_layers=num_layers, + max_position_embeddings=seq_len, + kv_bits=kv_bits, + ) + + response = requests.post( + f"{api_url}/v1/prepare_topology_manual", + json=request.model_dump(), + ) + response.raise_for_status() + return response.json() + + +def load_model(api_url: str, model: str) -> dict: + """Load model on all shards.""" + response = requests.post(f"{api_url}/v1/load_model", json={"model": model}) + response.raise_for_status() + return response.json() + + +def main(): + parser = argparse.ArgumentParser( + description="Prepare and load model for Context Parallelism", + epilog=""" + +Examples: + # Auto-discover all shards and use them for CP + uv run scripts/prepare_cp_model.py Qwen/Qwen3-4B-MLX-4bit + + # Use specific shards for CP + uv run scripts/prepare_cp_model.py Qwen/Qwen3-4B-MLX-4bit --shards m4s1,m4s2 + + # Use custom API URL + uv run scripts/prepare_cp_model.py Qwen/Qwen3-4B-MLX-4bit --api http://10.0.0.1:8080 + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "model", + type=str, + help="Model name or HuggingFace repo ID (e.g., Qwen/Qwen3-4B-MLX-4bit)", + ) + parser.add_argument( + "--api", + type=str, + default="http://localhost:8080", + help="API server URL (default: http://localhost:8080)", + ) + parser.add_argument( + "--shards", + type=str, + default=None, + help="Comma-separated shard instance names (default: all available)", + ) + parser.add_argument( + "--seq-len", + type=int, + default=None, + help="Sequence length (default: from model config or 8192)", + ) + args = parser.parse_args() + + api_url = args.api.rstrip("/") + + # Get kv_bits from server settings + kv_bits = get_kv_bits_from_server(api_url) + + # Step 1: Discover devices + print(f"[1/4] Fetching available devices from {api_url}...") + try: + devices_dict = get_devices(api_url) + except requests.RequestException as e: + print(f"Error: Could not connect to API at {api_url}: {e}") + sys.exit(1) + + # Build typed ManualDevice list, filtering out managers + all_devices: list[ManualDevice] = [] + for instance, props in devices_dict.items(): + if props.is_manager: + continue + all_devices.append( + ManualDevice( + instance=instance, + local_ip=props.local_ip, + server_port=props.server_port, + shard_port=props.shard_port, + ) + ) + + if not all_devices: + print("Error: No shards available. Make sure shard nodes are running.") + sys.exit(1) + + # Filter by requested shards if specified + shards = all_devices + if args.shards: + requested = set(args.shards.split(",")) + shards = [d for d in all_devices if d.instance in requested] + if not shards: + print(f"Error: None of the requested shards found: {args.shards}") + print(f"Available: {[d.instance for d in all_devices]}") + sys.exit(1) + + print(f" Using {len(shards)} shard(s) for Context Parallelism:") + for i, s in enumerate(shards): + print(f" [{i}] {s.instance} ({s.local_ip}:{s.server_port})") + + # Step 2: Get model config + print(f"[2/4] Fetching model config for {args.model}...") + model_config = get_model_config(args.model) + + num_layers = model_config.get("num_hidden_layers") or model_config.get("n_layers") + if not num_layers: + print("Error: Could not determine number of layers from model config.") + sys.exit(1) + + print(f" Model has {num_layers} layers (full model on each shard)") + + seq_len = args.seq_len + if seq_len is None: + seq_len = model_config.get("max_position_embeddings") or 8192 + print(f" Sequence length: {seq_len}") + + # Step 3: Prepare topology + print("[3/4] Preparing CP topology...") + try: + topology = prepare_cp_topology( + api_url=api_url, + model=args.model, + devices=shards, + num_layers=num_layers, + seq_len=seq_len, + kv_bits=kv_bits, + ) + print(" Topology prepared successfully") + print(f" Model: {topology.get('model')}") + assignments = topology.get("assignments", []) + print(f" Devices: {[a.get('instance') for a in assignments]}") + except requests.RequestException as e: + print(f"Error: Failed to prepare topology: {e}") + sys.exit(1) + + # Step 4: Load model + print("[4/4] Loading model on all shards (this may take a while)...") + try: + result = load_model(api_url, args.model) + print(" Model loaded successfully!") + print() + print("=" * 60) + print("Context Parallelism Ready") + print("=" * 60) + print(f" Model: {args.model}") + print(f" CP Ranks: {len(shards)}") + print(f" Shards: {', '.join(s.instance for s in shards)}") + print(f" KV Bits: {kv_bits}") + print(f" Seq Len: {seq_len}") + print() + print(f"Each shard has the full model and will process 1/{len(shards)} of") + print("the context window during inference.") + print() + + for status in result.get("shard_statuses", []): + success = "✓" if status.get("success") else "✗" + print( + f" {success} {status.get('instance')}: {status.get('message', 'OK')}" + ) + + except requests.RequestException as e: + print(f"Error: Failed to load model: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/stress_test_cp.py b/scripts/stress_test_cp.py new file mode 100644 index 00000000..12f1239b --- /dev/null +++ b/scripts/stress_test_cp.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 +""" +Stress test for Context Parallelism via the chat completions endpoint. + +Sends requests with varying prompt lengths to test CP's ability to handle +long contexts distributed across shards. + +Usage: + uv run scripts/stress_test_cp.py + uv run scripts/stress_test_cp.py --api http://10.0.0.1:8080 --max-tokens 1000 +""" + +import argparse +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +# Add project root to sys.path to allow imports from scripts package +sys.path.append(str(Path(__file__).parent.parent)) + +import requests + +from dnet.api.models import ChatMessage, ChatRequestModel, ChatResponseModel + +from scripts.cp_utils import ( + get_shards, + get_topology, + get_recommended_test_sizes, + is_cp_enabled, +) + + +@dataclass +class TestResult: + """Result of a single stress test run.""" + + context_length: int + prompt_chars: int + success: bool + total_time_s: float + time_to_first_token_s: Optional[float] = None + num_chunks: Optional[int] = None + response: Optional[ChatResponseModel] = None + error: Optional[str] = None + stream: bool = False + + +def generate_long_prompt(target_tokens: int) -> str: + """Generate a prompt of approximately target_tokens length. + + Uses repetitive text to reach target length. Rough estimate: 1 token ≈ 4 chars. + """ + base_text = ( + "The quick brown fox jumps over the lazy dog. " + "Pack my box with five dozen liquor jugs. " + "How vexingly quick daft zebras jump. " + ) + target_chars = target_tokens * 4 + repetitions = max(1, target_chars // len(base_text)) + return base_text * repetitions + + +def run_chat_request( + api_url: str, + prompt: str, + context_length: int, + max_tokens: int = 50, + stream: bool = False, + timeout: int = 3600, # 60 min for long contexts (64K+ tokens) +) -> TestResult: + """Send a chat completion request and return typed TestResult.""" + request = ChatRequestModel( + model="default", + messages=[ + ChatMessage(role="user", content=prompt), + ], + max_tokens=max_tokens, + stream=stream, + temperature=0.7, + ) + + prompt_chars = len(prompt) + start_time = time.time() + + if stream: + try: + response = requests.post( + f"{api_url}/v1/chat/completions", + json=request.model_dump(), + stream=True, + timeout=timeout, + ) + if not response.ok: + return TestResult( + context_length=context_length, + prompt_chars=prompt_chars, + success=False, + total_time_s=time.time() - start_time, + error=f"{response.status_code} {response.reason}: {response.text}", + stream=True, + ) + + chunks = [] + first_token_time: Optional[float] = None + for line in response.iter_lines(): + if line: + decoded = line.decode("utf-8") + if decoded.startswith("data: ") and decoded != "data: [DONE]": + if first_token_time is None: + first_token_time = time.time() + chunks.append(decoded[6:]) + + end_time = time.time() + return TestResult( + context_length=context_length, + prompt_chars=prompt_chars, + success=True, + total_time_s=end_time - start_time, + time_to_first_token_s=(first_token_time - start_time) + if first_token_time + else None, + num_chunks=len(chunks), + stream=True, + ) + except requests.RequestException as e: + return TestResult( + context_length=context_length, + prompt_chars=prompt_chars, + success=False, + total_time_s=time.time() - start_time, + error=str(e), + stream=True, + ) + + else: + try: + response = requests.post( + f"{api_url}/v1/chat/completions", + json=request.model_dump(), + timeout=timeout, + ) + if not response.ok: + return TestResult( + context_length=context_length, + prompt_chars=prompt_chars, + success=False, + total_time_s=time.time() - start_time, + error=f"{response.status_code} {response.reason}: {response.text}", + stream=False, + ) + + end_time = time.time() + chat_response = ChatResponseModel.model_validate(response.json()) + return TestResult( + context_length=context_length, + prompt_chars=prompt_chars, + success=True, + total_time_s=end_time - start_time, + response=chat_response, + stream=False, + ) + except requests.RequestException as e: + return TestResult( + context_length=context_length, + prompt_chars=prompt_chars, + success=False, + total_time_s=time.time() - start_time, + error=str(e), + stream=False, + ) + + +def run_stress_test( + api_url: str, + context_lengths: list[int], + max_tokens: int, + stream: bool, + verbose: bool, +) -> list[TestResult]: + """Run stress tests with varying context lengths.""" + results: list[TestResult] = [] + + for ctx_len in context_lengths: + print(f"\n[Test] Context length: ~{ctx_len:,} tokens") + prompt = generate_long_prompt(ctx_len) + actual_chars = len(prompt) + print(f" Prompt: {actual_chars:,} chars (~{actual_chars // 4:,} tokens)") + + try: + result = run_chat_request( + api_url=api_url, + prompt=prompt, + context_length=ctx_len, + max_tokens=max_tokens, + stream=stream, + ) + results.append(result) + + if result.success: + print(f" ✓ Success in {result.total_time_s:.2f}s") + if stream and result.time_to_first_token_s: + print( + f" Time to first token: {result.time_to_first_token_s:.2f}s" + ) + if verbose and not stream and result.response: + resp = result.response + if resp.choices: + msg = resp.choices[0].message + content = msg.content if msg else "" + print(f" Response: {content[:100]}...") + if resp.usage: + print( + f" Tokens: prompt={resp.usage.prompt_tokens}, completion={resp.usage.completion_tokens}" + ) + except requests.RequestException as e: + print(f" ✗ Failed: {e}") + results.append( + TestResult( + context_length=ctx_len, + prompt_chars=len(prompt), + success=False, + total_time_s=0.0, + error=str(e), + ) + ) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Stress test Context Parallelism via chat endpoint", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--api", + type=str, + default="http://localhost:8080", + help="API server URL (default: http://localhost:8080)", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Max tokens to generate (default: 100)", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Use streaming responses", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Show response content", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Quick test with small context lengths only", + ) + parser.add_argument( + "--sizes", + type=str, + default=None, + help="Comma-separated context sizes to test (default: auto based on shard count)", + ) + args = parser.parse_args() + + api_url = args.api.rstrip("/") + + print("=" * 60) + print("Context Parallelism Stress Test") + print("=" * 60) + print(f"API: {api_url}") + print(f"Max tokens: {args.max_tokens}") + print(f"Streaming: {args.stream}") + + # Get shard count for test size recommendations + print("\n[Check] Detecting shards...") + try: + shards = get_shards(api_url) + num_shards = len(shards) + print(f" Found {num_shards} shard(s):") + for s in shards: + print(f" - {s.instance} ({s.local_ip}:{s.server_port})") + except requests.RequestException as e: + print(f" Warning: Could not fetch shards: {e}") + num_shards = 1 + + # Verify model is loaded + print("\n[Check] Verifying model is loaded...") + topo = get_topology(api_url) + if topo: + print(f" Model: {topo.get('model', 'unknown')}") + else: + print(" Warning: Could not fetch topology") + + # Check if CP is enabled + print("\n[Check] Checking CP settings...") + cp_enabled = is_cp_enabled(api_url) + if cp_enabled: + print(" ✓ Context Parallelism is ENABLED") + else: + print(" ⚠ Context Parallelism is DISABLED (DNET_CP_ENABLED=false)") + print(" Tests will run in single-device mode") + + # Determine test context lengths + if args.sizes: + context_lengths = [int(s.strip()) for s in args.sizes.split(",")] + elif args.quick: + context_lengths = [100, 500, 1000] + else: + context_lengths = get_recommended_test_sizes(num_shards) + + print(f"\nTest sizes: {context_lengths}") + if num_shards > 1: + print( + f"(Recommended for {num_shards} shards - includes sizes that benefit from CP)" + ) + + # Run tests + results = run_stress_test( + api_url=api_url, + context_lengths=context_lengths, + max_tokens=args.max_tokens, + stream=args.stream, + verbose=args.verbose, + ) + + # Summary + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + + successful = [r for r in results if r.success] + failed = [r for r in results if not r.success] + + print(f"Tests passed: {len(successful)}/{len(results)}") + print(f"Shards used: {num_shards}") + + if successful: + times = [r.total_time_s for r in successful] + print(f"Avg time: {sum(times) / len(times):.2f}s") + print(f"Max time: {max(times):.2f}s") + + print("\nDetails:") + print(f"{'Context':<10} {'Time':<10} {'TTFT':<10} {'Tokens/s':<10}") + print("-" * 45) + for r in successful: + tokens_per_sec = "" + if r.response and r.response.usage: + total_tokens = ( + r.response.usage.prompt_tokens + r.response.usage.completion_tokens + ) + tps = total_tokens / r.total_time_s + tokens_per_sec = f"{tps:.1f}" + + ttft = f"{r.time_to_first_token_s:.2f}s" if r.time_to_first_token_s else "-" + print( + f"{r.context_length:<10} {r.total_time_s:<10.2f} {ttft:<10} {tokens_per_sec:<10}" + ) + + if failed: + print("\nFailed tests:") + for r in failed: + err = r.error or "unknown error" + print(f" - {r.context_length:,} tokens: {err}") + + sys.exit(0 if not failed else 1) + + +if __name__ == "__main__": + main() diff --git a/src/cli/api.py b/src/cli/api.py index 93ec9ae3..55f4870e 100644 --- a/src/cli/api.py +++ b/src/cli/api.py @@ -59,10 +59,19 @@ def _signal_handler(*_: object) -> None: discovery.create_instance(node_id, http_port, grpc_port, is_manager=True) await discovery.async_start() - # Components + # Components - select strategy based on config + from dnet.config import get_settings + from dnet.api.strategies.base import Strategy from dnet.api.strategies.ring import RingStrategy + from dnet.api.strategies.context_parallel import ContextParallelStrategy - strategy = RingStrategy() # ContextParallelStrategy() + settings = get_settings() + strategy: Strategy + if settings.context_parallel.enabled: + logger.info("Context parallelism enabled - using ContextParallelStrategy") + strategy = ContextParallelStrategy() + else: + strategy = RingStrategy() def update_tui_model_info( model_name: Optional[str], layers: int, loaded: bool diff --git a/src/cli/shard.py b/src/cli/shard.py index d0aa1be7..2a376489 100644 --- a/src/cli/shard.py +++ b/src/cli/shard.py @@ -34,11 +34,29 @@ async def serve( discovery = AsyncDnetP2P("lib/dnet-p2p/lib") # Core - use instance_name for runtime to align logs/metrics with discovery name runtime = ShardRuntime(shard_id=instance_name, queue_size=queue_size) - adapter = RingAdapter(runtime=runtime, discovery=discovery) + + # Select adapter based on CP config + from dnet.config import get_settings + from dnet.shard.adapters.base import TopologyAdapter + + settings = get_settings() + adapter: TopologyAdapter + if settings.context_parallel.enabled: + from dnet.shard.adapters.context_parallel import CPAdapter + + logger.info("Context parallelism enabled - using CPAdapter") + # Initial defaults; actual rank/logic will be configured by API via LoadModel + adapter = CPAdapter( + runtime=runtime, discovery=discovery, rank_id=0, num_ranks=1 + ) + else: + adapter = RingAdapter(runtime=runtime, discovery=discovery) + shard = Shard(shard_id=shard_id, adapter=adapter) # Servers grpc_server = ShardGrpcServer(shard=shard, grpc_port=grpc_port) + shard.grpc_server = grpc_server # For CP servicer wiring http_server = ShardHTTPServer( shard=shard, http_port=http_port, grpc_port=grpc_port, discovery=discovery ) diff --git a/src/dnet/api/grpc_servicer/server.py b/src/dnet/api/grpc_servicer/server.py index c7b3d8a5..ca224659 100644 --- a/src/dnet/api/grpc_servicer/server.py +++ b/src/dnet/api/grpc_servicer/server.py @@ -4,6 +4,7 @@ from grpc import aio as aio_grpc from dnet.utils.logger import logger +from dnet.utils.grpc_config import GRPC_AIO_OPTIONS from .servicer import ShardApiServicer from ..inference import InferenceManager from dnet.protos.shard_api_comm_pb2_grpc import add_ShardApiServiceServicer_to_server @@ -17,7 +18,7 @@ def __init__(self, grpc_port: int, inference_manager: InferenceManager) -> None: self.servicer = ShardApiServicer(self.inference_manager) async def start(self) -> None: - self.server = aio_grpc.server() + self.server = aio_grpc.server(options=GRPC_AIO_OPTIONS) add_ShardApiServiceServicer_to_server(self.servicer, self.server) listen_addr = f"[::]:{self.grpc_port}" self.server.add_insecure_port(listen_addr) diff --git a/src/dnet/api/http_api.py b/src/dnet/api/http_api.py index 1035d00f..c6d6ab60 100644 --- a/src/dnet/api/http_api.py +++ b/src/dnet/api/http_api.py @@ -91,6 +91,7 @@ async def _setup_routes(self) -> None: methods=["POST"], ) self.app.add_api_route("/v1/devices", self.get_devices, methods=["GET"]) + self.app.add_api_route("/v1/settings", self.get_settings, methods=["GET"]) async def health(self) -> HealthResponse: return HealthResponse( @@ -201,10 +202,26 @@ async def load_model(self, req: APILoadModelRequest) -> APILoadModelResponse: api_callback_address=api_callback_addr, ) if response.success: - first_shard = topology.devices[0] - await self.inference_manager.connect_to_ring( - first_shard.local_ip, first_shard.shard_port, api_callback_addr - ) + # Connect inference manager to shard(s) + # For CP with multiple devices, connect to all ranks + from dnet.api.strategies.context_parallel import CPApiAdapter + + if ( + isinstance(self.inference_manager.adapter, CPApiAdapter) + and len(topology.devices) > 1 + ): + rank_addresses = [ + f"{d.local_ip}:{d.shard_port}" for d in topology.devices + ] + await self.inference_manager.connect_to_cp_ranks( + rank_addresses, api_callback_addr + ) + else: + # Standard ring or single device + first_shard = topology.devices[0] + await self.inference_manager.connect_to_ring( + first_shard.local_ip, first_shard.shard_port, api_callback_addr + ) return response except Exception as e: @@ -240,6 +257,13 @@ async def get_devices(self) -> JSONResponse: } return JSONResponse(content={"devices": devices_dict}) + async def get_settings(self) -> JSONResponse: + """Return current dnet settings (all settings dumped for easy deserialization).""" + from dnet.config import get_settings + + settings = get_settings() + return JSONResponse(content=settings.model_dump()) + async def get_topology(self) -> TopologyInfo: topo = self.cluster_manager.current_topology if topo is None: @@ -387,6 +411,7 @@ async def prepare_topology_manual( model=req.model, kv_bits=req.kv_bits, num_layers=int(num_layers), + max_position_embeddings=req.max_position_embeddings, devices=devices_props, assignments=norm, solution=None, diff --git a/src/dnet/api/inference.py b/src/dnet/api/inference.py index d84cd5a9..25aaa431 100644 --- a/src/dnet/api/inference.py +++ b/src/dnet/api/inference.py @@ -19,6 +19,7 @@ from .model_manager import ModelManager from .strategies.base import ApiAdapterBase from dnet.core.decoding.config import DecodingConfig +from dnet.utils.logger import logger async def arange(count: int): @@ -63,6 +64,30 @@ async def connect_to_ring( await self.adapter.connect_first_shard(first_shard_ip, first_shard_port) self._api_callback_addr = api_callback_addr + async def connect_to_cp_ranks( + self, rank_addresses: list[str], api_callback_addr: str + ) -> None: + """ + Connect to all CP ranks for multi-rank broadcasting. + + Args: + rank_addresses: List of "host:port" strings for each rank. + api_callback_addr: Callback address for shards to send tokens. + """ + from dnet.api.strategies.context_parallel import CPApiAdapter + + if isinstance(self.adapter, CPApiAdapter) and len(rank_addresses) > 1: + await self.adapter.connect_all_ranks(rank_addresses) + logger.info("Connected to %d CP ranks", len(rank_addresses)) + else: + # Fallback to single shard connection + if rank_addresses: + parts = rank_addresses[0].split(":") + ip, port = parts[0], int(parts[1]) + await self.adapter.connect_first_shard(ip, port) + + self._api_callback_addr = api_callback_addr + async def generate_stream(self, req: ChatRequestModel): """ Generator for chat completion chunks. @@ -154,6 +179,12 @@ async def generate_stream(self, req: ChatRequestModel): else 1, ) + # RoPE offset: for prefill, start at 0. For decode, offset by prompt + generated tokens. + # During first iteration, we're processing the prompt from position 0. + # During subsequent iterations, we're adding tokens at position prompt_len + token_idx. + is_prefill = len(tokens) == 0 + rope_start_pos = 0 if is_prefill else len(prompt_tokens) + len(tokens) - 1 + # Send tokens to first shard await self.adapter.send_tokens( tokens=tok_bytes, @@ -162,8 +193,9 @@ async def generate_stream(self, req: ChatRequestModel): logprobs=req.logprobs if req.logprobs else False, top_logprobs=req.top_logprobs if req.top_logprobs else 0, decoding_config=decoding_config, + start_pos=rope_start_pos, ) - result = await self.adapter.await_token(nonce, timeout_s=300.0) + result = await self.adapter.await_token(nonce, timeout_s=3600.0) token = int(result.token_id) # Accumulate logprobs diff --git a/src/dnet/api/model_manager.py b/src/dnet/api/model_manager.py index 609642b4..f9021c92 100644 --- a/src/dnet/api/model_manager.py +++ b/src/dnet/api/model_manager.py @@ -114,12 +114,30 @@ async def load_model( try: # Build API callback address (gRPC). # For internet setups, allow explicit override to avoid advertising 127.0.0.1. - cb_addr = ( + param_api_callback_addr = ( api_callback_address if api_callback_address else f"{api_properties.local_ip}:{grpc_port}" ) + # Calculate Context Parallelism config + # Device list in topology is strictly ordered by ring position + cp_rank_addresses = [ + f"{d.local_ip}:{d.shard_port}" for d in topology.devices + ] + cp_num_ranks = len(cp_rank_addresses) + # Find rank for current instance + try: + # Iterate to find index where instance matches + cp_rank_id = next( + i + for i, d in enumerate(topology.devices) + if d.instance == instance + ) + except StopIteration: + # Should not happen if topology is consistent + cp_rank_id = 0 + # Call load_model via HTTP (window_size unified) url = f"http://{shard_props.local_ip}:{shard_props.server_port}/load_model" @@ -132,7 +150,12 @@ async def load_model( residency_size=assignment.residency_size, total_layers=topology.num_layers, kv_bits=topology.kv_bits, - api_callback_address=cb_addr, + api_callback_address=param_api_callback_addr, + # Context Parallel fields + cp_rank_id=cp_rank_id, + cp_num_ranks=cp_num_ranks, + cp_rank_addresses=cp_rank_addresses, + max_position_embeddings=topology.max_position_embeddings, ).model_dump() # timeout is `None` because shards may actually be downloading weights diff --git a/src/dnet/api/models.py b/src/dnet/api/models.py index e27681ed..29978650 100644 --- a/src/dnet/api/models.py +++ b/src/dnet/api/models.py @@ -4,7 +4,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, Literal from fastapi.responses import JSONResponse -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field + from dnet.core.types.topology import LayerAssignment @@ -108,13 +109,6 @@ def __init__(self, **data: Any): if isinstance(self.stop, str): self.stop = [self.stop] - @field_validator("logprobs") - def non_negative_tokens(cls, v: Any) -> Any: - """Validate logprobs parameter.""" - if v != -1 and not (0 < v <= 10): - raise ValueError(f"logprobs must be between 1 and 10 but got {v:,}") - return v - class ChatUsage(BaseModel): prompt_tokens: int @@ -351,6 +345,10 @@ class PrepareTopologyManualRequest(BaseModel): default=None, description="Total number of layers (optional; inferred if missing)", ) + max_position_embeddings: Optional[int] = Field( + default=None, + description="Override model context length limit (e.g. for RoPE scaling)", + ) class APILoadModelRequest(BaseModel): diff --git a/src/dnet/api/strategies/base.py b/src/dnet/api/strategies/base.py index c502fc96..a2289f7b 100644 --- a/src/dnet/api/strategies/base.py +++ b/src/dnet/api/strategies/base.py @@ -31,6 +31,7 @@ async def send_tokens( logprobs: bool = False, top_logprobs: int = 0, decoding_config: Any = None, # DecodingConfig + start_pos: int = 0, ) -> None: ... @abstractmethod diff --git a/src/dnet/api/strategies/context_parallel.py b/src/dnet/api/strategies/context_parallel.py new file mode 100644 index 00000000..f881fc2c --- /dev/null +++ b/src/dnet/api/strategies/context_parallel.py @@ -0,0 +1,555 @@ +"""Context Parallel strategy for API server. + +This module provides the ContextParallelStrategy which bundles: +- CPTopologySolver: Assigns all layers to all devices (full replication) +- CPApiAdapter: Handles token injection for CP mode +""" + +from __future__ import annotations + +import asyncio +from typing import Dict, Optional, Any, Literal, List + +from grpc import aio as aio_grpc +from dnet_p2p import DnetDeviceProperties, ThunderboltConnection +from distilp.common import DeviceProfile + +from dnet.utils.logger import logger +from dnet.core.stream_manager import StreamManager +from dnet.core.types.messages import TokenResult +from dnet.core.types.topology import TopologyInfo, LayerAssignment +from dnet.core.topology import TopologySolver +from dnet.protos import dnet_ring_pb2 as pb2 +from dnet.protos.dnet_ring_pb2_grpc import DnetRingServiceStub +from dnet.utils.time import utc_epoch_now +from dnet.core.types.messages import ActivationMessage +from dnet.core.cp.sharding import shard_for_mode +from .base import Strategy, ApiAdapterBase + + +class CPTopologyInfo(TopologyInfo): + """Extended topology info for context parallelism.""" + + num_cp_ranks: int = 1 + cp_algorithm: str = "auto" + + +class CPTopologySolver(TopologySolver): + """ + Topology solver for context parallelism. + + Unlike ring topology, CP assigns ALL layers to EACH device. + Optimization focuses on ordering devices for minimal ring latency. + """ + + async def solve( + self, + profiles: Dict[str, DeviceProfile], + model_profile: Any, + model_name: str, + num_layers: int, + kv_bits: Literal["4bit", "8bit", "fp16"], + shards: Dict[str, DnetDeviceProperties], + thunderbolts: Dict[str, Dict[str, ThunderboltConnection]], + ) -> TopologyInfo: + """ + Solve topology for context parallelism. + + For CP, all devices get the full model. We optimize the ring + ordering for minimal inter-device latency. + """ + + # Filter out manager nodes - only include actual shards that have profiles + active_shards = { + name: props + for name, props in shards.items() + if not props.is_manager and name in profiles + } + + # Order devices by Thunderbolt connectivity for minimal latency + ordered_instances = self._optimize_ring_order( + profiles, thunderbolts, list(active_shards.keys()) + ) + + # Build layer assignments as list of LayerAssignment objects + # For CP, each device gets ALL layers (full model replication) + all_layers = list(range(num_layers)) + layer_assignments: List[LayerAssignment] = [] + + for i, name in enumerate(ordered_instances): + next_name = ( + ordered_instances[(i + 1) % len(ordered_instances)] + if len(ordered_instances) > 1 + else None + ) + layer_assignments.append( + LayerAssignment( + instance=name, + layers=[all_layers], # All layers in single round (k=1) + next_instance=next_name, + window_size=num_layers, + residency_size=num_layers, + ) + ) + + shards_list = [shards[name] for name in ordered_instances] + + logger.info( + "CP topology: %d devices, each with all %d layers", + len(ordered_instances), + num_layers, + ) + + # Create TopologyInfo + return TopologyInfo( + model=model_name, + kv_bits=kv_bits, + num_layers=num_layers, + devices=shards_list, + assignments=layer_assignments, + solution=None, # No HALDA solution for CP + ) + + def _optimize_ring_order( + self, + profiles: Dict[str, DeviceProfile], + thunderbolts: Dict[str, Dict[str, ThunderboltConnection]], + device_names: list[str], + ) -> list[str]: + """ + Order devices to minimize ring latency. + + Prioritize Thunderbolt connections, fallback to device order. + """ + if len(device_names) <= 2: + return device_names + + # Build adjacency matrix of TB connections + has_tb = {} + for src in device_names: + if src in thunderbolts: + for dst, conn in thunderbolts[src].items(): + if dst in device_names and conn.ip_addr: + has_tb[(src, dst)] = True + + # Greedy ordering: start from first, pick next with TB if possible + ordered = [device_names[0]] + remaining = set(device_names[1:]) + + while remaining: + current = ordered[-1] + # Find neighbor with TB connection + next_device = None + for candidate in remaining: + if has_tb.get((current, candidate)): + next_device = candidate + break + + if not next_device: + # No TB connection, pick arbitrary + next_device = remaining.pop() + else: + remaining.remove(next_device) + + ordered.append(next_device) + + return ordered + + +class CPApiAdapter(ApiAdapterBase): + """API adapter for context parallel communication. + + Supports multi-rank broadcasting: splits token sequence across ranks + and sends chunks in parallel. Only the last rank samples and returns. + """ + + def __init__(self) -> None: + super().__init__() + # Legacy single-shard connection (kept for backward compat) + self.primary_channel: Optional[aio_grpc.Channel] = None + self.primary_stub: Optional[DnetRingServiceStub] = None + self._streams = StreamManager(idle_timeout_s=5.0, backoff_s=0.2) + self._pending: Dict[str, asyncio.Future[TokenResult]] = {} + + # Multi-rank connections for CP + self.num_ranks: int = 1 + self.rank_channels: Dict[int, aio_grpc.Channel] = {} + self.rank_stubs: Dict[int, DnetRingServiceStub] = {} + self._streams_by_rank: Dict[int, StreamManager] = {} + + async def start(self) -> None: + self.running = True + + async def shutdown(self) -> None: + self.running = False + # Clean up legacy streams + for nonce in list(getattr(self._streams, "_streams", {}).keys()): + try: + await self._streams.end_stream(nonce) + except Exception: + pass + if self.primary_channel: + try: + await self.primary_channel.close() + except Exception: + pass + self.primary_channel = None + self.primary_stub = None + + # Clean up multi-rank streams and channels + for streams in self._streams_by_rank.values(): + for nonce in list(getattr(streams, "_streams", {}).keys()): + try: + await streams.end_stream(nonce) + except Exception: + pass + for channel in self.rank_channels.values(): + try: + await channel.close() + except Exception: + pass + self.rank_channels.clear() + self.rank_stubs.clear() + self._streams_by_rank.clear() + + async def connect_first_shard(self, ip: str, port: int) -> None: + """Connect to primary shard (rank 0) - legacy single-shard mode.""" + target = f"{ip}:{port}" + if self.primary_channel: + try: + await self.primary_channel.close() + except Exception: + pass + from dnet.utils.grpc_config import GRPC_AIO_OPTIONS + + self.primary_channel = aio_grpc.insecure_channel( + target, options=GRPC_AIO_OPTIONS + ) + self.primary_stub = DnetRingServiceStub(self.primary_channel) + logger.info("CP adapter connected to primary shard at %s", target) + + async def connect_all_ranks(self, rank_addresses: List[str]) -> None: + """Connect to all CP ranks for multi-rank broadcasting. + + Args: + rank_addresses: List of "host:port" strings, one per rank, in order. + """ + from dnet.utils.grpc_config import GRPC_AIO_OPTIONS + + # Close existing connections + for channel in self.rank_channels.values(): + try: + await channel.close() + except Exception: + pass + self.rank_channels.clear() + self.rank_stubs.clear() + self._streams_by_rank.clear() + + self.num_ranks = len(rank_addresses) + for rank, addr in enumerate(rank_addresses): + self.rank_channels[rank] = aio_grpc.insecure_channel( + addr, options=GRPC_AIO_OPTIONS + ) + self.rank_stubs[rank] = DnetRingServiceStub(self.rank_channels[rank]) + self._streams_by_rank[rank] = StreamManager( + idle_timeout_s=60.0, backoff_s=0.2 + ) + + # Also set primary for backward compat + if rank_addresses: + self.primary_channel = self.rank_channels.get(0) + self.primary_stub = self.rank_stubs.get(0) + + logger.info( + "CP adapter connected to %d ranks: %s", self.num_ranks, rank_addresses + ) + + async def reset_cache(self) -> None: + """Reset cache on all ranks.""" + if self.num_ranks > 1 and self.rank_stubs: + # Multi-rank: reset on all + async def reset_rank(rank: int): + stub = self.rank_stubs.get(rank) + if stub: + try: + await stub.ResetCache(pb2.ResetCacheRequest()) + except Exception as e: + logger.warning("ResetCache failed on rank %d: %s", rank, e) + + await asyncio.gather(*[reset_rank(r) for r in range(self.num_ranks)]) + elif self.primary_stub: + # Single-rank fallback + try: + await self.primary_stub.ResetCache(pb2.ResetCacheRequest()) + except Exception as e: + logger.warning("ResetCache RPC failed: %s", e) + else: + raise RuntimeError("CP adapter not connected") + + async def send_tokens( + self, + nonce: str, + tokens: bytes, + callback_addr: str, + logprobs: bool = False, + top_logprobs: int = 0, + decoding_config: Optional[Any] = None, + start_pos: int = 0, + ) -> None: + """Send tokens to all CP ranks (split and broadcast). + + If multi-rank is configured, splits the token sequence using + shard_for_mode() and sends each chunk to its corresponding rank. + Only the last rank will sample and return the result. + """ + if self.num_ranks > 1 and self.rank_stubs: + # Multi-rank mode: split and broadcast + await self._send_tokens_multi_rank( + nonce, + tokens, + callback_addr, + logprobs, + top_logprobs, + decoding_config, + start_pos, + ) + elif self.primary_stub: + # Single-rank fallback (legacy behavior) + await self._send_tokens_single_rank( + nonce, tokens, callback_addr, logprobs, top_logprobs, decoding_config + ) + else: + raise RuntimeError("CP adapter not connected to any shard") + + async def _send_tokens_single_rank( + self, + nonce: str, + tokens: bytes, + callback_addr: str, + logprobs: bool, + top_logprobs: int, + decoding_config: Optional[Any], + ) -> None: + """Legacy single-rank send (original behavior).""" + msg = ActivationMessage( + nonce=nonce, + pool_id=-1, + batch_size=1, + shape=(len(tokens) // 4,), # int32 tokens + dtype="tokens", + layer_id=-1, + timestamp=utc_epoch_now(), + node_origin="api", + callback_url=f"grpc://{callback_addr}", + req_logprobs=logprobs, + req_top_logprobs=top_logprobs, + temperature=decoding_config.temperature if decoding_config else 1.0, + top_p=decoding_config.top_p if decoding_config else 1.0, + top_k=decoding_config.top_k if decoding_config else -1, + repetition_penalty=( + decoding_config.repetition_penalty if decoding_config else 1.0 + ), + min_p=decoding_config.min_p if decoding_config else 0.0, + min_tokens_to_keep=( + decoding_config.min_tokens_to_keep if decoding_config else 1 + ), + ) + req = msg.to_proto(tokens) + + stub = self.primary_stub + assert stub is not None, "primary_stub should be set" + ctx = await self._streams.get_or_create_stream( + nonce, + lambda it: stub.StreamActivations(it), + ) + if not ctx or not ctx.open: + raise RuntimeError(f"Failed to create stream for nonce {nonce}") + + ctx.last_seq += 1 + await ctx.queue.put( + pb2.ActivationFrame(request=req, seq=ctx.last_seq, end_of_request=False) + ) + ctx.last_activity_t = asyncio.get_running_loop().time() + + async def _send_tokens_multi_rank( + self, + nonce: str, + tokens: bytes, + callback_addr: str, + logprobs: bool, + top_logprobs: int, + decoding_config: Optional[Any], + start_pos: int, + ) -> None: + """Multi-rank send: broadcast full tokens to all ranks for Ring Attention.""" + import numpy as np + + # Deserialize full token sequence + full_tokens = np.frombuffer(tokens, dtype=np.int32) + num_tokens = len(full_tokens) + + logger.debug( + "CP multi-rank send: nonce=%s, %d tokens -> %d ranks", + nonce, + num_tokens, + self.num_ranks, + ) + + # For decode (single token), send to ALL ranks (Broadcast). + # Each rank needs the full Q to attend to its local KV shard. + if num_tokens == 1: + + async def send_broadcast(rank: int) -> None: + # Only the last rank should sample/generate tokens + is_last_rank = rank == self.num_ranks - 1 + + await self._send_chunk_to_rank( + rank, + nonce, + tokens, # Full tokens (broadcast) + callback_addr, + logprobs if is_last_rank else False, + top_logprobs if is_last_rank else 0, + decoding_config if is_last_rank else None, + num_tokens, + rope_offset=start_pos, + ) + + await asyncio.gather(*[send_broadcast(r) for r in range(self.num_ranks)]) + return + + # Phase 5: True Ring Attention (Sharded KV) + # Use load-balanced 2N sharding for prefill to ensure each rank stores only 1/N KV. + # The CPAttentionWrapper will use CPAdapter to rotate KV blocks. + + # Helper to send sharded chunk to a rank + async def send_shard_to_rank(rank: int) -> None: + import mlx.core as mx + import numpy as np + + # shard_for_mode expects mx.array, convert from numpy + mx_tokens = mx.array(full_tokens) + + # Get shard for this rank (prefill mode) + sharded_chunk_mx, indices = shard_for_mode( + mx_tokens, self.num_ranks, rank, "prefill" + ) + + # Convert back to bytes for network transmission + # mx.array -> numpy -> bytes + chunk_np = np.array(sharded_chunk_mx) + chunk_bytes = chunk_np.tobytes() + + # Only the last rank should sample/generate tokens + is_last_rank = rank == self.num_ranks - 1 + + # Use existing send helper + # RoPE offset is globally determined by the start index of this shard + chunk_offset = start_pos + indices[0] if indices else start_pos + + await self._send_chunk_to_rank( + rank, + nonce, + chunk_bytes, + callback_addr, + logprobs if is_last_rank else False, + top_logprobs if is_last_rank else 0, + decoding_config if is_last_rank else None, + len(chunk_np), + rope_offset=chunk_offset, + ) + + # Send sharded chunks to all ranks in parallel + await asyncio.gather(*[send_shard_to_rank(r) for r in range(self.num_ranks)]) + + async def _send_chunk_to_rank( + self, + rank: int, + nonce: str, + tokens: bytes, + callback_addr: str, + logprobs: bool, + top_logprobs: int, + decoding_config: Optional[Any], + num_tokens: int, + rope_offset: int, + ) -> None: + """Send tokens directly to a specific rank (for decode phase).""" + + msg = ActivationMessage( + nonce=nonce, + pool_id=-1, + batch_size=1, + shape=(num_tokens,), + dtype="tokens", + layer_id=-1, + timestamp=utc_epoch_now(), + node_origin="api", + callback_url=f"grpc://{callback_addr}", + req_logprobs=logprobs, + req_top_logprobs=top_logprobs, + temperature=decoding_config.temperature if decoding_config else 1.0, + top_p=decoding_config.top_p if decoding_config else 1.0, + top_k=decoding_config.top_k if decoding_config else -1, + repetition_penalty=( + decoding_config.repetition_penalty if decoding_config else 1.0 + ), + min_p=decoding_config.min_p if decoding_config else 0.0, + min_tokens_to_keep=( + decoding_config.min_tokens_to_keep if decoding_config else 1 + ), + rope_offset=rope_offset, + ) + req = msg.to_proto(tokens) + + stub = self.rank_stubs[rank] + streams = self._streams_by_rank[rank] + ctx = await streams.get_or_create_stream( + nonce, + lambda it: stub.StreamActivations(it), + ) + if not ctx or not ctx.open: + raise RuntimeError( + f"Failed to create stream for rank {rank}, nonce {nonce}" + ) + + ctx.last_seq += 1 + await ctx.queue.put( + pb2.ActivationFrame(request=req, seq=ctx.last_seq, end_of_request=False) + ) + ctx.last_activity_t = asyncio.get_running_loop().time() + + async def await_token(self, nonce: str, timeout_s: float) -> TokenResult: + fut = asyncio.get_running_loop().create_future() + self._pending[nonce] = fut + try: + return await asyncio.wait_for(fut, timeout=timeout_s) + finally: + self._pending.pop(nonce, None) + + def resolve_token(self, nonce: str, result: TokenResult) -> None: + fut = self._pending.get(nonce) + if fut and not fut.done(): + fut.set_result(result) + + +class ContextParallelStrategy(Strategy): + """ + Execution strategy using context parallelism. + + Distributes sequence dimension across devices while replicating + all model layers on each device. + """ + + def __init__(self): + self._solver = CPTopologySolver() + self._adapter = CPApiAdapter() + + @property + def solver(self) -> TopologySolver: + return self._solver + + @property + def adapter(self) -> ApiAdapterBase: + return self._adapter diff --git a/src/dnet/config.py b/src/dnet/config.py index 38e51397..ef42e6fe 100644 --- a/src/dnet/config.py +++ b/src/dnet/config.py @@ -242,6 +242,37 @@ class TopologySettings(BaseSettings): ) +class ContextParallelSettings(BaseSettings): + """Context parallelism configuration. + + Context parallelism distributes the sequence dimension across multiple + devices for long-context inference (128K+ tokens). + """ + + model_config = SettingsConfigDict(env_prefix="DNET_CP_") + + enabled: bool = Field( + default=False, + description="Enable context parallelism mode", + ) + algorithm: Literal["auto", "pass_kv", "pass_q", "ring_reduce"] = Field( + default="auto", + description="Ring attention algorithm (auto, pass_kv, pass_q, ring_reduce)", + ) + min_context_for_cp: int = Field( + default=32768, + description="Minimum context length to enable CP (below this, single-device)", + ) + min_tokens_for_pass_kv: int = Field( + default=256, + description="Minimum new tokens to prefer pass_kv over pass_q", + ) + chunk_overlap: int = Field( + default=0, + description="Overlap between chunks for sliding window attention", + ) + + class DnetSettings(BaseSettings): """Main dnet settings, loads from .env file.""" @@ -262,6 +293,9 @@ class DnetSettings(BaseSettings): grpc: GrpcSettings = Field(default_factory=GrpcSettings) storage: StorageSettings = Field(default_factory=StorageSettings) topology: TopologySettings = Field(default_factory=TopologySettings) + context_parallel: ContextParallelSettings = Field( + default_factory=ContextParallelSettings + ) @lru_cache @@ -284,4 +318,5 @@ def get_settings() -> DnetSettings: "GrpcSettings", "StorageSettings", "TopologySettings", + "ContextParallelSettings", ] diff --git a/src/dnet/core/cp/__init__.py b/src/dnet/core/cp/__init__.py new file mode 100644 index 00000000..0c5551cd --- /dev/null +++ b/src/dnet/core/cp/__init__.py @@ -0,0 +1,63 @@ +"""Context Parallelism core utilities. + +This package provides the core building blocks for context parallelism: +- sharding: Mode-aware sequence partitioning (prefill vs decode) +- merge_attention: Numerically stable merging of partial attention outputs +- heuristics: Algorithm selection (pass-KV, pass-Q, ring-reduce) +- ring_comm: Ring communication primitives + +Note: sharding and merge_attention require MLX (macOS only). + heuristics works on all platforms. +""" + +# Platform-independent imports (always available) +from dnet.core.cp.heuristics import select_algorithm, CPAlgorithm +from dnet.core.cp.ring_comm import ( + CPRingCommunicator, + RingNeighbors, + CPRingServiceServicer, + start_cp_ring_server, +) + + +# MLX-dependent imports (only available on macOS) +# These are lazy-imported to allow heuristics to work on other platforms +def __getattr__(name: str): + """Lazy import for MLX-dependent modules.""" + if name in ("shard_for_mode", "unshard"): + from dnet.core.cp.sharding import shard_for_mode, unshard + + return shard_for_mode if name == "shard_for_mode" else unshard + elif name in ( + "PartialAttentionOutput", + "merge_partial_attention", + "merge_two_partials", + ): + from dnet.core.cp.merge_attention import ( + PartialAttentionOutput, + merge_partial_attention, + merge_two_partials, + ) + + if name == "PartialAttentionOutput": + return PartialAttentionOutput + elif name == "merge_partial_attention": + return merge_partial_attention + else: + return merge_two_partials + raise AttributeError(f"module 'dnet.core.cp' has no attribute {name!r}") + + +__all__ = [ + "shard_for_mode", + "unshard", + "PartialAttentionOutput", + "merge_partial_attention", + "merge_two_partials", + "select_algorithm", + "CPAlgorithm", + "CPRingCommunicator", + "RingNeighbors", + "CPRingServiceServicer", + "start_cp_ring_server", +] diff --git a/src/dnet/core/cp/cp_kv_sync.py b/src/dnet/core/cp/cp_kv_sync.py new file mode 100644 index 00000000..c3c6c1dd --- /dev/null +++ b/src/dnet/core/cp/cp_kv_sync.py @@ -0,0 +1,230 @@ +""" +CP KV Synchronization: AllGather for KV cache across ranks. + +After each layer's forward pass, each rank has KV for its local chunk only. +This module provides sync_kv_cache() to AllGather KV from all ranks, +so each rank can attend to the full sequence. + +The sync is called after each layer, enabling full context attention. +""" + +from __future__ import annotations + +import asyncio +from typing import Optional, TYPE_CHECKING + +import mlx.core as mx +import numpy as np + +from dnet.utils.logger import logger + +if TYPE_CHECKING: + from dnet.core.cp.ring_comm import CPRingCommunicator + + +def serialize_kv_layer(kv_cache_layer) -> bytes: + """ + Serialize a single layer's KV cache to bytes. + + Args: + kv_cache_layer: MLX KV cache object for one layer + + Returns: + Serialized bytes containing K and V tensors + """ + # MLX KV cache has keys and values as mx.array + # Handle different cache types (QuantizedKVCache, etc.) + if hasattr(kv_cache_layer, "keys") and hasattr(kv_cache_layer, "values"): + k = np.array(kv_cache_layer.keys, copy=False) + v = np.array(kv_cache_layer.values, copy=False) + elif hasattr(kv_cache_layer, "state"): + # Some caches store state as tuple (k, v) + k = np.array(kv_cache_layer.state[0], copy=False) + v = np.array(kv_cache_layer.state[1], copy=False) + else: + # Fallback: assume it's indexable + k = np.array(kv_cache_layer[0], copy=False) + v = np.array(kv_cache_layer[1], copy=False) + + # Pack with shape info + k_flat = k.reshape(-1).astype(np.float16) + v_flat = v.reshape(-1).astype(np.float16) + + header = np.array( + [ + len(k.shape), + *k.shape, + len(v.shape), + *v.shape, + ], + dtype=np.int32, + ) + + return header.tobytes() + k_flat.tobytes() + v_flat.tobytes() + + +def deserialize_kv_layer(data: bytes) -> tuple[mx.array, mx.array]: + """ + Deserialize bytes back to K, V tensors. + + Returns: + Tuple of (keys, values) as mx.array + """ + # Read header + header_count = 0 + idx = 0 + + # Read K shape + k_ndim = int(np.frombuffer(data[idx : idx + 4], dtype=np.int32)[0]) + idx += 4 + header_count += 1 + + k_shape = tuple( + np.frombuffer(data[idx : idx + 4 * k_ndim], dtype=np.int32).tolist() + ) + idx += 4 * k_ndim + + # Read V shape + v_ndim = int(np.frombuffer(data[idx : idx + 4], dtype=np.int32)[0]) + idx += 4 + + v_shape = tuple( + np.frombuffer(data[idx : idx + 4 * v_ndim], dtype=np.int32).tolist() + ) + idx += 4 * v_ndim + + # Read K data + k_size = int(np.prod(k_shape)) + k_flat = np.frombuffer(data[idx : idx + k_size * 2], dtype=np.float16) + idx += k_size * 2 + k = mx.array(k_flat.reshape(k_shape)) + + # Read V data + v_size = int(np.prod(v_shape)) + v_flat = np.frombuffer(data[idx : idx + v_size * 2], dtype=np.float16) + v = mx.array(v_flat.reshape(v_shape)) + + return k, v + + +async def allgather_ring( + local_data: bytes, + ring_comm: CPRingCommunicator, + tag_prefix: str, +) -> list[bytes]: + """ + AllGather via ring: collect data from all ranks. + + Uses N-1 ring rotations to gather all chunks. + + Args: + local_data: This rank's data + ring_comm: Ring communicator + tag_prefix: Unique tag prefix for this gather + + Returns: + List of data from all ranks, in rank order + """ + num_ranks = ring_comm.num_ranks + rank_id = ring_comm.rank_id + + if num_ranks == 1: + return [local_data] + + # Storage for all chunks, indexed by original rank + all_chunks: list[Optional[bytes]] = [None] * num_ranks + all_chunks[rank_id] = local_data + + # Current chunk to send (starts as ours, then becomes received) + current_chunk = local_data + source_rank = rank_id + + for step in range(1, num_ranks): + tag = f"{tag_prefix}_step{step}" + + # Ring send/recv: send current to next, receive from prev + recv_chunk = await ring_comm.send_recv(current_chunk, tag) + + # Calculate which rank's data we received + source_rank = (source_rank - 1) % num_ranks + all_chunks[source_rank] = recv_chunk + + # Next iteration: forward what we received + current_chunk = recv_chunk + + return [c for c in all_chunks if c is not None] + + +async def sync_kv_cache_layer( + kv_cache_layer, + layer_idx: int, + ring_comm: CPRingCommunicator, + nonce: str, +) -> None: + """ + Synchronize a single layer's KV cache across all CP ranks. + + After this call, each rank has KV from all ranks concatenated. + + Args: + kv_cache_layer: The KV cache object for this layer + layer_idx: Layer index (for logging) + ring_comm: Ring communicator + nonce: Request nonce (for unique tags) + """ + if ring_comm.num_ranks == 1: + return + + # Serialize local KV + local_kv_bytes = serialize_kv_layer(kv_cache_layer) + + # AllGather KV from all ranks + all_kv_bytes = await allgather_ring( + local_kv_bytes, + ring_comm, + f"kv_L{layer_idx}_{nonce[:8]}", + ) + + # Deserialize all chunks + all_kvs = [deserialize_kv_layer(b) for b in all_kv_bytes] + + # Concatenate along sequence dimension (axis 2 for [B, H, S, D]) + all_keys = [kv[0] for kv in all_kvs] + all_values = [kv[1] for kv in all_kvs] + + merged_k = mx.concatenate(all_keys, axis=2) + merged_v = mx.concatenate(all_values, axis=2) + + # Update the cache in-place + if hasattr(kv_cache_layer, "keys") and hasattr(kv_cache_layer, "values"): + kv_cache_layer.keys = merged_k + kv_cache_layer.values = merged_v + elif hasattr(kv_cache_layer, "state"): + kv_cache_layer.state = (merged_k, merged_v) + + logger.debug( + "CP sync layer %d: %d ranks -> merged KV shape %s", + layer_idx, + ring_comm.num_ranks, + merged_k.shape, + ) + + +async def sync_full_kv_cache( + kv_cache: list, + ring_comm: CPRingCommunicator, + nonce: str, +) -> None: + """ + Synchronize all layers' KV caches across CP ranks. + + Calls sync_kv_cache_layer for each layer in parallel. + """ + if ring_comm.num_ranks == 1: + return + + tasks = [ + sync_kv_cache_layer(kv_cache[i], i, ring_comm, nonce) + for i in range(len(kv_cache)) + ] + await asyncio.gather(*tasks) diff --git a/src/dnet/core/cp/heuristics.py b/src/dnet/core/cp/heuristics.py new file mode 100644 index 00000000..26314f6d --- /dev/null +++ b/src/dnet/core/cp/heuristics.py @@ -0,0 +1,185 @@ +"""Algorithm selection heuristics for context parallelism. + +Provides a greedy heuristic for selecting the optimal CP algorithm based on: +- Context length and cache hit rate +- Batch size +- Number of query/KV heads (GQA ratio) +- Number of CP ranks + +This is a v1 hardcoded heuristic. Future versions will use a solver-based +approach for more accurate predictions. +""" + +from __future__ import annotations + +from enum import StrEnum + + +class CPAlgorithm(StrEnum): + """Context parallelism algorithm selection.""" + + SINGLE_DEVICE = "single_device" # No CP, run on single device + PASS_KV = "pass_kv" # Rotate KV blocks (best for prefill) + PASS_Q = "pass_q" # Rotate Q blocks with All2All + RING_REDUCE = "ring_reduce" # Rotate Q with ring reduction (best for decode) + + +def select_algorithm( + new_tokens: int, + cached_tokens: int, + batch_size: int, + num_ranks: int, + num_q_heads: int, + num_kv_heads: int, + context_parallel_enabled: bool, + min_context_for_cp: int = 32768, + min_tokens_for_pass_kv: int = 256, + gqa_threshold: float | None = None, +) -> CPAlgorithm: + """ + Greedy heuristic for selecting CP algorithm. + + Decision tree: + 1. Skip CP for small contexts or if disabled + 2. Decode mode (T <= batch_size) → ring_reduce (avoid All2All) + 3. Prefill with high cache hit → pass_q (Q smaller than KV) + 4. Full prefill → pass_kv (enough compute to hide comm) + + Args: + new_tokens: Number of new tokens to process (T) + cached_tokens: Number of tokens already in KV cache (P) + batch_size: Current batch size + num_ranks: Number of CP ranks + num_q_heads: Number of query heads + num_kv_heads: Number of KV heads (for GQA models) + context_parallel_enabled: Whether CP is enabled in config + min_context_for_cp: Minimum context to use CP (default 32K) + min_tokens_for_pass_kv: Minimum new tokens for pass-KV (default 256) + gqa_threshold: Cache miss rate threshold (default: 2 * NKV / NH) + + Returns: + Selected algorithm from CPAlgorithm enum + """ + total_context = new_tokens + cached_tokens + + # Rule 1: Skip CP for small contexts or if disabled + if not context_parallel_enabled or total_context < min_context_for_cp: + return CPAlgorithm.SINGLE_DEVICE + + # Rule 2: Single rank is always single device + if num_ranks <= 1: + return CPAlgorithm.SINGLE_DEVICE + + # Rule 3: Decode mode (T=1 per sequence in batch typically) + # Heuristic: if new_tokens <= batch_size, likely decode + if new_tokens <= batch_size: + return CPAlgorithm.RING_REDUCE # Avoid All2All for decode + + # Calculate cache miss rate + miss_rate = new_tokens / total_context if total_context > 0 else 1.0 + + # Compute GQA threshold if not provided + # Threshold from paper: 2 * NKV / NH (e.g., 2*8/128 = 0.125 for Llama) + if gqa_threshold is None: + if num_q_heads > 0: + gqa_threshold = 2.0 * num_kv_heads / num_q_heads + else: + gqa_threshold = 0.125 # Default fallback + + # Rule 4: Prefill with high cache hit (partial prefill) + # When miss rate is low, Q is much smaller than full KV + if miss_rate < gqa_threshold: + return CPAlgorithm.PASS_Q + + # Rule 5: Full prefill or sufficient new tokens + # pass-KV has enough compute to hide KV communication + if new_tokens >= min_tokens_for_pass_kv: + return CPAlgorithm.PASS_KV + + # Fallback for edge cases (short prefill with low cache hit) + return CPAlgorithm.PASS_Q + + +def estimate_algorithm_latency( + algorithm: CPAlgorithm, + new_tokens: int, + cached_tokens: int, + num_ranks: int, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + flops_per_sec: float, + bandwidth_bytes_per_sec: float, +) -> float: + """ + Estimate latency for a given algorithm (for solver integration). + + This is a simplified model for v1. Actual latency depends on: + - Overlap between compute and communication + - Memory bandwidth + - Kernel efficiency + + Args: + algorithm: Selected algorithm + new_tokens: Number of new tokens + cached_tokens: Number of cached tokens + num_ranks: Number of CP ranks + num_q_heads: Query heads + num_kv_heads: KV heads + head_dim: Dimension per head + flops_per_sec: Device compute throughput + bandwidth_bytes_per_sec: Inter-device bandwidth + + Returns: + Estimated latency in seconds + """ + total_context = new_tokens + cached_tokens + bytes_per_element = 2 # bfloat16 + + if algorithm == CPAlgorithm.SINGLE_DEVICE: + # Full attention compute + attn_flops = 2 * new_tokens * total_context * num_q_heads * head_dim + return attn_flops / flops_per_sec + + tokens_per_rank = total_context // num_ranks + + if algorithm == CPAlgorithm.PASS_KV: + # Compute: distributed across ranks + attn_flops = 2 * new_tokens * total_context * num_q_heads * head_dim + compute_time = attn_flops / (flops_per_sec * num_ranks) + + # Communication: KV blocks rotated N-1 times + kv_size = tokens_per_rank * num_kv_heads * head_dim * bytes_per_element * 2 + comm_time = (num_ranks - 1) * kv_size / bandwidth_bytes_per_sec + + # Overlap: max of compute and comm (simplified) + return max(compute_time, comm_time) + + elif algorithm == CPAlgorithm.PASS_Q: + # Compute: same as pass-KV + attn_flops = 2 * new_tokens * total_context * num_q_heads * head_dim + compute_time = attn_flops / (flops_per_sec * num_ranks) + + # Communication: Q blocks + All2All + q_size = (new_tokens // num_ranks) * num_q_heads * head_dim * bytes_per_element + ring_comm = (num_ranks - 1) * q_size / bandwidth_bytes_per_sec + + # All2All: O(N^2) communication pattern + output_size = new_tokens * num_q_heads * head_dim * bytes_per_element + all2all_time = output_size / bandwidth_bytes_per_sec # Simplified + + return max(compute_time, ring_comm) + all2all_time + + else: # RING_REDUCE + # Compute: same as others + attn_flops = 2 * new_tokens * total_context * num_q_heads * head_dim + compute_time = attn_flops / (flops_per_sec * num_ranks) + + # Communication: partial outputs + merge stats + # Each step passes output + max_score + log_sum_exp + output_per_rank = (new_tokens // num_ranks) * num_q_heads * head_dim + stats_per_rank = (new_tokens // num_ranks) * num_q_heads * 2 # max + lse + bytes_per_step = (output_per_rank + stats_per_rank) * bytes_per_element + ring_time = (num_ranks - 1) * bytes_per_step / bandwidth_bytes_per_sec + + return max(compute_time, ring_time) diff --git a/src/dnet/core/cp/merge_attention.py b/src/dnet/core/cp/merge_attention.py new file mode 100644 index 00000000..14f63546 --- /dev/null +++ b/src/dnet/core/cp/merge_attention.py @@ -0,0 +1,176 @@ +"""Merge attention operator for context parallelism. + +When computing blockwise attention across distributed KV caches, each device +produces partial outputs with local softmax statistics. These must be merged +correctly using numerically stable rescaling. + +Math: + For blocks with outputs O_i, max scores m_i, and log-sum-exp l_i: + m_global = max(m_1, m_2, ..., m_N) + l_global = sum(exp(m_i - m_global) * l_i) + O_merged = sum(exp(m_i - m_global) * l_i * O_i) / l_global +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import mlx.core as mx + + +@dataclass +class PartialAttentionOutput: + """Partial attention output with merge statistics. + + Attributes: + output: Attention output [batch, seq, heads, dim] or [seq, heads, dim] + max_score: Per-position max attention score [batch, seq, heads] or [seq, heads] + log_sum_exp: Per-position log-sum-exp of attention weights (same shape as max_score) + """ + + output: mx.array + max_score: mx.array + log_sum_exp: mx.array + + +def merge_partial_attention( + partials: list[PartialAttentionOutput], +) -> mx.array: + """ + Merge multiple partial attention outputs with numerically stable rescaling. + + This implements the online softmax merge algorithm from Flash Attention, + extended for distributed computation. + + Args: + partials: List of partial outputs from different KV blocks/ranks + + Returns: + Merged attention output tensor + """ + if not partials: + raise ValueError("Cannot merge empty list of partials") + + if len(partials) == 1: + # Single partial: still need to normalize since output is unnormalized + sum_exp_expanded = mx.expand_dims(partials[0].log_sum_exp, axis=-1) + return partials[0].output / sum_exp_expanded + + # Start with first partial as running state + running = partials[0] + + for partial in partials[1:]: + running = merge_two_partials(running, partial) + + return running.output + + +def merge_two_partials( + a: PartialAttentionOutput, + b: PartialAttentionOutput, +) -> PartialAttentionOutput: + """ + Merge two partial attention outputs using numerically stable sigmoid-based algorithm. + + This implements the merge formula from ring-flash-attention which uses sigmoid + and logsigmoid to keep values bounded and prevent numerical explosion: + out = out - sigmoid(block_lse - lse) * (out - block_out) + lse = lse - logsigmoid(lse - block_lse) + + Reference: https://github.com/zhuzilin/ring-flash-attention/pull/34 + + Args: + a: First partial output (running state) + b: Second partial output (new block to merge) + + Returns: + Merged partial output + """ + # Convert to float32 for numerical precision (matching reference) + out_a = a.output.astype(mx.float32) + out_b = b.output.astype(mx.float32) + lse_a = a.log_sum_exp.astype(mx.float32) + lse_b = b.log_sum_exp.astype(mx.float32) + # Sigmoid-based merge (bounded, numerically stable) + # sigmoid(x) = 1 / (1 + exp(-x)) + # out = out_a - sigmoid(lse_b - lse_a) * (out_a - out_b) + + # Expand lse for broadcasting with output [S_q, H, D] + lse_a_exp = mx.expand_dims(lse_a, axis=-1) + lse_b_exp = mx.expand_dims(lse_b, axis=-1) + + # sigmoid(lse_b - lse_a) - bounded between 0 and 1 + sig = mx.sigmoid(lse_b_exp - lse_a_exp) + + # Merge outputs: out = out_a - sig * (out_a - out_b) = out_a * (1 - sig) + out_b * sig + output_new = out_a - sig * (out_a - out_b) + + # Update LSE using logsigmoid + # lse = lse_a - logsigmoid(lse_a - lse_b) + # logsigmoid(x) = -log(1 + exp(-x)) = x - log(1 + exp(x)) for numerical stability + # lse_new = lse_a - logsigmoid(lse_a - lse_b) + # = lse_a + log(1 + exp(lse_b - lse_a)) [using -logsigmoid(x) = log(1 + exp(-x))] + # = lse_a + softplus(lse_b - lse_a) + # Or equivalently: max(lse_a, lse_b) + log(1 + exp(-|lse_a - lse_b|)) + # Which is the stable log-sum-exp of two values + lse_max = mx.maximum(lse_a, lse_b) + lse_new = lse_max + mx.log( + mx.exp(lse_a - lse_max) + mx.exp(lse_b - lse_max) + 1e-10 + ) + + return PartialAttentionOutput( + output=output_new, + max_score=lse_max, # Keep for compatibility + log_sum_exp=lse_new, + ) + + +def compute_partial_attention_stats( + attention_weights: mx.array, + values: mx.array, +) -> PartialAttentionOutput: + """ + Compute attention output with statistics needed for merging. + + This should be called after computing raw attention scores but before + the final softmax normalization. + + Args: + attention_weights: Raw attention scores [batch, heads, seq_q, seq_kv] + values: Value tensor [batch, seq_kv, heads, dim] + + Returns: + PartialAttentionOutput with output and merge statistics + """ + # Get max for numerical stability + max_score = mx.max(attention_weights, axis=-1) # [batch, heads, seq_q] + + # Compute softmax with numerical stability + shifted = attention_weights - mx.expand_dims(max_score, axis=-1) + exp_weights = mx.exp(shifted) + sum_exp = mx.sum(exp_weights, axis=-1) # [batch, heads, seq_q] + + # Normalize + normalized = exp_weights / mx.expand_dims(sum_exp, axis=-1) + + # Compute attention output + # normalized: [batch, heads, seq_q, seq_kv] + # values transposed: [batch, heads, seq_kv, dim] + values_transposed = mx.transpose(values, (0, 2, 1, 3)) + output = mx.matmul(normalized, values_transposed) # [batch, heads, seq_q, dim] + + # Transpose output back to [batch, seq_q, heads, dim] + output = mx.transpose(output, (0, 2, 1, 3)) + + # Transpose stats to match output: [batch, seq_q, heads] + max_score = mx.transpose(max_score, (0, 2, 1)) + sum_exp = mx.transpose(sum_exp, (0, 2, 1)) + + # Compute proper log-sum-exp: LSE = max + log(sum_exp) + lse = max_score + mx.log(sum_exp + 1e-10) + + return PartialAttentionOutput( + output=output, + max_score=max_score, + log_sum_exp=lse, + ) diff --git a/src/dnet/core/cp/ring_comm.py b/src/dnet/core/cp/ring_comm.py new file mode 100644 index 00000000..763dcf39 --- /dev/null +++ b/src/dnet/core/cp/ring_comm.py @@ -0,0 +1,411 @@ +"""Ring communication primitives for context parallelism. + +Provides async send/recv operations for passing data between CP ranks in a ring topology. +Uses gRPC for transport, with optional overlap of send/recv to hide latency. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Callable, Awaitable, AsyncIterator + +import grpc +from grpc import aio as aio_grpc + +from dnet.utils.grpc_config import GRPC_AIO_OPTIONS +from dnet.utils.logger import logger + +if TYPE_CHECKING: + pass + + +@dataclass +class RingNeighbors: + """Addresses of neighboring ranks in the ring.""" + + prev_address: str # host:port of rank (id - 1) % N + next_address: str # host:port of rank (id + 1) % N + + +class CPRingCommunicator: + """ + Manages ring communication for context parallelism. + + Provides async send_recv operation that simultaneously sends to next rank + and receives from previous rank, enabling pipelined communication. + """ + + def __init__( + self, + rank_id: int, + num_ranks: int, + neighbors: Optional[RingNeighbors] = None, + ): + """ + Initialize ring communicator. + + Args: + rank_id: This rank's ID (0 to num_ranks-1) + num_ranks: Total number of CP ranks + neighbors: Addresses of prev/next ranks (can be set later via connect) + """ + if num_ranks <= 0: + raise ValueError(f"num_ranks must be positive, got {num_ranks}") + if not 0 <= rank_id < num_ranks: + raise ValueError(f"rank_id {rank_id} out of range [0, {num_ranks})") + + self.rank_id = rank_id + self.num_ranks = num_ranks + self.prev_rank = (rank_id - 1) % num_ranks + self.next_rank = (rank_id + 1) % num_ranks + + self._neighbors = neighbors + self._prev_channel: Optional[aio_grpc.Channel] = None + self._next_channel: Optional[aio_grpc.Channel] = None + + # Pending receives keyed by tag + self._pending_recv: dict[str, asyncio.Future[bytes]] = {} + # Cache for data that arrived before _recv_from_prev was called + self._early_data: dict[str, bytes] = {} + + # Lock to ensure connect is called once + self._connect_lock = asyncio.Lock() + self._connected = False + + async def connect(self, neighbors: RingNeighbors) -> None: + """ + Establish gRPC channels to neighboring ranks. + + Args: + neighbors: Addresses for prev/next ranks + """ + async with self._connect_lock: + if self._connected: + return + + self._neighbors = neighbors + + # Connect to prev rank (we receive from them) + if self.num_ranks > 1: + self._prev_channel = aio_grpc.insecure_channel( + neighbors.prev_address, options=GRPC_AIO_OPTIONS + ) + self._next_channel = aio_grpc.insecure_channel( + neighbors.next_address, options=GRPC_AIO_OPTIONS + ) + logger.debug( + "Rank %d: connected to prev=%s, next=%s", + self.rank_id, + neighbors.prev_address, + neighbors.next_address, + ) + + self._connected = True + + async def disconnect(self) -> None: + """Close gRPC channels.""" + async with self._connect_lock: + if self._prev_channel: + await self._prev_channel.close() + self._prev_channel = None + if self._next_channel: + await self._next_channel.close() + self._next_channel = None + self._connected = False + self._early_data.clear() + for fut in self._pending_recv.values(): + if not fut.done(): + fut.cancel() + self._pending_recv.clear() + + async def send_recv( + self, + send_data: bytes, + tag: str, + send_fn: Optional[Callable[[bytes, str], Awaitable[None]]] = None, + recv_fn: Optional[Callable[[str], Awaitable[bytes]]] = None, + ) -> bytes: + """ + Simultaneously send to next rank and receive from previous rank. + + This is the core operation for ring attention - overlapping send/recv + allows pipelining computation with communication. + + Args: + send_data: Data to send to next rank + tag: Unique tag for this communication (e.g., "kv_step_0") + send_fn: Optional custom send function (for testing) + recv_fn: Optional custom recv function (for testing) + + Returns: + Data received from previous rank + """ + if self.num_ranks == 1: + # Single rank: no communication needed, return own data + return send_data + + # Use provided functions or defaults + do_send = send_fn if send_fn is not None else self._send_to_next + do_recv = recv_fn if recv_fn is not None else self._recv_from_prev + + # Launch send and recv concurrently using gather + _, recv_data = await asyncio.gather( + do_send(send_data, tag), + do_recv(tag), + ) + + return recv_data + + async def _send_to_next(self, data: bytes, tag: str) -> None: + """ + Send data to next rank in the ring via gRPC. + + Uses CPRingService.SendBlock unary RPC with raw bytes in a CPBlockFrame. + """ + if not self._next_channel: + raise RuntimeError("Not connected to next rank") + + from dnet.protos import dnet_cp_pb2, dnet_cp_pb2_grpc + + stub = dnet_cp_pb2_grpc.CPRingServiceStub(self._next_channel) + frame = dnet_cp_pb2.CPBlockFrame( + nonce=tag, + source_rank=self.rank_id, + # Use partial_output to carry raw bytes (reusing existing proto field) + partial_output=dnet_cp_pb2.PartialOutput(output_data=data), + ) + + # Retry parameters + max_retries = 20 + base_delay = 0.05 + + current_try = 0 + while True: + try: + ack = await stub.SendBlock(frame) + if ack.accepted: + logger.debug( + "Rank %d: sent %d bytes to rank %d (tag=%s)", + self.rank_id, + len(data), + self.next_rank, + tag, + ) + return # Success + + # Check if rejection is "No communicator attached" + if "No communicator attached" in ack.error_message: + raise RuntimeError(f"Peer not ready: {ack.error_message}") + else: + # Other rejections are fatal + raise RuntimeError( + f"Block rejected by next rank: {ack.error_message}" + ) + + except Exception as e: + is_peer_not_ready = "No communicator attached" in str( + e + ) or "Peer not ready" in str(e) + + current_try += 1 + if current_try >= max_retries: + logger.error( + "Rank %d: failed to send to next rank after %d retries: %s", + self.rank_id, + max_retries, + e, + ) + raise + + if is_peer_not_ready: + delay = base_delay * (1.5 ** (current_try - 1)) + delay = min(delay, 2.0) + if current_try % 5 == 0: + logger.debug( + "Rank %d: peer not ready (try %d/%d), retrying in %.2fs...", + self.rank_id, + current_try, + max_retries, + delay, + ) + await asyncio.sleep(delay) + else: + # Non-retryable error + logger.error("Rank %d: fatal send error: %s", self.rank_id, e) + raise + + async def _recv_from_prev(self, tag: str) -> bytes: + """ + Receive data from previous rank in the ring. + + Uses a pending receive pattern - the gRPC server calls resolve_recv + when data arrives, and this method waits on the future. + """ + if not self._prev_channel: + raise RuntimeError("Not connected to previous rank") + + # 1. Check if data arrived early (before we called recv) + if tag in self._early_data: + data = self._early_data.pop(tag) + logger.debug( + "Rank %d: retrieved %d bytes from early cache (tag=%s)", + self.rank_id, + len(data), + tag, + ) + return data + + # 2. Create a future for this tag if it doesn't exist + if tag not in self._pending_recv: + self._pending_recv[tag] = asyncio.get_event_loop().create_future() + + # 3. Wait for the data to arrive (set by resolve_recv when server receives it) + try: + data = await asyncio.wait_for(self._pending_recv[tag], timeout=30.0) + logger.debug( + "Rank %d: received %d bytes from rank %d (tag=%s)", + self.rank_id, + len(data), + self.prev_rank, + tag, + ) + return data + except asyncio.TimeoutError: + raise RuntimeError( + f"Rank {self.rank_id}: timeout waiting for data from prev rank (tag={tag})" + ) + + def resolve_recv(self, tag: str, data: bytes) -> None: + """ + Resolve a pending receive with incoming data. + + Called by the gRPC server when data arrives from prev rank. + """ + if tag in self._pending_recv: + # Future exists, resolve it + fut = self._pending_recv[tag] + if not fut.done(): + fut.set_result(data) + else: + logger.warning( + f"Rank {self.rank_id}: received data for tag {tag} but future already done (timeout?)" + ) + if fut.done() and tag in self._pending_recv: + del self._pending_recv[tag] + else: + # Future does not exist yet (arrived early), store in cache + if tag in self._early_data: + logger.warning( + "Rank %d: overwriting early data for tag %s (previous not consumed?)", + self.rank_id, + tag, + ) + self._early_data[tag] = data + logger.debug( + "Rank %d: cached early data for tag %s (%d bytes)", + self.rank_id, + tag, + len(data), + ) + + +class CPRingServiceServicer: + """ + gRPC servicer for CP ring communication. + + Receives blocks from other ranks and routes them to the appropriate + CPRingCommunicator via resolve_recv. + """ + + def __init__(self) -> None: + """Initialize servicer with no attached communicator.""" + self._communicator: Optional[CPRingCommunicator] = None + + def attach_communicator(self, communicator: CPRingCommunicator) -> None: + """Attach a communicator to receive incoming blocks.""" + self._communicator = communicator + + async def SendBlock( + self, + request: object, + context: object, + ) -> object: + """ + Handle incoming block from another rank. + + Extracts the data and routes it to the communicator. + """ + from typing import cast + from dnet.protos import dnet_cp_pb2 + + # Cast to proper type + req = cast(dnet_cp_pb2.CPBlockFrame, request) + tag = req.nonce + + if not self._communicator: + return dnet_cp_pb2.CPBlockAck( + nonce=tag, accepted=False, error_message="No communicator attached" + ) + + # Extract data from the partial_output field + if req.HasField("partial_output"): + data = req.partial_output.output_data + else: + data = b"" + + # Route to communicator + self._communicator.resolve_recv(tag, data) + + logger.debug( + "CPRingServiceServicer: received %d bytes (tag=%s) from rank %d", + len(data), + tag, + req.source_rank, + ) + + return dnet_cp_pb2.CPBlockAck(nonce=tag, seq=req.seq, accepted=True) + + async def StreamBlocks( + self, + request_iterator: object, + context: object, + ) -> AsyncIterator[object]: + """ + Handle streaming blocks (for high-throughput scenarios). + """ + async for request in request_iterator: # type: ignore[attr-defined] + ack = await self.SendBlock(request, context) + yield ack + + +async def start_cp_ring_server( + port: int, communicator: CPRingCommunicator +) -> grpc.aio.Server: + """ + Start a gRPC server for CP ring communication. + + Args: + port: Port to listen on + communicator: CPRingCommunicator to receive incoming blocks + + Returns: + Running gRPC server + """ + from typing import cast, Any + + from dnet.protos import dnet_cp_pb2_grpc + + server = aio_grpc.server(options=GRPC_AIO_OPTIONS) + servicer = CPRingServiceServicer() + servicer.attach_communicator(communicator) + # Cast to Any to satisfy mypy - our servicer implements the protocol + dnet_cp_pb2_grpc.add_CPRingServiceServicer_to_server(cast(Any, servicer), server) + + server.add_insecure_port(f"[::]:{port}") + await server.start() + + logger.info( + "CP ring server started on port %d for rank %d", port, communicator.rank_id + ) + return server diff --git a/src/dnet/core/cp/sharding.py b/src/dnet/core/cp/sharding.py new file mode 100644 index 00000000..ae722a3d --- /dev/null +++ b/src/dnet/core/cp/sharding.py @@ -0,0 +1,143 @@ +"""Mode-aware sequence sharding for context parallelism. + +Provides utilities for partitioning sequences across CP ranks: +- Prefill: Load-balanced 2N sharding (first+last pairs) for causal attention +- Decode: Even N-way split for uniform KV lookup compute +""" + +from __future__ import annotations + +from typing import Literal + +import mlx.core as mx + + +def shard_for_mode( + tokens_or_kv: mx.array, + num_ranks: int, + rank_id: int, + mode: Literal["prefill", "decode"], +) -> tuple[mx.array, list[int]]: + """ + Mode-aware sharding for context parallelism. + + Args: + tokens_or_kv: Input tensor with sequence dimension at axis 0 + num_ranks: Total number of CP ranks + rank_id: This rank's ID (0 to num_ranks-1) + mode: "prefill" for load-balanced 2N sharding, "decode" for even splits + + Returns: + sharded: Portion of input assigned to this rank + indices: Original positions (for unsharding) + + Prefill sharding (2N load-balanced): + Sequence [C0, C1, C2, C3, C4, C5, C6, C7] with 4 ranks: + - Rank 0: [C0, C7] (first + last) + - Rank 1: [C1, C6] + - Rank 2: [C2, C5] + - Rank 3: [C3, C4] + + Decode sharding (even N-way): + Sequence split into N equal contiguous chunks. + """ + seq_len = tokens_or_kv.shape[0] + + if seq_len == 0: + return tokens_or_kv, [] + + if num_ranks <= 0: + raise ValueError(f"num_ranks must be positive, got {num_ranks}") + + if not 0 <= rank_id < num_ranks: + raise ValueError(f"rank_id {rank_id} out of range [0, {num_ranks})") + + if mode == "prefill": + return _shard_prefill(tokens_or_kv, num_ranks, rank_id, seq_len) + else: # decode + return _shard_decode(tokens_or_kv, num_ranks, rank_id, seq_len) + + +def _shard_prefill( + tokens_or_kv: mx.array, + num_ranks: int, + rank_id: int, + seq_len: int, +) -> tuple[mx.array, list[int]]: + """ + Linear sharding for prefill (temporarily replacing 2N for v1 simplicity). + Rank k gets [k*L, (k+1)*L]. This allows simple RoPE offset handling. + """ + return _shard_linear(tokens_or_kv, num_ranks, rank_id, seq_len) + + +def _shard_decode( + tokens_or_kv: mx.array, + num_ranks: int, + rank_id: int, + seq_len: int, +) -> tuple[mx.array, list[int]]: + """Even N-way split for uniform decode compute.""" + return _shard_linear(tokens_or_kv, num_ranks, rank_id, seq_len) + + +def _shard_linear( + tokens_or_kv: mx.array, + num_ranks: int, + rank_id: int, + seq_len: int, +) -> tuple[mx.array, list[int]]: + """Linear sharding implementation.""" + chunk_size = seq_len // num_ranks + remainder = seq_len % num_ranks + + # Distribute remainder across first 'remainder' ranks + start = rank_id * chunk_size + min(rank_id, remainder) + local_size = chunk_size + (1 if rank_id < remainder else 0) + end = start + local_size + + sharded = tokens_or_kv[start:end] + indices = list(range(start, end)) + + return sharded, indices + + +def unshard( + sharded_chunks: list[mx.array], + indices_per_rank: list[list[int]], + total_seq_len: int, +) -> mx.array: + """ + Reconstruct full sequence from sharded chunks. + + Args: + sharded_chunks: List of sharded tensors, one per rank + indices_per_rank: List of index lists from shard_for_mode + total_seq_len: Total sequence length + + Returns: + Reconstructed tensor with original ordering + """ + if not sharded_chunks: + raise ValueError("sharded_chunks cannot be empty") + + # Get shape info from first chunk + sample = sharded_chunks[0] + rest_shape = sample.shape[1:] + dtype = sample.dtype + + # Create output buffer + output = mx.zeros((total_seq_len,) + rest_shape, dtype=dtype) + + # Scatter chunks back to original positions + # Note: Using .add() even though indices are disjoint because MLX ArrayAt + # doesn't have .set() method. Since indices don't overlap, this is equivalent. + for chunk, indices in zip(sharded_chunks, indices_per_rank): + if len(indices) != chunk.shape[0]: + raise ValueError( + f"Chunk size {chunk.shape[0]} != indices length {len(indices)}" + ) + for i, idx in enumerate(indices): + output = output.at[idx].add(chunk[i]) + + return output diff --git a/src/dnet/core/models/base.py b/src/dnet/core/models/base.py index 5597cbd0..41984f38 100644 --- a/src/dnet/core/models/base.py +++ b/src/dnet/core/models/base.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn +from dnet.utils.logger import logger class BaseRingModel(nn.Module, metaclass=ABCMeta): @@ -16,6 +17,39 @@ class BaseRingModel(nn.Module, metaclass=ABCMeta): model_type: Optional[str] = None + # Context Parallel injection + cp_adapter: Optional[Any] = None + + def set_cp_adapter(self, adapter: Any) -> None: + """Inject Context Parallel adapter and wrap attention layers.""" + from .cp_layers import CPAttentionWrapper + + self.cp_adapter = adapter + if not adapter or adapter.num_ranks <= 1: + return + + logger.info( + "BaseRingModel: Injecting CPAttentionWrapper into %d layers", + len(self.layers), + ) + + # Iterate over all hosted layers and wrap their attention module + # Note: self.layers might be exposed by subclasses or not. + # BaseRingModel doesn't define self.layers explicitly but implies it via iteration code elsewhere. + # We try accessing it, if it fails, catch it? + # load_weights uses getattr(self, "layers", []). + + layers = getattr(self, "layers", []) or [] + for i, layer in enumerate(layers): + if hasattr(layer, "self_attn"): + # Avoid double-wrapping + if isinstance(layer.self_attn, CPAttentionWrapper): + logger.debug("Layer %d already has CP adapter, skipping wrap", i) + continue + + # Wrap existing attention module + layer.self_attn = CPAttentionWrapper(layer.self_attn, adapter) + @abstractmethod def embed(self, x: mx.array) -> mx.array: """Embed input tokens. diff --git a/src/dnet/core/models/cp_layers.py b/src/dnet/core/models/cp_layers.py new file mode 100644 index 00000000..93cb3cc6 --- /dev/null +++ b/src/dnet/core/models/cp_layers.py @@ -0,0 +1,227 @@ +""" +Context Parallel wrapper layers. +""" + +from typing import Optional, Any +import mlx.core as mx +import mlx.nn as nn +from dnet.utils.logger import logger + + +class CPAttentionWrapper(nn.Module): + """ + Wraps a standard Attention module to enable Ring Attention. + + Instead of computing local attention, it delegates to the CPAdapter + to perform distributed Ring Attention (pass-KV or pass-Q). + """ + + def __init__(self, base_attn: nn.Module, adapter: Any): + super().__init__() + self.base_attn = base_attn + self.adapter = adapter + + # Mirror attributes for compatibility + if hasattr(base_attn, "n_heads"): + self.n_heads = base_attn.n_heads + if hasattr(base_attn, "n_kv_heads"): + self.n_kv_heads = base_attn.n_kv_heads + if hasattr(base_attn, "head_dim"): + self.head_dim = base_attn.head_dim + if hasattr(base_attn, "scale"): + self.scale = base_attn.scale + + # Debug flag to log weight norms once + self._weight_logged = False + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + """ + Forward pass with Ring Attention injection. + """ + B, L, D = x.shape + + is_decode = L == 1 + + # 1. Local Projections using original weights + queries = self.base_attn.q_proj(x) + keys = self.base_attn.k_proj(x) + values = self.base_attn.v_proj(x) + + # 2. Reshape AND TRANSPOSE to [B, H, L, D] - MUST match mlx-lm order! + n_heads = self.base_attn.n_heads + n_kv_heads = self.base_attn.n_kv_heads + # head_dim may not be directly available on all model architectures (e.g., Qwen3) + # Fall back to computing from projection output shape + if hasattr(self.base_attn, "head_dim"): + head_dim = self.base_attn.head_dim + else: + # Compute from q_proj output: queries shape is [B, L, n_heads * head_dim] + head_dim = queries.shape[-1] // n_heads + + queries = queries.reshape(B, L, n_heads, head_dim).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, n_kv_heads, head_dim).transpose(0, 2, 1, 3) + values = values.reshape(B, L, n_kv_heads, head_dim).transpose(0, 2, 1, 3) + + # 3. RoPE - Applied to [B, H, L, D] format (AFTER transpose!) + offset = 0 + if cache is not None: + if hasattr(cache, "offset"): + offset = cache.offset + + # CP Override: Use global offset from adapter if available + if hasattr(self.adapter, "current_rope_offset"): + offset = self.adapter.current_rope_offset + + if hasattr(self.base_attn, "rope"): + queries = self.base_attn.rope(queries, offset=offset) + keys = self.base_attn.rope(keys, offset=offset) + + # 4. Ring Attention via Adapter + if B != 1: + logger.warning(f"CP Ring Attention received Batch Size {B} != 1. May fail.") + + # Squeeze batch and permute for ring attention: [B, H, L, D] -> [L, H, D] + # Transpose to [B, L, H, D] then squeeze + q_s = queries.transpose(0, 2, 1, 3).squeeze(0) # [L, H, D] + k_s = keys.transpose(0, 2, 1, 3).squeeze(0) # [L, H, D] + v_s = values.transpose(0, 2, 1, 3).squeeze(0) # [L, H, D] + + # Update Local KV Cache & Retrieve Full Sequence + k_all = k_s + v_all = v_s + + if cache is not None: + # Determine if this is decode (single token) vs prefill (multiple tokens) + is_decode = L == 1 + + # ALL ranks update cache during both prefill and decode. + # During decode, all ranks store the same decode token to keep caches balanced. + # The ring_reduce_attention handles deduplication during merge. + should_update_cache = True + + # 1. Handle MLX Cache Objects (Quantized or Standard) + if hasattr(cache, "update_and_fetch"): + if should_update_cache: + # MLX cache expects [B, H, L, D] format - keys are already in this format! + k_out, v_out = cache.update_and_fetch(keys, values) + else: + # Non-last rank during decode: just fetch without update + # For QuantizedKVCache, we need to access the state directly + if hasattr(cache, "state") and cache.state is not None: + k_out, v_out = cache.state + elif hasattr(cache, "keys") and hasattr(cache, "values"): + k_out, v_out = cache.keys, cache.values + else: + # Fallback: use only local K/V + k_out, v_out = keys, values + + # Check for quantization (tuple return) + if isinstance(k_out, tuple): + # Dequantize for Ring Attention computation + group_size = getattr(cache, "group_size", 64) + bits = getattr(cache, "bits", 4) + + k_full = mx.dequantize( + k_out[0], k_out[1], k_out[2], group_size, bits + ) + v_full = mx.dequantize( + v_out[0], v_out[1], v_out[2], group_size, bits + ) + else: + # Standard cache (already mx.array) + k_full = k_out + v_full = v_out + + # Transpose back to [B, L, H, D] and squeeze batch dim for ring attention + # k_full is [B, H, L, D] -> [B, L, H, D] -> squeeze -> [L, H, D] + k_all = mx.transpose(k_full, axes=(0, 2, 1, 3)).squeeze(0) + v_all = mx.transpose(v_full, axes=(0, 2, 1, 3)).squeeze(0) + + # Note: For decode on non-last rank, we do NOT include the new token + # in k_all/v_all. The new token should only contribute to attention + # from one shard (last rank) to avoid double-counting during merge. + + # 2. Handle Simple List Cache (e.g. [K, V]) + elif isinstance(cache, list): + if cache[0] is not None: + if should_update_cache: + # keys/values are [B, H, L, D], concatenate on axis=2 (sequence dim) + k_c = mx.concatenate([cache[0], keys], axis=2) + v_c = mx.concatenate([cache[1], values], axis=2) + cache[0] = k_c + cache[1] = v_c + else: + k_c = cache[0] + v_c = cache[1] + # Transpose to [B, L, H, D] then squeeze + k_all = k_c.transpose(0, 2, 1, 3).squeeze(0) + v_all = v_c.transpose(0, 2, 1, 3).squeeze(0) + # Note: For decode on non-last rank, we do NOT include the new token. + + else: + cache[0] = keys + cache[1] = values + k_all = k_s + v_all = v_s + + # Dispatch Logic + nonce = self.adapter.active_nonce + layer_id = self.adapter.current_layer_id + + # Use is_decode from earlier (L == 1) - don't redefine it! + + if is_decode: + # Ring Reduce (Pass-Q/Partial) + # Efficient for decode where Q is small and KV is distributed + + context_out = self.adapter.ring_reduce_attention_sync( + q_s, + k_all, + v_all, + rope=self.base_attn.rope, + nonce=nonce, + layer_id=layer_id, + ) + else: + # Ring Pass-KV + # Efficient for prefill where KV is sharded and we need All-to-All + # Note: For prefill, k_all == k_s (chunk) + context_out = self.adapter.ring_pass_kv_attention_sync( + q_s, + k_all, + v_all, + rope=self.base_attn.rope, + nonce=nonce, + layer_id=layer_id, + ) + + # 5. Output Projection + context_out = context_out[None, ...] # Restore B + output = self.base_attn.o_proj(context_out.reshape(B, L, -1)) + + return output + + @property + def q_proj(self): + return self.base_attn.q_proj + + @property + def k_proj(self): + return self.base_attn.k_proj + + @property + def v_proj(self): + return self.base_attn.v_proj + + @property + def o_proj(self): + return self.base_attn.o_proj + + @property + def rope(self): + return getattr(self.base_attn, "rope", None) diff --git a/src/dnet/core/types/messages.py b/src/dnet/core/types/messages.py index d8a54814..e36493f6 100644 --- a/src/dnet/core/types/messages.py +++ b/src/dnet/core/types/messages.py @@ -46,6 +46,8 @@ class ActivationMessage: repetition_penalty: float = 1.0 min_p: float = 0.0 min_tokens_to_keep: int = 1 + # CP RoPE offset + rope_offset: int = 0 @classmethod def from_proto(cls, proto_msg: ActivationRequest, pool_id: int = 0): @@ -74,6 +76,7 @@ def from_proto(cls, proto_msg: ActivationRequest, pool_id: int = 0): min_tokens_to_keep=proto_msg.min_tokens_to_keep if proto_msg.HasField("min_tokens_to_keep") else 1, + rope_offset=proto_msg.activation.rope_offset, ) def to_proto(self, data: bytes) -> ActivationRequest: @@ -86,6 +89,7 @@ def to_proto(self, data: bytes) -> ActivationRequest: shape=list(self.shape), layer_id=self.layer_id, dtype=self.dtype, + rope_offset=self.rope_offset, ), timestamp=self.timestamp, node_origin=self.node_origin, diff --git a/src/dnet/core/types/topology.py b/src/dnet/core/types/topology.py index e1d01e40..0e55a9f9 100644 --- a/src/dnet/core/types/topology.py +++ b/src/dnet/core/types/topology.py @@ -37,6 +37,9 @@ class TopologyInfo(BaseModel): ..., description="KV cache quantization used by solver and shards" ) num_layers: int = Field(..., description="Total number of layers in model") + max_position_embeddings: Optional[int] = Field( + default=None, description="Override model context length limit" + ) devices: List[DnetDeviceProperties] = Field( ..., description="Devices (in solver order)" ) diff --git a/src/dnet/protos/dnet_cp.proto b/src/dnet/protos/dnet_cp.proto new file mode 100644 index 00000000..4c6c20bb --- /dev/null +++ b/src/dnet/protos/dnet_cp.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; + +package dnetcp; + +// Context Parallelism ring communication service +// Handles KV/Q block transfers and partial attention output merging +service CPRingService { + // Bidirectional stream for efficient block transfer during ring attention + rpc StreamBlocks(stream CPBlockFrame) returns (stream CPBlockAck); + + // Unary RPC for single-shot block transfer (fallback/debug) + rpc SendBlock(CPBlockFrame) returns (CPBlockAck); +} + +// Configuration for CP distributed attention +message CPConfig { + int32 rank_id = 1; + int32 num_ranks = 2; + repeated string rank_addresses = 3; // Ordered ring: [rank0_addr, rank1_addr, ...] + string algorithm = 4; // "pass_kv", "pass_q", "ring_reduce" +} + +// Frame for streaming KV or Q blocks between CP ranks +message CPBlockFrame { + string nonce = 1; // Request identifier + int32 source_rank = 2; // Sender rank ID + int32 layer_id = 3; // Transformer layer index + int32 step = 4; // Ring rotation step (0 to N-1) + + oneof payload { + KVBlock kv_block = 5; + QBlock q_block = 6; + PartialOutput partial_output = 7; // For ring reduction + } + + uint64 seq = 8; // Sequence number for ordering + int64 timestamp = 9; // Unix timestamp ms +} + +// Key-Value block for pass-KV algorithm +message KVBlock { + bytes key_data = 1; // Serialized key tensor + bytes value_data = 2; // Serialized value tensor + repeated int32 key_shape = 3; + repeated int32 value_shape = 4; + string dtype = 5; // "float16", "bfloat16", etc. + int32 k_start = 6; // Global starting position of this KV block +} + +// Query block for pass-Q algorithm +message QBlock { + bytes query_data = 1; // Serialized query tensor + repeated int32 shape = 2; + string dtype = 3; + repeated int32 token_indices = 4; // Original indices for unsharding +} + +// Partial attention output with merge statistics (for ring reduction) +message PartialOutput { + bytes output_data = 1; // Partial attention output + bytes max_scores = 2; // Max attention scores per position + bytes log_sum_exp = 3; // Log-sum-exp for stable merging + repeated int32 shape = 4; + string dtype = 5; +} + +// Acknowledgment for block transfer +message CPBlockAck { + string nonce = 1; + uint64 seq = 2; + bool accepted = 3; + string error_message = 4; // Non-empty if accepted=false +} diff --git a/src/dnet/protos/dnet_ring.proto b/src/dnet/protos/dnet_ring.proto index bf5cf402..8dba41d0 100644 --- a/src/dnet/protos/dnet_ring.proto +++ b/src/dnet/protos/dnet_ring.proto @@ -2,6 +2,8 @@ syntax = "proto3"; package dnetring; +import "dnet_cp.proto"; + // The service for running distributed inference over a ring service DnetRingService { // Send activation data to the next node in the ring @@ -26,6 +28,7 @@ message Activation { repeated int32 shape = 3; string dtype = 4; int32 layer_id = 5; + int32 rope_offset = 6; } message ActivationRequest { @@ -44,6 +47,9 @@ message ActivationRequest { optional float repetition_penalty = 11; optional float min_p = 12; optional int32 min_tokens_to_keep = 13; + + // Context parallelism configuration (if CP mode enabled) + optional dnetcp.CPConfig cp_config = 14; } // Response message for activation sending diff --git a/src/dnet/shard/adapters/base.py b/src/dnet/shard/adapters/base.py index c494b7b4..e82fe5d7 100644 --- a/src/dnet/shard/adapters/base.py +++ b/src/dnet/shard/adapters/base.py @@ -16,6 +16,7 @@ class TopologyAdapter(ABC): def __init__(self, runtime, discovery): self.runtime = runtime + self.runtime.adapter = self # Back-reference for policies to access adapter self.discovery = discovery self.running = False diff --git a/src/dnet/shard/adapters/context_parallel.py b/src/dnet/shard/adapters/context_parallel.py new file mode 100644 index 00000000..60d33681 --- /dev/null +++ b/src/dnet/shard/adapters/context_parallel.py @@ -0,0 +1,954 @@ +""" +Context Parallel Adapter: Implements ring attention for long-context inference. + +This adapter distributes the sequence dimension across multiple devices, +with each device holding part of the context. Uses ring communication +to pass KV or Q blocks between ranks during attention computation. +""" + +from __future__ import annotations + +import asyncio +import queue +from typing import Optional, Callable, Awaitable, Dict +from contextvars import ContextVar +from urllib.parse import urlparse +from grpc import aio as aio_grpc + +import mlx.core as mx +from dnet_p2p import AsyncDnetP2P + +from dnet.core.cp.heuristics import CPAlgorithm, select_algorithm +from dnet.core.cp.ring_comm import CPRingCommunicator, RingNeighbors +from dnet.core.cp.merge_attention import ( + PartialAttentionOutput, + merge_two_partials, +) +from dnet.shard.adapters.base import TopologyAdapter +from dnet.shard.runtime import ShardRuntime +from dnet.shard.models import ShardLoadModelRequest +from dnet.utils.logger import logger +from dnet.utils.grpc_config import GRPC_AIO_OPTIONS +from dnet.utils.time import utc_epoch_now +from dnet.protos.dnet_ring_pb2 import ActivationRequest +from dnet.core.types.messages import ActivationMessage +from dnet.shard.codec import ActivationCodec +from dnet.protos import shard_api_comm_pb2, shard_api_comm_pb2_grpc, dnet_cp_pb2 +from dnet.utils.serialization import bytes_to_tensor + + +class CPAdapter(TopologyAdapter): + """ + Context Parallel adapter for shards. + + Implements ring attention where each rank holds a portion of the sequence. + Supports both pass-KV (prefill-optimized) and pass-Q with ring reduction + (decode-optimized) algorithms. + """ + + def __init__( + self, + runtime: ShardRuntime, + discovery: AsyncDnetP2P, + rank_id: int = 0, + num_ranks: int = 1, + ): + super().__init__(runtime, discovery) + self.rank_id = rank_id + self.num_ranks = num_ranks + + # Codec for activation serialization/deserialization + self.codec = ActivationCodec(runtime) + + # Ring communicator (initialized on configure_topology) + self.ring_comm: Optional[CPRingCommunicator] = None + + # Current algorithm selection + self._algorithm: CPAlgorithm = CPAlgorithm.SINGLE_DEVICE + + # Model config (set on configure) + self._num_q_heads: int = 32 + self._num_kv_heads: int = 8 + self._head_dim: int = 128 + + # API callback gRPC + self.api_channel: Optional[aio_grpc.Channel] = None + self.api_stub: Optional[shard_api_comm_pb2_grpc.ShardApiServiceStub] = None + self.api_address: Optional[str] = None + self.api_callback_address: Optional[str] = None + self._active_nonce: Optional[str] = None + + # Queues + self.queue_size = runtime.max_queue_size + self._ingress_q: asyncio.Queue[ActivationRequest] = asyncio.Queue( + maxsize=self.queue_size + ) + self._computed_q: asyncio.Queue[ActivationMessage] = asyncio.Queue( + maxsize=self.queue_size + ) + self._token_q: asyncio.Queue[ActivationMessage] = asyncio.Queue( + maxsize=self.queue_size + ) + + self._tasks: list[asyncio.Task] = [] + + # Operation counter for robust ring tags + self._attn_op_counter: int = 0 + self._active_nonce: ContextVar[Optional[str]] = ContextVar( + "active_nonce", default=None + ) + self._current_layer_id: ContextVar[int] = ContextVar("layer_id", default=-1) + self._current_rope_offset: ContextVar[int] = ContextVar( + "rope_offset", default=0 + ) + + # Store futures for pending ring operations + # key: (nonce, layer_idx, step_idx) -> Future + self._pending_ops: Dict[str, asyncio.Future] = {} + + # Persistent state for decode phase + self._local_k_start: Optional[int] = None + # Track prefill size per rank for decode-phase deduplication + # During decode, non-last ranks only use prefill tokens for attention + self._prefill_size: Optional[int] = None + + def set_active_context(self, nonce: str) -> None: + """ + Set the active request context. + """ + self._active_nonce.set(nonce) + self._attn_op_counter = 0 + + def reset_state(self) -> None: + """Reset adapter state (called on cache reset).""" + self._local_k_start = None + self._prefill_size = None + + def set_current_layer(self, layer_id: int) -> None: + """Set current layer ID for unique ring tags.""" + self._current_layer_id.set(layer_id) + + def set_current_rope_offset(self, offset: int) -> None: + """Set current RoPE offset for CP calculation.""" + self._current_rope_offset.set(offset) + + @property + def current_rope_offset(self) -> int: + return self._current_rope_offset.get() + + @property + def active_nonce(self) -> Optional[str]: + return self._active_nonce.get() + + @property + def current_layer_id(self) -> int: + return self._current_layer_id.get() + + @property + def ingress_q(self) -> asyncio.Queue[ActivationRequest]: + return self._ingress_q + + @property + def activation_computed_queue(self) -> asyncio.Queue[ActivationMessage]: + return self._computed_q + + @property + def activation_token_queue(self) -> asyncio.Queue[ActivationMessage]: + return self._token_q + + async def start(self) -> None: + """Start background workers.""" + self.running = True + self._loop = asyncio.get_running_loop() + self._tasks = [ + asyncio.create_task(self._ingress_worker()), + asyncio.create_task(self._egress_worker()), + asyncio.create_task(self._token_tx_worker()), + ] + logger.info( + "CPAdapter started: rank=%d/%d, algorithm=%s", + self.rank_id, + self.num_ranks, + self._algorithm, + ) + + def ring_pass_kv_attention_sync( + self, + query: mx.array, + key: mx.array, + value: mx.array, + rope: object = None, + nonce: Optional[str] = None, + layer_id: int = -1, + ) -> mx.array: + """ + Synchronous wrapper for ring attention, safe to call from compute threads. + Blocks until the async ring operation on the main loop completes. + """ + if not self.running or not hasattr(self, "_loop") or self._loop.is_closed(): + # Fallback to local if not running or loop closed + return self._compute_attention_output(query, key, value) + + # DEBUG: Log entry to ring sync + # logger.debug(f"CPAdapter: ring_pass_kv_attention_sync rank={self.rank_id}") + + # Safe to block because we are in ShardRuntime's compute thread, not the event loop. + + future = asyncio.run_coroutine_threadsafe( + self.ring_pass_kv_attention( + query, key, value, rope=rope, nonce=nonce, layer_id=layer_id + ), + self._loop, + ) + + try: + return future.result() + except Exception as e: + logger.error(f"CPAdapter: ring_pass_kv_attention failed: {e}") + raise + + def ring_reduce_attention_sync( + self, + query: mx.array, + key: mx.array, + value: mx.array, + rope: object = None, + nonce: Optional[str] = None, + layer_id: int = -1, + ) -> mx.array: + """ + Synchronous wrapper for ring reduce attention. + """ + if not self.running or not hasattr(self, "_loop") or self._loop.is_closed(): + return self._compute_attention_output(query, key, value) + + future = asyncio.run_coroutine_threadsafe( + self.ring_reduce_attention( + query, key, value, rope=rope, nonce=nonce, layer_id=layer_id + ), + self._loop, + ) + + try: + return future.result() + except Exception as e: + logger.error(f"CPAdapter: ring_reduce_attention failed: {e}") + raise + + async def ingress(self) -> None: + """Handle incoming activation requests.""" + pass # Handled by _ingress_worker + + async def egress(self) -> None: + """Handle outgoing activations.""" + pass # Handled by _egress_worker + + async def configure_topology(self, req: ShardLoadModelRequest) -> None: + """ + Configure CP topology from load request. + + Extracts CP-specific config (rank_id, num_ranks, neighbor addresses) + and initializes the ring communicator. + """ + # Extract CP config using direct field access + self.rank_id = req.cp_rank_id + self.num_ranks = req.cp_num_ranks + + # For CP mode with multiple ranks, force load ALL layer weights before wrapping + # This is critical because previous PP mode may have evicted/shrunk weights, + # and the CPAttentionWrapper needs correct weights before wrapping attention modules. + if self.num_ranks > 1 and self.runtime.model: + logger.info( + "CPAdapter: Forcing full weight load for %d layers before injection", + len(self.runtime.assigned_layers), + ) + try: + # Get the policy's weight cache and force-load all layers + if hasattr(self.runtime, "policy") and self.runtime.policy: + policy = self.runtime.policy + if hasattr(policy, "weight_cache") and policy.weight_cache: + # Force load all assigned layers and bind to model + all_weights = {} + for layer_id in self.runtime.assigned_layers: + w = policy.weight_cache.get_weight(layer_id, inc_ref=False) + if w: + all_weights.update(w) + if all_weights: + self.runtime.model.load_weights( + list(all_weights.items()), strict=False + ) + logger.info( + "CPAdapter: Loaded %d weight tensors for CP mode", + len(all_weights), + ) + except Exception as e: + logger.warning("CPAdapter: Failed to force-load weights: %s", e) + + # Inject ourselves into the model + if self.runtime.model: + logger.info("CPAdapter: Injecting logic into model") + self.runtime.model.set_cp_adapter(self) + + self.api_callback_address = req.api_callback_address + + # Extract model attention config for algorithm selection + self._num_q_heads = req.num_q_heads + self._num_kv_heads = req.num_kv_heads + self._head_dim = req.head_dim + + # Extract neighbor addresses for ring + rank_addresses = req.cp_rank_addresses + if self.num_ranks > 1 and len(rank_addresses) >= self.num_ranks: + prev_rank = (self.rank_id - 1) % self.num_ranks + next_rank = (self.rank_id + 1) % self.num_ranks + neighbors = RingNeighbors( + prev_address=rank_addresses[prev_rank], + next_address=rank_addresses[next_rank], + ) + self.ring_comm = CPRingCommunicator( + rank_id=self.rank_id, + num_ranks=self.num_ranks, + ) + await self.ring_comm.connect(neighbors) + + # CPRingServiceServicer is registered on the shard's existing gRPC server + # (see GrpcServer.start()) - no need to start a separate server + + logger.info( + "CPAdapter: connected ring - rank %d, prev=%s, next=%s", + self.rank_id, + neighbors.prev_address, + neighbors.next_address, + ) + + self.ring_comm = CPRingCommunicator( + rank_id=self.rank_id, + num_ranks=self.num_ranks, + ) + await self.ring_comm.connect(neighbors) + + # Access the global GrpcServer to attach our communicator + # This is a bit hacky but we need to find the running server instance. + # ShardRuntime -> Shard -> GrpcServer + # But ShardRuntime doesn't know about Shard. + + # Alternative: The Shard (which owns both) should facilitate this. + # But `configure_topology` is called via ActivationRequest... no, ShardLoadModelRequest. + # The request comes into `ShardAdapter.configure_topology`. + + # If we can't easily reach Shard, we might need a singleton or registry. + # OR, we verify if `runtime` has a back-reference. + + # Let's check `shard.py` to see relationships. + + logger.info( + "CPAdapter configured: rank=%d/%d", + self.rank_id, + self.num_ranks, + ) + + async def reset_topology(self) -> None: + """Reset topology configuration.""" + if self.ring_comm: + await self.ring_comm.disconnect() + self.ring_comm = None + self.rank_id = 0 + self.num_ranks = 1 + + async def shutdown(self) -> None: + """Shutdown the adapter.""" + self.running = False + for t in self._tasks: + t.cancel() + if self._tasks: + await asyncio.gather(*self._tasks, return_exceptions=True) + self._tasks.clear() + + if self.ring_comm: + await self.ring_comm.disconnect() + + logger.info("CPAdapter: shutdown complete") + + async def _ingress_worker(self) -> None: + """Process incoming activation requests with CP attention.""" + loop = asyncio.get_running_loop() + + while self.running: + try: + req = await self._ingress_q.get() + except asyncio.CancelledError: + break + + try: + # Deserialize and push to runtime execution queue + activation_msg = await loop.run_in_executor( + self.runtime.executor, + self.codec.deserialize, + req, + ) + if activation_msg: + await loop.run_in_executor( + None, + self.runtime.activation_recv_queue.put_nowait, + activation_msg, + ) + except Exception as e: + logger.error("CPAdapter ingress error: %s", e) + + async def _egress_worker(self) -> None: + """Forward computed activations.""" + loop = asyncio.get_running_loop() + q = self.runtime.activation_send_queue + + while self.running: + try: + # Read from runtime queue + msg = await loop.run_in_executor( + self.runtime.executor, + lambda: q.get(timeout=0.5), + ) + except asyncio.CancelledError: + break + except (asyncio.QueueEmpty, queue.Empty): + continue + except Exception: + continue + + # For CP, all outputs are final tokens (full replication) + # Unless we support mixed pipeline+CP later. + if msg.is_final: + await self._token_q.put(msg) + else: + logger.warning("CPAdapter received non-final output, dropping") + + async def _token_tx_worker(self) -> None: + """Send generated tokens back to API.""" + while self.running: + try: + msg = await self._token_q.get() + except asyncio.CancelledError: + break + await self._send_token(msg) + + async def _send_token(self, msg: ActivationMessage) -> None: + """ + Final-hop delivery of a sampled token to the API. + """ + # Pick the callback address + cb = msg.callback_url or "" + addr: Optional[str] = None + + if cb: + parsed = urlparse(cb) + if parsed.scheme == "grpc" and parsed.netloc: + addr = parsed.netloc + else: + logger.error( + "Shard %s: invalid gRPC callback URL for token: %s", + self.runtime.shard_id, + cb, + ) + return + elif self.api_callback_address: + # Fallback to load_model-provided address: host:port + addr = self.api_callback_address + else: + logger.error( + "Shard %s: no callback URL for final token; nonce=%s", + self.runtime.shard_id, + msg.nonce, + ) + return + + try: + if (self.api_channel is None) or (addr != self.api_address): + # Close old channel if any + try: + if self.api_channel is not None: + await self.api_channel.close() + except Exception: + pass + + self.api_address = addr + self.api_channel = aio_grpc.insecure_channel( + addr, options=GRPC_AIO_OPTIONS + ) + self.api_stub = shard_api_comm_pb2_grpc.ShardApiServiceStub( + self.api_channel + ) + except Exception as e: + logger.error( + "Shard %s: failed to create API channel for %s: %s", + self.runtime.shard_id, + addr, + e, + ) + return + + # send token + try: + token_id = int(getattr(msg, "token_id", -1)) + logprob = float(getattr(msg, "logprob", 0.0)) + top_logprobs = getattr(msg, "top_logprobs", {}) or {} + + req = shard_api_comm_pb2.TokenRequest( + nonce=msg.nonce, + token_id=token_id, + timestamp=utc_epoch_now(), + logprob=logprob, + top_logprobs=top_logprobs, + ) + + if self.api_stub is None: + logger.error( + "Shard %s: API stub not available for nonce=%s token=%s", + self.runtime.shard_id, + msg.nonce, + token_id, + ) + return + + resp = await self.api_stub.SendToken(req, timeout=3.0) + + if resp is None or not resp.success: + logger.error( + "Shard %s: API SendToken failed for nonce=%s token=%s: %s", + self.runtime.shard_id, + msg.nonce, + token_id, + resp.message if resp else "no response", + ) + except Exception as e: + logger.exception( + "Shard %s: error sending token via gRPC for nonce=%s: %s", + self.runtime.shard_id, + msg.nonce, + e, + ) + + def select_algorithm_for_request( + self, + new_tokens: int, + cached_tokens: int, + batch_size: int, + ) -> CPAlgorithm: + """ + Select algorithm for current request based on heuristics. + + Updates self._algorithm and returns the selected algorithm. + """ + self._algorithm = select_algorithm( + new_tokens=new_tokens, + cached_tokens=cached_tokens, + batch_size=batch_size, + num_ranks=self.num_ranks, + num_q_heads=self._num_q_heads, + num_kv_heads=self._num_kv_heads, + context_parallel_enabled=(self.num_ranks > 1), + ) + return self._algorithm + + async def ring_pass_kv_attention( + self, + query: mx.array, + key: mx.array, + value: mx.array, + rope: object = None, + send_fn: Optional[Callable[[bytes, str], Awaitable[None]]] = None, + recv_fn: Optional[Callable[[str], Awaitable[bytes]]] = None, + nonce: Optional[str] = None, + layer_id: int = -1, + ) -> mx.array: + """ + Ring attention with KV rotation (pass-KV algorithm). + + Best for full prefill where KV is smaller than Q (GQA models). + + Algorithm: + 1. Compute local attention: Attn(Q_local, KV_local) + 2. For i in 1..N-1: + a. SendRecv: send KV to next, receive from prev + b. Compute partial attention with received KV + c. Accumulate partial outputs + 3. Merge all partial outputs using numerically stable merge + + Args: + query: Local query tensor [seq_len, num_heads, head_dim] + key: Local key tensor to rotate + value: Local value tensor to rotate + send_fn: Optional custom send function (for testing) + recv_fn: Optional custom recv function (for testing) + + Returns: + Merged attention output [seq_len, num_heads, head_dim] + """ + if self.num_ranks == 1 or self.ring_comm is None: + # Single device: standard attention + return self._compute_attention_output(query, key, value) + + # Query tokens are fixed in place for pass-KV. + # Global position is provided by the absolute rope_offset. + q_start = self.current_rope_offset + + # Local KV block starts at same position initially + if query.shape[0] > 1: + # Prefill: Force global offset based on rank, as runtime tracks local offset + q_start = self.rank_id * query.shape[0] + # Prefill: This is the start of the sequence for this shard + self._local_k_start = q_start + current_k_start = q_start + # Save prefill size for decode-phase deduplication + self._prefill_size = key.shape[0] + # Approximate total prefill length logic for RoPE splitting later + self._total_prefill_len = self._prefill_size * self.num_ranks + + else: + # Decode: Use the persisted start position of the KV cache + # q_start is the position of the NEW token, but KV cache starts at 0 (or previous start) + if self._local_k_start is None: + # Fallback if prefill wasn't run (unlikely but safe) + self._local_k_start = 0 + current_k_start = self._local_k_start + + # Compute local attention first + # Note: RoPE is already applied by CPAttentionWrapper before calling this function + running = self._compute_partial_attention( + query, key, value, q_start=q_start, k_start=current_k_start + ) + + current_k, current_v = key, value + + self._attn_op_counter += 1 + + # Determine tag base: prefer layer ID, fallback to op counter + tag_base = f"L{layer_id}" if layer_id >= 0 else f"op{self._attn_op_counter}" + current_op_id = f"{nonce}_{tag_base}" if nonce else tag_base + + for step in range(1, self.num_ranks): + # Serialize KV with its current global start position + kv_bytes = self._serialize_kv(current_k, current_v, current_k_start) + + # Ring send/recv + recv_bytes = await self.ring_comm.send_recv( + kv_bytes, + f"{current_op_id}_step{step}", + send_fn=send_fn, + recv_fn=recv_fn, + ) + + # Deserialize received KV and its global start position + current_k, current_v, current_k_start = self._deserialize_kv(recv_bytes) + + # Compute attention with received KV + # Skip if all queries are before all keys (would be fully masked by causal) + q_end = q_start + query.shape[0] - 1 # Last query position + k_start_pos = current_k_start # First key position + + if q_end < k_start_pos: + # All queries are before all keys - causal mask would block everything + # Skip this KV block to avoid numerical issues (LSE would be -inf) + continue + + partial = self._compute_partial_attention( + query, current_k, current_v, q_start=q_start, k_start=current_k_start + ) + + # Online merge: accumulate into running state immediately + running = merge_two_partials(running, partial) + + # Return merged normalized output directly + return running.output + + async def ring_reduce_attention( + self, + query: mx.array, + key: mx.array, + value: mx.array, + rope: object = None, + nonce: Optional[str] = None, + layer_id: int = -1, + ) -> mx.array: + """ + Ring reduction for decode (eliminates All2All). + + Each rank computes partial attention with its local KV, then + progressively merges partials in a ring pattern. + + Algorithm: + 1. Compute local partial = Attn(Q_all, KV_local) + 2. For step in 1..N-1: + a. Ring pass: send running state to next, recv from prev + b. Merge: running = merge(running, received) + 3. All ranks have fully merged output (no All2All needed!) + + Returns: + Fully merged attention output + """ + if self.num_ranks == 1 or self.ring_comm is None: + return self._compute_attention_output(query, key, value) + + # For decode: Q is the new token at position = total_kv_length + # KV is sharded across ranks, each rank has a portion + # Since Q is always at the END, it can attend to ALL previous tokens + # So we skip causal masking (all positions valid) + + # DEDUPLICATION: All ranks store the same decode tokens, but we only + # want to count them once during merge. Non-last ranks use only their + # prefill portion for attention. Last rank uses full KV (prefill + decode). + is_last_rank = self.rank_id == self.num_ranks - 1 + + k_for_attn = key + v_for_attn = value + + if not is_last_rank and self._prefill_size is not None: + # Slice to prefill-only portion to avoid double-counting decode tokens + prefill_size = self._prefill_size + if key.shape[0] > prefill_size: + k_for_attn = key[:prefill_size] + v_for_attn = value[:prefill_size] + + # Compute local partial with no causal mask (decode Q > all K) + # Note: RoPE is already applied by CPAttentionWrapper before calling this function + running_output = self._compute_partial_attention( + query, + k_for_attn, + v_for_attn, + skip_causal_mask=True, # Decode: Q always after K + ) + + for step in range(1, self.num_ranks): + # Serialize current running state + state_bytes = self._serialize_partial(running_output) + + # Ring pass + # Tag must be unique! + # If nonce/layer provided, use them. + # Tag must be unique! + # If nonce/layer provided, use them. + tag_suffix = f"reduce_step_{step}" + if layer_id >= 0: + tag_suffix = f"L{layer_id}_{tag_suffix}" + + if nonce: + tag = f"{nonce}_{tag_suffix}" + else: + tag = tag_suffix + + recv_bytes = await self.ring_comm.send_recv( + state_bytes, + tag, + ) + + # Deserialize and merge + received_partial = self._deserialize_partial(recv_bytes) + running_output = merge_two_partials(running_output, received_partial) + + # Return merged normalized output directly + return running_output.output + + def _compute_partial_attention( + self, + query: mx.array, + key: mx.array, + value: mx.array, + q_start: int = 0, + k_start: int = 0, + skip_causal_mask: bool = False, + ) -> PartialAttentionOutput: + """ + Compute attention with tracking of max scores and log-sum-exp. + + This enables numerically stable merging of partial outputs. + + Args: + query: Query tensor [S_q, H, D] + key: Key tensor [S_kv, H, D] + value: Value tensor [S_kv, H, D] + q_start: Global starting position of query tokens (for causal mask) + k_start: Global starting position of key tokens (for causal mask) + """ + # Derive dimensions dynamically from tensors [S, H, D] + S_q = query.shape[0] + S_kv = key.shape[0] + H_q = query.shape[1] + H_kv = key.shape[1] + D = query.shape[2] + + if query.shape[0] == 0: + # Handle empty query (idle rank in CP ring) + # Return empty tensors with correct shapes for aggregation + return PartialAttentionOutput( + output=mx.zeros((0, H_q, D), dtype=query.dtype), + max_score=mx.zeros((0, H_q), dtype=query.dtype), + log_sum_exp=mx.zeros((0, H_q), dtype=query.dtype), + ) + + # Transpose to [Heads, Seq, Dim] for correct broadcasting + # We want to broadcast over Heads, not Sequence, because S_q != S_kv in Ring Attention + q_h = mx.transpose(query, axes=(1, 0, 2)) # [H_q, S_q, D] + k_h = mx.transpose(key, axes=(1, 0, 2)) # [H_kv, S_kv, D] + v_h = mx.transpose(value, axes=(1, 0, 2)) # [H_kv, S_kv, D] + + # Handle GQA: Repeat KV heads if fewer than Q heads + if H_kv < H_q: + n_rep = H_q // H_kv + if n_rep > 1: + # k_h: [H_kv, S, D] -> [H_kv, n_rep, S, D] -> [H_q, S, D] + k_h = mx.broadcast_to( + k_h[:, None], + (H_kv, n_rep, k_h.shape[1], k_h.shape[2]), + ) + k_h = k_h.reshape(H_q, k_h.shape[2], k_h.shape[3]) + + v_h = mx.broadcast_to( + v_h[:, None], + (H_kv, n_rep, v_h.shape[1], v_h.shape[2]), + ) + v_h = v_h.reshape(H_q, v_h.shape[2], v_h.shape[3]) + + # Scaled dot-product: QK^T / sqrt(d) -> [H, S_q, S_kv] + scale = 1.0 / (D**0.5) + # q_h: [H, S_q, D], k_h.T: [H, D, S_kv] -> matmul: [H, S_q, S_kv] + scores = mx.matmul(q_h, mx.transpose(k_h, axes=(0, 2, 1))) * scale + + # Apply causal mask if needed (skip for decode where Q is always after cached K) + if not skip_causal_mask: + # q can only attend to k where q_global_pos >= k_global_pos + q_positions = mx.arange(S_q) + q_start # [S_q] + k_positions = mx.arange(S_kv) + k_start # [S_kv] + # Create causal mask: [S_q, S_kv] where True = can attend + causal_mask = q_positions[:, None] >= k_positions[None, :] # [S_q, S_kv] + # Apply mask: where mask is False, set score to very negative value + # Note: -6e4 is safer than -1e9 for float16 + mask_value = mx.array(-6e4, dtype=scores.dtype) + scores = mx.where(causal_mask, scores, mask_value) + + # Cast to float32 for softmax computation to prevent exp() overflow + # Even with 200 tokens, attention scores can reach 35+, and exp(35) overflows float16 + original_dtype = scores.dtype + scores_f32 = scores.astype(mx.float32) + + # Max for numerical stability + max_score = mx.max(scores_f32, axis=-1, keepdims=False) # [H, S_q] + + # Softmax numerator: exp(scores - max) + exp_scores = mx.exp(scores_f32 - max_score[..., None]) + sum_exp = mx.sum(exp_scores, axis=-1, keepdims=False) # [H, S_q] + + # NORMALIZED output: softmax @ V (standard attention output) + attn_weights = exp_scores / sum_exp[..., None] # Softmax in float32 + # Cast weights back to original dtype for matmul with V + attn_weights = attn_weights.astype(original_dtype) + # attn_weights: [H, S_q, S_kv], v_h: [H, S_kv, D] -> output: [H, S_q, D] + output_h = mx.matmul(attn_weights, v_h) + + # Check for INF/NAN in output (Debugging) + if mx.isnan(output_h).any() or mx.isinf(output_h).any(): + import logging + + logger = logging.getLogger("dnet") + # Safe layer_id access + lid = getattr(self, "current_layer_id", -1) + logger.error( + f"CPAdapter: INF/NAN detected in attention output! layer={lid}" + ) + # Also check inputs to see source + if mx.isinf(scores).any(): + logger.error(" scores has INF") + if mx.isinf(sum_exp).any(): + logger.error(" sum_exp has INF") + + # Transpose back to [S_q, H, D] + output = mx.transpose(output_h, axes=(1, 0, 2)) + + # Transpose stats back to [S_q, H] + max_score = mx.transpose(max_score, axes=(1, 0)) + sum_exp = mx.transpose(sum_exp, axes=(1, 0)) + + # Compute proper log-sum-exp: LSE = max + log(sum_exp) + # This is used for merging per Meta paper Eq (4) + lse = max_score + mx.log(sum_exp + 1e-10) # Add epsilon to avoid log(0) + + # Cast stats back to original dtype for serialization compatibility + max_score = max_score.astype(original_dtype) + lse = lse.astype(original_dtype) + + return PartialAttentionOutput( + output=output, + max_score=max_score, + log_sum_exp=lse, # Proper LSE for merge formula + ) + + def _compute_attention_output( + self, + query: mx.array, + key: mx.array, + value: mx.array, + ) -> mx.array: + """Standard attention without partial output tracking.""" + scale = 1.0 / (self._head_dim**0.5) + scores = mx.matmul(query, mx.transpose(key, axes=(0, 2, 1))) * scale + attn_weights = mx.softmax(scores, axis=-1) + return mx.matmul(attn_weights, value) + + def _serialize_kv(self, key: mx.array, value: mx.array, k_start: int = 0) -> bytes: + """Serialize KV tensors for ring transfer using Protobuf.""" + # Force evaluation of MLX arrays before serialization to ensure + # the bytes representation is correct + mx.eval(key) + mx.eval(value) + + block = dnet_cp_pb2.KVBlock( + key_data=bytes(memoryview(key)), + value_data=bytes(memoryview(value)), + key_shape=list(key.shape), + value_shape=list(value.shape), + dtype=str(key.dtype), + k_start=k_start, + ) + return block.SerializeToString() + + def _deserialize_kv(self, data: bytes) -> tuple[mx.array, mx.array, int]: + """Deserialize KV tensors from bytes using Protobuf.""" + block = dnet_cp_pb2.KVBlock() + block.ParseFromString(data) + + k = bytes_to_tensor(block.key_data, block.dtype).reshape(block.key_shape) + v = bytes_to_tensor(block.value_data, block.dtype).reshape(block.value_shape) + + return k, v, block.k_start + + def _serialize_partial(self, partial: PartialAttentionOutput) -> bytes: + """Serialize partial attention output for ring reduction using Protobuf.""" + # Force evaluation of MLX arrays before serialization to ensure + # the bytes representation is correct + mx.eval(partial.output) + mx.eval(partial.max_score) + mx.eval(partial.log_sum_exp) + + msg = dnet_cp_pb2.PartialOutput( + output_data=bytes(memoryview(partial.output)), + max_scores=bytes(memoryview(partial.max_score)), + log_sum_exp=bytes(memoryview(partial.log_sum_exp)), + shape=list(partial.output.shape), + dtype=str(partial.output.dtype), + ) + return msg.SerializeToString() + + def _deserialize_partial(self, data: bytes) -> PartialAttentionOutput: + """Deserialize partial attention output from bytes using Protobuf.""" + msg = dnet_cp_pb2.PartialOutput() + msg.ParseFromString(data) + + out = bytes_to_tensor(msg.output_data, msg.dtype).reshape(msg.shape) + + # Recover stats shape (B, H) from output shape (B, H, D) + stat_shape = msg.shape[:2] + max_s = bytes_to_tensor(msg.max_scores, msg.dtype).reshape(stat_shape) + lse = bytes_to_tensor(msg.log_sum_exp, msg.dtype).reshape(stat_shape) + + return PartialAttentionOutput( + output=out, + max_score=max_s, + log_sum_exp=lse, + ) diff --git a/src/dnet/shard/grpc_servicer/server.py b/src/dnet/shard/grpc_servicer/server.py index a2bbb353..62968561 100644 --- a/src/dnet/shard/grpc_servicer/server.py +++ b/src/dnet/shard/grpc_servicer/server.py @@ -1,9 +1,12 @@ from .servicer import GrpcServicer from ..shard import Shard from dnet.protos.dnet_ring_pb2_grpc import add_DnetRingServiceServicer_to_server +from dnet.protos.dnet_cp_pb2_grpc import add_CPRingServiceServicer_to_server +from dnet.core.cp.ring_comm import CPRingServiceServicer from grpc import aio as aio_grpc -from typing import Optional +from typing import Optional, Any, cast from dnet.utils.logger import logger +from dnet.utils.grpc_config import GRPC_AIO_OPTIONS class GrpcServer: @@ -12,13 +15,19 @@ def __init__(self, grpc_port: int, shard: Shard): self.shard = shard self.server: Optional[aio_grpc.Server] = None self.servicer = GrpcServicer(self.shard) + self.cp_servicer: Optional[CPRingServiceServicer] = None async def start(self): """ Start gRPC server """ - self.server = aio_grpc.server() + self.server = aio_grpc.server(options=GRPC_AIO_OPTIONS) add_DnetRingServiceServicer_to_server(self.servicer, self.server) + + # Register CP ring service (for context parallelism block transfer) + self.cp_servicer = CPRingServiceServicer() + add_CPRingServiceServicer_to_server(cast(Any, self.cp_servicer), self.server) + listen_addr = f"[::]:{self.grpc_port}" self.server.add_insecure_port(listen_addr) try: diff --git a/src/dnet/shard/models.py b/src/dnet/shard/models.py index 3b8eed3c..2238d941 100644 --- a/src/dnet/shard/models.py +++ b/src/dnet/shard/models.py @@ -31,6 +31,33 @@ class ShardLoadModelRequest(BaseModel): description="API callback address for final layer completion (gRPC host:port)", ) + # Context Parallelism fields + cp_rank_id: int = Field( + default=0, description="This shard's rank ID for context parallelism" + ) + cp_num_ranks: int = Field( + default=1, description="Total number of CP ranks (1=single device mode)" + ) + cp_rank_addresses: List[str] = Field( + default_factory=list, + description="Ordered list of CP rank addresses (host:port) for ring communication", + ) + cp_algorithm: Literal["auto", "pass_kv", "pass_q", "ring_reduce"] = Field( + default="auto", description="CP algorithm selection" + ) + + # Model attention config (for CP algorithm selection) + num_q_heads: int = Field( + default=32, description="Number of query heads in the model" + ) + num_kv_heads: int = Field( + default=8, description="Number of KV heads (for GQA models)" + ) + head_dim: int = Field(default=128, description="Dimension per attention head") + max_position_embeddings: Optional[int] = Field( + default=None, description="Override model context length limit" + ) + class ShardLoadModelResponse(BaseModel): """Response from model loading operation on shard.""" diff --git a/src/dnet/shard/policies/fit_in_memory.py b/src/dnet/shard/policies/fit_in_memory.py index 2801c566..aa0c7025 100644 --- a/src/dnet/shard/policies/fit_in_memory.py +++ b/src/dnet/shard/policies/fit_in_memory.py @@ -50,6 +50,13 @@ def process(self, msg: ActivationMessage) -> None: # 1) per-nonce KV kv = self.runtime.get_or_make_kv(msg.nonce) + # Set CP/Ring context for unique tag generation + if hasattr(self.runtime.adapter, "set_active_context"): + self.runtime.adapter.set_active_context(msg.nonce) + + if hasattr(self.runtime.adapter, "set_current_rope_offset"): + self.runtime.adapter.set_current_rope_offset(msg.rope_offset) + # 2) get input tensor from pool input_buffer = self.runtime.input_pool.get_buffer(msg.pool_id) if input_buffer is None: @@ -100,6 +107,10 @@ def process(self, msg: ActivationMessage) -> None: except Exception: pass for lyr in window_layers: + # Set current layer on adapter for ring tags + if hasattr(self.runtime.adapter, "set_current_layer"): + self.runtime.adapter.set_current_layer(lyr) + with self.runtime._mlx_lock: x = self.runtime.model.apply_single_layer(lyr, x, cache=kv) try: @@ -133,9 +144,29 @@ def process(self, msg: ActivationMessage) -> None: # build output ActivationMessage if nxt >= self.runtime.model_metadata.num_layers: # end-shard sampling + + # CP multi-rank: Only the last rank holds the final token + # of the distributed sequence. Other ranks finish silently. + cp_num_ranks = getattr(self.runtime, "cp_num_ranks", 1) + cp_rank_id = getattr(self.runtime, "cp_rank_id", 0) + if cp_num_ranks > 1 and cp_rank_id != cp_num_ranks - 1: + # Not the last rank in CP - release resources and return + self.runtime.input_pool.release(msg.pool_id) + return + try: with self.runtime._mlx_lock: - y = self.runtime.model.normalize(x_cast) + # We only need the last token's logits for next-token prediction + # Slicing here drastically reduces memory usage (avoiding [B, S, V] projection) + # Handle both 3D [B, S, H] and 2D [S, H] tensors + if len(x_cast.shape) >= 3: + x_last = x_cast[:, -1:, :] + elif len(x_cast.shape) == 2: + x_last = x_cast[-1:, :] + else: + x_last = x_cast # 1D or scalar, use as-is + + y = self.runtime.model.normalize(x_last) y = self.runtime.model.lm_project(y) # Sampling diff --git a/src/dnet/shard/runtime.py b/src/dnet/shard/runtime.py index 890e72c9..c6296d29 100644 --- a/src/dnet/shard/runtime.py +++ b/src/dnet/shard/runtime.py @@ -58,6 +58,9 @@ class ShardRuntime: Topology-agnostic shard runtime. """ + # Back-reference to adapter (set by adapter on init) + adapter: Any = None + def __init__( self, shard_id, @@ -176,6 +179,19 @@ def load_model_core(self, req: ShardLoadModelRequest) -> None: self._assigned_sorted = sorted(self.assigned_layers) self._assigned_set = set(self._assigned_sorted) self.model_path = req.model_path + self.cp_rank_id = req.cp_rank_id + self.cp_num_ranks = req.cp_num_ranks + + if req.max_position_embeddings: + logger.info( + "Overriding max_position_embeddings to %s", req.max_position_embeddings + ) + # Override common config keys for context limit + self.model_metadata.model_config["max_position_embeddings"] = ( + req.max_position_embeddings + ) + self.model_metadata.model_config["seq_length"] = req.max_position_embeddings + self.model_metadata.model_config["n_ctx"] = req.max_position_embeddings local_count = max(1, len(self.assigned_layers)) requested_w = max(1, int(req.window_size)) @@ -350,6 +366,9 @@ def reset_cache(self): kv_bits=self.kv_cache_config.bits, kv_group=self.kv_cache_config.group_size, ) + # Notify adapter to reset its state (e.g., CPAdapter._local_k_start) + if self.adapter and hasattr(self.adapter, "reset_state"): + self.adapter.reset_state() logger.info("Node %s: Cache reset successfully", self.shard_id) except Exception as e: logger.error("Node %s: Error resetting cache: %s", self.shard_id, e) diff --git a/src/dnet/shard/shard.py b/src/dnet/shard/shard.py index 3f241897..e0ef59e4 100644 --- a/src/dnet/shard/shard.py +++ b/src/dnet/shard/shard.py @@ -10,12 +10,13 @@ """ import asyncio +from typing import Any, Optional + from .runtime import ShardRuntime from .adapters.base import TopologyAdapter from dnet.protos.dnet_ring_pb2 import ActivationRequest from .models import ShardLoadModelResponse, ShardUnloadModelResponse - from dnet.utils.repack import delete_repacked_layers @@ -24,6 +25,8 @@ def __init__(self, shard_id, adapter: TopologyAdapter): self.node_id = shard_id self.adapter = adapter self.runtime: ShardRuntime = adapter.runtime + # Optional reference to gRPC server (set by CLI) for CP servicer wiring + self.grpc_server: Optional[Any] = None async def start(self, loop: asyncio.AbstractEventLoop) -> None: self.runtime.attach_loop(loop) @@ -50,6 +53,46 @@ async def load_model(self, req) -> ShardLoadModelResponse: loop = asyncio.get_running_loop() await loop.run_in_executor(None, self.runtime.load_model_core, req) await self.adapter.configure_topology(req) + + # Wire CP ring_comm to gRPC servicer if using CPAdapter + from dnet.shard.adapters.context_parallel import CPAdapter + from dnet.utils.logger import logger + + if isinstance(self.adapter, CPAdapter): + logger.info( + "Shard.load_model: Adapter is CPAdapter. checking grpc_server..." + ) + if self.grpc_server: + logger.info( + "Shard.load_model: grpc_server is present. checking cp_servicer..." + ) + if ( + hasattr(self.grpc_server, "cp_servicer") + and self.grpc_server.cp_servicer + ): + logger.info( + "Shard.load_model: cp_servicer found. checking ring_comm..." + ) + if self.adapter.ring_comm: + logger.info( + "Shard.load_model: Attaching communicator to cp_servicer" + ) + self.grpc_server.cp_servicer.attach_communicator( + self.adapter.ring_comm + ) + else: + logger.warning("Shard.load_model: adapter.ring_comm is None!") + else: + logger.warning( + "Shard.load_model: cp_servicer missing on grpc_server!" + ) + else: + logger.warning("Shard.load_model: self.grpc_server is None!") + else: + logger.info( + f"Shard.load_model: Adapter is {type(self.adapter)}, not CPAdapter" + ) + return ShardLoadModelResponse( success=True, message="Model loaded successfully", diff --git a/src/dnet/utils/grpc_config.py b/src/dnet/utils/grpc_config.py index abe1b17b..39e0d27c 100644 --- a/src/dnet/utils/grpc_config.py +++ b/src/dnet/utils/grpc_config.py @@ -38,7 +38,7 @@ def get_grpc_options() -> list[tuple[str, int]]: ("grpc.keepalive_time_ms", s.keepalive_time_ms), ("grpc.keepalive_timeout_ms", s.keepalive_timeout_ms), ("grpc.keepalive_permit_without_calls", 0), - ("grpc.http2.min_time_between_pings_ms", 120000), + ("grpc.http2.min_time_between_pings_ms", 1200000), ("grpc.http2.max_pings_without_data", 0), ("grpc.http2.bdp_probe", 0), # disable BDP probe to reduce pinging # Avoid any interference from HTTP proxies for direct ring links diff --git a/tests/fakes/api.py b/tests/fakes/api.py index 5fbb13b2..446847de 100644 --- a/tests/fakes/api.py +++ b/tests/fakes/api.py @@ -266,6 +266,7 @@ def __init__(self, grpc_port: int = 12345): self.connected: tuple[str, int, str] | None = None self.calls: list = [] self.last: tuple | None = None + self.adapter = None # Not CPApiAdapter, so http_api uses connect_to_ring def resolve_request(self, *a, **k): self.last = (a, k) diff --git a/tests/fakes/runtime.py b/tests/fakes/runtime.py index 79838fc6..f4680472 100644 --- a/tests/fakes/runtime.py +++ b/tests/fakes/runtime.py @@ -111,6 +111,8 @@ def __init__(self, assigned_layers=None, num_layers: int = 4, shard_id: str = "S self._emitted: list = [] self._compute_busy = threading.Event() self._loop = None + self.cp_rank_id = 0 + self.cp_num_ranks = 1 def attach_loop(self, loop): self._loop = loop diff --git a/tests/integration/test_cp_single_system.py b/tests/integration/test_cp_single_system.py new file mode 100644 index 00000000..c30e244d --- /dev/null +++ b/tests/integration/test_cp_single_system.py @@ -0,0 +1,503 @@ +"""Integration tests for Context Parallelism. + +These tests validate CP functionality end-to-end: +1. CP module integration tests (no mocks, real tensor operations) +2. Multi-rank simulation using actual ring communication +3. End-to-end server tests when servers are available + +Usage (module-level tests - no servers needed): + uv run pytest tests/integration/test_cp_single_system.py::TestCPModuleIntegration -v + +Usage (server tests - requires running servers): + uv run pytest tests/integration/test_cp_single_system.py::TestCPServerInference -v --start-servers +""" + +import logging +import os +import signal +import subprocess +import sys +import time +from typing import Generator + +import pytest +import requests +import mlx.core as mx + +from dnet.core.cp.sharding import shard_for_mode, unshard +from dnet.core.cp.merge_attention import ( + PartialAttentionOutput, + merge_partial_attention, + merge_two_partials, +) +from dnet.core.cp.heuristics import select_algorithm, CPAlgorithm +from dnet.core.cp.ring_comm import ( + CPRingCommunicator, + RingNeighbors, + start_cp_ring_server, +) +from dnet.shard.adapters.context_parallel import CPAdapter +from dnet.config import ContextParallelSettings, get_settings + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Server configuration +API_HTTP_PORT = 8080 +SHARD_HTTP_PORT = 8081 +BASE_URL = f"http://localhost:{API_HTTP_PORT}" + + +# ============================================================================= +# MODULE-LEVEL INTEGRATION TESTS (no servers, real computations) +# ============================================================================= + + +@pytest.mark.integration +class TestCPModuleIntegration: + """Test CP modules work together correctly with real tensor operations.""" + + def test_sharding_merge_roundtrip_prefill(self) -> None: + """Test full prefill sharding -> attention -> merge pipeline.""" + # Create realistic input tensors + batch_size = 2 + seq_len = 256 + num_heads = 8 + head_dim = 64 + num_ranks = 4 + + # Input sequence + x = mx.random.normal((seq_len, batch_size, num_heads * head_dim)) + mx.eval(x) # Force evaluation + + # Shard across ranks + shards = [] + indices_list = [] + for rank in range(num_ranks): + shard_data, indices = shard_for_mode(x, num_ranks, rank, mode="prefill") + mx.eval(shard_data) + shards.append(shard_data) + indices_list.append(indices) + + # Unshard and verify roundtrip + reconstructed = unshard(shards, indices_list, seq_len) + mx.eval(reconstructed) + + # Verify exact reconstruction + assert reconstructed.shape == x.shape + diff = mx.abs(reconstructed - x) + max_diff = float(mx.max(diff).item()) + assert max_diff < 1e-6, f"Roundtrip error: {max_diff}" + + def test_sharding_merge_roundtrip_decode(self) -> None: + """Test full decode sharding -> attention -> merge pipeline.""" + seq_len = 1024 + hidden_dim = 512 + num_ranks = 4 + + x = mx.random.normal((seq_len, hidden_dim)) + mx.eval(x) + + shards = [] + indices_list = [] + for rank in range(num_ranks): + shard_data, indices = shard_for_mode(x, num_ranks, rank, mode="decode") + mx.eval(shard_data) + shards.append(shard_data) + indices_list.append(indices) + + reconstructed = unshard(shards, indices_list, seq_len) + mx.eval(reconstructed) + + assert reconstructed.shape == x.shape + diff = mx.abs(reconstructed - x) + max_diff = float(mx.max(diff).item()) + assert max_diff < 1e-6, f"Roundtrip error: {max_diff}" + + def test_partial_attention_merge_numerical_stability(self) -> None: + """Test that merging partial attention outputs is numerically stable.""" + batch_size = 2 + seq_len = 64 + num_heads = 4 + head_dim = 32 + + # Create partial outputs with varying scales (tests numerical stability) + partials = [] + for i in range(4): + # Use different scales to stress-test the merge + scale = 10.0 ** (i - 2) # 0.01, 0.1, 1.0, 10.0 + output = ( + mx.random.normal((batch_size, seq_len, num_heads, head_dim)) * scale + ) + max_score = ( + mx.random.normal((batch_size, seq_len, num_heads)) + i * 2 + ) # Varying max scores + log_sum_exp = ( + mx.abs(mx.random.normal((batch_size, seq_len, num_heads))) + 0.1 + ) + + mx.eval(output, max_score, log_sum_exp) + partials.append( + PartialAttentionOutput( + output=output, + max_score=max_score, + log_sum_exp=log_sum_exp, + ) + ) + + # Merge should produce finite results + merged = merge_partial_attention(partials) + mx.eval(merged) + + assert merged.shape == (batch_size, seq_len, num_heads, head_dim) + assert mx.all(mx.isfinite(merged)).item(), "Merged output contains NaN/Inf" + + def test_pairwise_merge_associativity(self) -> None: + """Test that pairwise merging produces same result regardless of order.""" + batch_size = 1 + seq_len = 32 + num_heads = 2 + head_dim = 16 + + def make_partial(): + return PartialAttentionOutput( + output=mx.random.normal((batch_size, seq_len, num_heads, head_dim)), + max_score=mx.random.normal((batch_size, seq_len, num_heads)), + log_sum_exp=mx.abs(mx.random.normal((batch_size, seq_len, num_heads))) + + 0.1, + ) + + p1, p2, p3 = make_partial(), make_partial(), make_partial() + mx.eval(p1.output, p2.output, p3.output) + + # Merge (p1, p2), then p3 + m12 = merge_two_partials(p1, p2) + result_12_3 = merge_two_partials(m12, p3) + mx.eval(result_12_3.output) + + # Merge p1, then (p2, p3) + m23 = merge_two_partials(p2, p3) + result_1_23 = merge_two_partials(p1, m23) + mx.eval(result_1_23.output) + + # Results should be close (floating point tolerance) + diff = mx.abs(result_12_3.output - result_1_23.output) + max_diff = float(mx.max(diff).item()) + assert max_diff < 1e-4, f"Merge order affects result: {max_diff}" + + def test_algorithm_selection_consistency(self) -> None: + """Test algorithm selection produces consistent results for same inputs.""" + settings = ContextParallelSettings() + + test_cases = [ + # (new_tokens, cached_tokens, expected_algorithm) + (100, 0, CPAlgorithm.SINGLE_DEVICE), # Short context + (65536, 0, CPAlgorithm.PASS_KV), # Long prefill + (1, 65536, CPAlgorithm.RING_REDUCE), # Decode mode + (1024, 60000, CPAlgorithm.PASS_Q), # Partial prefill + ] + + for new_tokens, cached_tokens, expected in test_cases: + result = select_algorithm( + new_tokens=new_tokens, + cached_tokens=cached_tokens, + batch_size=1, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=settings.min_context_for_cp, + ) + assert result == expected, ( + f"Expected {expected} for ({new_tokens}, {cached_tokens}), got {result}" + ) + + +@pytest.mark.integration +class TestCPRingCommunication: + """Test ring communication with actual async operations.""" + + def test_ring_full_rotation_4_ranks(self) -> None: + """Test that data correctly rotates through all ranks in the ring. + + This test starts 4 real gRPC servers and has each rank send/recv data, + verifying that after N-1 rotations, each rank has seen all other ranks' data. + """ + import asyncio + + async def run_test(): + num_ranks = 4 + base_port = 59100 + + # Create communicators for each rank + comms = [] + for rank_id in range(num_ranks): + prev_rank = (rank_id - 1) % num_ranks + next_rank = (rank_id + 1) % num_ranks + comm = CPRingCommunicator(rank_id=rank_id, num_ranks=num_ranks) + comms.append(comm) + + # Start gRPC servers for each rank + servers = [] + for rank_id in range(num_ranks): + server = await start_cp_ring_server( + port=base_port + rank_id, + communicator=comms[rank_id], + ) + servers.append(server) + + # Connect communicators to neighbors + for rank_id in range(num_ranks): + prev_rank = (rank_id - 1) % num_ranks + next_rank = (rank_id + 1) % num_ranks + neighbors = RingNeighbors( + prev_address=f"localhost:{base_port + prev_rank}", + next_address=f"localhost:{base_port + next_rank}", + ) + await comms[rank_id].connect(neighbors) + + try: + # Each rank starts with unique data + initial_data = [f"rank_{i}_data".encode() for i in range(num_ranks)] + + # Track what each rank sees over N-1 rotations + all_seen: list[list[bytes]] = [[] for _ in range(num_ranks)] + + current_data = initial_data.copy() + + for step in range(num_ranks - 1): + # All ranks send/recv simultaneously + results = await asyncio.gather( + *[ + comms[i].send_recv(current_data[i], f"step_{step}") + for i in range(num_ranks) + ] + ) + + # Update current data and track what we received + for i in range(num_ranks): + all_seen[i].append(results[i]) + current_data[i] = results[i] + + # After N-1 rotations, each rank should have seen all other ranks' data + for rank_id in range(num_ranks): + seen_set = set(all_seen[rank_id]) + # Should have received from all ranks except self + expected_others = { + d for i, d in enumerate(initial_data) if i != rank_id + } + assert seen_set == expected_others, ( + f"Rank {rank_id} missing data: {expected_others - seen_set}" + ) + finally: + # Cleanup: disconnect and stop servers + for comm in comms: + await comm.disconnect() + for server in servers: + await server.stop(grace=0.1) + + asyncio.run(run_test()) + + def test_ring_communicator_initialization(self) -> None: + """Test CPRingCommunicator initializes correctly.""" + comm = CPRingCommunicator(rank_id=2, num_ranks=4) + + assert comm.rank_id == 2 + assert comm.num_ranks == 4 + assert comm.prev_rank == 1 + assert comm.next_rank == 3 + + def test_ring_communicator_edge_cases(self) -> None: + """Test ring communicator with edge case configurations.""" + # Single rank should work + single = CPRingCommunicator(rank_id=0, num_ranks=1) + assert single.prev_rank == 0 + assert single.next_rank == 0 + + # First rank wraps to last + first = CPRingCommunicator(rank_id=0, num_ranks=4) + assert first.prev_rank == 3 + + # Last rank wraps to first + last = CPRingCommunicator(rank_id=3, num_ranks=4) + assert last.next_rank == 0 + + +@pytest.mark.integration +class TestCPAdapterIntegration: + """Test CPAdapter without mocking - actual algorithm and selection logic.""" + + def test_adapter_full_lifecycle(self) -> None: + """Test adapter initialization, algorithm selection, and reset.""" + import asyncio + + class MockRuntime: + max_queue_size = 16 + + adapter = CPAdapter( + runtime=MockRuntime(), # type: ignore + discovery=None, # type: ignore + rank_id=1, + num_ranks=4, + ) + + assert adapter.rank_id == 1 + assert adapter.num_ranks == 4 + assert adapter._algorithm == CPAlgorithm.SINGLE_DEVICE + + # Test algorithm selection for different scenarios + algo = adapter.select_algorithm_for_request( + new_tokens=65536, cached_tokens=0, batch_size=1 + ) + assert algo == CPAlgorithm.PASS_KV + assert adapter._algorithm == CPAlgorithm.PASS_KV + + algo = adapter.select_algorithm_for_request( + new_tokens=1, cached_tokens=65536, batch_size=1 + ) + assert algo == CPAlgorithm.RING_REDUCE + + # Test reset + asyncio.run(adapter.reset_topology()) + assert adapter.rank_id == 0 + assert adapter.num_ranks == 1 + + +@pytest.mark.integration +class TestCPConfiguration: + """Test CP configuration loading and validation.""" + + def test_settings_defaults(self, monkeypatch) -> None: + """Test default CP settings without environment overrides.""" + # Clear any env vars that would override defaults + monkeypatch.delenv("DNET_CP_ENABLED", raising=False) + monkeypatch.delenv("DNET_CP_ALGORITHM", raising=False) + + settings = ContextParallelSettings() + + assert settings.enabled is False + assert settings.algorithm == "auto" + assert settings.min_context_for_cp == 32768 + assert settings.min_tokens_for_pass_kv == 256 + assert settings.chunk_overlap == 0 + + def test_settings_accessible_from_dnet_settings(self) -> None: + """Test CP settings are integrated into main DnetSettings.""" + all_settings = get_settings() + cp_settings = all_settings.context_parallel + + # Verify CP settings are loaded and accessible + _ = cp_settings.enabled + _ = cp_settings.algorithm + _ = cp_settings.min_context_for_cp + _ = cp_settings.min_tokens_for_pass_kv + _ = cp_settings.chunk_overlap + + +# ============================================================================= +# SERVER-LEVEL INTEGRATION TESTS (requires running servers) +# ============================================================================= + + +def wait_for_health(url: str, timeout: float = 60) -> bool: + """Wait for server health endpoint to respond.""" + deadline = time.time() + timeout + while time.time() < deadline: + try: + resp = requests.get(f"{url}/health", timeout=2) + if resp.status_code == 200: + return True + except requests.RequestException: + pass + time.sleep(1) + return False + + +@pytest.fixture(scope="module") +def servers(start_servers_flag) -> Generator[None, None, None]: + """Start servers with CP enabled if --start-servers flag is set.""" + procs: list[subprocess.Popen] = [] + + if start_servers_flag: + env = {**os.environ, "PYTHONPATH": "src", "DNET_CP_ENABLED": "true"} + + shard_proc = subprocess.Popen( + [sys.executable, "-m", "cli.shard", "--http-port", str(SHARD_HTTP_PORT)], + cwd=os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + procs.append(shard_proc) + + if not wait_for_health(f"http://localhost:{SHARD_HTTP_PORT}", timeout=30): + for p in procs: + p.kill() + pytest.skip("Shard server not healthy") + + api_proc = subprocess.Popen( + [sys.executable, "-m", "cli.api", "--http-port", str(API_HTTP_PORT)], + cwd=os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + procs.append(api_proc) + + if not wait_for_health(BASE_URL): + for p in procs: + p.kill() + pytest.skip(f"API server not healthy at {BASE_URL}") + + yield + + for p in procs: + p.send_signal(signal.SIGTERM) + try: + p.wait(timeout=10) + except subprocess.TimeoutExpired: + p.kill() + + +@pytest.mark.integration +class TestCPServerInference: + """Server-level tests - only run when servers are available.""" + + def test_server_health(self, servers) -> None: + """Verify servers are running with CP config.""" + resp = requests.get(f"{BASE_URL}/health", timeout=5) + assert resp.status_code == 200 + + def test_inference_with_cp_enabled(self, servers) -> None: + """Test inference with CP-enabled server.""" + model_id = "Qwen/Qwen2.5-0.5B-Instruct" + + # Prepare and load + resp = requests.post( + f"{BASE_URL}/v1/prepare_topology", json={"model": model_id}, timeout=300 + ) + resp.raise_for_status() + + resp = requests.post( + f"{BASE_URL}/v1/load_model", json={"model": model_id}, timeout=300 + ) + resp.raise_for_status() + + try: + # Inference + resp = requests.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "model": model_id, + "messages": [{"role": "user", "content": "Say hello."}], + "max_tokens": 10, + }, + timeout=120, + ) + resp.raise_for_status() + result = resp.json() + + assert "choices" in result + assert len(result["choices"]) > 0 + finally: + requests.post(f"{BASE_URL}/v1/unload_model", timeout=30) diff --git a/tests/subsystems/test_cp_heuristics.py b/tests/subsystems/test_cp_heuristics.py new file mode 100644 index 00000000..3559fbf4 --- /dev/null +++ b/tests/subsystems/test_cp_heuristics.py @@ -0,0 +1,213 @@ +"""Tests for context parallelism algorithm selection heuristics.""" + +from __future__ import annotations + + +from dnet.core.cp.heuristics import ( + CPAlgorithm, + select_algorithm, + estimate_algorithm_latency, +) + + +class TestSelectAlgorithm: + """Tests for the greedy heuristic algorithm selection.""" + + def test_cp_disabled(self): + """Should return SINGLE_DEVICE when CP is disabled.""" + result = select_algorithm( + new_tokens=10000, + cached_tokens=50000, + batch_size=1, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=False, + ) + assert result == CPAlgorithm.SINGLE_DEVICE + + def test_small_context(self): + """Should return SINGLE_DEVICE for small contexts.""" + result = select_algorithm( + new_tokens=1000, + cached_tokens=2000, + batch_size=1, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=32768, + ) + assert result == CPAlgorithm.SINGLE_DEVICE + + def test_single_rank(self): + """Single rank should return SINGLE_DEVICE.""" + result = select_algorithm( + new_tokens=10000, + cached_tokens=50000, + batch_size=1, + num_ranks=1, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + ) + assert result == CPAlgorithm.SINGLE_DEVICE + + def test_decode_mode(self): + """Decode (new_tokens <= batch_size) should use RING_REDUCE.""" + result = select_algorithm( + new_tokens=4, # 4 tokens for batch of 4 -> decode + cached_tokens=100000, + batch_size=4, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=32768, + ) + assert result == CPAlgorithm.RING_REDUCE + + def test_full_prefill(self): + """Full prefill with sufficient tokens should use PASS_KV.""" + result = select_algorithm( + new_tokens=50000, # Full prefill, no cache + cached_tokens=0, + batch_size=1, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=32768, + ) + assert result == CPAlgorithm.PASS_KV + + def test_high_cache_hit(self): + """High cache hit rate (low miss rate) should use PASS_Q.""" + # miss_rate = 100 / (100 + 100000) ≈ 0.001 < 0.125 + result = select_algorithm( + new_tokens=100, # Very few new tokens + cached_tokens=100000, # Large cache + batch_size=1, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=32768, + ) + assert result == CPAlgorithm.PASS_Q + + def test_gqa_threshold_calculation(self): + """GQA threshold should be computed correctly.""" + # With 128 Q heads and 8 KV heads: threshold = 2*8/128 = 0.125 + # miss_rate = 5000 / 50000 = 0.1 < 0.125 -> PASS_Q This test has been removed from the coverage + + # miss_rate = 10000 / 50000 = 0.2 > 0.125 -> PASS_KV + result = select_algorithm( + new_tokens=10000, + cached_tokens=40000, + batch_size=1, + num_ranks=4, + num_q_heads=128, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=32768, + ) + assert result == CPAlgorithm.PASS_KV + + def test_custom_thresholds(self): + """Custom thresholds should override defaults.""" + result = select_algorithm( + new_tokens=5000, + cached_tokens=5000, # 10K total, would normally skip CP + batch_size=1, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + context_parallel_enabled=True, + min_context_for_cp=8000, # Lower threshold + ) + # Should now consider CP since 10K > 8K + assert result in (CPAlgorithm.PASS_KV, CPAlgorithm.PASS_Q) + + +class TestEstimateAlgorithmLatency: + """Tests for latency estimation (for solver integration).""" + + def test_single_device_latency(self): + """Single device should have straightforward compute latency.""" + latency = estimate_algorithm_latency( + algorithm=CPAlgorithm.SINGLE_DEVICE, + new_tokens=1000, + cached_tokens=50000, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + head_dim=128, + flops_per_sec=1e12, # 1 TFLOPS + bandwidth_bytes_per_sec=100e9, # 100 GB/s + ) + # Should be positive and finite + assert latency > 0 + assert latency < float("inf") + + def test_pass_kv_vs_single_device(self): + """PASS_KV with more ranks should be faster than single device.""" + common_args = dict( + new_tokens=10000, + cached_tokens=50000, + num_q_heads=32, + num_kv_heads=8, + head_dim=128, + flops_per_sec=1e12, + bandwidth_bytes_per_sec=100e9, + ) + + single_latency = estimate_algorithm_latency( + algorithm=CPAlgorithm.SINGLE_DEVICE, num_ranks=1, **common_args + ) + pass_kv_latency = estimate_algorithm_latency( + algorithm=CPAlgorithm.PASS_KV, num_ranks=4, **common_args + ) + + # With ideal scaling, 4 ranks should be ~4x faster + # In practice, communication overhead reduces this + assert pass_kv_latency < single_latency + + def test_ring_reduce_vs_pass_q(self): + """RING_REDUCE should avoid All2All overhead.""" + common_args = dict( + new_tokens=4, # Decode-like + cached_tokens=100000, + num_ranks=4, + num_q_heads=32, + num_kv_heads=8, + head_dim=128, + flops_per_sec=1e12, + bandwidth_bytes_per_sec=100e9, + ) + + pass_q_latency = estimate_algorithm_latency( + algorithm=CPAlgorithm.PASS_Q, **common_args + ) + ring_reduce_latency = estimate_algorithm_latency( + algorithm=CPAlgorithm.RING_REDUCE, **common_args + ) + + # Ring reduce should be faster (no All2All) + assert ring_reduce_latency <= pass_q_latency + + +class TestCPAlgorithmEnum: + """Tests for CPAlgorithm enum.""" + + def test_string_values(self): + """Enum values should be lowercase strings.""" + assert CPAlgorithm.SINGLE_DEVICE == "single_device" + assert CPAlgorithm.PASS_KV == "pass_kv" + assert CPAlgorithm.PASS_Q == "pass_q" + assert CPAlgorithm.RING_REDUCE == "ring_reduce" + + def test_is_str_enum(self): + """Should be usable as strings.""" + algo = CPAlgorithm.PASS_KV + assert f"Using {algo}" == "Using pass_kv" diff --git a/tests/subsystems/test_cp_merge.py b/tests/subsystems/test_cp_merge.py new file mode 100644 index 00000000..7284d3c8 --- /dev/null +++ b/tests/subsystems/test_cp_merge.py @@ -0,0 +1,195 @@ +"""Tests for context parallelism merge attention operator.""" + +from __future__ import annotations + +import pytest +import mlx.core as mx + +from dnet.core.cp.merge_attention import ( + PartialAttentionOutput, + merge_partial_attention, + merge_two_partials, +) + + +def make_partial( + seq_len: int, + num_heads: int, + head_dim: int, + max_score_val: float = 0.0, + lse_val: float = 1.0, +) -> PartialAttentionOutput: + """Helper to create a partial attention output for testing.""" + return PartialAttentionOutput( + output=mx.random.normal((seq_len, num_heads, head_dim)), + max_score=mx.full((seq_len, num_heads), max_score_val), + log_sum_exp=mx.full((seq_len, num_heads), lse_val), + ) + + +class TestMergeTwoPartials: + """Tests for merging two partial attention outputs.""" + + def test_equal_weights(self): + """Two partials with equal stats should produce average.""" + seq_len, num_heads, head_dim = 4, 8, 64 + + # Create two partials with same max_score and lse + p1 = PartialAttentionOutput( + output=mx.ones((seq_len, num_heads, head_dim)), + max_score=mx.zeros((seq_len, num_heads)), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + p2 = PartialAttentionOutput( + output=mx.ones((seq_len, num_heads, head_dim)) * 3, + max_score=mx.zeros((seq_len, num_heads)), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + + merged = merge_two_partials(p1, p2) + + # With equal weights, should be average: (1 + 3) / 2 = 2 + expected = mx.ones((seq_len, num_heads, head_dim)) * 2 + assert mx.allclose(merged.output, expected, atol=1e-5) + + def test_different_max_scores(self): + """Partial with higher max_score should dominate.""" + seq_len, num_heads, head_dim = 4, 8, 64 + + # p1 has much higher max_score -> should dominate + p1 = PartialAttentionOutput( + output=mx.ones((seq_len, num_heads, head_dim)), + max_score=mx.full((seq_len, num_heads), 10.0), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + p2 = PartialAttentionOutput( + output=mx.ones((seq_len, num_heads, head_dim)) * 100, + max_score=mx.zeros((seq_len, num_heads)), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + + merged = merge_two_partials(p1, p2) + + # p1 should dominate (scale factor for p2 is exp(-10) ≈ 0) + assert mx.allclose(merged.output, p1.output, atol=1e-4) + + def test_numerical_stability(self): + """Should handle large max_score values without overflow.""" + seq_len, num_heads, head_dim = 4, 8, 64 + + # Very large max scores (would overflow without proper handling) + p1 = PartialAttentionOutput( + output=mx.ones((seq_len, num_heads, head_dim)), + max_score=mx.full((seq_len, num_heads), 1000.0), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + p2 = PartialAttentionOutput( + output=mx.ones((seq_len, num_heads, head_dim)) * 2, + max_score=mx.full((seq_len, num_heads), 999.0), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + + merged = merge_two_partials(p1, p2) + + # Should not have NaN or Inf + assert not mx.any(mx.isnan(merged.output)) + assert not mx.any(mx.isinf(merged.output)) + + def test_merge_updates_stats(self): + """Merged output should have updated max_score and lse.""" + seq_len, num_heads, head_dim = 4, 8, 64 + + p1 = make_partial(seq_len, num_heads, head_dim, max_score_val=5.0, lse_val=2.0) + p2 = make_partial(seq_len, num_heads, head_dim, max_score_val=3.0, lse_val=3.0) + + merged = merge_two_partials(p1, p2) + + # New max should be max of individual maxes + assert mx.allclose(merged.max_score, mx.full((seq_len, num_heads), 5.0)) + + # New lse should be greater than individual (log of sum of exps) + assert mx.all(merged.log_sum_exp > p1.log_sum_exp) + + +class TestMergePartialAttention: + """Tests for merging multiple partial outputs.""" + + def test_empty_list_raises(self): + """Should raise on empty list.""" + with pytest.raises(ValueError, match="Cannot merge empty"): + merge_partial_attention([]) + + def test_single_partial(self): + """Single partial should return its output unchanged.""" + p1 = make_partial(4, 8, 64) + result = merge_partial_attention([p1]) + + assert mx.allclose(result, p1.output) + + def test_multiple_partials(self): + """Should correctly merge multiple partials.""" + seq_len, num_heads, head_dim = 4, 8, 64 + + # Create 4 partials with equal weights + partials = [] + for i in range(4): + p = PartialAttentionOutput( + output=mx.full((seq_len, num_heads, head_dim), float(i + 1)), + max_score=mx.zeros((seq_len, num_heads)), + log_sum_exp=mx.ones((seq_len, num_heads)), + ) + partials.append(p) + + result = merge_partial_attention(partials) + + # With equal weights: (1 + 2 + 3 + 4) / 4 = 2.5 + expected = mx.full((seq_len, num_heads, head_dim), 2.5) + assert mx.allclose(result, expected, atol=1e-4) + + def test_associativity(self): + """Merge should be associative: merge([a,b,c]) == merge([merge([a,b]),c]).""" + partials = [make_partial(4, 8, 64) for _ in range(4)] + + # Merge all at once + result1 = merge_partial_attention(partials) + + # Merge pairwise + p12 = merge_two_partials(partials[0], partials[1]) + p34 = merge_two_partials(partials[2], partials[3]) + p1234 = merge_two_partials(p12, p34) + + assert mx.allclose(result1, p1234.output, atol=1e-4) + + +class TestRingReductionSimulation: + """Simulate ring reduction to verify merge correctness.""" + + def test_ring_reduction_4_ranks(self): + """Simulate 4-rank ring reduction and verify final merge.""" + seq_len, num_heads, head_dim = 8, 4, 32 + num_ranks = 4 + + # Create "ground truth" partials (what each rank computes) + rank_partials = [ + make_partial(seq_len, num_heads, head_dim) for _ in range(num_ranks) + ] + + # Simulate ring reduction: each rank progressively merges + # At the end, all ranks should have same result + def ring_reduce(rank_id: int) -> mx.array: + running = rank_partials[rank_id] + for step in range(1, num_ranks): + # In real ring: receive from (rank_id - step) mod N + prev_rank = (rank_id - step) % num_ranks + running = merge_two_partials(running, rank_partials[prev_rank]) + return running.output + + # All ranks should produce same final output + results = [ring_reduce(r) for r in range(num_ranks)] + + for i in range(1, num_ranks): + assert mx.allclose(results[0], results[i], atol=1e-4) + + # Should also match direct merge of all + direct = merge_partial_attention(rank_partials) + assert mx.allclose(results[0], direct, atol=1e-4) diff --git a/tests/subsystems/test_cp_ring_comm.py b/tests/subsystems/test_cp_ring_comm.py new file mode 100644 index 00000000..a94315e0 --- /dev/null +++ b/tests/subsystems/test_cp_ring_comm.py @@ -0,0 +1,205 @@ +"""Tests for context parallelism ring communication.""" + +from __future__ import annotations + +import asyncio +import pytest + +from dnet.core.cp.ring_comm import ( + CPRingCommunicator, + RingNeighbors, + start_cp_ring_server, +) + + +class TestCPRingCommunicator: + """Tests for the CPRingCommunicator class.""" + + def test_init_valid(self): + """Should initialize with valid rank/num_ranks.""" + comm = CPRingCommunicator(rank_id=0, num_ranks=4) + assert comm.rank_id == 0 + assert comm.num_ranks == 4 + assert comm.prev_rank == 3 + assert comm.next_rank == 1 + + def test_init_middle_rank(self): + """Should compute correct neighbors for middle rank.""" + comm = CPRingCommunicator(rank_id=2, num_ranks=4) + assert comm.prev_rank == 1 + assert comm.next_rank == 3 + + def test_init_last_rank(self): + """Should wrap around for last rank.""" + comm = CPRingCommunicator(rank_id=3, num_ranks=4) + assert comm.prev_rank == 2 + assert comm.next_rank == 0 + + def test_init_invalid_num_ranks(self): + """Should raise on invalid num_ranks.""" + with pytest.raises(ValueError, match="num_ranks must be positive"): + CPRingCommunicator(rank_id=0, num_ranks=0) + + def test_init_invalid_rank_id(self): + """Should raise on out-of-range rank_id.""" + with pytest.raises(ValueError, match="rank_id .* out of range"): + CPRingCommunicator(rank_id=5, num_ranks=4) + + def test_send_recv_single_rank(self): + """Single rank should return its own data.""" + + async def _test(): + comm = CPRingCommunicator(rank_id=0, num_ranks=1) + data = b"test_data" + result = await comm.send_recv(data, "tag1") + assert result == data + + asyncio.run(_test()) + + def test_connect_sets_flag(self): + """Connect should set the connected flag.""" + + async def _test(): + comm = CPRingCommunicator(rank_id=0, num_ranks=2) + neighbors = RingNeighbors( + prev_address="localhost:50001", + next_address="localhost:50002", + ) + await comm.connect(neighbors) + assert comm._connected + await comm.disconnect() + assert not comm._connected + + asyncio.run(_test()) + + +class TestRealGRPCRingCommunication: + """Tests for ring communication using real gRPC servers.""" + + def test_two_rank_exchange(self): + """Two ranks should exchange data correctly via real gRPC.""" + + async def _test(): + base_port = 59200 + num_ranks = 2 + + # Create communicators + comms = [ + CPRingCommunicator(rank_id=i, num_ranks=num_ranks) + for i in range(num_ranks) + ] + + # Start gRPC servers + servers = [] + for i in range(num_ranks): + server = await start_cp_ring_server( + port=base_port + i, communicator=comms[i] + ) + servers.append(server) + + # Connect to neighbors + for i in range(num_ranks): + prev_rank = (i - 1) % num_ranks + next_rank = (i + 1) % num_ranks + neighbors = RingNeighbors( + prev_address=f"localhost:{base_port + prev_rank}", + next_address=f"localhost:{base_port + next_rank}", + ) + await comms[i].connect(neighbors) + + try: + # Run both send_recv concurrently + data0 = b"from_rank_0" + data1 = b"from_rank_1" + + results = await asyncio.gather( + comms[0].send_recv(data0, "step1"), + comms[1].send_recv(data1, "step1"), + ) + + # rank0 receives from rank1 (prev of 0 is 1 in 2-rank ring) + # rank1 receives from rank0 (prev of 1 is 0) + assert results[0] == data1 # rank0 got data1 + assert results[1] == data0 # rank1 got data0 + finally: + for comm in comms: + await comm.disconnect() + for server in servers: + await server.stop(grace=0.1) + + asyncio.run(_test()) + + def test_four_rank_ring(self): + """Four ranks should form a proper ring via real gRPC.""" + + async def _test(): + base_port = 59210 + num_ranks = 4 + + # Create communicators + comms = [ + CPRingCommunicator(rank_id=i, num_ranks=num_ranks) + for i in range(num_ranks) + ] + + # Start gRPC servers + servers = [] + for i in range(num_ranks): + server = await start_cp_ring_server( + port=base_port + i, communicator=comms[i] + ) + servers.append(server) + + # Connect to neighbors + for i in range(num_ranks): + prev_rank = (i - 1) % num_ranks + next_rank = (i + 1) % num_ranks + neighbors = RingNeighbors( + prev_address=f"localhost:{base_port + prev_rank}", + next_address=f"localhost:{base_port + next_rank}", + ) + await comms[i].connect(neighbors) + + try: + # Each rank sends its ID as bytes + data = [f"rank_{i}".encode() for i in range(num_ranks)] + + results = await asyncio.gather( + *[comms[i].send_recv(data[i], "step1") for i in range(num_ranks)] + ) + + # Each rank should receive from its previous rank + for i in range(num_ranks): + prev = (i - 1) % num_ranks + assert results[i] == data[prev] + finally: + for comm in comms: + await comm.disconnect() + for server in servers: + await server.stop(grace=0.1) + + asyncio.run(_test()) + + def test_single_rank(self): + """Single rank should return own data (no gRPC needed).""" + + async def _test(): + comm = CPRingCommunicator(rank_id=0, num_ranks=1) + data = b"solo" + result = await comm.send_recv(data, "tag") + assert result == data + + asyncio.run(_test()) + + +class TestRingNeighbors: + """Tests for the RingNeighbors dataclass.""" + + def test_creation(self): + """Should create RingNeighbors with addresses.""" + neighbors = RingNeighbors( + prev_address="192.168.1.1:50051", + next_address="192.168.1.2:50051", + ) + assert neighbors.prev_address == "192.168.1.1:50051" + assert neighbors.next_address == "192.168.1.2:50051" diff --git a/tests/subsystems/test_cp_serialization.py b/tests/subsystems/test_cp_serialization.py new file mode 100644 index 00000000..44368c30 --- /dev/null +++ b/tests/subsystems/test_cp_serialization.py @@ -0,0 +1,80 @@ +import sys +from unittest.mock import MagicMock + +# Mock dnet.compression to avoid Metal dependency on Linux +mock_compression = MagicMock() +sys.modules["dnet.compression"] = mock_compression +sys.modules["dnet.compression.ops"] = MagicMock() +sys.modules["dnet.compression.kernels"] = MagicMock() + +import pytest # noqa: E402 +import mlx.core as mx # noqa: E402 +import numpy as np # noqa: E402 +from dnet.shard.adapters.context_parallel import CPAdapter # noqa: E402 +from dnet.core.cp.merge_attention import PartialAttentionOutput # noqa: E402 + + +# Mock dependencies for CPAdapter init +class MockRuntime: + max_queue_size = 10 + + +class MockDiscovery: + pass + + +@pytest.fixture +def adapter(): + return CPAdapter(runtime=MockRuntime(), discovery=MockDiscovery()) + + +def test_kv_serialization_roundtrip(adapter): + # Create test tensors + k = mx.random.uniform(shape=(2, 4, 32)) + v = mx.random.uniform(shape=(2, 4, 32)) + + # Serialize + data = adapter._serialize_kv(k, v) + assert isinstance(data, bytes) + assert len(data) > 0 + + # Deserialize + k_out, v_out = adapter._deserialize_kv(data) + + # Verify + assert k_out.shape == k.shape + assert v_out.shape == v.shape + assert k_out.dtype == k.dtype + assert v_out.dtype == v.dtype + + # Check values (using numpy for comparison) + np.testing.assert_allclose(np.array(k_out), np.array(k), rtol=1e-5) + np.testing.assert_allclose(np.array(v_out), np.array(v), rtol=1e-5) + + +def test_partial_serialization_roundtrip(adapter): + # Create test partial output + out = mx.random.uniform(shape=(2, 8, 64)) + # Max score: [B, H] + max_s = mx.random.uniform(shape=(2, 8)) + lse = mx.random.uniform(shape=(2, 8)) + + partial = PartialAttentionOutput(output=out, max_score=max_s, log_sum_exp=lse) + + # Serialize + data = adapter._serialize_partial(partial) + assert isinstance(data, bytes) + + # Deserialize + p_out = adapter._deserialize_partial(data) + + # Verify output + assert p_out.output.shape == out.shape + np.testing.assert_allclose(np.array(p_out.output), np.array(out), rtol=1e-5) + + # Verify metadata (restored shape) + assert p_out.max_score.shape == max_s.shape + assert p_out.log_sum_exp.shape == lse.shape + + np.testing.assert_allclose(np.array(p_out.max_score), np.array(max_s), rtol=1e-5) + np.testing.assert_allclose(np.array(p_out.log_sum_exp), np.array(lse), rtol=1e-5) diff --git a/tests/subsystems/test_cp_sharding.py b/tests/subsystems/test_cp_sharding.py new file mode 100644 index 00000000..e1b96ac8 --- /dev/null +++ b/tests/subsystems/test_cp_sharding.py @@ -0,0 +1,181 @@ +"""Tests for context parallelism sharding utilities.""" + +from __future__ import annotations + +import pytest +import mlx.core as mx + +from dnet.core.cp.sharding import shard_for_mode, unshard + + +class TestShardForModePrefill: + """Tests for prefill (2N load-balanced) sharding.""" + + def test_basic_4_ranks(self): + """Test 2N sharding with 4 ranks produces correct assignments.""" + # 16 tokens, 4 ranks -> 8 chunks -> pairs (0,7), (1,6), (2,5), (3,4) + tokens = mx.arange(16) + num_ranks = 4 + + # Rank 0 gets chunks 0 and 7 + sharded, indices = shard_for_mode(tokens, num_ranks, 0, "prefill") + assert sharded.shape[0] == 4 # 2 + 2 tokens + assert indices == [0, 1, 14, 15] + + # Rank 1 gets chunks 1 and 6 + sharded, indices = shard_for_mode(tokens, num_ranks, 1, "prefill") + assert indices == [2, 3, 12, 13] + + # Rank 3 gets chunks 3 and 4 (middle) + sharded, indices = shard_for_mode(tokens, num_ranks, 3, "prefill") + assert indices == [6, 7, 8, 9] + + def test_load_balance(self): + """Verify all ranks get equal-sized chunks (load balanced).""" + tokens = mx.arange(64) + num_ranks = 4 + + sizes = [] + for rank_id in range(num_ranks): + sharded, _ = shard_for_mode(tokens, num_ranks, rank_id, "prefill") + sizes.append(sharded.shape[0]) + + # All sizes should be equal (or differ by at most 1 for remainders) + assert max(sizes) - min(sizes) <= 1 + + def test_single_rank(self): + """Single rank should get all tokens.""" + tokens = mx.arange(10) + sharded, indices = shard_for_mode(tokens, 1, 0, "prefill") + + assert sharded.shape[0] == 10 + assert indices == list(range(10)) + + def test_coverage_all_indices(self): + """All indices should be covered exactly once across all ranks.""" + tokens = mx.arange(32) + num_ranks = 4 + + all_indices = [] + for rank_id in range(num_ranks): + _, indices = shard_for_mode(tokens, num_ranks, rank_id, "prefill") + all_indices.extend(indices) + + assert sorted(all_indices) == list(range(32)) + + +class TestShardForModeDecode: + """Tests for decode (even N-way) sharding.""" + + def test_basic_4_ranks(self): + """Test even sharding with 4 ranks.""" + tokens = mx.arange(16) + num_ranks = 4 + + # Each rank gets contiguous 4 tokens + for rank_id in range(num_ranks): + sharded, indices = shard_for_mode(tokens, num_ranks, rank_id, "decode") + assert sharded.shape[0] == 4 + assert indices == list(range(rank_id * 4, (rank_id + 1) * 4)) + + def test_uneven_split(self): + """Test handling of sequence length not divisible by ranks.""" + tokens = mx.arange(10) + num_ranks = 4 + + all_indices = [] + for rank_id in range(num_ranks): + sharded, indices = shard_for_mode(tokens, num_ranks, rank_id, "decode") + all_indices.extend(indices) + + # All indices covered + assert sorted(all_indices) == list(range(10)) + + def test_contiguous_chunks(self): + """Decode sharding should produce contiguous chunks.""" + tokens = mx.arange(100) + num_ranks = 4 + + for rank_id in range(num_ranks): + _, indices = shard_for_mode(tokens, num_ranks, rank_id, "decode") + # Check contiguity: indices should be sequential + for i in range(1, len(indices)): + assert indices[i] == indices[i - 1] + 1 + + +class TestShardValidation: + """Tests for input validation.""" + + def test_invalid_num_ranks(self): + """Should raise on invalid num_ranks.""" + tokens = mx.arange(10) + with pytest.raises(ValueError, match="num_ranks must be positive"): + shard_for_mode(tokens, 0, 0, "prefill") + + def test_rank_out_of_range(self): + """Should raise on rank_id out of range.""" + tokens = mx.arange(10) + with pytest.raises(ValueError, match="rank_id .* out of range"): + shard_for_mode(tokens, 4, 5, "prefill") + + def test_empty_input(self): + """Empty input should return empty output.""" + tokens = mx.zeros((0, 128)) + sharded, indices = shard_for_mode(tokens, 4, 0, "prefill") + assert sharded.shape[0] == 0 + assert indices == [] + + +class TestUnshard: + """Tests for unshard operation.""" + + def test_roundtrip_prefill(self): + """Shard -> unshard should recover original.""" + original = mx.arange(32).reshape(32, 1).astype(mx.float32) + num_ranks = 4 + + # Shard + chunks = [] + indices_list = [] + for rank_id in range(num_ranks): + sharded, indices = shard_for_mode(original, num_ranks, rank_id, "prefill") + chunks.append(sharded) + indices_list.append(indices) + + # Unshard + recovered = unshard(chunks, indices_list, 32) + + assert mx.allclose(recovered, original) + + def test_roundtrip_decode(self): + """Shard -> unshard should recover original for decode mode.""" + original = mx.arange(32).reshape(32, 1).astype(mx.float32) + num_ranks = 4 + + chunks = [] + indices_list = [] + for rank_id in range(num_ranks): + sharded, indices = shard_for_mode(original, num_ranks, rank_id, "decode") + chunks.append(sharded) + indices_list.append(indices) + + recovered = unshard(chunks, indices_list, 32) + + assert mx.allclose(recovered, original) + + def test_multidimensional(self): + """Test with multi-dimensional tensors.""" + # Simulate hidden states: [seq, heads, dim] + original = mx.random.normal((64, 8, 128)) + num_ranks = 4 + + chunks = [] + indices_list = [] + for rank_id in range(num_ranks): + sharded, indices = shard_for_mode(original, num_ranks, rank_id, "decode") + chunks.append(sharded) + indices_list.append(indices) + + recovered = unshard(chunks, indices_list, 64) + + assert mx.allclose(recovered, original, atol=1e-5) diff --git a/tests/subsystems/test_inference_manager.py b/tests/subsystems/test_inference_manager.py index 8c770728..4384f543 100644 --- a/tests/subsystems/test_inference_manager.py +++ b/tests/subsystems/test_inference_manager.py @@ -238,15 +238,6 @@ def test_invalid_request_params_max_tokens_negative(): ) -def test_invalid_request_params_logprobs_zero_invalid(): - with pytest.raises(ValidationError): - _ = ChatRequestModel( - model="m", - messages=[ChatMessage(role="user", content="x")], - logprobs=0, # coerces to False but should still fail via validator - ) - - def test_invalid_request_params_stop_bad_type(): with pytest.raises(ValidationError): _ = ChatRequestModel( diff --git a/tests/subsystems/test_model_manager.py b/tests/subsystems/test_model_manager.py index 1588d8a0..d54103d3 100644 --- a/tests/subsystems/test_model_manager.py +++ b/tests/subsystems/test_model_manager.py @@ -260,3 +260,76 @@ async def main(): assert mm.current_model_id is None asyncio.run(main()) + + +def test_load_model_cp_fields_populated(monkeypatch): + """Verify that CP rank fields are correctly populated in load_model requests.""" + topo, dev1, dev2 = _mk_topology() + mm = ModelManager() + + rec = {} + + def _mk_post(url): + def f(payload): + rec[url] = payload + return FakeResponse( + 200, + { + "success": True, + "message": "ok", + "layers_loaded": payload["layers"], + "load_time_ms": 1.0, + }, + ) + + return f + + post_map = { + f"http://{dev1.local_ip}:{dev1.server_port}/load_model": _mk_post( + f"http://{dev1.local_ip}:{dev1.server_port}/load_model" + ), + f"http://{dev2.local_ip}:{dev2.server_port}/load_model": _mk_post( + f"http://{dev2.local_ip}:{dev2.server_port}/load_model" + ), + } + + monkeypatch.setattr( + "httpx.AsyncClient", lambda: FakeClient({}, post_map), raising=True + ) + monkeypatch.setattr( + "dnet.api.model_manager.resolve_tokenizer_dir", + lambda m: "/tmp/dir", + raising=True, + ) + monkeypatch.setattr( + "dnet.api.model_manager.load_tokenizer", lambda d, cfg: object(), raising=True + ) + + api_props = DnetDeviceProperties( + is_manager=True, + is_busy=False, + instance="API", + server_port=0, + shard_port=0, + local_ip="1.1.1.1", + ) + + async def main(): + res = await mm.load_model(topo, api_props, grpc_port=5050) + assert res.success is True + + # Verify S1 payload (Rank 0) + p1 = rec[f"http://{dev1.local_ip}:{dev1.server_port}/load_model"] + assert p1["cp_rank_id"] == 0 + assert p1["cp_num_ranks"] == 2 + # Check addresses: dev1 is 10.0.0.1:9011, dev2 is 10.0.0.2:9012 + expected_addrs = ["10.0.0.1:9011", "10.0.0.2:9012"] + assert p1["cp_rank_addresses"] == expected_addrs + + # Verify S2 payload (Rank 1) + p2 = rec[f"http://{dev2.local_ip}:{dev2.server_port}/load_model"] + assert p2["cp_rank_id"] == 1 + assert p2["cp_num_ranks"] == 2 + assert p2["cp_rank_addresses"] == expected_addrs + + asyncio.run(main()) diff --git a/tests/subsystems/test_shard_runtime.py b/tests/subsystems/test_shard_runtime.py index b4af89e8..1580297a 100644 --- a/tests/subsystems/test_shard_runtime.py +++ b/tests/subsystems/test_shard_runtime.py @@ -345,6 +345,9 @@ def test_invalid_kv_bits_fallback(monkeypatch): "residency_size": 1, "kv_bits": "invalid", "api_callback_address": "cb", + "max_position_embeddings": None, + "cp_rank_id": 0, + "cp_num_ranks": 1, }, )() rt.load_model_core(req)