Skip to content

Commit

Permalink
cifar10 generation
Browse files Browse the repository at this point in the history
  • Loading branch information
swyoon committed Jul 30, 2024
1 parent 365cf92 commit 11eedf7
Show file tree
Hide file tree
Showing 17 changed files with 2,901 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

datasets/*
pretrained/*
80 changes: 79 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,81 @@
# Diffusion-by-MaxEntIRL

Code is coming soon!
The official code release of
**Maximum Entropy Inverse Reinforcement Learning of Diffusion Models with Energy-Based Models**

Sangwoong Yoon, Himchan Hwang, Dohyun Kwon, Yung-Kyun Noh, Frank C. Park
arxiv: https://arxiv.org/abs/2407.00626

![DxMI](figure/DxMI_figure_crop.jpg)

## Environment

* python >= 3.8
* pytorch >= 2.0
* cuda >= 11.6

## Unit tests

```
python -m pytest tests/
```

## TODO & Status

[] 2D
[] CIFAR-10 DDPM
[] Training
[] Generation
[] CIFAR-10 DDGAN
[] Training
[v] Generation
[] ImageNet64
[] Training
[] Generation
[] Anomaly Detection
[] FID Evaluation


## Datasets

```
datasets
├── cifar-10-batches-py
├── cifar10_train_png
├── cifar10_train_fid_stats.pt
├── imagenet # corresponds to ILSVRC/Data/CLS-LOC/train
└── mvtec
├── train_data.pth
   └── val_data.pth
```

Dataset files are released in [dropbox link](https://www.dropbox.com/scl/fo/kk65utuwwirobbltha4oq/AFYUYYhqNZBq8FIr0VX8uPY?rlkey=vh90rf1o6vhsxmywbktsea3sf&dl=0)

**CIFAR-10**

**ImageNet 64x64**

**MVTec-AD**

## Model Checkpoints

Model checkpoints files can be found in [dropbox link](https://www.dropbox.com/scl/fo/hubdctq91m273eomviuvb/AOKLhw1gg50ljxOSMTla8Ko?rlkey=o5ixr0xdr05391ap2fwigzdkx&dl=0)


## Generation

**CIFAR-10**

Run `generate_cifar10.py` for unconditional CIFAR-10 generation. This script automatically loads the config and the checkpoint and generate images.

The script also reports FID evaluated using `pytorch_fid` package. However, the FID scores reported in the paper are computed using Tensorflow code. (See Evaluation)

```
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 generate_cifar10.py --log_dir pretrained/cifar10_ddpm_dxmi_T10 \
--stat datasets/cifar10_train_fid_stats.pt -n 50000
```

## Training


## Evaluation
62 changes: 62 additions & 0 deletions cmd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
command-line argument parsing utilities
Example:
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='Config file')
parser.add_argument('--dataset', choices=['8gaussians', 'checkerboard', 'swissroll'], default='8gaussians', help='Dataset to use')
parser.add_argument('--run', default=None, type=str, help='Run name')
parser.add_argument('--device', default=0, type=str, help='Device to use')
args, unknown = parser.parse_known_args()
d_cmd_cfg = parse_unknown_args(unknown)
d_cmd_cfg = parse_nested_args(d_cmd_cfg)
"""

def parse_arg_type(val):
if val.isnumeric():
return int(val)
try:
return float(val)
except ValueError:

if val.lower() == 'true':
return True
elif val.lower() == 'false':
return False
elif val.lower() == 'null' or val.lower() == 'none':
return None
elif val.startswith('[') and val.endswith(']'): # parse list
return eval(val)
return val


def parse_unknown_args(l_args):
"""convert the list of unknown args into dict
this does similar stuff to OmegaConf.from_cli()
I may have invented the wheel again..."""
n_args = len(l_args) // 2
kwargs = {}
for i_args in range(n_args):
key = l_args[i_args*2]
val = l_args[i_args*2 + 1]
assert '=' not in key, 'optional arguments should be separated by space'
kwargs[key.strip('-')] = parse_arg_type(val)
return kwargs


def parse_nested_args(d_cmd_cfg):
"""produce a nested dictionary by parsing dot-separated keys
e.g. {key1.key2 : 1} --> {key1: {key2: 1}}"""
d_new_cfg = {}
for key, val in d_cmd_cfg.items():
l_key = key.split('.')
d = d_new_cfg
for i_key, each_key in enumerate(l_key):
if i_key == len(l_key) - 1:
d[each_key] = val
else:
if each_key not in d:
d[each_key] = {}
d = d[each_key]
return d_new_cfg
Binary file added figure/DxMI_figure_crop.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
172 changes: 172 additions & 0 deletions generate_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
generate.py
==========
Generate samples from a trained sampler model and calculate FID score with the generated samples and the training set.
Example:
torchrun --nproc_per_node=4 generate.py --log_dir results/cifar10/gcdv3_vi/vanilla --batchsize 100 -n 50000
"""
import argparse
import cmd_utils as cmd
from omegaconf import OmegaConf
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import random
import os
from utils import mkdir_p, print0
from hydra.utils import instantiate
from torchvision.utils import save_image
from tqdm import trange
from pytorch_fid.fid_score import calculate_fid_given_paths, calculate_fid_given_paths_cache


def print_size(net):
"""
Print the number of parameters of a network
"""
if net is not None and isinstance(net, torch.nn.Module):
module_parameters = filter(lambda p: p.requires_grad, net.parameters())
params = sum([np.prod(p.size()) for p in module_parameters])
print("{} Parameters: {:.6f}M".format(
net.__class__.__name__, params / 1e6), flush=True)

def rescale(X, batch=True):
return (X - (-1)) / (2)

if __name__ == '__main__':
"""
Example usage: torchrun --nproc_per_node=4 generate.py --log_dir results/cifar10/gcdv3_svi/test
Input: log directory (ex. results/cifar10/gcdv3_svi/test)
sampler.pth, config.yaml must be present in logdir
sampler model is loaded from sampler.pth, sampling parameters and dataset name is inferred from config.yaml
n_generate samples are saved in logdir/generated and FID score is calculated with the generated samples and the training set
generated images will be saved in logdir/generated, unless --save_images is set to False
n_generate should be a multiple of batchsize
"""
parser = argparse.ArgumentParser()
parser.add_argument('--log_dir', type=str, required=True, help='path to log_dir')
parser.add_argument("--batchsize", type=int, default=100, help="batch size to generate samples")
parser.add_argument('-n', '--n_generate', type=int, help='number of samples to generate', default = 50000)
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--epoch', type=str, default='best', help='which check point to load. can be "best" or "last"')
parser.add_argument('-save', '--save_images', type=bool, default=True, help='Whether to retain the generated images after calculating FID')
parser.add_argument('--guidance_scale', type=float, default=None, help = 'Value guidance scale, 0.0 for no guidance')
parser.add_argument('--stat', type=str, default=None, help='''path to precalculated statistics for FID calculation. Only used for pytorch_fid computation.
Example: datasets/cifar10_train_fid_stats.pt''')

# parse command line arguments
args, unknown = parser.parse_known_args()

# set random seed
# setting seeds
torch.backends.cudnn.deterministic = True
local_rank = int(os.environ["LOCAL_RANK"])
device = "cuda:{}".format(local_rank)
seed = args.seed
torch.cuda.set_device(device)
torch.manual_seed(seed + local_rank)
np.random.seed(seed + local_rank)
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed_all(seed + local_rank)
random.seed(seed + local_rank)
os.environ['PYTHONHASHSEED'] = str(seed + local_rank)

assert args.n_generate % args.batchsize == 0, "n_generate must be a multiple of batchsize"

config_path = os.path.join(args.log_dir, 'config.yaml')
if not os.path.exists(config_path):
raise ValueError("Config not found at {}".format(config_path))
run_config = OmegaConf.load(config_path)

data_path = os.path.join('datasets', f'{run_config.data.name}_train_png') # ex) datasets/cifar10_train_png
if not os.path.exists(data_path):
raise ValueError("Dataset not found at {}".format(data_path))

if args.guidance_scale is not None:
output_path = os.path.join(args.log_dir, f'generated_{args.guidance_scale}')
else:
output_path = os.path.join(args.log_dir, 'generated')
mkdir_p(output_path)

sampler_path = os.path.join(args.log_dir, f'sampler_{args.epoch}.pth')
old_sampler_path = os.path.join(args.log_dir, 'sampler.pth')
if not os.path.exists(sampler_path) and os.path.exists(old_sampler_path): # for backward compatibility
sampler_path = old_sampler_path
if not os.path.exists(sampler_path):
raise ValueError("Sampler not found at {}".format(sampler_path))

# load sampler
net = instantiate(run_config.sampler_net)
print_size(net)
sample_shape = run_config.sampler.sample_shape
sampler = instantiate(run_config.sampler, net=net).to(device)

checkpoint = torch.load(sampler_path, map_location=device)
sampler.net.load_state_dict(checkpoint['state_dict'])
print0("Loaded sampler from {}".format(sampler_path))
epoch_info = checkpoint['epoch']
fid_info = checkpoint['fid']
print0("Model was trained for {} epochs, FID score evaluated with 10000 samples are: {}".format(epoch_info, fid_info))
sampler.eval()

if args.guidance_scale is not None:
print0("Loading value function for guidance")
value_path = os.path.join(args.log_dir, 'value_best.pth')
if not os.path.exists(value_path):
raise ValueError("Value ftn not found at {}".format(value_path))
v = instantiate(run_config.value).to(device)
checkpoint = torch.load(value_path, map_location=device)
v.load_state_dict(checkpoint['state_dict'])
v.eval()
else:
v = None

ngpus = torch.cuda.device_count()
if ngpus >= 1:
print0(f"Using distributed training on {ngpus} gpus.")
torch.distributed.init_process_group(backend="nccl", init_method="env://")
sampler.net = DDP(sampler.net, device_ids=[local_rank], output_device=local_rank)
if v is not None:
v = DDP(v, device_ids=[local_rank], output_device=local_rank)

if args.guidance_scale is not None:
trainer = instantiate(run_config.trainer, batchsize=args.batchsize)
trainer.set_models(f=None, v=v, sampler=sampler, optimizer=None,
optimizer_fstar=None, optimizer_v=None)

n_sample_to_generate = args.n_generate / args.batchsize / ngpus
i_img = 0
for i in trange(int(n_sample_to_generate), ncols=80):
with torch.no_grad():
if args.guidance_scale is not None:
d_sample = trainer.sample_guidance(n_sample=args.batchsize, device=device, guidance_scale=args.guidance_scale)
else:
d_sample = sampler.sample(args.batchsize, device=device)
Xi = d_sample['sample'].detach().cpu()
sample = rescale(Xi).clamp(0,1).detach().cpu()
for s in sample:
save_image(s, os.path.join(output_path, f'{local_rank}_{i_img}.png'))
i_img += 1

torch.distributed.barrier() # make sure all files are generated
print0(f"Generated {args.n_generate} samples at {output_path}")

if local_rank <= 0:
print("Calculating FID score")
paths = [output_path, data_path]
kwargs = {'batch_size': args.batchsize, 'device': device, 'dims': 2048}
if args.stat is None:
fid_score = calculate_fid_given_paths(paths, **kwargs)
else:
print(f"Loading precomputed statistics from {args.stat}")
d_fid_stats = torch.load(args.stat)
m2, s2 = d_fid_stats['m2'], d_fid_stats['s2']
paths = [output_path, data_path]
fid_score, _, _ = calculate_fid_given_paths_cache(paths=paths, m2=m2, s2=s2, **kwargs)
print(f"FID score: {fid_score}")

if not args.save_images:
import shutil
shutil.rmtree(output_path)
Empty file added models/__init__.py
Empty file.
22 changes: 22 additions & 0 deletions models/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
'''code from https://github.com/UW-Madison-Lee-Lab/SFT-PG/blob/main/toy_exp/train.py'''
import torch


def make_beta_schedule(schedule='linear', n_timesteps=1000, start=1e-5, end=1e-2):
if schedule == 'linear':
betas = torch.linspace(start, end, n_timesteps)
elif schedule == "quad":
betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2
elif schedule == "sigmoid": # this is what is used in Ying Fan
betas = torch.linspace(-6, 6, n_timesteps)
betas = torch.sigmoid(betas) * (end - start) + start
elif schedule == "constant":
betas = torch.ones(n_timesteps) * start
return betas


def extract(input, t, x):
shape = x.shape
out = torch.gather(input, 0, t.to(input.device))
reshape = [t.shape[0]] + [1] * (len(shape) - 1)
return out.reshape(*reshape)
Empty file added models/image32/__init__.py
Empty file.
Loading

0 comments on commit 11eedf7

Please sign in to comment.