-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Hamidreza Eivazi
committed
Oct 21, 2024
0 parents
commit 180129d
Showing
152 changed files
with
9,338 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Ignore compiled Python code | ||
__pycache__ | ||
|
||
# Ignore VSCode data | ||
.vscode/ | ||
|
||
# Ignore spyproject | ||
.spyproject |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
This is a directory for model checkpoints. |
Binary file added
BIN
+2.75 MB
data/clo_test_ds/9068255351728993909/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
9068255351728993909�갗�Ї *0�> |
Binary file added
BIN
+4.24 MB
data/clo_train_ds/11191737472994864236/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
11191737472994864236�����Ї *0�> |
Binary file added
BIN
+1.33 MB
data/cruh_test_ds/3493010029797832045/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.33 MB
data/cruh_test_ds/5236434107691057326/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.33 MB
data/cruh_test_ds/6698691441096244991/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.33 MB
data/cruh_test_ds/7703266081458975963/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
5236434107691057326��â�݇ *0�> |
Binary file added
BIN
+2 MB
data/cruh_train_ds/10685006334357990431/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
7821779845748606900��â�݇ *0�> |
Binary file added
BIN
+526 KB
data/crush_test_ds/3353282546599082243/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3353282546599082243���ч *0�> |
Binary file added
BIN
+811 KB
data/crush_train_ds/17241199862395263327/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
17241199862395263327�΄��ч *0�> |
Binary file added
BIN
+884 KB
data/hust_test_ds/11302614873109255603/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
11302614873109255603��ݒ�Ї *0�> |
Binary file added
BIN
+2.16 MB
data/hust_train_ds/1938179409874429169/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1938179409874429169��ܒ�Ї *0�> |
Binary file added
BIN
+1.65 MB
data/matr_1_test_ds/12707456702768718607/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.65 MB
data/matr_1_test_ds/4583604515328584653/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
12707456702768718607Ѻ���· *0�> |
Binary file added
BIN
+1.61 MB
data/matr_1_train_ds/16475868263709912137/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.61 MB
data/matr_1_train_ds/8060084496854373314/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
16475868263709912137�ٓ��· *0�> |
Binary file added
BIN
+1.57 MB
data/matr_2_test_ds/14764405622374977204/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.57 MB
data/matr_2_test_ds/17842825233777586549/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.57 MB
data/matr_2_test_ds/18166745313780747664/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.57 MB
data/matr_2_test_ds/4237353906892565456/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.57 MB
data/matr_2_test_ds/7959759803595615959/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
4237353906892565456�����ԇ *0�> |
Binary file added
BIN
+1.61 MB
data/matr_2_train_ds/17757197488228993877/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.61 MB
data/matr_2_train_ds/4152201548562599209/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.61 MB
data/matr_2_train_ds/6497096738088307729/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.61 MB
data/matr_2_train_ds/6676123268512669149/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+1.61 MB
data/matr_2_train_ds/8357047373603310238/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
8357047373603310238�����ԇ *0�> |
Binary file added
BIN
+5.37 MB
data/mix_test_ds/18376487205760940729/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+5.37 MB
data/mix_test_ds/9766666449346203155/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
18376487205760940729͉���χ *0�> |
Binary file added
BIN
+8.04 MB
data/mix_train_ds/12966391879456103793/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+8.04 MB
data/mix_train_ds/1338059158835488053/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������dd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
12966391879456103793с���χ *0�> |
Binary file added
BIN
+223 KB
data/snl_test_ds/11021408624446146764/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file added
BIN
+223 KB
data/snl_test_ds/12056647078796054666/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
12056647078796054666�ᬨ�Ї *0�> |
Binary file added
BIN
+357 KB
data/snl_train_ds/17054954317676099652/00000000.shard/00000000.snapshot
Binary file not shown.
Binary file not shown.
Binary file added
BIN
+357 KB
data/snl_train_ds/3274803510225957421/00000000.shard/00000000.snapshot
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
�[ | ||
����������� | ||
���������� | ||
����������d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3274803510225957421�����Ї *0�> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
class DiffusionModel(keras.Model): | ||
def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999, p_uncond=0.0, first_channels=8): | ||
super().__init__() | ||
self.network = network | ||
self.ema_network = ema_network | ||
self.timesteps = timesteps | ||
self.gdf_util = gdf_util | ||
self.ema = ema | ||
self.p_uncond = p_uncond | ||
self.first_channels = first_channels | ||
|
||
@tf.function | ||
def bernoulli(self, shape): | ||
c = tf.random.uniform(shape, minval=0, maxval=1, dtype=tf.float32) | ||
c = tf.where(c < self.p_uncond, 0.0, 1.0) | ||
return c | ||
|
||
@tf.function | ||
def train_step(self, images): | ||
images, _, protocol = images | ||
# 1. Get the batch size | ||
batch_size = tf.shape(images)[0] | ||
|
||
# 2. Sample timesteps uniformly | ||
t = tf.random.uniform( | ||
minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64 | ||
) | ||
|
||
c_mask = self.bernoulli(shape=(batch_size,)) | ||
c_mask = tf.tile(c_mask[...,None], [1, self.first_channels*4]) | ||
|
||
with tf.GradientTape() as tape: | ||
# 3. Sample random noise to be added to the images in the batch | ||
noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype) | ||
|
||
# 4. Diffuse the images with noise | ||
images_t = self.gdf_util.q_sample(images, t, noise) | ||
|
||
# 5. Pass the diffused images and time steps to the network | ||
pred_noise = self.network([images_t, t, protocol, c_mask], training=True) | ||
|
||
# 6. Calculate the loss | ||
loss = self.loss(noise, pred_noise) | ||
|
||
# 7. Get the gradients | ||
gradients = tape.gradient(loss, self.network.trainable_weights) | ||
|
||
# 8. Update the weights of the network | ||
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) | ||
|
||
# 9. Updates the weight values for the network with EMA weights | ||
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): | ||
ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight) | ||
|
||
# 10. Return loss values | ||
return {"loss": loss} | ||
|
||
@tf.function | ||
def test_step(self, images): | ||
images, _, protocol = images | ||
# 1. Get the batch size | ||
batch_size = tf.shape(images)[0] | ||
|
||
# 2. Sample timesteps uniformly | ||
t = tf.random.uniform( | ||
minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64 | ||
) | ||
|
||
c_mask = tf.ones(shape=(batch_size, self.first_channels*4)) | ||
|
||
# 3. Sample random noise to be added to the images in the batch | ||
noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype) | ||
|
||
# 4. Diffuse the images with noise | ||
images_t = self.gdf_util.q_sample(images, t, noise) | ||
|
||
# 5. Pass the diffused images and time steps to the network | ||
pred_noise = self.network([images_t, t, protocol, c_mask], training=False) | ||
|
||
# 6. Calculate the loss | ||
loss = self.loss(noise, pred_noise) | ||
|
||
# 10. Return loss values | ||
return {"loss": loss} | ||
|
||
@tf.function | ||
def generate(self, samples, tt, capacity_matrices, guide_w): | ||
ones = tf.ones((len(samples), self.first_channels*4)) | ||
zeros = tf.zeros((len(samples), self.first_channels*4)) | ||
|
||
pred_noise1 = self.ema_network([samples, tt, capacity_matrices, ones], training=False) | ||
pred_noise2 = self.ema_network([samples, tt, capacity_matrices, zeros], training=False) | ||
pred_noise = (1+guide_w)*pred_noise1 - guide_w*pred_noise2 | ||
samples = self.gdf_util.p_sample( | ||
pred_noise, samples, tt, clip_denoised=False | ||
) | ||
return samples | ||
|
||
def generate_samples(self, capacity_matrices, guide_w = 0.0, record_samples=False): | ||
# 1. Randomly sample noise (starting point for reverse process) | ||
num_images = len(capacity_matrices) | ||
samples = tf.random.normal( | ||
shape=(num_images, 256, 1), dtype=tf.float32 | ||
) | ||
capacity_matrices = tf.cast(capacity_matrices, dtype=tf.float32) | ||
|
||
record = [] | ||
record.append(samples) | ||
# 2. Sample from the model iteratively | ||
for t in reversed(range(0, self.timesteps)): | ||
tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64) | ||
samples = self.generate(samples, tt, capacity_matrices, guide_w) | ||
if record_samples: | ||
record.append(samples) | ||
# 3. Return generated samples | ||
if record_samples: | ||
return samples, record | ||
else: | ||
return samples | ||
|
||
def get_config(self): | ||
config = super().get_config().copy() | ||
config.update({ | ||
'network': self.network, | ||
'ema_network': self.ema_network, | ||
'timesteps': self.timesteps, | ||
'gdf_util': self.gdf_util, | ||
'ema': self.ema, | ||
'p_uncond': self.p_uncond, | ||
'first_channels': self.first_channels, | ||
}) | ||
return config |
Oops, something went wrong.