Skip to content

Commit

Permalink
Textual Inversion for M1
Browse files Browse the repository at this point in the history
Update main.py

Update ddpm.py

Update personalized.py

Update personalized_style.py

Update v1-finetune.yaml

Update environment-mac.yaml

Rename v1-finetune.yaml to v1-m1-finetune.yaml

Create v1-finetune.yaml

Update main.py

Update main.py

Update environment-mac.yaml

Update v1-inference.yaml
  • Loading branch information
Any-Winter-4079 committed Sep 26, 2022
1 parent d2b5702 commit e19aab4
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 23 deletions.
2 changes: 1 addition & 1 deletion configs/stable-diffusion/v1-finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@ lightning:
benchmark: True
max_steps: 4000000
# max_steps: 4000


4 changes: 2 additions & 2 deletions configs/stable-diffusion/v1-inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ model:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 1
num_vectors_per_token: 6
progressive_words: False

unet_config:
Expand Down
110 changes: 110 additions & 0 deletions configs/stable-diffusion/v1-m1-finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
model:
base_learning_rate: 5.0e-03
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0

personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 6
progressive_words: False

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: val
per_image_tokens: false
repeats: 10

lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 5
increase_log_steps: False

trainer:
benchmark: False
max_steps: 6200
# max_steps: 4000

6 changes: 3 additions & 3 deletions environment-mac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ dependencies:
- omegaconf==2.1.1
- onnx==1.12.0
- onnxruntime==1.12.1
- protobuf==3.20.1
- protobuf==3.19.5
- pudb==2022.1
- pytorch-lightning==1.6.5
- pytorch-lightning==1.7.5
- scipy==1.9.1
- streamlit==1.12.2
- sympy==1.10.1
- tensorboard==2.9.0
- tensorboard==2.10.0
- torchmetrics==0.9.3
- pip:
- flask==2.1.3
Expand Down
2 changes: 1 addition & 1 deletion ldm/data/personalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(

self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]

# self._length = len(self.image_paths)
Expand Down
2 changes: 1 addition & 1 deletion ldm/data/personalized_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(

self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]

# self._length = len(self.image_paths)
Expand Down
4 changes: 2 additions & 2 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def make_cond_schedule(

@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
# only for very first batch
if (
self.scale_by_std
Expand Down Expand Up @@ -1890,7 +1890,7 @@ def log_images(
N=8,
n_row=4,
sample=True,
ddim_steps=200,
ddim_steps=50,
ddim_eta=1.0,
return_keys=None,
quantize_denoised=True,
Expand Down
11 changes: 8 additions & 3 deletions ldm/modules/embedding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,14 @@ def forward(
placeholder_embedding.shape[0], max_step_tokens
)

placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if torch.cuda.is_available():
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
else:
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token
)

if placeholder_rows.nelement() == 0:
continue
Expand Down
45 changes: 35 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,23 @@
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config

def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig

torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)

def load_model_from_config(config, ckpt, verbose=False):
print(f'Loading model from {ckpt}')
Expand Down Expand Up @@ -422,9 +439,7 @@ def __init__(
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.TestTubeLogger: self._testtube,
}
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
]
Expand Down Expand Up @@ -527,15 +542,15 @@ def check_frequency(self, check_idx):
return False

def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and (
pl_module.global_step > 0 or self.log_first_step
):
self.log_img(pl_module, batch, batch_idx, split='train')

def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split='val')
Expand All @@ -555,7 +570,7 @@ def on_train_epoch_start(self, trainer, pl_module):
torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time()

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
if torch.cuda.is_available():
torch.cuda.synchronize(trainer.root_gpu)
epoch_time = time.time() - self.start_time
Expand Down Expand Up @@ -736,6 +751,12 @@ def on_train_epoch_start(self, trainer, pl_module):
trainer_kwargs = dict()

# default logger configs
if torch.cuda.is_available():
def_logger = 'testtube'
def_logger_target = 'TestTubeLogger'
else:
def_logger = 'csv'
def_logger_target = 'CSVLogger'
default_logger_cfgs = {
'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger',
Expand All @@ -746,15 +767,15 @@ def on_train_epoch_start(self, trainer, pl_module):
'id': nowname,
},
},
'testtube': {
'target': 'pytorch_lightning.loggers.TestTubeLogger',
def_logger: {
'target': 'pytorch_lightning.loggers.' + def_logger_target,
'params': {
'name': 'testtube',
'name': def_logger,
'save_dir': logdir,
},
},
}
default_logger_cfg = default_logger_cfgs['testtube']
default_logger_cfg = default_logger_cfgs[def_logger]
if 'logger' in lightning_config:
logger_cfg = lightning_config.logger
else:
Expand Down Expand Up @@ -868,6 +889,10 @@ def on_train_epoch_start(self, trainer, pl_module):
]
trainer_kwargs['max_steps'] = trainer_opt.max_steps

if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
trainer_opt.accelerator = 'mps'
trainer_opt.detect_anomaly = False

trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###

Expand Down

0 comments on commit e19aab4

Please sign in to comment.