diff --git a/libai/inference/generator/generation_utils.py b/libai/inference/generator/generation_utils.py index 3b3a94adc..31aa82c76 100644 --- a/libai/inference/generator/generation_utils.py +++ b/libai/inference/generator/generation_utils.py @@ -526,8 +526,8 @@ def greedy_search( # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: - unfinished_sequences = flow.mul( - unfinished_sequences, (next_tokens != eos_token_id).long() + unfinished_sequences = unfinished_sequences.mul( + next_tokens.ne(eos_token_id).prod(dim=0) ) if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): diff --git a/projects/BLOOM/configs/bloom_inference.py b/projects/BLOOM/configs/bloom_inference.py index 6ee9a1e37..998c00306 100644 --- a/projects/BLOOM/configs/bloom_inference.py +++ b/projects/BLOOM/configs/bloom_inference.py @@ -6,6 +6,7 @@ cfg = dict( # model vocab_size=250880, + max_position_embeddings=512, hidden_size=64, hidden_layers=2, n_head=8, diff --git a/projects/BLOOM/utils/model_loader.py b/projects/BLOOM/utils/model_loader.py index b292a362f..0580aa4d3 100644 --- a/projects/BLOOM/utils/model_loader.py +++ b/projects/BLOOM/utils/model_loader.py @@ -43,7 +43,7 @@ def _convert_state_dict(self, flow_state_dict, cfg): # prefix has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict) - prefix2 = "transformer." if has_prefix else "" + prefix2 = "transformer." if not has_prefix else "" # Convert layers. for key in old_keys: @@ -61,8 +61,13 @@ def _load_config_from_json(self, config_file): cfg_dict = json.load(f) self._update_cfg("hidden_layers", cfg_dict["n_layer"]) - self._update_cfg("hidden_size", cfg_dict["n_embed"]) - self._update_cfg("n_head", cfg_dict["num_attention_heads"]) + + if "n_embed" in cfg_dict.keys(): + self._update_cfg("hidden_size", cfg_dict["n_embed"]) + self._update_cfg("n_head", cfg_dict["num_attention_heads"]) + else: + self._update_cfg("hidden_size", cfg_dict["hidden_size"]) + self._update_cfg("n_head", cfg_dict["n_head"]) # update libai_cfg by config.json for k, v in cfg_dict.items(): diff --git a/projects/ChatGLM/configs/chatglm_config.py b/projects/ChatGLM/configs/chatglm_config.py index aa97363af..1192a3629 100644 --- a/projects/ChatGLM/configs/chatglm_config.py +++ b/projects/ChatGLM/configs/chatglm_config.py @@ -23,6 +23,7 @@ layernorm_epsilon=1e-05, multi_query_attention=True, multi_query_group_num=2, + max_position_embeddings=2048, num_attention_heads=32, num_layers=28, padded_vocab_size=65024, diff --git a/projects/Eval_LLM/README.md b/projects/Eval_LLM/README.md new file mode 100644 index 000000000..7cf0af530 --- /dev/null +++ b/projects/Eval_LLM/README.md @@ -0,0 +1,49 @@ +# LLM Evaluation + +A tool for evaluating OneFlow models based on [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/) + +## Environment + +Follow this [Installation Instruction](https://libai.readthedocs.io/en/latest/tutorials/get_started/Installation.html) to install oneflow(1.0.0) and libai first. Conda is recommended. +**Make sure you have python>=3.10 to run evaluation for GLM.** +Then run ```pip install -r ./projects/Eval_LLM/requirements.txt``` to install dependencies. + +## Run Eval + +### Set the parameters in ./projects/Eval_LLM/config.py + +> pretrained_model_path: The path of your model weights, either huggingface weights or libai weights is ok. +> hf_tokenizer_path: The path of huggingface tokenizer. +> model_type: Type of your model, this argument is need for loading model. All choices are listed in ./projects/Eval_LLM/special_arguments.json +> model_weight_type: Whether your weights are huggingface weights or libai weights. +> eval_tasks: Tasks you want to evaluate you model on. +> batch_size_per_gpu: Batch size on a single gpu, if you want to accelerate you evaluation, set it larger. But this may lead to OOM error. + +Tasks for Evaluation are listed [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks). + +### Run the following command to start eval +``` +bash tools/infer.sh projects/Eval_LLM/main.py 1 +``` +Notice: The number stands for how many gpus you want to use. + +If you want to eval GLM(ChatGLM), run this: +``` +CHATGLM_HF_DIR=YOUR_MODEL_PATH bash tools/infer.sh projects/Eval_LLM/main.py 1 +``` + +Notice: To run a model with 6B parameters, you are about to have VRAM more than 24GB. You can use tensor or pipeline parallel on multiple devices. + +To know more about distributed inference: https://docs.oneflow.org/en/master/parallelism/04_launch.html + +## Example of Eval Result +Using Llama2-7b +``` +{'sciq': + {'acc,none': 0.794, 'acc_stderr,none': 0.012795613612786583, 'acc_norm,none': 0.707, 'acc_norm_stderr,none': 0.014399942998441271, 'alias': 'sciq'}, +'lambada_openai': + {'perplexity,none': 28.778403569948463, 'perplexity_stderr,none': 1.0792474430271395, 'acc,none': 0.33980205705414324, 'acc_stderr,none': 0.006598757339311441, 'alias': 'lambada_openai'}, +'gsm8k': + {'exact_match,strict-match': 0.001516300227445034, 'exact_match_stderr,strict-match': 0.0010717793485492675, 'exact_match,flexible-extract': 0.01061410159211524, 'exact_match_stderr,flexible-extract': 0.002822713322387704, 'alias': 'gsm8k'} +} +``` \ No newline at end of file diff --git a/projects/Eval_LLM/config.py b/projects/Eval_LLM/config.py new file mode 100644 index 000000000..cb1d180a3 --- /dev/null +++ b/projects/Eval_LLM/config.py @@ -0,0 +1,22 @@ +from omegaconf import DictConfig + +parallel_config = DictConfig( + dict( + data_parallel_size=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + pipeline_num_layers=32, + device_type="cuda", + ) +) + +eval_config = DictConfig( + dict( + pretrained_model_path="", + hf_tokenizer_path="", + model_type="llama", + model_weight_type="libai", # libai or huggingface + eval_tasks=["lambada_openai", "gsm8k"], + batch_size_per_gpu=1, + ) +) diff --git a/projects/Eval_LLM/eval_harness.py b/projects/Eval_LLM/eval_harness.py new file mode 100644 index 000000000..814e0a01e --- /dev/null +++ b/projects/Eval_LLM/eval_harness.py @@ -0,0 +1,342 @@ +import json +import os +from pathlib import Path +from typing import Dict, List, Optional, TypeVar + +import oneflow as flow +import oneflow.nn.functional as F + +flow.mock_torch.enable(lazy=True) + +import oneflow as torch # noqa +from lm_eval import evaluator, tasks, utils # noqa +from lm_eval.api.model import LM # noqa +from lm_eval.models.utils import chunks # noqa +from tqdm import tqdm # noqa + +import libai.utils.distributed as dist # noqa + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +T = TypeVar("T") + + +class EvalHarnessBase(LM): + def __init__(self, model, tokenizer, model_name, batch_size: int, cfg: dict): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.model_name = model_name + self.batch_size_per_gpu = batch_size + self.cfg = cfg + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + pass + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + @property + def max_length(self): + return self.cfg.max_position_embeddings + + @property + def vocab_size(self): + return self.cfg.vocab_size + + @property + def max_gen_toks(self): + return self.cfg.get("max_length", 64) + + @property + def batch_size(self): + return self.batch_size_per_gpu * dist.get_world_size() + + @property + def device(self): + return flow.device("cuda:0") + + def tok_encode(self, string: str) -> List[int]: + return self.tokenizer.encode(string, add_special_tokens=False) + + def tok_decode(self, tokens: List[int]) -> str: + return self.tokenizer.decode(tokens) + + def batch_encode(self, strings: List[str]) -> Dict: + return self.tokenizer.batch_encode_plus(strings, padding=True) + + @flow.inference_mode() + def _model_call(self, inps): + inps = inps.to_global( + sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), + placement=dist.get_layer_placement(0), + ) + return self.model(inps)["logits"].to_local().to(flow.float32) + + def _model_generate(self, context, max_length, eos_token_id) -> flow.Tensor: + context = dist.convert_to_distributed_default_setting(context) + out = self.model.generate( + context, + max_length, + eos_token_id=eos_token_id, + ) + return out.unsqueeze(0) + + def loglikelihood(self, requests, disable_tqdm=False): + new_reqs = [] + for request in tqdm(requests, disable=disable_tqdm): + context, continuation = request.arguments + if context == "": + # end of text as context + context_enc = [self.eos_token_id] + else: + context_enc = self.tok_encode(context) + + continuation_enc = self.tok_encode(continuation)[: self.max_length] + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + return self._loglikelihood_tokens(new_reqs) + + def loglikelihood_rolling(self, requests): + # TODO: Implement caching once we've confirmed the perplexity implementation + # TODO: automatic batch size detection for vectorization + + loglikelihoods = [] + for (string,) in tqdm(requests): + rolling_token_windows = list( + map( + utils.make_disjoint_window, + utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.eot_token_id, + max_seq_len=self.max_length, + context_len=1, + ), + ) + ) + + rolling_token_windows = [(None,) + x for x in rolling_token_windows] + + string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) + + # discard is_greedy + string_nll = [x[0] for x in string_nll] + + string_nll = sum(string_nll) + loglikelihoods.append(string_nll) + + return loglikelihoods + + def _loglikelihood_tokens(self, requests, disable_tqdm=False): + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, + # which is more useful for planning + # - to know the size of a batch when going through the list, + # you know the first one is always the batch + # padded context length. this is useful to simplify + # the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = x[1] + x[2] + return -len(toks), tuple(toks) + + # TODO: automatic (variable) batch size detection for vectorization + re_ord = utils.Reorderer(requests, _collate) + for chunk in chunks(tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size): + inps = [] + cont_toks_list = [] + inplens = [] + + padding_length = None + + # because vectorizing is annoying, + # we first convert each (context, continuation) pair to padded tensors, + # then we pack them together into a batch, call the model, + # and then pick it all apart again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works: + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # gpt2 \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], + dtype=torch.long, + ).to(self.device) + (inplen,) = inp.shape + + cont = continuation_enc + + # since in _collate we make sure length is descending, + # the longest is always the first one. + padding_length = padding_length if padding_length is not None else inplen + + # pad length from seq to padding_length + inp = torch.cat( + [ + inp, # [seq] + torch.zeros(padding_length - inplen, dtype=torch.long).to( + inp.device + ), # [padding_length - seq] + ], + dim=0, + ) + + inps.append(inp.unsqueeze(0)) # [1, padding_length] + cont_toks_list.append(cont) + inplens.append(inplen) + + batched_inps = torch.cat(inps, dim=0) # [batch, padding_length + multi_logits = F.log_softmax( + self._model_call(batched_inps), dim=-1 + ).cpu() # [batch, padding_length, vocab] + + for (cache_key, _, _), logits, inp, inplen, cont_toks in zip( + chunk, multi_logits, inps, inplens, cont_toks_list + ): + + # Slice to original seq length + contlen = len(cont_toks) + logits = logits[inplen - contlen : inplen].unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + + # partial caching + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + + res.append(answer) + + return re_ord.get_original(res) + + def generate_until(self, requests, disable_tqdm=False) -> List[str]: + res = [] + + for chunk in chunks( + tqdm(requests, disable=disable_tqdm, desc="Running generate_until requests"), + self.batch_size, + ): + _, until = chunk[0].arguments + if isinstance(until, dict): + until = until["until"] + if isinstance(until, str): + until = [until] + primary_until = self.tok_encode(until[0]) + reqs = [] + for request in chunk: + reqs.append(request.arguments[0]) + context_enc = torch.tensor(self.batch_encode(reqs)["input_ids"]).to(self.device)[ + :, self.max_gen_toks - self.max_length : + ] + cont = self._model_generate( + context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until[0] + ) + + for i in range(cont[0].shape[0]): + s = self.tok_decode(cont[0].tolist()[i][context_enc.shape[1] :]) + for term in until: + s = s.split(term)[0] + + res.append(s) + return res + + @flow.inference_mode() + def run_eval( + self, + eval_tasks: List[str], + limit: Optional[int], + bootstrap_iters: int, + ) -> Dict: + import fnmatch + + task_manager = tasks.TaskManager() + all_tasks = task_manager.all_tasks + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + task_names = list(task_names) + task_names.sort() + return task_names + + eval_tasks = pattern_match(eval_tasks, all_tasks) + print(f"Found tasks: {eval_tasks}") + + if dist.is_main_process() == 0: + tasks.get_task_dict(eval_tasks) + dist.synchronize() + + lm = self + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(task_name_list=eval_tasks), + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + results["config"] = dict( + model=self.model_name, + batch_size=self.batch_size, + device=str(self.device), + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + return results + + +@flow.inference_mode() +def run_eval_harness( + model, + tokenizer, + model_name, + eval_tasks: List[str] = [ + "hellaswag", + ], + batch_size_per_gpu: int = 1, + save_filepath: Optional[Path] = None, + limit: Optional[int] = None, + bootstrap_iters: int = 100000, + dtype=flow.float16, + cfg=None, +): + model.eval() + model = model.to(dtype) + with flow.no_grad(): + eval_harness = EvalHarnessBase(model, tokenizer, model_name, batch_size_per_gpu, cfg) + results = eval_harness.run_eval(eval_tasks, limit, bootstrap_iters) + if save_filepath is None: + print(results["results"]) + else: + print(f"Saving results to {str(save_filepath)!r}") + data = json.dumps(results) + with open(save_filepath, "w") as fw: + fw.write(data) diff --git a/projects/Eval_LLM/main.py b/projects/Eval_LLM/main.py new file mode 100644 index 000000000..487dfe975 --- /dev/null +++ b/projects/Eval_LLM/main.py @@ -0,0 +1,86 @@ +import importlib +import json + +from transformers import AutoTokenizer as HF_AutoTokenizer + +import libai.utils.distributed as dist # noqa +from libai.config import LazyConfig +from libai.models.utils.model_loader.base_loader import ModelLoaderLiBai # noqa + + +class LLMLoaderLibai(ModelLoaderLiBai): + def __init__(self, model, libai_cfg, pretrained_model_path, base_model_prefix, **kwargs): + super().__init__(model, libai_cfg, pretrained_model_path, **kwargs) + self.base_model_prefix_2 = base_model_prefix + + +def get_special_arguments(cfg): + with open("./projects/Eval_LLM/special_arguments.json", "r") as f: + arguments = json.load(f) + special_arguments = arguments[cfg.eval_config.model_type] + return special_arguments + + +def main(): + cfg = LazyConfig.load("./projects/Eval_LLM/config.py") + dist.setup_dist_util(cfg.parallel_config) + special_arguments = get_special_arguments(cfg) + print("Loading Model...") + model_cfg = LazyConfig.load(special_arguments["config_path"]) + if model_cfg.cfg.max_position_embeddings is None: + model_cfg.cfg.max_position_embeddings = 1024 + + model_class = getattr( + importlib.import_module(special_arguments["model_class_prefix"]), + special_arguments["model_class"], + ) + + assert cfg.eval_config.model_weight_type in [ + "huggingface", + "libai", + ], "model_weight_type must be huggingface or libai" + if cfg.eval_config.model_weight_type == "huggingface": + huggingface_loader = getattr( + importlib.import_module(special_arguments["huggingface_loader_prefix"]), + special_arguments["huggingface_loader"], + ) + load_func = huggingface_loader( + model=model_class, + libai_cfg=model_cfg.cfg, + pretrained_model_path=cfg.eval_config.pretrained_model_path, + ) + else: + load_func = LLMLoaderLibai( + model=model_class, + libai_cfg=model_cfg.cfg, + pretrained_model_path=cfg.eval_config.pretrained_model_path, + base_model_prefix=special_arguments["base_model_prefix_2"], + ) + + tokenizer = HF_AutoTokenizer.from_pretrained( + cfg.eval_config.hf_tokenizer_path, trust_remote_code=True + ) + with open(cfg.eval_config.hf_tokenizer_path + "/config.json", "r") as f: + generation_config = json.load(f) + + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = generation_config["pad_token_id"] + if tokenizer.eos_token_id is None: + tokenizer.eos_token_id = generation_config["eos_token_id"] + model = load_func.load() + print("Model Loaded!") + + from projects.Eval_LLM.eval_harness import run_eval_harness # noqa + + run_eval_harness( + model, + tokenizer, + cfg.eval_config.model_type, + eval_tasks=cfg.eval_config.eval_tasks, + batch_size_per_gpu=cfg.eval_config.batch_size_per_gpu, + cfg=model_cfg.cfg, + ) + + +if __name__ == "__main__": + main() diff --git a/projects/Eval_LLM/requirements.txt b/projects/Eval_LLM/requirements.txt new file mode 100644 index 000000000..6785b5a10 --- /dev/null +++ b/projects/Eval_LLM/requirements.txt @@ -0,0 +1,6 @@ +torch>=2.0.0 +tokenizers +transformers +datasets +huggingface-hub +lm-eval==0.4.2 \ No newline at end of file diff --git a/projects/Eval_LLM/special_arguments.json b/projects/Eval_LLM/special_arguments.json new file mode 100644 index 000000000..2c863fb6f --- /dev/null +++ b/projects/Eval_LLM/special_arguments.json @@ -0,0 +1,32 @@ +{ + "llama":{ + "n_layers_hf":"num_hidden_layers", + "n_layer_libai":"hidden_layers", + "base_model_prefix_2":"model", + "config_path":"./projects/Llama/configs/llama_config.py", + "model_class_prefix":"projects.Llama.llama", + "model_class":"LlamaForCausalLM", + "huggingface_loader_prefix":"projects.Llama.utils.llama_loader", + "huggingface_loader":"LlamaLoaderHuggerFace" + }, + "bloom":{ + "n_layers_hf":"n_layer", + "n_layer_libai":"hidden_layers", + "base_model_prefix_2":"transformer", + "config_path":"./projects/BLOOM/configs/bloom_inference.py", + "model_class_prefix":"projects.BLOOM.modeling.bloom_model", + "model_class":"BloomForCausalLM", + "huggingface_loader_prefix":"projects.BLOOM.utils.model_loader", + "huggingface_loader":"BlooMLoaderHuggerFace" + }, + "glm":{ + "n_layers_hf":"num_layers", + "n_layer_libai":"num_layers", + "base_model_prefix_2":"model", + "config_path":"./projects/ChatGLM/configs/chatglm_config.py", + "model_class_prefix":"projects.ChatGLM.chatglm", + "model_class":"ChatGLMForConditionalGeneration", + "huggingface_loader_prefix":"projects.ChatGLM.utils.chatglm_loader", + "huggingface_loader":"ChatGLMLoaderHuggerFace" + } +} \ No newline at end of file diff --git a/projects/Llama/adapter/adapter_config.py b/projects/Llama/adapter/adapter_config.py index 7381e64af..80f13cb71 100644 --- a/projects/Llama/adapter/adapter_config.py +++ b/projects/Llama/adapter/adapter_config.py @@ -11,7 +11,7 @@ hidden_size=4096, initializer_range=0.02, intermediate_size=11008, - max_position_embeddings=4096, + max_position_embeddings=2048, num_attention_heads=32, hidden_layers=32, pretraining_tp=1, diff --git a/projects/Llama/configs/llama_config.py b/projects/Llama/configs/llama_config.py index 58b86ecd6..01d208016 100644 --- a/projects/Llama/configs/llama_config.py +++ b/projects/Llama/configs/llama_config.py @@ -12,7 +12,7 @@ hidden_size=4096, initializer_range=0.02, intermediate_size=11008, - max_position_embeddings=4096, + max_position_embeddings=2048, num_attention_heads=32, hidden_layers=32, pretraining_tp=1,