Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 17, 2025
1 parent 38b4a4a commit db0b1cd
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 9 deletions.
6 changes: 5 additions & 1 deletion config/gpt2_nano_harness.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
eval_harness:
task_spec: ["piqa", "hellaswag"]
task_spec:
- mmlu
# - task: mmlu
# task_alias: mmlu_0shot
# num_fewshot: 0
max_examples: 32
eval_harness_steps: 50
data:
Expand Down
4 changes: 3 additions & 1 deletion config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
eval_harness:
# task_spec: ["hellaswag"]
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
- mmlu
- task: mmlu
num_fewshot: 1
task_alias: mmlu_1shot
tokenizer: "gpt2"
model:
type: gpt2
Expand Down
89 changes: 82 additions & 7 deletions src/levanter/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
result_greedy = np.zeros(len(requests))
covered_points = np.zeros(len(requests), dtype=bool)

return [ (0.0, False) ] * len(requests)

total_padding = 0
total_tokens = 0
pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req")
Expand Down Expand Up @@ -420,11 +422,14 @@ def to_task_dict(self) -> dict:
else:
our_name = task.get("task_alias", task["task"]) if isinstance(task, dict) else task
our_name = our_name.replace(" ", "_")
this_task = self._get_task_and_rename(manager, our_name, task)
this_tasks[our_name] = this_task
except Exception:
tasks_for_this_task_spec = self._get_task_and_rename(manager, our_name, task)
for k, v in tasks_for_this_task_spec.items():
if k in this_tasks:
raise ValueError(f"Task {k} already exists")
this_tasks[k] = v
except Exception as e:
logger.exception(f"Failed to load task {task}")
raise ValueError(f"Failed to load task {task}")
raise ValueError(f"Failed to load task {task}") from e

logger.info(f"Loaded {len(this_tasks)} tasks")
return this_tasks
Expand All @@ -437,12 +442,82 @@ def _get_task_and_rename(self, manager, our_name, task: dict | str):
"""
import lm_eval.tasks as tasks

task_name = task if isinstance(task, str) else task["task"]

task_dict = tasks.get_task_dict([task], manager)
this_task = task_dict.popitem()[1]
# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
this_task.config.task = our_name
assert len(task_dict) == 1, f"Expected 1 task, got {len(task_dict)}"
try:
this_task = self._rename_tasks_for_eval_harness(task_dict, task_name, our_name)
except AttributeError:
logger.exception(f"Failed to rename task {task}: {task_dict}")
raise ValueError(f"Failed to rename task {task}: {task_dict}")
return this_task

def _rename_tasks_for_eval_harness(self, this_task, lm_eval_task_name, our_name):
import lm_eval.tasks as tasks
# hacky, but this allows us to run multiple instances of the same task with different fewshot settings
if isinstance(this_task, dict):
out = {}
for k, v in this_task.items():
v = self._rename_tasks_for_eval_harness(v, lm_eval_task_name, our_name)

if isinstance(k, tasks.ConfigurableGroup):
k._config.group = self._replace_name_with_our_name(k.group, lm_eval_task_name, our_name)
out[k] = v
elif isinstance(k, str):
k = self._replace_name_with_our_name(k, lm_eval_task_name, our_name)
if isinstance(v, dict):
# ok so inexplicably, lm_eval_harness doesn't wrap the key in a ConfigurableGroup when you pass
# in a task dict (it seems like a mistake), so we need to do that here
# subtask is the name of all of the child tasks in v
subtask_list = self._get_child_tasks(v)
group = tasks.ConfigurableGroup(
config={"group": k, "task": subtask_list}
)
out[group] = v
else:
out[k] = v
else:
raise ValueError(f"Unknown key type: {k}")

return out

elif isinstance(this_task, tasks.ConfigurableTask):
this_task.config.task = self._replace_name_with_our_name(this_task.config.task, lm_eval_task_name, our_name)
return this_task
else:
raise ValueError(f"Unknown task type: {this_task}")

def _replace_name_with_our_name(self, lm_eval_name, lm_eval_prefix, our_name_prefix):
if our_name_prefix.startswith(lm_eval_prefix):
suffix = our_name_prefix[len(lm_eval_prefix) :]
prefix = lm_eval_prefix
else:
suffix = ""
prefix = our_name_prefix
if lm_eval_prefix in lm_eval_name:
lm_eval_name = lm_eval_name.replace(lm_eval_prefix, prefix) + suffix
else:
lm_eval_name = prefix + "_" + lm_eval_name + suffix
return lm_eval_name

def _get_child_tasks(self, task_group):
import lm_eval.tasks as tasks
out = []
for k, v in task_group.items():
if isinstance(k, tasks.ConfigurableGroup):
subtask_or_tasks = k.config.task
if isinstance(subtask_or_tasks, str):
out.append(subtask_or_tasks)
else:
out.extend(subtask_or_tasks)
elif isinstance(k, str):
out.append(k)
else:
raise ValueError(f"Unknown key type: {k}")

return out


@dataclass(frozen=True)
class EvalHarnessMainConfig:
Expand Down

0 comments on commit db0b1cd

Please sign in to comment.