diff --git a/gbmi/utils/images.py b/gbmi/utils/images.py index 355ce70c..2edf82d8 100644 --- a/gbmi/utils/images.py +++ b/gbmi/utils/images.py @@ -184,6 +184,8 @@ def optipng( save_bak: bool = True, fix: bool = True, trim_printout: bool = False, + stdout_write: Optional[Callable] = None, + stderr_write: Optional[Callable] = None, ): if not images: return @@ -197,6 +199,8 @@ def optipng( *images, check=True, trim_printout=trim_optipng if trim_printout else None, + stdout_write=stdout_write, + stderr_write=stderr_write, ) @@ -210,6 +214,8 @@ def pngcrush( tmpdir: Optional[Union[str, Path]] = None, cleanup: Optional[bool] = None, trim_printout: bool = False, + stdout_write: Optional[Callable] = None, + stderr_write: Optional[Callable] = None, ): if not images: return @@ -242,6 +248,8 @@ def pngcrush( *map(str, images), check=True, trim_printout=trim_pngcrush if trim_printout else None, + stdout_write=stdout_write, + stderr_write=stderr_write, ) # Replace original images with crushed images if they are smaller @@ -262,12 +270,18 @@ def optimize( tmpdir: Optional[Union[str, Path]] = None, cleanup: Optional[bool] = None, trim_printout: bool = False, + tqdm_position: Optional[int] = None, + tqdm_leave: Optional[bool] = None, + stdout_write: Optional[Callable] = None, + stderr_write: Optional[Callable] = None, ): cur_images = images cur_sizes = [Path(image).stat().st_size for image in cur_images] while cur_images: if shutil.which("ect"): - for img in tqdm(cur_images, desc="ect"): + for img in tqdm( + cur_images, desc="ect", position=tqdm_position, leave=tqdm_leave + ): ect( img, exhaustive=exhaustive, @@ -275,13 +289,21 @@ def optimize( stdout_write=partial(tqdm.write, file=sys.stdout), stderr_write=partial(tqdm.write, file=sys.stderr), ) - optipng(*cur_images, exhaustive=exhaustive, trim_printout=trim_printout) + optipng( + *cur_images, + exhaustive=exhaustive, + trim_printout=trim_printout, + stdout_write=stdout_write, + stderr_write=stderr_write, + ) pngcrush( *cur_images, brute=exhaustive, tmpdir=tmpdir, cleanup=cleanup, trim_printout=trim_printout, + stdout_write=stdout_write, + stderr_write=stderr_write, ) new_sizes = [Path(image).stat().st_size for image in cur_images] cur_images = [ diff --git a/notebooks_jason/max_of_K_all_models.py b/notebooks_jason/max_of_K_all_models.py index d02f7564..11ec3f95 100644 --- a/notebooks_jason/max_of_K_all_models.py +++ b/notebooks_jason/max_of_K_all_models.py @@ -160,8 +160,11 @@ # %% import subprocess import sys +from functools import cache, partial from pathlib import Path +from tqdm.auto import tqdm + import gbmi.utils.images as image_utils seq_len: int = cli_args.K @@ -328,12 +331,18 @@ def optimize_pngs(errs: list[Exception] = []): ) if not opt_success: - for f in LATEX_FIGURE_PATH.glob("*.png"): + for f in tqdm( + list(LATEX_FIGURE_PATH.glob("*.png")), desc="figures", position=0 + ): wrap_err( image_utils.optimize, f, exhaustive=True, trim_printout=COMPACT_IMAGE_OPTIMIZE_OUTPUT, + tqdm_position=1, + tqdm_leave=False, + stdout_write=partial(tqdm.write, file=sys.stdout), + stderr_write=partial(tqdm.write, file=sys.stderr), ) @@ -353,7 +362,6 @@ def optimize_pngs(errs: list[Exception] = []): from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from functools import cache, partial from itertools import chain from typing import Any, Callable, Collection, Iterator, Literal, Optional, Tuple, Union @@ -371,7 +379,6 @@ def optimize_pngs(errs: list[Exception] = []): from sklearn.linear_model import LinearRegression from sklearn.metrics import r2_score from torch import Tensor -from tqdm.auto import tqdm from transformer_lens import HookedTransformer import gbmi.exp_max_of_n.analysis.quadratic as analysis_quadratic