From 8b8869629f25552a2321632912415bb43fc9176a Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 21 Jan 2025 16:09:11 -0800 Subject: [PATCH] eval_model, fix serialization --- config/gpt2_small_fast_ema.yaml | 34 +++++++++++++++ src/levanter/callbacks.py | 5 ++- src/levanter/compat/hf_checkpoints.py | 2 +- src/levanter/eval.py | 2 +- src/levanter/eval_harness.py | 2 +- src/levanter/lora.py | 4 +- src/levanter/optim/model_averaging.py | 59 +++++++++++++++++++++++++++ src/levanter/trainer.py | 4 +- src/levanter/trainer_state.py | 46 +++++++++++++++++++-- 9 files changed, 146 insertions(+), 12 deletions(-) create mode 100644 config/gpt2_small_fast_ema.yaml create mode 100644 src/levanter/optim/model_averaging.py diff --git a/config/gpt2_small_fast_ema.yaml b/config/gpt2_small_fast_ema.yaml new file mode 100644 index 000000000..f40dc79b1 --- /dev/null +++ b/config/gpt2_small_fast_ema.yaml @@ -0,0 +1,34 @@ +data: !include data/openwebtext_source.yaml +model: + type: gpt2 + hidden_dim: 768 + num_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true + scale_attn_by_inverse_layer_idx: true +trainer: + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "gpt2", "itest"] + + mp: p=f32,c=bfloat16 + model_averaging: + type: ema + beta: 0.995 + + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 20000 + +# tensor_parallel_axes: ["position", "key_position"] +# tensor_parallel_axes: ["heads", "mlp"] +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 + decay: 200 # no decay b/c EMA + lr_schedule: inv diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 135d10dd5..b6bb49664 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -56,6 +56,7 @@ class StepInfo(Generic[S]): model = property(lambda self: self.state.model) opt_state = property(lambda self: self.state.opt_state) + eval_model = property(lambda self: self.state.eval_model) step = property(lambda self: int(self.state.step) - 1) """ @@ -190,7 +191,7 @@ def compute_validation_loss( name: Optional[str] = None, ): def compute_loss(info: StepInfo): - loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name) + loss = eval_loss_loop(loss_fn, info.eval_model, dataset, max_batches=max_batches, name=name) prefix = "eval" if name: @@ -372,7 +373,7 @@ def compute_and_visualize_log_probs(test_data, tokenizer, log_prob_fn, html_dir: """ def compute_and_viz_log_probs(step: StepInfo): - model = step.model + model = step.eval_model os.makedirs(html_dir, exist_ok=True) path = os.path.join(html_dir, f"step_{step.step}.html") diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index c61e3ac15..042f35e35 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -851,7 +851,7 @@ def cb(step: StepInfo): else: my_upload_kwargs = hf_upload_kwargs converter.save_pretrained( - cast(ModelWithHfSerializationMixin, step.model), + cast(ModelWithHfSerializationMixin, step.eval_model), os.path.join(base_path, f"step-{step.step}"), upload_to_hf=upload_to_hf, **my_upload_kwargs, diff --git a/src/levanter/eval.py b/src/levanter/eval.py index ada22bc14..ffbbe37b6 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -194,7 +194,7 @@ def cb_tagged_lm_evaluate( def eval_callback(step: StepInfo): with levanter.tracker.capture_time() as time_fn: - result = evaluator.evaluate(step.model) + result = evaluator.evaluate(step.eval_model) log_dict = { # log micro average as just "loss" diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 8c87c0530..3623aba9c 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -689,7 +689,7 @@ def lm_eval_harness(step: StepInfo, force=False): if step.step == 0 and not force: return - model = inference_mode(step.model, True) + model = step.eval_model logger.info("Running eval harness...") outputs = _actually_run_eval_harness( config, diff --git a/src/levanter/lora.py b/src/levanter/lora.py index cdabee3a5..a0c3c1cdd 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -390,7 +390,7 @@ def cb(step: StepInfo): logger.info(f"Saving PEFT checkpoint for step {step.step} to {base_path}") save_peft_pretrained( - step.model, + step.eval_model, config, base_model_name_or_path, os.path.join(base_path, f"step-{step.step}"), @@ -441,7 +441,7 @@ def save_merged_hf_model_cb(step: StepInfo): logger.info(f"Saving merged HF model for step {step.step} to {base_path}") path = os.path.join(base_path, f"step-{step.step}") - model = step.model + model = step.eval_model save_merged_hf_model(model, converter, path, upload_to_hf=upload_to_hf, **my_upload_kwargs) diff --git a/src/levanter/optim/model_averaging.py b/src/levanter/optim/model_averaging.py new file mode 100644 index 000000000..f5485b979 --- /dev/null +++ b/src/levanter/optim/model_averaging.py @@ -0,0 +1,59 @@ +import abc +import dataclasses +from typing import Generic, TypeVar + +import draccus +import equinox as eqx +import optax + + +S = TypeVar("S") +M = TypeVar("M") + + +class ModelAveraging(eqx.Module, Generic[M]): + """ + This is the interface for model averaging algorithms. Model averaging algorithms are used to average + the parameters of a model over multiple training steps. This is useful for improving generalization + """ + + @abc.abstractmethod + def update(self: S, model: M, step: int) -> S: + pass + + @property + @abc.abstractmethod + def model_params(self) -> M: + pass + + +class EmaModelAveraging(ModelAveraging[M]): + """ + Exponential moving average model averaging + """ + + model: M + beta: float = eqx.static_field() + + def update(self: S, new_model: M, step: int) -> S: + del step + return dataclasses.replace(self, model=optax.incremental_update(new_model, self.model, self.beta)) # type: ignore + + @property + def model_params(self) -> M: + return self.model + + +class ModelAveragingConfig(abc.ABC, draccus.ChoiceRegistry, Generic[M]): + @abc.abstractmethod + def create(self, model: M) -> ModelAveraging[M]: + pass + + +@ModelAveragingConfig.register_subclass("ema") +@dataclasses.dataclass +class EmaModelAveragingConfig(ModelAveragingConfig[M]): + beta: float = 0.999 + + def create(self, model: M) -> EmaModelAveraging[M]: + return EmaModelAveraging(model=model, beta=self.beta) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 82f32422a..b37acdc5d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -38,6 +38,7 @@ from levanter.data import AsyncDataset, DataLoader from levanter.distributed import DistributedConfig, RayConfig from levanter.grad_accum import microbatched +from levanter.optim.model_averaging import ModelAveragingConfig from levanter.tracker import TrackerConfig, capture_time from levanter.trainer_state import TrainerState, saveable_training_mask from levanter.utils import cloud_utils, fsspec_utils @@ -356,6 +357,7 @@ def init_state_and_model(model_init, training_key): is_trainable=is_trainable, mp=self.mp, fp8=self.fp8, + model_averaging=self.config.model_averaging, ) return state @@ -472,7 +474,6 @@ def add_eval_hook(self, eval_dataset, name: Optional[str] = None): @eqx.filter_jit def eval_loss(model, *batch, **batch_kwargs): - model = inference_mode(model, True) model = self.mp.cast_to_compute(model) return self.loss_fn(model, *batch, **batch_kwargs, key=None) @@ -577,6 +578,7 @@ class TrainerConfig: seed: int = 0 # random seed mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy fp8: Optional[bool | Fp8Config] = None + model_averaging: ModelAveragingConfig | None = None wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") diff --git a/src/levanter/trainer_state.py b/src/levanter/trainer_state.py index 549267681..17563cf9c 100644 --- a/src/levanter/trainer_state.py +++ b/src/levanter/trainer_state.py @@ -12,7 +12,9 @@ from haliax.quantization import Fp8Config, apply_updates, fp8_linear_layers, partition_for_grad_overwrite from haliax.types import IntScalar, Scalar +from levanter.optim.model_averaging import ModelAveraging, ModelAveragingConfig from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.tree_utils import inference_mode from levanter.utils.types import FilterTree @@ -51,6 +53,7 @@ class TrainerState(eqx.Module, Generic[M]): optimizer: GradientTransformation = eqx.field(static=True) opt_state: OptState training_key: PRNGKeyArray + model_averaging: ModelAveraging[M] | None is_trainable: FilterTree = eqx.field(static=True) mp: jmp.Policy = eqx.field(static=True) @@ -70,6 +73,19 @@ def trainable_model(self) -> M: def saveable_state(self) -> FilterTree: return eqx.filter(self, saveable_training_mask(self, self.is_trainable)) + @property + def eval_model(self) -> M: + """ + Returns the model in evaluation mode, using the inference mode of the model averaging if it exists. + Otherwise, it uses the inference mode of the model. + """ + if self.model_averaging is not None: + m = self.model_averaging.model_params + else: + m = self.model + + return inference_mode(m, True) + @classmethod def init( cls, @@ -80,6 +96,7 @@ def init( is_trainable: FilterTree = True, mp: Optional[jmp.Policy] = None, fp8: Fp8Config = None, + model_averaging: ModelAveragingConfig[M] | None = None, **kwargs, ) -> "TrainerState[M]": if mp is not None: @@ -90,8 +107,22 @@ def init( if fp8 is not None: model = fp8_linear_layers(model, fp8) + if model_averaging is not None: + model_averaging = model_averaging.create(model) + opt_state = init_optimizer_for_trainables(optimizer, model, is_trainable) - return cls(0, model, optimizer, opt_state, key, is_trainable=is_trainable, mp=mp, *args, **kwargs) + return cls( + 0, + model, + optimizer, + opt_state, + key, + is_trainable=is_trainable, + mp=mp, + model_averaging=model_averaging, + *args, + **kwargs, + ) def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] = None) -> S: assert isinstance(self, TrainerState) # make mypy happy @@ -103,7 +134,13 @@ def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] = obj_fun=obj_fun, is_trainable=self.is_trainable, ) - return dataclasses.replace(self, model=model, opt_state=opt_state, step=self.step + 1) + + if self.model_averaging is not None: + ma = self.model_averaging.update(model, self.step) + else: + ma = None + + return dataclasses.replace(self, model=model, opt_state=opt_state, model_averaging=ma, step=self.step + 1) def init_optimizer_for_trainables(optimizer, model, is_trainable): @@ -164,8 +201,9 @@ def saveable_training_mask(trainer_state: S, is_trainable_param: FilterTree = Tr is_trainable_param = make_floating_point_trainable_filter(is_trainable_param) - trainer_state = jax.tree_util.tree_map(lambda x: True, trainer_state) - saveable_state = dataclasses.replace(trainer_state, model=is_trainable_param) # type: ignore + trainer_state = jax.tree_util.tree_map(lambda x: is_inexact_arrayish, trainer_state) + saveable_state = dataclasses.replace(trainer_state, step=True, training_key=True) # type: ignore + saveable_state = dataclasses.replace(saveable_state, model=is_trainable_param) # type: ignore return saveable_state # type: ignore