diff --git a/config/gpt2_small_fast_eval.yaml b/config/gpt2_small_fast_eval.yaml index 14638db1b..11245c7b3 100644 --- a/config/gpt2_small_fast_eval.yaml +++ b/config/gpt2_small_fast_eval.yaml @@ -24,6 +24,11 @@ supervised_data: cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/" tags: [ "arc", "e"] +eval_harness: + task_spec: ["piqa", "hellaswag"] + max_examples: 2048 +eval_harness_steps: 1000 + model: type: gpt2 hidden_dim: 768 diff --git a/docker/tpu/Dockerfile.incremental b/docker/tpu/Dockerfile.incremental index 64c14b4c9..45a5b1aa6 100644 --- a/docker/tpu/Dockerfile.incremental +++ b/docker/tpu/Dockerfile.incremental @@ -18,6 +18,7 @@ WORKDIR /opt/levanter ADD pyproject.toml README.md /opt/levanter/ RUN mkdir -p /opt/levanter/src/levanter RUN pip install -e '.[test]' +RUN pip install "lm-eval@git+https://github.com/dlwh/lm-evaluation-harness.git@no_torch" ADD . /opt/levanter # Add $EXTRA_CTX to the same location as in local machine. diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index cb15125e0..4821f4d82 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -220,6 +220,8 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: # skip padding result = result[:initial_length] + logger.info(f"Finished running {len(requests)} loglikelihoods.") + return result def _pad_dataset_to_batch_size(self, requests): @@ -295,6 +297,7 @@ class LmEvalHarnessConfig: max_examples: int | None = None max_eval_length: int | None = None log_samples: bool = False + bootstrap_iters: int = 0 # set to 0 see if this makes it not hang randomly def to_task_spec(self) -> list[str | dict]: return [task.to_dict() if isinstance(task, TaskConfig) else task for task in self.task_spec] @@ -307,12 +310,13 @@ def to_task_dict(self) -> dict: run, and LM Eval Harness doesn't seem to want to do that by default. So we need to do some hacky stuff to make it work. """ + logger.info("Loading tasks...") import lm_eval.tasks as tasks manager = tasks.TaskManager() # we need to do it this way b/c i can't figure out how to run e.g. hellaswag 0 shot and 10 shot in a single run this_tasks = {} - for task in self.to_task_spec(): + for task in tqdm(self.to_task_spec()): try: if isinstance(task, str): this_tasks.update(tasks.get_task_dict(task, manager)) @@ -324,6 +328,8 @@ def to_task_dict(self) -> dict: except Exception: logger.exception(f"Failed to load task {task}") raise ValueError(f"Failed to load task {task}") + + logger.info(f"Loaded {len(this_tasks)} tasks") return this_tasks def _get_task_and_rename(self, manager, our_name, task: dict | str): @@ -397,16 +403,25 @@ def _actually_run_eval_harness( max_eval_length = config.max_eval_length EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) + num_parameters = levanter.utils.jax_utils.parameter_count(model) + logger.info( + f"Evaluating with max eval length {EvalPos.size} and batch size {EvalBatch.size}. There are" + f" {num_parameters} parameters in the model." + ) harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp) - # we always set log_samples here and filter out the samples later if we don't want them - outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) + logger.info("Running eval harness...") + outputs = evaluator.evaluate( + harness, + tasks_to_run, + limit=max_examples, + log_samples=config.log_samples, + bootstrap_iters=config.bootstrap_iters, + ) + logger.info("Finished running eval harness.") averages = _compute_averages(outputs) outputs["averages"] = averages - if not config.log_samples: - del outputs["samples"] - return outputs @@ -417,7 +432,9 @@ def _compute_averages(outputs): Args: outputs: Dictionary with results and samples: - "results": Dictionary of task-level results. - - "samples": Dictionary of task-level sample counts. + - "n-samples" : Dictionary of task-level sample counts. + + Returns: Averages dictionary with macro and micro averages for all metrics. @@ -429,7 +446,7 @@ def _compute_averages(outputs): for task_results in outputs["results"].values(): metric_keys.update(k for k in task_results.keys() if "stderr" not in k and k != "alias") - examples_per_task = [len(task_samples) for task_samples in outputs["samples"].values()] + examples_per_task = [task_samples["effective"] for task_samples in outputs["n-samples"].values()] # Compute macro and micro averages for metric in metric_keys: @@ -448,7 +465,8 @@ def _compute_averages(outputs): # Compute macro and micro averages averages["macro_avg_" + metric] = np.mean(metric_values) - averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task) + if sum(this_examples_per_task) > 0: + averages["micro_avg_" + metric] = np.average(metric_values, weights=this_examples_per_task) return averages @@ -591,18 +609,24 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): ) logger.info("Finished running LM eval harness") + + # log the results + logger.info("Logging results to tracker") + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Finished logging results to tracker") + # log the results as json + logger.info("uploading artifacts...") with open("lm_eval_harness_results.json", "w") as f: json.dump(outputs, f, indent=2) + f.flush() + f_path = f.name + levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results") # also write to stdout if jax.process_index() == 0: print(json.dumps(outputs, indent=2), flush=True) - # also log the results - levanter.tracker.current_tracker().log_artifact("lm_eval_harness_results.json", name="lm_eval_harness_results") - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - return outputs @@ -639,6 +663,7 @@ def lm_eval_harness(step: StepInfo, force=False): return # don't run eval on the first step model = inference_mode(step.model, True) + logger.info("Running eval harness...") outputs = _actually_run_eval_harness( config, model, @@ -648,18 +673,22 @@ def lm_eval_harness(step: StepInfo, force=False): axis_resources, mp, ) + logger.info("Finished running eval harness.") - if jax.process_index() == 0: - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Logged report to tracker") + if jax.process_index() == 0: # don't delete b/c wandb will sometimes defer upload with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json json.dump(outputs, f) + f.flush() levanter.tracker.current_tracker().log_artifact( f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output" ) + logger.info("Uploaded results to tracker") return lm_eval_harness