Skip to content

Commit

Permalink
eval_model, fix serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 22, 2025
1 parent 5581062 commit 8b88696
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 12 deletions.
34 changes: 34 additions & 0 deletions config/gpt2_small_fast_ema.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data: !include data/openwebtext_source.yaml
model:
type: gpt2
hidden_dim: 768
num_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
scale_attn_by_inverse_layer_idx: true
trainer:
tracker:
- type: wandb
project: "levanter"
tags: [ "openwebtext", "gpt2", "itest"]

mp: p=f32,c=bfloat16
model_averaging:
type: ema
beta: 0.995

model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
num_train_steps: 20000

# tensor_parallel_axes: ["position", "key_position"]
# tensor_parallel_axes: ["heads", "mlp"]
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
decay: 200 # no decay b/c EMA
lr_schedule: inv
5 changes: 3 additions & 2 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class StepInfo(Generic[S]):

model = property(lambda self: self.state.model)
opt_state = property(lambda self: self.state.opt_state)
eval_model = property(lambda self: self.state.eval_model)

step = property(lambda self: int(self.state.step) - 1)
"""
Expand Down Expand Up @@ -190,7 +191,7 @@ def compute_validation_loss(
name: Optional[str] = None,
):
def compute_loss(info: StepInfo):
loss = eval_loss_loop(loss_fn, info.model, dataset, max_batches=max_batches, name=name)
loss = eval_loss_loop(loss_fn, info.eval_model, dataset, max_batches=max_batches, name=name)

prefix = "eval"
if name:
Expand Down Expand Up @@ -372,7 +373,7 @@ def compute_and_visualize_log_probs(test_data, tokenizer, log_prob_fn, html_dir:
"""

def compute_and_viz_log_probs(step: StepInfo):
model = step.model
model = step.eval_model
os.makedirs(html_dir, exist_ok=True)
path = os.path.join(html_dir, f"step_{step.step}.html")

Expand Down
2 changes: 1 addition & 1 deletion src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def cb(step: StepInfo):
else:
my_upload_kwargs = hf_upload_kwargs
converter.save_pretrained(
cast(ModelWithHfSerializationMixin, step.model),
cast(ModelWithHfSerializationMixin, step.eval_model),
os.path.join(base_path, f"step-{step.step}"),
upload_to_hf=upload_to_hf,
**my_upload_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def cb_tagged_lm_evaluate(

def eval_callback(step: StepInfo):
with levanter.tracker.capture_time() as time_fn:
result = evaluator.evaluate(step.model)
result = evaluator.evaluate(step.eval_model)

log_dict = {
# log micro average as just "loss"
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def lm_eval_harness(step: StepInfo, force=False):
if step.step == 0 and not force:
return

model = inference_mode(step.model, True)
model = step.eval_model
logger.info("Running eval harness...")
outputs = _actually_run_eval_harness(
config,
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def cb(step: StepInfo):
logger.info(f"Saving PEFT checkpoint for step {step.step} to {base_path}")

save_peft_pretrained(
step.model,
step.eval_model,
config,
base_model_name_or_path,
os.path.join(base_path, f"step-{step.step}"),
Expand Down Expand Up @@ -441,7 +441,7 @@ def save_merged_hf_model_cb(step: StepInfo):
logger.info(f"Saving merged HF model for step {step.step} to {base_path}")
path = os.path.join(base_path, f"step-{step.step}")

model = step.model
model = step.eval_model

save_merged_hf_model(model, converter, path, upload_to_hf=upload_to_hf, **my_upload_kwargs)

Expand Down
59 changes: 59 additions & 0 deletions src/levanter/optim/model_averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import abc
import dataclasses
from typing import Generic, TypeVar

import draccus
import equinox as eqx
import optax


S = TypeVar("S")
M = TypeVar("M")


class ModelAveraging(eqx.Module, Generic[M]):
"""
This is the interface for model averaging algorithms. Model averaging algorithms are used to average
the parameters of a model over multiple training steps. This is useful for improving generalization
"""

@abc.abstractmethod
def update(self: S, model: M, step: int) -> S:
pass

@property
@abc.abstractmethod
def model_params(self) -> M:
pass


class EmaModelAveraging(ModelAveraging[M]):
"""
Exponential moving average model averaging
"""

model: M
beta: float = eqx.static_field()

def update(self: S, new_model: M, step: int) -> S:
del step
return dataclasses.replace(self, model=optax.incremental_update(new_model, self.model, self.beta)) # type: ignore

@property
def model_params(self) -> M:
return self.model


class ModelAveragingConfig(abc.ABC, draccus.ChoiceRegistry, Generic[M]):
@abc.abstractmethod
def create(self, model: M) -> ModelAveraging[M]:
pass


@ModelAveragingConfig.register_subclass("ema")
@dataclasses.dataclass
class EmaModelAveragingConfig(ModelAveragingConfig[M]):
beta: float = 0.999

def create(self, model: M) -> EmaModelAveraging[M]:
return EmaModelAveraging(model=model, beta=self.beta)
4 changes: 3 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from levanter.data import AsyncDataset, DataLoader
from levanter.distributed import DistributedConfig, RayConfig
from levanter.grad_accum import microbatched
from levanter.optim.model_averaging import ModelAveragingConfig
from levanter.tracker import TrackerConfig, capture_time
from levanter.trainer_state import TrainerState, saveable_training_mask
from levanter.utils import cloud_utils, fsspec_utils
Expand Down Expand Up @@ -356,6 +357,7 @@ def init_state_and_model(model_init, training_key):
is_trainable=is_trainable,
mp=self.mp,
fp8=self.fp8,
model_averaging=self.config.model_averaging,
)
return state

Expand Down Expand Up @@ -472,7 +474,6 @@ def add_eval_hook(self, eval_dataset, name: Optional[str] = None):

@eqx.filter_jit
def eval_loss(model, *batch, **batch_kwargs):
model = inference_mode(model, True)
model = self.mp.cast_to_compute(model)
return self.loss_fn(model, *batch, **batch_kwargs, key=None)

Expand Down Expand Up @@ -577,6 +578,7 @@ class TrainerConfig:
seed: int = 0 # random seed
mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy
fp8: Optional[bool | Fp8Config] = None
model_averaging: ModelAveragingConfig | None = None

wandb: Optional[tracker.wandb.WandbConfig] = None
log_dir: Path = Path("logs/")
Expand Down
46 changes: 42 additions & 4 deletions src/levanter/trainer_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from haliax.quantization import Fp8Config, apply_updates, fp8_linear_layers, partition_for_grad_overwrite
from haliax.types import IntScalar, Scalar

from levanter.optim.model_averaging import ModelAveraging, ModelAveragingConfig
from levanter.utils.jax_utils import is_inexact_arrayish
from levanter.utils.tree_utils import inference_mode
from levanter.utils.types import FilterTree


Expand Down Expand Up @@ -51,6 +53,7 @@ class TrainerState(eqx.Module, Generic[M]):
optimizer: GradientTransformation = eqx.field(static=True)
opt_state: OptState
training_key: PRNGKeyArray
model_averaging: ModelAveraging[M] | None

is_trainable: FilterTree = eqx.field(static=True)
mp: jmp.Policy = eqx.field(static=True)
Expand All @@ -70,6 +73,19 @@ def trainable_model(self) -> M:
def saveable_state(self) -> FilterTree:
return eqx.filter(self, saveable_training_mask(self, self.is_trainable))

@property
def eval_model(self) -> M:
"""
Returns the model in evaluation mode, using the inference mode of the model averaging if it exists.
Otherwise, it uses the inference mode of the model.
"""
if self.model_averaging is not None:
m = self.model_averaging.model_params
else:
m = self.model

return inference_mode(m, True)

@classmethod
def init(
cls,
Expand All @@ -80,6 +96,7 @@ def init(
is_trainable: FilterTree = True,
mp: Optional[jmp.Policy] = None,
fp8: Fp8Config = None,
model_averaging: ModelAveragingConfig[M] | None = None,
**kwargs,
) -> "TrainerState[M]":
if mp is not None:
Expand All @@ -90,8 +107,22 @@ def init(
if fp8 is not None:
model = fp8_linear_layers(model, fp8)

if model_averaging is not None:
model_averaging = model_averaging.create(model)

opt_state = init_optimizer_for_trainables(optimizer, model, is_trainable)
return cls(0, model, optimizer, opt_state, key, is_trainable=is_trainable, mp=mp, *args, **kwargs)
return cls(
0,
model,
optimizer,
opt_state,
key,
is_trainable=is_trainable,
mp=mp,
model_averaging=model_averaging,
*args,
**kwargs,
)

def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] = None) -> S:
assert isinstance(self, TrainerState) # make mypy happy
Expand All @@ -103,7 +134,13 @@ def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] =
obj_fun=obj_fun,
is_trainable=self.is_trainable,
)
return dataclasses.replace(self, model=model, opt_state=opt_state, step=self.step + 1)

if self.model_averaging is not None:
ma = self.model_averaging.update(model, self.step)
else:
ma = None

return dataclasses.replace(self, model=model, opt_state=opt_state, model_averaging=ma, step=self.step + 1)


def init_optimizer_for_trainables(optimizer, model, is_trainable):
Expand Down Expand Up @@ -164,8 +201,9 @@ def saveable_training_mask(trainer_state: S, is_trainable_param: FilterTree = Tr

is_trainable_param = make_floating_point_trainable_filter(is_trainable_param)

trainer_state = jax.tree_util.tree_map(lambda x: True, trainer_state)
saveable_state = dataclasses.replace(trainer_state, model=is_trainable_param) # type: ignore
trainer_state = jax.tree_util.tree_map(lambda x: is_inexact_arrayish, trainer_state)
saveable_state = dataclasses.replace(trainer_state, step=True, training_key=True) # type: ignore
saveable_state = dataclasses.replace(saveable_state, model=is_trainable_param) # type: ignore
return saveable_state # type: ignore


Expand Down

0 comments on commit 8b88696

Please sign in to comment.