From 0e591ef88dcad2fbde230601e7ed5220eefa23db Mon Sep 17 00:00:00 2001 From: ignorejjj <1009979434@qq.com> Date: Thu, 23 May 2024 20:08:05 +0800 Subject: [PATCH] Update --- examples/methods/my_config.yaml | 164 ++++++++++++++++++------------ examples/methods/run_exp.py | 5 +- flashrag/config/basic_config.yaml | 2 +- flashrag/generator/generator.py | 48 +++++++-- flashrag/utils/utils.py | 20 ++-- 5 files changed, 150 insertions(+), 89 deletions(-) diff --git a/examples/methods/my_config.yaml b/examples/methods/my_config.yaml index 71a60ba..5f1b0f6 100644 --- a/examples/methods/my_config.yaml +++ b/examples/methods/my_config.yaml @@ -1,90 +1,120 @@ -# ----Global Paths---- +# ------------------------------------------------Global Paths------------------------------------------------# # Paths to retrieval models retriever_model2path: - e5: model/e5-base-v2 - contriever: "model/contriever-msmarco" + e5: "intfloat/e5-base-v2" + bge: "intfloat/e5-base-v2" + contriever: "facebook/contriever" # Paths to generation models generator_model2path: - llama2-13B: model/llama-2-13b-hf - llama2-13B-chat: model/llama2-13b-chat - llama2-7B: model/llama-2-7b-hf - llama2-7B-chat: model/llama-2-7b-chat-hf - llama3-8B-instruct: model/LLaMA-3-8b-Instruct/ - + llama2-7B-chat: "meta-llama/Llama-2-7b-chat-hf" + llama2-7B: "meta-llama/Llama-2-7b-hf" + llama2-13B: "meta-llama/Llama-2-13b-hf" + llama2-13B-chat: "meta-llama/Llama-2-13b-chat-hf" + # Pooling methods for each embedding model model2pooling: - default: "pooler" e5: "mean" bge: "cls" contriever: "mean" + jina: 'mean' + dpr: cls # Indexes path for retrieval models method2index: - e5: "index/e5_flat_inner.index" - bm25: "index/bm25" - contriever: "index/contriever.index" + e5: ~ + bm25: ~ + contriever: ~ + +# ------------------------------------------------Environment Settings------------------------------------------------# +# Directory paths for data and outputs +data_dir: "dataset/" +save_dir: "output/" -# ----Environment Settings---- -gpu_id: "0,1" -dataset_name: "nq" -split: ["dev",'test'] +gpu_id: "0,1,2,3" +dataset_name: "nq" # name of the dataset in data_dir +split: ["test"] # dataset split to load (e.g. train,dev,test) # Sampling configurations for testing -test_sample_num: 5 -random_sample: False -save_intermediate_data: True +test_sample_num: ~ # number of samples to test (only work in dev/test split), if None, test all samples +random_sample: False # whether to randomly sample the test samples + # Seed for reproducibility seed: 2024 -# Directory paths for data and outputs -data_dir: "datasets/" -#save_dir: "/data00/jiajie_jin/test_project/output" -save_dir: "output/" +# Whether save intermediate data +save_intermediate_data: True +save_note: 'experiment' + +# -------------------------------------------------Retrieval Settings------------------------------------------------# +# If set the name, the model path will be find in global paths +retrieval_method: "e5" # name or path of the retrieval model. +index_path: ~ # set automatically if not provided. +faiss_gpu: False # whether use gpu to hold index +corpus_path: ~ # path to corpus in '.jsonl' format that store the documents -# ----Retrieval Settings---- -retrieval_method: "e5" # name or path of the retrieval model -index_path: ~ # Set automatically if not provided -corpus_path: "index/wiki_dump.jsonl" -retrieval_pooling_method: ~ - -retrieval_topk: 5 -retrieval_batch_size: 256 -retrieval_use_fp16: True -retrieval_query_max_length: 128 -save_retrieval_cache: False -use_retrieval_cache: False -retrieval_cache_path: ~ - -use_reranker: False -rerank_model_name: e5 -rerank_model_path: ~ +retrieval_topk: 5 # number of retrieved documents +retrieval_batch_size: 256 # batch size for retrieval +retrieval_use_fp16: True # whether to use fp16 for retrieval model +retrieval_query_max_length: 128 # max length of the query +save_retrieval_cache: True # whether to save the retrieval cache +use_retrieval_cache: False # whether to use the retrieval cache +retrieval_cache_path: ~ # path to the retrieval cache +retrieval_pooling_method: ~ # set automatically if not provided + +use_reranker: False # whether to use reranker +rerank_model_name: ~ # same as retrieval_method +rerank_model_path: ~ # path to reranker model, path will be automatically find in `retriever_model2path` rerank_pooling_method: ~ +rerank_topk: 5 # number of remain documents after reranking +rerank_max_length: 512 +rerank_batch_size: 256 # batch size for reranker rerank_use_fp16: True -rerank_topk: 5 -rerank_max_length: 512 -rerank_batch_size: 256 - -# ----Generator Settings---- -use_vllm: False -generator_model: "llama3-8B-instruct" # name or path of the generator -generator_max_input_len: 4096 -generator_batch_size: 4 -generation_params: - do_sample: False - max_tokens: 32 - temperature: 0.1 + +# -------------------------------------------------Generator Settings------------------------------------------------# +framework: hf # inference frame work of LLM, supporting: 'hf','vllm','fschat' +generator_model: "llama2-7B-chat" # name or path of the generator model +generator_max_input_len: 1024 # max length of the input +generator_batch_size: 2 # batch size for generation, invalid for vllm +generation_params: + max_tokens: 64 + temperature: 1.0 top_p: 1.0 -vllm_gpu_memory_utilization: 0.8 - -# ----Evaluation Settings---- -#metrics: ['em','f1','sub_em','precision','recall','retrieval_recall','rouge-1','rouge-l', 'bleu'] -metrics: ['em','f1','sub_em','precision','recall'] -save_metric_score: True - -# ---index building -index_doc_max_length: 256 -index_batch_size: 4096 -index_use_fp16: True -index_save_dir: "indexes/" -index_corpus_path: ~ # path to jsonl file, only used in building index +use_fid: False # whether to use FID, only valid in encoder-decoder model + + +# -------------------------------------------------Refiner Settings------------------------------------------------# +# If set, the refiner will be used to refine the retrieval documents. +refiner_name: ~ +refiner_model_path: ~ + +# Used for extractive method (e.g embedding models) +refiner_topk: 5 # number of remain sentence after refiner +refiner_pooling_method: 'mean' # pooling method of refiner model +refiner_encode_max_length: 256 +# Used for abstractive method (e.g. generation models like bart-large-cnn) +refiner_max_input_length: 1024 +refiner_max_output_length: 512 + +# Specify settings for llmlingua +llmlingua_config: + 'rate': 0.55, + 'condition_in_question': 'after_condition', + 'reorder_context': 'sort', + 'dynamic_context_compression_ratio': 0.3, + 'condition_compare': True, + 'context_budget': "+100", + 'rank_method': 'longllmlingua' +sc_config: + 'reduce_ratio': 0.5 + +# -------------------------------------------------Evaluation Settings------------------------------------------------# +# Metrics to evaluate the result +metrics: ['em','f1','sub_em','precision','recall'] +# Specify setting for metric, will be called within certain metrics +metric_setting: + retrieval_recall_topk: 5 +save_metric_score: True # whether to save the metric score into txt file + + + diff --git a/examples/methods/run_exp.py b/examples/methods/run_exp.py index a10a4f4..d281145 100644 --- a/examples/methods/run_exp.py +++ b/examples/methods/run_exp.py @@ -186,7 +186,6 @@ def retrobust(args): 'generator_lora_path': lora_path, 'generation_params':{"max_tokens":100}, 'gpu_id':args.gpu_id, - 'use_vllm':False, 'generator_max_input_len': 4096, 'dataset_name':args.dataset_name} config = Config('my_config.yaml',config_dict) @@ -258,7 +257,7 @@ def skr(args): def selfrag(args): config_dict = {'generator_model':'selfrag-llama2-7B', 'generator_model_path': 'model/selfrag_llama2_7b', - 'use_vllm': True, + 'framework': 'vllm', 'save_note':'self-rag', 'gpu_id':args.gpu_id, 'generation_params':{'max_new_tokens':100,'temperature':0.0,'top_p':1.0,'skip_special_tokens':False}, @@ -277,7 +276,7 @@ def selfrag(args): result = pipeline.run(test_data, batch_size=256) def flare(args): - config_dict={'save_note':'flare', 'gpu_id':args.gpu_id, 'use_vllm':True, + config_dict={'save_note':'flare', 'gpu_id':args.gpu_id, 'dataset_name':args.dataset_name} config = Config('my_config.yaml',config_dict) all_split = get_dataset(config) diff --git a/flashrag/config/basic_config.yaml b/flashrag/config/basic_config.yaml index bff6160..5f1b0f6 100644 --- a/flashrag/config/basic_config.yaml +++ b/flashrag/config/basic_config.yaml @@ -72,7 +72,7 @@ rerank_batch_size: 256 # batch size for reranker rerank_use_fp16: True # -------------------------------------------------Generator Settings------------------------------------------------# -use_vllm: False +framework: hf # inference frame work of LLM, supporting: 'hf','vllm','fschat' generator_model: "llama2-7B-chat" # name or path of the generator model generator_max_input_len: 1024 # max length of the input generator_batch_size: 2 # batch size for generation, invalid for vllm diff --git a/flashrag/generator/generator.py b/flashrag/generator/generator.py index bb004c7..98d159d 100644 --- a/flashrag/generator/generator.py +++ b/flashrag/generator/generator.py @@ -215,7 +215,7 @@ def generate(self, input_list: List[str], return_raw_output=False, return_scores return base_output -class CausalLMGenerator(BaseGenerator): +class HFCausalLMGenerator(BaseGenerator): """Class for decoder-only generator, based on hf. """ def __init__(self, config, model=None): @@ -233,15 +233,12 @@ def _load_model(self, model=None): """ if model is None: - from fastchat.model import load_model - model, tokenizer = load_model(self.model_path, - device = 'cuda', - num_gpus = self.gpu_num, - load_8bit = False, - cpu_offloading = False, - debug = False,) - #model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype="auto", device_map="auto") - #tokenizer = AutoTokenizer.from_pretrained(self.model_path) + model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype="auto", + device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(self.model_path) else: model.cuda() @@ -288,7 +285,7 @@ def generate(self, input_list: List[str], batch_size=None, return_scores=False, padding=True, truncation=True, max_length=self.max_input_len - ).to(self.model.device) + ).to('cuda') outputs = self.model.generate( **inputs, output_scores=True, @@ -357,3 +354,32 @@ def cal_gen_probs(self, prev, next): logits = logits[range(len(target_ids)), target_ids] return target_ids, logits + + +class FastChatGenerator(HFCausalLMGenerator): + def __init__(self, config, model=None): + super().__init__(config) + + def _load_model(self, model=None): + r"""Load model and tokenizer for generator. + + """ + if model is None: + from fastchat.model import load_model + model, tokenizer = load_model(self.model_path, + device = 'cuda', + num_gpus = self.gpu_num, + load_8bit = False, + cpu_offloading = False, + debug = False,) + + else: + model.cuda() + tokenizer = AutoTokenizer.from_pretrained(self.model_path) + model.eval() + if 'qwen' not in self.model_name: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + return model, tokenizer + \ No newline at end of file diff --git a/flashrag/utils/utils.py b/flashrag/utils/utils.py index 5de0836..68492d9 100644 --- a/flashrag/utils/utils.py +++ b/flashrag/utils/utils.py @@ -27,23 +27,29 @@ def get_dataset(config): def get_generator(config, **params): r"""Automatically select generator class based on config.""" - if "t5" in config['generator_model'] or "bart" in config['generator_model']: + if config['framework'] == 'vllm': return getattr( - importlib.import_module("flashrag.generator"), - "EncoderDecoderGenerator" - )(config, **params) + importlib.import_module("flashrag.generator"), + "VLLMGenerator" + )(config, **params) + elif config['framework'] == 'fschat': + return getattr( + importlib.import_module("flashrag.generator"), + "FastChatGenerator" + )(config, **params) else: - if config['use_vllm']: + if "t5" in config['generator_model'] or "bart" in config['generator_model']: return getattr( importlib.import_module("flashrag.generator"), - "VLLMGenerator" + "EncoderDecoderGenerator" )(config, **params) else: return getattr( importlib.import_module("flashrag.generator"), - "CausalLMGenerator" + "HFCausalLMGenerator" )(config, **params) + def get_retriever(config): r"""Automatically select retriever class based on config's retrieval method