Skip to content

Commit

Permalink
see if this works better with lm_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 15, 2024
1 parent 1d646c7 commit adfa8db
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
5 changes: 5 additions & 0 deletions config/gpt2_small_fast_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ supervised_data:
cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/"
tags: [ "arc", "e"]

eval_harness:
task_spec: ["piqa", "hellaswag"]
max_examples: 2048
eval_harness_steps: 1000

model:
type: gpt2
hidden_dim: 768
Expand Down
1 change: 1 addition & 0 deletions docker/tpu/Dockerfile.incremental
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ WORKDIR /opt/levanter
ADD pyproject.toml README.md /opt/levanter/
RUN mkdir -p /opt/levanter/src/levanter
RUN pip install -e '.[test]'
RUN pip install "lm-eval@git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch"
ADD . /opt/levanter

# Add $EXTRA_CTX to the same location as in local machine.
Expand Down
59 changes: 44 additions & 15 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
# skip padding
result = result[:initial_length]

logger.info(f"Finished running {len(requests)} loglikelihoods.")

return result

def _pad_dataset_to_batch_size(self, requests):
Expand Down Expand Up @@ -295,6 +297,7 @@ class LmEvalHarnessConfig:
max_examples: int | None = None
max_eval_length: int | None = None
log_samples: bool = False
bootstrap_iters: int = 0 # set to 0 see if this makes it not hang randomly

def to_task_spec(self) -> list[str | dict]:
return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec]
Expand All @@ -307,12 +310,13 @@ def to_task_dict(self) -> dict:
run, and LM Eval Harness doesn't seem to want to do that by default. So we need to do some hacky stuff to make
it work.
"""
logger.info("Loading tasks...")
import lm_eval.tasks as tasks

manager = tasks.TaskManager()
# we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run
this_tasks = {}
for task in self.to_task_spec():
for task in tqdm(self.to_task_spec()):
try:
if isinstance(task, str):
this_tasks.update(tasks.get_task_dict(task, manager))
Expand All @@ -324,6 +328,8 @@ def to_task_dict(self) -> dict:
except Exception:
logger.exception(f"Failed to load task {task}")
raise ValueError(f"Failed to load task {task}")

logger.info(f"Loaded {len(this_tasks)} tasks")
return this_tasks

def _get_task_and_rename(self, manager, our_name, task: dict | str):
Expand Down Expand Up @@ -397,16 +403,25 @@ def _actually_run_eval_harness(
max_eval_length = config.max_eval_length

EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length)
num_parameters = levanter.utils.jax_utils.parameter_count(model)
logger.info(
f"Evaluating with max eval length {EvalPos.size} and batch size {EvalBatch.size}. There are"
f" {num_parameters} parameters in the model."
)
harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp)
# we always set log_samples here and filter out the samples later if we don't want them
outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True)
logger.info("Running eval harness...")
outputs = evaluator.evaluate(
harness,
tasks_to_run,
limit=max_examples,
log_samples=config.log_samples,
bootstrap_iters=config.bootstrap_iters,
)
logger.info("Finished running eval harness.")

averages = _compute_averages(outputs)
outputs["averages"] = averages

if not config.log_samples:
del outputs["samples"]

return outputs


Expand All @@ -417,7 +432,9 @@ def _compute_averages(outputs):
Args:
outputs: Dictionary with results and samples:
- "results": Dictionary of task-level results.
- "samples": Dictionary of task-level sample counts.
- "n-samples" : Dictionary of task-level sample counts.
Returns:
Averages dictionary with macro and micro averages for all metrics.
Expand All @@ -429,7 +446,7 @@ def _compute_averages(outputs):
for task_results in outputs["results"].values():
metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias")

examples_per_task = [len(task_samples) for task_samples in outputs["samples"].values()]
examples_per_task = [task_samples["effective"] for task_samples in outputs["n-samples"].values()]

# Compute macro and micro averages
for metric in metric_keys:
Expand All @@ -448,7 +465,8 @@ def _compute_averages(outputs):

# Compute macro and micro averages
averages["macro_avg_" + metric] = np.mean(metric_values)
averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task)
if sum(this_examples_per_task) > 0:
averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task)

return averages

Expand Down Expand Up @@ -591,18 +609,24 @@ def run_eval_harness_main(config: EvalHarnessMainConfig):
)

logger.info("Finished running LM eval harness")

# log the results
logger.info("Logging results to tracker")
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())
logger.info("Finished logging results to tracker")

# log the results as json
logger.info("uploading artifacts...")
with open("lm_eval_harness_results.json", "w") as f:
json.dump(outputs, f, indent=2)
f.flush()
f_path = f.name
levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results")

# also write to stdout
if jax.process_index() == 0:
print(json.dumps(outputs, indent=2), flush=True)

# also log the results
levanter.tracker.current_tracker().log_artifact("lm_eval_harness_results.json", name="lm_eval_harness_results")
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())

return outputs


Expand Down Expand Up @@ -639,6 +663,7 @@ def lm_eval_harness(step: StepInfo, force=False):
return # don't run eval on the first step

model = inference_mode(step.model, True)
logger.info("Running eval harness...")
outputs = _actually_run_eval_harness(
config,
model,
Expand All @@ -648,18 +673,22 @@ def lm_eval_harness(step: StepInfo, force=False):
axis_resources,
mp,
)
logger.info("Finished running eval harness.")

if jax.process_index() == 0:
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())
log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker())
logger.info("Logged report to tracker")

if jax.process_index() == 0:
# don't delete b/c wandb will sometimes defer upload
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f:
import json

json.dump(outputs, f)
f.flush()
levanter.tracker.current_tracker().log_artifact(
f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output"
)
logger.info("Uploaded results to tracker")

return lm_eval_harness

Expand Down

0 comments on commit adfa8db

Please sign in to comment.