diff --git a/sillm/dpo.py b/sillm/dpo.py index 97c0cfc..fd6ea99 100644 --- a/sillm/dpo.py +++ b/sillm/dpo.py @@ -37,6 +37,9 @@ parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits") parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits") parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity") + parser.add_argument("--wandb", default=False, action="store_true", help="Log training to Weights & Biases") + parser.add_argument("--wandb_project", default=None, type=str, help="Weights & Biases project name") + parser.add_argument("--wandb_report_steps", default=5, type=int, help="Number of batch iterations per Weights & Biases report (default: 5)") args = parser.parse_args() # Change working directory @@ -133,6 +136,9 @@ def eval_callback(i, val_loss): "report_steps": args.report_steps, "eval_steps": args.eval_steps, "validation_samples": args.validation_samples, + "wandb": args.wandb, + "wandb_project": args.wandb_project, + "wandb_report_steps": args.wandb_report_steps, } model.train(dataset_training, dataset_validation, diff --git a/sillm/lora.py b/sillm/lora.py index 1d5d7fe..fad55ef 100644 --- a/sillm/lora.py +++ b/sillm/lora.py @@ -43,6 +43,9 @@ parser.add_argument("-q4", default=False, action="store_true", help="Quantize the model to 4 bits") parser.add_argument("-q8", default=False, action="store_true", help="Quantize the model to 8 bits") parser.add_argument("-v", "--verbose", default=1, action="count", help="Increase output verbosity") + parser.add_argument("--wandb", default=False, action="store_true", help="Log training to Weights & Biases") + parser.add_argument("--wandb_project", default=None, type=str, help="Weights & Biases project name") + parser.add_argument("--wandb_report_steps", default=5, type=int, help="Number of batch iterations per Weights & Biases report (default: 5)") args = parser.parse_args() # Change working directory @@ -146,6 +149,9 @@ def eval_callback(i, val_loss): "report_steps": args.report_steps, "eval_steps": args.eval_steps, "validation_samples": args.validation_samples, + "wandb": args.wandb, + "wandb_project": args.wandb_project, + "wandb_report_steps": args.wandb_report_steps, } model.train(dataset_training, dataset_validation, diff --git a/sillm/reporting/__init__.py b/sillm/reporting/__init__.py new file mode 100644 index 0000000..bfed9ca --- /dev/null +++ b/sillm/reporting/__init__.py @@ -0,0 +1,7 @@ +from sillm.reporting import wandb +from sillm.reporting.wandb import WandBLogger + +__all__ = [ + "wandb", + "WandBLogger" +] diff --git a/sillm/reporting/wandb.py b/sillm/reporting/wandb.py new file mode 100644 index 0000000..6ef7c41 --- /dev/null +++ b/sillm/reporting/wandb.py @@ -0,0 +1,135 @@ +import os +import logging +from typing import Any, Dict, Optional, Union + +import wandb +import numpy as np +import mlx.core as mx + +logger = logging.getLogger("sillm") + +class WandBLogger: + """ + Weights & Biases logger for tracking training metrics and experiments. + """ + def __init__(self, + project: str = "sillm", + name: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + tags: Optional[list] = None, + resume: bool = False, + id: Optional[str] = None + ): + """ + Initialize WandB logger. + + Args: + project: WandB project name + name: Run name (optional) + config: Run configuration/hyperparameters (optional) + tags: Run tags (optional) + resume: Whether to resume a previous run + id: Run ID to resume (optional) + """ + self.run = wandb.init( + project=project, + name=name, + config=config, + tags=tags, + resume=resume, + id=id + ) + + self._step = 0 + logger.info(f"Initialized WandB logger - Project: {project}, Run: {self.run.name}") + + @property + def step(self) -> int: + """Current logging step.""" + return self._step + + def _convert_value(self, value: Any) -> Any: + """Convert MLX arrays and other types to standard Python types.""" + if isinstance(value, mx.array): + return value.item() + elif isinstance(value, np.ndarray): + return value.tolist() + return value + + def log(self, + metrics: Dict[str, Any], + step: Optional[int] = None, + commit: bool = True + ): + """ + Log metrics to WandB. + + Args: + metrics: Dictionary of metric names and values + step: Optional step number (default: auto-increment) + commit: Whether to commit the logs immediately + """ + if step is not None: + self._step = step + + # Convert any MLX arrays to Python types + metrics = {k: self._convert_value(v) for k, v in metrics.items()} + + # Log metrics + wandb.log(metrics, step=self._step, commit=commit) + + if commit: + self._step += 1 + + def log_hyperparams(self, params: Dict[str, Any]): + """ + Log hyperparameters to WandB config. + + Args: + params: Dictionary of hyperparameter names and values + """ + # Convert any MLX arrays to Python types + params = {k: self._convert_value(v) for k, v in params.items()} + + # Update wandb config + wandb.config.update(params, allow_val_change=True) + + def log_model(self, + artifact_name: str, + path: str, + metadata: Optional[Dict[str, Any]] = None, + type: str = "model" + ): + """ + Log model files/weights as a WandB artifact. + + Args: + artifact_name: Name for the artifact + path: Path to model file/directory + metadata: Optional metadata to attach to artifact + type: Artifact type (default: "model") + """ + artifact = wandb.Artifact( + name=artifact_name, + type=type, + metadata=metadata + ) + + if os.path.isfile(path): + artifact.add_file(path) + else: + artifact.add_dir(path) + + self.run.log_artifact(artifact) + + def finish(self): + """End the WandB run.""" + if self.run is not None: + self.run.finish() + self.run = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.finish() diff --git a/sillm/training/trainer.py b/sillm/training/trainer.py index 45f48ba..deda7d4 100644 --- a/sillm/training/trainer.py +++ b/sillm/training/trainer.py @@ -14,6 +14,7 @@ from sillm.core.llm import LLM from sillm.training.dataset import Dataset +from sillm.reporting import WandBLogger logger = logging.getLogger("sillm") @@ -102,7 +103,10 @@ def train(self, report_callback: callable = None, eval_steps: int = 100, eval_callback: callable = None, - validation_samples: int = 40 + validation_samples: int = 40, + wandb: bool = False, + wandb_project: str = "sillm", + wandb_report_steps: int = 5 ): """ Train model. @@ -118,6 +122,9 @@ def train(self, eval_callback: Callback after eval. validation_samples: Number of validation samples. debug: Whether to enable debug mode. + wandb: Whether to enable WandB logging. + wandb_project: WandB project name. + wandb_report_steps: Report every `wandb_report_steps` iterations to WandB. """ # Calculate number of iterations if iterations == 0: @@ -131,6 +138,24 @@ def train(self, # Get system memory system_memory = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") + + if wandb: + # Initialize WandB logger + wandb_config = { + "optimizer": optimizer_type, + "learning_rate": learning_rate, + "learning_decay": learning_decay, + "learning_warmup": learning_warmup, + "batch_size": batch_size, + "gradient_checkpointing": gradient_checkpointing, + "gradient_accumulation_steps": gradient_accumulation_steps, + "epochs": epochs, + "iterations": iterations + } + wandb_logger = WandBLogger( + project=wandb_project, + config=wandb_config + ) # Initialize scheduler if learning_decay > 0.0: @@ -156,6 +181,8 @@ def train(self, else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer.init(self.model.trainable_parameters()) + # Initialize gradient accumulation if gradient_accumulation_steps > 1: accum_grad = None @@ -253,10 +280,24 @@ def step(batch): else: rewards = np.vstack([rewards, reward]) + # Calculate loss and timings + train_loss = np.mean(losses) + stop = time.perf_counter() + + # Log to WandB if needed + if wandb and (n + 1) % wandb_report_steps == 0: + metrics = { + "train/loss": train_loss, + "train/tokens_per_sec": float(intv_tokens) / (stop - start), + "train/learning_rate": optimizer.learning_rate.item() + } + if rewards is not None: + metrics["train/reward_chosen"] = np.mean(rewards, axis=0)[0] + metrics["train/reward_rejected"] = np.mean(rewards, axis=0)[1] + wandb_logger.log(metrics) + # Report training loss if needed if (n + 1) % report_steps == 0: - train_loss = np.mean(losses) - stop = time.perf_counter() # Print training loss and timings pbar_epochs.write(f"#{n + 1}:\tTraining loss {train_loss:.3f}\t{float(intv_tokens) / (stop - start):.3f} tok/sec (learning rate: {optimizer.learning_rate.item():.3e})") @@ -286,6 +327,12 @@ def step(batch): # Print validation loss and timings stop = time.perf_counter() val_loss = self.evaluate(dataset_validation, batch_size, validation_batches) + + # Log to WandB + if wandb: + wandb_logger.log({ + "val/loss": val_loss + }) start = time.perf_counter() pbar_epochs.write(f"#{n + 1}:\tValidation loss {val_loss:.3f}\t{(start - stop):.3f} sec") @@ -300,8 +347,15 @@ def step(batch): # Evaluate test dataset if dataset_test is not None: test_batches = validation_samples // batch_size - - stop = time.perf_counter() test_loss = self.evaluate(dataset_test, batch_size, test_batches) + + stop = time.perf_counter() start = time.perf_counter() - logger.info(f"Test loss: {test_loss:.3f}\t{start - stop:.3f} sec") \ No newline at end of file + logger.info(f"Test loss: {test_loss:.3f}\t{start - stop:.3f} sec") + + # Log to WandB + if wandb: + wandb_logger.log({ + "test/loss": test_loss + }) + wandb_logger.finish()