-
Notifications
You must be signed in to change notification settings - Fork 7
feat: context parallelism #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements Context Parallelism (CP) infrastructure to enable long-context inference (128K+ tokens) by distributing the sequence dimension across multiple Apple Silicon devices while replicating the full model on each device. This complements the existing pipeline parallelism (RingStrategy) which shards layers.
Key changes include:
- Core CP primitives for load-balanced sharding, numerically stable partial attention merging, and ring communication
- Algorithm selection heuristics to choose between pass-KV (prefill), pass-Q, and ring-reduce (decode) based on cache hit rates
- CPAdapter for shards and ContextParallelStrategy for API coordination
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 21 comments.
Show a summary per file
| File | Description |
|---|---|
src/dnet/core/cp/sharding.py |
Load-balanced 2N prefill sharding and contiguous decode sharding with roundtrip unshard |
src/dnet/core/cp/merge_attention.py |
Numerically stable merging of partial attention outputs using online softmax algorithm |
src/dnet/core/cp/heuristics.py |
Greedy algorithm selection based on context length, cache hit rate, and GQA ratio |
src/dnet/core/cp/ring_comm.py |
Ring communication primitives with async send/recv (placeholder gRPC implementation) |
src/dnet/core/cp/__init__.py |
Lazy imports for MLX-dependent modules to support cross-platform heuristics |
src/dnet/shard/adapters/context_parallel.py |
CPAdapter implementing ring_pass_kv_attention and ring_reduce_attention (serialization incomplete) |
src/dnet/api/strategies/context_parallel.py |
CPTopologySolver for full model replication and CPApiAdapter for token distribution |
src/dnet/config.py |
ContextParallelSettings with algorithm selection and threshold configuration |
src/dnet/shard/models.py |
CP-specific fields in ShardLoadModelRequest for rank topology |
src/dnet/protos/dnet_cp.proto |
gRPC service and message definitions for KV/Q block transfers |
tests/subsystems/test_cp_*.py |
Unit tests for sharding, merging, heuristics, and ring communication (47 tests) |
tests/integration/test_cp_single_system.py |
Integration tests for module interop and server e2e (11 tests) |
docs/design/context-parallelism.md |
Comprehensive design document with architecture, algorithms, and verification plan |
.github/workflows/cp-integration-tests.yml |
CI workflow for CP testing on macOS runner |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Implement CPRingServiceServicer with SendBlock and StreamBlocks methods - Add start_cp_ring_server helper to start gRPC server for CP ring communication - Remove MockRingCommunicator and MockRankCommunicator - Rewrite test_ring_full_rotation_4_ranks to use actual gRPC servers - All communication now goes through real network, no mocks
As per implementation plan section 4.4, added: - DNET_CP_ENABLED - DNET_CP_ALGORITHM - DNET_CP_MIN_CONTEXT_FOR_CP - DNET_CP_MIN_TOKENS_FOR_PASS_KV - DNET_CP_CHUNK_OVERLAP
- Import and use ContextParallelStrategy when settings.context_parallel.enabled is true - Fall back to RingStrategy otherwise - Add Strategy base type annotation to fix mypy
Now both sides are wired: - api.py uses ContextParallelStrategy when CP enabled - shard.py uses CPAdapter when CP enabled
- Add cp_config field to ActivationRequest in dnet_ring.proto - Import dnet_cp.proto for CPConfig type - Wire up api.py strategy selection (ContextParallelStrategy) - Wire up shard.py adapter selection (CPAdapter)
dnet_ring.proto imports dnet_cp.proto, but protoc generates bare imports like 'import dnet_cp_pb2' which fail at runtime. Added post-processing to generate_protos.py to convert these to relative imports.
CPTopologySolver was receiving all peers including the API manager node, causing it to try loading model on API server (404). Now only actual shard nodes are passed to solver. Also fixed a stray triple-quote line in context_parallel.py.
Filter shards by checking if they exist in . Since only contains shards that passed health and latency checks, this robustly excludes invalid nodes (like the API server itself) even if flag is unreliable.
Implemented _ingress_worker to deserialize requests and feed ShardRuntime. Implemented _egress_worker to drain runtime output and forward results. Implemented _token_tx_worker to stream generated tokens to API. This resolves the ReadTimeout hang observed during inference.
|
@andthattoo For sanity check, I have tested context-parallelism on single-machine ( Could you please check it on your local machine to ensure things are in order? |
|
Scaling appears to be approximately linear |
Phase 1: CPApiAdapter now splits tokens and broadcasts to all ranks - Added connect_all_ranks() for multi-rank connection - Added _send_tokens_multi_rank() with sequence splitting via shard_for_mode() Phase 2: Restored CP rank check in FitInMemoryPolicy - Only last rank samples (with guard for single-device mode) - Fixes the root cause of CP hang - API now splits properly Phase 3: Wired multi-rank connection in InferenceManager/http_api - Added connect_to_cp_ranks() method - Updated load_model to use multi-rank when CP with >1 devices
During decode phase (1 token), splitting across ranks gives 0 tokens to some ranks causing reshape errors. Now decode tokens go directly to last rank only.
- Add adapter=None to FakeInferenceManager for isinstance check - Handle 1D/2D tensors in logits slicing (FakeComputeModel returns 1D) - Add type annotation to needle_in_haystack.py by_size dict
Token splitting broke cross-chunk attention. Now all ranks get full context. Savings come from sharded KV cache, not token splitting.
…els for best results
|
Added needle-in-haystack test |
|
Scaling is nearly linear, but there are some crashes at 16k+ context on |
- Add CPAttentionWrapper with global RoPE offset and cache handling - Inject CP adapter into BaseRingModel - Update API to use load-balanced sharding for prefill - Implement sync bridges for ring attention ops
- Add safety check in CPAttentionWrapper.__getattr__ - Add idempotency check in BaseRingModel.set_cp_adapter
|
dc63fab is the current working commit from this PR. |
|
Update: I am still debugging merge-attention issues with distributed KV-Cache on multiple shards.
|
|
I have got the code to work on multiple shards with distributed KV-Cache. However, I am seeing some 8-bit KV Cache corruption. Problem: RoPE (Rotary Position Embeddings) produce values with extreme magnitudes (both very large and very small). When these RoPE'd keys are stored in 8-bit quantized KV cache, quantization loses precision on the extreme values. Dequantization produces inf or -inf. Result: k_norm=inf, k_mean=-inf at decode time Fix: Use kv_bits: "fp16"for Context Parallelism. A proper fix would require quantization-aware RoPE handling (store keys pre-RoPE, apply RoPE after dequant). |
|
Preliminary scaling numbers for distributed KV-Cache. |
|
Environment variable configurations for tests: export DNET_LOG=DEBUG
export DNET_CP_ENABLED=true
export DNET_CP_MIN_CONTEXT_FOR_CP=64
export DNET_API_CALLBACK_ADDR="62.210.193.181:58080"
export DNET_KV_MODE=fp16
export DNET_KV_BITS=16
export DNET_KV_TTL_S=600
export DNET_GRPC_MAX_MESSAGE_LENGTH=267108864
export DNET_GRPC_MAX_CONCURRENT_STREAMS=1024
export DNET_GRPC_KEEPALIVE_TIME_MS=2120000
export DNET_GRPC_KEEPALIVE_TIMEOUT_MS=220000
export DNET_TRANSPORT_STREAM_BACKOFF_S=50
export DNET_TRANSPORT_STREAM_IDLE_S=200
export DNET_TRANSPORT_SEND_RETRIES=3
export DNET_TRANSPORT_COMPRESS=false
export DNET_TRANSPORT_COMPRESS_MIN_BYTES=265536 |
Summary
Implements Context Parallelism (CP) infrastructure for long-context inference (128K+ tokens) by distributing sequence dimensions across multiple devices while replicating all model layers on each device.
Motivation
Long-context LLM inference is memory-bound. Context Parallelism enables:
Based on Ring Attention papers (arXiv:2310.01889, 2411.01783).
Changes
Core CP infrastructure (
src/dnet/core/cp/):sharding.py- Load-balanced prefill and contiguous decode shardingmerge_attention.py- Numerically stable partial output mergingheuristics.py- Algorithm selection (pass-KV, pass-Q, ring-reduce)ring_comm.py- Async ring send/recv with gRPC transportShard adapter (
src/dnet/shard/adapters/context_parallel.py):CPAdapterimplementing ring_pass_kv_attention and ring_reduce_attentionAPI strategy (
src/dnet/api/strategies/context_parallel.py):ContextParallelStrategy,CPTopologySolver,CPApiAdapterConfiguration (
src/dnet/config.py,src/dnet/shard/models.py):ContextParallelSettingswith algorithm, thresholdsShardLoadModelRequestProto (
src/dnet/protos/dnet_cp.proto):CPRingServicefor block transfersTests: 47 unit tests + 11 integration tests
Type of Change
Testing
Tests cover:
Checklist
Related Issues
#84