Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/bench.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
opencv-python~=4.12 # We use OpenCV to generate MP4 files for random videos
127 changes: 127 additions & 0 deletions tests/benchmarks/test_random_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,130 @@ def test_random_mm_bucket_config_not_mutated(
assert len(mm_data) >= 1
for it in mm_data:
assert it.get("type") == "image_url"


@pytest.mark.benchmark
def test_random_mm_video_sampling(
hf_tokenizer: PreTrainedTokenizerBase
) -> None:
"""Test video sampling functionality in RandomMultiModalDataset."""
ds = RandomMultiModalDataset(random_seed=42)

# Test with video bucket configuration
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}

limit_mm_per_prompt = {"image": 2, "video": 2}

samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

assert len(samples) == 5

# Check that we have both images and videos
video_count = 0
image_count = 0

for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1

item = mm_data[0]
if item.get("type") == "video_url":
video_count += 1
# Verify video URL format
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
elif item.get("type") == "image_url":
image_count += 1
# Verify image URL format
url = item.get("image_url", {}).get("url", "")
assert url.startswith("data:image/jpeg;base64,")

# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0


@pytest.mark.benchmark
def test_random_mm_video_only_sampling(
hf_tokenizer: PreTrainedTokenizerBase
) -> None:
"""Test sampling with only video buckets."""
ds = RandomMultiModalDataset(random_seed=42)

bucket_config = {
(64, 64, 8): 1.0, # Only videos
}

limit_mm_per_prompt = {"image": 0, "video": 1}

samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

assert len(samples) == 3

for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1

item = mm_data[0]
assert item.get("type") == "video_url"
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")


@pytest.mark.benchmark
def test_random_mm_video_deterministic_sampling(
hf_tokenizer: PreTrainedTokenizerBase
) -> None:
"""Test that video sampling is deterministic with same seed."""
seed = 123
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)

bucket_config = {
(64, 64, 8): 1.0, # Only videos
}

limit_mm_per_prompt = {"image": 0, "video": 1}

a = _collect_mm_samples(
ds_a,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

b = _collect_mm_samples(
ds_b,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)

fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb
Loading