Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions env.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import functools
import evaluate

# fix: GPU OOM (TF exhausts GPU memory, crashing PyTorch)
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
import pseudo_func

datasets = {
"newsroom": {
Expand All @@ -14,7 +9,7 @@
"document_column": "ArticleText",
"system_summary_column": "SystemSummary",
"reference_summary_column": "ReferenceSummary",
"approaches": ["trad", "new"],
"approaches": ["new"],
"human_eval_only_path": "dataloader/newsroom-human-eval.csv", # you need to get this file. See ReadMe.
"refs_path": "dataloader/test.jsonl", # you need to get this file. See ReadMe.
"human_eval_w_refs_path": "dataloader/newsroom_human_eval_with_refs.csv"
Expand Down Expand Up @@ -57,6 +52,8 @@
"bleurt": evaluate.load('bleurt', config_name='BLEURT-20', module_type='metric').compute,
"rouge": functools.partial(evaluate.load("rouge").compute, use_aggregator=False),
"bertscore": functools.partial(evaluate.load("bertscore").compute, lang='en', use_fast_tokenizer=True),
"pseudo_metric":functools.partial(pseudo_func.pseudo_func, model_name = 'google/pegasus-xsum',ref_based_metric_name='rouge')

}


Expand Down
49 changes: 49 additions & 0 deletions pseudo_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""PegasusDemo.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1j9LyFcJLiv37zZloRudYq7Wk9YrgPbfv
"""


"""# New Section"""


from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
import evaluate


"""## pseudo_metric function"""

import pandas as pd
import csv



def pseudo_func(predictions,references, model_name, ref_based_metric_name):
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
pseudo_ref_sum = []

for system_sum, src_text in zip (predictions, references) :
batch = tokenizer(src_text, truncation=True, padding="longest", return_tensors="pt").to(device)
translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
pseudo_ref_sum.append(tgt_text[0])
print(system_sum,tgt_text[0])

metricfn = evaluate.load(ref_based_metric_name)
if ref_based_metric_name == 'rouge' :
results =(metricfn.compute(predictions= pseudo_ref_sum,
references= predictions,
use_aggregator=False))

else :
results =(metricfn.compute(predictions= pseudo_ref_sum,
references= predictions))


print(results)
return results