Skip to content

Commit

Permalink
fix(config/dataset): fix save directory and seed setting issues
Browse files Browse the repository at this point in the history
- Fix save directory path issue in Config class
- Add seed validation and conversion to integer
- Ensure sample_num is converted to integer in Dataset class
  • Loading branch information
ignorejjj committed Jan 8, 2025
1 parent 82dca4c commit 20cf8f6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
15 changes: 12 additions & 3 deletions flashrag/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,14 @@ def set_retrieval_keys(model2path, model2pooling, method2index, config):

def _prepare_dir(self):
save_note = self.final_config["save_note"]
save_dir = self.final_config['save_dir']
if not save_dir.endswith("/"):
save_dir += "/"

current_time = datetime.datetime.now()

self.final_config["save_dir"] = os.path.join(
self.final_config["save_dir"],
save_dir,
f"{self.final_config['dataset_name']}_{current_time.strftime('%Y_%m_%d_%H_%M')}_{save_note}",
)
os.makedirs(self.final_config["save_dir"], exist_ok=True)
Expand All @@ -251,8 +256,12 @@ def _prepare_dir(self):
def _set_seed(self):
import torch
import numpy as np

seed = self.final_config["seed"]
seed = self.final_config['seed']
try:
seed = int(seed)
except:
seed = 2025
self.final_config['seed'] = seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Expand Down
1 change: 1 addition & 0 deletions flashrag/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _load_data(self, dataset_name: str, dataset_path: str) -> List[Item]:
raise NotImplementedError

if self.sample_num is not None:
self.sample_num = int(self.sample_num)
if self.random_sample:
print(f"Random sample {self.sample_num} items in test set.")
data = random.sample(data, self.sample_num)
Expand Down

0 comments on commit 20cf8f6

Please sign in to comment.