55"""
66
77# Standard
8- import os , shutil
9- import yaml
108from uuid import uuid4
9+ import os
10+ import shutil
1111
1212# Third Party
1313from lm_eval .tasks .unitxt import task
14+ import yaml
1415
1516# First Party
1617from instructlab .eval .mmlu import MMLUBranchEvaluator
2021
2122logger = setup_logger (__name__ )
2223
23- TEMP_DIR_PREFIX = 'unitxt_temp'
24+ TEMP_DIR_PREFIX = "unitxt_temp"
25+
2426
2527class UnitxtEvaluator (MMLUBranchEvaluator ):
2628 """
@@ -29,45 +31,51 @@ class UnitxtEvaluator(MMLUBranchEvaluator):
2931 Attributes:
3032 model_path absolute path to or name of a huggingface model
3133 unitxt_recipe unitxt recipe (see unitxt.ai for more information)
32- A Recipe holds a complete specification of a unitxt pipeline
34+ A Recipe holds a complete specification of a unitxt pipeline
3335 Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
34-
36+
3537 """
38+
3639 name = "unitxt"
40+
3741 def __init__ (
3842 self ,
39- model_path ,
43+ model_path ,
4044 unitxt_recipe : str ,
4145 ):
42- task = self .assign_task_name ()
43- tasks_dir = self .assign_tasks_dir (task )
46+ unitxt_task = self .assign_task_name ()
47+ tasks_dir = self .assign_tasks_dir (unitxt_task )
4448 super ().__init__ (
45- model_path = model_path ,
46- tasks_dir = tasks_dir ,
47- tasks = [task ],
48- few_shots = 0
49+ model_path = model_path , tasks_dir = tasks_dir , tasks = [unitxt_task ], few_shots = 0
4950 )
5051 self .unitxt_recipe = unitxt_recipe
5152
52- def assign_tasks_dir (self , task ):
53- return f' { TEMP_DIR_PREFIX } _{ task } '
53+ def assign_tasks_dir (self , task_name ):
54+ return f" { TEMP_DIR_PREFIX } _{ task_name } "
5455
5556 def assign_task_name (self ):
5657 return str (uuid4 ())
5758
58- def prepare_unitxt_files (self )-> tuple :
59- task = self .tasks [0 ]
60- yaml_file = os .path .join (self .tasks_dir , f"{ task } .yaml" )
59+ def prepare_unitxt_files (self ) -> None :
60+ taskname = self .tasks [0 ]
61+ yaml_file = os .path .join (str ( self .tasks_dir ), f"{ taskname } .yaml" )
6162 create_unitxt_pointer (self .tasks_dir )
62- create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task )
63+ create_unitxt_yaml (
64+ yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = taskname
65+ )
6366
6467 def remove_unitxt_files (self ):
65- if self .tasks_dir .startswith (TEMP_DIR_PREFIX ): #to avoid unintended deletion if this class is inherited
68+ if self .tasks_dir .startswith (
69+ TEMP_DIR_PREFIX
70+ ): # to avoid unintended deletion if this class is inherited
6671 shutil .rmtree (self .tasks_dir )
6772 else :
68- logger .warning (f"unitxt tasks dir did not start with '{ TEMP_DIR_PREFIX } ' and therefor was not deleted" )
73+ logger .warning (
74+ "unitxt tasks dir did not start with '%s' and therefor was not deleted" ,
75+ TEMP_DIR_PREFIX ,
76+ )
6977
70- def run (self ,server_url : str | None = None ) -> tuple :
78+ def run (self , server_url : str | None = None ) -> tuple :
7179 """
7280 Runs evaluation
7381
@@ -80,40 +88,40 @@ def run(self,server_url: str | None = None) -> tuple:
8088 os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
8189 results = self ._run_mmlu (server_url = server_url , return_all_results = True )
8290 taskname = self .tasks [0 ]
83- global_scores = results [' results' ][taskname ]
84- global_scores .pop (' alias' )
91+ global_scores = results [" results" ][taskname ]
92+ global_scores .pop (" alias" )
8593 try :
86- instances = results [' samples' ][taskname ]
94+ instances = results [" samples" ][taskname ]
8795 instance_scores = {}
88- metrics = [metric .replace ('metrics.' ,'' ) for metric in instances [0 ]['doc' ]['metrics' ]]
89- for i ,instance in enumerate (instances ):
96+ metrics = [
97+ metric .replace ("metrics." , "" )
98+ for metric in instances [0 ]["doc" ]["metrics" ]
99+ ]
100+ for i , instance in enumerate (instances ):
90101 scores = {}
91102 for metric in metrics :
92103 scores [metric ] = instance [metric ][0 ]
93104 instance_scores [i ] = scores
94- except Exception as e :
105+ except KeyError as e :
95106 logger .error ("Error in extracting single instance scores" )
96107 logger .error (e )
97108 logger .error (e .__traceback__ )
98109 instance_scores = None
99110 self .remove_unitxt_files ()
100- return global_scores ,instance_scores
111+ return global_scores , instance_scores
101112
102113
103- def create_unitxt_yaml (yaml_file ,unitxt_recipe , task_name ):
104- data = {
105- 'task' : f'{ task_name } ' ,
106- 'include' : 'unitxt' ,
107- 'recipe' : f'{ unitxt_recipe } '
108- }
109- with open (yaml_file , 'w' ) as file :
114+ def create_unitxt_yaml (yaml_file , unitxt_recipe , task_name ):
115+ data = {"task" : f"{ task_name } " , "include" : "unitxt" , "recipe" : f"{ unitxt_recipe } " }
116+ with open (yaml_file , "w" , encoding = "utf-8" ) as file :
110117 yaml .dump (data , file , default_flow_style = False )
111- logger .debug (f"task { task } unitxt recipe written to { yaml_file } " )
118+ logger .debug ("task %s unitxt recipe written to %s" , task_name , yaml_file )
119+
112120
113121def create_unitxt_pointer (tasks_dir ):
114122 class_line = "class: !function " + task .__file__ .replace ("task.py" , "task.Unitxt" )
115- output_file = os .path .join (tasks_dir ,' unitxt' )
123+ output_file = os .path .join (tasks_dir , " unitxt" )
116124 os .makedirs (os .path .dirname (output_file ), exist_ok = True )
117- with open (output_file , 'w' ) as f :
125+ with open (output_file , "w" , encoding = "utf-8" ) as f :
118126 f .write (class_line )
119- logger .debug (f "Unitxt task pointer written to { output_file } " )
127+ logger .debug ("Unitxt task pointer written to %s" , output_file )
0 commit comments