|
36 | 36 | PivotalTuningDatasetCapation, |
37 | 37 | extract_lora_ups_down, |
38 | 38 | inject_trainable_lora, |
| 39 | + inject_trainable_lora_extended, |
39 | 40 | inspect_lora, |
40 | 41 | save_lora_weight, |
41 | 42 | save_all, |
42 | 43 | prepare_clip_model_sets, |
43 | 44 | evaluate_pipe, |
| 45 | + UNET_EXTENDED_TARGET_REPLACE, |
44 | 46 | ) |
45 | 47 |
|
46 | 48 |
|
@@ -418,6 +420,8 @@ def perform_tuning( |
418 | 420 | placeholder_tokens, |
419 | 421 | save_path, |
420 | 422 | lr_scheduler_lora, |
| 423 | + lora_unet_target_modules, |
| 424 | + lora_clip_target_modules, |
421 | 425 | ): |
422 | 426 |
|
423 | 427 | progress_bar = tqdm(range(num_steps)) |
@@ -467,6 +471,8 @@ def perform_tuning( |
467 | 471 | save_path=os.path.join( |
468 | 472 | save_path, f"step_{global_step}.safetensors" |
469 | 473 | ), |
| 474 | + target_replace_module_text=lora_clip_target_modules, |
| 475 | + target_replace_module_unet=lora_unet_target_modules, |
470 | 476 | ) |
471 | 477 | moved = ( |
472 | 478 | torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) |
@@ -521,11 +527,12 @@ def train( |
521 | 527 | lora_rank: int = 4, |
522 | 528 | lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, |
523 | 529 | lora_clip_target_modules={"CLIPAttention"}, |
| 530 | + use_extended_lora: bool = False, |
524 | 531 | clip_ti_decay: bool = True, |
525 | 532 | learning_rate_unet: float = 1e-4, |
526 | 533 | learning_rate_text: float = 1e-5, |
527 | 534 | learning_rate_ti: float = 5e-4, |
528 | | - continue_inversion: bool = True, |
| 535 | + continue_inversion: bool = False, |
529 | 536 | continue_inversion_lr: Optional[float] = None, |
530 | 537 | use_face_segmentation_condition: bool = False, |
531 | 538 | scale_lr: bool = False, |
@@ -690,9 +697,21 @@ def train( |
690 | 697 | del ti_optimizer |
691 | 698 |
|
692 | 699 | # Next perform Tuning with LoRA: |
693 | | - unet_lora_params, _ = inject_trainable_lora( |
694 | | - unet, r=lora_rank, target_replace_module=lora_unet_target_modules |
695 | | - ) |
| 700 | + if not use_extended_lora: |
| 701 | + unet_lora_params, _ = inject_trainable_lora( |
| 702 | + unet, r=lora_rank, target_replace_module=lora_unet_target_modules |
| 703 | + ) |
| 704 | + else: |
| 705 | + print("USING EXTENDED UNET!!!") |
| 706 | + lora_unet_target_modules = ( |
| 707 | + lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE |
| 708 | + ) |
| 709 | + print("Will replace modules: ", lora_unet_target_modules) |
| 710 | + |
| 711 | + unet_lora_params, _ = inject_trainable_lora_extended( |
| 712 | + unet, r=lora_rank, target_replace_module=lora_unet_target_modules |
| 713 | + ) |
| 714 | + print(f"PTI : has {len(unet_lora_params)} lora") |
696 | 715 |
|
697 | 716 | print("Before training:") |
698 | 717 | inspect_lora(unet) |
@@ -720,7 +739,8 @@ def train( |
720 | 739 | ) |
721 | 740 | for param in params_to_freeze: |
722 | 741 | param.requires_grad = False |
723 | | - |
| 742 | + else: |
| 743 | + text_encoder.requires_grad_(False) |
724 | 744 | if train_text_encoder: |
725 | 745 | text_encoder_lora_params, _ = inject_trainable_lora( |
726 | 746 | text_encoder, |
@@ -763,6 +783,8 @@ def train( |
763 | 783 | placeholder_token_ids=placeholder_token_ids, |
764 | 784 | save_path=output_dir, |
765 | 785 | lr_scheduler_lora=lr_scheduler_lora, |
| 786 | + lora_unet_target_modules=lora_unet_target_modules, |
| 787 | + lora_clip_target_modules=lora_clip_target_modules, |
766 | 788 | ) |
767 | 789 |
|
768 | 790 |
|
|
0 commit comments