Skip to content

Commit d9bc29b

Browse files
committed
format : black
1 parent ae44451 commit d9bc29b

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

lora_diffusion/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,15 @@
5151
]
5252

5353

54-
def image_grid(_imgs, rows = None, cols = None):
55-
54+
def image_grid(_imgs, rows=None, cols=None):
55+
5656
if rows is None and cols is None:
5757
rows = cols = math.ceil(len(_imgs) ** 0.5)
58-
58+
5959
if rows is None:
6060
rows = math.ceil(len(_imgs) / cols)
6161
if cols is None:
6262
cols = math.ceil(len(_imgs) / rows)
63-
6463

6564
w, h = _imgs[0].size
6665
grid = Image.new("RGB", size=(cols * w, rows * h))
@@ -176,25 +175,23 @@ def visualize_progress(
176175
text_sclae=1.0,
177176
num_inference_steps=50,
178177
guidance_scale=5.0,
179-
offset : int = 0,
180-
limit : int = 10,
181-
seed : int = 0
178+
offset: int = 0,
179+
limit: int = 10,
180+
seed: int = 0,
182181
):
183182

184-
185183
imgs = []
186184
if isinstance(path_alls, str):
187185
alls = list(set(glob.glob(path_alls)))
188-
186+
189187
alls.sort(key=os.path.getmtime)
190188
else:
191189
alls = path_alls
192-
190+
193191
pipe = StableDiffusionPipeline.from_pretrained(
194192
model_id, torch_dtype=torch.float16
195193
).to(device)
196194

197-
198195
print(f"Found {len(alls)} checkpoints")
199196
for path in alls[offset:limit]:
200197
print(path)
@@ -207,8 +204,11 @@ def visualize_progress(
207204
tune_lora_scale(pipe.text_encoder, text_sclae)
208205

209206
torch.manual_seed(seed)
210-
image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
207+
image = pipe(
208+
prompt,
209+
num_inference_steps=num_inference_steps,
210+
guidance_scale=guidance_scale,
211+
).images[0]
211212
imgs.append(image)
212213

213214
return imgs
214-

training_scripts/train_lora_dreambooth.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,20 +90,25 @@ def __init__(
9090
self.class_prompt = class_prompt
9191
else:
9292
self.class_data_root = None
93-
94-
img_transforms = []
9593

94+
img_transforms = []
9695

9796
if resize:
98-
img_transforms.append(transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR))
97+
img_transforms.append(
98+
transforms.Resize(
99+
size, interpolation=transforms.InterpolationMode.BILINEAR
100+
)
101+
)
99102
if center_crop:
100103
img_transforms.append(transforms.CenterCrop(size))
101104
if color_jitter:
102105
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
103106
if h_flip:
104107
img_transforms.append(transforms.RandomHorizontalFlip())
105108

106-
self.image_transforms = transforms.Compose([*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
109+
self.image_transforms = transforms.Compose(
110+
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
111+
)
107112

108113
def __len__(self):
109114
return self._length

0 commit comments

Comments
 (0)