diff --git a/.gitignore b/.gitignore index 99e488d9e..ec3bbc394 100755 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,9 @@ uv.lock workspace/* .claude/* remote_code/* + +results/ +lavis +cookies.txt +external/ +*.ipynb \ No newline at end of file diff --git a/examples/models/qwen25vl.sh b/examples/models/qwen25vl.sh index e16bedf2f..e6c3c864a 100644 --- a/examples/models/qwen25vl.sh +++ b/examples/models/qwen25vl.sh @@ -1,6 +1,6 @@ # Run and exactly reproduce qwen2vl results! # mme as an example -export HF_HOME="~/.cache/huggingface" +export HF_HOME="~/flash/.cache/huggingface" # pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git # pip3 install qwen_vl_utils # use `interleave_visuals=True` to control the visual token position, currently only for mmmu_val and mmmu_pro (and potentially for other interleaved image-text tasks), please do not use it unless you are sure about the operation details. @@ -11,8 +11,11 @@ export HF_HOME="~/.cache/huggingface" # --tasks mmmu_pro \ # --batch_size 1 +echo "Running Qwen2.5-VL-7B-Instruct" accelerate launch --num_processes=8 --main_process_port=12346 -m lmms_eval \ --model qwen2_5_vl \ --model_args=pretrained=Qwen/Qwen2.5-VL-7B-Instruct,max_pixels=12845056,attn_implementation=flash_attention_2,interleave_visuals=False \ --tasks mme \ - --batch_size 1 \ No newline at end of file + --batch_size 1 + +# uv run python -m lmms_eval --model qwen2_5_vl --model_args=pretrained=Qwen/Qwen2.5-VL-3B-Instruct,max_pixels=602112,interleave_visuals=False,attn_implementation=flash_attention_2,video_sampler=uniform --tasks egoschema --batch_size 1 --output_path results/test.jsonl \ No newline at end of file diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py index fb2b37439..86bf5aa47 100755 --- a/lmms_eval/__main__.py +++ b/lmms_eval/__main__.py @@ -12,6 +12,10 @@ import torch import yaml +from dotenv import load_dotenv + +load_dotenv() + warnings.simplefilter("ignore", category=DeprecationWarning) import hashlib @@ -274,6 +278,8 @@ def parse_eval_args() -> argparse.Namespace: ) parser.add_argument("--process_with_media", action="store_true", help="Whether you will process you dataset with audio, image. By default set to False" "In case some benchmarks need to be processed with media, set this flag to True.") parser.add_argument("--force_simple", action="store_true", help="Force the evaluation to use the simple mode of the models") + parser.add_argument("--video_sampler", type=str, default=None, help="Video sampler to use") + parser.add_argument("--video_sampler_kwargs", default="", help="String arguments for video sampler, e.g. `max_num_frames=32,ratio=1,t1=0.8,t2=-100,all_depth=5`",) args = parser.parse_args() return args @@ -481,9 +487,28 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None: request_caching_args = request_caching_arg_to_dict(cache_requests=args.cache_requests) datetime_str = utils.get_datetime_str(timezone=args.timezone) + # Configure metrics logging destination for downstream log_metrics calls. + os.environ.pop("LMMS_EVAL_METRICS_PATH", None) + if args.output_path: + from lmms_eval.loggers.evaluation_tracker import GeneralConfigTracker + + fallback_model_name = args.model if isinstance(args.model, str) else str(args.model) + candidate_model_name = GeneralConfigTracker._get_model_name(args.model_args or "") or fallback_model_name + sanitized_model_name = utils.sanitize_model_name(candidate_model_name) or utils.sanitize_model_name(fallback_model_name) or "model" + + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + if not is_distributed or (is_distributed and torch.distributed.get_rank() == 0): + metrics_dir = Path(args.output_path).expanduser().resolve() / sanitized_model_name + metrics_dir.mkdir(parents=True, exist_ok=True) + date_id = datetime_str.replace(":", "-") + metrics_path = metrics_dir / f"{date_id}_metrics.json" + os.environ["LMMS_EVAL_METRICS_PATH"] = str(metrics_path) + results = evaluator.simple_evaluate( model=args.model, model_args=args.model_args, + video_sampler=args.video_sampler, + video_sampler_kwargs=args.video_sampler_kwargs, tasks=task_names, num_fewshot=args.num_fewshot, batch_size=args.batch_size, diff --git a/lmms_eval/api/instance.py b/lmms_eval/api/instance.py index 18cfb7399..ebfed9751 100755 --- a/lmms_eval/api/instance.py +++ b/lmms_eval/api/instance.py @@ -10,7 +10,9 @@ class Instance: metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here resps: list = field(default_factory=list) filtered_resps: dict = field(default_factory=dict) - + video_metadata: object = None + num_input_tokens: int = None + num_input_vision_tokens: int = None # initialized after init task_name: str = None doc_id: str = None diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index ed9aa7576..00e915e9e 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -1,6 +1,7 @@ import abc import ast import copy +import importlib import inspect import itertools import json @@ -27,6 +28,10 @@ ) import datasets +try: + import torch +except ImportError: + torch = None import numpy as np from accelerate import Accelerator from datasets import Audio, DownloadConfig, Image, Sequence @@ -939,7 +944,47 @@ def _download_from_youtube(path): cache_dir = dataset_kwargs["cache_dir"] cache_dir = os.path.join(hf_home, cache_dir) accelerator = Accelerator() + external_downloader_spec = dataset_kwargs.get("external_downloader") if dataset_kwargs is not None else None if accelerator.is_main_process: + if external_downloader_spec: + + def _resolve_external_downloader(spec): + if callable(spec): + return spec, {} + if isinstance(spec, str): + if "." not in spec: + raise ValueError("external_downloader string must be a fully qualified path, e.g. 'pkg.module.fn'") + module_path, attr_name = spec.rsplit(".", 1) + return getattr(importlib.import_module(module_path), attr_name), {} + if isinstance(spec, dict): + if "fn" not in spec: + raise ValueError("external_downloader dict must include a 'fn' key") + fn, base_kwargs = _resolve_external_downloader(spec["fn"]) + kwargs = {**base_kwargs, **spec.get("kwargs", {})} + return fn, kwargs + raise TypeError(f"Unsupported external_downloader spec type: {type(spec)}") + + downloader_fn, downloader_kwargs = _resolve_external_downloader(external_downloader_spec) + downloader_kwargs.setdefault("cache_dir", cache_dir) + downloader_kwargs.setdefault("videos_dir", os.path.join(cache_dir, "videos")) + download_result = downloader_fn(**downloader_kwargs) + + def _set_nested(target_dict, dotted_key, value): + keys = dotted_key.split(".") + curr = target_dict + for key in keys[:-1]: + if key not in curr or not isinstance(curr[key], dict): + curr[key] = {} + curr = curr[key] + curr[keys[-1]] = value + + if isinstance(external_downloader_spec, dict): + result_target = external_downloader_spec.get("result_dataset_kwarg") + if result_target and download_result is not None: + if dataset_kwargs is None: + dataset_kwargs = {} + _set_nested(dataset_kwargs, result_target, os.path.expanduser(str(download_result))) + force_download = dataset_kwargs.get("force_download", False) force_unzip = dataset_kwargs.get("force_unzip", False) revision = dataset_kwargs.get("revision", "main") @@ -1024,8 +1069,21 @@ def concat_tar_parts(tar_parts, output_tar): eval_logger.info(f"Symbolic link created successfully: {cache_path} -> {cache_dir}") accelerator.wait_for_everyone() - dataset_kwargs.pop("cache_dir") - dataset_kwargs.pop("video") + if dataset_kwargs is not None: + if accelerator.num_processes > 1: + if torch is not None and torch.distributed.is_available() and torch.distributed.is_initialized(): + shared_dataset_kwargs = [dataset_kwargs if accelerator.is_main_process else None] + torch.distributed.broadcast_object_list(shared_dataset_kwargs, src=0) + dataset_kwargs = shared_dataset_kwargs[0] + elif accelerator.is_main_process: + eval_logger.warning("Multiple processes detected but torch.distributed is not initialized. Secondary ranks may not receive updated dataset kwargs.") + if "external_downloader" in dataset_kwargs: + external_downloader = dataset_kwargs.pop("external_downloader", None) + if 'data_files' in external_downloader: + dataset_kwargs['data_files'] = external_downloader.pop('data_files') + dataset_kwargs['split'] = external_downloader.pop('split', 'test') + dataset_kwargs.pop("cache_dir", None) + dataset_kwargs.pop("video", None) if "builder_script" in dataset_kwargs: builder_script = dataset_kwargs["builder_script"] diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index c1dc315a6..3398c57a5 100755 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -46,6 +46,7 @@ run_task_tests, simple_parse_args_string, ) +from lmms_eval.video_samplers import get_video_sampler_cls @positional_deprecated @@ -83,6 +84,8 @@ def simple_evaluate( distributed_executor_backend: str = "accelerate", cli_args=None, force_simple: bool = False, + video_sampler: Optional[str] = None, + video_sampler_kwargs: Optional[str] = None, ): """Instantiate and evaluate a model on a list of tasks. @@ -183,6 +186,18 @@ def simple_evaluate( if task_manager is None: task_manager = TaskManager(verbosity, model_name=model) + video_sampler_instance = None + if isinstance(video_sampler, str): + if video_sampler_kwargs is None: + video_sampler_kwargs = "" + video_sampler_instance = get_video_sampler_cls(video_sampler).create_from_arg_string(video_sampler_kwargs, + { + "batch_size": batch_size, + "device": device, + } + ) + + if isinstance(model, str): if model_args is None: model_args = "" @@ -191,6 +206,7 @@ def simple_evaluate( { "batch_size": batch_size, "max_batch_size": max_batch_size, + "video_sampler": video_sampler_instance, "device": device, }, ) @@ -255,6 +271,8 @@ def _adjust_config(task_dict): evaluation_tracker.general_config_tracker.log_experiment_args( model_source=model, model_args=model_args, + video_sampler=video_sampler, + video_sampler_kwargs=video_sampler_kwargs, system_instruction=system_instruction, chat_template=lm.chat_template if apply_chat_template else None, fewshot_as_multiturn=fewshot_as_multiturn, @@ -411,6 +429,20 @@ def evaluate( if distributed_executor_backend == "accelerate" and not hasattr(lm, "accelerator"): lm.accelerator = Accelerator() + def _serialize_video_metadata(meta): + if isinstance(meta, dict): + return {k: _serialize_video_metadata(v) for k, v in meta.items()} + if isinstance(meta, (list, tuple)): + return [_serialize_video_metadata(v) for v in meta] + if hasattr(meta, "tolist"): + try: + return meta.tolist() + except TypeError: + pass + if isinstance(meta, (np.generic,)): + return meta.item() + return meta + for task_output in eval_tasks: task: Task = task_output.task task_name = task_output.task_name @@ -578,6 +610,18 @@ def evaluate( # else: # filtered_arguments.append(_handle_non_serializable(value)) + video_metadata_list = [] + seen_instance_ids = set() + for req in requests: + if id(req) in seen_instance_ids: + continue + seen_instance_ids.add(id(req)) + meta = getattr(req, "video_metadata", None) + num_input_tokens = getattr(req, "num_input_tokens", None) + num_input_vision_tokens = getattr(req, "num_input_vision_tokens", None) + if meta is None: + continue + video_metadata_list.append(_serialize_video_metadata(meta)) example = { "doc_id": doc_id, "doc": saved_doc, @@ -595,6 +639,14 @@ def evaluate( ), # Removing prompt hash and target hash here } + if video_metadata_list: + example["video_metadata"] = ( + video_metadata_list if len(video_metadata_list) > 1 else video_metadata_list[0] + ) + if num_input_tokens is not None: + example["num_input_tokens"] = num_input_tokens + if num_input_vision_tokens is not None: + example["num_input_vision_tokens"] = num_input_vision_tokens example.update(metrics) task_output.logged_samples.append(example) for metric, value in metrics.items(): diff --git a/lmms_eval/loggers/evaluation_tracker.py b/lmms_eval/loggers/evaluation_tracker.py index 65936cdf9..a966885d3 100644 --- a/lmms_eval/loggers/evaluation_tracker.py +++ b/lmms_eval/loggers/evaluation_tracker.py @@ -78,6 +78,8 @@ def log_experiment_args( self, model_source: str, model_args: str, + video_sampler: str, + video_sampler_kwargs: str, system_instruction: str, chat_template: str, fewshot_as_multiturn: bool, @@ -86,6 +88,8 @@ def log_experiment_args( self.model_source = model_source self.model_name = GeneralConfigTracker._get_model_name(model_args) self.model_name_sanitized = sanitize_model_name(self.model_name) + self.video_sampler = video_sampler + self.video_sampler_kwargs = video_sampler_kwargs self.system_instruction = system_instruction self.system_instruction_sha = hash_string(system_instruction) if system_instruction else None self.chat_template = chat_template diff --git a/lmms_eval/models/chat/qwen2_5_vl.py b/lmms_eval/models/chat/qwen2_5_vl.py index fbac95b0a..5d0998271 100644 --- a/lmms_eval/models/chat/qwen2_5_vl.py +++ b/lmms_eval/models/chat/qwen2_5_vl.py @@ -5,6 +5,7 @@ from loguru import logger as eval_logger from PIL import Image from tqdm import tqdm +import torch from lmms_eval import utils from lmms_eval.api.instance import Instance @@ -26,6 +27,18 @@ class Qwen2_5_VL(Qwen2_5_VLSimple): is_simple = False + def get_num_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + vision_end_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_end|>") + start = (inputs.input_ids == vision_start_token_id) # [B, L] + end = (inputs.input_ids == vision_end_token_id) # [B, L] + level = start.cumsum(dim=1) - end.cumsum(dim=1) + in_vision_span = level > 0 + in_vision_span = in_vision_span | start | end + num_vision_tokens = in_vision_span.sum(dim=1) + num_tokens = inputs.attention_mask.sum(dim=1) + return num_tokens, num_vision_tokens + def generate_until(self, requests: List[Instance]) -> List[str]: res = [] @@ -36,15 +49,22 @@ def _collate(x): # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. - re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True) + re_ords = utils.Collator([(idx, reg.args) for idx, reg in enumerate(requests)], _collate, group_fn=lambda x: x[1][2], grouping=True) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + chunk_offset = 0 num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") e2e_latency = 0 + vision_processing_latency = 0 total_tokens = 0 + total_num_input_vision_tokens = 0 + total_num_input_tokens = 0 for chunk in chunks: + chunk_request_indices, chunk = zip(*chunk) + chunk_requests = [requests[idx] for idx in chunk_request_indices] ctx, doc_to_messages, all_gen_kwargs, doc_id, task, split = zip(*chunk) chat_messages = [doc_to_messages[idx](self.task_dict[task][split][ids]) for idx, (ids, task, split) in enumerate(zip(doc_id, task, split))] + eval_logger.debug(f"chat_messages: {chat_messages}") chat_messages: List[ChatMessages] = [ChatMessages(**{"messages": message}) for message in chat_messages] visuals = [] videos = [] @@ -57,26 +77,38 @@ def _collate(x): gen_kwargs = all_gen_kwargs[0] # Apply chat template - video_kwargs = { - "max_pixels": self.max_pixels, - "min_pixels": self.min_pixels, - } + if self.resized_height is not None and self.resized_width is not None: + video_kwargs = { + "resized_height": self.resized_height, + "resized_width": self.resized_width + } + elif self.max_pixels is not None and self.min_pixels is not None: + video_kwargs = { + "max_pixels": self.max_pixels, + "min_pixels": self.min_pixels, + } if self.fps is not None: video_kwargs["fps"] = self.fps else: video_kwargs["nframes"] = self.max_num_frames + video_kwargs["video_sampler"] = self.video_sampler + eval_logger.debug(f"video sampler in worker: {self.video_sampler!r}") batched_messages = [chat_message.to_hf_messages(video_kwargs=video_kwargs) for chat_message in chat_messages] - texts = self.processor.apply_chat_template(batched_messages, tokenize=False, add_generation_prompt=True) - image_inputs, video_inputs = process_vision_info(batched_messages) - if video_inputs is not None: - total_frames = video_inputs[0].shape[0] - indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - video_inputs[0] = video_inputs[0][indices] - padding_side = "left" if self.batch_size > 1 else "right" - inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, padding_side=padding_side, return_tensors="pt") + texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batched_messages] + # Vision Info + start_time = time.time() + image_inputs, video_inputs = process_vision_info(batched_messages, return_video_metadata=True) + video_metadata_seq = [] + if video_inputs: + frames, video_metadata_seq = zip(*video_inputs) + video_inputs = list(frames) + video_metadata_seq = list(video_metadata_seq) + else: + video_inputs = None + assert len(chunk_requests) == len(video_metadata_seq) + inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + end_time = time.time() + vision_info_time = end_time - start_time if self.device_map == "auto": inputs = inputs.to("cuda") @@ -116,13 +148,23 @@ def _collate(x): use_cache=self.use_cache, ) end_time = time.time() - + generation_latency = end_time - start_time generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)] answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Calculate timing metrics for batch - e2e_latency += end_time - start_time + vision_processing_latency += vision_info_time + e2e_latency += generation_latency total_tokens += sum(len(ids) for ids in generated_ids_trimmed) + num_input_tokens, num_input_vision_tokens = self.get_num_tokens(inputs) + total_num_input_tokens += num_input_tokens.sum() + total_num_input_vision_tokens += num_input_vision_tokens.sum() + + for k, (inst, meta) in enumerate(zip(chunk_requests, video_metadata_seq)): + inst.video_metadata = meta + inst.num_input_tokens = num_input_tokens[k].cpu().item() + inst.num_input_vision_tokens = num_input_vision_tokens[k].cpu().item() + for ans, context in zip(answers, texts): clean_ans = parse_reasoning_model_answer(ans) @@ -145,9 +187,70 @@ def _collate(x): "avg_speed": avg_speed, "additional_metrics": { "rank": self.rank, + "vision_processing_latency": vision_processing_latency, + "total_num_input_tokens": total_num_input_tokens, + "total_num_input_vision_tokens": total_num_input_vision_tokens, + "num_requests": len(requests), + }, + } + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + gathered = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(gathered, metric_dict) + + if self.rank == 0: + total_tokens = sum(m["total_tokens"] for m in gathered) + total_requests = sum(m["additional_metrics"]["num_requests"] for m in gathered) + total_vision_processing_latency = sum(m["additional_metrics"]["vision_processing_latency"] for m in gathered) + total_e2e_latency = sum(m["e2e_latency"] for m in gathered) + total_num_input_tokens = sum(m["additional_metrics"]["total_num_input_tokens"].cpu().item() for m in gathered) + total_num_input_vision_tokens = sum(m["additional_metrics"]["total_num_input_vision_tokens"].cpu().item() for m in gathered) + + throughput = total_tokens / total_e2e_latency if total_e2e_latency > 0 else 0.0 + avg_latency_per_req = total_e2e_latency / total_requests if total_requests else 0.0 + avg_vision_processing_latency = total_vision_processing_latency / total_requests if total_vision_processing_latency > 0 else 0.0 + avg_num_input_tokens = total_num_input_tokens / total_requests if total_num_input_tokens > 0 else 0.0 + avg_num_input_vision_tokens = total_num_input_vision_tokens / total_requests if total_num_input_vision_tokens > 0 else 0.0 + avg_total_tokens = total_tokens / total_requests if total_tokens > 0 else 0.0 + + metric_dict = { + "total_tokens": total_tokens, + "e2e_latency": total_e2e_latency, + "avg_speed": throughput, + "additional_metrics": { + "rank": self.rank, + "vision_processing_latency": total_vision_processing_latency, + "total_num_input_tokens": total_num_input_tokens, + "total_num_input_vision_tokens": total_num_input_vision_tokens, + "num_requests": total_requests, + "avg_num_output_tokens": avg_total_tokens, + "avg_num_input_tokens": avg_num_input_tokens, + "avg_num_input_vision_tokens": avg_num_input_vision_tokens, + "avg_vision_processing_latency": avg_vision_processing_latency, + "avg_e2e_latency": avg_latency_per_req, + "per_worker": gathered + }, + } + log_metrics(**metric_dict) + else: + metric_dict = { + "total_tokens": total_tokens, + "e2e_latency": e2e_latency, + "avg_speed": avg_speed, + "additional_metrics": { + "rank": self.rank, + "vision_processing_latency": vision_processing_latency, + "total_num_input_tokens": total_num_input_tokens, + "total_num_input_vision_tokens": total_num_input_vision_tokens, + "num_requests": len(requests), + "avg_num_output_tokens": total_tokens / len(requests), + "avg_num_input_tokens": total_num_input_tokens / len(requests), + "avg_num_input_vision_tokens": total_num_input_vision_tokens / len(requests), + "avg_vision_processing_latency": vision_processing_latency / len(requests), + "avg_e2e_latency": e2e_latency / len(requests), }, } - log_metrics(**metric_dict) + log_metrics(**metric_dict) pbar.close() return res diff --git a/lmms_eval/models/chat/qwen3_vl.py b/lmms_eval/models/chat/qwen3_vl.py index 397bb44b2..253d486e0 100644 --- a/lmms_eval/models/chat/qwen3_vl.py +++ b/lmms_eval/models/chat/qwen3_vl.py @@ -16,6 +16,7 @@ from lmms_eval.models.simple.qwen3_vl import Qwen3_VL as Qwen3_VLSimple from lmms_eval.protocol import ChatMessages +import torch try: from qwen_vl_utils import process_vision_info except ImportError: @@ -26,6 +27,18 @@ class Qwen3_VL(Qwen3_VLSimple): is_simple = False + def get_num_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + vision_end_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_end|>") + start = (inputs.input_ids == vision_start_token_id) # [B, L] + end = (inputs.input_ids == vision_end_token_id) # [B, L] + level = start.cumsum(dim=1) - end.cumsum(dim=1) + in_vision_span = level > 0 + in_vision_span = in_vision_span | start | end + num_vision_tokens = in_vision_span.sum(dim=1) + num_tokens = inputs.attention_mask.sum(dim=1) + return num_tokens, num_vision_tokens + def generate_until(self, requests: List[Instance]) -> List[str]: res = [] @@ -36,15 +49,27 @@ def _collate(x): # we group requests by their generation_kwargs, # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # in the same batch. - re_ords = utils.Collator([reg.args for reg in requests], _collate, group_fn=lambda x: x[2], grouping=True) + re_ords = utils.Collator([(idx, reg.args) for idx, reg in enumerate(requests)], _collate, group_fn=lambda x: x[1][2], grouping=True) chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + chunk_offset = 0 num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1 pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") e2e_latency = 0 total_tokens = 0 + vision_processing_latency = 0 + total_num_input_vision_tokens = 0 + total_num_input_tokens = 0 for chunk in chunks: + # Vision Info + start_time = time.time() + video_metadata_seq = [] + chunk_request_indices, chunk = zip(*chunk) + chunk_requests = [requests[idx] for idx in chunk_request_indices] ctx, doc_to_messages, all_gen_kwargs, doc_id, task, split = zip(*chunk) chat_messages = [doc_to_messages[idx](self.task_dict[task][split][ids]) for idx, (ids, task, split) in enumerate(zip(doc_id, task, split))] + if self.video_sampler.will_process_messages: + chat_messages, video_metadata_seq = zip(*[self.video_sampler.process_messages(chat_message, eval_logger) for chat_message in chat_messages]) + assert len(chunk_requests) == len(video_metadata_seq) chat_messages: List[ChatMessages] = [ChatMessages(**{"messages": message}) for message in chat_messages] visuals = [] videos = [] @@ -57,26 +82,35 @@ def _collate(x): gen_kwargs = all_gen_kwargs[0] # Apply chat template - video_kwargs = { - "max_pixels": self.max_pixels, - "min_pixels": self.min_pixels, - } + if self.resized_height is not None and self.resized_width is not None: + video_kwargs = { + "resized_height": self.resized_height, + "resized_width": self.resized_width + } + elif self.max_pixels is not None and self.min_pixels is not None: + video_kwargs = { + "max_pixels": self.max_pixels, + "min_pixels": self.min_pixels, + } if self.fps is not None: video_kwargs["fps"] = self.fps else: video_kwargs["nframes"] = self.max_num_frames + video_kwargs["video_sampler"] = self.video_sampler batched_messages = [chat_message.to_hf_messages(video_kwargs=video_kwargs) for chat_message in chat_messages] texts = self.processor.apply_chat_template(batched_messages, tokenize=False, add_generation_prompt=True) - image_inputs, video_inputs = process_vision_info(batched_messages) - if video_inputs is not None: - total_frames = video_inputs[0].shape[0] - indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - video_inputs[0] = video_inputs[0][indices] + image_inputs, video_inputs = process_vision_info(batched_messages, return_video_metadata=True) + if video_inputs: + frames, video_metadata_seq = zip(*video_inputs) + video_inputs = list(frames) + video_metadata_seq = list(video_metadata_seq) + assert len(chunk_requests) == len(video_metadata_seq) + else: + video_inputs = None padding_side = "left" if self.batch_size > 1 else "right" inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, padding_side=padding_side, return_tensors="pt") + end_time = time.time() + vision_info_time = end_time - start_time if self.device_map == "auto": inputs = inputs.to("cuda") @@ -116,13 +150,23 @@ def _collate(x): use_cache=self.use_cache, ) end_time = time.time() + generation_latency = end_time - start_time generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)] answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Calculate timing metrics for batch - e2e_latency += end_time - start_time + vision_processing_latency += vision_info_time + e2e_latency += generation_latency total_tokens += sum(len(ids) for ids in generated_ids_trimmed) + num_input_tokens, num_input_vision_tokens = self.get_num_tokens(inputs) + total_num_input_tokens += num_input_tokens.sum() + total_num_input_vision_tokens += num_input_vision_tokens.sum() + + for k, (inst, meta) in enumerate(zip(chunk_requests, video_metadata_seq)): + inst.video_metadata = meta + inst.num_input_tokens = num_input_tokens[k].cpu().item() + inst.num_input_vision_tokens = num_input_vision_tokens[k].cpu().item() for ans, context in zip(answers, texts): clean_ans = parse_reasoning_model_answer(ans) @@ -145,6 +189,68 @@ def _collate(x): "avg_speed": avg_speed, "additional_metrics": { "rank": self.rank, + "vision_processing_latency": vision_processing_latency, + "total_num_input_tokens": total_num_input_tokens, + "total_num_input_vision_tokens": total_num_input_vision_tokens, + "num_requests": len(requests), + }, + } + + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + gathered = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(gathered, metric_dict) + + if self.rank == 0: + total_tokens = sum(m["total_tokens"] for m in gathered) + total_requests = sum(m["additional_metrics"]["num_requests"] for m in gathered) + total_vision_processing_latency = sum(m["additional_metrics"]["vision_processing_latency"] for m in gathered) + total_e2e_latency = sum(m["e2e_latency"] for m in gathered) + total_num_input_tokens = sum(m["additional_metrics"]["total_num_input_tokens"].cpu().item() for m in gathered) + total_num_input_vision_tokens = sum(m["additional_metrics"]["total_num_input_vision_tokens"].cpu().item() for m in gathered) + + throughput = total_tokens / total_e2e_latency if total_e2e_latency > 0 else 0.0 + avg_latency_per_req = total_e2e_latency / total_requests if total_requests else 0.0 + avg_vision_processing_latency = total_vision_processing_latency / total_requests if total_vision_processing_latency > 0 else 0.0 + avg_num_input_tokens = total_num_input_tokens / total_requests if total_num_input_tokens > 0 else 0.0 + avg_num_input_vision_tokens = total_num_input_vision_tokens / total_requests if total_num_input_vision_tokens > 0 else 0.0 + avg_total_tokens = total_tokens / total_requests if total_tokens > 0 else 0.0 + + metric_dict = { + "total_tokens": total_tokens, + "e2e_latency": total_e2e_latency, + "avg_speed": throughput, + "additional_metrics": { + "rank": self.rank, + "vision_processing_latency": total_vision_processing_latency, + "total_num_input_tokens": total_num_input_tokens, + "total_num_input_vision_tokens": total_num_input_vision_tokens, + "num_requests": total_requests, + "avg_num_output_tokens": avg_total_tokens, + "avg_num_input_tokens": avg_num_input_tokens, + "avg_num_input_vision_tokens": avg_num_input_vision_tokens, + "avg_vision_processing_latency": avg_vision_processing_latency, + "avg_e2e_latency": avg_latency_per_req, + "per_worker": gathered + }, + } + log_metrics(**metric_dict) + else: + metric_dict = { + "total_tokens": total_tokens, + "e2e_latency": e2e_latency, + "avg_speed": avg_speed, + "additional_metrics": { + "rank": self.rank, + "vision_processing_latency": vision_processing_latency, + "total_num_input_tokens": total_num_input_tokens, + "total_num_input_vision_tokens": total_num_input_vision_tokens, + "num_requests": len(requests), + "avg_num_output_tokens": total_tokens / len(requests), + "avg_num_input_tokens": total_num_input_tokens / len(requests), + "avg_num_input_vision_tokens": total_num_input_vision_tokens / len(requests), + "avg_vision_processing_latency": vision_processing_latency / len(requests), + "avg_e2e_latency": e2e_latency / len(requests), }, } log_metrics(**metric_dict) diff --git a/lmms_eval/models/chat/sglang.py b/lmms_eval/models/chat/sglang.py index 42097e4eb..d3fde2a3c 100644 --- a/lmms_eval/models/chat/sglang.py +++ b/lmms_eval/models/chat/sglang.py @@ -290,8 +290,12 @@ def generate_until(self, requests) -> List[str]: if video_inputs is not None: video_inputs, video_metadatas = zip(*video_inputs) video_inputs, video_metadatas = list(video_inputs), list(video_metadatas) + for req, meta in zip(batch_requests, video_metadatas): + req.video_metadata = meta else: video_metadatas = None + for req in batch_requests: + req.video_metadata = None assert image_inputs is None or video_inputs is None, "Only one of image or video inputs should be provided" inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, video_metadata=video_metadatas, **video_kwargs, padding=True, return_tensors="pt") # If video inputs is not None, we need to replace the image token ids with the video token ids before generating diff --git a/lmms_eval/models/chat/vllm_generate.py b/lmms_eval/models/chat/vllm_generate.py index 6c9eba758..2938442a1 100644 --- a/lmms_eval/models/chat/vllm_generate.py +++ b/lmms_eval/models/chat/vllm_generate.py @@ -108,6 +108,7 @@ def make_one_request(self, request: Instance) -> Tuple[list[dict], dict]: video_metadatas.append(video_metadata) kwargs["fps"] = fps kwargs["do_sample_frames"] = False + request.video_metadata = video_metadatas if video_metadatas else None if len(videos) == 0: video_inputs = None video_metadatas = None diff --git a/lmms_eval/models/model_utils/gen_metrics.py b/lmms_eval/models/model_utils/gen_metrics.py index aff03cb27..6a4e52a43 100644 --- a/lmms_eval/models/model_utils/gen_metrics.py +++ b/lmms_eval/models/model_utils/gen_metrics.py @@ -1,5 +1,8 @@ +import json +import os import time -from typing import Any, Callable, Dict, List +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union import torch from loguru import logger as eval_logger @@ -31,30 +34,94 @@ def calculate_token_throughput(token_count: int, duration: float) -> float: return token_count / duration -def log_metrics(e2e_latency: float, total_tokens: int, avg_speed: float, additional_metrics: Dict[str, Any] = None): +def _json_default_serializer(value: Any) -> Any: + if isinstance(value, torch.Tensor): + if value.ndim == 0: + return value.item() + return value.tolist() + if isinstance(value, (set, tuple)): + return list(value) + return str(value) + + +def _persist_metrics(metrics: Dict[str, Any], metrics_path: Union[str, Path]) -> None: + destination = Path(metrics_path) + destination.parent.mkdir(parents=True, exist_ok=True) + + existing: List[Dict[str, Any]] = [] + if destination.exists(): + try: + with destination.open("r", encoding="utf-8") as handle: + content = handle.read().strip() + if content: + loaded = json.loads(content) + if isinstance(loaded, list): + existing = loaded + else: + existing = [loaded] + except json.JSONDecodeError: + eval_logger.warning(f"Existing metrics file {destination} is not valid JSON. Overwriting.") + except Exception as exc: + eval_logger.warning(f"Could not read metrics file at {destination}: {exc}") + + existing.append(metrics) + + tmp_path = destination.with_suffix(destination.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump(existing, handle, indent=2, ensure_ascii=False, default=_json_default_serializer) + handle.write("\n") + tmp_path.replace(destination) + + +def log_metrics( + e2e_latency: float, + total_tokens: int, + avg_speed: float, + additional_metrics: Optional[Dict[str, Any]] = None, + metrics_path: Optional[Union[str, Path]] = None, +): """ - Log the metrics in a structured format. + Log the metrics in a structured format and optionally persist them to disk. Args: e2e_latency (float): The end-to-end latency in seconds. total_tokens (int): The total number of tokens processed. avg_speed (float): The average speed in tokens per second. - additional_metrics (Dict[str, Any]): Additional metrics to log. + additional_metrics (Dict[str, Any], optional): Additional metrics to log. + metrics_path (Union[str, Path], optional): Path to a JSON file where metrics should be appended. + If not provided, the `LMMS_EVAL_METRICS_PATH` environment variable will be used when available. """ required_stats = f"Metric summary - Total time: {e2e_latency:.3f}s, Total tokens: {total_tokens}, Avg speed: {avg_speed:.1f} tokens/s" - if additional_metrics is not None: + if additional_metrics: required_stats += ", Additional metrics: " required_stats += ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in additional_metrics.items()) eval_logger.info(required_stats) + metrics_payload: Dict[str, Any] = { + "e2e_latency": float(e2e_latency), + "total_tokens": float(total_tokens), + "avg_speed": float(avg_speed), + "logged_at": time.time(), + } + if additional_metrics: + metrics_payload["additional_metrics"] = additional_metrics + + destination = metrics_path or os.getenv("LMMS_EVAL_METRICS_PATH") + if destination: + try: + _persist_metrics(metrics_payload, destination) + except Exception as exc: + eval_logger.warning(f"Failed to persist metrics to {destination}: {exc}") + class GenMetrics: """ A class to manage the generation of metrics for model evaluation. """ - def __init__(self, tokenize_fn: Callable = space_tokenizer): + def __init__(self, tokenize_fn: Callable = space_tokenizer, metrics_path: Optional[Union[str, Path]] = None): self.tokenize_fn = tokenize_fn + self.metrics_path = Path(metrics_path) if metrics_path else None def __enter__(self): """ @@ -71,15 +138,23 @@ def log_metric(self, content: List[Any], additional_metrics: Dict[str, Any] = No num_tokens = sum(self.tokenize_fn(item) for item in content) duration = self.end_time - self.start_time throughput = calculate_token_throughput(num_tokens, duration) - self.metrics = { + base_metrics = { "num_tokens": num_tokens, "duration": duration, "throughput": throughput, } + self.metrics = base_metrics.copy() if additional_metrics: self.metrics.update(additional_metrics) - log_metrics(self.metrics) + supplementary_metrics = {k: v for k, v in self.metrics.items() if k not in {"num_tokens", "duration", "throughput"}} + log_metrics( + e2e_latency=duration, + total_tokens=num_tokens, + avg_speed=throughput, + additional_metrics=supplementary_metrics or None, + metrics_path=self.metrics_path, + ) def __exit__(self, exc_type, exc_value, traceback): """ diff --git a/lmms_eval/models/simple/qwen2_5_vl.py b/lmms_eval/models/simple/qwen2_5_vl.py index 55742d7b5..111594ca7 100644 --- a/lmms_eval/models/simple/qwen2_5_vl.py +++ b/lmms_eval/models/simple/qwen2_5_vl.py @@ -45,6 +45,8 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, use_cache=True, attn_implementation: Optional[str] = None, + resized_height: Optional[int] = None, + resized_width: Optional[int] = None, min_pixels: int = 256 * 28 * 28, max_pixels: int = 1605632, max_num_frames: int = 32, @@ -54,6 +56,7 @@ def __init__( system_prompt: Optional[str] = "You are a helpful assistant.", interleave_visuals: Optional[bool] = False, reasoning_prompt: Optional[str] = None, + video_sampler: Optional[str] = None, **kwargs, ) -> None: super().__init__() @@ -128,6 +131,9 @@ def __init__( else: self._rank = 0 self._world_size = 1 + self.resized_height = resized_height + self.resized_width = resized_width + self.video_sampler = video_sampler @property def config(self): @@ -223,7 +229,7 @@ def _collate(x): for i in range(len(contexts)): if "" in contexts[i]: contexts[i] = contexts[i].replace("", "") - + batched_messages = [] for i, context in enumerate(contexts): if "" in context: @@ -242,7 +248,12 @@ def _collate(x): first_frame = vr[0].asnumpy() height, width = first_frame.shape[:2] # max_pixels = height * width - processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels}) + processed_visuals.append({ + "type": "video", + "video": visual, + "max_pixels": self.max_pixels, + "min_pixels": self.min_pixels, + }) elif isinstance(visual, Image.Image): # Handle both single and multiple images base64_image = visual.convert("RGB") buffer = BytesIO() @@ -284,18 +295,18 @@ def _collate(x): texts = self.processor.apply_chat_template(batched_messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(batched_messages) - if video_inputs is not None: - total_frames = video_inputs[0].shape[0] - indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) - # Ensure unique indices if linspace produces duplicates for few frames - indices = np.unique(indices) - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - indices = np.unique(indices) # Ensure uniqueness again - video_inputs[0] = video_inputs[0][indices] - padding_side = "left" if self.batch_size > 1 else "right" - inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, padding_side=padding_side, return_tensors="pt") + # if video_inputs is not None: + # total_frames = video_inputs[0].shape[0] + # indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) + # # Ensure unique indices if linspace produces duplicates for few frames + # indices = np.unique(indices) + # # Append the last frame index if not already included + # if total_frames - 1 not in indices: + # indices = np.append(indices, total_frames - 1) + # indices = np.unique(indices) # Ensure uniqueness again + # video_inputs[0] = video_inputs[0][indices] + inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + if self.device_map == "auto": inputs = inputs.to("cuda") else: @@ -345,9 +356,9 @@ def _collate(x): self.cache_hook.add_partial("generate_until", (context, gen_kwargs), clean_ans) pbar.update(1) - # eval_logger.debug(f"Question: {context}") - # eval_logger.debug(f"Model Raw Response: {ans}") - # eval_logger.debug(f"Model Clean Response: {clean_ans}") + eval_logger.debug(f"Question: {context}") + eval_logger.debug(f"Model Raw Response: {ans}") + eval_logger.debug(f"Model Clean Response: {clean_ans}") # reorder this group of results back to original unsorted form res = re_ords.get_original(res) diff --git a/lmms_eval/models/simple/qwen3_vl.py b/lmms_eval/models/simple/qwen3_vl.py index f800c7372..a1e059fa2 100644 --- a/lmms_eval/models/simple/qwen3_vl.py +++ b/lmms_eval/models/simple/qwen3_vl.py @@ -46,6 +46,8 @@ def __init__( batch_size: Optional[Union[int, str]] = 1, use_cache=True, attn_implementation: Optional[str] = None, + resized_height: Optional[int] = None, + resized_width: Optional[int] = None, min_pixels: int = 256 * 28 * 28, max_pixels: int = 1605632, max_num_frames: int = 32, @@ -55,11 +57,12 @@ def __init__( system_prompt: Optional[str] = "You are a helpful assistant.", interleave_visuals: Optional[bool] = False, reasoning_prompt: Optional[str] = None, + video_sampler: Optional[str] = None, **kwargs, ) -> None: super().__init__() # Do not use kwargs for now - assert kwargs == {}, f"Unexpected kwargs: {kwargs}" + # assert kwargs == {}, f"Unexpected kwargs: {kwargs}" # Validate attention implementation valid_attn_implementations = [None, "flash_attention_2", "sdpa", "eager"] @@ -75,6 +78,7 @@ def __init__( raise ValueError("max_image_size is only applicable if use_custom_video_loader is True") accelerator = Accelerator() + eval_logger.info(f"Accelerator: {accelerator.distributed_type}") self.accelerator = accelerator if accelerator.num_processes > 1: self._device = torch.device(f"cuda:{accelerator.local_process_index}") @@ -84,14 +88,16 @@ def __init__( self.device_map = device_map if device_map else device # Prepare model loading arguments - model_kwargs = { + model_kwargs = kwargs.copy() + model_kwargs.update({ "torch_dtype": "bfloat16", "device_map": self.device_map, - } + }) # Add attention implementation if specified if attn_implementation is not None: model_kwargs["attn_implementation"] = attn_implementation + print("model_kwargs: ", model_kwargs) # check whether its an MoE model match = re.search(r"A\d+B", pretrained) @@ -132,6 +138,9 @@ def __init__( else: self._rank = 0 self._world_size = 1 + self.resized_height = resized_height + self.resized_width = resized_width + self.video_sampler = video_sampler @property def config(self): @@ -287,18 +296,17 @@ def _collate(x): batched_messages.append(message) texts = self.processor.apply_chat_template(batched_messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(batched_messages) - if video_inputs is not None: - total_frames = video_inputs[0].shape[0] - indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) - # Ensure unique indices if linspace produces duplicates for few frames - indices = np.unique(indices) - # Append the last frame index if not already included - if total_frames - 1 not in indices: - indices = np.append(indices, total_frames - 1) - indices = np.unique(indices) # Ensure uniqueness again - video_inputs[0] = video_inputs[0][indices] - padding_side = "left" if self.batch_size > 1 else "right" - inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, padding_side=padding_side, return_tensors="pt") + # if video_inputs is not None: + # total_frames = video_inputs[0].shape[0] + # indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int) + # # Ensure unique indices if linspace produces duplicates for few frames + # indices = np.unique(indices) + # # Append the last frame index if not already included + # if total_frames - 1 not in indices: + # indices = np.append(indices, total_frames - 1) + # indices = np.unique(indices) # Ensure uniqueness again + # video_inputs[0] = video_inputs[0][indices] + inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") if self.device_map == "auto": inputs = inputs.to("cuda") diff --git a/lmms_eval/protocol.py b/lmms_eval/protocol.py index 46d423d3e..4a7abf9af 100644 --- a/lmms_eval/protocol.py +++ b/lmms_eval/protocol.py @@ -27,6 +27,7 @@ class ChatImageContent(BaseModel): class ChatVideoContent(BaseModel): type: Literal["video"] = "video" url: Any + question: str = None class ChatAudioContent(BaseModel): @@ -74,6 +75,8 @@ def to_hf_messages(self, video_kwargs: Dict[str, str] = None): elif content.type == "image": hf_message["content"].append({"type": "image", "image": content.url}) elif content.type == "video": + if content.question is not None: + video_kwargs["question"] = content.question hf_message["content"].append({"type": "video", "video": content.url, **video_kwargs}) elif content.type == "audio": hf_message["content"].append({"type": "audio", "audio": content.url}) diff --git a/lmms_eval/tasks/egoschema/egoschema.yaml b/lmms_eval/tasks/egoschema/egoschema.yaml index 5a69399c7..af9bc5f1b 100755 --- a/lmms_eval/tasks/egoschema/egoschema.yaml +++ b/lmms_eval/tasks/egoschema/egoschema.yaml @@ -2,7 +2,7 @@ dataset_name: "GENERATION" task: "egoschema" test_split: test output_type: generate_until -doc_to_visual: !function utils.egoschema_doc_to_visual +doc_to_messages: !function utils.egoschema_doc_to_messages doc_to_text: !function utils.egoschema_doc_to_text doc_to_target: !function utils.egoschema_doc_to_answer process_results: !function utils.egoschema_process_results_generation @@ -17,4 +17,4 @@ lmms_eval_specific_kwargs: post_prompt: "\nAnswer with the option's letter from the given choices directly." aria: pre_prompt: "Please answer the question about the video:\n" - post_prompt: "\nAnswer with the option's letter from the given choices directly." \ No newline at end of file + post_prompt: "\nAnswer with the option's letter from the given choices directly." diff --git a/lmms_eval/tasks/egoschema/egoschema_mcppl.yaml b/lmms_eval/tasks/egoschema/egoschema_mcppl.yaml index b3a380b01..6375865db 100755 --- a/lmms_eval/tasks/egoschema/egoschema_mcppl.yaml +++ b/lmms_eval/tasks/egoschema/egoschema_mcppl.yaml @@ -2,7 +2,7 @@ dataset_name: "MC_PPL" task: "egoschema_mcppl" test_split: test output_type: multiple_choice -doc_to_visual: !function utils.egoschema_doc_to_visual +doc_to_messages: !function utils.egoschema_doc_to_messages doc_to_text: "question" doc_to_target: !function utils.egoschema_doc_to_answer doc_to_choice: !function utils.egoschema_doc_to_choice diff --git a/lmms_eval/tasks/egoschema/egoschema_subset.yaml b/lmms_eval/tasks/egoschema/egoschema_subset.yaml index 6b63c34f8..32f280a2a 100755 --- a/lmms_eval/tasks/egoschema/egoschema_subset.yaml +++ b/lmms_eval/tasks/egoschema/egoschema_subset.yaml @@ -2,7 +2,7 @@ dataset_name: "Subset" task: "egoschema_subset" test_split: test output_type: generate_until -doc_to_visual: !function utils.egoschema_doc_to_visual +doc_to_messages: !function utils.egoschema_doc_to_messages doc_to_text: !function utils.egoschema_doc_to_text doc_to_target: !function utils.egoschema_doc_to_answer process_results: !function utils.egoschema_process_results_generation diff --git a/lmms_eval/tasks/egoschema/egoschema_subset_mcppl.yaml b/lmms_eval/tasks/egoschema/egoschema_subset_mcppl.yaml index ebe868649..51b9e8a4a 100755 --- a/lmms_eval/tasks/egoschema/egoschema_subset_mcppl.yaml +++ b/lmms_eval/tasks/egoschema/egoschema_subset_mcppl.yaml @@ -2,7 +2,7 @@ dataset_name: "Subset" task: "egoschema_subset_mcppl" test_split: test output_type: multiple_choice -doc_to_visual: !function utils.egoschema_doc_to_visual +doc_to_messages: !function utils.egoschema_doc_to_messages doc_to_text: "question" doc_to_target: !function utils.egoschema_doc_to_answer doc_to_choice: !function utils.egoschema_doc_to_choice diff --git a/lmms_eval/tasks/egoschema/utils.py b/lmms_eval/tasks/egoschema/utils.py index 6b1100fff..735145e5e 100755 --- a/lmms_eval/tasks/egoschema/utils.py +++ b/lmms_eval/tasks/egoschema/utils.py @@ -31,6 +31,26 @@ from loguru import logger as eval_logger +from PIL import Image as PIL_Image + +def egoschema_doc_to_messages(doc, lmms_eval_specific_kwargs=None): + visuals = egoschema_doc_to_visual(doc) + if visuals is None: + visuals = [] + text = egoschema_doc_to_text(doc, lmms_eval_specific_kwargs=lmms_eval_specific_kwargs) + messages = [{"role": "user", "content": []}] + content = [] + for visual in visuals: + if isinstance(visual, PIL_Image.Image): + content.append({"type": "image", "url": visual}) + elif isinstance(visual, dict): + content.append({"type": "audio", "url": visual}) + elif isinstance(visual, str): + content.append({"type": "video", "url": visual, "question": egoschema_doc_to_question(doc)}) + content.append({"type": "text", "text": text}) + messages[0]["content"] = content + return messages + # Pass in video path here # Can only work correctly with video llm @@ -45,6 +65,8 @@ def egoschema_doc_to_visual(doc): sys.exit(f"video path:{video_path} does not exist, please check") return [video_path] +def egoschema_doc_to_question(doc): + return doc["question"] # This is the place where you format your question def egoschema_doc_to_text(doc, lmms_eval_specific_kwargs=None): @@ -57,7 +79,7 @@ def egoschema_doc_to_text(doc, lmms_eval_specific_kwargs=None): if "post_prompt" in lmms_eval_specific_kwargs: post_prompt = lmms_eval_specific_kwargs["post_prompt"] - question = doc["question"] + question = egoschema_doc_to_question(doc) if "option" in doc: for op in doc["option"]: question += "\n" + op diff --git a/lmms_eval/tasks/minerva/README.md b/lmms_eval/tasks/minerva/README.md new file mode 100644 index 000000000..623a209df --- /dev/null +++ b/lmms_eval/tasks/minerva/README.md @@ -0,0 +1,45 @@ +## MINERVA +MINERVA consists of ~1.5K +challenging question-answer-decoy (QAD) sets for variable length videos. For +each question, we provide 5 answer choices, as well as detailed, +manually-annotated reasoning traces. Every question in MINERVA requires complex +reasoning using two or more skills +(for example numerical reasoning, temporal reasoning, spatial navigation). +Videos also span multiple domains (short films, sports, instructional videos +etc), with various video lengths (from 2 minutes to over 1.5 hours). The +hand-crafted, detailed reasoning trace accompanying each question outlines +the steps that are required to come to the correct answer. +These traces include timestamps where necessary to refer to relevant sections of +the video, and also describes key actions, objects, as well as outlines logical +reasoning steps. More details are provided in our +[arXiv](https://arxiv.org/abs/2505.00681) paper. + +### Downloading the Data +We provide a json file that contains the YouTube IDs and annotations. + +The json file contains the following fields: + +- key: Unique identifier for each question +- video_id: YouTube URL +- question: Free-form question +- answer: Free-form answer +- answer_choice_{i}: Decoys for MCQ evaluation, i in range(0,4) +- answer_id: ID of the correct answer in the decoys +- reasoning: Detailed reasoning trace +- question type: A comma-separated list of multiple skills needed to answer the +question +- split: Coarse video domain +- category: Fine-grained video domain + +[MINERVA json](https://storage.mtls.cloud.google.com/neptunedata/minerva.json) + +### Citing this work + +```latex +@article{minerva25, + title={MINERVA: Evaluating Complex Video Reasoning}, + author={Nagrani, Arsha and Menon, Sachit and Iscen, Ahmet and Buch, Shyamal and Mehran, Ramin and Jha, Nilpa and Hauth, Anja and Zhu, Yukun and Vondrick, Carl and Sirotenko, Mikhail and Schmid, Cordelia and Weyand, Tobias}, + journal={arXiv preprint arXiv:2505.00681}, + year={2025} +} +``` \ No newline at end of file diff --git a/lmms_eval/tasks/minerva/_default_template_yaml b/lmms_eval/tasks/minerva/_default_template_yaml new file mode 100644 index 000000000..a060727e6 --- /dev/null +++ b/lmms_eval/tasks/minerva/_default_template_yaml @@ -0,0 +1,11 @@ +dataset_path: json +dataset_kwargs: + video: True + cache_dir: minerva + external_downloader: + fn: lmms_eval.tasks.minerva.utils.download_dataset + result_dataset_kwarg: data_files.test +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" \ No newline at end of file diff --git a/lmms_eval/tasks/minerva/download_videos.py b/lmms_eval/tasks/minerva/download_videos.py new file mode 100644 index 000000000..257dec8b7 --- /dev/null +++ b/lmms_eval/tasks/minerva/download_videos.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +""" +Script to download YouTube videos from minerva.json dataset. +Downloads unique videos only, skipping duplicates and already downloaded files. +""" + +import json +import os +import subprocess +from pathlib import Path +from typing import Set +from tqdm import tqdm + + +def load_video_ids(json_path: Path) -> list: + """Load all video IDs from the JSON file.""" + print(f"Loading video IDs from {json_path}...") + with open(json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + video_ids = [item['video_id'] for item in data if 'video_id' in item] + print(f"Found {len(video_ids)} total entries") + return video_ids + +def get_unique_video_ids(video_ids: list) -> list: + """Get unique video IDs while preserving order.""" + seen = set() + unique_ids = [] + for vid in video_ids: + if vid not in seen: + seen.add(vid) + unique_ids.append(vid) + print(f"Found {len(unique_ids)} unique video IDs") + return unique_ids + +def get_already_downloaded(output_dir: Path) -> Set[str]: + """Check which videos have already been downloaded.""" + if not output_dir.exists(): + return set() + + downloaded = set() + # Common video extensions + extensions = ['.mp4', '.webm', '.mkv', '.flv', '.avi', '.mov'] + + for file in output_dir.iterdir(): + if file.is_file() and file.suffix in extensions: + # Extract video ID from filename (assumes format: video_id.ext) + video_id = file.stem + downloaded.add(video_id) + + print(f"Found {len(downloaded)} already downloaded videos") + return downloaded + +def download_video(video_id: str, output_dir: Path) -> bool: + """Download a single YouTube video using yt-dlp.""" + url = f"https://www.youtube.com/watch?v={video_id}" + output_template = str(output_dir / f"{video_id}.%(ext)s") + if os.path.exists(output_template): + return True + + try: + # Using yt-dlp for downloading + # -f best: download best quality + # --no-playlist: don't download playlists + # -o: output template + cmd = [ + "yt-dlp", + "-f", "best", + "--no-playlist", + "-o", output_template, + "--cookies", "cookies.txt", + url + ] + print('cmd', ' '.join(cmd)) + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + + if result.returncode == 0: + return True + else: + print(f"✗ Failed to download {video_id}: {result.stderr}") + return False + + except subprocess.TimeoutExpired: + print(f"✗ Timeout while downloading {video_id}") + return False + except FileNotFoundError: + print("Error: yt-dlp is not installed. Please install it with:") + print(" pip install yt-dlp") + print(" or: sudo apt-get install yt-dlp") + raise + except Exception as e: + print(f"✗ Error downloading {video_id}: {e}") + return False + +def main(output_dir: Path, json_file: Path): + """Main function to orchestrate video downloads.""" + # Create output directory if it doesn't exist + output_dir.mkdir(parents=True, exist_ok=True) + + # Load video IDs from JSON + all_video_ids = load_video_ids(json_file) + + # Get unique video IDs + unique_video_ids = get_unique_video_ids(all_video_ids) + + # Check which videos are already downloaded + already_downloaded = get_already_downloaded(output_dir) + + # Filter out already downloaded videos + to_download = [vid for vid in unique_video_ids if vid not in already_downloaded] + + if not to_download: + print("\nAll videos are already downloaded!") + return + + print(f"\nNeed to download {len(to_download)} videos") + print(f"Skipping {len(already_downloaded)} already downloaded videos\n") + + # Download videos + successful = 0 + failed = 0 + + for i, video_id in enumerate(tqdm(to_download, desc="Downloading videos", total=len(to_download)), 1): + print(f"\n[{i}/{len(to_download)}] Processing {video_id}") + if download_video(video_id, output_dir): + successful += 1 + else: + failed += 1 + + # Print summary + print("\n" + "="*60) + print("DOWNLOAD SUMMARY") + print("="*60) + print(f"Total unique videos: {len(unique_video_ids)}") + print(f"Already downloaded: {len(already_downloaded)}") + print(f"Attempted downloads: {len(to_download)}") + print(f"Successful: {successful}") + print(f"Failed: {failed}") + print("="*60) diff --git a/lmms_eval/tasks/minerva/minerva.yaml b/lmms_eval/tasks/minerva/minerva.yaml new file mode 100755 index 000000000..41909111e --- /dev/null +++ b/lmms_eval/tasks/minerva/minerva.yaml @@ -0,0 +1,23 @@ +dataset_name: "MINERVA" +task: "minerva" +test_split: test +output_type: generate_until +doc_to_messages: !function utils.minerva_doc_to_messages +doc_to_text: !function utils.minerva_doc_to_text +doc_to_target: !function utils.minerva_doc_to_answer +process_results: !function utils.minerva_process_results_generation +metric_list: + - metric: submission + aggregation: !function utils.minerva_aggregate_mc + higher_is_better: true + - metric: score + aggregation: !function utils.minerva_aggregate_score + higher_is_better: true +include: _default_template_yaml +lmms_eval_specific_kwargs: + default: + pre_prompt: "You will be given a question about a video and five possible answer options. You are provided frames from the video.\n\n" + post_prompt: "\n\nOutput the final answer in the format “Final Answer: (X)” where X is the correct letter choice from (a)-(e). DO NOT OUTPUT text or any other words with the answer." + aria: + pre_prompt: "Please answer the question about the video:\n" + post_prompt: "\nAnswer with the option's letter from the given choices directly." \ No newline at end of file diff --git a/lmms_eval/tasks/minerva/utils.py b/lmms_eval/tasks/minerva/utils.py new file mode 100755 index 000000000..598f81bb5 --- /dev/null +++ b/lmms_eval/tasks/minerva/utils.py @@ -0,0 +1,301 @@ +import datetime +import json +import os +import random +import sys +from pathlib import Path + +import numpy as np +import yaml +from decord import VideoReader, cpu +from lmms_eval.tasks.minerva.download_videos import main as download_videos + +import lmms_eval.tasks._task_utils.file_utils as file_utils + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) + +# We will unzip all the zip files +# To HF HOME cache dir +# And load it here +HF_HOME = os.environ["HF_HOME"] if "HF_HOME" in os.environ else os.path.expanduser("~/.cache/huggingface/hub") +_CACHE_SUBDIR = config["dataset_kwargs"]["cache_dir"] +CACHE_DIR = Path(os.path.join(HF_HOME, _CACHE_SUBDIR)).expanduser() +VIDEOS_DIR = CACHE_DIR / "videos" +cache_dir = str(CACHE_DIR) +videos_dir = str(VIDEOS_DIR) +OPTIONS = ["a", "b", "c", "d", "e"] + +from loguru import logger as eval_logger + +from PIL import Image as PIL_Image + +import requests +MINERVA_JSON_URL = "https://storage.mtls.cloud.google.com/neptunedata/minerva.json" + + +def download_file(url: str, target_dir: str | Path, filename: str | None = None) -> Path: + target_path = Path(target_dir).expanduser().resolve() + target_path.mkdir(parents=True, exist_ok=True) + + if filename is None: + filename = url.split("/")[-1] or "downloaded_file" + file_path = target_path / filename + if file_path.exists(): + return file_path + + with requests.get(url, stream=True, timeout=30) as response: + response.raise_for_status() + with file_path.open("wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + return file_path + + +def download_dataset( + cache_dir: str | Path | None = None, + videos_dir: str | Path | None = None, + url: str = MINERVA_JSON_URL, + **_, +) -> Path: + target_cache_dir = Path(cache_dir).expanduser().resolve() if cache_dir else CACHE_DIR + target_videos_dir = Path(videos_dir).expanduser().resolve() if videos_dir else VIDEOS_DIR + + json_file = download_file(url, target_cache_dir, "minerva.json") + download_videos(output_dir=target_videos_dir, json_file=json_file) + return json_file + + +def minerva_doc_to_question(doc): + return doc["question"] + +def minerva_doc_to_messages(doc, lmms_eval_specific_kwargs=None): + visuals = minerva_doc_to_visual(doc) + if visuals is None: + visuals = [] + text = minerva_doc_to_text(doc, lmms_eval_specific_kwargs=lmms_eval_specific_kwargs) + messages = [{"role": "user", "content": []}] + content = [] + for visual in visuals: + if isinstance(visual, PIL_Image.Image): + content.append({"type": "image", "url": visual}) + elif isinstance(visual, dict): + content.append({"type": "audio", "url": visual}) + elif isinstance(visual, str): + content.append({"type": "video", "url": visual, "question": minerva_doc_to_question(doc)}) + content.append({"type": "text", "text": text}) + messages[0]["content"] = content + return messages + + +# Pass in video path here +# Can only work correctly with video llm +def minerva_doc_to_visual(doc): + video_path = doc["video_id"] + ".mp4" + video_path = os.path.join(videos_dir, video_path) + if os.path.exists(video_path): + video_path = video_path + elif os.path.exists(video_path.replace("mp4", "MP4")): + video_path = video_path.replace("mp4", "MP4") + else: + sys.exit(f"video path:{video_path} does not exist, please check") + return [video_path] + + +# This is the place where you format your question +def minerva_doc_to_text(doc, lmms_eval_specific_kwargs=None): + if lmms_eval_specific_kwargs is None: + lmms_eval_specific_kwargs = {} + pre_prompt = "" + post_prompt = "" + if "pre_prompt" in lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] + if "post_prompt" in lmms_eval_specific_kwargs: + post_prompt = lmms_eval_specific_kwargs["post_prompt"] + + question = minerva_doc_to_question(doc) + for i, choice in enumerate(OPTIONS): + question += f"\n({choice}) {doc[f'answer_choice_{i}']}" + + return f"{pre_prompt}{question}{post_prompt}" + + +def minerva_doc_to_answer(doc): + return doc["answer_id"] + + +# Process result for mc_ppl +def minerva_process_results(doc, result): + # Initialize minimum value and index + min_value = float("inf") + min_index = -1 + + # Iterate through the results to find the index of the lowest value + for i, (value, _) in enumerate(result): + if value < min_value: + min_value = value + min_index = i + + # Return the result with the index of the lowest value + return {"submission": {doc["video_id"]: min_index}, "score": {"pred": min_index, "ground_truth": doc["answer_id"]}} + + +def get_multi_choice_info(doc): + all_choices = [] + index2ans = {} + for i in range(len(OPTIONS)): + # import pdb;pdb.set_trace() + index2ans[OPTIONS[i]] = doc[f'answer_choice_{i}'].strip() + all_choices.append(OPTIONS[i]) + + return index2ans, all_choices + + +def parse_multi_choice_response(response, all_choices, index2ans): + """ + Parse the prediction from the generated response. + Return the predicted index e.g., A, B, C, D. + https://github.com/MMMU-Benchmark/MMMU/blob/51ce7f3e829c16bb44bc5445782686b4c3508794/eval/eval_utils.py#L10 + """ + for char in [",", ".", "!", "?", ";", ":", "'"]: + response = response.strip(char) + response = " " + response + " " # add space to avoid partial match + + index_ans = True + ans_with_brack = False + ans_with_space = False + ans_with_dot = False + candidates = [] + # import pdb; pdb.set_trace() + for choice in all_choices: # e.g., (a) (b) (c) (d) (e) + if f"({choice})" in response: + candidates.append(f"({choice})") + ans_with_brack = True + + for choice in all_choices: + if f"{choice}" == response.strip(): + candidates.append(f"{choice}") + ans_with_space = True + + # if len(candidates) == 0: + for choice in all_choices: # e.g., A B C D + if f"{choice} " in response: + candidates.append(f"{choice} ") + ans_with_space = True + + # if len(candidates) == 0: + for choice in all_choices: # e.g., A. B. C. D. + if f"{choice}." in response: + candidates.append(f"{choice}.") + ans_with_dot = True + + # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example + if len(candidates) == 0 and len(response.split()) > 5: + for index, ans in index2ans.items(): + if ans.lower() in response.lower(): + candidates.append(index) + index_ans = False # it's content ans. + + if len(candidates) == 0: # still not get answer, randomly choose one. + pred_index = random.choice(all_choices) + elif len(candidates) > 1: + # candidates = list(set(candidates)) + start_indexes = [] + if index_ans: + # if ans_with_brack: + for can in candidates: + index = response.rfind(can) + start_indexes.append(index) # -1 will be ignored anyway + # start_indexes = [generated_response.index(f'({can})') for can in candidates] + # if ans_with_space: + # for can in candidates: + # index = response.rfind(f"{can} ") + # start_indexes.append(index) + # if ans_with_dot: + # for can in candidates: + # index = response.rfind(f"{can}.") + # start_indexes.append(index) + # if not ans_with_brack and not ans_with_space and not ans_with_dot: + # for can in candidates: + # index = response.rfind(f" {can} ") + # start_indexes.append(index) + else: + for can in candidates: + index = response.lower().rfind(index2ans[can].lower()) + start_indexes.append(index) + # get the first one + pred_index = candidates[np.argmin(start_indexes)] + pred_index = pred_index.replace("(", "").replace(")", "").replace(".", "").strip() + else: # if only one candidate, use it. + pred_index = candidates[0] + pred_index = pred_index.replace("(", "").replace(")", "").replace(".", "").strip() + print(f"IN UTILS parse_multi_choice_response response: {response}, pred_index: {pred_index}, candidates: {candidates}, all_choices: {all_choices}, index2ans: {index2ans}") + + return pred_index, len(candidates) > 0 + + +# Process result for mcq answer generation +def minerva_process_results_generation(doc, result): + # import pdb;pdb.set_trace() + pred = result[0] + + index2ans, all_choices = get_multi_choice_info(doc) + parsed_pred, matched_tag = parse_multi_choice_response(pred, all_choices, index2ans) + + pred_to_index = {choice: i for i, choice in enumerate(all_choices)} + index = pred_to_index.get(parsed_pred, -1) # Default to -1 if the prediction is not found + + return {"submission": {doc["video_id"]: index}, "score": {"pred": index, "ground_truth": doc["answer_id"]}} + + +def minerva_aggregate_submissions(results, args, task): + now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + submission_file_name = f"inference_results_minerva_{task}_{now_date_time}.json" + path = file_utils.generate_submission_file(submission_file_name, args) + + # results is a list of 5031 dict, + # need to convert results into a single dict with 5031 key-value pairs + combined_submission = {} + + for submission_dict in results: + combined_submission.update(submission_dict) + + with open(path, "w") as f: + json.dump(combined_submission, f, indent=4) + + eval_logger.info(f"Submission file saved to {path}") + + +# Factory into different aggregate +def minerva_aggregate_mc(results, args): + minerva_aggregate_submissions(results, args, "MC") + + +def minerva_aggregate_mc_ppl(results, args): + minerva_aggregate_submissions(results, args, "MC_PPL") + + +def minerva_aggregate_score(results, args): + yes_count = 0 + + # results is a list of dict + for answer_dict in results: + if str(answer_dict["ground_truth"]) == str(answer_dict["pred"]): + yes_count = yes_count + 1 + + accuracy = yes_count / len(results) + + return accuracy + + +def minerva_doc_to_choice(doc): + return [doc[f'answer_choice_{i}'] for i in range(len(OPTIONS))] diff --git a/lmms_eval/video_samplers/__init__.py b/lmms_eval/video_samplers/__init__.py new file mode 100644 index 000000000..3be5204f6 --- /dev/null +++ b/lmms_eval/video_samplers/__init__.py @@ -0,0 +1,61 @@ +# samplers/__init__.py +from __future__ import annotations + +from typing import Dict, Type, Any, Optional, Iterable +import importlib +import pkgutil + +# ---- Registry -------------------------------------------------------------- + +VIDEO_SAMPLER_REGISTRY: Dict[str, Type] = {} + +def register_video_sampler(name: str, *, overwrite: bool = False): + """Class decorator to register a sampler class under a name.""" + def decorate(cls: Type): + if not overwrite and name in VIDEO_SAMPLER_REGISTRY: + raise KeyError( + f"Sampler name {name!r} already registered with " + f"{VIDEO_SAMPLER_REGISTRY[name]!r}" + ) + from .base import BaseVideoSampler + if not issubclass(cls, BaseVideoSampler): + raise TypeError(f"{cls.__name__} must subclass BaseVideoSampler") + VIDEO_SAMPLER_REGISTRY[name] = cls + return cls + return decorate + +# ---- Lookups --------------------------------------------------------------- + +def get_video_sampler_cls(name: str) -> Type: + """Return the registered class for a sampler `name`.""" + try: + return VIDEO_SAMPLER_REGISTRY[name] + except KeyError as e: + known = ", ".join(sorted(VIDEO_SAMPLER_REGISTRY)) or "" + raise KeyError(f"Unknown video sampler {name!r}. Known: {known}") from e + +# ---- Optional: auto-discover submodules so decorators run ------------------ + +def _auto_import_submodules(package_name: str, exclude: Iterable[str] = ()): + """Import all submodules of this package so @register_... executes. + + Call once at import time. Skips names in `exclude`. + """ + pkg = importlib.import_module(package_name) + if not hasattr(pkg, "__path__"): # not a package + return + for m in pkgutil.iter_modules(pkg.__path__): + mod_name = f"{package_name}.{m.name}" + if m.name in exclude: + continue + importlib.import_module(mod_name) + +# Import all samplers on package import (edit excludes as needed) +_auto_import_submodules(__name__, exclude=("base", "__init__")) + +__all__ = [ + "register_video_sampler", + "get_video_sampler", + "get_video_sampler_cls", + "VIDEO_SAMPLER_REGISTRY", +] diff --git a/lmms_eval/video_samplers/aks_sampler.py b/lmms_eval/video_samplers/aks_sampler.py new file mode 100644 index 000000000..a5e806d40 --- /dev/null +++ b/lmms_eval/video_samplers/aks_sampler.py @@ -0,0 +1,221 @@ +from .base import BaseVideoSampler +from . import register_video_sampler +from typing import Any, Dict, Optional +import torchvision.transforms as T +import sys +sys.path.append("./external") +from lavis.models import load_model_and_preprocess +from torchvision.transforms.functional import InterpolationMode +import numpy as np +import heapq +import torch +from PIL import Image + + +@register_video_sampler("aks") +class AKSVideoSampler(BaseVideoSampler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = self.__class__.__name__ + self.extract_feature_model = kwargs.get("extract_feature_model", "blip") + self.load_feature_model() + self.max_num_frames = kwargs.get("max_num_frames", 64) + self.ratio = kwargs.get("ratio", 1) + self.t1 = kwargs.get("t1", 0.8) + self.t2 = kwargs.get("t2", -100) + self.all_depth = kwargs.get("all_depth", 5) + + def load_feature_model(self): + if self.extract_feature_model == 'blip': + self.model, self.vis_processors, self.text_processors = load_model_and_preprocess("blip_image_text_matching", "large", device=self.device, is_eval=True) + self.vis_processors['eval_tensor'] = self.compose_to_tensor_transform(self.vis_processors["eval"].transform) + elif self.extract_feature_model == 'clip': + from transformers import CLIPModel, CLIPProcessor + self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.model.to(self.device) + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + elif self.extract_feature_model == 'sevila': + self.model, self.vis_processors, self.text_processors = load_model_and_preprocess(name="sevila", model_type="pretrain_flant5xl", is_eval=True, device=self.device) + self.vis_processors['eval_tensor'] = self.compose_to_tensor_transform(self.vis_processors["eval"].transform) + else: + raise ValueError(f"model {extract_feature_model} not supported") + + def meanstd(self, len_scores, dic_scores, n, fns,t1,t2,all_depth): + split_scores = [] + split_fn = [] + no_split_scores = [] + no_split_fn = [] + i= 0 + for dic_score, fn in zip(dic_scores, fns): + # normalized_data = (score - np.min(score)) / (np.max(score) - np.min(score)) + score = dic_score['score'] + depth = dic_score['depth'] + mean = np.mean(score) + std = np.std(score) + + top_n = heapq.nlargest(n, range(len(score)), score.__getitem__) + top_score = [score[t] for t in top_n] + # print(f"split {i}: ",len(score)) + i += 1 + mean_diff = np.mean(top_score) - mean + if mean_diff > t1 and std > t2: + no_split_scores.append(dic_score) + no_split_fn.append(fn) + elif depth < all_depth: + # elif len(score)>(len_scores/n)*2 and len(score) >= 8: + score1 = score[:len(score)//2] + score2 = score[len(score)//2:] + fn1 = fn[:len(score)//2] + fn2 = fn[len(score)//2:] + split_scores.append(dict(score=score1,depth=depth+1)) + split_scores.append(dict(score=score2,depth=depth+1)) + split_fn.append(fn1) + split_fn.append(fn2) + else: + no_split_scores.append(dic_score) + no_split_fn.append(fn) + if len(split_scores) > 0: + all_split_score, all_split_fn = self.meanstd(len_scores, split_scores, n, split_fn,t1,t2,all_depth) + else: + all_split_score = [] + all_split_fn = [] + all_split_score = no_split_scores + all_split_score + all_split_fn = no_split_fn + all_split_fn + + + return all_split_score, all_split_fn + + def compose_to_tensor_transform(self, pil_compose: T.Compose) -> T.Compose: + """ + Convert a PIL-based torchvision Compose (e.g., from LAVIS vis_processors['eval']) + into a tensor-native Compose. Common ops (Resize/CenterCrop/Normalize/ToTensor) + are translated to their tensor-friendly counterparts. + + Input expectation: + - Tensor images as (C,H,W) or (B,C,H,W), dtype uint8 or float. + - RGB order; handle permutes before calling if needed. + + Returns: + - A torchvision.transforms.Compose that works on torch.Tensors. + """ + mapped = [] + for t in pil_compose.transforms: + name = t.__class__.__name__ + + if name == "Resize": + mapped.append(T.Resize( + size=t.size, + interpolation=getattr(t, "interpolation", InterpolationMode.BILINEAR), + antialias=getattr(t, "antialias", True), + )) + + elif name == "CenterCrop": + mapped.append(T.CenterCrop(size=t.size)) + + elif name == "RandomResizedCrop": + mapped.append(T.RandomResizedCrop( + size=t.size, + scale=t.scale, + ratio=t.ratio, + interpolation=getattr(t, "interpolation", InterpolationMode.BILINEAR), + antialias=getattr(t, "antialias", True), + )) + + elif name == "ToTensor": + # For tensor input, we only need dtype/scale (PIL->Tensor step not needed). + mapped.append(T.ConvertImageDtype(torch.float32)) + + elif name == "Normalize": + mapped.append(T.Normalize(mean=t.mean, std=t.std, inplace=getattr(t, "inplace", False))) + + else: + # Keep unknown transforms as-is (may still expect PIL). + mapped.append(t) + + return T.Compose(mapped) + + def get_frames(self, vr, frame_num, backend): # (T, C, H, W) + if backend == 'torchcodec': + full_raw_image_tensors = vr.get_frames_at(indices=frame_num).data + elif backend == 'decord': + full_raw_image_tensors = torch.from_numpy(vr.get_batch(frame_num).asnumpy().permute(0, 3, 1, 2)) + elif backend == 'torchvision': + full_raw_image_tensors = vr[frame_num] + else: + raise ValueError(f"backend {backend} not supported") + return full_raw_image_tensors + + def sample(self, ele: Any, **kwargs) -> Optional[Dict[str, Any]]: + # TODO: Implement AKS sampling + video_path = ele["video"] + text = ele['question'] + vr = ele['video_reader'] + fps = ele['video_fps'] + frame_nums = int(ele["total_frames"]/int(fps)) + frame_num = [j*int(fps) for j in range(frame_nums)] + score = [] + + if self.extract_feature_model == 'blip': + txt = self.text_processors["eval"](text) + with torch.no_grad(): + for i in range(0, len(frame_num), self.batch_size): + batch = frame_num[i:i+self.batch_size] + full_raw_image_tensors = self.get_frames(vr, batch, ele['video_reader_backend']) + imgs = self.vis_processors['eval_tensor'](full_raw_image_tensors).to(self.device) + blip_output = self.model({"image": imgs, "text_input": [txt]*imgs.shape[0]}, match_head="itm") + blip_scores = torch.nn.functional.softmax(blip_output, dim=1) + score.extend(blip_scores[:, 1].tolist()) + elif self.extract_feature_model == 'clip': + inputs_text = self.processor(text=text, return_tensors="pt", padding=True,truncation=True).to(self.device) + text_features = self.model.get_text_features(**inputs_text) + for j in range(frame_nums): + raw_image_tensor = vr[j*int(fps)] + if ele['video_reader_backend'] != 'decord': + raw_image_tensor = raw_image_tensor.permute(1,2,0) + raw_image = np.array(raw_image_tensor.cpu()) + raw_image = Image.fromarray(raw_image) + inputs_image = self.processor(images=raw_image, return_tensors="pt", padding=True).to(self.device) + with torch.no_grad(): + image_features = self.model.get_image_features(**inputs_image) + clip_score = torch.nn.CosineSimilarity(dim=-1)(text_features, image_features) + score.append(clip_score.item()) + frame_num.append(j*int(fps)) + + elif self.extract_feature_model == 'sevila': + text = 'Question: ' + data['question'] + ' Candidate: ' + for j,cad in enumerate(data['answer_choices']): + text = text + ". ".join([chr(ord("A")+j), cad]) + ' ' + text = text + '. Is this a good frame that can answer the question?' + txt = self.text_processors["eval"](text) + full_raw_image_tensors = self.get_frames(vr, frame_num, ele['video_reader_backend']) + with torch.no_grad(): + for batch in full_raw_image_tensors: + imgs = self.vis_processors['eval_tensor'](batch).unsqueeze(1).to(self.device) + samples = {'video':imgs,'loc_input':[txt]*imgs.shape[0]} + sevila_score = self.model.generate_score(samples).squeeze(1) + score.append(sevila_score) + score = torch.cat(score, dim=0).detach().cpu().numpy() + else: + raise ValueError(f"model {self.extract_feature_model} not supported") + + nums = int(len(score)/self.ratio) + new_score = [score[num*self.ratio] for num in range(nums)] + new_fnum = [frame_num[num*self.ratio] for num in range(nums)] + score = new_score + fn = new_fnum + num = self.max_num_frames + if len(score) >= num: + normalized_data = (score - np.min(score)) / (np.max(score) - np.min(score)) + a, b = self.meanstd(len(score), [dict(score=normalized_data,depth=0)], num, [fn], self.t1, self.t2, self.all_depth) + out = [] + if len(score) >= num: + for s,f in zip(a,b): + f_num = int(num / 2**(s['depth'])) + topk = heapq.nlargest(f_num, range(len(s['score'])), s['score'].__getitem__) + f_nums = [f[t] for t in topk] + out.extend(f_nums) + out.sort() + return out + else: + return fn \ No newline at end of file diff --git a/lmms_eval/video_samplers/base.py b/lmms_eval/video_samplers/base.py new file mode 100644 index 000000000..4f79780b1 --- /dev/null +++ b/lmms_eval/video_samplers/base.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from lmms_eval import utils +from accelerate import Accelerator + + +class BaseVideoSampler(ABC): + """Abstract base class for video samplers used by multimodal models. + + Samplers take a raw ``visual`` item (e.g. a video path or frame list) and + return a dictionary payload understood by the downstream processor. + """ + + def __init__(self, *args, **kwargs): + self.batch_size = int(kwargs.get("batch_size", 128)) + device_pref = kwargs.get("device", None) + accelerator = Accelerator(cpu=(device_pref == "cpu")) + self.accelerator = accelerator + self.device = accelerator.device + self.return_frames = False + self.will_process_messages = False + + @classmethod + def create_from_arg_string(cls: Type[T], arg_string: str, additional_config: Optional[dict] = None) -> T: + """ + Creates an instance of the LMM class using the given argument string and additional config. + + Parameters: + - arg_string: A string containing arguments in the format key1=value1,key2=value2. + - additional_config: Optional dictionary containing additional configuration parameters. + + Returns: + - Instance of the video sampler class. + """ + additional_config = {} if additional_config is None else additional_config + args = utils.simple_parse_args_string(arg_string) + args2 = {k: v for k, v in additional_config.items() if v is not None} + return cls(**args, **args2) + + @abstractmethod + def sample( + self, + ele: Any, + **kwargs + ) -> Optional[Dict[str, Any]]: + """Return a processed representation for ``visual``. + + Implementations may return ``None`` to signal that the input should be + skipped. + """ + + def __call__( + self, + ele: Any, + **kwargs + ) -> Optional[Dict[str, Any]]: + return self.sample( + ele, + **kwargs + ) + diff --git a/lmms_eval/video_samplers/dkts_sampler.py b/lmms_eval/video_samplers/dkts_sampler.py new file mode 100644 index 000000000..8c479599f --- /dev/null +++ b/lmms_eval/video_samplers/dkts_sampler.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + +import sys +sys.path.append("./external/D-KTS") +import torchvision.models as models +from feature_extractor import dhash, hanming +import torch +from torch.autograd import Variable +import numpy as np +from kts.auto_alter import cpd_auto2 +from kts.nonlin_alter import kernel + +from .base import BaseVideoSampler +from . import register_video_sampler + +from qwen_vl_utils.vision_process import smart_nframes + +@register_video_sampler("dkts") +class DKTSVideoSampler(BaseVideoSampler): + """Sampler that samples frames uniformly from a video.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = self.__class__.__name__ + self.threshold = kwargs.get("threshold", 4.0) + self.v = kwargs.get("v", 1.0) # vmax in penalty + self.use_cpu = kwargs.get("use_cpu", False) + self.extract_frequency = kwargs.get("extract_frequency", 1) + googlenet = models.googlenet(pretrained=True) + self.googlenet = torch.nn.Sequential(*list(googlenet.children())[:-2]) + self.googlenet.eval() + if not self.use_cpu: + self.googlenet = self.googlenet.cuda() + else: + print('Using CPU......') + + def get_frames(self, vr, frame_num, backend): # (T, C, H, W) + if backend == 'torchcodec': + full_raw_image_tensors = vr.get_frames_at(indices=frame_num).data + elif backend == 'decord': + full_raw_image_tensors = torch.from_numpy(vr.get_batch(frame_num).asnumpy().permute(0, 3, 1, 2)) + elif backend == 'torchvision': + full_raw_image_tensors = vr[frame_num] + else: + raise ValueError(f"backend {backend} not supported") + return full_raw_image_tensors + + def get_features(self, ele: Any): # from one_video_hash + vr = ele['video_reader'] + fps = ele['video_fps'] + frame_count = ele["total_frames"] + if self.extract_frequency == "fps": + self.extract_frequency = int(fps) + + frames=[] + video_features = [] + count = 0 + skip_count = 0 + arr = [] + + with torch.no_grad(): + base = None + hash1 = None + for count in range(0, frame_count, self.extract_frequency): + if count % self.extract_frequency == 0: + fr = self.get_frames(vr, [count], ele['video_reader_backend'])[0].permute(1,2,0).numpy() + hash2=dhash(fr) + if hash1 is not None: + dist = hanming(hash1,hash2) + if base is None or dist > self.threshold: + base = fr + hash1 = hash2 + frames.append(np.rollaxis(fr, 2)) + arr.append(skip_count) + skip_count = 0 + else: + skip_count += 1 + frames.append(np.rollaxis(base, 2)) + if (len(frames) == self.batch_size) or (count >= frame_count and len(frames) > 0): + batch = np.array(frames) + if self.use_cpu: + variable = Variable(torch.from_numpy(batch).float()) + feature = self.googlenet(variable).detach().numpy() + else: + variable = Variable(torch.from_numpy(batch).float()).cuda() + feature = self.googlenet(variable).cpu().detach().numpy() + video_features.extend(feature) + frames.clear() + video_features = np.squeeze(np.array(video_features)) + duration = frame_count/fps + picks = np.arange(0, video_features.shape[0]) * self.extract_frequency + return { + "n_frames": int(frame_count), + "features": video_features, + "picks": picks, + "duration": duration, + "skip_arr": arr, + } + + def kts_run(self, ele: Any): + features = self.get_features(ele) + X = features['features'] + n_frames = features['n_frames'] + n = X.shape[0] + n1 = min(n, 338) # 95% + m = round(n_frames / 106 * 2) + + K1 = kernel(X, X.T, n1) + cps1, scores1 = cpd_auto2(K1, m, self.v,self.extract_frequency) + cps1 *= self.extract_frequency + cps1 = np.hstack((0, cps1, n_frames)) + begin_frames = cps1[:-1] + end_frames = cps1[1:] + cps1 = np.vstack((begin_frames, end_frames - 1)).T + return cps1 + + def sample( + self, + ele: Any, + **kwargs + ) -> Tuple[List[int], int]: + cps = self.kts_run(ele) + frame_indices = [] + for i in range(cps.shape[0]): + start_frame = cps[i, 0] + end_frame = cps[i, 1] + frame_idx = int((start_frame+end_frame)/2) + frame_indices.append(frame_idx) + return frame_indices + diff --git a/lmms_eval/video_samplers/fps_sampler.py b/lmms_eval/video_samplers/fps_sampler.py new file mode 100644 index 000000000..9abe6ca06 --- /dev/null +++ b/lmms_eval/video_samplers/fps_sampler.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + + +import torch + +from .base import BaseVideoSampler +from . import register_video_sampler + +from qwen_vl_utils.vision_process import smart_nframes + +@register_video_sampler("fps") +class FPSVideoSampler(BaseVideoSampler): + """Sampler that samples frames by fps from a video.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = self.__class__.__name__ + self.fps = kwargs.get("fps", 1) + + def sample( + self, + ele: Any, + **kwargs + ) -> Tuple[List[int], int]: + del ele["nframes"] + ele["fps"] = self.fps + nframes = smart_nframes(ele, total_frames=ele["total_frames"], video_fps=self.fps) + idx = torch.linspace(ele["start_frame"], ele["end_frame"], nframes).round().long().tolist() + return idx + diff --git a/lmms_eval/video_samplers/mg_sampler.py b/lmms_eval/video_samplers/mg_sampler.py new file mode 100644 index 000000000..e3641147b --- /dev/null +++ b/lmms_eval/video_samplers/mg_sampler.py @@ -0,0 +1,139 @@ +# Infeasible for long vidoes because of the memory usage +from __future__ import annotations + +from typing import Any, List, Tuple + +import imageio +import numpy as np +from skimage.metrics import structural_similarity as compare_ssim +import cv2 +from multiprocessing.dummy import Pool as ThreadPool +import random + +from .base import BaseVideoSampler +from . import register_video_sampler + +from qwen_vl_utils.vision_process import smart_nframes + +@register_video_sampler("mgsampler") +class MGVideoSampler(BaseVideoSampler): + """Sample frames from the video. + Required keys are "filename", "total_frames", "start_index" , added or + modified keys are "frame_inds", "frame_interval" and "num_clips". + Args: + clip_len (int): Frames of each sampled output clip. + frame_interval (int): Temporal interval of adjacent sampled frames. + Default: 1. + num_clips (int): Number of clips to be sampled. Default: 1. + temporal_jitter (bool): Whether to apply temporal jittering. + Default: False. + twice_sample (bool): Whether to use twice sample when testing. + If set to True, it will sample frames with and without fixed shift, + which is commonly used for testing in TSM model. Default: False. + out_of_bound_opt (str): The way to deal with out of bounds frame + indexes. Available options are 'loop', 'repeat_last'. + Default: 'loop'. + test_mode (bool): Store True when building test or validation dataset. + Default: False. + start_index (None): This argument is deprecated and moved to dataset + class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc), + see this: https://github.com/open-mmlab/mmaction2/pull/89. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = self.__class__.__name__ + self.num_frames = kwargs.get("num_frames", None) + self.test_mode = kwargs.get("test_mode", True) + + def multiplication(self,video_path): + + img = list() + img_diff = [] + try: + + vid = imageio.get_reader(video_path, 'ffmpeg') + for num, im in enumerate[Any](vid): + img.append(im) + for i in range(len(img) - 1): + tmp1 = cv2.cvtColor(img[i], cv2.COLOR_RGB2GRAY) + tmp2 = cv2.cvtColor(img[i + 1], cv2.COLOR_RGB2GRAY) + (score, diff) = compare_ssim(tmp1, tmp2, full=True) + score = 1 - score + img_diff.append(score) + except(OSError): + video_name = (video_path.split('/')[-1]).split('.')[0] + raise ValueError(f"error! {video_name}") + return img_diff + + def sample( + self, + ele: Any, + **kwargs + ) -> Tuple[List[int], int]: + video_path = ele["video"] + + def find_nearest(array, value): + array = np.asarray(array) + try: + idx = (np.abs(array - value)).argmin() + return int(idx + 1) + except(ValueError): + raise ValueError(f"error! {video_path}") + + diff_score = self.multiplication(video_path) + diff_score = np.power(diff_score, 0.5) + sum_num = np.sum(diff_score) + diff_score = diff_score / sum_num + + count = 0 + pic_diff = list() + for i in range(len(diff_score)): + count = count + diff_score[i] + pic_diff.append(count) + + choose_index = list() + + if self.test_mode: + choose_index.append(find_nearest(pic_diff, 1 / 32)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 1 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 2 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 3 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 4 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 5 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 6 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 7 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 8 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 9 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 10 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 11 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 12 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 13 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 14 / 16)) + choose_index.append(find_nearest(pic_diff, 1 / 32 + 15 / 16)) + + else: + choose_index.append(find_nearest(pic_diff, random.uniform(0, 1 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(1 / 16, 2 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(2 / 16, 3 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(3 / 16, 4 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(4 / 16, 5 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(5 / 16, 6 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(6 / 16, 7 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(7 / 16, 8 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(8 / 16, 9 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(9 / 16, 10 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(10 / 16, 11 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(11 / 16, 12 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(12 / 16, 13 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(13 / 16, 14 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(14 / 16, 15 / 16))) + choose_index.append(find_nearest(pic_diff, random.uniform(15 / 16, 16 / 16))) + + return choose_index + + + + + + diff --git a/lmms_eval/video_samplers/qframe_sampler.py b/lmms_eval/video_samplers/qframe_sampler.py new file mode 100644 index 000000000..66f78a945 --- /dev/null +++ b/lmms_eval/video_samplers/qframe_sampler.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + + +import torch + +from .base import BaseVideoSampler +from . import register_video_sampler + +from qwen_vl_utils.vision_process import smart_nframes, smart_resize + +import random +from PIL import Image + +import sys +import cv2 +from io import BytesIO +import base64 +from external.longclip.model import longclip +import numpy as np +from typing import List, Tuple +import decord + +@register_video_sampler("qframe") +class QFrameVideoSampler(BaseVideoSampler): + """Sampler that samples frames uniformly from a video.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = self.__class__.__name__ + self.num_frames = int(kwargs.get("num_frames", 32)) + self.model_path = kwargs.get("qframe_model_path", "external/longclip/checkpoints/longclip-L.pt") + self.clip_model, self.clip_processor = longclip.load(self.model_path, device=self.device) + self.tau = float(kwargs.get("tau", 0.8)) + self.high_frames = int(kwargs.get("high_frames", 6)) + self.mid_frames = int(kwargs.get("mid_frames", 6)) + self.low_frames = int(kwargs.get("low_frames", 8)) + self.return_frames = True + self.baseline = kwargs.get("baseline", False) + self.will_process_messages = True + + def text_image_matching(self, question, images, tau=1.0): + + # print(f"{text}\n{'-'*100}\n{question}") + with torch.no_grad(), torch.cuda.amp.autocast(): + text = longclip.tokenize([question]).to(self.device) + images = torch.stack([self.clip_processor(Image.fromarray(image)) for image in images]).to(self.device) + + image_features = self.clip_model.encode_image(images) + text_features = self.clip_model.encode_text(text) + logits_per_text = text_features @ image_features.T # this is the image-text similarity score + + probs = (logits_per_text / tau).softmax(dim=1)[0] + + probs = torch.log(probs) - torch.log(-torch.log(torch.rand(len(images), device=probs.device) + 1e-10) + 1e-10) # gumble + + indices = np.argsort(-probs.cpu().detach().numpy()) + + return indices + + def load_video(self, video_path, num_frames, fps=1, force_sample=False): + if num_frames == 0: + return np.zeros((1, 336, 336, 3)) + vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + fps = round(vr.get_avg_fps() / fps) + frame_idx = [i for i in range(0, len(vr), fps)] + frame_time = [i / fps for i in frame_idx] + if len(frame_idx) > num_frames or force_sample: + sample_fps = num_frames + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i / vr.get_avg_fps() for i in frame_idx] + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + spare_frames = vr.get_batch(frame_idx).numpy() + + video_metadata = dict( + fps=vr.get_avg_fps(), + frames_indices=frame_idx, + total_num_frames=len(vr), + video_backend="decord", + ) + + return spare_frames, frame_idx, frame_time, video_time, video_metadata + + def process_messages(self, chat_message, eval_logger): + messages = chat_message[0]["content"] + new_messages = [] + for i, message in enumerate(messages): + if message["type"] in ["video", "image"]: + visual = message["url"] + question = message["question"] + if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file + # modify the video processing to multi-image processing + """ + vr = decord.VideoReader(visual) + first_frame = vr[0].asnumpy() + height, width = first_frame.shape[:2] + # max_pixels = height * width + message.append({"role": "user", "content": [{"type": "video", "video": visual, "max_pixels": self.max_pixels}, {"type": "text", "text": context}]}) + """ + visual, frame_idx, frame_time, video_time, video_metadata = self.load_video(visual, self.num_frames) + + try: + indices = self.text_image_matching(question, visual, tau=self.tau) + + visual = [Image.fromarray(v).convert("RGB") for v in visual] + if not self.baseline: + visual_tmp = [None] * len(visual) + width, height = visual[0].size + for idx in indices[:self.high_frames]: + visual_tmp[idx] = visual[idx].resize((width // 2, height // 2), Image.Resampling.LANCZOS) + for idx in indices[self.high_frames: self.high_frames+self.mid_frames]: + visual_tmp[idx] =visual[idx].resize((width // 4, height // 4), Image.Resampling.LANCZOS) + for idx in indices[self.high_frames+self.mid_frames: self.high_frames+self.mid_frames+self.low_frames]: + visual_tmp[idx] =visual[idx].resize((width // 8, height // 8), Image.Resampling.LANCZOS) + visual = [v for v in visual_tmp if v is not None ] + else: + visual_tmp = [None] * len(visual) + for idx in indices: + visual_tmp[idx] = visual[idx] + visual = [v for v in visual_tmp if v is not None ] + except Exception as e: + eval_logger.info(f"{e}") + if len(visual) >= self.sample_frames: + visual = visual[sorted(random.sample(range(len(visual)), self.sample_frames))] + height, width, _ = visual[0].shape + visual = [Image.fromarray(v).convert("RGB").resize((width // 2, height // 2), Image.Resampling.LANCZOS) for v in visual] + + image_content = [] + for base64_image in visual: + # base64_image = Image.fromarray(v).convert("RGB") + buffer = BytesIO() + base64_image.save(buffer, format="JPEG") + base64_bytes = base64.b64encode(buffer.getvalue()) + base64_string = base64_bytes.decode("utf-8") + new_messages.append({"type": "image", "url": f"data:image/jpeg;base64,{base64_string}"}) + + elif isinstance(visual, Image.Image): # Single image + base64_image = visual.convert("RGB") + buffer = BytesIO() + base64_image.save(buffer, format="JPEG") + base64_bytes = base64.b64encode(buffer.getvalue()) + base64_string = base64_bytes.decode("utf-8") + new_messages.append({"type": "image", "url": f"data:image/jpeg;base64,{base64_string}"}) + elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images + for v in visual: + base64_image = v.convert("RGB") + buffer = BytesIO() + base64_image.save(buffer, format="JPEG") + base64_bytes = base64.b64encode(buffer.getvalue()) + base64_string = base64_bytes.decode("utf-8") + new_messages.append({"type": "image", "url": f"data:image/jpeg;base64,{base64_string}"}) + else: + raise ValueError(f"Invalid visual type: {type(visual)}") + else: + new_messages.append(message) + chat_message[0]["content"] = new_messages + return chat_message, video_metadata + + def sample( + self, + ele: Any, + **kwargs + ) -> Tuple[List[int], int]: + if self.num_frames is not None: + ele["nframes"] = self.num_frames + nframes = smart_nframes(ele, total_frames=ele["total_frames"], video_fps=ele["video_fps"]) + idx = torch.linspace(ele["start_frame"], ele["end_frame"], nframes).round().long().tolist() + return idx + + # def load_video(self, ele, max_num_frames, fps=1, force_sample=False): + # if max_num_frames == 0: + # return np.zeros((1, 336, 336, 3)) + # vr = ele["video_reader"] + # total_frame_num = ele["total_frames"] + # video_fps = ele["video_fps"] + # video_time = total_frame_num / video_fps + # fps = round(video_fps / fps) + # frame_idx = [i for i in range(0, total_frame_num, fps)] + # frame_time = [i / fps for i in frame_idx] + # if len(frame_idx) > max_num_frames or force_sample: + # sample_fps = max_num_frames + # uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) + # frame_idx = uniform_sampled_frames.tolist() + # frame_time = [i / video_fps for i in frame_idx] + # frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + # spare_frames = vr.get_batch(frame_idx).asnumpy() + # return spare_frames, frame_idx, frame_time, video_time + + + # def sample( + # self, + # ele: Any, + # **kwargs + # ) -> Tuple[List[int], int]: + # visual, frame_idx, frame_time, video_time = self.load_video(ele, self.max_num_frames) + + # indices = self.text_image_matching(ele["question"], visual, tau=self.tau) + + # if ele["video_reader_backend"] == "decord": + # T, H, W, C = visual.shape + # else: + # T, C, H, W = visual.shape + # visual_tmp = [None] * len(visual) + # for idx in indices[:self.high_frames]: + # visual_tmp[idx] = cv2.resize(visual[idx], (W//2, H//2), interpolation=cv2.INTER_LANCZOS4) + + # for idx in indices[self.high_frames:self.high_frames+self.mid_frames]: + # visual_tmp[idx] = cv2.resize(visual[idx], (W//4, H//4), interpolation=cv2.INTER_LANCZOS4) + + # for idx in indices[self.high_frames+self.mid_frames:self.high_frames+self.mid_frames+self.low_frames]: + # visual_tmp[idx] = cv2.resize(visual[idx], (W//8, H//8), interpolation=cv2.INTER_LANCZOS4) + + # # TODO DEBUG THIS + # patch_factor = int(image_patch_size * SPATIAL_MERGE_SIZE) + # for v in visual_tmp: + # if v is None: + # continue + # if "resized_height" in ele and "resized_width" in ele: + # resized_height, resized_width = smart_resize( + # ele["resized_height"], + # ele["resized_width"], + # factor=patch_factor, + # ) + # else: + # width, height = W, H + # min_pixels = ele.get("min_pixels", IMAGE_MIN_TOKEN_NUM * patch_factor ** 2) + # max_pixels = ele.get("max_pixels", IMAGE_MAX_TOKEN_NUM * patch_factor ** 2) + # resized_height, resized_width = smart_resize( + # height, + # width, + # factor=patch_factor, + # min_pixels=min_pixels, + # max_pixels=max_pixels, + # ) + # image = cv2.resize((resized_width, resized_height)) + # visual.append(image) + + # nframes = smart_nframes(ele, total_frames=ele["total_frames"], video_fps=ele["video_fps"]) + # idx = torch.linspace(ele["start_frame"], ele["end_frame"], nframes).round().long().tolist() + # return idx, visual + diff --git a/lmms_eval/video_samplers/uniform_sampler.py b/lmms_eval/video_samplers/uniform_sampler.py new file mode 100644 index 000000000..68615158b --- /dev/null +++ b/lmms_eval/video_samplers/uniform_sampler.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional + + +import torch + +from .base import BaseVideoSampler +from . import register_video_sampler + +from qwen_vl_utils.vision_process import smart_nframes + +@register_video_sampler("uniform") +class UniformVideoSampler(BaseVideoSampler): + """Sampler that samples frames uniformly from a video.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = self.__class__.__name__ + self.num_frames = kwargs.get("num_frames", None) + + def sample( + self, + ele: Any, + **kwargs + ) -> Tuple[List[int], int]: + if self.num_frames is not None: + ele["nframes"] = self.num_frames + nframes = smart_nframes(ele, total_frames=ele["total_frames"], video_fps=ele["video_fps"]) + idx = torch.linspace(ele["start_frame"], ele["end_frame"], nframes).round().long().tolist() + return idx +