-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplot.py
More file actions
113 lines (94 loc) · 3.9 KB
/
plot.py
File metadata and controls
113 lines (94 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import matplotlib.pyplot as plt
from collections import defaultdict
import argparse
parser = argparse.ArgumentParser(description="Plot metrics from input")
parser.add_argument('--metric',choices=['Precision', 'Recall', 'F1-Score', 'all'], default='all',help="Metric to plot (default: all)")
parser.add_argument('--input', type=argparse.FileType('r'), nargs='+', required=True, help="Paths to input files (tab-separated format)")
args = parser.parse_args()
def process_file(input_file):
scores = defaultdict(lambda: {'Recall': [], 'Precision': [], 'F1-Score': []})
for line in input_file:
line = line.strip()
parts = line.split('\t')
if len(parts) < 4:
continue
key = parts[0]
try:
scores[key]['Recall'].append(float(parts[2]))
scores[key]['Precision'].append(float(parts[1]))
scores[key]['F1-Score'].append(float(parts[3]))
except ValueError:
print(f"Skipping line due to invalid float: {line}")
return scores
all_scores = []
for file in args.input:
scores = process_file(file)
baseline = scores.pop('avg_segments')
all_scores.append((scores, baseline))
metric_titles = {
'Recall': 'Recall',
'Precision': 'Precision',
'F1-Score': 'F-score'
}
colors = ['blue', 'green', 'red', 'purple', 'orange', 'brown', 'gray']
linestyles = ['-', '--', ':', '-.']
metrics_to_plot = (
['Recall', 'Precision', 'F1-Score']
if args.metric == 'all' else [args.metric]
)
fig, axes = plt.subplots(
3, 3,
sharex=True, #figsize=(15,5)
figsize=(16, 16)
)
axes = axes.flatten()
lang_map = {'ces': 'cs', 'deu': 'de', 'eng': 'en', 'nld': 'nl', 'fin':'fn', 'hbs':'hb', 'hye':'hy' ,'kan':'kn', 'slk':'sk'}
for file_idx, (scores, baseline) in enumerate(all_scores):
file_label = args.input[file_idx].name.split('/')[-1].split('.')[0].split('-')[0]
for metric_idx, metric in enumerate(metrics_to_plot):
ax_idx = file_idx * len(metrics_to_plot) + metric_idx
if ax_idx >= len(axes):
continue
ax = axes[ax_idx]
scores_by_type = {
'Recall': defaultdict(lambda: defaultdict(dict)),
'Precision': defaultdict(lambda: defaultdict(dict)),
'F1-Score': defaultdict(lambda: defaultdict(dict))
}
for key, values in scores.items():
threshold = float(key.split('-')[-1])
mode = key.split('-')[-2]
type_ = key.split('-')[-3]
if type_ not in ('harmonic', 'geometric'):
for m in metrics_to_plot:
scores_by_type[m][type_][mode][threshold] = values[m][0]
for ti, (type_, modes_dict) in enumerate(scores_by_type[metric].items()):
for mi, (mode, threshold_dict) in enumerate(modes_dict.items()):
sorted_items = sorted(threshold_dict.items())
thresholds = [t for t, _ in sorted_items]
values = [v for _, v in sorted_items]
label = f'{type_}-{mode}'
line, = ax.plot(
thresholds, values,
marker='o',
label=label,
color=colors[ti % len(colors)],
linestyle=linestyles[mi % len(linestyles)]
)
if baseline:
ax.axhline(baseline[metric][0], linewidth=2.5, linestyle='--', color='red', label='Baseline')
ax.axhline(0.0, linewidth=2.5, linestyle='-', color='black')
ax.set_title(f'{lang_map[file_label]}', fontsize=18)
#ax.set_title(metric_titles[metric], fontsize=18)
ax.tick_params(axis='both', labelsize=12)
ax.grid(True)
plt.legend(
bbox_to_anchor=(1.05, 1.05),
loc='best',
fontsize=15,
frameon=False
)
fig.supxlabel('Threshold', fontsize=18)
fig.supylabel(args.metric, fontsize=18)
plt.tight_layout()
plt.show()