Skip to content

Commit ff515ce

Browse files
committed
[refactor] Move iter_counter handling to PyExecutor
- Moved iter_counter in PyExecutor to ensure consistency in tracking iterations. - This allows tracking of iteration where scheduled requests are empty. Signed-off-by: Robin Kobus <[email protected]>
1 parent c231186 commit ff515ce

File tree

5 files changed

+16
-13
lines changed

5 files changed

+16
-13
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def __init__(
153153
self.llm_args.batch_wait_timeout_iters = 0
154154
self.llm_args.batch_wait_max_tokens_ratio = 0.0
155155
self.llm_args.max_num_tokens = seq_info.max_num_tokens
156-
self.iter_counter = 0
157156

158157
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
159158
self.max_beam_width = max_beam_width

tensorrt_llm/_torch/expert_statistic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ def create(rank_id: int):
2929
rank_id, start, stop)
3030

3131
@staticmethod
32-
def set_iter(iter_id: int) -> bool:
32+
def set_iter(iter_id: int) -> None:
3333
if ExpertStatistic.expert_statistic_obj is not None:
34-
return ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
35-
else:
36-
return False
34+
ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
3735

3836
@staticmethod
3937
def set_layer(layer_id: int) -> None:

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def maybe_get_cuda_graph(
158158
engine = self._get_engine()
159159

160160
# disable when doing statistic
161-
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
162-
engine.iter_counter):
161+
if ExpertStatistic.get() is not None:
163162
return False, None, None, None
164163

165164
can_run_cuda_graph = batch.can_run_cuda_graph

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def __init__(
366366
if self.use_mrope:
367367
self.mrope_position_ids_cuda = torch.empty(
368368
(3, 1, self.max_num_tokens), dtype=torch.int, device='cuda')
369-
self.iter_counter = 0
370369

371370
# We look up this key in resource_manager during forward to find the
372371
# kv cache manager. Can be changed to support multiple model engines
@@ -2338,7 +2337,6 @@ def forward(
23382337
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
23392338
new_tensors_device, cache_indirection_buffer)
23402339

2341-
self.iter_counter += 1
23422340
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
23432341
if not maybe_graph:
23442342
# Fallback to eager execution if graph was not used

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313

14+
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
1415
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
1516

1617
try:
@@ -135,6 +136,7 @@ def __init__(self,
135136

136137
self.peft_cache_config = peft_cache_config
137138

139+
self.iter_counter = 0
138140
# profile config
139141
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
140142
PROFILE_START_STOP_ENV_VAR_NAME)
@@ -567,7 +569,7 @@ def profile_step():
567569
formatted_timestamp = datetime.datetime.now().strftime(
568570
"%Y-%m-%d %H:%M:%S")
569571
logger.info(
570-
f"iter = {self.model_engine.iter_counter}, "
572+
f"iter = {self.iter_counter}, "
571573
f"global_rank = {self.global_rank}, "
572574
f"rank = {self.dist.rank}, "
573575
f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/"
@@ -697,7 +699,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
697699
stats.cpu_mem_usage = 0
698700
stats.pinned_mem_usage = 0
699701

700-
stats.iter = self.model_engine.iter_counter
702+
stats.iter = self.iter_counter
701703

702704
kv_cache_manager = self.resource_manager.resource_managers.get(
703705
ResourceManagerType.KV_CACHE_MANAGER)
@@ -994,6 +996,8 @@ def _executor_loop_pp(self):
994996
self.active_requests,
995997
previous_batch)
996998

999+
self.iter_counter += 1
1000+
9971001
def wait_on_pp_send_handles(self, microbatch_id):
9981002
if self.send_handles[microbatch_id] is not None:
9991003
self.send_handles[microbatch_id].wait()
@@ -1232,6 +1236,8 @@ def _executor_loop(self):
12321236
iter_stats=iter_stats,
12331237
iter_start_time=iter_start_time))
12341238

1239+
self.iter_counter += 1
1240+
12351241
def _prepare_draft_requests(self):
12361242
try:
12371243
# Set draft tokens here to make the KV cache manager
@@ -1417,6 +1423,8 @@ def _executor_loop_overlap(self):
14171423

14181424
self._kv_connector_terminate_requests()
14191425

1426+
self.iter_counter += 1
1427+
14201428
def _process_previous_batch(self):
14211429
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
14221430
for req in self.previous_batch.ctx_transmission_reqs:
@@ -1820,9 +1828,10 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0):
18201828
def _forward_step(self,
18211829
scheduled_requests,
18221830
new_tensors_device: Optional[SampleStateTensors] = None):
1831+
ExpertStatistic.set_iter(self.iter_counter)
18231832

18241833
@nvtx_range(
1825-
f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
1834+
f"[Executor] _forward_step {self.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
18261835
)
18271836
def forward(scheduled_requests, resource_manager, new_tensors_device,
18281837
gather_context_logits, cache_indirection_buffer):
@@ -2160,7 +2169,7 @@ def _handle_responses(self):
21602169

21612170
# Skip active requests that are not scheduled
21622171
if request.return_perf_metrics and request.py_decoding_iter >= 1:
2163-
request.update_perf_metrics(self.model_engine.iter_counter)
2172+
request.update_perf_metrics(self.iter_counter)
21642173

21652174
request_done = False
21662175
if request.py_decoding_iter == 1 or request.is_finished or \

0 commit comments

Comments
 (0)