-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[feat]: implement "local" caption upsampling for Flux.2 #12718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
7350d07
e6a0ab6
b4a8406
0b1f884
ceb8a3a
b07bee3
82685f2
6397a67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| from ..pipeline_utils import DiffusionPipeline | ||
| from .image_processor import Flux2ImageProcessor | ||
| from .pipeline_output import Flux2PipelineOutput | ||
| from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I | ||
|
|
||
|
|
||
| if is_torch_xla_available(): | ||
|
|
@@ -57,24 +58,107 @@ | |
| """ | ||
|
|
||
|
|
||
| def format_text_input(prompts: List[str], system_message: str = None): | ||
| # Adapted from | ||
| # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 | ||
| def format_input( | ||
| prompts: List[str], | ||
| system_message: str = SYSTEM_MESSAGE, | ||
| images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None, | ||
| ): | ||
| """ | ||
| Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images | ||
| to the input. | ||
|
|
||
| Args: | ||
| prompts: List of text prompts | ||
| system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) | ||
| images (optional): List of images to add to the input. | ||
|
|
||
| Returns: | ||
| List of conversations, where each conversation is a list of message dicts | ||
| """ | ||
| # Remove [IMG] tokens from prompts to avoid Pixtral validation issues | ||
| # when truncation is enabled. The processor counts [IMG] tokens and fails | ||
| # if the count changes after truncation. | ||
| cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] | ||
|
|
||
| return [ | ||
| if images is None or len(images) == 0: | ||
| return [ | ||
| [ | ||
| { | ||
| "role": "system", | ||
| "content": [{"type": "text", "text": system_message}], | ||
| }, | ||
| {"role": "user", "content": [{"type": "text", "text": prompt}]}, | ||
| ] | ||
| for prompt in cleaned_txt | ||
| ] | ||
| else: | ||
| assert len(images) == len(prompts), "Number of images must match number of prompts" | ||
| images = _validate_and_process_images(images) | ||
|
|
||
| messages = [ | ||
| [ | ||
| { | ||
| "role": "system", | ||
| "content": [{"type": "text", "text": system_message}], | ||
| }, | ||
| ] | ||
| for _ in cleaned_txt | ||
| ] | ||
|
|
||
| for i, (el, images) in enumerate(zip(messages, images)): | ||
| # optionally add the images per batch element. | ||
| if images is not None: | ||
| el.append( | ||
| { | ||
| "role": "user", | ||
| "content": [{"type": "image", "image": image_obj} for image_obj in images], | ||
| } | ||
| ) | ||
| # add the text. | ||
| el.append( | ||
| { | ||
| "role": "user", | ||
| "content": [{"type": "text", "text": cleaned_txt[i]}], | ||
| } | ||
| ) | ||
|
|
||
| return messages | ||
|
|
||
|
|
||
| # Adapted from | ||
| # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 | ||
| def _validate_and_process_images( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we have a seperate step to validate and process image and then run
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. We now first |
||
| images: List[List[PIL.Image.Image]] | List[PIL.Image.Image], | ||
| image_processor: Flux2ImageProcessor, | ||
| upsampling_max_image_size: int, | ||
| ) -> List[List[PIL.Image.Image]]: | ||
| # Simple validation: ensure it's a list of PIL images or list of lists of PIL images | ||
| if not images: | ||
| return [] | ||
|
|
||
| # Check if it's a list of lists or a list of images | ||
| if isinstance(images[0], PIL.Image.Image): | ||
| # It's a list of images, convert to list of lists | ||
| images = [[im] for im in images] | ||
|
|
||
| # potentially concatenate multiple images to reduce the size | ||
| images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images] | ||
|
|
||
| # cap the pixels | ||
| images = [ | ||
| [ | ||
| { | ||
| "role": "system", | ||
| "content": [{"type": "text", "text": system_message}], | ||
| }, | ||
| {"role": "user", "content": [{"type": "text", "text": prompt}]}, | ||
| image_processor._resize_to_target_area(img_i, upsampling_max_image_size, return_if_small_image=True) | ||
| for img_i in img_i | ||
| ] | ||
| for prompt in cleaned_txt | ||
| for img_i in images | ||
| ] | ||
| return images | ||
|
|
||
|
|
||
| # Taken from | ||
| # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 | ||
| def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: | ||
| a1, b1 = 8.73809524e-05, 1.89833333 | ||
| a2, b2 = 0.00016927, 0.45666666 | ||
|
|
@@ -214,9 +298,10 @@ def __init__( | |
| self.tokenizer_max_length = 512 | ||
| self.default_sample_size = 128 | ||
|
|
||
| # fmt: off | ||
| self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." | ||
| # fmt: on | ||
| self.system_message = SYSTEM_MESSAGE | ||
| self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I | ||
| self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I | ||
| self.upsampling_max_image_size = 768**2 | ||
|
|
||
| @staticmethod | ||
| def _get_mistral_3_small_prompt_embeds( | ||
|
|
@@ -226,9 +311,7 @@ def _get_mistral_3_small_prompt_embeds( | |
| dtype: Optional[torch.dtype] = None, | ||
| device: Optional[torch.device] = None, | ||
| max_sequence_length: int = 512, | ||
| # fmt: off | ||
| system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", | ||
| # fmt: on | ||
| system_message: str = SYSTEM_MESSAGE, | ||
| hidden_states_layers: List[int] = (10, 20, 30), | ||
| ): | ||
| dtype = text_encoder.dtype if dtype is None else dtype | ||
|
|
@@ -237,7 +320,7 @@ def _get_mistral_3_small_prompt_embeds( | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | ||
|
|
||
| # Format input messages | ||
| messages_batch = format_text_input(prompts=prompt, system_message=system_message) | ||
| messages_batch = format_input(prompts=prompt, system_message=system_message) | ||
|
|
||
| # Process all messages at once | ||
| inputs = tokenizer.apply_chat_template( | ||
|
|
@@ -426,6 +509,64 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch | |
|
|
||
| return torch.stack(x_list, dim=0) | ||
|
|
||
| def upsample_prompt( | ||
| self, | ||
| prompt: Union[str, List[str]], | ||
| images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None, | ||
| temperature: float = 0.15, | ||
| device: torch.device = None, | ||
| ) -> List[str]: | ||
| prompt = [prompt] if isinstance(prompt, str) else prompt | ||
| device = self.text_encoder.device if device is None else device | ||
|
|
||
| # Set system message based on whether images are provided | ||
| if images is None or len(images) == 0 or images[0] is None: | ||
| system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I | ||
| else: | ||
| system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I | ||
|
|
||
| # Format input messages | ||
| messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) | ||
|
|
||
| # Process all messages at once | ||
| # with image processing a too short max length can throw an error in here. | ||
| inputs = self.tokenizer.apply_chat_template( | ||
| messages_batch, | ||
| add_generation_prompt=True, | ||
| tokenize=True, | ||
| return_dict=True, | ||
| return_tensors="pt", | ||
| padding="max_length", | ||
| truncation=True, | ||
| max_length=2048, | ||
| ) | ||
|
|
||
| # Move to device | ||
| inputs["input_ids"] = inputs["input_ids"].to(device) | ||
| inputs["attention_mask"] = inputs["attention_mask"].to(device) | ||
|
|
||
| if "pixel_values" in inputs: | ||
| inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype) | ||
|
|
||
| # Generate text using the model's generate method | ||
| generated_ids = self.text_encoder.generate( | ||
| **inputs, | ||
| max_new_tokens=512, | ||
| do_sample=True, | ||
| temperature=temperature, | ||
| use_cache=True, | ||
| ) | ||
|
|
||
| # Decode only the newly generated tokens (skip input tokens) | ||
| # Extract only the generated portion | ||
| input_length = inputs["input_ids"].shape[1] | ||
| generated_tokens = generated_ids[:, input_length:] | ||
|
|
||
| upsampled_prompt = self.tokenizer.tokenizer.batch_decode( | ||
| generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True | ||
| ) | ||
| return upsampled_prompt | ||
|
|
||
| def encode_prompt( | ||
| self, | ||
| prompt: Union[str, List[str]], | ||
|
|
@@ -620,6 +761,7 @@ def __call__( | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||
| max_sequence_length: int = 512, | ||
| text_encoder_out_layers: Tuple[int] = (10, 20, 30), | ||
| caption_upsample_temperature: float = None, | ||
| ): | ||
| r""" | ||
| Function invoked when calling the pipeline for generation. | ||
|
|
@@ -635,11 +777,11 @@ def __call__( | |
| The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | ||
| instead. | ||
| guidance_scale (`float`, *optional*, defaults to 1.0): | ||
| Guidance scale as defined in [Classifier-Free Diffusion | ||
| Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. | ||
| of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting | ||
| `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to | ||
| the text `prompt`, usually at the expense of lower image quality. | ||
| Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages | ||
| a model to generate images more aligned with `prompt` at the expense of lower image quality. | ||
|
|
||
| Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to | ||
| the [paper](https://huggingface.co/papers/2210.03142) to learn more. | ||
| height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
| The height in pixels of the generated image. This is set to 1024 by default for the best results. | ||
| width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): | ||
|
|
@@ -684,6 +826,9 @@ def __call__( | |
| max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. | ||
| text_encoder_out_layers (`Tuple[int]`): | ||
| Layer indices to use in the `text_encoder` to derive the final prompt embeddings. | ||
| caption_upsample_temperature (`float`): | ||
| When specified, we will try to perform caption upsampling for potentially improved outputs. We | ||
| recommend setting it to 0.15 if caption upsampling is to be performed. | ||
|
|
||
| Examples: | ||
|
|
||
|
|
@@ -718,6 +863,10 @@ def __call__( | |
| device = self._execution_device | ||
|
|
||
| # 3. prepare text embeddings | ||
| if caption_upsample_temperature: | ||
| prompt = self.upsample_prompt( | ||
| prompt, images=image, temperature=caption_upsample_temperature, device=device | ||
| ) | ||
| prompt_embeds, text_ids = self.encode_prompt( | ||
| prompt=prompt, | ||
| prompt_embeds=prompt_embeds, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| """ | ||
| These system prompts come from: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed internally, this new-line character thingy messes up the quality a bit. Hence, I have decided to keep these system messages one-to-one same as the original implementation linked above. If we run make style && make quality, this order will be completely destroyed. We can change the |
||
| https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54 | ||
| """ | ||
|
|
||
| SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object | ||
| attribution and actions without speculation.""" | ||
|
|
||
| SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. | ||
| Guidelines: | ||
| 1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs. | ||
| 2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context. | ||
| 3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish. | ||
| Output only the revised prompt and nothing else.""" | ||
|
|
||
| SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests). | ||
| Rules: | ||
| - Single instruction only, no commentary | ||
| - Use clear, analytical language (avoid "whimsical," "cascading," etc.) | ||
| - Specify what changes AND what stays the same (face, lighting, composition) | ||
| - Reference actual image elements | ||
| - Turn negatives into positives ("don't change X" → "keep X") | ||
| - Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels") | ||
| - Keep content PG-13 | ||
| Output only the final instruction in plain text and nothing else.""" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh do you want to add a new method called something like
_resize_if_exceeds_area? or rename this one if we only use it this wayThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup.
I created
_resize_if_exceeds_area()which is basically: