diff --git a/requirements.txt b/requirements.txt index cfec36885..32f09ab51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +accelerate einops gradio huggingface_hub diff --git a/vlmeval/smp/vlm.py b/vlmeval/smp/vlm.py index b01fb73f4..c65ea239e 100644 --- a/vlmeval/smp/vlm.py +++ b/vlmeval/smp/vlm.py @@ -8,6 +8,10 @@ import base64 from PIL import Image import sys +import torch +from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model +from .misc import get_rank_and_world_size + Image.MAX_IMAGE_PIXELS = 1e9 @@ -179,3 +183,44 @@ def circular_pred(df, extract_func=None): flag_map = {k: v for k, v in flag_map.items() if valid_map[k]} flags = list(flag_map.values()) return np.mean(flags) + + +def get_memory(): + total_memory = torch.cuda.get_device_properties(0).total_memory + total_mem = total_memory / 1024 / 1024 / 1024 + return total_mem + + +def build_device_map(model, default_map=None, no_split=None, alpha=0.97, beta=0.9): + total_num_gpus = torch.cuda.device_count() + rank, world_size = get_rank_and_world_size() + if world_size == total_num_gpus: + return model.cuda() + + num_gpus = total_num_gpus // world_size + memory_map = {} + per_gpu_mem = get_memory() * alpha + memory_map.update({rank: f'{beta * per_gpu_mem:.2f}GiB'}) + for gpu_id in range(1, num_gpus): + memory_map.update({rank + gpu_id * world_size: f'{per_gpu_mem:.2f}GiB'}) + if hasattr(model, '_no_split_modules'): + no_split_module = model._no_split_modules + else: + no_split_module = [] + if no_split is not None: + no_split_module = list(set((no_split_module + no_split))) + device_map = infer_auto_device_map( + model, + max_memory=memory_map, + no_split_module_classes=no_split_module + ) + if default_map is not None: + for i in default_map: + device_map[i] = rank + for value in device_map.values(): + assert value != 'disk', 'Please check and make sure to have enough memory to load model.' + + model = dispatch_model( + model, + device_map=device_map).eval() + return model, device_map diff --git a/vlmeval/vlm/cogvlm.py b/vlmeval/vlm/cogvlm.py index d5d1ece94..ddec88960 100644 --- a/vlmeval/vlm/cogvlm.py +++ b/vlmeval/vlm/cogvlm.py @@ -27,9 +27,9 @@ def __init__(self, model_path='THUDM/glm-4v-9b', **kwargs): self.end_text_token = '<|endoftext|>' def generate_inner(self, message, dataset=None): - prompt, image_path = self.message_to_promptimg(message, dataset=dataset) + prompt, image_path = self.message_to_promptimg(message) image = Image.open(image_path).convert('RGB') - if dataset is not None and DATASET_TYPE(dataset) in ['MCQ', 'Y/N']: + if dataset is not None and DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']: prompt += '\nShort Answer.' inputs = self.tokenizer.apply_chat_template( [{'role': 'user', 'image': image, 'content': prompt}], @@ -51,11 +51,16 @@ class CogVlm(BaseModel): def __init__(self, model_path='THUDM/cogvlm2-llama3-chat-19B', tokenizer_name=None, **kwargs): assert model_path is not None - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, - trust_remote_code=True, - ).to('cuda').eval() + from accelerate import init_empty_weights + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + model, _ = build_device_map(model) self.kwargs = kwargs if tokenizer_name: diff --git a/vlmeval/vlm/emu.py b/vlmeval/vlm/emu.py index 1051c799b..0088fe5f8 100644 --- a/vlmeval/vlm/emu.py +++ b/vlmeval/vlm/emu.py @@ -21,18 +21,12 @@ def __init__(self, from transformers import AutoModelForCausalLM, AutoTokenizer from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model - local_rank = os.environ.get('LOCAL_RANK', 0) + local_rank = int(os.environ.get('LOCAL_RANK', 0)) device_num = torch.cuda.device_count() assert local_rank * 2 <= device_num, 'The number of devices does not match the world size' assert device_num >= 2, 'You need at least 2 GPUs to use EMU' - device_1 = local_rank - device_2 = local_rank + device_num // 2 - - torch.cuda.set_device(device_1) - torch.cuda.set_device(device_2) - tokenizer = AutoTokenizer.from_pretrained(model_path) # "BAAI/Emu2-Chat" self.tokenizer = tokenizer with init_empty_weights(): @@ -42,20 +36,9 @@ def __init__(self, low_cpu_mem_usage=True, trust_remote_code=True) - device_map = infer_auto_device_map( - model, - max_memory={ - device_1: '38GiB', - device_2: '38GiB' - }, - no_split_module_classes=['Block', 'LlamaDecoderLayer']) - - # input and output logits should be on same device - device_map['model.decoder.lm.lm_head'] = device_1 - - model = dispatch_model( - model, - device_map=device_map).eval() + no_split = ['Block', 'LlamaDecoderLayer'] + default_map = ['model.decoder.lm.lm_head'] + model, _ = build_device_map(model, default_map, no_split) self.model = model kwargs_default = dict(max_new_tokens=512, length_penalty=-1) diff --git a/vlmeval/vlm/internvl_chat.py b/vlmeval/vlm/internvl_chat.py index 563714357..c3ed34358 100644 --- a/vlmeval/vlm/internvl_chat.py +++ b/vlmeval/vlm/internvl_chat.py @@ -137,26 +137,28 @@ def __init__(self, model_path='OpenGVLab/InternVL-Chat-V1-5', load_in_8bit=False self.model_path = model_path self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) - if listinstr(['InternVL2-Llama3-76B'], model_path): - device_map = split_model(model_path.split('/')[1]) - self.model = AutoModel.from_pretrained( + if not load_in_8bit: + model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, load_in_8bit=load_in_8bit, trust_remote_code=True, low_cpu_mem_usage=True, - device_map=device_map).eval() + device_map='cpu').eval() + default_map = [ + 'vision_model', 'mlp1', 'language_model.model.tok_embeddings', + 'language_model.model.embed_tokens', 'language_model.output', + 'language_model.model.norm', 'language_model.lm_head' + ] + model, _ = build_device_map(model, default_map) else: - device = torch.cuda.current_device() - self.device = device - self.model = AutoModel.from_pretrained( - model_path, - torch_dtype=torch.bfloat16, + model = AutoModel.from_pretrained( + model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, load_in_8bit=load_in_8bit).eval() - if not load_in_8bit: - self.model = self.model.to(device) - + self.device = torch.cuda.current_device() + self.model_path = model_path + self.model = model self.image_size = self.model.config.vision_config.image_size self.version = version self.kwargs = kwargs diff --git a/vlmeval/vlm/omnilmm.py b/vlmeval/vlm/omnilmm.py index 12971cd77..afdefe28a 100644 --- a/vlmeval/vlm/omnilmm.py +++ b/vlmeval/vlm/omnilmm.py @@ -5,6 +5,7 @@ from .base import BaseModel from ..smp import * from ..dataset import DATASET_TYPE +from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model DEFAULT_IMAGE_TOKEN = '' @@ -97,7 +98,13 @@ class OmniLMM12B(BaseModel): def __init__(self, model_path, root, **kwargs) -> None: sys.path.append(root) - model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path) + with init_empty_weights(): + model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path) + + default_map = ['lm_head', 'model.norm', 'model.resampler', 'model.layers'] + no_split = ['Eva', 'MistralDecoderLayer', 'ModuleList', 'Resampler'] + model, _ = build_device_map(model, default_map, no_split) + self.model = model self.image_token_len = image_token_len self.image_transform = img_processor diff --git a/vlmeval/vlm/pandagpt.py b/vlmeval/vlm/pandagpt.py index 47821de7e..805a55d22 100644 --- a/vlmeval/vlm/pandagpt.py +++ b/vlmeval/vlm/pandagpt.py @@ -3,6 +3,7 @@ import os.path as osp import warnings from .base import BaseModel +from ..smp import * class PandaGPT(BaseModel): @@ -40,7 +41,12 @@ def __init__(self, name, root=None, **kwargs): delta_ckpt = torch.load(self.args['delta_ckpt_path'], map_location=torch.device('cpu')) model.load_state_dict(delta_ckpt, strict=False) torch.cuda.empty_cache() - self.model = model.eval().half().cuda() + + default_map = ['llama_model.base_model.model.lm_head', 'llama_proj'] + no_split_list = ['LlamaDecoderLayer', 'VisionTransformer'] + model, _ = build_device_map(model, default_map, no_split_list) + + self.model = model.eval() kwargs_default = {'top_p': 0.9, 'do_sample': False, 'max_tgt_len': 128, 'temperature': 0.001} kwargs_default.update(kwargs) self.kwargs = kwargs_default