Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions src/transformers/models/llama4/image_processing_llama4_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def get_factors(dividend: int) -> set[int]:
set: A set containing all factors of the number.
"""
factors_set = set()

for i in range(1, int(dividend**0.5) + 1):
# Loop up to sqrt(dividend) and add both i and dividend//i when divisible
limit = int(dividend**0.5)
for i in range(1, limit + 1):
if dividend % i == 0:
factors_set.add(i)
factors_set.add(dividend // i)
Expand Down Expand Up @@ -131,21 +132,26 @@ def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> tor
if height != width:
raise ValueError("`size` must be square.")

patch_size = height
patch_size_val = height # store to avoid name confusion, preserve assignment for safety

asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))

# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, depth in value:
possible_resolutions.append((height * patch_size, depth * patch_size))

# Precompute all factors and aspect ratios for each chunk_size to minimize lookups
range_chunks = range(max_num_chunks, 0, -1)
for chunk_size in range_chunks:
factors = get_factors(chunk_size)
factor_list = sorted(factors)
# Avoid constructing tuples in a slow way: use zip directly
for factor in factor_list:
second = chunk_size // factor
ratio_float = factor / second
asp_dict[ratio_float].append((factor, second))

# Instead of nested loops over asp_dict.values(), flatten the list efficiently
# Also, preallocate possible_resolutions for faster appends
possible_resolutions = [
(height * patch_size_val, width * patch_size_val) for value in asp_dict.values() for height, width in value
]

return possible_resolutions

Expand Down