diff --git a/src/layers.py b/src/layers.py index 509a6a1..ac02dfe 100644 --- a/src/layers.py +++ b/src/layers.py @@ -54,6 +54,7 @@ def __init__( union_tp=None, dropout=0.0, no_mix=True, + device = 'cuda' ): super().__init__() assert embed_time % num_heads == 0 @@ -69,7 +70,7 @@ def __init__( nn.Linear(embed_time, embed_time, bias=False), nn.Linear(input_size * num_heads, nhidden, bias=False) ]) - self.time_emb = TimeEmbedding(embed_time, arg='periodic') + self.time_emb = TimeEmbedding(embed_time, arg='periodic',device=device) self.dropout = nn.Dropout(p=dropout) self.union_tp = union_tp self.no_mix = no_mix diff --git a/src/models.py b/src/models.py index c9a1f42..23e6da2 100644 --- a/src/models.py +++ b/src/models.py @@ -26,6 +26,7 @@ def load_network(args, dim, union_tp=None, device="cuda"): mse_weight=args.mse_weight, norm=args.norm, mixing=args.mixing, + device=device ).to(device) elif args.net == 'hetvae_det': net = HeTVAE_DET( @@ -46,6 +47,7 @@ def load_network(args, dim, union_tp=None, device="cuda"): mse_weight=args.mse_weight, norm=args.norm, mixing=args.mixing, + device=device ).to(device) elif args.net == 'hetvae_prob': net = HeTVAE_PROB( @@ -66,6 +68,7 @@ def load_network(args, dim, union_tp=None, device="cuda"): mse_weight=args.mse_weight, norm=args.norm, mixing=args.mixing, + device=device ).to(device) else: raise ValueError("Network not available") diff --git a/src/vae_models.py b/src/vae_models.py index 7896f3c..2d34940 100644 --- a/src/vae_models.py +++ b/src/vae_models.py @@ -243,12 +243,14 @@ def __init__( intensity=self.intensity, union_tp=self.union_tp, no_mix=True, + device=self.device ) self.decoder = UnTAN( input_dim=self.latent_dim, nhidden=self.nhidden, embed_time=self.embed_time, num_heads=self.num_heads, + device=self.device ) @@ -276,6 +278,7 @@ def __init__( nhidden=self.nhidden, embed_time=self.embed_time, num_heads=self.num_heads, + device=self.device ) def encode(self, context_x, context_y):