-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
2,901 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,81 @@ | ||
# Diffusion-by-MaxEntIRL | ||
|
||
Code is coming soon! | ||
The official code release of | ||
**Maximum Entropy Inverse Reinforcement Learning of Diffusion Models with Energy-Based Models** | ||
|
||
Sangwoong Yoon, Himchan Hwang, Dohyun Kwon, Yung-Kyun Noh, Frank C. Park | ||
arxiv: https://arxiv.org/abs/2407.00626 | ||
|
||
 | ||
|
||
## Environment | ||
|
||
* python >= 3.8 | ||
* pytorch >= 2.0 | ||
* cuda >= 11.6 | ||
|
||
## Unit tests | ||
|
||
``` | ||
python -m pytest tests/ | ||
``` | ||
|
||
## TODO & Status | ||
|
||
[] 2D | ||
[] CIFAR-10 DDPM | ||
[] Training | ||
[] Generation | ||
[] CIFAR-10 DDGAN | ||
[] Training | ||
[v] Generation | ||
[] ImageNet64 | ||
[] Training | ||
[] Generation | ||
[] Anomaly Detection | ||
[] FID Evaluation | ||
|
||
|
||
## Datasets | ||
|
||
``` | ||
datasets | ||
├── cifar-10-batches-py | ||
├── cifar10_train_png | ||
├── cifar10_train_fid_stats.pt | ||
├── imagenet # corresponds to ILSVRC/Data/CLS-LOC/train | ||
└── mvtec | ||
├── train_data.pth | ||
└── val_data.pth | ||
``` | ||
|
||
Dataset files are released in [dropbox link](https://www.dropbox.com/scl/fo/kk65utuwwirobbltha4oq/AFYUYYhqNZBq8FIr0VX8uPY?rlkey=vh90rf1o6vhsxmywbktsea3sf&dl=0) | ||
|
||
**CIFAR-10** | ||
|
||
**ImageNet 64x64** | ||
|
||
**MVTec-AD** | ||
|
||
## Model Checkpoints | ||
|
||
Model checkpoints files can be found in [dropbox link](https://www.dropbox.com/scl/fo/hubdctq91m273eomviuvb/AOKLhw1gg50ljxOSMTla8Ko?rlkey=o5ixr0xdr05391ap2fwigzdkx&dl=0) | ||
|
||
|
||
## Generation | ||
|
||
**CIFAR-10** | ||
|
||
Run `generate_cifar10.py` for unconditional CIFAR-10 generation. This script automatically loads the config and the checkpoint and generate images. | ||
|
||
The script also reports FID evaluated using `pytorch_fid` package. However, the FID scores reported in the paper are computed using Tensorflow code. (See Evaluation) | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 generate_cifar10.py --log_dir pretrained/cifar10_ddpm_dxmi_T10 \ | ||
--stat datasets/cifar10_train_fid_stats.pt -n 50000 | ||
``` | ||
|
||
## Training | ||
|
||
|
||
## Evaluation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
""" | ||
command-line argument parsing utilities | ||
Example: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--config', type=str, help='Config file') | ||
parser.add_argument('--dataset', choices=['8gaussians', 'checkerboard', 'swissroll'], default='8gaussians', help='Dataset to use') | ||
parser.add_argument('--run', default=None, type=str, help='Run name') | ||
parser.add_argument('--device', default=0, type=str, help='Device to use') | ||
args, unknown = parser.parse_known_args() | ||
d_cmd_cfg = parse_unknown_args(unknown) | ||
d_cmd_cfg = parse_nested_args(d_cmd_cfg) | ||
""" | ||
|
||
def parse_arg_type(val): | ||
if val.isnumeric(): | ||
return int(val) | ||
try: | ||
return float(val) | ||
except ValueError: | ||
|
||
if val.lower() == 'true': | ||
return True | ||
elif val.lower() == 'false': | ||
return False | ||
elif val.lower() == 'null' or val.lower() == 'none': | ||
return None | ||
elif val.startswith('[') and val.endswith(']'): # parse list | ||
return eval(val) | ||
return val | ||
|
||
|
||
def parse_unknown_args(l_args): | ||
"""convert the list of unknown args into dict | ||
this does similar stuff to OmegaConf.from_cli() | ||
I may have invented the wheel again...""" | ||
n_args = len(l_args) // 2 | ||
kwargs = {} | ||
for i_args in range(n_args): | ||
key = l_args[i_args*2] | ||
val = l_args[i_args*2 + 1] | ||
assert '=' not in key, 'optional arguments should be separated by space' | ||
kwargs[key.strip('-')] = parse_arg_type(val) | ||
return kwargs | ||
|
||
|
||
def parse_nested_args(d_cmd_cfg): | ||
"""produce a nested dictionary by parsing dot-separated keys | ||
e.g. {key1.key2 : 1} --> {key1: {key2: 1}}""" | ||
d_new_cfg = {} | ||
for key, val in d_cmd_cfg.items(): | ||
l_key = key.split('.') | ||
d = d_new_cfg | ||
for i_key, each_key in enumerate(l_key): | ||
if i_key == len(l_key) - 1: | ||
d[each_key] = val | ||
else: | ||
if each_key not in d: | ||
d[each_key] = {} | ||
d = d[each_key] | ||
return d_new_cfg |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
""" | ||
generate.py | ||
========== | ||
Generate samples from a trained sampler model and calculate FID score with the generated samples and the training set. | ||
Example: | ||
torchrun --nproc_per_node=4 generate.py --log_dir results/cifar10/gcdv3_vi/vanilla --batchsize 100 -n 50000 | ||
""" | ||
import argparse | ||
import cmd_utils as cmd | ||
from omegaconf import OmegaConf | ||
import torch | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
import numpy as np | ||
import random | ||
import os | ||
from utils import mkdir_p, print0 | ||
from hydra.utils import instantiate | ||
from torchvision.utils import save_image | ||
from tqdm import trange | ||
from pytorch_fid.fid_score import calculate_fid_given_paths, calculate_fid_given_paths_cache | ||
|
||
|
||
def print_size(net): | ||
""" | ||
Print the number of parameters of a network | ||
""" | ||
if net is not None and isinstance(net, torch.nn.Module): | ||
module_parameters = filter(lambda p: p.requires_grad, net.parameters()) | ||
params = sum([np.prod(p.size()) for p in module_parameters]) | ||
print("{} Parameters: {:.6f}M".format( | ||
net.__class__.__name__, params / 1e6), flush=True) | ||
|
||
def rescale(X, batch=True): | ||
return (X - (-1)) / (2) | ||
|
||
if __name__ == '__main__': | ||
""" | ||
Example usage: torchrun --nproc_per_node=4 generate.py --log_dir results/cifar10/gcdv3_svi/test | ||
Input: log directory (ex. results/cifar10/gcdv3_svi/test) | ||
sampler.pth, config.yaml must be present in logdir | ||
sampler model is loaded from sampler.pth, sampling parameters and dataset name is inferred from config.yaml | ||
n_generate samples are saved in logdir/generated and FID score is calculated with the generated samples and the training set | ||
generated images will be saved in logdir/generated, unless --save_images is set to False | ||
n_generate should be a multiple of batchsize | ||
""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--log_dir', type=str, required=True, help='path to log_dir') | ||
parser.add_argument("--batchsize", type=int, default=100, help="batch size to generate samples") | ||
parser.add_argument('-n', '--n_generate', type=int, help='number of samples to generate', default = 50000) | ||
parser.add_argument('--seed', type=int, default=0, help='random seed') | ||
parser.add_argument('--epoch', type=str, default='best', help='which check point to load. can be "best" or "last"') | ||
parser.add_argument('-save', '--save_images', type=bool, default=True, help='Whether to retain the generated images after calculating FID') | ||
parser.add_argument('--guidance_scale', type=float, default=None, help = 'Value guidance scale, 0.0 for no guidance') | ||
parser.add_argument('--stat', type=str, default=None, help='''path to precalculated statistics for FID calculation. Only used for pytorch_fid computation. | ||
Example: datasets/cifar10_train_fid_stats.pt''') | ||
|
||
# parse command line arguments | ||
args, unknown = parser.parse_known_args() | ||
|
||
# set random seed | ||
# setting seeds | ||
torch.backends.cudnn.deterministic = True | ||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
device = "cuda:{}".format(local_rank) | ||
seed = args.seed | ||
torch.cuda.set_device(device) | ||
torch.manual_seed(seed + local_rank) | ||
np.random.seed(seed + local_rank) | ||
torch.backends.cudnn.benchmark = False | ||
torch.cuda.manual_seed_all(seed + local_rank) | ||
random.seed(seed + local_rank) | ||
os.environ['PYTHONHASHSEED'] = str(seed + local_rank) | ||
|
||
assert args.n_generate % args.batchsize == 0, "n_generate must be a multiple of batchsize" | ||
|
||
config_path = os.path.join(args.log_dir, 'config.yaml') | ||
if not os.path.exists(config_path): | ||
raise ValueError("Config not found at {}".format(config_path)) | ||
run_config = OmegaConf.load(config_path) | ||
|
||
data_path = os.path.join('datasets', f'{run_config.data.name}_train_png') # ex) datasets/cifar10_train_png | ||
if not os.path.exists(data_path): | ||
raise ValueError("Dataset not found at {}".format(data_path)) | ||
|
||
if args.guidance_scale is not None: | ||
output_path = os.path.join(args.log_dir, f'generated_{args.guidance_scale}') | ||
else: | ||
output_path = os.path.join(args.log_dir, 'generated') | ||
mkdir_p(output_path) | ||
|
||
sampler_path = os.path.join(args.log_dir, f'sampler_{args.epoch}.pth') | ||
old_sampler_path = os.path.join(args.log_dir, 'sampler.pth') | ||
if not os.path.exists(sampler_path) and os.path.exists(old_sampler_path): # for backward compatibility | ||
sampler_path = old_sampler_path | ||
if not os.path.exists(sampler_path): | ||
raise ValueError("Sampler not found at {}".format(sampler_path)) | ||
|
||
# load sampler | ||
net = instantiate(run_config.sampler_net) | ||
print_size(net) | ||
sample_shape = run_config.sampler.sample_shape | ||
sampler = instantiate(run_config.sampler, net=net).to(device) | ||
|
||
checkpoint = torch.load(sampler_path, map_location=device) | ||
sampler.net.load_state_dict(checkpoint['state_dict']) | ||
print0("Loaded sampler from {}".format(sampler_path)) | ||
epoch_info = checkpoint['epoch'] | ||
fid_info = checkpoint['fid'] | ||
print0("Model was trained for {} epochs, FID score evaluated with 10000 samples are: {}".format(epoch_info, fid_info)) | ||
sampler.eval() | ||
|
||
if args.guidance_scale is not None: | ||
print0("Loading value function for guidance") | ||
value_path = os.path.join(args.log_dir, 'value_best.pth') | ||
if not os.path.exists(value_path): | ||
raise ValueError("Value ftn not found at {}".format(value_path)) | ||
v = instantiate(run_config.value).to(device) | ||
checkpoint = torch.load(value_path, map_location=device) | ||
v.load_state_dict(checkpoint['state_dict']) | ||
v.eval() | ||
else: | ||
v = None | ||
|
||
ngpus = torch.cuda.device_count() | ||
if ngpus >= 1: | ||
print0(f"Using distributed training on {ngpus} gpus.") | ||
torch.distributed.init_process_group(backend="nccl", init_method="env://") | ||
sampler.net = DDP(sampler.net, device_ids=[local_rank], output_device=local_rank) | ||
if v is not None: | ||
v = DDP(v, device_ids=[local_rank], output_device=local_rank) | ||
|
||
if args.guidance_scale is not None: | ||
trainer = instantiate(run_config.trainer, batchsize=args.batchsize) | ||
trainer.set_models(f=None, v=v, sampler=sampler, optimizer=None, | ||
optimizer_fstar=None, optimizer_v=None) | ||
|
||
n_sample_to_generate = args.n_generate / args.batchsize / ngpus | ||
i_img = 0 | ||
for i in trange(int(n_sample_to_generate), ncols=80): | ||
with torch.no_grad(): | ||
if args.guidance_scale is not None: | ||
d_sample = trainer.sample_guidance(n_sample=args.batchsize, device=device, guidance_scale=args.guidance_scale) | ||
else: | ||
d_sample = sampler.sample(args.batchsize, device=device) | ||
Xi = d_sample['sample'].detach().cpu() | ||
sample = rescale(Xi).clamp(0,1).detach().cpu() | ||
for s in sample: | ||
save_image(s, os.path.join(output_path, f'{local_rank}_{i_img}.png')) | ||
i_img += 1 | ||
|
||
torch.distributed.barrier() # make sure all files are generated | ||
print0(f"Generated {args.n_generate} samples at {output_path}") | ||
|
||
if local_rank <= 0: | ||
print("Calculating FID score") | ||
paths = [output_path, data_path] | ||
kwargs = {'batch_size': args.batchsize, 'device': device, 'dims': 2048} | ||
if args.stat is None: | ||
fid_score = calculate_fid_given_paths(paths, **kwargs) | ||
else: | ||
print(f"Loading precomputed statistics from {args.stat}") | ||
d_fid_stats = torch.load(args.stat) | ||
m2, s2 = d_fid_stats['m2'], d_fid_stats['s2'] | ||
paths = [output_path, data_path] | ||
fid_score, _, _ = calculate_fid_given_paths_cache(paths=paths, m2=m2, s2=s2, **kwargs) | ||
print(f"FID score: {fid_score}") | ||
|
||
if not args.save_images: | ||
import shutil | ||
shutil.rmtree(output_path) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
'''code from https://github.com/UW-Madison-Lee-Lab/SFT-PG/blob/main/toy_exp/train.py''' | ||
import torch | ||
|
||
|
||
def make_beta_schedule(schedule='linear', n_timesteps=1000, start=1e-5, end=1e-2): | ||
if schedule == 'linear': | ||
betas = torch.linspace(start, end, n_timesteps) | ||
elif schedule == "quad": | ||
betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2 | ||
elif schedule == "sigmoid": # this is what is used in Ying Fan | ||
betas = torch.linspace(-6, 6, n_timesteps) | ||
betas = torch.sigmoid(betas) * (end - start) + start | ||
elif schedule == "constant": | ||
betas = torch.ones(n_timesteps) * start | ||
return betas | ||
|
||
|
||
def extract(input, t, x): | ||
shape = x.shape | ||
out = torch.gather(input, 0, t.to(input.device)) | ||
reshape = [t.shape[0]] + [1] * (len(shape) - 1) | ||
return out.reshape(*reshape) |
Empty file.
Oops, something went wrong.