diff --git a/openstef/tasks/train_model.py b/openstef/tasks/train_model.py index 98e078268..855102e0e 100644 --- a/openstef/tasks/train_model.py +++ b/openstef/tasks/train_model.py @@ -53,6 +53,8 @@ def train_model_task( datetime_start: Optional[datetime] = None, datetime_end: Optional[datetime] = None, ignore_existing_models: bool = DEFAULT_IGNORE_EXISTING_MODELS, + train_period_days: int = TRAINING_PERIOD_DAYS, + maximum_model_age: int = MAXIMUM_MODEL_AGE, ) -> None: """Train model task. @@ -68,6 +70,9 @@ def train_model_task( check_old_model_age: check if model is too young to be retrained datetime_start: Start datetime_end: End + ignore_existing_models: Ignore existing models when training + train_period_days: Number of days to fetch for training + maximum_model_age: Maximum model age in days to skip retraining Raises: SkipSaveTrainingForecasts: If old model is better or too young, you don't need to save the traing forcast. @@ -113,9 +118,9 @@ def train_model_task( ) # Check old model age and continue yes/no - if (old_model_age < MAXIMUM_MODEL_AGE) and check_old_model_age: + if (old_model_age < maximum_model_age) and check_old_model_age: context.perf_meter.checkpoint( - f"Old model is younger than {MAXIMUM_MODEL_AGE} days, skip training" + f"Old model is younger than {maximum_model_age} days, skip training" ) if pj.save_train_forecasts: raise SkipSaveTrainingForecasts @@ -123,9 +128,9 @@ def train_model_task( # Define start and end of the training input data training_period_days_to_fetch = ( - TRAINING_PERIOD_DAYS + train_period_days if pj.data_balancing_ratio is None - else int(pj.data_balancing_ratio * TRAINING_PERIOD_DAYS) + else int(pj.data_balancing_ratio * train_period_days) ) if datetime_end is None: