diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 2616057a..878ba016 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -9,7 +9,7 @@ from pprint import pformat, pprint from tempfile import TemporaryDirectory from typing import Dict, Optional, Tuple, Union - +import re import click import math import numpy as np @@ -1166,7 +1166,10 @@ def best_training_run( config = yaml.safe_load(f) with open(config_file_out, "w") as f: - yaml.dump({"model": config["model"]}, f) + yaml.dump({"model": config["model"], + "rare_variant_annotations": config["training_data"]["dataset_config"]["rare_embedding"]["config"]["annotations"], + "training_data_thresholds": {k: str(re.sub(f"^{k} ", "", v)) for k,v in config["training_data"]["dataset_config"]["rare_embedding"]["config"]["thresholds"].items()} + }, f) n_bags = config["training"]["n_bags"] if not debug else 3 for k in range(n_bags):