Skip to content

Commit

Permalink
solved nondeterminism i think
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 30, 2024
1 parent a4d1df9 commit 0616db0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 9 deletions.
42 changes: 33 additions & 9 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from levanter.models.loss import next_token_loss
from levanter.utils.background_iterable import BackgroundIterator
from levanter.utils.hf_utils import HfTokenizer
from levanter.utils.py_utils import set_global_rng_seeds


try:
Expand Down Expand Up @@ -546,23 +547,27 @@ def _actually_run_eval_harness(
worker = _LmEvalHarnessWorker(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments=64)

if jax.process_index() == 0:
print("Running eval harness on process 0", flush=True)
logger.info("Process 0 is running the eval harness.")
harness = worker.make_harness_lm()
outputs = evaluator.evaluate(
harness,
tasks_to_run,
limit=max_examples,
log_samples=config.log_samples,
bootstrap_iters=config.bootstrap_iters,
)

# eval_harness only sets seeds in simple_evaluate, which we can't use (I think?)
tasks_to_run = _adjust_config(tasks_to_run, 0)
with set_global_rng_seeds(0):
outputs = evaluator.evaluate(
harness,
tasks_to_run,
limit=max_examples,
log_samples=config.log_samples,
bootstrap_iters=config.bootstrap_iters,
)
worker.stop()

averages = _compute_averages(outputs)
outputs["averages"] = averages

return outputs
else:
print("Running worker message loop", flush=True)
logger.info(f"Process {jax.process_index()} is waiting for eval harness requests from process 0.")
worker.worker_message_loop()

logger.info("Finished running eval harness.")
Expand Down Expand Up @@ -745,6 +750,25 @@ def lm_eval_harness(step: StepInfo, force=False):
return lm_eval_harness


# lifted from lm-eval simple_evaluate
def _adjust_config(task_dict, fewshot_random_seed=0):
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
adjusted_task_dict = {
**adjusted_task_dict,
**{task_name: _adjust_config(task_obj, fewshot_random_seed=fewshot_random_seed)},
}

else:
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=fewshot_random_seed)

adjusted_task_dict[task_name] = task_obj

return adjusted_task_dict


if __name__ == "__main__":
levanter.config.main(run_eval_harness_main)()
print("Done", flush=True)
32 changes: 32 additions & 0 deletions src/levanter/utils/py_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os
import sys
import time
Expand Down Expand Up @@ -121,3 +122,34 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()


@contextlib.contextmanager
def set_global_rng_seeds(seed):
import numpy as np

current_np_seed = np.random.get_state()
np.random.seed(seed)

import random

current_random_seed = random.getstate()
random.seed(seed)

try:
import torch

current_torch_seed = torch.random.get_rng_state()
torch.manual_seed(seed)
except ImportError:
torch = None
current_torch_seed = None
pass

try:
yield
finally:
np.random.set_state(current_np_seed)
random.setstate(current_random_seed)
if current_torch_seed is not None:
torch.random.set_rng_state(current_torch_seed)

0 comments on commit 0616db0

Please sign in to comment.