diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml deleted file mode 100644 index c8a2e5f3..00000000 --- a/.github/workflows/code-quality-pr.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# This workflow finds which files were changed, prints them, -# and runs `pre-commit` on those files. - -# Inspired by the sktime library: -# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml - -name: Code Quality PR - -env: - SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL: True - -on: - pull_request: - branches: [main, "release/*"] - -jobs: - code-quality: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - - name: Find modified files - id: file_changes - uses: trilom/file-changes-action@v1.2.4 - with: - output: " " - - - name: List modified files - run: echo '${{ steps.file_changes.outputs.files}}' - - - name: Run pre-commits - uses: pre-commit/action@v2.0.3 - with: - extra_args: --files ${{ steps.file_changes.outputs.files}} diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality.yaml similarity index 64% rename from .github/workflows/code-quality-main.yaml rename to .github/workflows/code-quality.yaml index ce95976b..7f7431a7 100644 --- a/.github/workflows/code-quality-main.yaml +++ b/.github/workflows/code-quality.yaml @@ -6,6 +6,8 @@ name: Code Quality Main on: push: branches: [main] + pull_request: + branches: [main, "release/*"] jobs: code-quality: @@ -13,12 +15,12 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.13" - name: Run pre-commits - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a846507a..a320513f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -13,15 +13,15 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} @@ -46,10 +46,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" diff --git a/.github/workflows/test_runner.yaml b/.github/workflows/test_runner.yaml index 4bc58d36..e2cf62e2 100644 --- a/.github/workflows/test_runner.yaml +++ b/.github/workflows/test_runner.yaml @@ -1,10 +1,10 @@ name: Runner Tests -on: - push: - branches: [main] - pull_request: - branches: [main, "release/*"] +#on: +# push: +# branches: [main] +# pull_request: +# branches: [main, "release/*"] jobs: run_tests_ubuntu: @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} @@ -48,10 +48,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46e35536..29174a11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,6 @@ -default_language_version: - python: python3 - node: 16.14.2 - repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -29,55 +25,26 @@ repos: - id: check-added-large-files require_serial: true - # python code formatting - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - require_serial: true - args: [--line-length, "99"] - - # python import sorting - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - require_serial: true - args: ["--profile", "black", "--filter-files"] - # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade - rev: v3.9.0 + rev: v3.17.0 hooks: - id: pyupgrade require_serial: true args: [--py38-plus] - # python docstring formatting - - repo: https://github.com/myint/docformatter - rev: master + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.5 hooks: - - id: docformatter - require_serial: true - args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] - - # python check (PEP8), programming errors and code complexity - - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - require_serial: true - entry: pflake8 - additional_dependencies: ["pyproject-flake8"] + - id: ruff + args: [--fix] + - id: ruff-format # python security linter - - repo: https://github.com/PyCQA/bandit - rev: "1.7.5" + - repo: https://github.com/gitleaks/gitleaks + rev: v8.18.2 hooks: - - id: bandit - require_serial: true - args: ["-c", "pyproject.toml"] - additional_dependencies: ["bandit[toml]"] + - id: gitleaks # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier @@ -95,19 +62,13 @@ repos: require_serial: true args: ["-e", "SC2102"] - # md formatting - - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 hooks: - - id: mdformat - require_serial: true - args: ["--number"] - additional_dependencies: - - mdformat-gfm - - mdformat-tables - - mdformat_frontmatter - # - mdformat-toc - # - mdformat-black + - id: prettier + # To avoid conflicts, tell prettier to ignore file types + # that ruff already handles. + exclude_types: [python] # word spelling linter - repo: https://github.com/codespell-project/codespell diff --git a/README.md b/README.md index c029a913..321e93e9 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y. Major Changes: -- __Added cifar10 examples with an FID of 3.5__ +- **Added cifar10 examples with an FID of 3.5** - Added code for the new Simulation-free Score and Flow Matching (SF)2M preprint - Created `torchcfm` pip installable package - Moved `pytorch-lightning` implementation and experiments to `runner` directory diff --git a/examples/images/cifar10/README.md b/examples/images/cifar10/README.md index edc5a53e..e2eb5927 100644 --- a/examples/images/cifar10/README.md +++ b/examples/images/cifar10/README.md @@ -1,6 +1,6 @@ # CIFAR-10 experiments using TorchCFM -This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset. +This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an **FID of 3.5** on the Cifar10 dataset.
diff --git a/examples/images/cifar10/compute_fid.py b/examples/images/cifar10/compute_fid.py
index 7596699c..565e4ba4 100644
--- a/examples/images/cifar10/compute_fid.py
+++ b/examples/images/cifar10/compute_fid.py
@@ -3,12 +3,10 @@
# Authors: Kilian Fatras
# Alexander Tong
-import os
import sys
-import matplotlib.pyplot as plt
import torch
-from absl import app, flags
+from absl import flags
from cleanfid import fid
from torchdiffeq import odeint
from torchdyn.core import NeuralODE
@@ -81,7 +79,12 @@ def gen_1_img(unused_latent):
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, 2, device=device)
traj = odeint(
- new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
+ new_net,
+ x,
+ t_span,
+ rtol=FLAGS.tol,
+ atol=FLAGS.tol,
+ method=FLAGS.integration_method,
)
traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1)
img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0)
diff --git a/examples/images/cifar10/train_cifar10.py b/examples/images/cifar10/train_cifar10.py
index 14b8b04d..1cbad83b 100644
--- a/examples/images/cifar10/train_cifar10.py
+++ b/examples/images/cifar10/train_cifar10.py
@@ -8,7 +8,6 @@
import torch
from absl import app, flags
-from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop
@@ -98,9 +97,7 @@ def train(argv):
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
- ).to(
- device
- ) # new dropout + bs of 128
+ ).to(device) # new dropout + bs of 128
ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
diff --git a/examples/images/cifar10/train_cifar10_ddp.py b/examples/images/cifar10/train_cifar10_ddp.py
index 851f28ca..21b09683 100644
--- a/examples/images/cifar10/train_cifar10_ddp.py
+++ b/examples/images/cifar10/train_cifar10_ddp.py
@@ -12,7 +12,6 @@
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
-from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup
@@ -116,9 +115,7 @@ def train(rank, total_num_gpus, argv):
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
- ).to(
- rank
- ) # new dropout + bs of 128
+ ).to(rank) # new dropout + bs of 128
ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
@@ -181,7 +178,11 @@ def train(rank, total_num_gpus, argv):
# sample and Saving the weights
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
generate_samples(
- net_model, FLAGS.parallel, savedir, global_step, net_="normal"
+ net_model,
+ FLAGS.parallel,
+ savedir,
+ global_step,
+ net_="normal",
)
generate_samples(
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
diff --git a/examples/images/cifar10/utils_cifar.py b/examples/images/cifar10/utils_cifar.py
index cfa36b8f..818df68a 100644
--- a/examples/images/cifar10/utils_cifar.py
+++ b/examples/images/cifar10/utils_cifar.py
@@ -6,7 +6,7 @@
from torchdyn.core import NeuralODE
# from torchvision.transforms import ToPILImage
-from torchvision.utils import make_grid, save_image
+from torchvision.utils import save_image
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
@@ -28,7 +28,6 @@ def setup(
master_port: Port number of the master node.
backend: Backend to use.
"""
-
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
diff --git a/pyproject.toml b/pyproject.toml
index 94e48ab3..2424da72 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -24,17 +24,16 @@ exclude_lines = [
"if __name__ == .__main__.:",
]
-[tool.flake8]
-extend-ignore = ["E203", "E402", "E501", "F401", "F841", "E741", "F403"]
-exclude = ["logs/*","data/*"]
-per-file-ignores = [
- '__init__.py:F401',
-]
-max-line-length = 99
-count = true
+[tool.ruff]
+line-length = 99
+
+[tool.ruff.lint]
+ignore = ["C901", "E501", "E741", "W605", "C408", "E402"]
+select = ["C", "E", "F", "I", "W"]
-[tool.bandit]
-skips = ["B101", "B311"]
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
-[tool.isort]
-known_first_party = ["tests", "src"]
+[tool.ruff.lint.isort]
+known-first-party = ["src"]
+known-third-party = ["torch", "transformers", "wandb"]
diff --git a/runner/src/datamodules/cifar10_datamodule.py b/runner/src/datamodules/cifar10_datamodule.py
index e7ec3633..fe680033 100644
--- a/runner/src/datamodules/cifar10_datamodule.py
+++ b/runner/src/datamodules/cifar10_datamodule.py
@@ -1,7 +1,6 @@
from typing import Any, List, Union
import pl_bolts
-import torch
from torch.utils.data import DataLoader
from torchvision import transforms as transform_lib
diff --git a/runner/src/datamodules/components/generators2d.py b/runner/src/datamodules/components/generators2d.py
index 9b7de610..48eaa613 100644
--- a/runner/src/datamodules/components/generators2d.py
+++ b/runner/src/datamodules/components/generators2d.py
@@ -3,6 +3,7 @@
Largely from
https://github.com/AmirTag/OT-ICNN/blob/6caa9b982596a101b90a8a947d10f35f18c7de4e/2_dim_experiments/W2-minimax-tf.py
"""
+
import random
import numpy as np
diff --git a/runner/src/datamodules/components/tnet_dataset.py b/runner/src/datamodules/components/tnet_dataset.py
index 32a4d9b4..4fa167ad 100644
--- a/runner/src/datamodules/components/tnet_dataset.py
+++ b/runner/src/datamodules/components/tnet_dataset.py
@@ -2,6 +2,7 @@
Loads datasets into uniform format for learning continuous flows
"""
+
import math
import numpy as np
@@ -119,7 +120,7 @@ def plot_paths(self):
plt.show()
def factory(name, args):
- if type(args) == dict:
+ if type(args) is dict:
from argparse import Namespace
args = Namespace(**args)
@@ -161,7 +162,6 @@ def factory(name, args):
def _get_data_points(adata, basis) -> np.ndarray:
"""Returns the data points corresponding to the selected basis."""
-
if basis == "highly_variable":
data_points = adata[:, adata.var[basis]].X.toarray()
elif basis in adata.obsm.keys():
@@ -650,9 +650,10 @@ def get_paths(self, n=5000, n_steps=3):
class CircleTestDataV5(TreeTestData):
- """This builds on version 3 to include a better middle timepoint. Where instead of being
- parametrically defined, the middle timepoint is defined in terms of the interpolant between the
- first and last timepoints along the manifold.
+ """This builds on version 3 to include a better middle timepoint.
+
+ Where instead of being parametrically defined, the middle timepoint is defined in terms of the
+ interpolant between the first and last timepoints along the manifold.
This is a useful thing to relate to in terms of transport along the manifold.
"""
diff --git a/runner/src/datamodules/distribution_datamodule.py b/runner/src/datamodules/distribution_datamodule.py
index 9bb6cfc6..dfd13715 100644
--- a/runner/src/datamodules/distribution_datamodule.py
+++ b/runner/src/datamodules/distribution_datamodule.py
@@ -7,7 +7,7 @@
from pytorch_lightning import LightningDataModule
from pytorch_lightning.trainer.supporters import CombinedLoader
from sklearn.preprocessing import StandardScaler
-from torch.utils.data import DataLoader, Sampler, TensorDataset, random_split
+from torch.utils.data import DataLoader, TensorDataset, random_split
from torchdyn.datasets import ToyDataset
from src import utils
@@ -56,7 +56,7 @@ def split(self):
"""Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = self.hparams.train_val_test_split
if isinstance(train_val_test_split, int):
- self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data))
+ self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data]
return
splitter = partial(
random_split,
@@ -141,7 +141,7 @@ def split(self):
"""Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = self.hparams.train_val_test_split
if isinstance(train_val_test_split, int):
- self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data))
+ self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data]
return
splitter = partial(
random_split,
@@ -322,7 +322,7 @@ def split(self):
"""Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = self.hparams.train_val_test_split
if isinstance(train_val_test_split, int):
- self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data))
+ self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data]
return
splitter = partial(
random_split,
@@ -338,9 +338,7 @@ def closed_form_marginal(self, sigma, t):
"""
a = self.a
mean = (2 * a * t - a) * torch.ones(self.dim)
- cov = (math.sqrt(4 + sigma**4) * t * (1 - t) + (1 - t) ** 2 + t**2) * torch.eye(
- self.dim
- )
+ cov = (math.sqrt(4 + sigma**4) * t * (1 - t) + (1 - t) ** 2 + t**2) * torch.eye(self.dim)
return mean, cov
def detailed_evaluation(self, xt, sigma, t):
@@ -424,7 +422,7 @@ def split(self):
"""Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = self.hparams.train_val_test_split
if isinstance(train_val_test_split, int):
- self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data))
+ self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data]
return
splitter = partial(
random_split,
@@ -551,7 +549,7 @@ def split(self):
"""Split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = self.hparams.train_val_test_split
if isinstance(train_val_test_split, int):
- self.split_timepoint_data = list(map(lambda x: (x, x, x), self.timepoint_data))
+ self.split_timepoint_data = [(x, x, x) for x in self.timepoint_data]
return
splitter = partial(
random_split,
@@ -644,7 +642,8 @@ def __init__(
class DistributionDataModule(BaseLightningDataModule):
- """DEPRECATED: Implements loader for datasets taking the form of a sequence of distributions over time.
+ """DEPRECATED: Implements loader for datasets taking the form of a sequence of distributions
+ over time.
Each batch is a 3-tuple of data (data, time, causal graph) ([b x d], [b], [b x d x d]).
"""
diff --git a/runner/src/eval.py b/runner/src/eval.py
index 5636b3f2..94528244 100644
--- a/runner/src/eval.py
+++ b/runner/src/eval.py
@@ -57,7 +57,6 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
-
assert cfg.ckpt_path
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
diff --git a/runner/src/models/cfm_module.py b/runner/src/models/cfm_module.py
index 3006d248..925b47b3 100644
--- a/runner/src/models/cfm_module.py
+++ b/runner/src/models/cfm_module.py
@@ -5,7 +5,6 @@
import numpy as np
import torch
-import torchsde
from pytorch_lightning import LightningDataModule, LightningModule
from torch.distributions import MultivariateNormal
from torchdyn.core import NeuralODE
@@ -19,7 +18,6 @@
from .components.distribution_distances import compute_distribution_distances
from .components.optimal_transport import OTPlanSampler
from .components.plotting import (
- plot_paths,
plot_samples,
plot_trajectory,
store_trajectories,
@@ -226,7 +224,6 @@ def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
def calc_loc_and_target(self, x0, x1, t, t_select, training):
"""Computes the loss on a batch of data."""
-
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
eps_t = torch.randn_like(mu_t)
@@ -246,7 +243,6 @@ def calc_loc_and_target(self, x0, x1, t, t_select, training):
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
-
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
# Either randomly sample a single T or sample a batch of T's
@@ -280,9 +276,7 @@ def training_step(self, batch: Any, batch_idx: int):
def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
import os
- from math import prod
- from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torchvision.utils import save_image
# val_augmentations = AugmentationModule(
@@ -394,7 +388,6 @@ def preprocess_epoch_end(self, outputs: List[Any], prefix: str):
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101)
- aug_dims = self.val_augmentations.aug_dims
regs = []
trajs = []
full_trajs = []
@@ -910,7 +903,8 @@ def step(self, batch: Any, training: bool = False):
reg, vt, st = self.forward_flow_and_score(t, x)
flow_loss = self.criterion(vt, ut)
score_loss = self.criterion(
- -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2, score_target
+ -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2,
+ score_target,
)
return torch.mean(reg) + self.hparams.score_weight * score_loss, flow_loss
@@ -1051,9 +1045,7 @@ def training_epoch_end(self, training_step_outputs):
def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
import os
- from math import prod
- from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torchvision.utils import save_image
solver = self.partial_solver(self.net, self.dim)
@@ -1128,7 +1120,6 @@ def step(self, batch: Any, training: bool = False):
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101).type_as(x0)
- aug_dims = self.val_augmentations.aug_dims
regs = []
trajs = []
full_trajs = []
@@ -1238,7 +1229,6 @@ def step(self, batch: Any, training: bool = False):
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101)
- aug_dims = self.val_augmentations.aug_dims
regs = []
trajs = []
full_trajs = []
@@ -1337,11 +1327,11 @@ def step(self, batch: Any, training: bool = False):
class FMLitModule(CFMLitModule):
- """Implements a Lipman et al. 2023 style flow matching loss.
+ """Implements a Lipman et al.
- This maps the standard normal distribution to the data distribution by using conditional flows
- that are the optimal transport flow from a narrow Gaussian around a datapoint to a standard N(x
- | 0, 1).
+ 2023 style flow matching loss. This maps the standard normal distribution to the data
+ distribution by using conditional flows that are the optimal transport flow from a narrow
+ Gaussian around a datapoint to a standard N(x | 0, 1).
"""
def calc_mu_sigma(self, x0, x1, t):
@@ -1365,7 +1355,7 @@ class SplineCFMLitModule(CFMLitModule):
def preprocess_batch(self, X, training=False):
from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs
- """converts a batch of data into matched a random pair of (x0, x1)"""
+ """Converts a batch of data into matched a random pair of (x0, x1)"""
lotp = self.hparams.leaveout_timepoint
valid_times = torch.arange(X.shape[1]).type_as(X)
t_select = torch.zeros(1)
@@ -1431,7 +1421,6 @@ def step(self, batch: Any, training: bool = False):
obs = self.unpack_batch(batch)
if not self.is_trajectory:
obs = obs[:, None, :]
- aug_dims = self.augmentations.aug_dims
even_ts = torch.arange(obs.shape[1]).to(obs) + 1
self.prior = MultivariateNormal(
torch.zeros(self.dim).type_as(obs), torch.eye(self.dim).type_as(obs)
diff --git a/runner/src/models/components/augmentation.py b/runner/src/models/components/augmentation.py
index 24901878..2f08d325 100644
--- a/runner/src/models/components/augmentation.py
+++ b/runner/src/models/components/augmentation.py
@@ -1,5 +1,3 @@
-from typing import Callable, List, Union
-
import torch
from torch import nn
@@ -213,16 +211,20 @@ def forward(self, x):
class Augmenter(nn.Module):
- """Augmentation class. Can handle several types of augmentation strategies for Neural DEs.
+ """Augmentation class.
+ Can handle several types of augmentation strategies for Neural DEs.
:param augment_dims: number of augmented dimensions to initialize
:type augment_dims: int
:param augment_idx: index of dimension to augment
:type augment_idx: int
- :param augment_func: nn.Module applied to the input datasets of dimension `d` to determine the augmented initial condition of dimension `d + a`.
- `a` is defined implicitly in `augment_func` e.g. augment_func=nn.Linear(2, 5) augments a 2 dimensional input with 3 additional dimensions.
+ :param augment_func: nn.Module applied to the input datasets of dimension `d` to determine the
+ augmented initial condition of dimension `d + a`. `a` is defined implicitly in
+ `augment_func` e.g. augment_func=nn.Linear(2, 5) augments a 2 dimensional input with 3
+ additional dimensions.
:type augment_func: nn.Module
- :param order: whether to augment before datasets [augmentation, x] or after [x, augmentation] along dimension `augment_idx`. Options: ('first', 'last')
+ :param order: whether to augment before datasets [augmentation, x] or after [x, augmentation]
+ along dimension `augment_idx`. Options: ('first', 'last')
:type order: str
"""
diff --git a/runner/src/models/components/base.py b/runner/src/models/components/base.py
index ebf805c2..89715a65 100644
--- a/runner/src/models/components/base.py
+++ b/runner/src/models/components/base.py
@@ -568,9 +568,7 @@ def reset_parameters(self):
def _get_kl(self, param_mean, sigma, prior_log_sigma):
kl = torch.sum(
- prior_log_sigma
- - torch.log(sigma)
- + 0.5 * (sigma**2) / (math.exp(prior_log_sigma * 2))
+ prior_log_sigma - torch.log(sigma) + 0.5 * (sigma**2) / (math.exp(prior_log_sigma * 2))
)
kl += 0.5 * torch.sum(param_mean**2) / math.exp(prior_log_sigma * 2)
return kl
diff --git a/runner/src/models/components/emd.py b/runner/src/models/components/emd.py
index 5654f11c..c1c394da 100644
--- a/runner/src/models/components/emd.py
+++ b/runner/src/models/components/emd.py
@@ -139,7 +139,7 @@ def interpolate_per_point_with_ot(p0, p1, tmap, interp_frac):
)
I = len(p0)
- J = len(p1)
+ # J = len(p1)
# Assume growth is exponential and retrieve growth rate at t_interpolate
# If all sums are the same then this does not change anything
# This only matters if sum is not the same for all rows
diff --git a/runner/src/models/components/evaluation.py b/runner/src/models/components/evaluation.py
index 2b94e13f..8e411362 100644
--- a/runner/src/models/components/evaluation.py
+++ b/runner/src/models/components/evaluation.py
@@ -1,7 +1,6 @@
from collections import Counter
import numpy as np
-import torch
from sklearn.metrics import average_precision_score, roc_auc_score
diff --git a/runner/src/models/components/layers/diffeq_layers/basic.py b/runner/src/models/components/layers/diffeq_layers/basic.py
index 3eec391b..7f7e1805 100644
--- a/runner/src/models/components/layers/diffeq_layers/basic.py
+++ b/runner/src/models/components/layers/diffeq_layers/basic.py
@@ -438,7 +438,7 @@ def __init__(
groups=1,
bias=True,
transpose=False,
- **unused_kwargs
+ **unused_kwargs,
):
super().__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
diff --git a/runner/src/models/components/layers/odefunc.py b/runner/src/models/components/layers/odefunc.py
index c39313bf..48dfed55 100644
--- a/runner/src/models/components/layers/odefunc.py
+++ b/runner/src/models/components/layers/odefunc.py
@@ -1,6 +1,5 @@
import copy
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
diff --git a/runner/src/models/components/logger.py b/runner/src/models/components/logger.py
index 9c6e456e..eb328130 100644
--- a/runner/src/models/components/logger.py
+++ b/runner/src/models/components/logger.py
@@ -7,7 +7,6 @@
import json
import os
import os.path as osp
-import shutil
import sys
import tempfile
import time
diff --git a/runner/src/models/components/nn.py b/runner/src/models/components/nn.py
index 5e454703..cdbd6334 100644
--- a/runner/src/models/components/nn.py
+++ b/runner/src/models/components/nn.py
@@ -111,8 +111,8 @@ def checkpoint(func, inputs, params, flag):
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
+ :param params: a sequence of parameters `func` depends on but does not explicitly take as
+ arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
diff --git a/runner/src/models/components/plotting.py b/runner/src/models/components/plotting.py
index f15b06e0..f3085d5d 100644
--- a/runner/src/models/components/plotting.py
+++ b/runner/src/models/components/plotting.py
@@ -77,7 +77,6 @@ def store_trajectories(obs: Union[torch.Tensor, list], model, title="trajs", sta
batch_size, ts, dim = obs.shape
start = obs[:n, start_time, :]
obs = obs.reshape(-1, dim).detach().cpu().numpy()
- tts = np.tile(np.arange(ts), batch_size)
from torchdyn.core import NeuralODE
with torch.no_grad():
@@ -108,11 +107,9 @@ def plot_trajectory(
data = np.concatenate(data, axis=0)
labels = np.concatenate(labels, axis=0)
scprep.plot.scatter2d(data, c=labels)
- start = obs[0][:n]
ts = len(obs)
else:
batch_size, ts, dim = obs.shape
- start = obs[:n, start_time, :]
obs = obs.reshape(-1, dim).detach().cpu().numpy()
tts = np.tile(np.arange(ts), batch_size)
scprep.plot.scatter2d(obs, c=tts)
diff --git a/runner/src/models/components/sinkhorn_knopp_unbalanced.py b/runner/src/models/components/sinkhorn_knopp_unbalanced.py
index 7e7871e1..8010c0cd 100644
--- a/runner/src/models/components/sinkhorn_knopp_unbalanced.py
+++ b/runner/src/models/components/sinkhorn_knopp_unbalanced.py
@@ -6,6 +6,7 @@
something large we can compute an unbalanced optimal transport where all the scaling is done on the
source distribution and none is done on the target distribution.
"""
+
import warnings
import numpy as np
diff --git a/runner/src/models/components/solver.py b/runner/src/models/components/solver.py
index d0916ed3..d7cf2db1 100644
--- a/runner/src/models/components/solver.py
+++ b/runner/src/models/components/solver.py
@@ -11,7 +11,7 @@
import torchsde
from torchdyn.core import NeuralODE
-from .augmentation import AugmentationModule, AugmentedVectorField, Sequential
+from .augmentation import AugmentedVectorField, Sequential
class TorchSDE(torch.nn.Module):
diff --git a/runner/src/models/components/unet.py b/runner/src/models/components/unet.py
index 72760e8e..5ff82a24 100644
--- a/runner/src/models/components/unet.py
+++ b/runner/src/models/components/unet.py
@@ -1,4 +1,5 @@
"""From https://raw.githubusercontent.com/openai/guided-diffusion/main/guided_diffusion/unet.py."""
+
import math
from abc import abstractmethod
@@ -372,27 +373,25 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
+ :param attention_resolutions: a collection of downsample rates at which attention will take
+ place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
+ downsampling, attention will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
+ :param conv_resample: if True, use learned convolutions for upsampling and downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
+ :param num_classes: if specified (as an int), then this model will be class-conditional with
+ `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
+ :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width
+ per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number of heads for
+ upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
+ :param use_new_attention_order: use a different attention pattern for potentially increased
+ efficiency.
"""
def __init__(
diff --git a/runner/src/models/icnn_module.py b/runner/src/models/icnn_module.py
index c6ec571a..33544b19 100644
--- a/runner/src/models/icnn_module.py
+++ b/runner/src/models/icnn_module.py
@@ -1,20 +1,11 @@
-from typing import Any, List, Optional, Union
+from typing import Any, List
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule
-from torch import autograd, nn
-from torch.distributions import MultivariateNormal
-from torchdyn.core import NeuralODE
-
-from .components.augmentation import (
- AugmentationModule,
- AugmentedVectorField,
- Sequential,
-)
+from torch import autograd
+
from .components.distribution_distances import compute_distribution_distances
-from .components.optimal_transport import OTPlanSampler
-from .components.plotting import plot_paths, plot_scatter_and_flow
from .utils import get_wandb_logger
@@ -197,7 +188,12 @@ def x_to_y(x):
x_pred = y_to_x(x1)
plot(
- x0, x1, x_pred, pred, savename=f"{self.current_epoch}_match", wandb_logger=wandb_logger
+ x0,
+ x1,
+ x_pred,
+ pred,
+ savename=f"{self.current_epoch}_match",
+ wandb_logger=wandb_logger,
)
def validation_step(self, batch: Any, batch_idx: int):
diff --git a/runner/src/models/runner.py b/runner/src/models/runner.py
index a5e3f781..29d4e567 100644
--- a/runner/src/models/runner.py
+++ b/runner/src/models/runner.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Optional, Union
+from typing import Any, List, Optional
import numpy as np
import torch
@@ -112,7 +112,6 @@ def preprocess_epoch_end(self, outputs: List[Any], prefix: str):
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101)
- aug_dims = self.val_augmentations.aug_dims
solver = self.solver(self.net, self.dim)
solver.augmentations = self.val_augmentations
traj, aug = solver.odeint(x0, t_span)
diff --git a/runner/src/train.py b/runner/src/train.py
index 554e7230..c724c6e8 100644
--- a/runner/src/train.py
+++ b/runner/src/train.py
@@ -47,8 +47,9 @@
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
- """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
- training.
+ """Trains the model.
+
+ Can additionally evaluate on a testset, using best weights obtained during training.
This method is wrapped in optional @task_wrapper decorator which applies extra utilities
before and after the call.
@@ -59,7 +60,6 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
-
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
diff --git a/runner/src/utils/pylogger.py b/runner/src/utils/pylogger.py
index 92ffa718..93759656 100644
--- a/runner/src/utils/pylogger.py
+++ b/runner/src/utils/pylogger.py
@@ -5,12 +5,19 @@
def get_pylogger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python command line logger."""
-
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
- logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
+ logging_levels = (
+ "debug",
+ "info",
+ "warning",
+ "error",
+ "exception",
+ "fatal",
+ "critical",
+ )
for level in logging_levels:
setattr(logger, level, rank_zero_only(getattr(logger, level)))
diff --git a/runner/src/utils/rich_utils.py b/runner/src/utils/rich_utils.py
index e340dc69..ddc8a1e2 100644
--- a/runner/src/utils/rich_utils.py
+++ b/runner/src/utils/rich_utils.py
@@ -37,7 +37,6 @@ def print_config_tree(
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
save_to_file (bool, optional): Whether to export config to the hydra output folder.
"""
-
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
@@ -78,7 +77,6 @@ def print_config_tree(
@rank_zero_only
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""Prompts user to input tags from command line if no tags are provided in config."""
-
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
diff --git a/runner/src/utils/utils.py b/runner/src/utils/utils.py
index cf61b547..9b08371b 100644
--- a/runner/src/utils/utils.py
+++ b/runner/src/utils/utils.py
@@ -2,7 +2,7 @@
import warnings
from importlib.util import find_spec
from pathlib import Path
-from typing import Any, Callable, Dict, List
+from typing import Callable, List
import hydra
from omegaconf import DictConfig
@@ -60,7 +60,6 @@ def extras(cfg: DictConfig) -> None:
- Setting tags from command line
- Rich config printing
"""
-
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found!