Skip to content

Commit be1ef54

Browse files
authored
Merge pull request #133 from cloneofsimo/convlora
Feature/ Better LoRA : Dropout, Conv2d
2 parents 25aeab4 + 770b97a commit be1ef54

File tree

4 files changed

+264
-78
lines changed

4 files changed

+264
-78
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ __pycache__
55
__test*
66
merged_lora*
77
wandb
8-
exps
8+
exps*
99
.vscode

lora_diffusion/cli_lora_pti.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
PivotalTuningDatasetCapation,
3737
extract_lora_ups_down,
3838
inject_trainable_lora,
39+
inject_trainable_lora_extended,
3940
inspect_lora,
4041
save_lora_weight,
4142
save_all,
4243
prepare_clip_model_sets,
4344
evaluate_pipe,
45+
UNET_EXTENDED_TARGET_REPLACE,
4446
)
4547

4648

@@ -418,6 +420,8 @@ def perform_tuning(
418420
placeholder_tokens,
419421
save_path,
420422
lr_scheduler_lora,
423+
lora_unet_target_modules,
424+
lora_clip_target_modules,
421425
):
422426

423427
progress_bar = tqdm(range(num_steps))
@@ -467,6 +471,8 @@ def perform_tuning(
467471
save_path=os.path.join(
468472
save_path, f"step_{global_step}.safetensors"
469473
),
474+
target_replace_module_text=lora_clip_target_modules,
475+
target_replace_module_unet=lora_unet_target_modules,
470476
)
471477
moved = (
472478
torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
@@ -521,11 +527,12 @@ def train(
521527
lora_rank: int = 4,
522528
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
523529
lora_clip_target_modules={"CLIPAttention"},
530+
use_extended_lora: bool = False,
524531
clip_ti_decay: bool = True,
525532
learning_rate_unet: float = 1e-4,
526533
learning_rate_text: float = 1e-5,
527534
learning_rate_ti: float = 5e-4,
528-
continue_inversion: bool = True,
535+
continue_inversion: bool = False,
529536
continue_inversion_lr: Optional[float] = None,
530537
use_face_segmentation_condition: bool = False,
531538
scale_lr: bool = False,
@@ -690,9 +697,21 @@ def train(
690697
del ti_optimizer
691698

692699
# Next perform Tuning with LoRA:
693-
unet_lora_params, _ = inject_trainable_lora(
694-
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
695-
)
700+
if not use_extended_lora:
701+
unet_lora_params, _ = inject_trainable_lora(
702+
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
703+
)
704+
else:
705+
print("USING EXTENDED UNET!!!")
706+
lora_unet_target_modules = (
707+
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
708+
)
709+
print("Will replace modules: ", lora_unet_target_modules)
710+
711+
unet_lora_params, _ = inject_trainable_lora_extended(
712+
unet, r=lora_rank, target_replace_module=lora_unet_target_modules
713+
)
714+
print(f"PTI : has {len(unet_lora_params)} lora")
696715

697716
print("Before training:")
698717
inspect_lora(unet)
@@ -720,7 +739,8 @@ def train(
720739
)
721740
for param in params_to_freeze:
722741
param.requires_grad = False
723-
742+
else:
743+
text_encoder.requires_grad_(False)
724744
if train_text_encoder:
725745
text_encoder_lora_params, _ = inject_trainable_lora(
726746
text_encoder,
@@ -763,6 +783,8 @@ def train(
763783
placeholder_token_ids=placeholder_token_ids,
764784
save_path=output_dir,
765785
lr_scheduler_lora=lr_scheduler_lora,
786+
lora_unet_target_modules=lora_unet_target_modules,
787+
lora_clip_target_modules=lora_clip_target_modules,
766788
)
767789

768790

0 commit comments

Comments
 (0)