Skip to content

Commit

Permalink
Added plot script for benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasimi committed Oct 13, 2024
1 parent 30ff7db commit 0a72038
Showing 1 changed file with 238 additions and 0 deletions.
238 changes: 238 additions & 0 deletions benchmarks/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import math

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, LogLocator


alpha = 0.75


df_colors = pd.DataFrame({
'bench': ['tda-mapper', 'giotto-tda', 'kepler-mapper'],
'color': ['tab:blue', 'tab:orange', 'tab:green'],
})


df_styles = pd.DataFrame({
'bench': ['tda-mapper', 'giotto-tda', 'kepler-mapper'],
'style': ['-', '--', '-.'],
})


df_markers = pd.DataFrame({
'bench': ['tda-mapper', 'giotto-tda', 'kepler-mapper'],
'marker': ['D', 's', 'o'],
})


df_zorder = pd.DataFrame({
'bench': ['tda-mapper', 'giotto-tda', 'kepler-mapper'],
'zorder': [3, 1, 2],
})


df_titles = pd.DataFrame({
'dataset': ['line', 'digits', 'mnist', 'cifar10', 'fashion_mnist'],
'title': ['Line', 'Digits', 'MNIST', 'Cifar-10', 'Fashion-MNIST'],
})


def load_benchmark(path):
df = pd.read_csv(path)
df = pd.merge(
left=df,
right=df_colors,
on='bench'
)
df = pd.merge(
left=df,
right=df_styles,
on='bench'
)
df = pd.merge(
left=df,
right=df_markers,
on='bench'
)
df = pd.merge(
left=df,
right=df_zorder,
on='bench'
)
df = pd.merge(
left=df,
right=df_titles,
on='dataset'
)
df.sort_values(by=['dataset', 'bench', 'p', 'k'], inplace=True)
df.reset_index(drop=True, inplace=True)
return df


def get_inset(ax):
inset_ax = ax.inset_axes([0.2, 0.55, 0.4, 0.4])
inset_ax.set_yscale('log')
inset_ax.yaxis.set_label_position('left')
inset_ax.tick_params(axis='y', direction='out')
inset_ax.tick_params(axis='x', direction='out')
inset_ax.yaxis.tick_left()
inset_ax.xaxis.set_major_locator(MaxNLocator(integer=True))
inset_ax.yaxis.set_major_locator(LogLocator(base=10.0))
for _, spine in inset_ax.spines.items():
spine.set_alpha(alpha)
inset_ax.xaxis.label.set_alpha(alpha)
inset_ax.yaxis.label.set_alpha(alpha)
for tickline in inset_ax.get_xticklines():
tickline.set_alpha(alpha)
for label in inset_ax.get_xticklabels():
label.set_alpha(alpha)
label.set_fontsize(8)
for tickline in inset_ax.get_yticklines():
tickline.set_alpha(alpha)
for label in inset_ax.get_yticklabels():
label.set_alpha(alpha)
label.set_fontsize(8)
return inset_ax


def plot_library(df_bench, ax, ax_log):
bench = df_bench.bench.values[0]
color = df_bench.color.values[0]
style = df_bench['style'].values[0]
marker = df_bench.marker.values[0]
zorder = df_bench.zorder.values[0]
df_plot = df_bench.sort_values(by='k')
ax_log.plot(
df_plot.k,
df_plot.time,
label=bench,
color=color,
linestyle=style,
marker=marker,
markersize=2,
markerfacecolor='white',
markeredgecolor=color,
markeredgewidth=1.0,
alpha=alpha,
linewidth=1.0,
zorder=zorder,
)
line = ax.plot(
df_plot.k,
df_plot.time,
label=bench,
color=color,
linestyle=style,
marker=marker,
markersize=4,
markerfacecolor='white',
markeredgecolor=color,
markeredgewidth=1.5,
alpha=1.0,
linewidth=1.5,
zorder=zorder,
)[0]
return line


def plot_benchmark(df_p, ax, ax_log):
p = df_p.p.values[0]
max_time = df_p.time.max()
min_time = df_p.time.min()
max_time_log = math.ceil(math.log10(max_time))
min_time_log = math.floor(math.log10(min_time))
ax_log.set_yticks([10**i for i in range(min_time_log, max_time_log + 1)])
ax.set_title(f'p = {p}')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_label_position('left')
ax.yaxis.tick_left()
for bench in df_p.bench.unique():
df_bench = df_p[df_p.bench == bench]
plot_library(df_bench, ax, ax_log)


def plot_experiment(df_dataset, fig, axes, axes_log):
for j, p in enumerate(df_dataset.p.unique()):
df_p = df_dataset[df_dataset.p == p]
plot_benchmark(df_p, axes[j], axes_log[j])
lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
unique = dict(zip(labels, lines))
fig.legend(unique.values(), unique.keys(), loc='upper center', bbox_to_anchor=(0.5, 0.0), ncol=3)


def plot_split(df):
for dataset in df.dataset.unique():
df_dataset = df[df.dataset == dataset]
p_count = df_dataset.p.nunique()
fig, axes = plt.subplots(nrows=1, ncols=p_count, figsize=(8, 3.2), dpi=300, sharex=False, sharey=False)
axes_log = []
for ax in axes:
ax.set_xlabel('k')
ax.set_ylabel('time (s)')
inset_ax = get_inset(ax)
axes_log.append(inset_ax)
plot_experiment(df_dataset, fig, axes, axes_log)
title = df_dataset.title.values[0]
plt.suptitle(f'{title}', fontsize=14)
plt.tight_layout()
fig.savefig(f'benchmark_{dataset}.png', bbox_inches='tight')
plt.show()


def plot_full(df):
p_count = df.p.nunique()
fig, axes = plt.subplots(nrows=3, ncols=p_count, figsize=(8, 8), dpi=300, sharex=False, sharey=False)
for i, dataset in enumerate(df.dataset.unique()):
df_dataset = df[df.dataset == dataset]
p_count = df_dataset.p.nunique()
for j, p in enumerate(df_dataset.p.unique()):
df_p = df_dataset[df_dataset.p == p]
title = df_titles[df_titles.dataset == dataset].title.values[0]
axes[i, j].set_title(f'{title}, p = {p}')
axes[i, j].set_xlabel('k')
axes[i, j].set_ylabel('time (s)')
axes[i, j].xaxis.set_major_locator(MaxNLocator(integer=True))
for bench in df_p.bench.unique():
df_bench = df_p[df_p.bench == bench]
color = df_bench.color.values[0]
style = df_bench['style'].values[0]
marker = df_bench.marker.values[0]
zorder = df_bench.zorder.values[0]
axes[i, j].plot(
df_bench.k,
df_bench.time,
label=bench,
color=color,
linestyle=style,
marker=marker,
markersize=4,
markerfacecolor='white',
markeredgecolor=color,
markeredgewidth=1.5,
alpha=1.0,
linewidth=1.5,
zorder=zorder,
)
lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
unique = dict(zip(labels, lines))
unique_size = len(unique.keys())
fig.legend(unique.values(), unique.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=unique_size)
plt.tight_layout()
fig.savefig(f'benchmark_{unique_size}.png', bbox_inches='tight')
plt.show()


if __name__ == '__main__':
df_benchmark = load_benchmark('./benchmark.csv')

plt.rcParams.update({'font.size': 11, 'font.family': 'serif'})
plot_split(df_benchmark)

plt.rcParams.update({'font.size': 11, 'font.family': 'sans-serif'})
df_sel_2 = df_benchmark[(df_benchmark['bench'].isin(['giotto-tda', 'kepler-mapper'])) & (df_benchmark['dataset'].isin(['mnist', 'cifar10', 'fashion_mnist']))]
plot_full(df_sel_2)
df_sel_3 = df_benchmark[df_benchmark['dataset'].isin(['mnist', 'cifar10', 'fashion_mnist'])]
plot_full(df_sel_3)

0 comments on commit 0a72038

Please sign in to comment.