Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed May 23, 2024
1 parent 20038eb commit 0e591ef
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 89 deletions.
164 changes: 97 additions & 67 deletions examples/methods/my_config.yaml
Original file line number Diff line number Diff line change
@@ -1,90 +1,120 @@
# ----Global Paths----
# ------------------------------------------------Global Paths------------------------------------------------#
# Paths to retrieval models
retriever_model2path:
e5: model/e5-base-v2
contriever: "model/contriever-msmarco"
e5: "intfloat/e5-base-v2"
bge: "intfloat/e5-base-v2"
contriever: "facebook/contriever"

# Paths to generation models
generator_model2path:
llama2-13B: model/llama-2-13b-hf
llama2-13B-chat: model/llama2-13b-chat
llama2-7B: model/llama-2-7b-hf
llama2-7B-chat: model/llama-2-7b-chat-hf
llama3-8B-instruct: model/LLaMA-3-8b-Instruct/

llama2-7B-chat: "meta-llama/Llama-2-7b-chat-hf"
llama2-7B: "meta-llama/Llama-2-7b-hf"
llama2-13B: "meta-llama/Llama-2-13b-hf"
llama2-13B-chat: "meta-llama/Llama-2-13b-chat-hf"

# Pooling methods for each embedding model
model2pooling:
default: "pooler"
e5: "mean"
bge: "cls"
contriever: "mean"
jina: 'mean'
dpr: cls

# Indexes path for retrieval models
method2index:
e5: "index/e5_flat_inner.index"
bm25: "index/bm25"
contriever: "index/contriever.index"
e5: ~
bm25: ~
contriever: ~

# ------------------------------------------------Environment Settings------------------------------------------------#
# Directory paths for data and outputs
data_dir: "dataset/"
save_dir: "output/"

# ----Environment Settings----
gpu_id: "0,1"
dataset_name: "nq"
split: ["dev",'test']
gpu_id: "0,1,2,3"
dataset_name: "nq" # name of the dataset in data_dir
split: ["test"] # dataset split to load (e.g. train,dev,test)

# Sampling configurations for testing
test_sample_num: 5
random_sample: False
save_intermediate_data: True
test_sample_num: ~ # number of samples to test (only work in dev/test split), if None, test all samples
random_sample: False # whether to randomly sample the test samples

# Seed for reproducibility
seed: 2024

# Directory paths for data and outputs
data_dir: "datasets/"
#save_dir: "/data00/jiajie_jin/test_project/output"
save_dir: "output/"
# Whether save intermediate data
save_intermediate_data: True
save_note: 'experiment'

# -------------------------------------------------Retrieval Settings------------------------------------------------#
# If set the name, the model path will be find in global paths
retrieval_method: "e5" # name or path of the retrieval model.
index_path: ~ # set automatically if not provided.
faiss_gpu: False # whether use gpu to hold index
corpus_path: ~ # path to corpus in '.jsonl' format that store the documents

# ----Retrieval Settings----
retrieval_method: "e5" # name or path of the retrieval model
index_path: ~ # Set automatically if not provided
corpus_path: "index/wiki_dump.jsonl"
retrieval_pooling_method: ~

retrieval_topk: 5
retrieval_batch_size: 256
retrieval_use_fp16: True
retrieval_query_max_length: 128
save_retrieval_cache: False
use_retrieval_cache: False
retrieval_cache_path: ~

use_reranker: False
rerank_model_name: e5
rerank_model_path: ~
retrieval_topk: 5 # number of retrieved documents
retrieval_batch_size: 256 # batch size for retrieval
retrieval_use_fp16: True # whether to use fp16 for retrieval model
retrieval_query_max_length: 128 # max length of the query
save_retrieval_cache: True # whether to save the retrieval cache
use_retrieval_cache: False # whether to use the retrieval cache
retrieval_cache_path: ~ # path to the retrieval cache
retrieval_pooling_method: ~ # set automatically if not provided

use_reranker: False # whether to use reranker
rerank_model_name: ~ # same as retrieval_method
rerank_model_path: ~ # path to reranker model, path will be automatically find in `retriever_model2path`
rerank_pooling_method: ~
rerank_topk: 5 # number of remain documents after reranking
rerank_max_length: 512
rerank_batch_size: 256 # batch size for reranker
rerank_use_fp16: True
rerank_topk: 5
rerank_max_length: 512
rerank_batch_size: 256

# ----Generator Settings----
use_vllm: False
generator_model: "llama3-8B-instruct" # name or path of the generator
generator_max_input_len: 4096
generator_batch_size: 4
generation_params:
do_sample: False
max_tokens: 32
temperature: 0.1

# -------------------------------------------------Generator Settings------------------------------------------------#
framework: hf # inference frame work of LLM, supporting: 'hf','vllm','fschat'
generator_model: "llama2-7B-chat" # name or path of the generator model
generator_max_input_len: 1024 # max length of the input
generator_batch_size: 2 # batch size for generation, invalid for vllm
generation_params:
max_tokens: 64
temperature: 1.0
top_p: 1.0
vllm_gpu_memory_utilization: 0.8

# ----Evaluation Settings----
#metrics: ['em','f1','sub_em','precision','recall','retrieval_recall','rouge-1','rouge-l', 'bleu']
metrics: ['em','f1','sub_em','precision','recall']
save_metric_score: True

# ---index building
index_doc_max_length: 256
index_batch_size: 4096
index_use_fp16: True
index_save_dir: "indexes/"
index_corpus_path: ~ # path to jsonl file, only used in building index
use_fid: False # whether to use FID, only valid in encoder-decoder model


# -------------------------------------------------Refiner Settings------------------------------------------------#
# If set, the refiner will be used to refine the retrieval documents.
refiner_name: ~
refiner_model_path: ~

# Used for extractive method (e.g embedding models)
refiner_topk: 5 # number of remain sentence after refiner
refiner_pooling_method: 'mean' # pooling method of refiner model
refiner_encode_max_length: 256
# Used for abstractive method (e.g. generation models like bart-large-cnn)
refiner_max_input_length: 1024
refiner_max_output_length: 512

# Specify settings for llmlingua
llmlingua_config:
'rate': 0.55,
'condition_in_question': 'after_condition',
'reorder_context': 'sort',
'dynamic_context_compression_ratio': 0.3,
'condition_compare': True,
'context_budget': "+100",
'rank_method': 'longllmlingua'
sc_config:
'reduce_ratio': 0.5

# -------------------------------------------------Evaluation Settings------------------------------------------------#
# Metrics to evaluate the result
metrics: ['em','f1','sub_em','precision','recall']
# Specify setting for metric, will be called within certain metrics
metric_setting:
retrieval_recall_topk: 5
save_metric_score: True # whether to save the metric score into txt file



5 changes: 2 additions & 3 deletions examples/methods/run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def retrobust(args):
'generator_lora_path': lora_path,
'generation_params':{"max_tokens":100},
'gpu_id':args.gpu_id,
'use_vllm':False,
'generator_max_input_len': 4096,
'dataset_name':args.dataset_name}
config = Config('my_config.yaml',config_dict)
Expand Down Expand Up @@ -258,7 +257,7 @@ def skr(args):
def selfrag(args):
config_dict = {'generator_model':'selfrag-llama2-7B',
'generator_model_path': 'model/selfrag_llama2_7b',
'use_vllm': True,
'framework': 'vllm',
'save_note':'self-rag',
'gpu_id':args.gpu_id,
'generation_params':{'max_new_tokens':100,'temperature':0.0,'top_p':1.0,'skip_special_tokens':False},
Expand All @@ -277,7 +276,7 @@ def selfrag(args):
result = pipeline.run(test_data, batch_size=256)

def flare(args):
config_dict={'save_note':'flare', 'gpu_id':args.gpu_id, 'use_vllm':True,
config_dict={'save_note':'flare', 'gpu_id':args.gpu_id,
'dataset_name':args.dataset_name}
config = Config('my_config.yaml',config_dict)
all_split = get_dataset(config)
Expand Down
2 changes: 1 addition & 1 deletion flashrag/config/basic_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ rerank_batch_size: 256 # batch size for reranker
rerank_use_fp16: True

# -------------------------------------------------Generator Settings------------------------------------------------#
use_vllm: False
framework: hf # inference frame work of LLM, supporting: 'hf','vllm','fschat'
generator_model: "llama2-7B-chat" # name or path of the generator model
generator_max_input_len: 1024 # max length of the input
generator_batch_size: 2 # batch size for generation, invalid for vllm
Expand Down
48 changes: 37 additions & 11 deletions flashrag/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def generate(self, input_list: List[str], return_raw_output=False, return_scores
return base_output


class CausalLMGenerator(BaseGenerator):
class HFCausalLMGenerator(BaseGenerator):
"""Class for decoder-only generator, based on hf. """

def __init__(self, config, model=None):
Expand All @@ -233,15 +233,12 @@ def _load_model(self, model=None):
"""
if model is None:
from fastchat.model import load_model
model, tokenizer = load_model(self.model_path,
device = 'cuda',
num_gpus = self.gpu_num,
load_8bit = False,
cpu_offloading = False,
debug = False,)
#model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype="auto", device_map="auto")
#tokenizer = AutoTokenizer.from_pretrained(self.model_path)
model = AutoModelForCausalLM.from_pretrained(
self.model_path,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(self.model_path)

else:
model.cuda()
Expand Down Expand Up @@ -288,7 +285,7 @@ def generate(self, input_list: List[str], batch_size=None, return_scores=False,
padding=True,
truncation=True,
max_length=self.max_input_len
).to(self.model.device)
).to('cuda')
outputs = self.model.generate(
**inputs,
output_scores=True,
Expand Down Expand Up @@ -357,3 +354,32 @@ def cal_gen_probs(self, prev, next):
logits = logits[range(len(target_ids)), target_ids]

return target_ids, logits


class FastChatGenerator(HFCausalLMGenerator):
def __init__(self, config, model=None):
super().__init__(config)

def _load_model(self, model=None):
r"""Load model and tokenizer for generator.
"""
if model is None:
from fastchat.model import load_model
model, tokenizer = load_model(self.model_path,
device = 'cuda',
num_gpus = self.gpu_num,
load_8bit = False,
cpu_offloading = False,
debug = False,)

else:
model.cuda()
tokenizer = AutoTokenizer.from_pretrained(self.model_path)
model.eval()
if 'qwen' not in self.model_name:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

return model, tokenizer

20 changes: 13 additions & 7 deletions flashrag/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,29 @@ def get_dataset(config):

def get_generator(config, **params):
r"""Automatically select generator class based on config."""
if "t5" in config['generator_model'] or "bart" in config['generator_model']:
if config['framework'] == 'vllm':
return getattr(
importlib.import_module("flashrag.generator"),
"EncoderDecoderGenerator"
)(config, **params)
importlib.import_module("flashrag.generator"),
"VLLMGenerator"
)(config, **params)
elif config['framework'] == 'fschat':
return getattr(
importlib.import_module("flashrag.generator"),
"FastChatGenerator"
)(config, **params)
else:
if config['use_vllm']:
if "t5" in config['generator_model'] or "bart" in config['generator_model']:
return getattr(
importlib.import_module("flashrag.generator"),
"VLLMGenerator"
"EncoderDecoderGenerator"
)(config, **params)
else:
return getattr(
importlib.import_module("flashrag.generator"),
"CausalLMGenerator"
"HFCausalLMGenerator"
)(config, **params)


def get_retriever(config):
r"""Automatically select retriever class based on config's retrieval method
Expand Down

0 comments on commit 0e591ef

Please sign in to comment.