Skip to content

Commit

Permalink
Fix unused imports and trailing whitespaces with ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
perkfly committed Jun 5, 2024
1 parent bc4a97e commit fed0f0f
Show file tree
Hide file tree
Showing 31 changed files with 345 additions and 373 deletions.
18 changes: 9 additions & 9 deletions examples/methods/run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def naive(args):

pred_process_fun = lambda x: x.split("\n")[0]
pipeline = SequentialPipeline(config)

result = pipeline.run(test_data)

def zero_shot(args):
Expand All @@ -33,7 +33,7 @@ def zero_shot(args):
from flashrag.pipeline import SequentialPipeline
from flashrag.prompt import PromptTemplate
templete = PromptTemplate(
config = config,
config = config,
system_prompt = "Answer the question based on your own knowledge. Only give me the answer and do not output any other words.",
user_prompt = "Question: {question}"
)
Expand All @@ -53,9 +53,9 @@ def aar(args):
# index path of this retriever
retrieval_method = args.method_name
if 'contriever' in retrieval_method:
index_path = "aar-contriever_Flat.index"
index_path = "aar-contriever_Flat.index"
else:
index_path = "aar-ance_Flat.index"
index_path = "aar-ance_Flat.index"

model2path = {"AAR-contriever": "model/AAR-Contriever-KILT",
"AAR-ANCE": "model/AAR-ANCE"}
Expand Down Expand Up @@ -92,7 +92,7 @@ def llmlingua(args):
in ICLR MEFoMo 2024.
Official repo: https://github.com/microsoft/LLMLingua
"""
refiner_name = "longllmlingua" #
refiner_name = "longllmlingua" #
refiner_model_path = "model/llama-2-7b-hf"

config_dict = {
Expand Down Expand Up @@ -183,7 +183,7 @@ def sc(args):
pip install en_core_web_sm-3.6.0.tar.gz
```
"""
refiner_name = "selective-context"
refiner_name = "selective-context"
refiner_model_path = "model/gpt2"

config_dict = {
Expand Down Expand Up @@ -239,7 +239,7 @@ def retrobust(args):
from flashrag.pipeline import SelfAskPipeline
from flashrag.utils import selfask_pred_parse
pipeline = SelfAskPipeline(config, max_iter=5, single_hop=False)
# use specify prediction parse function
# use specify prediction parse function
result = pipeline.run(test_data, pred_process_fun=selfask_pred_parse)

def sure(args):
Expand Down Expand Up @@ -349,7 +349,7 @@ def selfrag(args):
test_data = all_split[args.split]

from flashrag.pipeline import SelfRAGPipeline
pipeline = SelfRAGPipeline(config, threhsold=0.2, max_depth=2, beam_width=2,
pipeline = SelfRAGPipeline(config, threhsold=0.2, max_depth=2, beam_width=2,
w_rel=1.0, w_sup=1.0, w_use=1.0,
use_grounding=True, use_utility=True, use_seqscore=True, ignore_cont=True,
mode='adaptive_retrieval')
Expand All @@ -363,7 +363,7 @@ def flare(args):
Official repo: https://github.com/bbuing9/ICLR24_SuRe
"""
config_dict={'save_note':'flare', 'gpu_id':args.gpu_id,
config_dict={'save_note':'flare', 'gpu_id':args.gpu_id,
'dataset_name':args.dataset_name}
config = Config('my_config.yaml',config_dict)
all_split = get_dataset(config)
Expand Down
6 changes: 3 additions & 3 deletions examples/quick_start/simple_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str)
parser.add_argument('--retriever_path', type=str)
args = parser.parse_args()
args = parser.parse_args()

config_dict = {
config_dict = {
'data_dir': 'dataset/',
'index_path': 'indexes/e5_flat_sample.index',
'corpus_path': 'indexes/sample_data.jsonl',
Expand All @@ -26,7 +26,7 @@
all_split = get_dataset(config)
test_data = all_split['test']
prompt_templete = PromptTemplate(
config,
config,
system_prompt = "Answer the question based on the given document. \
Only give me the answer and do not output any other words. \
\nThe following are given documents.\n\n{reference}",
Expand Down
3 changes: 0 additions & 3 deletions flashrag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
30 changes: 13 additions & 17 deletions flashrag/config/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import re
import os
import yaml
import json
import random
import copy
import importlib
import sys
import datetime

class Config:
Expand Down Expand Up @@ -60,7 +56,7 @@ def _load_file_config(self, config_file_path:str):
def _update_dict(old_dict: dict, new_dict: dict):
# Update the original update method of the dictionary:
# If there is the same key in `old_dict` and `new_dict`, and value is of type dict, update the key in dict

same_keys = []
for key,value in new_dict.items():
if key in old_dict and isinstance(value, dict):
Expand All @@ -73,7 +69,7 @@ def _update_dict(old_dict: dict, new_dict: dict):

old_dict.update(new_dict)
return old_dict


def _merge_external_config(self):
external_config = dict()
Expand All @@ -84,11 +80,11 @@ def _merge_external_config(self):

def _get_internal_config(self):
current_path = os.path.dirname(os.path.realpath(__file__))
init_config_path = os.path.join(current_path, "basic_config.yaml")
init_config_path = os.path.join(current_path, "basic_config.yaml")
internal_config = self._load_file_config(init_config_path)

return internal_config

def _get_final_config(self):
final_config = dict()
final_config = self._update_dict(final_config, self.internal_config)
Expand All @@ -114,7 +110,7 @@ def _init_device(self):
else:
import torch
self.final_config['device'] = torch.device('cpu')


def _set_additional_key(self):
# set dataset
Expand All @@ -135,7 +131,7 @@ def _set_additional_key(self):
self.final_config['index_path'] = method2index[retrieval_method]
except:
print("Index is empty!!")
assert False
assert False

self.final_config['retrieval_model_path'] = model2path.get(retrieval_method, retrieval_method)
# TODO: not support when `retrieval_model` is path
Expand All @@ -148,12 +144,12 @@ def set_pooling_method(method, model2pooling):

if self.final_config.get('retrieval_pooling_method') is None:
self.final_config['retrieval_pooling_method'] = set_pooling_method(retrieval_method, model2pooling)


rerank_model_name = self.final_config['rerank_model_name']
if self.final_config.get('rerank_model_path') is None:
if rerank_model_name is not None:
self.final_config['rerank_model_path'] = model2path.get(rerank_model_name, rerank_model_name)
self.final_config['rerank_model_path'] = model2path.get(rerank_model_name, rerank_model_name)
if self.final_config['rerank_pooling_method'] is None:
if rerank_model_name is not None:
self.final_config['rerank_pooling_method'] = set_pooling_method(
Expand All @@ -171,7 +167,7 @@ def set_pooling_method(method, model2pooling):
def _prepare_dir(self):
save_note = self.final_config['save_note']
current_time = datetime.datetime.now()
self.final_config['save_dir'] = os.path.join(self.final_config['save_dir'],
self.final_config['save_dir'] = os.path.join(self.final_config['save_dir'],
f"{self.final_config['dataset_name']}_{current_time.strftime('%Y_%m_%d_%H_%M')}_{save_note}")
os.makedirs(self.final_config['save_dir'], exist_ok=True)
# save config parameters
Expand All @@ -190,9 +186,9 @@ def _set_seed(self):
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True





def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str.")
Expand All @@ -201,7 +197,7 @@ def __setitem__(self, key, value):
def __getattr__(self, item):
if "final_config" not in self.__dict__:
raise AttributeError(
f"'Config' object has no attribute 'final_config'"
"'Config' object has no attribute 'final_config'"
)
if item in self.final_config:
return self.final_config[item]
Expand All @@ -214,6 +210,6 @@ def __contains__(self, key):
if not isinstance(key, str):
raise TypeError("index must be a str.")
return key in self.final_config

def __repr__(self):
return self.final_config.__str__()
21 changes: 10 additions & 11 deletions flashrag/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def update_evaluation_score(self, metric_name, metric_score):
if 'metric_score' not in self.output:
self.output['metric_score'] = {}
self.output['metric_score'][metric_name] = metric_score

def __getattr__(self, attr_name):
if attr_name in ['id','question','golden_answers','metadata','output']:
return super().__getattribute__(attr_name)
Expand Down Expand Up @@ -62,14 +62,14 @@ def to_dict(self):
output['metadata'] = self.metadata

return output


class Dataset:
"""A container class used to store the whole dataset. Inside the class, each data sample will be stored
in ```Item``` class.
The properties of the dataset represent the list of attributes corresponding to each item in the dataset.
"""

def __init__(self, config=None, dataset_path=None, data=None, sample_num = None, random_sample = False):
self.config = config
self.dataset_name = config['dataset_name']
Expand Down Expand Up @@ -104,7 +104,7 @@ def _load_data(self, dataset_name, dataset_path):
data = data[:self.sample_num]

return data

def update_output(self, key, value_list):
"""Update the overall output field for each sample in the dataset."""

Expand All @@ -131,7 +131,7 @@ def get_batch_data(self, attr_name:str, batch_size: int):
for i in range(0, len(self.data), batch_size):
batch_items = self.data[i:i+batch_size]
yield [item[attr_name] for item in batch_items]

def __getattr__(self, attr_name):
return [item.__getattr__(attr_name) for item in self.data]

Expand All @@ -141,24 +141,23 @@ def get_attr_data(self, attr_name):
obtain a list of this attribute in the entire dataset.
"""
return [item[attr_name] for item in self.data]

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return len(self.data)

def save(self, save_path):
"""Save the dataset into the original format."""

save_data = [item.to_dict() for item in self.data]
with open(save_path,"w") as f:
json.dump(save_data, f, indent=4)








10 changes: 5 additions & 5 deletions flashrag/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def filter_dataset(dataset: Dataset, filter_func = None):

def split_dataset(dataset: Dataset, split_bool:list):
assert len(split_bool) == len(dataset)

data = dataset.data
pos_data = [x for x,flag in zip(data,split_bool) if flag]
neg_data = [x for x,flag in zip(data,split_bool) if not flag]
Expand All @@ -26,17 +26,17 @@ def merge_dataset(pos_dataset: Dataset, neg_dataset: Dataset, merge_bool: list):

pos_data_iter = iter(pos_dataset.data)
neg_data_iter = iter(neg_dataset.data)

final_data = []

for is_pos in merge_bool:
if is_pos:
final_data.append(next(pos_data_iter))
else:
final_data.append(next(neg_data_iter))

final_dataset = Dataset(config=pos_dataset.config, data=final_data)

return final_dataset

def get_batch_dataset(dataset: Dataset, batch_size=16):
Expand Down
16 changes: 7 additions & 9 deletions flashrag/evaluator/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
import json
import sys
from flashrag.evaluator.metrics import BaseMetric

class Evaluator:
Expand All @@ -9,7 +7,7 @@ class Evaluator:
def __init__(self, config):
self.config = config
self.save_dir = config['save_dir']

self.save_metric_flag = config['save_metric_score']
self.save_data_flag = config['save_intermediate_data']
self.metrics = [metric.lower() for metric in self.config['metrics']]
Expand All @@ -23,7 +21,7 @@ def __init__(self, config):
else:
print(f"{metric} has not been implemented!")
raise NotImplementedError

def _collect_metrics(self):
"""Collect all classes based on ```BaseMetric```."""

Expand All @@ -37,7 +35,7 @@ def find_descendants(base_class, subclasses=None):
subclasses.add(subclass)
find_descendants(subclass, subclasses)
return subclasses

avaliable_metrics = {}
for cls in find_descendants(BaseMetric):
metric_name = cls.metric_name
Expand All @@ -52,23 +50,23 @@ def evaluate(self, data):
try:
metric_result, metric_scores = self.metric_class[metric].calculate_metric(data)
result_dict.update(metric_result)

for metric_score, item in zip(metric_scores, data):
item.update_evaluation_score(metric, metric_score)
except Exception as e:
print(f'Error in {metric}!')
print(e)
continue

if self.save_metric_flag:
self.save_metric_score(result_dict)

if self.save_data_flag:
self.save_data(data)


return result_dict

def save_metric_score(self, result_dict):
file_name = "metric_score.txt"
save_path = os.path.join(self.save_dir, file_name)
Expand Down
Loading

0 comments on commit fed0f0f

Please sign in to comment.