Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Support] multiple process parallel inference large model on multi-gpu #298

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
einops
gradio
huggingface_hub
Expand Down
45 changes: 45 additions & 0 deletions vlmeval/smp/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import base64
from PIL import Image
import sys
import torch
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model
from .misc import get_rank_and_world_size


Image.MAX_IMAGE_PIXELS = 1e9

Expand Down Expand Up @@ -179,3 +183,44 @@ def circular_pred(df, extract_func=None):
flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
flags = list(flag_map.values())
return np.mean(flags)


def get_memory():
total_memory = torch.cuda.get_device_properties(0).total_memory
total_mem = total_memory / 1024 / 1024 / 1024
return total_mem


def build_device_map(model, default_map=None, no_split=None, alpha=0.97, beta=0.9):
total_num_gpus = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
if world_size == total_num_gpus:
return model.cuda()

num_gpus = total_num_gpus // world_size
memory_map = {}
per_gpu_mem = get_memory() * alpha
memory_map.update({rank: f'{beta * per_gpu_mem:.2f}GiB'})
for gpu_id in range(1, num_gpus):
memory_map.update({rank + gpu_id * world_size: f'{per_gpu_mem:.2f}GiB'})
if hasattr(model, '_no_split_modules'):
no_split_module = model._no_split_modules
else:
no_split_module = []
if no_split is not None:
no_split_module = list(set((no_split_module + no_split)))
device_map = infer_auto_device_map(
model,
max_memory=memory_map,
no_split_module_classes=no_split_module
)
if default_map is not None:
for i in default_map:
device_map[i] = rank
for value in device_map.values():
assert value != 'disk', 'Please check and make sure to have enough memory to load model.'

model = dispatch_model(
model,
device_map=device_map).eval()
return model, device_map
19 changes: 12 additions & 7 deletions vlmeval/vlm/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def __init__(self, model_path='THUDM/glm-4v-9b', **kwargs):
self.end_text_token = '<|endoftext|>'

def generate_inner(self, message, dataset=None):
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
prompt, image_path = self.message_to_promptimg(message)
image = Image.open(image_path).convert('RGB')
if dataset is not None and DATASET_TYPE(dataset) in ['MCQ', 'Y/N']:
if dataset is not None and DATASET_TYPE(dataset) in ['multi-choice', 'Y/N']:
prompt += '\nShort Answer.'
inputs = self.tokenizer.apply_chat_template(
[{'role': 'user', 'image': image, 'content': prompt}],
Expand All @@ -51,11 +51,16 @@ class CogVlm(BaseModel):

def __init__(self, model_path='THUDM/cogvlm2-llama3-chat-19B', tokenizer_name=None, **kwargs):
assert model_path is not None
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to('cuda').eval()
from accelerate import init_empty_weights

with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
model, _ = build_device_map(model)

self.kwargs = kwargs
if tokenizer_name:
Expand Down
25 changes: 4 additions & 21 deletions vlmeval/vlm/emu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,12 @@ def __init__(self,
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model

local_rank = os.environ.get('LOCAL_RANK', 0)
local_rank = int(os.environ.get('LOCAL_RANK', 0))

device_num = torch.cuda.device_count()
assert local_rank * 2 <= device_num, 'The number of devices does not match the world size'
assert device_num >= 2, 'You need at least 2 GPUs to use EMU'

device_1 = local_rank
device_2 = local_rank + device_num // 2

torch.cuda.set_device(device_1)
torch.cuda.set_device(device_2)

tokenizer = AutoTokenizer.from_pretrained(model_path) # "BAAI/Emu2-Chat"
self.tokenizer = tokenizer
with init_empty_weights():
Expand All @@ -42,20 +36,9 @@ def __init__(self,
low_cpu_mem_usage=True,
trust_remote_code=True)

device_map = infer_auto_device_map(
model,
max_memory={
device_1: '38GiB',
device_2: '38GiB'
},
no_split_module_classes=['Block', 'LlamaDecoderLayer'])

# input and output logits should be on same device
device_map['model.decoder.lm.lm_head'] = device_1

model = dispatch_model(
model,
device_map=device_map).eval()
no_split = ['Block', 'LlamaDecoderLayer']
default_map = ['model.decoder.lm.lm_head']
model, _ = build_device_map(model, default_map, no_split)

self.model = model
kwargs_default = dict(max_new_tokens=512, length_penalty=-1)
Expand Down
26 changes: 14 additions & 12 deletions vlmeval/vlm/internvl_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,28 @@ def __init__(self, model_path='OpenGVLab/InternVL-Chat-V1-5', load_in_8bit=False
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)

if listinstr(['InternVL2-Llama3-76B'], model_path):
device_map = split_model(model_path.split('/')[1])
self.model = AutoModel.from_pretrained(
if not load_in_8bit:
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
load_in_8bit=load_in_8bit,
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map=device_map).eval()
device_map='cpu').eval()
default_map = [
'vision_model', 'mlp1', 'language_model.model.tok_embeddings',
'language_model.model.embed_tokens', 'language_model.output',
'language_model.model.norm', 'language_model.lm_head'
]
model, _ = build_device_map(model, default_map)
else:
device = torch.cuda.current_device()
self.device = device
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
model = AutoModel.from_pretrained(
model_path, torch_dtype=torch.bfloat16,
trust_remote_code=True,
load_in_8bit=load_in_8bit).eval()
if not load_in_8bit:
self.model = self.model.to(device)

self.device = torch.cuda.current_device()
self.model_path = model_path
self.model = model
self.image_size = self.model.config.vision_config.image_size
self.version = version
self.kwargs = kwargs
Expand Down
9 changes: 8 additions & 1 deletion vlmeval/vlm/omnilmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model


DEFAULT_IMAGE_TOKEN = '<image>'
Expand Down Expand Up @@ -97,7 +98,13 @@ class OmniLMM12B(BaseModel):

def __init__(self, model_path, root, **kwargs) -> None:
sys.path.append(root)
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
with init_empty_weights():
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)

default_map = ['lm_head', 'model.norm', 'model.resampler', 'model.layers']
no_split = ['Eva', 'MistralDecoderLayer', 'ModuleList', 'Resampler']
model, _ = build_device_map(model, default_map, no_split)

self.model = model
self.image_token_len = image_token_len
self.image_transform = img_processor
Expand Down
8 changes: 7 additions & 1 deletion vlmeval/vlm/pandagpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os.path as osp
import warnings
from .base import BaseModel
from ..smp import *


class PandaGPT(BaseModel):
Expand Down Expand Up @@ -40,7 +41,12 @@ def __init__(self, name, root=None, **kwargs):
delta_ckpt = torch.load(self.args['delta_ckpt_path'], map_location=torch.device('cpu'))
model.load_state_dict(delta_ckpt, strict=False)
torch.cuda.empty_cache()
self.model = model.eval().half().cuda()

default_map = ['llama_model.base_model.model.lm_head', 'llama_proj']
no_split_list = ['LlamaDecoderLayer', 'VisionTransformer']
model, _ = build_device_map(model, default_map, no_split_list)

self.model = model.eval()
kwargs_default = {'top_p': 0.9, 'do_sample': False, 'max_tgt_len': 128, 'temperature': 0.001}
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
Expand Down