2020sys .path .append ("../language-modeling/" )
2121from run_time_clm import get_special_tokens
2222
23- def get_classification_model (model_args ):
24- model_path = model_args .classification_model
25- cache_dir = "/nlp/scr/rewang/huggingface/"
26- tokenizer = PreTrainedTokenizerFast .from_pretrained (model_path )
27- model = BertForSequenceClassification .from_pretrained (model_path )
28- model .to (model_args .device )
29- # tokenizer = AutoTokenizer.from_pretrained(
30- # model_path,
31- # cache_dir=cache_dir,
32- # use_fast=True, # model_args.use_fast_tokenizer,
33- # revision="main", # model_args.model_revision,
34- # # use_auth_token=True if model_args.use_auth_token else None,
35- # use_auth_token=None,
36- # )
37- # config = AutoConfig.from_pretrained(
38- # # model_args.config_name if model_args.config_name else model_args.model_name_or_path,
39- # model_path,
40- # num_labels=4, # num_labels,
41- # finetuning_task="wikisection", # data_args.task_name,
42- # cache_dir=cache_dir,
43- # revision="main",
44- # use_auth_token=None, # True if model_args.use_auth_token else None,
45- # )
46- # model = AutoModelForSequenceClassification.from_pretrained(
47- # model_path,
48- # # model_args.model_name_or_path,
49- # from_tf=bool(".ckpt" in model_path),
50- # config=config,
51- # cache_dir=cache_dir,
52- # revision="main", # model_args.model_revision,
53- # use_auth_token=None,
54- # )
55- return tokenizer , model
56-
5723class GenerationMetrics :
5824
5925 def __init__ (self , model , device , tokenizer , dataset_name , fname ,
@@ -75,12 +41,10 @@ def __init__(self, model, device, tokenizer, dataset_name, fname,
7541 self .section_ids = self .section_ids [:- 1 ]
7642 self ._info = []
7743 self ._examples = []
78- self ._classification_examples = dict ()
7944 self .metrics = defaultdict (lambda : [])
8045 self .examples = {}
8146 self .fname = fname
8247
83- self .classification_tokenizer , self .classification_model = get_classification_model (model_args )
8448 self .mode = "section" if ("splitsection" in dataset_name ) else "doc"
8549
8650 def calculate (self , input_ids , raw_seq , section_name = None ,
@@ -97,7 +61,6 @@ def calculate(self, input_ids, raw_seq, section_name=None,
9761 self ._examples .append ({'text' : raw_seq })
9862
9963 def _stories (self , input_ids , raw_seq ):
100- # TODO get story classification
10164 # Check for redundancy in WP and Prompt
10265 info = {}
10366 for special_tok , name in zip ([50257 , 50258 ], ['[ WP ]' , '[ RESPONSE ]' ]):
@@ -132,9 +95,6 @@ def _stories(self, input_ids, raw_seq):
13295 def _track_doc_examples (self , raw_seq ):
13396 self .examples ['ordering = {}' .format (self .metrics ['ordering' ][- 1 ])] = raw_seq
13497
135- for k , v in self ._classification_examples .items ():
136- self .examples [k ] = v
137-
13898 for section_i , section_name in enumerate (self .section_names ):
13999 is_present = self .metrics ['{} present' .format (section_name )]
140100 is_redundant = self .metrics ['{} redundant' .format (section_name )]
@@ -186,19 +146,6 @@ def _check_total_length(self, input_ids, info):
186146 info ['total length' ] = input_ids .shape [- 1 ]
187147 return info
188148
189- def _check_classification (self , raw_seq , info ):
190- classification_results = self ._get_classification (raw_seq )
191- histograms = dict ()
192- for k , v in classification_results .items ():
193- # if list, create a histogram and include mean
194- if isinstance (v , list ):
195- histograms [self .prepend_ + k + " hist" ] = wandb .Histogram (v )
196- v = np .mean (v )
197- info [k ] = v
198- wandb .log (histograms )
199- return info
200-
201-
202149 def _taskmaster_section_length (self , input_ids , idxs , section_name , info ):
203150 lengths = []
204151 other_id = 50258 if 'USER' in section_name else 50257
@@ -300,8 +247,6 @@ def _document(self, input_ids, raw_seq, gt_raw_seq):
300247 info = {}
301248
302249 info = self ._check_total_length (input_ids = input_ids , info = info )
303- if 'taskmaster' not in self .dataset_name :
304- info = self ._check_classification (raw_seq = raw_seq , info = info )
305250 info = self ._check_ordering (input_ids = input_ids , raw_seq = raw_seq , info = info )
306251 for section_id , section_name in zip (self .section_ids , self .section_names ):
307252 idxs = (input_ids == section_id ).nonzero (as_tuple = True )
@@ -385,68 +330,6 @@ def _document(self, input_ids, raw_seq, gt_raw_seq):
385330
386331 wandb .log (most_recent )
387332
388- def _get_classification (self , raw_seq ):
389- results = defaultdict (lambda : [])
390- self ._classification_examples = dict ()
391- raw_seq = raw_seq .replace ("<|endoftext|> " , "" )
392- split_seq = raw_seq .split (". " )
393- sec_id = 0
394- seq_idxs = []
395- for seq_idx , seq in enumerate (split_seq ):
396- if not seq :
397- continue
398- seq_idxs .append (seq_idx )
399- seq += "."
400- for tok in self .section_names :
401- if tok in seq :
402- sec_id = self .section_names .index (tok )
403- seq = seq .replace (tok + " " , "" )
404- try :
405- assert tok not in seq
406- except :
407- seq = seq .replace (tok , "" )
408-
409- tokenized_seq = self .classification_tokenizer (seq , return_tensors = 'pt' ).to (
410- self .classification_model .device
411- )
412- result = self .classification_model (input_ids = tokenized_seq ['input_ids' ][:, :512 ])
413- probs = torch .nn .functional .softmax (result .logits , dim = 1 )
414-
415- acc = int (torch .argmax (probs ) == sec_id )
416- entropy = - torch .sum (probs * torch .log (probs )).detach ().cpu ().numpy ()
417- prob_sec_id = probs [0 , sec_id ].detach ().cpu ().numpy ()
418-
419- # uniform_p = torch.tensor([0.25]*4)
420- # y_entropy = -torch.sum(uniform_p * torch.log(uniform_p))
421- # mi = float(y_entropy - entropy)
422-
423- self .metrics ["{} class acc" .format (self .section_names [sec_id ])].append (acc )
424- self .metrics ["{} class entropy" .format (self .section_names [sec_id ])].append (entropy )
425- # self.metrics["{} MI".format(self.section_names[sec_id])].append(mi)
426- self .metrics ["{} p(section_id*|x)" .format (self .section_names [sec_id ])].append (prob_sec_id )
427- results ["{} class acc" .format (self .section_names [sec_id ])].append (acc )
428- results ["{} class entropy" .format (self .section_names [sec_id ])].append (entropy )
429- # results["{} MI".format(self.section_names[sec_id])].append(mi)
430- results ["{} p(section_id*|x)" .format (self .section_names [sec_id ])].append (prob_sec_id )
431-
432- # sentences that are induce high/low acc/entropy/mi
433- for key , metric in zip (["{} class acc" , "{} class entropy" ], [acc , entropy ,]):
434- key = key .format (self .section_names [sec_id ])
435- if results [key ] and max (results [key ]) == metric :
436- self ._classification_examples [key + " MAX" ] = seq
437- self .metrics [key + " MAX IDX" ].append (
438- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
439- results [key + " MAX IDX" ].append (
440- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
441- if results [key ] and min (results [key ]) == metric :
442- self ._classification_examples [key + " MIN" ] = seq
443- self .metrics [key + " MIN IDX" ].append (
444- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
445- results [key + " MIN IDX" ].append (
446- len (results ["{} class acc" .format (self .section_names [sec_id ])]))
447-
448- return results
449-
450333 def print_results (self ):
451334 print ("Examples" )
452335 extreme_ex = []
0 commit comments