@@ -348,7 +348,8 @@ def forward(self, J_indx):
348
348
done = self ._forward (self .J [i ])
349
349
except RuntimeError as e :
350
350
if 'out of memory' in str (e ):
351
- torch .cuda .empty_cache ()
351
+ if torch .cuda .is_available ():
352
+ torch .cuda .empty_cache ()
352
353
353
354
if _n_batches < self .n_train :
354
355
_n_batches = _next_batch_size (
@@ -478,8 +479,6 @@ def __init__(
478
479
479
480
self .tril_indices = np .tril_indices (self .n_atoms , k = - 1 )
480
481
481
- self .R_d_desc = None
482
-
483
482
if torch .cuda .is_available (): # Ignore limits and take whatever the GPU has.
484
483
max_memory = (
485
484
min (
@@ -522,39 +521,30 @@ def __init__(
522
521
523
522
self .max_processes = max_processes
524
523
525
- self .perm_idxs = (
526
- torch .tensor (model ['tril_perms_lin' ]).view (- 1 , self .n_perms ).t ()
527
- )
528
-
529
- self ._xs_train = nn .Parameter (
530
- self .apply_perms_to_obj (torch .tensor (model ['R_desc' ]).t (), perm_idxs = None ),
531
- requires_grad = False ,
532
- )
533
- self ._Jx_alphas = nn .Parameter (
534
- self .apply_perms_to_obj (
535
- torch .tensor (np .array (model ['R_d_desc_alpha' ])), perm_idxs = None
536
- ),
537
- requires_grad = False ,
538
- )
524
+ self .R_d_desc = None
525
+ self ._xs_train = nn .Parameter (torch .tensor (model ['R_desc' ]).t (), requires_grad = False )
526
+ self ._Jx_alphas = nn .Parameter (torch .tensor (np .array (model ['R_d_desc_alpha' ])), requires_grad = False )
539
527
540
528
self ._alphas_E = None
541
529
if 'alphas_E' in model :
542
530
self ._alphas_E = nn .Parameter (
543
531
torch .from_numpy (model ['alphas_E' ]), requires_grad = False
544
532
)
545
533
534
+ self .perm_idxs = (
535
+ torch .tensor (model ['tril_perms_lin' ]).view (- 1 , self .n_perms ).t ()
536
+ )
537
+
546
538
# Try to cache all permutated variants of 'self._xs_train' and 'self._Jx_alphas'
547
539
try :
548
540
self .set_n_perm_batches (n_perm_batches )
549
541
except RuntimeError as e :
550
542
if 'out of memory' in str (e ):
551
- torch .cuda .empty_cache ()
543
+ if torch .cuda .is_available ():
544
+ torch .cuda .empty_cache ()
552
545
553
546
if n_perm_batches == 1 :
554
- self ._log .debug (
555
- 'Trying to cache permutations FAILED during init (continuing without)'
556
- )
557
- self .set_n_perm_batches (2 )
547
+ self .set_n_perm_batches (2 ) # Set to 2 perm batches, because that's the first batch size (and fastest) that is not cached.
558
548
pass
559
549
else :
560
550
self ._log .critical (
@@ -577,11 +567,8 @@ def __init__(
577
567
else self .n_train
578
568
)
579
569
_batch_size = min (_batch_size , max_batch_size )
580
- self ._log .debug (
581
- 'Starting with a batch size of {} ({} points in total).' .format (
582
- _batch_size , self .n_train
583
- )
584
- )
570
+
571
+ self ._log .debug ('Setting batch size to {}/{} points.' .format (_batch_size , self .n_train ))
585
572
586
573
self .desc = Desc (self .n_atoms , max_processes = max_processes )
587
574
@@ -594,9 +581,7 @@ def set_n_perm_batches(self, n_perm_batches):
594
581
595
582
global _n_perm_batches
596
583
597
- self ._log .debug (
598
- 'Permutations will be generated in {} batches.' .format (n_perm_batches )
599
- )
584
+ self ._log .debug ('Setting permutation batch size to {}{}.' .format (n_perm_batches , ' (no caching)' if n_perm_batches > 1 else '' ))
600
585
601
586
_n_perm_batches = n_perm_batches
602
587
if n_perm_batches == 1 and self .n_perms > 1 :
@@ -627,14 +612,12 @@ def uncache_perms(self):
627
612
628
613
xs_train_n_perms = self ._xs_train .numel () // (self .n_train * self .dim_d )
629
614
if xs_train_n_perms != 1 : # Uncached already?
630
- self ._log .debug ('Uncaching permutations for \' self._xs_train\' ' )
631
615
self ._xs_train = nn .Parameter (
632
616
self .remove_perms_from_obj (self ._xs_train ), requires_grad = False
633
617
)
634
618
635
619
Jx_alphas_n_perms = self ._Jx_alphas .numel () // (self .n_train * self .dim_d )
636
620
if Jx_alphas_n_perms != 1 : # Uncached already?
637
- self ._log .debug ('Uncaching permutations for \' self._Jx_alphas\' ' )
638
621
self ._Jx_alphas = nn .Parameter (
639
622
self .remove_perms_from_obj (self ._Jx_alphas ), requires_grad = False
640
623
)
@@ -643,19 +626,13 @@ def cache_perms(self):
643
626
644
627
xs_train_n_perms = self ._xs_train .numel () // (self .n_train * self .dim_d )
645
628
if xs_train_n_perms == 1 : # Cached already?
646
- self ._log .debug ('Caching permutations for \' self._xs_train\' ' )
647
- xs_train = self .apply_perms_to_obj (self ._xs_train , perm_idxs = self .perm_idxs )
629
+ self ._xs_train = nn .Parameter (self .apply_perms_to_obj (self ._xs_train , perm_idxs = self .perm_idxs ), requires_grad = False )
648
630
649
631
Jx_alphas_n_perms = self ._Jx_alphas .numel () // (self .n_train * self .dim_d )
650
632
if Jx_alphas_n_perms == 1 : # Cached already?
651
- self ._log .debug ('Caching permutations for \' self._Jx_alphas\' ' )
652
- Jx_alphas = self .apply_perms_to_obj (
633
+ self ._Jx_alphas = nn .Parameter (self .apply_perms_to_obj (
653
634
self ._Jx_alphas , perm_idxs = self .perm_idxs
654
- )
655
-
656
- # Do not overwrite before the operation above is successful.
657
- self ._xs_train = nn .Parameter (xs_train , requires_grad = False )
658
- self ._Jx_alphas = nn .Parameter (Jx_alphas , requires_grad = False )
635
+ ), requires_grad = False )
659
636
660
637
def est_mem_requirement (self , return_min = False ):
661
638
"""
@@ -741,15 +718,12 @@ def set_R_d_desc(self, R_d_desc):
741
718
if 'out of memory' in str (e ):
742
719
torch .cuda .empty_cache ()
743
720
744
- self ._log .debug ('Not enough memory to cache \' R_d_desc\' on GPU' )
721
+ self ._log .debug ('Failed to cache \' R_d_desc\' on GPU. ' )
745
722
else :
746
723
raise e
747
724
else :
748
- self ._log .debug ('\' R_d_desc\' lives on the GPU now' )
749
725
self .R_d_desc = R_d_desc
750
726
751
- self .R_d_desc = nn .Parameter (self .R_d_desc , requires_grad = False )
752
-
753
727
def set_alphas (self , alphas , alphas_E = None ):
754
728
"""
755
729
Reconfigure the current model with a new set of regression parameters.
@@ -776,68 +750,75 @@ def set_alphas(self, alphas, alphas_E=None):
776
750
777
751
if alphas_E is not None :
778
752
self ._alphas_E = nn .Parameter (
779
- torch .from_numpy (alphas_E ).to (self .R_d_desc .device ), requires_grad = False
753
+ torch .from_numpy (alphas_E ).to (self ._xs_train .device ), requires_grad = False
780
754
)
781
755
782
756
del self ._Jx_alphas
783
757
while True :
784
758
try :
785
759
786
- alphas_torch = torch .from_numpy (alphas ).to (self .R_d_desc .device )
760
+ alphas_torch = torch .from_numpy (alphas ).to (self .R_d_desc .device ) # Send to whatever device 'R_d_desc' is on, first.
787
761
xs = self .desc .d_desc_dot_vec (
788
762
self .R_d_desc , alphas_torch .reshape (- 1 , self .dim_i )
789
763
)
790
764
del alphas_torch
791
765
766
+ if torch .cuda .is_available () and not xs .is_cuda :
767
+ xs = xs .to (self ._xs_train .device ) # Only now send it to the GPU ('_xs_train' will be for sure, if GPUs are available)
768
+
792
769
except RuntimeError as e :
793
770
if 'out of memory' in str (e ):
794
- if not torch .cuda .is_available ():
795
- self ._log .critical (
796
- 'Not enough CPU memory to cache \' R_d_desc\' ! There nothing we can do...'
797
- )
798
- print ()
799
- os ._exit (1 )
800
- else :
801
- self .R_d_desc = self .R_d_desc .cpu ()
771
+
772
+ if torch .cuda .is_available ():
802
773
torch .cuda .empty_cache ()
803
774
775
+ self .R_d_desc = self .R_d_desc .cpu ()
776
+ #torch.cuda.empty_cache()
777
+
804
778
self ._log .debug (
805
- 'Failed to \' set_alphas()\' on the GPU ( \' R_d_desc\' was moved back from GPU to CPU) '
779
+ 'Failed to \' set_alphas()\' : \' R_d_desc\' was moved back from GPU to CPU'
806
780
)
807
781
808
- pass
782
+ pass
783
+
784
+ else :
785
+
786
+ self ._log .critical (
787
+ 'Not enough memory to cache \' R_d_desc\' ! There nothing we can do...'
788
+ )
789
+ print ()
790
+ os ._exit (1 )
791
+
809
792
else :
810
793
raise e
811
794
else :
812
795
break
813
796
814
797
try :
798
+
815
799
perm_idxs = self .perm_idxs if _n_perm_batches == 1 else None
816
800
self ._Jx_alphas = nn .Parameter (
817
801
self .apply_perms_to_obj (xs , perm_idxs = perm_idxs ), requires_grad = False
818
802
)
819
803
820
804
except RuntimeError as e :
821
805
if 'out of memory' in str (e ):
822
- torch .cuda .empty_cache ()
806
+ if torch .cuda .is_available ():
807
+ torch .cuda .empty_cache ()
823
808
824
809
if _n_perm_batches < self .n_perms :
825
810
826
- self ._log .debug ('Uncaching permutations (within \' set_alphas()\' )' )
811
+ #self._log.debug('Uncaching permutations (within \'set_alphas()\')')
812
+
813
+ self ._log .debug ('Setting permutation batch size to {}{}.' .format (_n_perm_batches , ' (no caching)' if _n_perm_batches > 1 else '' ))
827
814
828
815
_n_perm_batches += 1 # Do NOT change me to use 'self.set_n_perm_batches(_n_perm_batches + 1)'!
829
816
self ._xs_train = nn .Parameter (
830
817
self .remove_perms_from_obj (self ._xs_train ), requires_grad = False
831
- )
818
+ ) # Remove any permutations from 'self._xs_train'.
832
819
self ._Jx_alphas = nn .Parameter (
833
820
self .apply_perms_to_obj (xs , perm_idxs = None ), requires_grad = False
834
- )
835
-
836
- self ._log .debug (
837
- 'Trying {} permutation batches (within \' set_alphas()\' )' .format (
838
- _n_perm_batches
839
- )
840
- )
821
+ ) # Set 'self._Jx_alphas' without applying permutations.
841
822
842
823
else :
843
824
self ._log .critical (
@@ -848,6 +829,7 @@ def set_alphas(self, alphas, alphas_E=None):
848
829
else :
849
830
raise e
850
831
832
+
851
833
def _forward (self , Rs_or_train_idxs , return_E = True ):
852
834
853
835
global _n_perm_batches
@@ -895,6 +877,8 @@ def _forward(self, Rs_or_train_idxs, return_E=True):
895
877
] # ignore permutations
896
878
897
879
Jxs = self .R_d_desc [train_idxs , :, :]
880
+ #if torch.cuda.is_available() and not self.R_d_desc.is_cuda:
881
+ Jxs = Jxs .to (xs .device ) # 'R_d_desc' can live on the CPU, as well.
898
882
899
883
# current:
900
884
# diffs: N, a, a, 3
@@ -1000,6 +984,7 @@ def _forward(self, Rs_or_train_idxs, return_E=True):
1000
984
diffs = torch .zeros (
1001
985
(n , self .n_atoms , self .n_atoms , 3 ), device = xs .device , dtype = xs .dtype
1002
986
)
987
+
1003
988
diffs [:, i , j , :] = Jxs * Fs_x [..., None ]
1004
989
diffs [:, j , i , :] = - diffs [:, i , j , :]
1005
990
@@ -1061,23 +1046,20 @@ def forward(self, Rs_or_train_idxs=None, return_E=True):
1061
1046
)
1062
1047
except RuntimeError as e :
1063
1048
if 'out of memory' in str (e ):
1064
- torch .cuda .empty_cache ()
1049
+ if torch .cuda .is_available ():
1050
+ torch .cuda .empty_cache ()
1065
1051
1066
1052
if _batch_size > 1 :
1067
- _batch_size -= 1
1068
1053
1069
- self ._log .debug ('Trying batch size of {}.' .format (_batch_size ))
1054
+ self ._log .debug ('Setting batch size to {}/{} points.' .format (_batch_size , self .n_train ))
1055
+ _batch_size -= 1
1070
1056
1071
1057
elif _n_perm_batches < self .n_perms :
1072
1058
n_perm_batches = _next_batch_size (
1073
1059
self .n_perms , _n_perm_batches
1074
1060
)
1075
1061
self .set_n_perm_batches (n_perm_batches )
1076
1062
1077
- self ._log .debug (
1078
- 'Trying {} permutation batches.' .format (n_perm_batches )
1079
- )
1080
-
1081
1063
else :
1082
1064
self ._log .critical (
1083
1065
'Could not allocate enough (GPU) memory to evaluate model, despite reducing batch size.'
0 commit comments