diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml deleted file mode 100644 index c8a2e5f3..00000000 --- a/.github/workflows/code-quality-pr.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# This workflow finds which files were changed, prints them, -# and runs `pre-commit` on those files. - -# Inspired by the sktime library: -# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml - -name: Code Quality PR - -env: - SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL: True - -on: - pull_request: - branches: [main, "release/*"] - -jobs: - code-quality: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - - name: Find modified files - id: file_changes - uses: trilom/file-changes-action@v1.2.4 - with: - output: " " - - - name: List modified files - run: echo '${{ steps.file_changes.outputs.files}}' - - - name: Run pre-commits - uses: pre-commit/action@v2.0.3 - with: - extra_args: --files ${{ steps.file_changes.outputs.files}} diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality.yaml similarity index 64% rename from .github/workflows/code-quality-main.yaml rename to .github/workflows/code-quality.yaml index ce95976b..7f7431a7 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality.yaml @@ -6,6 +6,8 @@ name: Code Quality Main on: push: branches: [main] + pull_request: + branches: [main, "release/*"] jobs: code-quality: @@ -13,12 +15,12 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.13" - name: Run pre-commits - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a846507a..a320513f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -13,15 +13,15 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} @@ -46,10 +46,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" diff --git a/.github/workflows/test_runner.yaml b/.github/workflows/test_runner.yaml index 4bc58d36..e2cf62e2 100644 --- a/.github/workflows/test_runner.yaml +++ b/.github/workflows/test_runner.yaml @@ -1,10 +1,10 @@ name: Runner Tests -on: - push: - branches: [main] - pull_request: - branches: [main, "release/*"] +#on: +# push: +# branches: [main] +# pull_request: +# branches: [main, "release/*"] jobs: run_tests_ubuntu: @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} @@ -48,10 +48,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46e35536..29174a11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,6 @@ -default_language_version: - python: python3 - node: 16.14.2 - repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -29,55 +25,26 @@ repos: - id: check-added-large-files require_serial: true - # python code formatting - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - require_serial: true - args: [--line-length, "99"] - - # python import sorting - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - require_serial: true - args: ["--profile", "black", "--filter-files"] - # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade - rev: v3.9.0 + rev: v3.17.0 hooks: - id: pyupgrade require_serial: true args: [--py38-plus] - # python docstring formatting - - repo: https://github.com/myint/docformatter - rev: master + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.5 hooks: - - id: docformatter - require_serial: true - args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] - - # python check (PEP8), programming errors and code complexity - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - require_serial: true - entry: pflake8 - additional_dependencies: ["pyproject-flake8"] + - id: ruff + args: [--fix] + - id: ruff-format # python security linter - - repo: https://github.com/PyCQA/bandit - rev: "1.7.5" + - repo: https://github.com/gitleaks/gitleaks + rev: v8.18.2 hooks: - - id: bandit - require_serial: true - args: ["-c", "pyproject.toml"] - additional_dependencies: ["bandit[toml]"] + - id: gitleaks # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier @@ -95,19 +62,13 @@ repos: require_serial: true args: ["-e", "SC2102"] - # md formatting - - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 hooks: - - id: mdformat - require_serial: true - args: ["--number"] - additional_dependencies: - - mdformat-gfm - - mdformat-tables - - mdformat_frontmatter - # - mdformat-toc - # - mdformat-black + - id: prettier + # To avoid conflicts, tell prettier to ignore file types + # that ruff already handles. + exclude_types: [python] # word spelling linter - repo: https://github.com/codespell-project/codespell diff --git a/README.md b/README.md index c029a913..321e93e9 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y. Major Changes: -- __Added cifar10 examples with an FID of 3.5__ +- **Added cifar10 examples with an FID of 3.5** - Added code for the new Simulation-free Score and Flow Matching (SF)2M preprint - Created `torchcfm` pip installable package - Moved `pytorch-lightning` implementation and experiments to `runner` directory diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index edc5a53e..e2eb5927 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -1,6 +1,6 @@ # CIFAR-10 experiments using TorchCFM -This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset. +This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an **FID of 3.5** on the Cifar10 dataset.

diff --git a/examples/images/cifar10/compute_fid.py b/examples/images/cifar10/compute_fid.py index 7596699c..565e4ba4 100644 --- a/examples/images/cifar10/compute_fid.py +++ b/examples/images/cifar10/compute_fid.py @@ -3,12 +3,10 @@ # Authors: Kilian Fatras # Alexander Tong -import os import sys -import matplotlib.pyplot as plt import torch -from absl import app, flags +from absl import flags from cleanfid import fid from torchdiffeq import odeint from torchdyn.core import NeuralODE @@ -81,7 +79,12 @@ def gen_1_img(unused_latent): print("Use method: ", FLAGS.integration_method) t_span = torch.linspace(0, 1, 2, device=device) traj = odeint( - new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method + new_net, + x, + t_span, + rtol=FLAGS.tol, + atol=FLAGS.tol, + method=FLAGS.integration_method, ) traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1) img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0) diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py index 14b8b04d..1cbad83b 100644 --- a/examples/images/cifar10/train_cifar10.py +++ b/examples/images/cifar10/train_cifar10.py @@ -8,7 +8,6 @@ import torch from absl import app, flags -from torchdyn.core import NeuralODE from torchvision import datasets, transforms from tqdm import trange from utils_cifar import ema, generate_samples, infiniteloop @@ -98,9 +97,7 @@ def train(argv): num_head_channels=64, attention_resolutions="16", dropout=0.1, - ).to( - device - ) # new dropout + bs of 128 + ).to(device) # new dropout + bs of 128 ema_model = copy.deepcopy(net_model) optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py index 851f28ca..21b09683 100644 --- a/examples/images/cifar10/train_cifar10_ddp.py +++ b/examples/images/cifar10/train_cifar10_ddp.py @@ -12,7 +12,6 @@ from absl import app, flags from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DistributedSampler -from torchdyn.core import NeuralODE from torchvision import datasets, transforms from tqdm import trange from utils_cifar import ema, generate_samples, infiniteloop, setup @@ -116,9 +115,7 @@ def train(rank, total_num_gpus, argv): num_head_channels=64, attention_resolutions="16", dropout=0.1, - ).to( - rank - ) # new dropout + bs of 128 + ).to(rank) # new dropout + bs of 128 ema_model = copy.deepcopy(net_model) optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr) @@ -181,7 +178,11 @@ def train(rank, total_num_gpus, argv): # sample and Saving the weights if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0: generate_samples( - net_model, FLAGS.parallel, savedir, global_step, net_="normal" + net_model, + FLAGS.parallel, + savedir, + global_step, + net_="normal", ) generate_samples( ema_model, FLAGS.parallel, savedir, global_step, net_="ema" diff --git a/examples/images/cifar10/utils_cifar.py b/examples/images/cifar10/utils_cifar.py index cfa36b8f..818df68a 100644 --- a/examples/images/cifar10/utils_cifar.py +++ b/examples/images/cifar10/utils_cifar.py @@ -6,7 +6,7 @@ from torchdyn.core import NeuralODE # from torchvision.transforms import ToPILImage -from torchvision.utils import make_grid, save_image +from torchvision.utils import save_image use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") @@ -28,7 +28,6 @@ def setup( master_port: Port number of the master node. backend: Backend to use. """ - os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port diff --git a/pyproject.toml b/pyproject.toml index 94e48ab3..2424da72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,17 +24,16 @@ exclude_lines = [ "if __name__ == .__main__.:", ] -[tool.flake8] -extend-ignore = ["E203", "E402", "E501", "F401", "F841", "E741", "F403"] -exclude = ["logs/*","data/*"] -per-file-ignores = [ - '__init__.py:F401', -] -max-line-length = 99 -count = true +[tool.ruff] +line-length = 99 + +[tool.ruff.lint] +ignore = ["C901", "E501", "E741", "W605", "C408", "E402"] +select = ["C", "E", "F", "I", "W"] -[tool.bandit] -skips = ["B101", "B311"] +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] -[tool.isort] -known_first_party = ["tests", "src"] +[tool.ruff.lint.isort] +known-first-party = ["src"] +known-third-party = ["torch", "transformers", "wandb"] diff --git a/runner/src/datamodules/cifar10_datamodule.py b/runner/src/datamodules/cifar10_datamodule.py index e7ec3633..fe680033 100644 --- a/runner/src/datamodules/cifar10_datamodule.py +++ b/runner/src/datamodules/cifar10_datamodule.py @@ -1,7 +1,6 @@ from typing import Any, List, Union import pl_bolts -import torch from torch.utils.data import DataLoader from torchvision import transforms as transform_lib diff --git a/runner/src/datamodules/components/generators2d.py b/runner/src/datamodules/components/generators2d.py index 9b7de610..48eaa613 100644 --- a/runner/src/datamodules/components/generators2d.py +++ b/runner/src/datamodules/components/generators2d.py @@ -3,6 +3,7 @@ Largely from https://github.com/AmirTag/OT-ICNN/blob/6caa9b982596a101b90a8a947d10f35f18c7de4e/2_dim_experiments/W2-minimax-tf.py """ + import random import numpy as np diff --git a/runner/src/datamodules/components/tnet_dataset.py b/runner/src/datamodules/components/tnet_dataset.py index 32a4d9b4..4fa167ad 100644 --- a/runner/src/datamodules/components/tnet_dataset.py +++ b/runner/src/datamodules/components/tnet_dataset.py @@ -2,6 +2,7 @@ Loads datasets into uniform format for learning continuous flows """ + import math import numpy as np @@ -119,7 +120,7 @@ def plot_paths(self): plt.show() def factory(name, args): - if type(args) == dict: + if type(args) is dict: from argparse import Namespace args = Namespace(**args) @@ -161,7 +162,6 @@ def factory(name, args): def _get_data_points(adata, basis) -> np.ndarray: """Returns the data points corresponding to the selected basis.""" - if basis == "highly_variable": data_points = adata[:, adata.var[basis]].X.toarray() elif basis in adata.obsm.keys(): @@ -650,9 +650,10 @@ def get_paths(self, n=5000, n_steps=3): class CircleTestDataV5(TreeTestData): - """This builds on version 3 to include a better middle timepoint. Where instead of being - parametrically defined, the middle timepoint is defined in terms of the interpolant between the - first and last timepoints along the manifold. + """This builds on version 3 to include a better middle timepoint. + + Where instead of being parametrically defined, the middle timepoint is defined in terms of the + interpolant between the first and last timepoints along the manifold. This is a useful thing to relate to in terms of transport along the manifold. """ diff --git a/runner/src/datamodules/distribution_datamodule.py b/runner/src/datamodules/distribution_datamodule.py index 9bb6cfc6..dfd13715 100644 --- a/runner/src/datamodules/distribution_datamodule.py +++ b/runner/src/datamodules/distribution_datamodule.py @@ -7,7 +7,7 @@ from pytorch_lightning import LightningDataModule from pytorch_lightning.trainer.supporters import CombinedLoader from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader, Sampler, TensorDataset, random_split +from torch.utils.data import DataLoader, TensorDataset, random_split from torchdyn.datasets import ToyDataset from src import utils @@ -56,7 +56,7 @@ def split(self): """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" train_val_test_split = self.hparams.train_val_test_split if isinstance(train_val_test_split, int): - self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data)) + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] return splitter = partial( random_split, @@ -141,7 +141,7 @@ def split(self): """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" train_val_test_split = self.hparams.train_val_test_split if isinstance(train_val_test_split, int): - self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data)) + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] return splitter = partial( random_split, @@ -322,7 +322,7 @@ def split(self): """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" train_val_test_split = self.hparams.train_val_test_split if isinstance(train_val_test_split, int): - self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data)) + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] return splitter = partial( random_split, @@ -338,9 +338,7 @@ def closed_form_marginal(self, sigma, t): """ a = self.a mean = (2 * a * t - a) * torch.ones(self.dim) - cov = (math.sqrt(4 + sigma**4) * t * (1 - t) + (1 - t) ** 2 + t**2) * torch.eye( - self.dim - ) + cov = (math.sqrt(4 + sigma**4) * t * (1 - t) + (1 - t) ** 2 + t**2) * torch.eye(self.dim) return mean, cov def detailed_evaluation(self, xt, sigma, t): @@ -424,7 +422,7 @@ def split(self): """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" train_val_test_split = self.hparams.train_val_test_split if isinstance(train_val_test_split, int): - self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data)) + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] return splitter = partial( random_split, @@ -551,7 +549,7 @@ def split(self): """Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" train_val_test_split = self.hparams.train_val_test_split if isinstance(train_val_test_split, int): - self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data)) + self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data] return splitter = partial( random_split, @@ -644,7 +642,8 @@ def __init__( class DistributionDataModule(BaseLightningDataModule): - """DEPRECATED: Implements loader for datasets taking the form of a sequence of distributions over time. + """DEPRECATED: Implements loader for datasets taking the form of a sequence of distributions + over time. Each batch is a 3-tuple of data (data, time, causal graph) ([b x d], [b], [b x d x d]). """ diff --git a/runner/src/eval.py b/runner/src/eval.py index 5636b3f2..94528244 100644 --- a/runner/src/eval.py +++ b/runner/src/eval.py @@ -57,7 +57,6 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. """ - assert cfg.ckpt_path log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") diff --git a/runner/src/models/cfm_module.py b/runner/src/models/cfm_module.py index 3006d248..925b47b3 100644 --- a/runner/src/models/cfm_module.py +++ b/runner/src/models/cfm_module.py @@ -5,7 +5,6 @@ import numpy as np import torch -import torchsde from pytorch_lightning import LightningDataModule, LightningModule from torch.distributions import MultivariateNormal from torchdyn.core import NeuralODE @@ -19,7 +18,6 @@ from .components.distribution_distances import compute_distribution_distances from .components.optimal_transport import OTPlanSampler from .components.plotting import ( - plot_paths, plot_samples, plot_trajectory, store_trajectories, @@ -226,7 +224,6 @@ def calc_u(self, x0, x1, x, t, mu_t, sigma_t): def calc_loc_and_target(self, x0, x1, t, t_select, training): """Computes the loss on a batch of data.""" - t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))) mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape) eps_t = torch.randn_like(mu_t) @@ -246,7 +243,6 @@ def calc_loc_and_target(self, x0, x1, t, t_select, training): def step(self, batch: Any, training: bool = False): """Computes the loss on a batch of data.""" - X = self.unpack_batch(batch) x0, x1, t_select = self.preprocess_batch(X, training) # Either randomly sample a single T or sample a batch of T's @@ -280,9 +276,7 @@ def training_step(self, batch: Any, batch_idx: int): def image_eval_step(self, batch: Any, batch_idx: int, prefix: str): import os - from math import prod - from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from torchvision.utils import save_image # val_augmentations = AugmentationModule( @@ -394,7 +388,6 @@ def preprocess_epoch_end(self, outputs: List[Any], prefix: str): def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): # Build a trajectory t_span = torch.linspace(0, 1, 101) - aug_dims = self.val_augmentations.aug_dims regs = [] trajs = [] full_trajs = [] @@ -910,7 +903,8 @@ def step(self, batch: Any, training: bool = False): reg, vt, st = self.forward_flow_and_score(t, x) flow_loss = self.criterion(vt, ut) score_loss = self.criterion( - -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2, score_target + -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2, + score_target, ) return torch.mean(reg) + self.hparams.score_weight * score_loss, flow_loss @@ -1051,9 +1045,7 @@ def training_epoch_end(self, training_step_outputs): def image_eval_step(self, batch: Any, batch_idx: int, prefix: str): import os - from math import prod - from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from torchvision.utils import save_image solver = self.partial_solver(self.net, self.dim) @@ -1128,7 +1120,6 @@ def step(self, batch: Any, training: bool = False): def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): # Build a trajectory t_span = torch.linspace(0, 1, 101).type_as(x0) - aug_dims = self.val_augmentations.aug_dims regs = [] trajs = [] full_trajs = [] @@ -1238,7 +1229,6 @@ def step(self, batch: Any, training: bool = False): def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): # Build a trajectory t_span = torch.linspace(0, 1, 101) - aug_dims = self.val_augmentations.aug_dims regs = [] trajs = [] full_trajs = [] @@ -1337,11 +1327,11 @@ def step(self, batch: Any, training: bool = False): class FMLitModule(CFMLitModule): - """Implements a Lipman et al. 2023 style flow matching loss. + """Implements a Lipman et al. - This maps the standard normal distribution to the data distribution by using conditional flows - that are the optimal transport flow from a narrow Gaussian around a datapoint to a standard N(x - | 0, 1). + 2023 style flow matching loss. This maps the standard normal distribution to the data + distribution by using conditional flows that are the optimal transport flow from a narrow + Gaussian around a datapoint to a standard N(x | 0, 1). """ def calc_mu_sigma(self, x0, x1, t): @@ -1365,7 +1355,7 @@ class SplineCFMLitModule(CFMLitModule): def preprocess_batch(self, X, training=False): from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs - """converts a batch of data into matched a random pair of (x0, x1)""" + """Converts a batch of data into matched a random pair of (x0, x1)""" lotp = self.hparams.leaveout_timepoint valid_times = torch.arange(X.shape[1]).type_as(X) t_select = torch.zeros(1) @@ -1431,7 +1421,6 @@ def step(self, batch: Any, training: bool = False): obs = self.unpack_batch(batch) if not self.is_trajectory: obs = obs[:, None, :] - aug_dims = self.augmentations.aug_dims even_ts = torch.arange(obs.shape[1]).to(obs) + 1 self.prior = MultivariateNormal( torch.zeros(self.dim).type_as(obs), torch.eye(self.dim).type_as(obs) diff --git a/runner/src/models/components/augmentation.py b/runner/src/models/components/augmentation.py index 24901878..2f08d325 100644 --- a/runner/src/models/components/augmentation.py +++ b/runner/src/models/components/augmentation.py @@ -1,5 +1,3 @@ -from typing import Callable, List, Union - import torch from torch import nn @@ -213,16 +211,20 @@ def forward(self, x): class Augmenter(nn.Module): - """Augmentation class. Can handle several types of augmentation strategies for Neural DEs. + """Augmentation class. + Can handle several types of augmentation strategies for Neural DEs. :param augment_dims: number of augmented dimensions to initialize :type augment_dims: int :param augment_idx: index of dimension to augment :type augment_idx: int - :param augment_func: nn.Module applied to the input datasets of dimension `d` to determine the augmented initial condition of dimension `d + a`. - `a` is defined implicitly in `augment_func` e.g. augment_func=nn.Linear(2, 5) augments a 2 dimensional input with 3 additional dimensions. + :param augment_func: nn.Module applied to the input datasets of dimension `d` to determine the + augmented initial condition of dimension `d + a`. `a` is defined implicitly in + `augment_func` e.g. augment_func=nn.Linear(2, 5) augments a 2 dimensional input with 3 + additional dimensions. :type augment_func: nn.Module - :param order: whether to augment before datasets [augmentation, x] or after [x, augmentation] along dimension `augment_idx`. Options: ('first', 'last') + :param order: whether to augment before datasets [augmentation, x] or after [x, augmentation] + along dimension `augment_idx`. Options: ('first', 'last') :type order: str """ diff --git a/runner/src/models/components/base.py b/runner/src/models/components/base.py index ebf805c2..89715a65 100644 --- a/runner/src/models/components/base.py +++ b/runner/src/models/components/base.py @@ -568,9 +568,7 @@ def reset_parameters(self): def _get_kl(self, param_mean, sigma, prior_log_sigma): kl = torch.sum( - prior_log_sigma - - torch.log(sigma) - + 0.5 * (sigma**2) / (math.exp(prior_log_sigma * 2)) + prior_log_sigma - torch.log(sigma) + 0.5 * (sigma**2) / (math.exp(prior_log_sigma * 2)) ) kl += 0.5 * torch.sum(param_mean**2) / math.exp(prior_log_sigma * 2) return kl diff --git a/runner/src/models/components/emd.py b/runner/src/models/components/emd.py index 5654f11c..c1c394da 100644 --- a/runner/src/models/components/emd.py +++ b/runner/src/models/components/emd.py @@ -139,7 +139,7 @@ def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac): ) I = len(p0) - J = len(p1) + # J = len(p1) # Assume growth is exponential and retrieve growth rate at t_interpolate # If all sums are the same then this does not change anything # This only matters if sum is not the same for all rows diff --git a/runner/src/models/components/evaluation.py b/runner/src/models/components/evaluation.py index 2b94e13f..8e411362 100644 --- a/runner/src/models/components/evaluation.py +++ b/runner/src/models/components/evaluation.py @@ -1,7 +1,6 @@ from collections import Counter import numpy as np -import torch from sklearn.metrics import average_precision_score, roc_auc_score diff --git a/runner/src/models/components/layers/diffeq_layers/basic.py b/runner/src/models/components/layers/diffeq_layers/basic.py index 3eec391b..7f7e1805 100644 --- a/runner/src/models/components/layers/diffeq_layers/basic.py +++ b/runner/src/models/components/layers/diffeq_layers/basic.py @@ -438,7 +438,7 @@ def __init__( groups=1, bias=True, transpose=False, - **unused_kwargs + **unused_kwargs, ): super().__init__() module = nn.ConvTranspose2d if transpose else nn.Conv2d diff --git a/runner/src/models/components/layers/odefunc.py b/runner/src/models/components/layers/odefunc.py index c39313bf..48dfed55 100644 --- a/runner/src/models/components/layers/odefunc.py +++ b/runner/src/models/components/layers/odefunc.py @@ -1,6 +1,5 @@ import copy -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F diff --git a/runner/src/models/components/logger.py b/runner/src/models/components/logger.py index 9c6e456e..eb328130 100644 --- a/runner/src/models/components/logger.py +++ b/runner/src/models/components/logger.py @@ -7,7 +7,6 @@ import json import os import os.path as osp -import shutil import sys import tempfile import time diff --git a/runner/src/models/components/nn.py b/runner/src/models/components/nn.py index 5e454703..cdbd6334 100644 --- a/runner/src/models/components/nn.py +++ b/runner/src/models/components/nn.py @@ -111,8 +111,8 @@ def checkpoint(func, inputs, params, flag): :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. + :param params: a sequence of parameters `func` depends on but does not explicitly take as + arguments. :param flag: if False, disable gradient checkpointing. """ if flag: diff --git a/runner/src/models/components/plotting.py b/runner/src/models/components/plotting.py index f15b06e0..f3085d5d 100644 --- a/runner/src/models/components/plotting.py +++ b/runner/src/models/components/plotting.py @@ -77,7 +77,6 @@ def store_trajectories(obs: Union[torch.Tensor, list], model, title="trajs", sta batch_size, ts, dim = obs.shape start = obs[:n, start_time, :] obs = obs.reshape(-1, dim).detach().cpu().numpy() - tts = np.tile(np.arange(ts), batch_size) from torchdyn.core import NeuralODE with torch.no_grad(): @@ -108,11 +107,9 @@ def plot_trajectory( data = np.concatenate(data, axis=0) labels = np.concatenate(labels, axis=0) scprep.plot.scatter2d(data, c=labels) - start = obs[0][:n] ts = len(obs) else: batch_size, ts, dim = obs.shape - start = obs[:n, start_time, :] obs = obs.reshape(-1, dim).detach().cpu().numpy() tts = np.tile(np.arange(ts), batch_size) scprep.plot.scatter2d(obs, c=tts) diff --git a/runner/src/models/components/sinkhorn_knopp_unbalanced.py b/runner/src/models/components/sinkhorn_knopp_unbalanced.py index 7e7871e1..8010c0cd 100644 --- a/runner/src/models/components/sinkhorn_knopp_unbalanced.py +++ b/runner/src/models/components/sinkhorn_knopp_unbalanced.py @@ -6,6 +6,7 @@ something large we can compute an unbalanced optimal transport where all the scaling is done on the source distribution and none is done on the target distribution. """ + import warnings import numpy as np diff --git a/runner/src/models/components/solver.py b/runner/src/models/components/solver.py index d0916ed3..d7cf2db1 100644 --- a/runner/src/models/components/solver.py +++ b/runner/src/models/components/solver.py @@ -11,7 +11,7 @@ import torchsde from torchdyn.core import NeuralODE -from .augmentation import AugmentationModule, AugmentedVectorField, Sequential +from .augmentation import AugmentedVectorField, Sequential class TorchSDE(torch.nn.Module): diff --git a/runner/src/models/components/unet.py b/runner/src/models/components/unet.py index 72760e8e..5ff82a24 100644 --- a/runner/src/models/components/unet.py +++ b/runner/src/models/components/unet.py @@ -1,4 +1,5 @@ """From https://raw.githubusercontent.com/openai/guided-diffusion/main/guided_diffusion/unet.py.""" + import math from abc import abstractmethod @@ -372,27 +373,25 @@ class UNetModel(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. + :param attention_resolutions: a collection of downsample rates at which attention will take + place. May be a set, list, or tuple. For example, if this contains 4, then at 4x + downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. + :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. + :param num_classes: if specified (as an int), then this model will be class-conditional with + `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. + :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width + per attention head. + :param num_heads_upsample: works with num_heads to set a different number of heads for + upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. + :param use_new_attention_order: use a different attention pattern for potentially increased + efficiency. """ def __init__( diff --git a/runner/src/models/icnn_module.py b/runner/src/models/icnn_module.py index c6ec571a..33544b19 100644 --- a/runner/src/models/icnn_module.py +++ b/runner/src/models/icnn_module.py @@ -1,20 +1,11 @@ -from typing import Any, List, Optional, Union +from typing import Any, List import torch import torch.nn.functional as F from pytorch_lightning import LightningDataModule, LightningModule -from torch import autograd, nn -from torch.distributions import MultivariateNormal -from torchdyn.core import NeuralODE - -from .components.augmentation import ( - AugmentationModule, - AugmentedVectorField, - Sequential, -) +from torch import autograd + from .components.distribution_distances import compute_distribution_distances -from .components.optimal_transport import OTPlanSampler -from .components.plotting import plot_paths, plot_scatter_and_flow from .utils import get_wandb_logger @@ -197,7 +188,12 @@ def x_to_y(x): x_pred = y_to_x(x1) plot( - x0, x1, x_pred, pred, savename=f"{self.current_epoch}_match", wandb_logger=wandb_logger + x0, + x1, + x_pred, + pred, + savename=f"{self.current_epoch}_match", + wandb_logger=wandb_logger, ) def validation_step(self, batch: Any, batch_idx: int): diff --git a/runner/src/models/runner.py b/runner/src/models/runner.py index a5e3f781..29d4e567 100644 --- a/runner/src/models/runner.py +++ b/runner/src/models/runner.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import numpy as np import torch @@ -112,7 +112,6 @@ def preprocess_epoch_end(self, outputs: List[Any], prefix: str): def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix): # Build a trajectory t_span = torch.linspace(0, 1, 101) - aug_dims = self.val_augmentations.aug_dims solver = self.solver(self.net, self.dim) solver.augmentations = self.val_augmentations traj, aug = solver.odeint(x0, t_span) diff --git a/runner/src/train.py b/runner/src/train.py index 554e7230..c724c6e8 100644 --- a/runner/src/train.py +++ b/runner/src/train.py @@ -47,8 +47,9 @@ @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[dict, dict]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. + """Trains the model. + + Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator which applies extra utilities before and after the call. @@ -59,7 +60,6 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: Returns: Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. """ - # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): pl.seed_everything(cfg.seed, workers=True) diff --git a/runner/src/utils/pylogger.py b/runner/src/utils/pylogger.py index 92ffa718..93759656 100644 --- a/runner/src/utils/pylogger.py +++ b/runner/src/utils/pylogger.py @@ -5,12 +5,19 @@ def get_pylogger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python command line logger.""" - logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + logging_levels = ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ) for level in logging_levels: setattr(logger, level, rank_zero_only(getattr(logger, level))) diff --git a/runner/src/utils/rich_utils.py b/runner/src/utils/rich_utils.py index e340dc69..ddc8a1e2 100644 --- a/runner/src/utils/rich_utils.py +++ b/runner/src/utils/rich_utils.py @@ -37,7 +37,6 @@ def print_config_tree( resolve (bool, optional): Whether to resolve reference fields of DictConfig. save_to_file (bool, optional): Whether to export config to the hydra output folder. """ - style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) @@ -78,7 +77,6 @@ def print_config_tree( @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: """Prompts user to input tags from command line if no tags are provided in config.""" - if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") diff --git a/runner/src/utils/utils.py b/runner/src/utils/utils.py index cf61b547..9b08371b 100644 --- a/runner/src/utils/utils.py +++ b/runner/src/utils/utils.py @@ -2,7 +2,7 @@ import warnings from importlib.util import find_spec from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import Callable, List import hydra from omegaconf import DictConfig @@ -60,7 +60,6 @@ def extras(cfg: DictConfig) -> None: - Setting tags from command line - Rich config printing """ - # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! ") @@ -134,7 +133,6 @@ def log_hyperparameters(object_dict: dict) -> None: Additionally saves: - Number of model parameters """ - hparams = {} cfg = object_dict["cfg"] @@ -174,7 +172,6 @@ def log_hyperparameters(object_dict: dict) -> None: def get_metric_value(metric_dict: dict, metric_name: str) -> float: """Safely retrieves value of the metric logged in LightningModule.""" - if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None @@ -194,7 +191,6 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float: def close_loggers() -> None: """Makes sure all loggers closed properly (prevents logging failure during multirun).""" - log.info("Closing loggers...") if find_spec("wandb"): # if wandb is installed diff --git a/runner/tests/test_datamodule.py b/runner/tests/test_datamodule.py index 0499d8d6..383bff45 100644 --- a/runner/tests/test_datamodule.py +++ b/runner/tests/test_datamodule.py @@ -1,5 +1,3 @@ -from pathlib import Path - import pytest import torch diff --git a/tests/test_models.py b/tests/test_models.py index fa7df44c..6d82f986 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,15 +1,13 @@ -import pytest - from torchcfm.models import MLP from torchcfm.models.unet import UNetModel def test_initialize_models(): - model = UNetModel( + UNetModel( dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True, ) - model = MLP(dim=2, time_varying=True, w=64) + MLP(dim=2, time_varying=True, w=64) diff --git a/tests/test_optimal_transport.py b/tests/test_optimal_transport.py index da43219a..ad696e8b 100644 --- a/tests/test_optimal_transport.py +++ b/tests/test_optimal_transport.py @@ -2,8 +2,6 @@ # Author: Kilian Fatras -import math - import numpy as np import ot import pytest @@ -77,11 +75,15 @@ def test_wasserstein(batch_size=128, seed=1980): W1 = wasserstein(x0, x1, "exact", power=1) pot_eot = ot.sinkhorn2( - ot.unif(x0.shape[0]), ot.unif(x1.shape[0]), M.numpy(), reg=0.01, numItermax=int(1e7) + ot.unif(x0.shape[0]), + ot.unif(x1.shape[0]), + M.numpy(), + reg=0.01, + numItermax=int(1e7), ) eot = wasserstein(x0, x1, "sinkhorn", reg=0.01, power=1) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError): eot = wasserstein(x0, x1, "noname", reg=0.01, power=1) assert pot_W2 == W2 diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index b71b6bd2..c4585949 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -50,11 +50,12 @@ class ConditionalFlowMatcher: """ def __init__(self, sigma: Union[float, int] = 0.0): - r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. + r"""Initialize the ConditionalFlowMatcher class. - Parameters - ---------- - sigma : Union[float, int] + It requires the hyper-parameter $\sigma$. + Parameters + ---------- + sigma : Union[float, int] """ self.sigma = sigma @@ -217,19 +218,22 @@ def compute_lambda(self, t): class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher): - """Child class for optimal transport conditional flow matching method. This class implements - the OT-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class. + """Child class for optimal transport conditional flow matching method. + + This class implements the OT-CFM methods from [1] and inherits the ConditionalFlowMatcher + parent class. It overrides the sample_location_and_conditional_flow. """ def __init__(self, sigma: Union[float, int] = 0.0): - r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. + r"""Initialize the ConditionalFlowMatcher class. - Parameters - ---------- - sigma : Union[float, int] - ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]). + It requires the hyper-parameter $\sigma$. + Parameters + ---------- + sigma : Union[float, int] + ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]). """ super().__init__(sigma) self.ot_sampler = OTPlanSampler(method="exact") @@ -313,9 +317,11 @@ def guided_sample_location_and_conditional_flow( class TargetConditionalFlowMatcher(ConditionalFlowMatcher): - """Lipman et al. 2023 style target OT conditional flow matching. This class inherits the - ConditionalFlowMatcher and override the compute_mu_t, compute_sigma_t and - compute_conditional_flow functions in order to compute [2]'s flow matching. + """Lipman et al. + + 2023 style target OT conditional flow matching. This class inherits the ConditionalFlowMatcher + and override the compute_mu_t, compute_sigma_t and compute_conditional_flow functions in order + to compute [2]'s flow matching. [2] Flow Matching for Generative Modelling, ICLR, Lipman et al. """ @@ -389,16 +395,19 @@ def compute_conditional_flow(self, x0, x1, t, xt): class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher): - """Child class for Schrödinger bridge conditional flow matching method. This class implements - the SB-CFM methods from [1] and inherits the ConditionalFlowMatcher parent class. + """Child class for Schrödinger bridge conditional flow matching method. + + This class implements the SB-CFM methods from [1] and inherits the ConditionalFlowMatcher + parent class. It overrides the compute_sigma_t, compute_conditional_flow and sample_location_and_conditional_flow functions. """ def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"): - r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper- - parameter $\sigma$ and the entropic OT map. + r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. + + It requires the hyper- parameter $\sigma$ and the entropic OT map. Parameters ---------- @@ -548,9 +557,11 @@ def guided_sample_location_and_conditional_flow( class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): - """Albergo et al. 2023 trigonometric interpolants class. This class inherits the - ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in - order to compute [3]'s trigonometric interpolants. + """Albergo et al. + + 2023 trigonometric interpolants class. This class inherits the ConditionalFlowMatcher and + override the compute_mu_t and compute_conditional_flow functions in order to compute [3]'s + trigonometric interpolants. [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. """ diff --git a/torchcfm/models/unet/logger.py b/torchcfm/models/unet/logger.py index 9c6e456e..eb328130 100644 --- a/torchcfm/models/unet/logger.py +++ b/torchcfm/models/unet/logger.py @@ -7,7 +7,6 @@ import json import os import os.path as osp -import shutil import sys import tempfile import time diff --git a/torchcfm/models/unet/nn.py b/torchcfm/models/unet/nn.py index 5e454703..cdbd6334 100644 --- a/torchcfm/models/unet/nn.py +++ b/torchcfm/models/unet/nn.py @@ -111,8 +111,8 @@ def checkpoint(func, inputs, params, flag): :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. + :param params: a sequence of parameters `func` depends on but does not explicitly take as + arguments. :param flag: if False, disable gradient checkpointing. """ if flag: diff --git a/torchcfm/models/unet/unet.py b/torchcfm/models/unet/unet.py index 205ecab4..a75c9ded 100644 --- a/torchcfm/models/unet/unet.py +++ b/torchcfm/models/unet/unet.py @@ -1,4 +1,5 @@ """From https://raw.githubusercontent.com/openai/guided-diffusion/main/guided_diffusion/unet.py.""" + import math from abc import abstractmethod @@ -372,27 +373,25 @@ class UNetModel(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. + :param attention_resolutions: a collection of downsample rates at which attention will take + place. May be a set, list, or tuple. For example, if this contains 4, then at 4x + downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. + :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. + :param num_classes: if specified (as an int), then this model will be class-conditional with + `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. + :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width + per attention head. + :param num_heads_upsample: works with num_heads to set a different number of heads for + upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. + :param use_new_attention_order: use a different attention pattern for potentially increased + efficiency. """ def __init__( diff --git a/torchcfm/optimal_transport.py b/torchcfm/optimal_transport.py index b11b3c3a..cfe055d6 100644 --- a/torchcfm/optimal_transport.py +++ b/torchcfm/optimal_transport.py @@ -144,6 +144,43 @@ def sample_plan(self, x0, x1, replace=True): i, j = self.sample_map(pi, x0.shape[0], replace=replace) return x0[i], x1[j] + def sample_plan_with_scipy(self, x0, x1): + r"""Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target + minibatch using scipy and draw source and target samples from pi $(x,z) \sim \pi$. + + This sampler has two advantages: + * Reduced variance compared to sampling from the OT plan + * Preserves the order of x1 by construction + * Preserves entire batch if x0 and x1 have the same size + + Parameters + ---------- + x0 : Tensor, shape (bs, *dim) + represents the source minibatch + x1 : Tensor, shape (bs, *dim) + represents the source minibatch + + Returns + ------- + x0[i] : Tensor, shape (bs, *dim) + represents the source minibatch drawn from $\pi$ + x1[j] : Tensor, shape (bs, *dim) + represents the source minibatch drawn from $\pi$ + """ + import scipy + + if x0.dim() > 2: + x0 = x0.reshape(x0.shape[0], -1) + if x1.dim() > 2: + x1 = x1.reshape(x1.shape[0], -1) + M = torch.cdist(x0.detach(), x1.detach()) ** 2 + if self.normalize_cost: + M = M / M.max() + _, j = scipy.optimize.linear_sum_assignment(M.cpu().numpy()) + pi_x0 = x0[j] + pi_x1 = x1 + return pi_x0, pi_x1 + def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=True): r"""Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target minibatch and draw source and target labeled samples from pi $(x,z) \sim \pi$ diff --git a/torchcfm/utils.py b/torchcfm/utils.py index 77648c6b..9149b1e7 100644 --- a/torchcfm/utils.py +++ b/torchcfm/utils.py @@ -3,7 +3,6 @@ import matplotlib.pyplot as plt import numpy as np import torch -import torchdyn from torchdyn.datasets import generate_moons # Implement some helper functions