diff --git a/eval_mm/mmbench/MMBENCH.md b/eval_mm/mmbench/MMBENCH.md new file mode 100644 index 0000000..1a87cad --- /dev/null +++ b/eval_mm/mmbench/MMBENCH.md @@ -0,0 +1,61 @@ +# MMBench Evaluation + +## Data + +```bash +/cpfs01/shared/public/shusheng.yss/workspace/23082502_qwenvl_eval_test/eval_mm/data/mmbench +``` + +## Dev + +```bash +checkpoint=/PATH/TO/CHECKPOINT +ds=mmbench_dev_20230712 +python -m torch.distributed.launch --use-env \ + --nproc_per_node ${NPROC_PER_NODE:-8} \ + --nnodes ${WORLD_SIZE:-1} \ + --node_rank ${RANK:-0} \ + --master_addr ${MASTER_ADDR:-127.0.0.1} \ + --master_port ${MASTER_PORT:-12345} \ + evaluate_multiple_choice_mmbench.py \ + --checkpoint $checkpoint \ + --dataset $ds \ + --batch-size 2 \ + --num-workers 2 + +# the results will be saved to mmbench_dev_20230712.json + +# without consistency constrain + +python mmbench_evaluation.py + +# with consistency constrain + +python mmbench_evaluation_tricky.py + +``` + +## Test + +```bash +checkpoint=/PATH/TO/CHECKPOINT +ds=mmbench_test_20230712 +python -m torch.distributed.launch --use-env \ + --nproc_per_node ${NPROC_PER_NODE:-8} \ + --nnodes ${WORLD_SIZE:-1} \ + --node_rank ${RANK:-0} \ + --master_addr ${MASTER_ADDR:-127.0.0.1} \ + --master_port ${MASTER_PORT:-12345} \ + evaluate_multiple_choice_mmbench.py \ + --checkpoint $checkpoint \ + --dataset $ds \ + --batch-size 2 \ + --num-workers 2 + +# the results will be saved to mmbench_test_20230712.json + +# convert to submission format with consistency constrain + +python mmbench_predict_to_submission.py + +``` diff --git a/eval_mm/mmbench/evaluate_multiple_choice_mmbench.py b/eval_mm/mmbench/evaluate_multiple_choice_mmbench.py new file mode 100644 index 0000000..5ac4a56 --- /dev/null +++ b/eval_mm/mmbench/evaluate_multiple_choice_mmbench.py @@ -0,0 +1,189 @@ +import argparse +import itertools +import json +import os +from functools import partial + +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +multiple_choices = ['A', 'B', 'C', 'D', 'E'] + +ds_collections = { + 'mmbench_dev_20230712': { + 'test': 'data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.jsonl', + }, + 'mmbench_test_20230712': { + 'test': 'data/mmbench/mmbench_test_20230712/mmbench_test_20230712.jsonl', + } +} + +def collate_fn(batches, pad_token_id): + + indexes = [_['index'] for _ in batches] + + input_tokens = [_['input_tokens'] for _ in batches] + target_lengths = [_['target_lengths'] for _ in batches] + + chunk_sizes = [len(_) for _ in input_tokens] + + input_tokens = [_ for _ in itertools.chain.from_iterable(input_tokens)] + + max_lengths = max([len(_) for _ in input_tokens]) + input_tokens = [[pad_token_id] * (max_lengths - len(_)) + _ + for _ in input_tokens] + input_tokens = torch.LongTensor(input_tokens) + + attention_mask = 1 - input_tokens.eq(pad_token_id).float() + + return input_tokens, attention_mask, target_lengths, chunk_sizes, indexes + + +class MultipleChoiceDataste(torch.utils.data.Dataset): + + def __init__(self, test, prompt, tokenizer): + self.datas = open(test).readlines() + self.prompt = prompt + self.tokenizer = tokenizer + + def __len__(self): + return len(self.datas) + + def __getitem__(self, idx): + + data = json.loads(self.datas[idx].strip()) + index = data['index'] + image = data['image'] + hint = data['hint'] if data['hint'] else 'N/A' + question = data['question'] + + choices = data['choices'] + choice_list = [] + for i, c in enumerate(choices): + choice_list.append('{}. {}'.format(multiple_choices[i], c)) + choice_txt = '\n'.join(choice_list) + + prompt = self.prompt.format(image, hint, question, choice_txt) + + prompt_tokens = self.tokenizer(prompt).input_ids + target_tokens = [ + self.tokenizer(' ' + _).input_ids + for _ in multiple_choices[:len(choices)] + ] + + return { + 'index': index, + 'input_tokens': [prompt_tokens + _ for _ in target_tokens], + 'target_lengths': [len(_) for _ in target_tokens], + # 'answer': data['answer'], + } + + +class InferenceSampler(torch.utils.data.sampler.Sampler): + + def __init__(self, size): + self._size = int(size) + assert size > 0 + self._rank = torch.distributed.get_rank() + self._world_size = torch.distributed.get_world_size() + self._local_indices = self._get_local_indices(size, self._world_size, + self._rank) + + @staticmethod + def _get_local_indices(total_size, world_size, rank): + shard_size = total_size // world_size + left = total_size % world_size + shard_sizes = [shard_size + int(r < left) for r in range(world_size)] + + begin = sum(shard_sizes[:rank]) + end = min(sum(shard_sizes[:rank + 1]), total_size) + return range(begin, end) + + def __iter__(self): + yield from self._local_indices + + def __len__(self): + return len(self._local_indices) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--checkpoint', type=str, default='') + parser.add_argument('--dataset', type=str, default='') + parser.add_argument('--batch-size', type=int, default=1) + parser.add_argument('--num-workers', type=int, default=1) + args = parser.parse_args() + + torch.distributed.init_process_group( + backend='nccl', + world_size=int(os.getenv('WORLD_SIZE', '1')), + rank=int(os.getenv('RANK', '0')), + ) + + torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint, device_map='cuda', trust_remote_code=True).eval() + + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, + trust_remote_code=True) + + prompt = '{}Context: {}\nQuestion: {}\nOptions: {}\nAnswer:' + + dataset = MultipleChoiceDataste(test=ds_collections[args.dataset]['test'], + prompt=prompt, + tokenizer=tokenizer) + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + sampler=InferenceSampler(len(dataset)), + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=True, + drop_last=False, + collate_fn=partial(collate_fn, pad_token_id=tokenizer.eod_id), + ) + + results = [] + with torch.no_grad(): + for _, (input_tokens, attention_mask, target_lengths, + chunk_sizes, indexes) in tqdm(enumerate(dataloader)): + + outputs = model( + input_ids=input_tokens[:, :-1].cuda(), + attention_mask=attention_mask[:, :-1].cuda(), + return_dict=True, + ) + losses = torch.nn.functional.cross_entropy(outputs.logits.permute( + 0, 2, 1), + input_tokens[:, + 1:].cuda(), + reduction='none') + + losses = losses.split(chunk_sizes, dim=0) + + for loss, target_length, index in zip(losses, target_lengths, indexes): + + target_loss = loss.mean(-1) + for _ in range(len(target_length)): + target_loss[_] = loss[_, -target_length[_]:].mean() + pred = target_loss.argmin().item() + + results.append({ + "index": index, + "prediction": pred, + }) + + torch.distributed.barrier() + + world_size = torch.distributed.get_world_size() + merged_results = [None for _ in range(world_size)] + torch.distributed.all_gather_object(merged_results, results) + + merged_results = [_ for _ in itertools.chain.from_iterable(merged_results)] + + if torch.distributed.get_rank() == 0: + json.dump(merged_results, open(f"{args.dataset}.json", "w")) + + torch.distributed.barrier() diff --git a/eval_mm/mmbench/mmbench_converter_dev.py b/eval_mm/mmbench/mmbench_converter_dev.py new file mode 100644 index 0000000..a1eb9c5 --- /dev/null +++ b/eval_mm/mmbench/mmbench_converter_dev.py @@ -0,0 +1,48 @@ +import pandas as pd +import io +import base64 +import json +from PIL import Image + +''' +This scripts convert mmbench_dev tsv file to jsonl +''' + +datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t') + +global_choices = ['A', 'B', 'C', 'D'] + +def decode_base64_to_image(base64_string): + image_data = base64.b64decode(base64_string) + image = Image.open(io.BytesIO(image_data)) + return image + + +with open('./data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.jsonl', 'w') as f: + for idx in range(len(datas)): + data = datas.iloc[idx] + + index = int(data['index']) + question = data['question'] + hint = data['hint'] if not pd.isna(data['hint']) else 'N/A' + + choices = [] + for opt in global_choices: + if pd.isna(data[opt]): + continue + choices.append(data[opt]) + + answer = global_choices.index(data['answer']) + + image = decode_base64_to_image(data['image']) + image.save("data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index) + + f.write(json.dumps({ + "index": index, + "image": "data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index, + "hint": hint, + "question": question, + "choices": choices, + "answer": answer, + }) + "\n") + diff --git a/eval_mm/mmbench/mmbench_converter_test.py b/eval_mm/mmbench/mmbench_converter_test.py new file mode 100644 index 0000000..894e766 --- /dev/null +++ b/eval_mm/mmbench/mmbench_converter_test.py @@ -0,0 +1,49 @@ +import pandas as pd +import io +import base64 +import json +from PIL import Image + +''' +This script convert mmbench_test tsv file to jsonl +This script is very similar to mmbench_converter_dev except there's no answer for accuracy calculation +''' + +datas = pd.read_csv("data/mmbench/mmbench_test_20230712/mmbench_test_20230712.tsv", sep='\t') + +global_choices = ['A', 'B', 'C', 'D'] + +def decode_base64_to_image(base64_string): + image_data = base64.b64decode(base64_string) + image = Image.open(io.BytesIO(image_data)) + return image + + +with open('./data/mmbench/mmbench_test_20230712/mmbench_test_20230712.jsonl', 'w') as f: + for idx in range(len(datas)): + data = datas.iloc[idx] + + index = int(data['index']) + question = data['question'] + hint = data['hint'] if not pd.isna(data['hint']) else 'N/A' + + choices = [] + for opt in global_choices: + if pd.isna(data[opt]): + continue + choices.append(data[opt]) + + # answer = global_choices.index(data['answer']) + + image = decode_base64_to_image(data['image']) + image.save("data/mmbench/mmbench_test_20230712/images/%d.jpg" % index) + + f.write(json.dumps({ + "index": index, + "image": "data/mmbench/mmbench_test_20230712/images/%d.jpg" % index, + "hint": hint, + "question": question, + "choices": choices, + # "answer": answer, + }) + "\n") + diff --git a/eval_mm/mmbench/mmbench_evaluation.py b/eval_mm/mmbench/mmbench_evaluation.py new file mode 100644 index 0000000..c753e2f --- /dev/null +++ b/eval_mm/mmbench/mmbench_evaluation.py @@ -0,0 +1,39 @@ +import pandas as pd +import json + +''' +This script provides `global top-1 accuracy` metric calculation for mmbench_dev. +''' + +predictions = json.load(open('mmbench_dev_20230712.json')) + +index2predictions = {} +for pred in predictions: + index2predictions[pred['index']] = pred['prediction'] + +datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t') + +glb_opts = ['A', 'B', 'C', 'D'] +index2answer = {} +for idx in range(len(datas)): + data = datas.iloc[idx] + index2answer[data['index']] = glb_opts.index(data['answer']) + +identity_indexes = list(set([int(_ % 1e6) for _ in index2predictions.keys()])) + +correct = 0 +total = 0 +for index in identity_indexes: + for _ in range(4): + cycle_index = int(_ * 1e6 + index) + if index2predictions.get(cycle_index, None) is not None: + if index2predictions[cycle_index] == index2answer[cycle_index]: + continue + else: + print(cycle_index) + break + else: + correct += 1 + total += 1 + +print(correct, total) diff --git a/eval_mm/mmbench/mmbench_evaluation_tricky.py b/eval_mm/mmbench/mmbench_evaluation_tricky.py new file mode 100644 index 0000000..237da51 --- /dev/null +++ b/eval_mm/mmbench/mmbench_evaluation_tricky.py @@ -0,0 +1,66 @@ +import pandas as pd +import json +import random + +''' +This script provides metric calculation for mmbench_dev with the same accuarcy algo as OpenCompass server +''' + +predictions = json.load(open('mmbench_dev_20230712.json')) + +index2predictions = {} +for pred in predictions: + index2predictions[pred['index']] = pred['prediction'] + + +from collections import Counter + +def most_common_elements(lst): + counter = Counter(lst) + max_count = max(counter.values()) + most_common = [element for element, count in counter.items() if count == max_count] + return random.choice(most_common) # random sample from random choice + +datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t') + +glb_opts = ['A', 'B', 'C', 'D'] +index2answer = {} +index2choices = {} +index2rawanswer = {} +for idx in range(len(datas)): + data = datas.iloc[idx] + + choices = [] + for opt in glb_opts: + if not pd.isna(data[opt]): + choices.append(data[opt]) + index2choices[data['index']] = choices + + index2answer[data['index']] = glb_opts.index(data['answer']) + index2rawanswer[data['index']] = choices[glb_opts.index(data['answer'])] + +identity_indexes = list(set([int(_ % 1e6) for _ in index2predictions.keys()])) + +correct = 0 +total = 0 +for index in identity_indexes: + raw_preds = [] + raw_answer = [] + for _ in range(4): + cycle_index = int(_ * 1e6 + index) + if index2predictions.get(cycle_index, None) is not None: + raw_answer = index2rawanswer[cycle_index] + raw_pred = index2choices[cycle_index][index2predictions[cycle_index]] + raw_preds.append(raw_pred) + + if len(set(raw_preds)) == 1: + if raw_preds[0] == raw_answer: + correct += 1 + else: + result = most_common_elements(raw_preds) + if result == raw_answer: + correct += 1 + + total += 1 + +print(correct, total, correct / total * 100.) diff --git a/eval_mm/mmbench/mmbench_predict_to_submission.py b/eval_mm/mmbench/mmbench_predict_to_submission.py new file mode 100644 index 0000000..baa0db8 --- /dev/null +++ b/eval_mm/mmbench/mmbench_predict_to_submission.py @@ -0,0 +1,73 @@ +import pandas as pd +import json +import random + +''' +This script convert the output file of our inference processor to target formation of OpenCompass evaluator server +''' + +predictions = json.load(open('mmbench_test_20230712.json')) + +index2predictions = {} +for pred in predictions: + index2predictions[pred['index']] = pred['prediction'] + +from collections import Counter + +def most_common_elements(lst): + counter = Counter(lst) + max_count = max(counter.values()) + most_common = [element for element, count in counter.items() if count == max_count] + print(most_common) + return random.choice(most_common) + # return most_common + +datas = pd.read_csv("data/mmbench/mmbench_test_20230712/mmbench_test_20230712.tsv", sep='\t') + +datas = datas.drop('image', axis=1) + +glb_opts = ['A', 'B', 'C', 'D'] +index2choices = {} +for idx in range(len(datas)): + data = datas.iloc[idx] + + choices = [] + for opt in glb_opts: + if not pd.isna(data[opt]): + choices.append(data[opt]) + index2choices[data['index']] = choices + +identity_indexes = list(set([int(_ % 1e6) for _ in index2predictions.keys()])) + + +processed_index2predictions = {} +for index in identity_indexes: + raw_preds = [] + for _ in range(4): + cycle_index = int(_ * 1e6 + index) + if index2predictions.get(cycle_index, None) is not None: + raw_pred = index2choices[cycle_index][index2predictions[cycle_index]] + raw_preds.append(raw_pred) + + if len(set(raw_preds)) == 1: + pred_answer = raw_preds[0] + else: + pred_answer = most_common_elements(raw_preds) + + print(index, pred_answer) + for _ in range(4): + cycle_index = int(_ * 1e6 + index) + if index2predictions.get(cycle_index, None) is not None: + processed_index2predictions[cycle_index] = index2choices[cycle_index].index(pred_answer) + + +predictions = [] +for idx in range(len(datas)): + data = datas.iloc[idx] + index = data['index'] + prediction = glb_opts[processed_index2predictions[index]] + predictions.append(prediction) + +datas['prediction'] = predictions +datas.to_excel("mmbench_test_20230712_230831_constrained.xlsx", index=False) +# constrained means we force the model predict same answer when tested on a question for multiple times