-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining.py
118 lines (101 loc) · 3.31 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.data import Dataset as tfds
from gaussian_diffusion import GaussianDiffusion
from diffusion_model import DiffusionModel
from unet import build_model
from postproc_utils import PostProcess
import platform
import subprocess
def is_m_chip():
"""
Checks if the current machine is a Mac with an M-chip (Apple Silicon).
Returns:
bool: True if the machine is a Mac with an M-chip, False otherwise.
"""
# Check if the operating system is macOS
if platform.system() != 'Darwin':
return False
# Check if the architecture is arm64
try:
# 'uname -m' returns 'arm64' for Apple Silicon
chip_type = subprocess.check_output(['uname', '-m']).decode('utf-8').strip()
return chip_type == 'arm64'
except Exception as e:
print(f"Error during check: {e}")
return False
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
tf.config.set_visible_devices(gpus, 'GPU')
battery_dataset = 'matr_1'
num_epochs = 1 # Just for the sake of demonstration
total_timesteps = 1000
learning_rate = 1e-3
sequence_length = 256
first_conv_channels = 8
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2 # Number of residual blocks
p_uncond = 0.2
# Build the unet model
network = build_model(
sequence_length=sequence_length,
widths=widths,
has_attention=has_attention,
first_conv_channels = first_conv_channels,
num_res_blocks=num_res_blocks,
)
ema_network = build_model(
sequence_length=sequence_length,
widths=widths,
has_attention=has_attention,
first_conv_channels = first_conv_channels,
num_res_blocks=num_res_blocks,
)
ema_network.set_weights(network.get_weights()) # Initially the weights are the same
# Get an instance of the Gaussian Diffusion utilities
gdf_util = GaussianDiffusion(timesteps=total_timesteps)
# Get the model
model = DiffusionModel(
network=network,
ema_network=ema_network,
gdf_util=gdf_util,
timesteps=total_timesteps,
p_uncond=p_uncond,
)
train_ds = tfds.load(f'./data/{battery_dataset}_train_ds')
test_ds = tfds.load(f'./data/{battery_dataset}_test_ds')
# Compile the model
if is_m_chip():
keras_optimizer = keras.optimizers.legacy.Adam(learning_rate=learning_rate)
else:
keras_optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(
loss=keras.losses.MeanSquaredError(),
optimizer=keras_optimizer,
weighted_metrics=[]
)
checkpoint_filepath = f'./checkpoints/{battery_dataset}_checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
verbose=1,
monitor='val_loss',
save_freq = 'epoch',
mode='min',
save_best_only=True)
# Train the model
hist = model.fit(
train_ds,
epochs=num_epochs,
validation_data=test_ds,
callbacks=[model_checkpoint_callback],
verbose=2,
validation_freq=1,
)
model.load_weights(checkpoint_filepath)
network.save(f'./checkpoints/{battery_dataset}_network', save_format='tf')
ema_network.save(f'./checkpoints/{battery_dataset}_ema_network', save_format='tf')