Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility to add orion tags to the final best_hparams.yaml file #45

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/MOABB/run_hparam_optimization.sh
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,5 @@ scp $best_yaml_file $final_yaml_file

echo "The test performance with best hparams is available at $output_folder/best"

# add the orion flags to the best_hparams.yaml file
python utils/rewrite.py $hparams $final_yaml_file
78 changes: 78 additions & 0 deletions benchmarks/MOABB/utils/rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/python3
"""
Yaml file rewriter to add orion tags to the best hparams file from original yaml file.
Author
------
Victor Cruz, 2024
"""
import argparse
import yaml
import re

def readargs():
parser = argparse.ArgumentParser()
parser.add_argument("original_yaml_file", type=str, help="Original yaml file")
parser.add_argument("best_hparams_file", type=str, help="Best hparams file")
args = parser.parse_args()

# Check if the file paths are valid
if not args.original_yaml_file.endswith(".yaml"):
raise ValueError("Original yaml file must be a yaml file")
if not args.best_hparams_file.endswith(".yaml"):
raise ValueError("Best hparams file must be a yaml file")
return args

def extract_orion_tags(original_yaml_file):
"""
Function to extract orion tags and variable names from the original yaml file.
Orion tags are comments that start with '# @orion_step<stepid>'.
"""
orion_tags = {}
tag_pattern = re.compile(r"# @orion_step(\d+):\s*(.*)")

with open(original_yaml_file, "r") as og_f:
for line in og_f:
# Extract lines that contain Orion tags
tag_match = tag_pattern.search(line.strip())
if tag_match:
variable_name = line.split(":")[0].strip() # Get the variable name before ":"
tag_info = tag_match.group(0) # Full tag line
orion_tags[variable_name] = tag_info # Store variable and tag info
return orion_tags

def rewrite_with_orion_tags(original_yaml_file, best_hparams_file):
"""
Function to add orion tags to the best hparams file.
Matches based on the variable name from the original file to the target file.
"""
orion_tags = extract_orion_tags(original_yaml_file)

# Read the best_hparams YAML file
with open(best_hparams_file, "r") as best_f:
best_hparams_lines = best_f.readlines()

# Add orion tags to the appropriate lines in the new file
new_best_hparams_lines = []
for line in best_hparams_lines:
stripped_line = line.strip()
# Extract variable name from the line in the best hparams file
if ":" in stripped_line:
variable_name = stripped_line.split(":")[0].strip()

# Check if this variable has a corresponding orion tag
if variable_name in orion_tags:
# Append the orion tag to the same line, ensuring there's a space before the comment
line = line.rstrip() + " " + orion_tags[variable_name] + "\n"
new_best_hparams_lines.append(line)
else:
new_best_hparams_lines.append(line)
else:
new_best_hparams_lines.append(line)

# Write the modified content back to the best_hparams file
with open(best_hparams_file, "w") as best_f:
best_f.writelines(new_best_hparams_lines)

if __name__ == "__main__":
args = readargs()
rewrite_with_orion_tags(args.original_yaml_file, args.best_hparams_file)