Summary
Extend Phase 7 virtual batch attention to support multi-request batching (max_num_seqs > 1). The Phase 7 MVP enforces single-request batching due to complexity of handling heterogeneous prefix lengths across multiple requests.
Background
Current limitation (Phase 7 MVP):
- Virtual batch attention raises
NotImplementedError when num_reqs > 1
- Server requires
--max-num-seqs 1 configuration
- Single-request processing reduces GPU utilization in high-throughput scenarios
Technical blocker:
Virtual batch transformation (make_block_attention_virtual_batches()) currently assumes homogeneous prefix lengths. With multiple concurrent requests, each request may have different num_prefix_tokens, requiring per-request metadata transformation.
Scope
In Scope
-
Per-request virtual batch transformation
- Remove
num_reqs > 1 check in dllm_plugin/attention/virtual_batches.py:56
- Implement heterogeneous prefix length handling
- Transform metadata separately for each request in the batch
-
Metadata structure updates
- Support variable-length prefix chunks across requests
- Correct KV cache slicing for heterogeneous batches
- Maintain separate
query_start_loc offsets per request
-
Integration testing
- Add test with 2+ concurrent requests to
tests/test_virtual_batch_multi_request.py
- Validate correct attention computation with heterogeneous prefix lengths
- End-to-end integration test with real LLaDA2.0 model
-
Performance benchmarking
- Measure throughput improvement with multi-request batching
- Compare
max_num_seqs=1 vs max_num_seqs=4 vs max_num_seqs=8
- Document GPU utilization improvements
-
Documentation updates
- Update
docs/OPERATOR_LLaDA2.md to reflect multi-request support
- Remove single-request workarounds from operator guide
- Add multi-request configuration examples
Out of Scope
- Pipeline parallelism (
PP > 1) - remains unsupported (separate issue)
- Tensor parallelism enhancements - TP already supported in Phase 7
- Attention optimizations (single-pass, CUTLASS, FlashInfer) - covered by Phase 8.x
Technical Design
Current Architecture (Phase 7 MVP)
def make_block_attention_virtual_batches(
attn_metadata: CommonAttentionMetadata,
num_prefix_tokens: int, # Single value for all requests
block_size: int,
) -> tuple[CommonAttentionMetadata | None, CommonAttentionMetadata]:
if attn_metadata.num_reqs > 1:
raise NotImplementedError("multi-request not supported")
# ... transform assuming single request
Proposed Architecture (Phase 7.1)
def make_block_attention_virtual_batches(
attn_metadata: CommonAttentionMetadata,
num_prefix_tokens: list[int] | int, # Per-request or single value
block_size: int,
) -> tuple[CommonAttentionMetadata | None, CommonAttentionMetadata]:
# Handle both single-request and multi-request cases
if isinstance(num_prefix_tokens, int):
# Single request path (backward compatible)
num_prefix_tokens = [num_prefix_tokens]
# Per-request transformation
for req_idx in range(attn_metadata.num_reqs):
req_prefix_tokens = num_prefix_tokens[req_idx]
# Transform metadata for this request
# Update query_start_loc, seq_lens, block_tables per request
# Combine transformed metadata for all requests
# ...
Dependencies
Acceptance Criteria
Testing Strategy
Unit Tests
def test_virtual_batch_multi_request_heterogeneous_prefix():
"""Test with 2 requests having different prefix lengths."""
attn_metadata = CommonAttentionMetadata(
num_reqs=2,
# Request 0: 16 prefix tokens
# Request 1: 24 prefix tokens
...
)
prefix_metadata, block_metadata = make_block_attention_virtual_batches(
attn_metadata=attn_metadata,
num_prefix_tokens=[16, 24], # Heterogeneous
block_size=32,
)
assert prefix_metadata.num_reqs == 2
assert block_metadata.num_reqs == 2
# Validate correct per-request slicing
Integration Tests
- End-to-end test with real LLaDA2.0-mini model
- 2 concurrent requests with different prompt lengths
- Validate output correctness for both requests
- Measure GPU utilization improvement
Benchmarks
# Baseline (Phase 7 MVP)
vllm serve inclusionAI/LLaDA2.0-mini --max-num-seqs 1
# Benchmark: X tokens/sec
# Phase 7.1
vllm serve inclusionAI/LLaDA2.0-mini --max-num-seqs 4
# Benchmark: Y tokens/sec (expect Y > X)
Performance Impact
Expected improvements:
- Throughput: +50-100% with
max_num_seqs=4 (GPU-dependent)
- GPU utilization: Increase from ~30-50% to ~70-90% under load
- Latency: Per-request latency may increase slightly, but overall throughput improves
Measurement:
- Use GuideLLM with constant rate profile to measure concurrent requests
- Monitor GPU utilization with
nvidia-smi dmon
- Compare TTFT and ITL distributions
Risks and Mitigations
Risk 1: Complexity of per-request transformation
Mitigation:
- Start with simple case: 2 requests with same block size
- Incremental implementation: homogeneous block sizes first, then heterogeneous
- Extensive testing at each step
Risk 2: Performance regression for single-request case
Mitigation:
- Keep fast path for single-request case (
isinstance(num_prefix_tokens, int))
- Benchmark both before/after to ensure no regression
- Use profiling to identify bottlenecks
Risk 3: KV cache slicing bugs
Mitigation:
- Thorough unit tests with manual verification of offsets
- Integration tests with known-good attention outputs
- Reference implementation comparison (if available)
Migration Notes
Backward compatibility: Phase 7.1 will remain backward compatible with single-request configurations. Operators can continue using --max-num-seqs 1 if desired.
Recommended upgrade path:
- Deploy Phase 7.1 with
--max-num-seqs 1 initially (validate no regression)
- Gradually increase to
--max-num-seqs 2, then 4, monitoring performance
- Tune based on GPU memory and throughput requirements
Related Issues
Timeline Estimate
- Research/Design: 1-2 days (understand heterogeneous prefix transformations)
- Implementation: 3-5 days (per-request transformation + integration)
- Testing: 2-3 days (unit tests + integration tests + benchmarks)
- Documentation: 1 day
Total: ~1.5-2 weeks for single developer
References
Summary
Extend Phase 7 virtual batch attention to support multi-request batching (
max_num_seqs > 1). The Phase 7 MVP enforces single-request batching due to complexity of handling heterogeneous prefix lengths across multiple requests.Background
Current limitation (Phase 7 MVP):
NotImplementedErrorwhennum_reqs > 1--max-num-seqs 1configurationTechnical blocker:
Virtual batch transformation (
make_block_attention_virtual_batches()) currently assumes homogeneous prefix lengths. With multiple concurrent requests, each request may have differentnum_prefix_tokens, requiring per-request metadata transformation.Scope
In Scope
Per-request virtual batch transformation
num_reqs > 1check indllm_plugin/attention/virtual_batches.py:56Metadata structure updates
query_start_locoffsets per requestIntegration testing
tests/test_virtual_batch_multi_request.pyPerformance benchmarking
max_num_seqs=1vsmax_num_seqs=4vsmax_num_seqs=8Documentation updates
docs/OPERATOR_LLaDA2.mdto reflect multi-request supportOut of Scope
PP > 1) - remains unsupported (separate issue)Technical Design
Current Architecture (Phase 7 MVP)
Proposed Architecture (Phase 7.1)
Dependencies
Acceptance Criteria
num_reqs > 1check fromvirtual_batches.pypytest -v tests/test_virtual_batch_multi_request.pymax_num_seqs > 1--max-num-seqs 4(or higher) without errorsTesting Strategy
Unit Tests
Integration Tests
Benchmarks
Performance Impact
Expected improvements:
max_num_seqs=4(GPU-dependent)Measurement:
nvidia-smi dmonRisks and Mitigations
Risk 1: Complexity of per-request transformation
Mitigation:
Risk 2: Performance regression for single-request case
Mitigation:
isinstance(num_prefix_tokens, int))Risk 3: KV cache slicing bugs
Mitigation:
Migration Notes
Backward compatibility: Phase 7.1 will remain backward compatible with single-request configurations. Operators can continue using
--max-num-seqs 1if desired.Recommended upgrade path:
--max-num-seqs 1initially (validate no regression)--max-num-seqs 2, then4, monitoring performanceRelated Issues
Timeline Estimate
Total: ~1.5-2 weeks for single developer
References
dllm_plugin/attention/virtual_batches.pytests/test_virtual_batch_multi_request.pydocs/OPERATOR_LLaDA2.md