Skip to content

Commit 4c5ae81

Browse files
authored
Fix in plots (#21)
* Fix in plots * Fix in plots * Fix CI/CD * Fix CI/CD * Fix CI/CD
1 parent 149193f commit 4c5ae81

8 files changed

+112
-75
lines changed

.github/workflows/ci.yml

+5-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
fail-fast: true
1515
matrix:
16-
python-version: ["3.12"]
16+
python-version: ["3.13"]
1717
steps:
1818
- name: Checkout
1919
uses: actions/checkout@v4
@@ -26,4 +26,7 @@ jobs:
2626
run: |
2727
python -m pip install --upgrade uv
2828
uv pip install --system -r requirements-dev.txt
29-
uv pip install --system .
29+
30+
- name: Check code formatting with pre-commit
31+
run: |
32+
pre-commit run --all-files

benchmark/runner.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import importlib
22
import json
3+
import logging
34
import os
45
import time
56
from pathlib import Path
@@ -12,6 +13,13 @@
1213
from .transforms.specs import TRANSFORM_SPECS
1314
from .utils import get_image_loader, get_library_versions, get_system_info, time_transform, verify_thread_settings
1415

16+
# Configure logging
17+
logging.basicConfig(
18+
level=logging.INFO,
19+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
20+
)
21+
logger = logging.getLogger(__name__)
22+
1523
# Environment variables for various libraries
1624
os.environ["OMP_NUM_THREADS"] = "1"
1725
os.environ["OPENBLAS_NUM_THREADS"] = "1"
@@ -89,7 +97,7 @@ def load_images(self) -> list[Any]:
8997
raise ValueError("No valid RGB images found in the directory")
9098

9199
if len(rgb_images) < self.num_images:
92-
print(f"Warning: Only found {len(rgb_images)} valid RGB images, requested {self.num_images}")
100+
logger.warning("Only found %d valid RGB images, requested %d", len(rgb_images), self.num_images)
93101

94102
return rgb_images
95103

@@ -228,7 +236,7 @@ def run_transform(self, transform_spec: Any, images: list[Any]) -> dict[str, Any
228236

229237
def run(self, output_path: Path | None = None) -> dict[str, Any]:
230238
"""Run all benchmarks"""
231-
print(f"\nRunning benchmarks for {self.library}")
239+
logger.info("Running benchmarks for %s", self.library)
232240
images = self.load_images()
233241

234242
# Collect metadata
@@ -263,7 +271,7 @@ def run(self, output_path: Path | None = None) -> dict[str, Any]:
263271
if output_path:
264272
output_path = Path(output_path)
265273
output_path.parent.mkdir(parents=True, exist_ok=True)
266-
with open(output_path, "w") as f:
274+
with output_path.open("w") as f:
267275
json.dump(full_results, f, indent=2)
268276

269277
return full_results
4.72 KB
Loading
1.94 KB
Loading

pyproject.toml

+5-5
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ license = { file = "LICENSE" }
3434
maintainers = [ { name = "Vladimir Iglovikov" } ]
3535

3636
authors = [ { name = "Vladimir Iglovikov" } ]
37-
requires-python = ">=3.12"
37+
requires-python = ">=3.13"
3838

3939
classifiers = [
4040
"Development Status :: 5 - Production/Stable",
@@ -44,7 +44,6 @@ classifiers = [
4444
"Operating System :: OS Independent",
4545
"Programming Language :: Python",
4646
"Programming Language :: Python :: 3 :: Only",
47-
"Programming Language :: Python :: 3.12",
4847
"Programming Language :: Python :: 3.13",
4948
"Topic :: Scientific/Engineering :: Artificial Intelligence",
5049
"Topic :: Scientific/Engineering :: Image Processing",
@@ -63,12 +62,12 @@ exclude = [ "output*", "output_videos*", ".venv*", "tests*" ]
6362

6463
[tool.ruff]
6564
# Exclude a variety of commonly ignored directories.
66-
target-version = "py312"
65+
target-version = "py313"
6766

6867
line-length = 120
6968
indent-width = 4
7069

71-
# Assume Python 3.12
70+
# Assume Python 3.13
7271
exclude = [
7372
".bzr",
7473
".direnv",
@@ -121,6 +120,7 @@ lint.ignore = [
121120
"D101",
122121
"D102",
123122
"D103",
123+
"D104",
124124
"D105",
125125
"D107",
126126
"D415",
@@ -152,7 +152,7 @@ lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
152152
lint.pydocstyle.convention = "google"
153153

154154
[tool.mypy]
155-
python_version = "3.12"
155+
python_version = "3.13"
156156
ignore_missing_imports = true
157157
follow_imports = "silent"
158158
warn_redundant_casts = true

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ matplotlib
22
pandas
33
pre_commit>=3.5.0
44
pytest>=8.3.3
5+
seaborn
56
tabulate

tools/generate_speedup_plots.py

+55-46
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from typing import Any
99

1010
import matplotlib.pyplot as plt
11-
import numpy as np
1211
import pandas as pd
12+
import seaborn as sns
1313

1414
# Configure logging
1515
logging.basicConfig(
@@ -138,37 +138,31 @@ def plot_speedup_distribution(
138138
plt.close()
139139
return
140140

141+
# Set seaborn style for better aesthetics
142+
sns.set_style("whitegrid")
143+
sns.set_context("paper", font_scale=1.1)
144+
145+
palette = sns.color_palette("tab10", 4)
146+
hist_color = palette[0] # Blue
147+
top_color = palette[2] # Green
148+
bottom_color = palette[3] # Red
149+
141150
# Create figure with three subplots
142-
fig = plt.figure(figsize=(15, 5))
151+
fig = plt.figure(figsize=(15, 6.5)) # Increased height for better spacing
143152
gs = plt.GridSpec(1, 3, width_ratios=[1.5, 1, 1])
144153
ax1 = fig.add_subplot(gs[0])
145154
ax2 = fig.add_subplot(gs[1])
146155
ax3 = fig.add_subplot(gs[2])
147156

148-
# Use colorblind-friendly colors
149-
hist_color = "#4878D0" # Blue
150-
top_color = "#60BD68" # Green
151-
bottom_color = "#EE6677" # Red
152-
153157
# 1. Histogram of typical speedups (< max_speedup)
154158
typical_speedups = comparison_df[comparison_df["Speedup"] < max_speedup]["Speedup"]
155159

156160
if len(typical_speedups) > 0:
157-
ax1.hist(typical_speedups, bins=15, color=hist_color, alpha=0.7, edgecolor="black")
158-
ax1.axvline(
159-
typical_speedups.median(),
160-
color="#404040",
161-
linestyle="--",
162-
linewidth=1.5,
163-
label=f"Median: {typical_speedups.median():.2f}x",
164-
)
165-
ax1.axvline(1, color="#404040", linestyle=":", linewidth=1.5, alpha=0.7, label="No speedup (1x)")
166-
ax1.grid(True, alpha=0.3)
161+
sns.histplot(typical_speedups, bins=15, color=hist_color, alpha=0.7, edgecolor="black", ax=ax1)
167162

168163
ax1.set_xlabel("Speedup (x)", fontsize=12)
169164
ax1.set_ylabel("Number of transforms", fontsize=12)
170-
ax1.set_title(f"(a) Distribution of Typical Speedups\n(< {max_speedup}x)", fontsize=14)
171-
ax1.legend(fontsize=10)
165+
ax1.set_title(f"(a) Distribution of Speedups < {max_speedup}x", fontsize=14)
172166
else:
173167
ax1.text(0.5, 0.5, "No speedup data < 20x", ha="center", va="center", fontsize=12)
174168
ax1.set_axis_off()
@@ -177,24 +171,32 @@ def plot_speedup_distribution(
177171
top_n = min(10, len(comparison_df))
178172
if top_n > 0:
179173
top_10 = comparison_df.nlargest(top_n, "Speedup")
180-
bars = ax2.barh(np.arange(len(top_10)), top_10["Speedup"], color=top_color, alpha=0.7, edgecolor="black")
181-
ax2.set_yticks(np.arange(len(top_10)))
182-
ax2.set_yticklabels(top_10["Transform"], fontsize=10)
174+
sns.barplot(
175+
x="Speedup",
176+
y="Transform",
177+
data=top_10,
178+
color=top_color,
179+
alpha=0.7,
180+
edgecolor="black",
181+
ax=ax2,
182+
)
183183
ax2.grid(True, alpha=0.3)
184184

185-
for _, bar in enumerate(bars):
186-
width = bar.get_width()
185+
# Add text labels for speedup values
186+
for i, v in enumerate(top_10["Speedup"]):
187187
ax2.text(
188-
width + 0.05,
189-
bar.get_y() + bar.get_height() / 2,
190-
f"{width:.2f}x",
188+
v + 0.05,
189+
i,
190+
f"{v:.2f}x",
191191
ha="left",
192192
va="center",
193193
fontsize=10,
194194
bbox={"facecolor": "white", "alpha": 0.8, "edgecolor": "none"},
195195
)
196196

197197
ax2.set_xlabel("Speedup (x)", fontsize=12)
198+
# Remove y-label "Transform"
199+
ax2.set_ylabel("")
198200
ax2.set_title("(b) Top 10 Speedups", fontsize=14)
199201
else:
200202
ax2.text(0.5, 0.5, "No speedup data", ha="center", va="center", fontsize=12)
@@ -204,30 +206,32 @@ def plot_speedup_distribution(
204206
bottom_n = min(10, len(comparison_df))
205207
if bottom_n > 0:
206208
bottom_10 = comparison_df.nsmallest(bottom_n, "Speedup")
207-
bars = ax3.barh(
208-
np.arange(len(bottom_10)),
209-
bottom_10["Speedup"],
209+
sns.barplot(
210+
x="Speedup",
211+
y="Transform",
212+
data=bottom_10,
210213
color=bottom_color,
211214
alpha=0.7,
212215
edgecolor="black",
216+
ax=ax3,
213217
)
214-
ax3.set_yticks(np.arange(len(bottom_10)))
215-
ax3.set_yticklabels(bottom_10["Transform"], fontsize=10)
216218
ax3.grid(True, alpha=0.3)
217219

218-
for _, bar in enumerate(bars):
219-
width = bar.get_width()
220+
# Add text labels for speedup values
221+
for i, v in enumerate(bottom_10["Speedup"]):
220222
ax3.text(
221-
width + 0.05,
222-
bar.get_y() + bar.get_height() / 2,
223-
f"{width:.2f}x",
223+
v + 0.05,
224+
i,
225+
f"{v:.2f}x",
224226
ha="left",
225227
va="center",
226228
fontsize=10,
227229
bbox={"facecolor": "white", "alpha": 0.8, "edgecolor": "none"},
228230
)
229231

230232
ax3.set_xlabel("Speedup (x)", fontsize=12)
233+
# Remove y-label "Transform"
234+
ax3.set_ylabel("")
231235
ax3.set_title("(c) Bottom 10 Speedups", fontsize=14)
232236
else:
233237
ax3.text(0.5, 0.5, "No speedup data", ha="center", va="center", fontsize=12)
@@ -256,23 +260,28 @@ def plot_speedup_distribution(
256260
f"{total_transforms} transforms with multiple library support"
257261
)
258262

259-
# Add the stats text to the bottom right of the figure
263+
# Add the stats text to the right side of the left plot with larger font
264+
ax1_pos = ax1.get_position()
265+
# Calculate 10% of the plot width
266+
plot_width = ax1_pos.x1 - ax1_pos.x0
267+
shift_amount = plot_width * 0.1
268+
260269
plt.figtext(
261-
0.98,
262-
0.02,
270+
ax1_pos.x1 - 0.02 - shift_amount, # Shifted left by 10% of plot width
271+
ax1_pos.y1 - 0.02, # Slightly below the top edge of the left plot
263272
stats_text,
264273
ha="right",
265-
va="bottom",
266-
bbox={"facecolor": "white", "alpha": 0.9, "edgecolor": "none"},
267-
fontsize=10,
274+
va="top",
275+
bbox={"facecolor": "white", "alpha": 0.9, "edgecolor": "lightgray", "boxstyle": "round,pad=0.5"},
276+
fontsize=14, # Significantly increased font size
268277
)
269278

270-
# Add title with information about the reference library
271-
plt.suptitle(f"Speedup Analysis: {reference_library.capitalize()} vs Other Libraries", fontsize=16)
279+
# Add title with information about the reference library with more space
280+
plt.suptitle(f"Speedup Analysis: {reference_library.capitalize()} vs Other Libraries", fontsize=16, y=1.02)
272281

273282
# Adjust layout and save
274283
plt.tight_layout()
275-
plt.subplots_adjust(top=0.9, bottom=0.15) # Make room for the suptitle and stats
284+
plt.subplots_adjust(top=0.88) # Increased top margin for suptitle
276285
plt.savefig(output_path, dpi=300, bbox_inches="tight")
277286
plt.close()
278287

0 commit comments

Comments
 (0)