Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sillm/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits")
parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits")
parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity")
parser.add_argument("--wandb", default=False, action="store_true", help="Log training to Weights & Biases")
parser.add_argument("--wandb_project", default=None, type=str, help="Weights & Biases project name")
parser.add_argument("--wandb_report_steps", default=5, type=int, help="Number of batch iterations per Weights & Biases report (default: 5)")
args = parser.parse_args()

# Change working directory
Expand Down Expand Up @@ -133,6 +136,9 @@ def eval_callback(i, val_loss):
"report_steps": args.report_steps,
"eval_steps": args.eval_steps,
"validation_samples": args.validation_samples,
"wandb": args.wandb,
"wandb_project": args.wandb_project,
"wandb_report_steps": args.wandb_report_steps,
}
model.train(dataset_training,
dataset_validation,
Expand Down
6 changes: 6 additions & 0 deletions sillm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits")
parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits")
parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity")
parser.add_argument("--wandb", default=False, action="store_true", help="Log training to Weights & Biases")
parser.add_argument("--wandb_project", default=None, type=str, help="Weights & Biases project name")
parser.add_argument("--wandb_report_steps", default=5, type=int, help="Number of batch iterations per Weights & Biases report (default: 5)")
args = parser.parse_args()

# Change working directory
Expand Down Expand Up @@ -146,6 +149,9 @@ def eval_callback(i, val_loss):
"report_steps": args.report_steps,
"eval_steps": args.eval_steps,
"validation_samples": args.validation_samples,
"wandb": args.wandb,
"wandb_project": args.wandb_project,
"wandb_report_steps": args.wandb_report_steps,
}
model.train(dataset_training,
dataset_validation,
Expand Down
7 changes: 7 additions & 0 deletions sillm/reporting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from sillm.reporting import wandb
from sillm.reporting.wandb import WandBLogger

__all__ = [
"wandb",
"WandBLogger"
]
135 changes: 135 additions & 0 deletions sillm/reporting/wandb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import os
import logging
from typing import Any, Dict, Optional, Union

import wandb
import numpy as np
import mlx.core as mx

logger = logging.getLogger("sillm")

class WandBLogger:
"""
Weights & Biases logger for tracking training metrics and experiments.
"""
def __init__(self,
project: str = "sillm",
name: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
tags: Optional[list] = None,
resume: bool = False,
id: Optional[str] = None
):
"""
Initialize WandB logger.

Args:
project: WandB project name
name: Run name (optional)
config: Run configuration/hyperparameters (optional)
tags: Run tags (optional)
resume: Whether to resume a previous run
id: Run ID to resume (optional)
"""
self.run = wandb.init(
project=project,
name=name,
config=config,
tags=tags,
resume=resume,
id=id
)

self._step = 0
logger.info(f"Initialized WandB logger - Project: {project}, Run: {self.run.name}")

@property
def step(self) -> int:
"""Current logging step."""
return self._step

def _convert_value(self, value: Any) -> Any:
"""Convert MLX arrays and other types to standard Python types."""
if isinstance(value, mx.array):
return value.item()
elif isinstance(value, np.ndarray):
return value.tolist()
return value

def log(self,
metrics: Dict[str, Any],
step: Optional[int] = None,
commit: bool = True
):
"""
Log metrics to WandB.

Args:
metrics: Dictionary of metric names and values
step: Optional step number (default: auto-increment)
commit: Whether to commit the logs immediately
"""
if step is not None:
self._step = step

# Convert any MLX arrays to Python types
metrics = {k: self._convert_value(v) for k, v in metrics.items()}

# Log metrics
wandb.log(metrics, step=self._step, commit=commit)

if commit:
self._step += 1

def log_hyperparams(self, params: Dict[str, Any]):
"""
Log hyperparameters to WandB config.

Args:
params: Dictionary of hyperparameter names and values
"""
# Convert any MLX arrays to Python types
params = {k: self._convert_value(v) for k, v in params.items()}

# Update wandb config
wandb.config.update(params, allow_val_change=True)

def log_model(self,
artifact_name: str,
path: str,
metadata: Optional[Dict[str, Any]] = None,
type: str = "model"
):
"""
Log model files/weights as a WandB artifact.

Args:
artifact_name: Name for the artifact
path: Path to model file/directory
metadata: Optional metadata to attach to artifact
type: Artifact type (default: "model")
"""
artifact = wandb.Artifact(
name=artifact_name,
type=type,
metadata=metadata
)

if os.path.isfile(path):
artifact.add_file(path)
else:
artifact.add_dir(path)

self.run.log_artifact(artifact)

def finish(self):
"""End the WandB run."""
if self.run is not None:
self.run.finish()
self.run = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.finish()
66 changes: 60 additions & 6 deletions sillm/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from sillm.core.llm import LLM
from sillm.training.dataset import Dataset
from sillm.reporting import WandBLogger

logger = logging.getLogger("sillm")

Expand Down Expand Up @@ -102,7 +103,10 @@ def train(self,
report_callback: callable = None,
eval_steps: int = 100,
eval_callback: callable = None,
validation_samples: int = 40
validation_samples: int = 40,
wandb: bool = False,
wandb_project: str = "sillm",
wandb_report_steps: int = 5
):
"""
Train model.
Expand All @@ -118,6 +122,9 @@ def train(self,
eval_callback: Callback after eval.
validation_samples: Number of validation samples.
debug: Whether to enable debug mode.
wandb: Whether to enable WandB logging.
wandb_project: WandB project name.
wandb_report_steps: Report every `wandb_report_steps` iterations to WandB.
"""
# Calculate number of iterations
if iterations == 0:
Expand All @@ -131,6 +138,24 @@ def train(self,

# Get system memory
system_memory = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")

if wandb:
# Initialize WandB logger
wandb_config = {
"optimizer": optimizer_type,
"learning_rate": learning_rate,
"learning_decay": learning_decay,
"learning_warmup": learning_warmup,
"batch_size": batch_size,
"gradient_checkpointing": gradient_checkpointing,
"gradient_accumulation_steps": gradient_accumulation_steps,
"epochs": epochs,
"iterations": iterations
}
wandb_logger = WandBLogger(
project=wandb_project,
config=wandb_config
)

# Initialize scheduler
if learning_decay > 0.0:
Expand All @@ -156,6 +181,8 @@ def train(self,
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")

optimizer = optimizer.init(self.model.trainable_parameters())

# Initialize gradient accumulation
if gradient_accumulation_steps > 1:
accum_grad = None
Expand Down Expand Up @@ -253,10 +280,24 @@ def step(batch):
else:
rewards = np.vstack([rewards, reward])

# Calculate loss and timings
train_loss = np.mean(losses)
stop = time.perf_counter()

# Log to WandB if needed
if wandb and (n + 1) % wandb_report_steps == 0:
metrics = {
"train/loss": train_loss,
"train/tokens_per_sec": float(intv_tokens) / (stop - start),
"train/learning_rate": optimizer.learning_rate.item()
}
if rewards is not None:
metrics["train/reward_chosen"] = np.mean(rewards, axis=0)[0]
metrics["train/reward_rejected"] = np.mean(rewards, axis=0)[1]
wandb_logger.log(metrics)

# Report training loss if needed
if (n + 1) % report_steps == 0:
train_loss = np.mean(losses)
stop = time.perf_counter()

# Print training loss and timings
pbar_epochs.write(f"#{n + 1}:\tTraining loss {train_loss:.3f}\t{float(intv_tokens) / (stop - start):.3f} tok/sec (learning rate: {optimizer.learning_rate.item():.3e})")
Expand Down Expand Up @@ -286,6 +327,12 @@ def step(batch):
# Print validation loss and timings
stop = time.perf_counter()
val_loss = self.evaluate(dataset_validation, batch_size, validation_batches)

# Log to WandB
if wandb:
wandb_logger.log({
"val/loss": val_loss
})
start = time.perf_counter()
pbar_epochs.write(f"#{n + 1}:\tValidation loss {val_loss:.3f}\t{(start - stop):.3f} sec")

Expand All @@ -300,8 +347,15 @@ def step(batch):
# Evaluate test dataset
if dataset_test is not None:
test_batches = validation_samples // batch_size

stop = time.perf_counter()
test_loss = self.evaluate(dataset_test, batch_size, test_batches)

stop = time.perf_counter()
start = time.perf_counter()
logger.info(f"Test loss: {test_loss:.3f}\t{start - stop:.3f} sec")
logger.info(f"Test loss: {test_loss:.3f}\t{start - stop:.3f} sec")

# Log to WandB
if wandb:
wandb_logger.log({
"test/loss": test_loss
})
wandb_logger.finish()