Skip to content

Commit a65f262

Browse files
committed
add some new configs & whatnot
1 parent ddad80c commit a65f262

File tree

4 files changed

+17
-14
lines changed

4 files changed

+17
-14
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
unet:
2+
in_channels: 3
3+
# 6 channels b/c need to output learned variance interpolation vector
4+
out_channels: 6
5+
model_channels: 256
6+
channel_mult: [1, 2, 3, 4]
7+
layer_attn: [false, false, true, true]
8+
res_blocks: 3
9+
attention_heads: 4
10+
diffusion:
11+
schedule: cosine
12+
steps: 1000
13+
learn_sigma: true

config_run/1xA100-80GB-learned.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ warmup_batches: 500
66
# b/c I suspect it will increase convergence
77
# significantly
88
lr: 1.00*1e-4
9+
cosine_factor: 1.2
910

1011
batch_size: 302
1112
micro_batches: 1

config_run/1xA100-80GB.bash

Lines changed: 0 additions & 13 deletions
This file was deleted.

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class RunConfig(hp.Hparams):
1717

1818
warmup_batches: int = hp.required("Warmup batches")
1919
lr: str = hp.required("Learning rate")
20+
cosine_factor: float = hp.required("# the factor to use for the cosine scheduler")
2021

2122
batch_size: int = hp.required("Batch size")
2223
micro_batches: int = hp.required("Micro-batches")
@@ -27,6 +28,7 @@ class RunConfig(hp.Hparams):
2728
evals: int = hp.required("# Evals")
2829

2930

31+
3032
def main(
3133
model_config_file: Path = typer.Option(...),
3234
run_config_file: Path = typer.Option(...),
@@ -62,7 +64,7 @@ def get_num(expr, typ):
6264
batch_rate=c.batch_rate,
6365
target_time=target_time,
6466
warmup=c.warmup_batches,
65-
cosine_factor=1.4,
67+
cosine_factor=c.cosine_factor,
6668
)
6769

6870
train_dl = dataloader(dir_train, c.batch_size, workers=8)

0 commit comments

Comments
 (0)