diff --git a/data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py b/data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py index a870d0035b..cad48061dc 100644 --- a/data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py +++ b/data_juicer/ops/mapper/text_tagging_by_prompt_mapper.py @@ -114,7 +114,11 @@ def __init__( """ # noqa super().__init__(*args, **kwargs) - self.num_proc = 1 + if not enable_vllm or not is_ray_mode(): + # In Ray mode with vllm, num_proc is left to the caller so Ray can + # schedule one actor per GPU for data parallelism. + self.num_proc = 1 + self.prompt = prompt self.tag_list = tag_list @@ -129,9 +133,6 @@ def __init__( sampling_params = update_sampling_params(sampling_params, hf_model, self.enable_vllm) if enable_vllm: - if not is_ray_mode(): - # cannot initialize vllm replicas on different GPUs - self.num_proc = 1 self.model_key = prepare_model( model_type="vllm", pretrained_model_name_or_path=hf_model,