Skip to content

Commit 9fdfb23

Browse files
committed
add pp_partition to customize each rank's layer number
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent 1944fb1 commit 9fdfb23

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

docs/source/developer-guide/api-change.md

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ TensorRT LLM classifies APIs into two categories:
3434
All API schemas are:
3535
- Stored as YAML files in the codebase
3636
- Protected by unit tests in `tests/unittest/api_stability/`
37-
- Automatically validated to ensure consistency
37+
- Automatically validated to ensure consistency
3838

3939
## API Change Principles
4040

@@ -44,22 +44,22 @@ All API schemas are:
4444

4545
Argument names should describe what the argument represents, not how it is used internally.
4646

47-
**Good**: `max_new_tokens` (clear meaning)
47+
**Good**: `max_new_tokens` (clear meaning)
4848
**Bad**: `num` (ambiguous)
4949

5050
**Reflect Argument Type and Granularity**
5151

5252
- For **boolean** knobs, prefix with verbs like `enable_` and so on.
5353
Examples: `enable_cache`, `enable_flash_attention`
5454

55-
- For **numerical threshold** knobs, suffix with `_limit`, `_size`, `_count`, `_len_` or `_ratio`
55+
- For **numerical threshold** knobs, suffix with `_limit`, `_size`, `_count`, `_len_` or `_ratio`
5656
Examples: `max_seq_len`, `prefill_batch_size`
5757

5858
**Avoid Redundant Prefixes**
5959

6060
Example (in `MoeConfig`):
6161

62-
**Good**: `backend`
62+
**Good**: `backend`
6363
**Bad**: `moe_backend` (redundant since it's already in `MoeConfig`)
6464

6565
**Use Specific Names for Narrow Scenarios**
@@ -68,7 +68,7 @@ When adding knobs for specific use cases, make the name convey the restriction c
6868

6969
Example (argument to the LLM class):
7070

71-
**Good**: `rope_scaling_factor` → clearly indicates it's for RoPE
71+
**Good**: `rope_scaling_factor` → clearly indicates it's for RoPE
7272
**Bad**: `scaling_factor` → too generic and prone to misuse
7373

7474
### 2. Hierarchical Configuration
@@ -77,13 +77,13 @@ Organize complex or hierarchical arguments into **dedicated configuration datacl
7777

7878
**Guidelines**
7979

80-
- Use the `XxxConfig` suffix consistently
80+
- Use the `XxxConfig` suffix consistently
8181
Examples: `ModelConfig`, `ParallelConfig`, `MoeConfig`
82-
83-
- **Reflect conceptual hierarchy**
82+
83+
- **Reflect conceptual hierarchy**
8484
The dataclass name should represent a coherent functional unit, not an arbitrary grouping
85-
86-
- **Avoid over-nesting**
85+
86+
- **Avoid over-nesting**
8787
Use only one level of configuration hierarchy whenever possible (e.g., `LlmArgs → ParallelConfig`) to balance readability and modularity
8888

8989
### 3. Prefer `LlmArgs` Over Environment Variables
@@ -154,15 +154,15 @@ garbage_collection_gen0_threshold: int = Field(
154154

155155
Add the field to the appropriate schema file:
156156

157-
- **Non-committed arguments**: `tests/unittest/api_stability/references/llm_args.yaml`
157+
- **Non-committed arguments**: `tests/unittest/api_stability/references/llm.yaml`
158158
```yaml
159159
garbage_collection_gen0_threshold:
160160
type: int
161161
default: 20000
162162
status: beta # Must match the status in code
163163
```
164164
165-
- **Committed arguments**: `tests/unittest/api_stability/references_committed/llm_args.yaml`
165+
- **Committed arguments**: `tests/unittest/api_stability/references_committed/llm.yaml`
166166
```yaml
167167
garbage_collection_gen0_threshold:
168168
type: int
@@ -196,16 +196,16 @@ For non-committed APIs, use the `@set_api_status` decorator:
196196
```python
197197
@set_api_status("beta")
198198
def generate_with_streaming(
199-
self,
200-
prompts: List[str],
199+
self,
200+
prompts: List[str],
201201
**kwargs
202202
) -> Iterator[GenerationOutput]:
203203
"""Generate text with streaming output.
204-
204+
205205
Args:
206206
prompts: Input prompts for generation
207207
**kwargs: Additional generation parameters
208-
208+
209209
Returns:
210210
Iterator of generation outputs
211211
"""

tensorrt_llm/llmapi/llm_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ class _ParallelConfig(StrictBaseModel):
326326
moe_tp_size: int = -1
327327
moe_ep_size: int = -1
328328
cp_config: dict = Field(default_factory=dict)
329+
pp_partition: Optional[List[int]] = Field(default=None)
329330
enable_attention_dp: bool = False
330331
enable_lm_head_tp_in_adp: bool = False
331332

@@ -372,6 +373,7 @@ def to_mapping(self) -> Mapping:
372373
gpus_per_node=self.gpus_per_node,
373374
tp_size=self.tp_size,
374375
pp_size=self.pp_size,
376+
pp_partition=self.pp_partition,
375377
cp_size=self.cp_size,
376378
cp_config=self.cp_config,
377379
enable_attention_dp=self.enable_attention_dp,
@@ -1587,6 +1589,12 @@ class BaseLlmArgs(StrictBaseModel):
15871589
description="Enable LM head TP in attention dp.",
15881590
status="prototype")
15891591

1592+
pp_partition: Optional[List[int]] = Field(
1593+
default=None,
1594+
description=
1595+
"Pipeline parallel partition, a list of each rank's layer number.",
1596+
status="prototype")
1597+
15901598
cp_config: Optional[dict] = Field(default_factory=dict,
15911599
description="Context parallel config.",
15921600
status="prototype")
@@ -1843,6 +1851,7 @@ def validate_parallel_config(self):
18431851
moe_ep_size=self.moe_expert_parallel_size,
18441852
enable_attention_dp=self.enable_attention_dp,
18451853
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
1854+
pp_partition=self.pp_partition,
18461855
cp_config=self.cp_config)
18471856
return self
18481857

tensorrt_llm/mapping.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
cp_config=None,
5050
tp_size=1,
5151
pp_size=1,
52+
pp_partition=None,
5253
moe_cluster_size=-1, # -1 means no moe
5354
moe_tp_size=-1, # -1 means no moe
5455
moe_ep_size=-1, # -1 means no moe
@@ -126,6 +127,7 @@ def __init__(
126127
self.cp_size = cp_size
127128
self.cp_config = cp_config if cp_config is not None else {}
128129
self.pp_size = pp_size
130+
self.pp_partition = pp_partition
129131
self.moe_tp_size = moe_tp_size
130132
self.moe_ep_size = moe_ep_size
131133
self.moe_cluster_size = moe_cluster_size
@@ -156,6 +158,7 @@ def __eq__(self, other):
156158
and self.tp_size == other.tp_size
157159
and self.moe_cluster_size == other.moe_cluster_size
158160
and self.pp_size == other.pp_size
161+
and self.pp_partition == other.pp_partition
159162
and self.moe_tp_size == other.moe_tp_size
160163
and self.moe_ep_size == other.moe_ep_size
161164
and self.attn_tp_size == other.attn_tp_size
@@ -177,6 +180,7 @@ def __hash__(self):
177180
self.attn_cp_size,
178181
# note: we do not allow updating cp_config after initialization
179182
tuple(sorted(self.cp_config.items())),
183+
tuple(self.pp_partition) if self.pp_partition is not None else (),
180184
))
181185

182186
@property
@@ -299,9 +303,20 @@ def has_moe_ep(self):
299303
return self.moe_ep_size > 1
300304

301305
def pp_layers(self, num_layers: int) -> List[int]:
302-
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
303-
return torch.tensor_split(torch.arange(num_layers),
304-
self.pp_size)[self.pp_rank].tolist()
306+
if self.pp_partition is not None:
307+
if len(self.pp_partition) != self.pp_size:
308+
raise ValueError(
309+
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
310+
)
311+
if sum(self.pp_partition) != num_layers:
312+
raise ValueError(
313+
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
314+
return torch.arange(num_layers).split(
315+
self.pp_partition)[self.pp_rank].tolist()
316+
else:
317+
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
318+
return torch.tensor_split(torch.arange(num_layers),
319+
self.pp_size)[self.pp_rank].tolist()
305320

306321
def ep_experts(self, num_experts: int) -> List[int]:
307322
assert self.cp_size == 1
@@ -446,6 +461,7 @@ def __init__(
446461
cp_config=None,
447462
tp_size=1,
448463
pp_size=1,
464+
pp_partition=None,
449465
moe_cluster_size=-1, # -1 means no moe
450466
moe_tp_size=-1, # -1 means no moe
451467
moe_ep_size=-1, # -1 means no moe

tests/unittest/api_stability/references/llm.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ methods:
1818
annotation: Optional[dict]
1919
default: null
2020
status: prototype
21+
pp_partition:
22+
annotation: Optional[List[int]]
23+
default: null
24+
status: prototype
2125
# Stats
2226
iter_stats_max_iterations:
2327
annotation: Optional[int]

0 commit comments

Comments
 (0)