Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented alternative experiment tracking functionality #102

Merged
merged 9 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ data/
# Models
models/

# Weights and Biases experiment tracking
# Experiment tracking
wandb/
mlruns/

# Data files
*.xlsx
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ ______________________________________________________________________

Developers:

- Anders Jess Pedersen ([email protected])
- Dan Saattrup Nielsen ([email protected])
- Simon Leminen Madsen ([email protected])



## Installation
Expand Down
9 changes: 4 additions & 5 deletions config/asr_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
defaults:
- model: wav2vec2-small
- model: whisper-xxsmall
- datasets:
- coral
- decoder_datasets:
- wikipedia
- common_voice
- reddit
- experiment_tracking: wandb
- override hydra/job_logging: custom
- _self_

seed: 4242

experiment_tracking: null

evaluation_dataset:
id: alexandrainst/coral
subset: read_aloud
Expand Down Expand Up @@ -48,10 +51,6 @@ fp16_allowed: true
bf16_allowed: true

# Training parameters
wandb: false
wandb_project: CoRal
wandb_group: default
wandb_name: ${model_id}
resume_from_checkpoint: false
ignore_data_skip: false
save_total_limit: 0 # Will automatically be set to >=1 if `early_stopping` is enabled
Expand Down
3 changes: 3 additions & 0 deletions config/experiment_tracking/mlflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type: mlflow
name_experiment: CoRal
name_run: ${model_id}
4 changes: 4 additions & 0 deletions config/experiment_tracking/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type: wandb
name_experiment: CoRal
name_run: ${model_id}
name_group: default
2,328 changes: 1,552 additions & 776 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ gradio = {version = "^5.5.0", optional=true}
samplerate = {version="^0.2.1", optional=true}
punctfix = {version="^0.11.1", optional=true}
matplotlib = {version = "^3.9.2", optional = true}
mlflow = "^2.17.2"

[tool.poetry.group.dev.dependencies]
pytest = ">=8.1.1"
Expand Down
4 changes: 4 additions & 0 deletions src/coral/experiment_tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""The CoRal project.

Experiment tracking.
"""
28 changes: 28 additions & 0 deletions src/coral/experiment_tracking/extracking_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Factory for experiment tracking setup."""

from omegaconf import DictConfig

from .extracking_setup import ExTrackingSetup
from .mlflow_setup import MLFlowSetup
from .wandb_setup import WandbSetup


def load_extracking_setup(config: DictConfig) -> ExTrackingSetup:
"""Return the experiment tracking setup.

Args:
config:
The configuration object.

Returns:
The experiment tracking setup.
"""
match config.experiment_tracking.type:
case "wandb":
return WandbSetup(config=config)
case "mlflow":
return MLFlowSetup(config=config)
case _:
raise ValueError(
f"Unknown experiment tracking type: {config.experiment_tracking.type}"
)
34 changes: 34 additions & 0 deletions src/coral/experiment_tracking/extracking_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""This module contains the base class for an experiment tracking setup."""

from abc import ABC, abstractmethod

from omegaconf import DictConfig


class ExTrackingSetup(ABC):
"""Base class for an experiment tracking setup."""

@abstractmethod
def __init__(self, config: DictConfig) -> None:
"""Initialise the experiment tracking setup.

Args:
config:
The configuration object.
"""

@abstractmethod
def run_initialization(self) -> None:
"""Run the initialization of the experiment tracking setup.

Returns:
True if the initialization was successful, False otherwise.
saattrupdan marked this conversation as resolved.
Show resolved Hide resolved
"""

@abstractmethod
def run_finalization(self) -> None:
"""Run the finalization of the experiment tracking setup.

Returns:
True if the finalization was successful, False otherwise.
"""
33 changes: 33 additions & 0 deletions src/coral/experiment_tracking/mlflow_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""MLFlow experiment tracking setup class."""

import os

import mlflow
from omegaconf import DictConfig

from .extracking_setup import ExTrackingSetup


class MLFlowSetup(ExTrackingSetup):
"""MLFlow setup class."""

def __init__(self, config: DictConfig) -> None:
"""Initialise the MLFlow setup.

Args:
config:
The configuration object.
"""
self.config = config
self.is_main_process = os.getenv("RANK", "0") == "0"

def run_initialization(self) -> None:
"""Run the initialization of the experiment tracking setup."""
mlflow.set_experiment(self.config.experiment_tracking.name_experiment)
mlflow.start_run(run_name=self.config.experiment_tracking.name_run)
return

def run_finalization(self) -> None:
"""Run the finalization of the experiment tracking setup."""
mlflow.end_run()
return
37 changes: 37 additions & 0 deletions src/coral/experiment_tracking/wandb_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""wandb experiment tracking setup class."""

import os

import wandb
from omegaconf import DictConfig

from .extracking_setup import ExTrackingSetup


class WandbSetup(ExTrackingSetup):
"""Wandb setup class."""

def __init__(self, config: DictConfig) -> None:
"""Initialise the Wandb setup.

Args:
config:
The configuration object.
"""
self.config = config
self.is_main_process = os.getenv("RANK", "0") == "0"

def run_initialization(self) -> None:
"""Run the initialization of the experiment tracking setup."""
wandb.init(
project=self.config.experiment_tracking.name_experiment,
name=self.config.experiment_tracking.name_run,
group=self.config.experiment_tracking.name_group,
config=dict(self.config),
)
return

def run_finalization(self) -> None:
"""Run the finalization of the experiment tracking setup."""
wandb.finish()
return
18 changes: 7 additions & 11 deletions src/coral/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from omegaconf import DictConfig
from transformers import EarlyStoppingCallback, TrainerCallback
from wandb import finish as wandb_finish
from wandb.sdk.wandb_init import init as wandb_init

from .data import load_data_for_finetuning
from .data_models import ModelSetup
from .experiment_tracking.extracking_factory import load_extracking_setup
from .model_setup import load_model_setup
from .ngram import train_and_store_ngram_model
from .utils import block_terminal_output, disable_tqdm, push_model_to_hub
Expand All @@ -33,13 +32,9 @@ def finetune(config: DictConfig) -> None:
model = model_setup.load_model()
dataset = load_data_for_finetuning(config=config, processor=processor)

if config.wandb and is_main_process:
wandb_init(
project=config.wandb_project,
group=config.wandb_group,
name=config.wandb_name,
config=dict(config),
)
if bool(config.experiment_tracking) and is_main_process:
extracking_setup = load_extracking_setup(config=config)
extracking_setup.run_initialization()

if "val" not in dataset and is_main_process:
logger.info("No validation set found. Disabling early stopping.")
Expand All @@ -58,8 +53,9 @@ def finetune(config: DictConfig) -> None:
block_terminal_output()
with disable_tqdm():
trainer.train(resume_from_checkpoint=config.resume_from_checkpoint)
if config.wandb and is_main_process:
wandb_finish()

if bool(config.experiment_tracking) and is_main_process:
extracking_setup.run_finalization()

model.save_pretrained(save_directory=config.model_dir)

Expand Down
4 changes: 3 additions & 1 deletion src/coral/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def load_training_arguments(self) -> TrainingArguments:
optim=OptimizerNames.ADAMW_TORCH,
adam_beta1=self.config.adam_first_momentum,
adam_beta2=self.config.adam_second_momentum,
report_to=["wandb"] if self.config.wandb else [],
report_to=[self.config.experiment_tracking.type]
if self.config.experiment_tracking
else [],
ignore_data_skip=self.config.ignore_data_skip,
save_safetensors=True,
use_cpu=hasattr(sys, "_called_from_test"),
Expand Down
4 changes: 3 additions & 1 deletion src/coral/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def load_training_arguments(self) -> TrainingArguments:
optim=OptimizerNames.ADAMW_TORCH,
adam_beta1=self.config.adam_first_momentum,
adam_beta2=self.config.adam_second_momentum,
report_to=["wandb"] if self.config.wandb else [],
report_to=[self.config.experiment_tracking.type]
if self.config.experiment_tracking
else [],
ignore_data_skip=self.config.ignore_data_skip,
save_safetensors=True,
predict_with_generate=True,
Expand Down
Loading