Skip to content

Commit

Permalink
add mmbench codes
Browse files Browse the repository at this point in the history
  • Loading branch information
vealocia committed Sep 6, 2023
1 parent f695b79 commit b553964
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 0 deletions.
61 changes: 61 additions & 0 deletions eval_mm/mmbench/MMBENCH.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# MMBench Evaluation

## Data

```bash
/cpfs01/shared/public/shusheng.yss/workspace/23082502_qwenvl_eval_test/eval_mm/data/mmbench
```

## Dev

```bash
checkpoint=/PATH/TO/CHECKPOINT
ds=mmbench_dev_20230712
python -m torch.distributed.launch --use-env \
--nproc_per_node ${NPROC_PER_NODE:-8} \
--nnodes ${WORLD_SIZE:-1} \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT:-12345} \
evaluate_multiple_choice_mmbench.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 2 \
--num-workers 2

# the results will be saved to mmbench_dev_20230712.json

# without consistency constrain

python mmbench_evaluation.py

# with consistency constrain

python mmbench_evaluation_tricky.py

```

## Test

```bash
checkpoint=/PATH/TO/CHECKPOINT
ds=mmbench_test_20230712
python -m torch.distributed.launch --use-env \
--nproc_per_node ${NPROC_PER_NODE:-8} \
--nnodes ${WORLD_SIZE:-1} \
--node_rank ${RANK:-0} \
--master_addr ${MASTER_ADDR:-127.0.0.1} \
--master_port ${MASTER_PORT:-12345} \
evaluate_multiple_choice_mmbench.py \
--checkpoint $checkpoint \
--dataset $ds \
--batch-size 2 \
--num-workers 2

# the results will be saved to mmbench_test_20230712.json

# convert to submission format with consistency constrain

python mmbench_predict_to_submission.py

```
189 changes: 189 additions & 0 deletions eval_mm/mmbench/evaluate_multiple_choice_mmbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import argparse
import itertools
import json
import os
from functools import partial

import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

multiple_choices = ['A', 'B', 'C', 'D', 'E']

ds_collections = {
'mmbench_dev_20230712': {
'test': 'data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.jsonl',
},
'mmbench_test_20230712': {
'test': 'data/mmbench/mmbench_test_20230712/mmbench_test_20230712.jsonl',
}
}

def collate_fn(batches, pad_token_id):

indexes = [_['index'] for _ in batches]

input_tokens = [_['input_tokens'] for _ in batches]
target_lengths = [_['target_lengths'] for _ in batches]

chunk_sizes = [len(_) for _ in input_tokens]

input_tokens = [_ for _ in itertools.chain.from_iterable(input_tokens)]

max_lengths = max([len(_) for _ in input_tokens])
input_tokens = [[pad_token_id] * (max_lengths - len(_)) + _
for _ in input_tokens]
input_tokens = torch.LongTensor(input_tokens)

attention_mask = 1 - input_tokens.eq(pad_token_id).float()

return input_tokens, attention_mask, target_lengths, chunk_sizes, indexes


class MultipleChoiceDataste(torch.utils.data.Dataset):

def __init__(self, test, prompt, tokenizer):
self.datas = open(test).readlines()
self.prompt = prompt
self.tokenizer = tokenizer

def __len__(self):
return len(self.datas)

def __getitem__(self, idx):

data = json.loads(self.datas[idx].strip())
index = data['index']
image = data['image']
hint = data['hint'] if data['hint'] else 'N/A'
question = data['question']

choices = data['choices']
choice_list = []
for i, c in enumerate(choices):
choice_list.append('{}. {}'.format(multiple_choices[i], c))
choice_txt = '\n'.join(choice_list)

prompt = self.prompt.format(image, hint, question, choice_txt)

prompt_tokens = self.tokenizer(prompt).input_ids
target_tokens = [
self.tokenizer(' ' + _).input_ids
for _ in multiple_choices[:len(choices)]
]

return {
'index': index,
'input_tokens': [prompt_tokens + _ for _ in target_tokens],
'target_lengths': [len(_) for _ in target_tokens],
# 'answer': data['answer'],
}


class InferenceSampler(torch.utils.data.sampler.Sampler):

def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)

@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]

begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)

def __iter__(self):
yield from self._local_indices

def __len__(self):
return len(self._local_indices)


if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
args = parser.parse_args()

torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)

torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))

model = AutoModelForCausalLM.from_pretrained(
args.checkpoint, device_map='cuda', trust_remote_code=True).eval()

tokenizer = AutoTokenizer.from_pretrained(args.checkpoint,
trust_remote_code=True)

prompt = '<img>{}</img>Context: {}\nQuestion: {}\nOptions: {}\nAnswer:'

dataset = MultipleChoiceDataste(test=ds_collections[args.dataset]['test'],
prompt=prompt,
tokenizer=tokenizer)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, pad_token_id=tokenizer.eod_id),
)

results = []
with torch.no_grad():
for _, (input_tokens, attention_mask, target_lengths,
chunk_sizes, indexes) in tqdm(enumerate(dataloader)):

outputs = model(
input_ids=input_tokens[:, :-1].cuda(),
attention_mask=attention_mask[:, :-1].cuda(),
return_dict=True,
)
losses = torch.nn.functional.cross_entropy(outputs.logits.permute(
0, 2, 1),
input_tokens[:,
1:].cuda(),
reduction='none')

losses = losses.split(chunk_sizes, dim=0)

for loss, target_length, index in zip(losses, target_lengths, indexes):

target_loss = loss.mean(-1)
for _ in range(len(target_length)):
target_loss[_] = loss[_, -target_length[_]:].mean()
pred = target_loss.argmin().item()

results.append({
"index": index,
"prediction": pred,
})

torch.distributed.barrier()

world_size = torch.distributed.get_world_size()
merged_results = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_results, results)

merged_results = [_ for _ in itertools.chain.from_iterable(merged_results)]

if torch.distributed.get_rank() == 0:
json.dump(merged_results, open(f"{args.dataset}.json", "w"))

torch.distributed.barrier()
48 changes: 48 additions & 0 deletions eval_mm/mmbench/mmbench_converter_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pandas as pd
import io
import base64
import json
from PIL import Image

'''
This scripts convert mmbench_dev tsv file to jsonl
'''

datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t')

global_choices = ['A', 'B', 'C', 'D']

def decode_base64_to_image(base64_string):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image


with open('./data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.jsonl', 'w') as f:
for idx in range(len(datas)):
data = datas.iloc[idx]

index = int(data['index'])
question = data['question']
hint = data['hint'] if not pd.isna(data['hint']) else 'N/A'

choices = []
for opt in global_choices:
if pd.isna(data[opt]):
continue
choices.append(data[opt])

answer = global_choices.index(data['answer'])

image = decode_base64_to_image(data['image'])
image.save("data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index)

f.write(json.dumps({
"index": index,
"image": "data/mmbench/mmbench_dev_20230712/images/%d.jpg" % index,
"hint": hint,
"question": question,
"choices": choices,
"answer": answer,
}) + "\n")

49 changes: 49 additions & 0 deletions eval_mm/mmbench/mmbench_converter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pandas as pd
import io
import base64
import json
from PIL import Image

'''
This script convert mmbench_test tsv file to jsonl
This script is very similar to mmbench_converter_dev except there's no answer for accuracy calculation
'''

datas = pd.read_csv("data/mmbench/mmbench_test_20230712/mmbench_test_20230712.tsv", sep='\t')

global_choices = ['A', 'B', 'C', 'D']

def decode_base64_to_image(base64_string):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image


with open('./data/mmbench/mmbench_test_20230712/mmbench_test_20230712.jsonl', 'w') as f:
for idx in range(len(datas)):
data = datas.iloc[idx]

index = int(data['index'])
question = data['question']
hint = data['hint'] if not pd.isna(data['hint']) else 'N/A'

choices = []
for opt in global_choices:
if pd.isna(data[opt]):
continue
choices.append(data[opt])

# answer = global_choices.index(data['answer'])

image = decode_base64_to_image(data['image'])
image.save("data/mmbench/mmbench_test_20230712/images/%d.jpg" % index)

f.write(json.dumps({
"index": index,
"image": "data/mmbench/mmbench_test_20230712/images/%d.jpg" % index,
"hint": hint,
"question": question,
"choices": choices,
# "answer": answer,
}) + "\n")

39 changes: 39 additions & 0 deletions eval_mm/mmbench/mmbench_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pandas as pd
import json

'''
This script provides `global top-1 accuracy` metric calculation for mmbench_dev.
'''

predictions = json.load(open('mmbench_dev_20230712.json'))

index2predictions = {}
for pred in predictions:
index2predictions[pred['index']] = pred['prediction']

datas = pd.read_csv("data/mmbench/mmbench_dev_20230712/mmbench_dev_20230712.tsv", sep='\t')

glb_opts = ['A', 'B', 'C', 'D']
index2answer = {}
for idx in range(len(datas)):
data = datas.iloc[idx]
index2answer[data['index']] = glb_opts.index(data['answer'])

identity_indexes = list(set([int(_ % 1e6) for _ in index2predictions.keys()]))

correct = 0
total = 0
for index in identity_indexes:
for _ in range(4):
cycle_index = int(_ * 1e6 + index)
if index2predictions.get(cycle_index, None) is not None:
if index2predictions[cycle_index] == index2answer[cycle_index]:
continue
else:
print(cycle_index)
break
else:
correct += 1
total += 1

print(correct, total)
Loading

0 comments on commit b553964

Please sign in to comment.