Skip to content

Commit 25aeab4

Browse files
committed
fix : gradient checkpointing. Added to default as well.
1 parent d9bc29b commit 25aeab4

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

lora_diffusion/cli_lora_pti.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,7 @@ def train(
598598
)
599599

600600
if gradient_checkpointing:
601-
text_encoder.gradient_checkpointing_enable()
602-
unet.gradient_checkpointing_enable()
601+
unet.enable_gradient_checkpointing()
603602

604603
if scale_lr:
605604
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size

training_scripts/multivector_example.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ lora_pti \
1010
--resolution=512 \
1111
--train_batch_size=1 \
1212
--gradient_accumulation_steps=4 \
13+
--gradient_checkpointing \
1314
--scale_lr \
1415
--learning_rate_unet=1e-4 \
1516
--learning_rate_text=1e-5 \

0 commit comments

Comments
 (0)