diff --git a/baler/modules/data_processing.py b/baler/modules/data_processing.py index d14d9958..682ec4fa 100644 --- a/baler/modules/data_processing.py +++ b/baler/modules/data_processing.py @@ -57,7 +57,10 @@ def encoder_saver(model, model_path: str) -> None: Returns: None: Saved encoder state dictionary as `.pt` file. """ - torch.save(model.encoder.state_dict(), model_path) + if hasattr(model.encoder, "state_dict"): + torch.save(model.encoder.state_dict(), model_path) + else: + model.save_encoder(model_path) def decoder_saver(model, model_path: str) -> None: @@ -70,7 +73,10 @@ def decoder_saver(model, model_path: str) -> None: Returns: None: Saved decoder state dictionary as `.pt` file. """ - torch.save(model.decoder.state_dict(), model_path) + if hasattr(model.decoder, "state_dict"): + torch.save(model.decoder.state_dict(), model_path) + else: + model.save_decoder(model_path) def initialise_model(model_name: str): diff --git a/baler/modules/models.py b/baler/modules/models.py index f95ac4ca..cc3dc6cb 100644 --- a/baler/modules/models.py +++ b/baler/modules/models.py @@ -712,3 +712,81 @@ def get_final_layer_dims(self): def set_final_layer_dims(self, conv_op_shape): self.conv_op_shape = conv_op_shape + + +class PJ_Conv_AE_FPGA(nn.Module): + def __init__(self, n_features, z_dim=10, *args, **kwargs): + super(PJ_Conv_AE_FPGA, self).__init__(*args, **kwargs) + + # Encoder layers + self.en1 = nn.Conv2d(1, 20, kernel_size=5, stride=2, padding=2) + self.en_act1 = nn.ReLU() + self.en2 = nn.Conv2d(20, 50, kernel_size=5, stride=2, padding=2) + self.en_act2 = nn.Flatten() + self.en3 = nn.Linear(50 * 7 * 7, 500) + self.en4 = nn.Linear(500, z_dim) + + # Decoder layers + self.de1 = nn.Linear(z_dim, 500) + self.de_act1 = nn.ReLU() + self.de2 = nn.Linear(500, 2450) + self.de_unflatten = nn.Unflatten(1, (50, 7, 7)) + self.de_conv1 = nn.ConvTranspose2d( + 50, 20, kernel_size=5, stride=2, padding=2, output_padding=1 + ) + self.de_conv2 = nn.ConvTranspose2d( + 20, 1, kernel_size=5, stride=2, padding=2, output_padding=1 + ) + self.de_act2 = nn.ReLU() + + self.output_shape = None + + def encoder(self, x): + s1 = self.en1(x) + s2 = self.en_act1(s1) + s3 = self.en2(s2) + s4 = self.en_act2(s3) + s5 = self.en3(s4) + s6 = self.en4(s5) + return s6 + + def decoder(self, z): + d1 = self.de1(z) + d2 = self.de_act1(d1) + d3 = self.de2(d2) + d4 = self.de_unflatten(d3) + d5 = self.de_conv1(d4) + d6 = self.de_conv2(d5) + self.output_shape = d6.shape + return d6 + + def forward(self, x): + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded + + def get_final_layer_dims(self): + return + + def set_final_layer_dims(self, conv_op_shape): + self.conv_op_shape = conv_op_shape + + def save_encoder(self, file_path): + # Create an instance of the encoder + encoder_instance = nn.Sequential( + self.en1, self.en_act1, self.en2, self.en_act2, self.en3, self.en4 + ) + torch.save(encoder_instance.state_dict(), file_path) + + def save_decoder(self, file_path): + # Create an instance of the decoder + decoder_instance = nn.Sequential( + self.de1, + self.de_act1, + self.de2, + self.de_unflatten, + self.de_conv1, + self.de_conv2, + self.de_act2, + ) + torch.save(decoder_instance.state_dict(), file_path)