Skip to content

Commit ebea5e9

Browse files
Adding automatically generating "score" and "evaluate" commands for each metric (#7)
1 parent 737b50e commit ebea5e9

File tree

8 files changed

+727
-82
lines changed

8 files changed

+727
-82
lines changed

sacrerouge/__main__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22

3-
from sacrerouge.commands import correlate, evaluate, score, setup_dataset, setup_metric
3+
from sacrerouge.commands import correlate, evaluate, metric_command, score, setup_dataset, setup_metric
44

55

66
def main():
@@ -17,6 +17,9 @@ def main():
1717
for subcommand in subcommands:
1818
subcommand.add_subparser(subparsers)
1919

20+
# Add a command for each individual metric
21+
metric_command.add_metric_subcommands(subparsers)
22+
2023
args = parser.parse_args()
2124
if 'func' in dir(args):
2225
args.func(args)

sacrerouge/commands/evaluate.py

+59-45
Original file line numberDiff line numberDiff line change
@@ -55,46 +55,72 @@ def evaluate_instances(instances: List[EvalInstance], metrics: List[Metric]) ->
5555
return macro, micro_list
5656

5757

58-
class EvaluateSubcommand(Subcommand):
59-
@overrides
60-
def add_subparser(self, parser: argparse._SubParsersAction):
61-
description = 'Evaluate a summarization model'
62-
self.parser = parser.add_parser('evaluate', description=description, help=description)
63-
self.parser.add_argument(
58+
def save_evaluation_results(macro_results: MetricsDict,
59+
micro_results_list: List[Metrics],
60+
macro_output_json: str,
61+
micro_output_jsonl: str,
62+
silent: bool) -> None:
63+
dirname = os.path.dirname(macro_output_json)
64+
if dirname:
65+
os.makedirs(dirname, exist_ok=True)
66+
67+
serialized_macro = jsons.dumps({'metrics': macro_results}, jdkwargs={'indent': 2})
68+
with open(macro_output_json, 'w') as out:
69+
out.write(serialized_macro)
70+
if not silent:
71+
logger.info(serialized_macro)
72+
73+
with JsonlWriter(micro_output_jsonl) as out:
74+
for metrics_dict in micro_results_list:
75+
out.write(metrics_dict)
76+
77+
78+
def add_evaluate_arguments(parser: argparse.ArgumentParser, include_config_arguments: bool) -> None:
79+
if include_config_arguments:
80+
parser.add_argument(
6481
'config',
6582
type=str,
6683
help='The config file that specifies the dataset reader and metrics'
6784
)
68-
self.parser.add_argument(
69-
'macro_output_json',
70-
type=str,
71-
help='The path to where the system-level metrics should be written'
72-
)
73-
self.parser.add_argument(
74-
'micro_output_jsonl',
75-
type=str,
76-
help='The path to where the input-level metrics should be written'
77-
)
78-
self.parser.add_argument(
79-
'--log-file',
80-
type=str,
81-
help='The file where the log should be written'
82-
)
83-
self.parser.add_argument(
84-
'--silent',
85-
action='store_true',
86-
help='Controls whether the log should be written to stdout'
87-
)
88-
self.parser.add_argument(
85+
parser.add_argument(
8986
'--overrides',
9087
type=str,
9188
help='A serialized json that will override the parameters passed in "config"'
9289
)
93-
self.parser.add_argument(
94-
'--include-packages',
95-
nargs='+',
96-
help='A list of additional packages to include'
97-
)
90+
91+
parser.add_argument(
92+
'macro_output_json',
93+
type=str,
94+
help='The path to where the system-level metrics should be written'
95+
)
96+
parser.add_argument(
97+
'micro_output_jsonl',
98+
type=str,
99+
help='The path to where the input-level metrics should be written'
100+
)
101+
parser.add_argument(
102+
'--log-file',
103+
type=str,
104+
help='The file where the log should be written'
105+
)
106+
parser.add_argument(
107+
'--silent',
108+
action='store_true',
109+
help='Controls whether the log should be written to stdout'
110+
)
111+
parser.add_argument(
112+
'--include-packages',
113+
nargs='+',
114+
help='A list of additional packages to include'
115+
)
116+
117+
118+
class EvaluateSubcommand(Subcommand):
119+
@overrides
120+
def add_subparser(self, parser: argparse._SubParsersAction):
121+
description = 'Evaluate a summarization model'
122+
self.parser = parser.add_parser('evaluate', description=description, help=description)
123+
add_evaluate_arguments(self.parser, True)
98124
self.parser.set_defaults(func=self.run)
99125

100126
@overrides
@@ -117,16 +143,4 @@ def run(self, args):
117143
instances = dataset_reader.read(*input_files)
118144
macro, micro_list = evaluate_instances(instances, metrics)
119145

120-
dirname = os.path.dirname(args.macro_output_json)
121-
if dirname:
122-
os.makedirs(dirname, exist_ok=True)
123-
124-
serialized_macro = jsons.dumps({'metrics': macro}, jdkwargs={'indent': 2})
125-
with open(args.macro_output_json, 'w') as out:
126-
out.write(serialized_macro)
127-
if not args.silent:
128-
logger.info(serialized_macro)
129-
130-
with JsonlWriter(args.micro_output_jsonl) as out:
131-
for metrics_dict in micro_list:
132-
out.write(metrics_dict)
146+
save_evaluation_results(macro, micro_list, args.macro_output_json, args.micro_output_jsonl, args.silent)

sacrerouge/commands/metric_command.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
from overrides import overrides
3+
from typing import Type
4+
5+
from sacrerouge.commands import Subcommand
6+
from sacrerouge.commands.evaluate import add_evaluate_arguments, evaluate_instances, save_evaluation_results
7+
from sacrerouge.commands.score import add_score_arguments, save_score_results, score_instances
8+
from sacrerouge.common import Registrable
9+
from sacrerouge.common.arguments import add_metric_arguments, get_dataset_reader_from_argument, get_metric_from_arguments
10+
from sacrerouge.common.logging import prepare_global_logging
11+
from sacrerouge.metrics import Metric
12+
13+
14+
def add_metric_subcommands(subparsers: argparse._SubParsersAction) -> None:
15+
"""Adds a MetricSubcommand for every registered metric."""
16+
for name, (metric, _) in Registrable._registry[Metric].items():
17+
command = MetricSubcommand(name, metric)
18+
command.add_subparser(subparsers)
19+
20+
21+
def add_dataset_reader_arguments(parser: argparse.ArgumentParser) -> None:
22+
parser.add_argument(
23+
'--dataset-reader',
24+
type=str,
25+
required=True,
26+
help='The name or the parameters as a serialized json for the dataset reader'
27+
)
28+
parser.add_argument(
29+
'--input-files',
30+
nargs='+',
31+
required=True,
32+
help='The input files to be passed to the dataset reader'
33+
)
34+
35+
36+
class MetricSubcommand(Subcommand):
37+
def __init__(self, name: str, metric_type: Type) -> None:
38+
super().__init__()
39+
self.name = name
40+
self.metric_type = metric_type
41+
42+
@overrides
43+
def add_subparser(self, parser: argparse._SubParsersAction):
44+
description = f'Run "evaluate" or "score" with the "{self.name}" metric.'
45+
self.parser = parser.add_parser(self.name, description=description, help=description)
46+
subparsers = self.parser.add_subparsers()
47+
48+
description = f'Run "evaluate" with the "{self.name}" metric.'
49+
self.evaluate_parser = subparsers.add_parser('evaluate', description=description, help=description)
50+
add_evaluate_arguments(self.evaluate_parser, False)
51+
add_metric_arguments(self.evaluate_parser, self.metric_type)
52+
add_dataset_reader_arguments(self.evaluate_parser)
53+
self.evaluate_parser.set_defaults(func=self.run_evaluate)
54+
55+
description = f'Run "score" with the "{self.name}" metric.'
56+
self.score_parser = subparsers.add_parser('score', description=description, help=description)
57+
add_score_arguments(self.score_parser, False)
58+
add_metric_arguments(self.score_parser, self.metric_type)
59+
add_dataset_reader_arguments(self.score_parser)
60+
self.score_parser.set_defaults(func=self.run_score)
61+
62+
def run_evaluate(self, args: argparse.Namespace) -> None:
63+
prepare_global_logging(file_path=args.log_file, silent=args.silent)
64+
65+
dataset_reader = get_dataset_reader_from_argument(args.dataset_reader)
66+
metric = get_metric_from_arguments(self.metric_type, args)
67+
input_files = args.input_files
68+
69+
instances = dataset_reader.read(*input_files)
70+
macro, micro_list = evaluate_instances(instances, [metric])
71+
72+
save_evaluation_results(macro, micro_list, args.macro_output_json, args.micro_output_jsonl, args.silent)
73+
74+
def run_score(self, args: argparse.Namespace) -> None:
75+
prepare_global_logging(file_path=args.log_file, silent=args.silent)
76+
77+
dataset_reader = get_dataset_reader_from_argument(args.dataset_reader)
78+
metric = get_metric_from_arguments(self.metric_type, args)
79+
input_files = args.input_files
80+
81+
instances = dataset_reader.read(*input_files)
82+
metrics_dicts = score_instances(instances, [metric])
83+
84+
save_score_results(metrics_dicts, args.output_jsonl, args.silent)

sacrerouge/commands/score.py

+44-35
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,41 @@
1616
logger = logging.getLogger(__name__)
1717

1818

19+
def add_score_arguments(parser: argparse.ArgumentParser, include_config_arguments: bool) -> None:
20+
if include_config_arguments:
21+
parser.add_argument(
22+
'config',
23+
type=str,
24+
help='The config file that specifies the dataset reader and metrics'
25+
)
26+
parser.add_argument(
27+
'--overrides',
28+
type=str,
29+
help='A serialized json that will override the parameters passed in "config"'
30+
)
31+
32+
parser.add_argument(
33+
'output_jsonl',
34+
type=str,
35+
help='The path to where the input-level metrics should be written'
36+
)
37+
parser.add_argument(
38+
'--log-file',
39+
type=str,
40+
help='The file where the log should be written'
41+
)
42+
parser.add_argument(
43+
'--silent',
44+
action='store_true',
45+
help='Controls whether the log should be written to stdout'
46+
)
47+
parser.add_argument(
48+
'--include-packages',
49+
nargs='+',
50+
help='A list of additional packages to include'
51+
)
52+
53+
1954
def _load_metrics(params: Params) -> List[Metric]:
2055
metrics = []
2156
for metric_params in params.pop('metrics'):
@@ -106,41 +141,19 @@ def score_instances(instances: List[EvalInstance], metrics: List[Metric]) -> Dic
106141
return metrics_dicts
107142

108143

144+
def save_score_results(metrics_dicts: Dict[str, Dict[str, Metrics]], output_file: str, silent: bool) -> None:
145+
with JsonlWriter(output_file) as out:
146+
for instance_id in sorted(metrics_dicts.keys()):
147+
for summarizer_id in sorted(metrics_dicts[instance_id].keys()):
148+
out.write(metrics_dicts[instance_id][summarizer_id])
149+
150+
109151
class ScoreSubcommand(Subcommand):
110152
@overrides
111153
def add_subparser(self, parser: argparse._SubParsersAction):
112154
description = 'Score all of the inputs to evaluate a metric'
113155
self.parser = parser.add_parser('score', description=description, help=description)
114-
self.parser.add_argument(
115-
'config',
116-
type=str,
117-
help='The config file that specifies the dataset reader and metrics'
118-
)
119-
self.parser.add_argument(
120-
'output_jsonl',
121-
type=str,
122-
help='The path to where the input-level metrics should be written'
123-
)
124-
self.parser.add_argument(
125-
'--log-file',
126-
type=str,
127-
help='The file where the log should be written'
128-
)
129-
self.parser.add_argument(
130-
'--silent',
131-
action='store_true',
132-
help='Controls whether the log should be written to stdout'
133-
)
134-
self.parser.add_argument(
135-
'--overrides',
136-
type=str,
137-
help='A serialized json that will override the parameters passed in "config"'
138-
)
139-
self.parser.add_argument(
140-
'--include-packages',
141-
nargs='+',
142-
help='A list of additional packages to include'
143-
)
156+
add_score_arguments(self.parser, True)
144157
self.parser.set_defaults(func=self.run)
145158

146159
@overrides
@@ -163,8 +176,4 @@ def run(self, args):
163176
instances = dataset_reader.read(*input_files)
164177
metrics_dicts = score_instances(instances, metrics)
165178

166-
# Save the results to the output file
167-
with JsonlWriter(args.output_jsonl) as out:
168-
for instance_id in sorted(metrics_dicts.keys()):
169-
for summarizer_id in sorted(metrics_dicts[instance_id].keys()):
170-
out.write(metrics_dicts[instance_id][summarizer_id])
179+
save_score_results(metrics_dicts, args.output_jsonl, args.silent)

0 commit comments

Comments
 (0)