Skip to content

Conversation

@jaisw7
Copy link
Collaborator

@jaisw7 jaisw7 commented Jan 2, 2026

Summary

Implements Context Parallelism (CP) infrastructure for long-context inference (128K+ tokens) by distributing sequence dimensions across multiple devices while replicating all model layers on each device.

Motivation

Long-context LLM inference is memory-bound. Context Parallelism enables:

  • Processing sequences longer than single-device memory capacity
  • Ring-based communication to overlap compute with data transfer
  • Dynamic algorithm selection (pass-KV for prefill, ring-reduce for decode)

Based on Ring Attention papers (arXiv:2310.01889, 2411.01783).

Changes

  • Core CP infrastructure (src/dnet/core/cp/):

    • sharding.py - Load-balanced prefill and contiguous decode sharding
    • merge_attention.py - Numerically stable partial output merging
    • heuristics.py - Algorithm selection (pass-KV, pass-Q, ring-reduce)
    • ring_comm.py - Async ring send/recv with gRPC transport
  • Shard adapter (src/dnet/shard/adapters/context_parallel.py):

    • CPAdapter implementing ring_pass_kv_attention and ring_reduce_attention
  • API strategy (src/dnet/api/strategies/context_parallel.py):

    • ContextParallelStrategy, CPTopologySolver, CPApiAdapter
  • Configuration (src/dnet/config.py, src/dnet/shard/models.py):

    • ContextParallelSettings with algorithm, thresholds
    • CP fields in ShardLoadModelRequest
  • Proto (src/dnet/protos/dnet_cp.proto):

    • CPRingService for block transfers
  • Tests: 47 unit tests + 11 integration tests

Type of Change

  • New feature (non-breaking change which adds functionality)

Testing

  • Tests pass locally
  • Added new tests for the changes
tests/subsystems/test_cp_*.py ............... 47 passed
tests/integration/test_cp_single_system.py .. 11 passed

Tests cover:

  • Sharding roundtrip (prefill/decode modes)
  • Merge numerical stability with varying scales
  • Ring communication rotation (4-rank simulation)
  • Algorithm selection heuristics
  • Adapter lifecycle and configuration

Checklist

  • My code follows the project's code style
  • I have made corresponding changes to the documentation

Related Issues

#84

@jaisw7 jaisw7 requested review from Copilot and removed request for Copilot January 2, 2026 21:48
@jaisw7 jaisw7 added the in progress Active work is being done on this issue label Jan 2, 2026
Copilot AI review requested due to automatic review settings January 2, 2026 22:03
@jaisw7 jaisw7 review requested due to automatic review settings January 2, 2026 22:03
Copilot AI review requested due to automatic review settings January 2, 2026 22:19
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements Context Parallelism (CP) infrastructure to enable long-context inference (128K+ tokens) by distributing the sequence dimension across multiple Apple Silicon devices while replicating the full model on each device. This complements the existing pipeline parallelism (RingStrategy) which shards layers.

Key changes include:

  • Core CP primitives for load-balanced sharding, numerically stable partial attention merging, and ring communication
  • Algorithm selection heuristics to choose between pass-KV (prefill), pass-Q, and ring-reduce (decode) based on cache hit rates
  • CPAdapter for shards and ContextParallelStrategy for API coordination

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 21 comments.

Show a summary per file
File Description
src/dnet/core/cp/sharding.py Load-balanced 2N prefill sharding and contiguous decode sharding with roundtrip unshard
src/dnet/core/cp/merge_attention.py Numerically stable merging of partial attention outputs using online softmax algorithm
src/dnet/core/cp/heuristics.py Greedy algorithm selection based on context length, cache hit rate, and GQA ratio
src/dnet/core/cp/ring_comm.py Ring communication primitives with async send/recv (placeholder gRPC implementation)
src/dnet/core/cp/__init__.py Lazy imports for MLX-dependent modules to support cross-platform heuristics
src/dnet/shard/adapters/context_parallel.py CPAdapter implementing ring_pass_kv_attention and ring_reduce_attention (serialization incomplete)
src/dnet/api/strategies/context_parallel.py CPTopologySolver for full model replication and CPApiAdapter for token distribution
src/dnet/config.py ContextParallelSettings with algorithm selection and threshold configuration
src/dnet/shard/models.py CP-specific fields in ShardLoadModelRequest for rank topology
src/dnet/protos/dnet_cp.proto gRPC service and message definitions for KV/Q block transfers
tests/subsystems/test_cp_*.py Unit tests for sharding, merging, heuristics, and ring communication (47 tests)
tests/integration/test_cp_single_system.py Integration tests for module interop and server e2e (11 tests)
docs/design/context-parallelism.md Comprehensive design document with architecture, algorithms, and verification plan
.github/workflows/cp-integration-tests.yml CI workflow for CP testing on macOS runner

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

jaisw7 added 12 commits January 2, 2026 17:47
- Implement CPRingServiceServicer with SendBlock and StreamBlocks methods
- Add start_cp_ring_server helper to start gRPC server for CP ring communication
- Remove MockRingCommunicator and MockRankCommunicator
- Rewrite test_ring_full_rotation_4_ranks to use actual gRPC servers
- All communication now goes through real network, no mocks
As per implementation plan section 4.4, added:
- DNET_CP_ENABLED
- DNET_CP_ALGORITHM
- DNET_CP_MIN_CONTEXT_FOR_CP
- DNET_CP_MIN_TOKENS_FOR_PASS_KV
- DNET_CP_CHUNK_OVERLAP
- Import and use ContextParallelStrategy when settings.context_parallel.enabled is true
- Fall back to RingStrategy otherwise
- Add Strategy base type annotation to fix mypy
Now both sides are wired:
- api.py uses ContextParallelStrategy when CP enabled
- shard.py uses CPAdapter when CP enabled
- Add cp_config field to ActivationRequest in dnet_ring.proto
- Import dnet_cp.proto for CPConfig type
- Wire up api.py strategy selection (ContextParallelStrategy)
- Wire up shard.py adapter selection (CPAdapter)
dnet_ring.proto imports dnet_cp.proto, but protoc generates bare imports
like 'import dnet_cp_pb2' which fail at runtime. Added post-processing to
generate_protos.py to convert these to relative imports.
CPTopologySolver was receiving all peers including the API manager node,
causing it to try loading model on API server (404). Now only actual
shard nodes are passed to solver.

Also fixed a stray triple-quote line in context_parallel.py.
Filter shards by checking if they exist in . Since  only
contains shards that passed health and latency checks, this robustly excludes
invalid nodes (like the API server itself) even if  flag is unreliable.
Implemented _ingress_worker to deserialize requests and feed ShardRuntime.
Implemented _egress_worker to drain runtime output and forward results.
Implemented _token_tx_worker to stream generated tokens to API.
This resolves the ReadTimeout hang observed during inference.
@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 3, 2026

@andthattoo For sanity check, I have tested context-parallelism on single-machine (62.210.193.181 on scaleway). dnet-tui needs changes as well. For now, please use curl requests to check if things are working correctly. The environment variable DNET_CP_ENABLED=true controls context parallelism.

Could you please check it on your local machine to ensure things are in order?

@jaisw7 jaisw7 requested review from andthattoo and erhant January 3, 2026 00:55
@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 3, 2026

Scaling appears to be approximately linear

[Test] Context length: ~100 tokens
       Prompt: 369 chars (~92 tokens)
       ✓ Success in 4.20s

[Test] Context length: ~500 tokens
       Prompt: 1,968 chars (~492 tokens)
       ✓ Success in 4.82s

[Test] Context length: ~1,000 tokens
       Prompt: 3,936 chars (~984 tokens)
       ✓ Success in 6.53s

[Test] Context length: ~8,000 tokens
       Prompt: 31,980 chars (~7,995 tokens)
       ✓ Success in 50.54s

jaisw7 added 10 commits January 3, 2026 11:12
Phase 1: CPApiAdapter now splits tokens and broadcasts to all ranks
- Added connect_all_ranks() for multi-rank connection
- Added _send_tokens_multi_rank() with sequence splitting via shard_for_mode()

Phase 2: Restored CP rank check in FitInMemoryPolicy
- Only last rank samples (with guard for single-device mode)
- Fixes the root cause of CP hang - API now splits properly

Phase 3: Wired multi-rank connection in InferenceManager/http_api
- Added connect_to_cp_ranks() method
- Updated load_model to use multi-rank when CP with >1 devices
During decode phase (1 token), splitting across ranks gives 0 tokens
to some ranks causing reshape errors. Now decode tokens go directly
to last rank only.
jaisw7 added 5 commits January 3, 2026 15:34
- Add adapter=None to FakeInferenceManager for isinstance check
- Handle 1D/2D tensors in logits slicing (FakeComputeModel returns 1D)
- Add type annotation to needle_in_haystack.py by_size dict
Token splitting broke cross-chunk attention. Now all ranks get full
context. Savings come from sharded KV cache, not token splitting.
@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 3, 2026

Added needle-in-haystack test scripts/needle_in_haystack.py to check sanity of implementation on long contexts.

m1@673bd4ec-9389-4eda-aa51-fa7d3abee8e8 dnet % uv run python scripts/needle_in_haystack.py --sizes 8192

============================================================
Needle in Haystack Test
============================================================
Target context: ~8192 tokens
Actual prompt: ~8263 tokens
Needle position: 10%
Expected password: juliet-936-lima
============================================================
Response: juliet-936-lima<|eot_id|>
Latency: 18.73s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~8192 tokens
Actual prompt: ~8290 tokens
Needle position: 25%
Expected password: lima-691-hotel
============================================================
Response: lima-691-hotel<|eot_id|>
Latency: 17.02s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~8192 tokens
Actual prompt: ~8285 tokens
Needle position: 50%
Expected password: alpha-286-echo
============================================================
Response: alpha-286-echo<|eot_id|>
Latency: 19.65s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~8192 tokens
Actual prompt: ~8310 tokens
Needle position: 75%
Expected password: gamma-802-kilo
============================================================
Response: gamma-802-kilo<|eot_id|>
Latency: 19.91s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~8192 tokens
Actual prompt: ~8286 tokens
Needle position: 90%
Expected password: echo-479-foxtrot
============================================================
Response: echo-479-foxtrot<|eot_id|>
Latency: 20.01s
Result: ✓ PASS

============================================================
SUMMARY
============================================================
Passed: 5/5
    8192 tokens: 5/5 passed, avg 19.1s

✓ ALL TESTS PASSED - CP is working correctly!

@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 3, 2026

Scaling is nearly linear, but there are some crashes at 16k+ context on Qwen/Qwen3-4B-MLX-4bit

m1@673bd4ec-9389-4eda-aa51-fa7d3abee8e8 dnet % uv run python scripts/stress_test_cp.py --sizes 128,256,512,1024,2048,4096,8192                               
============================================================
Context Parallelism Stress Test
============================================================
API:        http://localhost:8080
Max tokens: 100
Streaming:  False

[Check] Detecting shards...
        Found 2 shard(s):
          - shard-1 (62.210.193.181:8082)
          - shard-2 (62.210.193.192:8082)

[Check] Verifying model is loaded...
        Model: Qwen/Qwen3-4B-MLX-4bit

[Check] Checking CP settings...
        ✓ Context Parallelism is ENABLED

Test sizes: [128, 256, 512, 1024, 2048, 4096, 8192]
(Recommended for 2 shards - includes sizes that benefit from CP)

[Test] Context length: ~128 tokens
       Prompt: 492 chars (~123 tokens)
       ✓ Success in 4.45s

[Test] Context length: ~256 tokens
       Prompt: 984 chars (~246 tokens)
       ✓ Success in 4.68s

[Test] Context length: ~512 tokens
       Prompt: 1,968 chars (~492 tokens)
       ✓ Success in 5.46s

[Test] Context length: ~1,024 tokens
       Prompt: 4,059 chars (~1,014 tokens)
       ✓ Success in 7.22s

[Test] Context length: ~2,048 tokens
       Prompt: 8,118 chars (~2,029 tokens)
       ✓ Success in 11.16s

[Test] Context length: ~4,096 tokens
       Prompt: 16,359 chars (~4,089 tokens)
       ✓ Success in 21.76s

[Test] Context length: ~8,192 tokens
       Prompt: 32,718 chars (~8,179 tokens)
       ✓ Success in 52.11s

============================================================
Summary
============================================================
Tests passed: 7/7
Shards used:  2
Avg time:     15.26s
Max time:     52.11s

Details:
Context    Time       TTFT       Tokens/s  
---------------------------------------------
128        4.45       -          51.5      
256        4.68       -          74.6      
512        5.46       -          107.9     
1024       7.22       -          152.2     
2048       11.16      -          187.2     
4096       21.76      -          188.4     
8192       52.11      -          155.2     

jaisw7 added 2 commits January 3, 2026 16:19
- Add CPAttentionWrapper with global RoPE offset and cache handling

- Inject CP adapter into BaseRingModel

- Update API to use load-balanced sharding for prefill

- Implement sync bridges for ring attention ops
- Add safety check in CPAttentionWrapper.__getattr__

- Add idempotency check in BaseRingModel.set_cp_adapter
@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 4, 2026

dc63fab is the current working commit from this PR.
There is some numerical instability in merge attention in latest commit.

@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 4, 2026

Update: I am still debugging merge-attention issues with distributed KV-Cache on multiple shards.

  • Code works correctly on single shard.
  • Code works correctly with all-to-all broadcast (without distributed KV-Cache) on multiple shards.
[1/4] Fetching available devices from http://localhost:8080...
    Using 2 shard(s) for Context Parallelism:
      [0] shard-1 (62.210.193.181:8082)
      [1] shard-2 (62.210.193.192:8082)
[2/4] Fetching model config for mlx-community/Llama-3.2-3B-Instruct-4bit...
    Model has 28 layers (full model on each shard)
    Sequence length: 131072
[3/4] Preparing CP topology...
    Topology prepared successfully
    Model: mlx-community/Llama-3.2-3B-Instruct-4bit
    Devices: ['shard-1', 'shard-2']
[4/4] Loading model on all shards (this may take a while)...
    Model loaded successfully!

============================================================
Context Parallelism Ready
============================================================
  Model:      mlx-community/Llama-3.2-3B-Instruct-4bit
  CP Ranks:   2
  Shards:     shard-1, shard-2
  KV Bits:    8bit
  Seq Len:    131072

Each shard has the full model and will process 1/2 of
the context window during inference.

  ✓ shard-1: Model loaded successfully
  ✓ shard-2: Model loaded successfully

============================================================
Needle in Haystack Test
============================================================
Target context: ~128 tokens
Actual prompt: ~199 tokens
Needle position: 10%
Expected password: juliet-431-india
============================================================
Response: ininininininininininivivivlelelelelelelewwwwouinainainaina [ina [ [ [ [ [ theif thea thea thea thein thea the Question of the Question ofinininlininlinllelelerleinelerineineine "le "B::::::::::::::::::''s''s''s isin's the Question's the Question: (in the Question: (d, is thed: (in a problem -d, but not expected, but will likely have been toldier, but they will likely have been told the Question and also have you don'tec it will [ [ [ [ [ [ [ as [ [ [ [ [ [ [ [ [ [ [ [ [ [ [ [ [ as [ [ [ [ as the as the as the as the or or or or or or or or or or or or or or or or or or or or as as as as as as as aseneeneeneeneseenese the  the  the  the the the the the the the the the the the the the the the the the the the

@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 5, 2026

I have got the code to work on multiple shards with distributed KV-Cache. However, I am seeing some 8-bit KV Cache corruption.

Problem: RoPE (Rotary Position Embeddings) produce values with extreme magnitudes (both very large and very small). When these RoPE'd keys are stored in 8-bit quantized KV cache, quantization loses precision on the extreme values. Dequantization produces inf or -inf. Result: k_norm=inf, k_mean=-inf at decode time

Fix: Use kv_bits: "fp16"for Context Parallelism. A proper fix would require quantization-aware RoPE handling (store keys pre-RoPE, apply RoPE after dequant).

[1/4] Fetching available devices from http://localhost:8080...
    Using 2 shard(s) for Context Parallelism:
      [0] shard-1 (62.210.193.181:8082)
      [1] shard-2 (62.210.193.192:8082)
[2/4] Fetching model config for mlx-community/Llama-3.2-3B-Instruct-4bit...
    Model has 28 layers (full model on each shard)
    Sequence length: 131072
[3/4] Preparing CP topology...
    Topology prepared successfully
    Model: mlx-community/Llama-3.2-3B-Instruct-4bit
    Devices: ['shard-1', 'shard-2']
[4/4] Loading model on all shards (this may take a while)...
    Model loaded successfully!

============================================================
Context Parallelism Ready
============================================================
  Model:      mlx-community/Llama-3.2-3B-Instruct-4bit
  CP Ranks:   2
  Shards:     shard-1, shard-2
  KV Bits:    fp16
  Seq Len:    131072

Each shard has the full model and will process 1/2 of
the context window during inference.

  ✓ shard-1: Model loaded successfully
  ✓ shard-2: Model loaded successfully

============================================================
Needle in Haystack Test
============================================================
Target context: ~128 tokens
Actual prompt: ~212 tokens
Needle position: 10%
Expected password: lima-355-delta
============================================================
Response: lima-355-delta<|eot_id|>
Latency: 2.39s
Result: ✓ PASS

============================================================
SUMMARY
============================================================
Passed: 1/1
     128 tokens: 1/1 passed, avg 2.4s

✓ ALL TESTS PASSED - CP is working correctly!

@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 5, 2026

Preliminary scaling numbers for distributed KV-Cache.

m1@673bd4ec-9389-4eda-aa51-fa7d3abee8e8 dnet % uv run python scripts/needle_in_haystack.py --sizes 128,256,512,1024,2048,4096,8192,16384

============================================================
Needle in Haystack Test
============================================================
Target context: ~128 tokens
Actual prompt: ~212 tokens
Needle position: 75%
Expected password: lima-601-lima
============================================================
Response: lima-601-lima<|eot_id|>
Latency: 2.53s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~256 tokens
Actual prompt: ~362 tokens
Needle position: 75%
Expected password: hotel-263-hotel
============================================================
Response: hotel-263-hotel<|eot_id|>
Latency: 2.37s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~512 tokens
Actual prompt: ~585 tokens
Needle position: 75%
Expected password: foxtrot-223-india
============================================================
Response: foxtrot-223-india<|eot_id|>
Latency: 3.03s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~1024 tokens
Actual prompt: ~1111 tokens
Needle position: 75%
Expected password: india-428-bravo
============================================================
Response: india-428-bravo<|eot_id|>
Latency: 3.58s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~2048 tokens
Actual prompt: ~2135 tokens
Needle position: 75%
Expected password: charlie-683-foxtrot
============================================================
Response: charlie-683-foxtrot<|eot_id|>
Latency: 5.19s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~4096 tokens
Actual prompt: ~4175 tokens
Needle position: 75%
Expected password: gamma-347-kilo
============================================================
Response: gamma-347-kilo<|eot_id|>
Latency: 8.59s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~8192 tokens
Actual prompt: ~8259 tokens
Needle position: 75%
Expected password: foxtrot-688-kilo
============================================================
Response: foxtrot-688-kilo<|eot_id|>
Latency: 21.23s
Result: ✓ PASS

============================================================
Needle in Haystack Test
============================================================
Target context: ~16384 tokens
Actual prompt: ~16486 tokens
Needle position: 75%
Expected password: delta-433-lima
============================================================
Response: delta-433-lima<|eot_id|>
Latency: 66.07s
Result: ✓ PASS

============================================================
SUMMARY
============================================================
Passed: 8/8
     128 tokens: 1/1 passed, avg 2.5s
     256 tokens: 1/1 passed, avg 2.4s
     512 tokens: 1/1 passed, avg 3.0s
    1024 tokens: 1/1 passed, avg 3.6s
    2048 tokens: 1/1 passed, avg 5.2s
    4096 tokens: 1/1 passed, avg 8.6s
    8192 tokens: 1/1 passed, avg 21.2s
   16384 tokens: 1/1 passed, avg 66.1s

✓ ALL TESTS PASSED - CP is working correctly!

@jaisw7
Copy link
Collaborator Author

jaisw7 commented Jan 5, 2026

Environment variable configurations for tests:

export DNET_LOG=DEBUG
export DNET_CP_ENABLED=true
export DNET_CP_MIN_CONTEXT_FOR_CP=64
export DNET_API_CALLBACK_ADDR="62.210.193.181:58080"
export DNET_KV_MODE=fp16
export DNET_KV_BITS=16
export DNET_KV_TTL_S=600

export DNET_GRPC_MAX_MESSAGE_LENGTH=267108864
export DNET_GRPC_MAX_CONCURRENT_STREAMS=1024
export DNET_GRPC_KEEPALIVE_TIME_MS=2120000
export DNET_GRPC_KEEPALIVE_TIMEOUT_MS=220000

export DNET_TRANSPORT_STREAM_BACKOFF_S=50
export DNET_TRANSPORT_STREAM_IDLE_S=200
export DNET_TRANSPORT_SEND_RETRIES=3
export DNET_TRANSPORT_COMPRESS=false
export DNET_TRANSPORT_COMPRESS_MIN_BYTES=265536

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in progress Active work is being done on this issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants