Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions openstef/tasks/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def train_model_task(
datetime_start: Optional[datetime] = None,
datetime_end: Optional[datetime] = None,
ignore_existing_models: bool = DEFAULT_IGNORE_EXISTING_MODELS,
train_period_days: int = TRAINING_PERIOD_DAYS,
maximum_model_age: int = MAXIMUM_MODEL_AGE,
) -> None:
"""Train model task.

Expand All @@ -68,6 +70,9 @@ def train_model_task(
check_old_model_age: check if model is too young to be retrained
datetime_start: Start
datetime_end: End
ignore_existing_models: Ignore existing models when training
train_period_days: Number of days to fetch for training
maximum_model_age: Maximum model age in days to skip retraining

Raises:
SkipSaveTrainingForecasts: If old model is better or too young, you don't need to save the traing forcast.
Expand Down Expand Up @@ -113,19 +118,19 @@ def train_model_task(
)

# Check old model age and continue yes/no
if (old_model_age < MAXIMUM_MODEL_AGE) and check_old_model_age:
if (old_model_age < maximum_model_age) and check_old_model_age:
context.perf_meter.checkpoint(
f"Old model is younger than {MAXIMUM_MODEL_AGE} days, skip training"
f"Old model is younger than {maximum_model_age} days, skip training"
)
if pj.save_train_forecasts:
raise SkipSaveTrainingForecasts
return

# Define start and end of the training input data
training_period_days_to_fetch = (
TRAINING_PERIOD_DAYS
train_period_days
if pj.data_balancing_ratio is None
else int(pj.data_balancing_ratio * TRAINING_PERIOD_DAYS)
else int(pj.data_balancing_ratio * train_period_days)
)

if datetime_end is None:
Expand Down