Skip to content

Commit 3eb78f7

Browse files
lancellysuyoggupta
authored andcommitted
[https://nvbugs/5556998][fix] init_hf_modules in worker_main for models with trust_remote=true (NVIDIA#8931)
Signed-off-by: Lanyu Liao <[email protected]> Co-authored-by: Lanyu Liao <[email protected]>
1 parent f3cbaa9 commit 3eb78f7

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

tensorrt_llm/executor/base_worker.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,32 @@
4040

4141
__all__ = [
4242
"BaseWorker",
43+
"_init_hf_modules",
4344
]
4445

4546

47+
def _init_hf_modules():
48+
"""Initialize cached HuggingFace modules for models with trust_remote_code=True.
49+
50+
This is safe to call multiple times (idempotent) and should be called:
51+
1. At module import time (for main process and spawned subprocesses)
52+
2. At worker_main entry (for forked processes or external MPI ranks)
53+
54+
References: https://github.com/vllm-project/vllm/pull/871
55+
"""
56+
try:
57+
from transformers.dynamic_module_utils import init_hf_modules
58+
init_hf_modules()
59+
logger.debug("HF modules initialized")
60+
except ImportError as e:
61+
logger.warning(f"ImportError initializing HF modules: {e}")
62+
except Exception as e:
63+
logger.error(f"Exception initializing HF modules: {e}")
64+
65+
66+
_init_hf_modules()
67+
68+
4669
class BaseWorker(GenerationExecutor):
4770

4871
class WorkerExit(GeneratorExit):

tensorrt_llm/executor/worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug,
2222
print_traceback_on_error)
2323
from ..sampling_params import BatchedLogitsProcessor
24-
from .base_worker import BaseWorker
24+
from .base_worker import BaseWorker, _init_hf_modules
2525
from .executor import IterationResultQueue
2626
from .ipc import FusedIpcQueue, IpcQueue
2727
from .postproc_worker import (PostprocWorker, PostprocWorkerConfig,
@@ -242,6 +242,10 @@ def worker_main(
242242
llm_args: Optional[BaseLlmArgs] = None,
243243
) -> None:
244244
mpi_comm().barrier()
245+
246+
if llm_args is not None and llm_args.trust_remote_code:
247+
_init_hf_modules()
248+
245249
logger_debug(f"Worker {mpi_rank()} entering worker_main...\n", "green")
246250

247251
result_queue: Optional[IpcQueue] = None

0 commit comments

Comments
 (0)