Skip to content

Commit 461e765

Browse files
committed
[refactor] Move iter_counter handling to PyExecutor
- Moved iter_counter in PyExecutor to ensure consistency in tracking iterations. - This allows tracking of iteration where scheduled requests are empty. Signed-off-by: Robin Kobus <[email protected]>
1 parent 4becb44 commit 461e765

File tree

5 files changed

+16
-13
lines changed

5 files changed

+16
-13
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def __init__(
153153
self.llm_args.batch_wait_timeout_iters = 0
154154
self.llm_args.batch_wait_max_tokens_ratio = 0.0
155155
self.llm_args.max_num_tokens = seq_info.max_num_tokens
156-
self.iter_counter = 0
157156

158157
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
159158
self.max_beam_width = max_beam_width

tensorrt_llm/_torch/expert_statistic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ def create(rank_id: int):
2929
rank_id, start, stop)
3030

3131
@staticmethod
32-
def set_iter(iter_id: int) -> bool:
32+
def set_iter(iter_id: int) -> None:
3333
if ExpertStatistic.expert_statistic_obj is not None:
34-
return ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
35-
else:
36-
return False
34+
ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
3735

3836
@staticmethod
3937
def set_layer(layer_id: int) -> None:

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def maybe_get_cuda_graph(
158158
engine = self._get_engine()
159159

160160
# disable when doing statistic
161-
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
162-
engine.iter_counter):
161+
if ExpertStatistic.get() is not None:
163162
return False, None, None, None
164163

165164
can_run_cuda_graph = batch.can_run_cuda_graph

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def __init__(
366366
if self.use_mrope:
367367
self.mrope_position_ids_cuda = torch.empty(
368368
(3, 1, self.max_num_tokens), dtype=torch.int, device='cuda')
369-
self.iter_counter = 0
370369

371370
# We look up this key in resource_manager during forward to find the
372371
# kv cache manager. Can be changed to support multiple model engines
@@ -2338,7 +2337,6 @@ def forward(
23382337
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
23392338
new_tensors_device, cache_indirection_buffer)
23402339

2341-
self.iter_counter += 1
23422340
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
23432341
if not maybe_graph:
23442342
# Fallback to eager execution if graph was not used

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313

14+
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
1415
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
1516

1617
try:
@@ -136,6 +137,7 @@ def __init__(self,
136137

137138
self.peft_cache_config = peft_cache_config
138139

140+
self.iter_counter = 0
139141
# profile config
140142
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
141143
PROFILE_START_STOP_ENV_VAR_NAME)
@@ -575,7 +577,7 @@ def profile_step():
575577
formatted_timestamp = datetime.datetime.now().strftime(
576578
"%Y-%m-%d %H:%M:%S")
577579
logger.info(
578-
f"iter = {self.model_engine.iter_counter}, "
580+
f"iter = {self.iter_counter}, "
579581
f"global_rank = {self.global_rank}, "
580582
f"rank = {self.dist.rank}, "
581583
f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/"
@@ -705,7 +707,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
705707
stats.cpu_mem_usage = 0
706708
stats.pinned_mem_usage = 0
707709

708-
stats.iter = self.model_engine.iter_counter
710+
stats.iter = self.iter_counter
709711

710712
kv_cache_manager = self.resource_manager.resource_managers.get(
711713
ResourceManagerType.KV_CACHE_MANAGER)
@@ -1004,6 +1006,8 @@ def _executor_loop_pp(self):
10041006
self.active_requests,
10051007
previous_batch)
10061008

1009+
self.iter_counter += 1
1010+
10071011
def wait_on_pp_send_handles(self, microbatch_id):
10081012
if self.send_handles[microbatch_id] is not None:
10091013
self.send_handles[microbatch_id].wait()
@@ -1244,6 +1248,8 @@ def _executor_loop(self):
12441248
iter_stats=iter_stats,
12451249
iter_start_time=iter_start_time))
12461250

1251+
self.iter_counter += 1
1252+
12471253
def _prepare_draft_requests(self):
12481254
try:
12491255
# Set draft tokens here to make the KV cache manager
@@ -1472,6 +1478,8 @@ def _executor_loop_overlap(self):
14721478

14731479
self._kv_connector_terminate_requests()
14741480

1481+
self.iter_counter += 1
1482+
14751483
def _process_previous_batch(self):
14761484
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
14771485
for req in self.previous_batch.ctx_transmission_reqs:
@@ -1875,9 +1883,10 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0):
18751883
def _forward_step(self,
18761884
scheduled_requests,
18771885
new_tensors_device: Optional[SampleStateTensors] = None):
1886+
ExpertStatistic.set_iter(self.iter_counter)
18781887

18791888
@nvtx_range(
1880-
f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
1889+
f"[Executor] _forward_step {self.iter_counter}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs"
18811890
)
18821891
def forward(scheduled_requests, resource_manager, new_tensors_device,
18831892
gather_context_logits, cache_indirection_buffer):
@@ -2215,7 +2224,7 @@ def _handle_responses(self):
22152224

22162225
# Skip active requests that are not scheduled
22172226
if request.return_perf_metrics and request.py_decoding_iter >= 1:
2218-
request.update_perf_metrics(self.model_engine.iter_counter)
2227+
request.update_perf_metrics(self.iter_counter)
22192228

22202229
request_done = False
22212230
if request.py_decoding_iter == 1 or request.is_finished or \

0 commit comments

Comments
 (0)