diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 2ecf06a660..b0a345e1c2 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -2092,6 +2092,73 @@ def process_messages( return messages +@dataclass +class LFMVLPlugin(BasePlugin): + r"""Plugin for LFM2.5-VL vision-language models. + + LFM2.5-VL uses dynamic image token counts based on image resolution. + The image processor returns spatial_shapes tensor with [height, width] grid dimensions. + Token count per image = (spatial_h * spatial_w) / (downsample_factor^2) + """ + + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + downsample_factor: int = getattr(image_processor, "downsample_factor", 2) + + if self.expand_mm_tokens and len(images) > 0: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + spatial_shapes = mm_inputs.get("spatial_shapes", []) + else: + spatial_shapes = [] + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens: + h, w = spatial_shapes[num_image_tokens].tolist() + image_seqlen = (h * w) // (downsample_factor * downsample_factor) + else: + image_seqlen = 1 + + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 + + message["content"] = content.replace("{{image}}", self.image_token) + + return messages + + PLUGINS = { "base": BasePlugin, "ernie_vl": ErnieVLPlugin, @@ -2104,6 +2171,7 @@ def process_messages( "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, + "lfm2_vl": LFMVLPlugin, "minicpm_v": MiniCPMVPlugin, "mllama": MllamaPlugin, "paligemma": PaliGemmaPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index dd9d062915..d229034b16 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1350,6 +1350,27 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: ) +register_template( + name="lfm2_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"), + format_observation=StringFormatter( + slots=[ + "<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n" + "<|im_start|>assistant\n" + ] + ), + format_tools=ToolFormatter(tool_format="lfm"), + default_system="You are a helpful multimodal assistant by Liquid AI.", + stop_words=["<|im_end|>"], + tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"), + replace_eos=True, + mm_plugin=get_mm_plugin(name="lfm2_vl", image_token=""), +) + + register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 8920420ad5..ec4d34d7e7 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1506,6 +1506,17 @@ def register_model_group( ) +register_model_group( + models={ + "LFM2.5-VL-1.6B": { + DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B", + }, + }, + template="lfm2_vl", + multimodal=True, +) + + register_model_group( models={ "Llama-7B": { diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index bf811043d8..19e174bae8 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -151,6 +151,12 @@ def patch_config( if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.") + if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"): + raise RuntimeError( + "LFM2.5-VL model requires transformers>=4.58.0 or install from commit: " + "pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe" + ) + if getattr(config, "model_type", None) == "qwen3_omni_moe": patch_qwen3_omni_moe_thinker_text_sparse_moe_block() diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index bb416ed440..3187004aa5 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -419,3 +419,15 @@ def test_video_llava_plugin(): ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) _check_plugin(**check_inputs) + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm2_vl_plugin(): + """Test LFM2.5-VL plugin instantiation.""" + # Test plugin can be instantiated with correct tokens + lfm2_vl_plugin = get_mm_plugin(name="lfm2_vl", image_token="") + assert lfm2_vl_plugin is not None + assert lfm2_vl_plugin.image_token == "" + assert lfm2_vl_plugin.video_token is None + assert lfm2_vl_plugin.audio_token is None + assert lfm2_vl_plugin.__class__.__name__ == "LFMVLPlugin"