Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fc637f1
Add video support for random-mm dataset
BloodAxe Sep 30, 2025
1f0f2e7
Add video support for random-mm dataset
BloodAxe Sep 30, 2025
5f10fb6
Add video support for random-mm dataset
BloodAxe Sep 30, 2025
cd3c843
Add video support for random-mm dataset
BloodAxe Sep 30, 2025
87d12aa
Add video support for random-mm dataset
BloodAxe Sep 30, 2025
5a4fcf8
Added dependency for opencv for benchmarking
BloodAxe Oct 1, 2025
fc2d72c
Fix bug random-mm dataset - ensure that we don't use special tokens w…
BloodAxe Oct 1, 2025
8e4015d
Code formatting
BloodAxe Oct 3, 2025
e12bd5c
Merge branch 'main' into feature/video-support-in-random-mm-dataset
BloodAxe Oct 3, 2025
2eda6ea
Merge main
BloodAxe Oct 6, 2025
933d1ef
Remove mistakengly put (self) to module-level function
BloodAxe Oct 6, 2025
a5b588b
Do not import cv2 at the module level
BloodAxe Oct 6, 2025
fab3ec4
Remove debug prints
BloodAxe Oct 6, 2025
8869f0a
Remove debug prints
BloodAxe Oct 6, 2025
3ca2ec3
Fix issue of not excluding all special tokens
BloodAxe Oct 6, 2025
64b0eec
Merge branch 'main' into feature/video-support-in-random-mm-dataset
BloodAxe Oct 15, 2025
7105106
Merge branch 'main' into feature/video-support-in-random-mm-dataset
ywang96 Oct 16, 2025
e7aebe1
Merge branch 'main' into feature/video-support-in-random-mm-dataset
BloodAxe Oct 20, 2025
78a808c
Change info logging to debug
BloodAxe Oct 20, 2025
26c50d6
Change info logging to debug & update comment explaining why we exclu…
BloodAxe Oct 20, 2025
f0c6cdc
Remove import guard for cv2
BloodAxe Oct 20, 2025
ee30d58
Merge branch 'main' into feature/video-support-in-random-mm-dataset
BloodAxe Oct 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/bench.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
opencv-python~=4.12 # We use OpenCV to generate MP4 files for random videos
123 changes: 123 additions & 0 deletions tests/benchmarks/test_random_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
assert len(mm_data) >= 1
for it in mm_data:
assert it.get("type") == "image_url"


@pytest.mark.benchmark
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test video sampling functionality in RandomMultiModalDataset."""
ds = RandomMultiModalDataset(random_seed=42)

# Test with video bucket configuration
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}

limit_mm_per_prompt = {"image": 2, "video": 2}

samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

assert len(samples) == 5

# Check that we have both images and videos
video_count = 0
image_count = 0

for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1

item = mm_data[0]
if item.get("type") == "video_url":
video_count += 1
# Verify video URL format
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
elif item.get("type") == "image_url":
image_count += 1
# Verify image URL format
url = item.get("image_url", {}).get("url", "")
assert url.startswith("data:image/jpeg;base64,")

# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0


@pytest.mark.benchmark
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test sampling with only video buckets."""
ds = RandomMultiModalDataset(random_seed=42)

bucket_config = {
(64, 64, 8): 1.0, # Only videos
}

limit_mm_per_prompt = {"image": 0, "video": 1}

samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

assert len(samples) == 3

for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1

item = mm_data[0]
assert item.get("type") == "video_url"
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")


@pytest.mark.benchmark
def test_random_mm_video_deterministic_sampling(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Test that video sampling is deterministic with same seed."""
seed = 123
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)

bucket_config = {
(64, 64, 8): 1.0, # Only videos
}

limit_mm_per_prompt = {"image": 0, "video": 1}

a = _collect_mm_samples(
ds_a,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

b = _collect_mm_samples(
ds_b,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb
Loading
Loading