Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions .github/workflows/code-quality-pr.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ name: Code Quality Main
on:
push:
branches: [main]
pull_request:
branches: [main, "release/*"]

jobs:
code-quality:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v5

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.10"
python-version: "3.13"

- name: Run pre-commits
uses: pre-commit/action@v2.0.3
uses: pre-commit/action@v3.0.1
12 changes: 6 additions & 6 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v5

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -46,10 +46,10 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v5

- name: Set up Python 3.10
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.10"

Expand Down
18 changes: 9 additions & 9 deletions .github/workflows/test_runner.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
name: Runner Tests

on:
push:
branches: [main]
pull_request:
branches: [main, "release/*"]
#on:
# push:
# branches: [main]
# pull_request:
# branches: [main, "release/*"]

jobs:
run_tests_ubuntu:
Expand All @@ -18,10 +18,10 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v5

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -48,10 +48,10 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v5

- name: Set up Python 3.10
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.10"

Expand Down
71 changes: 16 additions & 55 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
default_language_version:
python: python3
node: 16.14.2

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
Expand All @@ -29,55 +25,26 @@ repos:
- id: check-added-large-files
require_serial: true

# python code formatting
- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
require_serial: true
args: [--line-length, "99"]

# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
require_serial: true
args: ["--profile", "black", "--filter-files"]

# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.9.0
rev: v3.17.0
hooks:
- id: pyupgrade
require_serial: true
args: [--py38-plus]

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: master
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
hooks:
- id: docformatter
require_serial: true
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]

# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
require_serial: true
entry: pflake8
additional_dependencies: ["pyproject-flake8"]
- id: ruff
args: [--fix]
- id: ruff-format

# python security linter
- repo: https://github.com/PyCQA/bandit
rev: "1.7.5"
- repo: https://github.com/gitleaks/gitleaks
rev: v8.18.2
hooks:
- id: bandit
require_serial: true
args: ["-c", "pyproject.toml"]
additional_dependencies: ["bandit[toml]"]
- id: gitleaks

# yaml formatting
- repo: https://github.com/pre-commit/mirrors-prettier
Expand All @@ -95,19 +62,13 @@ repos:
require_serial: true
args: ["-e", "SC2102"]

# md formatting
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: mdformat
require_serial: true
args: ["--number"]
additional_dependencies:
- mdformat-gfm
- mdformat-tables
- mdformat_frontmatter
# - mdformat-toc
# - mdformat-black
- id: prettier
# To avoid conflicts, tell prettier to ignore file types
# that ruff already handles.
exclude_types: [python]

# word spelling linter
- repo: https://github.com/codespell-project/codespell
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y.

Major Changes:

- __Added cifar10 examples with an FID of 3.5__
- **Added cifar10 examples with an FID of 3.5**
- Added code for the new Simulation-free Score and Flow Matching (SF)2M preprint
- Created `torchcfm` pip installable package
- Moved `pytorch-lightning` implementation and experiments to `runner` directory
Expand Down
2 changes: 1 addition & 1 deletion examples/images/cifar10/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# CIFAR-10 experiments using TorchCFM

This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an __FID of 3.5__ on the Cifar10 dataset.
This repository is used to reproduce the CIFAR-10 experiments from [1](https://arxiv.org/abs/2302.00482). We have designed a novel experimental procedure that helps us to reach an **FID of 3.5** on the Cifar10 dataset.

<p align="center">
<img src="../../../assets/169_generated_samples_otcfm.png" width="600"/>
Expand Down
11 changes: 7 additions & 4 deletions examples/images/cifar10/compute_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
# Authors: Kilian Fatras
# Alexander Tong

import os
import sys

import matplotlib.pyplot as plt
import torch
from absl import app, flags
from absl import flags
from cleanfid import fid
from torchdiffeq import odeint
from torchdyn.core import NeuralODE
Expand Down Expand Up @@ -81,7 +79,12 @@ def gen_1_img(unused_latent):
print("Use method: ", FLAGS.integration_method)
t_span = torch.linspace(0, 1, 2, device=device)
traj = odeint(
new_net, x, t_span, rtol=FLAGS.tol, atol=FLAGS.tol, method=FLAGS.integration_method
new_net,
x,
t_span,
rtol=FLAGS.tol,
atol=FLAGS.tol,
method=FLAGS.integration_method,
)
traj = traj[-1, :] # .view([-1, 3, 32, 32]).clip(-1, 1)
img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8) # .permute(1, 2, 0)
Expand Down
5 changes: 1 addition & 4 deletions examples/images/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
from absl import app, flags
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop
Expand Down Expand Up @@ -98,9 +97,7 @@ def train(argv):
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(
device
) # new dropout + bs of 128
).to(device) # new dropout + bs of 128

ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
Expand Down
11 changes: 6 additions & 5 deletions examples/images/cifar10/train_cifar10_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup
Expand Down Expand Up @@ -116,9 +115,7 @@ def train(rank, total_num_gpus, argv):
num_head_channels=64,
attention_resolutions="16",
dropout=0.1,
).to(
rank
) # new dropout + bs of 128
).to(rank) # new dropout + bs of 128

ema_model = copy.deepcopy(net_model)
optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
Expand Down Expand Up @@ -181,7 +178,11 @@ def train(rank, total_num_gpus, argv):
# sample and Saving the weights
if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
generate_samples(
net_model, FLAGS.parallel, savedir, global_step, net_="normal"
net_model,
FLAGS.parallel,
savedir,
global_step,
net_="normal",
)
generate_samples(
ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
Expand Down
3 changes: 1 addition & 2 deletions examples/images/cifar10/utils_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchdyn.core import NeuralODE

# from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid, save_image
from torchvision.utils import save_image

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
Expand All @@ -28,7 +28,6 @@ def setup(
master_port: Port number of the master node.
backend: Backend to use.
"""

os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

Expand Down
23 changes: 11 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@ exclude_lines = [
"if __name__ == .__main__.:",
]

[tool.flake8]
extend-ignore = ["E203", "E402", "E501", "F401", "F841", "E741", "F403"]
exclude = ["logs/*","data/*"]
per-file-ignores = [
'__init__.py:F401',
]
max-line-length = 99
count = true
[tool.ruff]
line-length = 99

[tool.ruff.lint]
ignore = ["C901", "E501", "E741", "W605", "C408", "E402"]
select = ["C", "E", "F", "I", "W"]

[tool.bandit]
skips = ["B101", "B311"]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]

[tool.isort]
known_first_party = ["tests", "src"]
[tool.ruff.lint.isort]
known-first-party = ["src"]
known-third-party = ["torch", "transformers", "wandb"]
1 change: 0 additions & 1 deletion runner/src/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, List, Union

import pl_bolts
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as transform_lib

Expand Down
Loading
Loading