Skip to content

Commit 87567a6

Browse files
committed
improved GPU memory management
1 parent 4a5ae03 commit 87567a6

File tree

2 files changed

+59
-77
lines changed

2 files changed

+59
-77
lines changed

sgdml/__init__.py

100755100644
+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2323
# SOFTWARE.
2424

25-
__version__ = '0.5.2'
25+
__version__ = '0.5.3.dev0'
2626

2727
MAX_PRINT_WIDTH = 100
2828
LOG_LEVELNAME_WIDTH = 7 # do not modify
@@ -108,8 +108,8 @@ def __init__(self, name):
108108
hd = logging.StreamHandler()
109109
hd.setFormatter(formatter)
110110
hd.setLevel(
111-
logging.DEBUG
112-
) # control logging level here (default: logging.DEBUG)
111+
logging.INFO
112+
) # control logging level here
113113

114114
self.addHandler(hd)
115115
return

sgdml/torchtools.py

+56-74
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def forward(self, J_indx):
348348
done = self._forward(self.J[i])
349349
except RuntimeError as e:
350350
if 'out of memory' in str(e):
351-
torch.cuda.empty_cache()
351+
if torch.cuda.is_available():
352+
torch.cuda.empty_cache()
352353

353354
if _n_batches < self.n_train:
354355
_n_batches = _next_batch_size(
@@ -478,8 +479,6 @@ def __init__(
478479

479480
self.tril_indices = np.tril_indices(self.n_atoms, k=-1)
480481

481-
self.R_d_desc = None
482-
483482
if torch.cuda.is_available(): # Ignore limits and take whatever the GPU has.
484483
max_memory = (
485484
min(
@@ -522,39 +521,30 @@ def __init__(
522521

523522
self.max_processes = max_processes
524523

525-
self.perm_idxs = (
526-
torch.tensor(model['tril_perms_lin']).view(-1, self.n_perms).t()
527-
)
528-
529-
self._xs_train = nn.Parameter(
530-
self.apply_perms_to_obj(torch.tensor(model['R_desc']).t(), perm_idxs=None),
531-
requires_grad=False,
532-
)
533-
self._Jx_alphas = nn.Parameter(
534-
self.apply_perms_to_obj(
535-
torch.tensor(np.array(model['R_d_desc_alpha'])), perm_idxs=None
536-
),
537-
requires_grad=False,
538-
)
524+
self.R_d_desc = None
525+
self._xs_train = nn.Parameter(torch.tensor(model['R_desc']).t(), requires_grad=False)
526+
self._Jx_alphas = nn.Parameter(torch.tensor(np.array(model['R_d_desc_alpha'])), requires_grad=False)
539527

540528
self._alphas_E = None
541529
if 'alphas_E' in model:
542530
self._alphas_E = nn.Parameter(
543531
torch.from_numpy(model['alphas_E']), requires_grad=False
544532
)
545533

534+
self.perm_idxs = (
535+
torch.tensor(model['tril_perms_lin']).view(-1, self.n_perms).t()
536+
)
537+
546538
# Try to cache all permutated variants of 'self._xs_train' and 'self._Jx_alphas'
547539
try:
548540
self.set_n_perm_batches(n_perm_batches)
549541
except RuntimeError as e:
550542
if 'out of memory' in str(e):
551-
torch.cuda.empty_cache()
543+
if torch.cuda.is_available():
544+
torch.cuda.empty_cache()
552545

553546
if n_perm_batches == 1:
554-
self._log.debug(
555-
'Trying to cache permutations FAILED during init (continuing without)'
556-
)
557-
self.set_n_perm_batches(2)
547+
self.set_n_perm_batches(2) # Set to 2 perm batches, because that's the first batch size (and fastest) that is not cached.
558548
pass
559549
else:
560550
self._log.critical(
@@ -577,11 +567,8 @@ def __init__(
577567
else self.n_train
578568
)
579569
_batch_size = min(_batch_size, max_batch_size)
580-
self._log.debug(
581-
'Starting with a batch size of {} ({} points in total).'.format(
582-
_batch_size, self.n_train
583-
)
584-
)
570+
571+
self._log.debug('Setting batch size to {}/{} points.'.format(_batch_size, self.n_train))
585572

586573
self.desc = Desc(self.n_atoms, max_processes=max_processes)
587574

@@ -594,9 +581,7 @@ def set_n_perm_batches(self, n_perm_batches):
594581

595582
global _n_perm_batches
596583

597-
self._log.debug(
598-
'Permutations will be generated in {} batches.'.format(n_perm_batches)
599-
)
584+
self._log.debug('Setting permutation batch size to {}{}.'.format(n_perm_batches, ' (no caching)' if n_perm_batches > 1 else ''))
600585

601586
_n_perm_batches = n_perm_batches
602587
if n_perm_batches == 1 and self.n_perms > 1:
@@ -627,14 +612,12 @@ def uncache_perms(self):
627612

628613
xs_train_n_perms = self._xs_train.numel() // (self.n_train * self.dim_d)
629614
if xs_train_n_perms != 1: # Uncached already?
630-
self._log.debug('Uncaching permutations for \'self._xs_train\'')
631615
self._xs_train = nn.Parameter(
632616
self.remove_perms_from_obj(self._xs_train), requires_grad=False
633617
)
634618

635619
Jx_alphas_n_perms = self._Jx_alphas.numel() // (self.n_train * self.dim_d)
636620
if Jx_alphas_n_perms != 1: # Uncached already?
637-
self._log.debug('Uncaching permutations for \'self._Jx_alphas\'')
638621
self._Jx_alphas = nn.Parameter(
639622
self.remove_perms_from_obj(self._Jx_alphas), requires_grad=False
640623
)
@@ -643,19 +626,13 @@ def cache_perms(self):
643626

644627
xs_train_n_perms = self._xs_train.numel() // (self.n_train * self.dim_d)
645628
if xs_train_n_perms == 1: # Cached already?
646-
self._log.debug('Caching permutations for \'self._xs_train\'')
647-
xs_train = self.apply_perms_to_obj(self._xs_train, perm_idxs=self.perm_idxs)
629+
self._xs_train = nn.Parameter(self.apply_perms_to_obj(self._xs_train, perm_idxs=self.perm_idxs), requires_grad=False)
648630

649631
Jx_alphas_n_perms = self._Jx_alphas.numel() // (self.n_train * self.dim_d)
650632
if Jx_alphas_n_perms == 1: # Cached already?
651-
self._log.debug('Caching permutations for \'self._Jx_alphas\'')
652-
Jx_alphas = self.apply_perms_to_obj(
633+
self._Jx_alphas = nn.Parameter(self.apply_perms_to_obj(
653634
self._Jx_alphas, perm_idxs=self.perm_idxs
654-
)
655-
656-
# Do not overwrite before the operation above is successful.
657-
self._xs_train = nn.Parameter(xs_train, requires_grad=False)
658-
self._Jx_alphas = nn.Parameter(Jx_alphas, requires_grad=False)
635+
), requires_grad=False)
659636

660637
def est_mem_requirement(self, return_min=False):
661638
"""
@@ -741,15 +718,12 @@ def set_R_d_desc(self, R_d_desc):
741718
if 'out of memory' in str(e):
742719
torch.cuda.empty_cache()
743720

744-
self._log.debug('Not enough memory to cache \'R_d_desc\' on GPU')
721+
self._log.debug('Failed to cache \'R_d_desc\' on GPU.')
745722
else:
746723
raise e
747724
else:
748-
self._log.debug('\'R_d_desc\' lives on the GPU now')
749725
self.R_d_desc = R_d_desc
750726

751-
self.R_d_desc = nn.Parameter(self.R_d_desc, requires_grad=False)
752-
753727
def set_alphas(self, alphas, alphas_E=None):
754728
"""
755729
Reconfigure the current model with a new set of regression parameters.
@@ -776,68 +750,75 @@ def set_alphas(self, alphas, alphas_E=None):
776750

777751
if alphas_E is not None:
778752
self._alphas_E = nn.Parameter(
779-
torch.from_numpy(alphas_E).to(self.R_d_desc.device), requires_grad=False
753+
torch.from_numpy(alphas_E).to(self._xs_train.device), requires_grad=False
780754
)
781755

782756
del self._Jx_alphas
783757
while True:
784758
try:
785759

786-
alphas_torch = torch.from_numpy(alphas).to(self.R_d_desc.device)
760+
alphas_torch = torch.from_numpy(alphas).to(self.R_d_desc.device) # Send to whatever device 'R_d_desc' is on, first.
787761
xs = self.desc.d_desc_dot_vec(
788762
self.R_d_desc, alphas_torch.reshape(-1, self.dim_i)
789763
)
790764
del alphas_torch
791765

766+
if torch.cuda.is_available() and not xs.is_cuda:
767+
xs = xs.to(self._xs_train.device) # Only now send it to the GPU ('_xs_train' will be for sure, if GPUs are available)
768+
792769
except RuntimeError as e:
793770
if 'out of memory' in str(e):
794-
if not torch.cuda.is_available():
795-
self._log.critical(
796-
'Not enough CPU memory to cache \'R_d_desc\'! There nothing we can do...'
797-
)
798-
print()
799-
os._exit(1)
800-
else:
801-
self.R_d_desc = self.R_d_desc.cpu()
771+
772+
if torch.cuda.is_available():
802773
torch.cuda.empty_cache()
803774

775+
self.R_d_desc = self.R_d_desc.cpu()
776+
#torch.cuda.empty_cache()
777+
804778
self._log.debug(
805-
'Failed to \'set_alphas()\' on the GPU (\'R_d_desc\' was moved back from GPU to CPU)'
779+
'Failed to \'set_alphas()\': \'R_d_desc\' was moved back from GPU to CPU'
806780
)
807781

808-
pass
782+
pass
783+
784+
else:
785+
786+
self._log.critical(
787+
'Not enough memory to cache \'R_d_desc\'! There nothing we can do...'
788+
)
789+
print()
790+
os._exit(1)
791+
809792
else:
810793
raise e
811794
else:
812795
break
813796

814797
try:
798+
815799
perm_idxs = self.perm_idxs if _n_perm_batches == 1 else None
816800
self._Jx_alphas = nn.Parameter(
817801
self.apply_perms_to_obj(xs, perm_idxs=perm_idxs), requires_grad=False
818802
)
819803

820804
except RuntimeError as e:
821805
if 'out of memory' in str(e):
822-
torch.cuda.empty_cache()
806+
if torch.cuda.is_available():
807+
torch.cuda.empty_cache()
823808

824809
if _n_perm_batches < self.n_perms:
825810

826-
self._log.debug('Uncaching permutations (within \'set_alphas()\')')
811+
#self._log.debug('Uncaching permutations (within \'set_alphas()\')')
812+
813+
self._log.debug('Setting permutation batch size to {}{}.'.format(_n_perm_batches, ' (no caching)' if _n_perm_batches > 1 else ''))
827814

828815
_n_perm_batches += 1 # Do NOT change me to use 'self.set_n_perm_batches(_n_perm_batches + 1)'!
829816
self._xs_train = nn.Parameter(
830817
self.remove_perms_from_obj(self._xs_train), requires_grad=False
831-
)
818+
) # Remove any permutations from 'self._xs_train'.
832819
self._Jx_alphas = nn.Parameter(
833820
self.apply_perms_to_obj(xs, perm_idxs=None), requires_grad=False
834-
)
835-
836-
self._log.debug(
837-
'Trying {} permutation batches (within \'set_alphas()\')'.format(
838-
_n_perm_batches
839-
)
840-
)
821+
) # Set 'self._Jx_alphas' without applying permutations.
841822

842823
else:
843824
self._log.critical(
@@ -848,6 +829,7 @@ def set_alphas(self, alphas, alphas_E=None):
848829
else:
849830
raise e
850831

832+
851833
def _forward(self, Rs_or_train_idxs, return_E=True):
852834

853835
global _n_perm_batches
@@ -895,6 +877,8 @@ def _forward(self, Rs_or_train_idxs, return_E=True):
895877
] # ignore permutations
896878

897879
Jxs = self.R_d_desc[train_idxs, :, :]
880+
#if torch.cuda.is_available() and not self.R_d_desc.is_cuda:
881+
Jxs = Jxs.to(xs.device) # 'R_d_desc' can live on the CPU, as well.
898882

899883
# current:
900884
# diffs: N, a, a, 3
@@ -1000,6 +984,7 @@ def _forward(self, Rs_or_train_idxs, return_E=True):
1000984
diffs = torch.zeros(
1001985
(n, self.n_atoms, self.n_atoms, 3), device=xs.device, dtype=xs.dtype
1002986
)
987+
1003988
diffs[:, i, j, :] = Jxs * Fs_x[..., None]
1004989
diffs[:, j, i, :] = -diffs[:, i, j, :]
1005990

@@ -1061,23 +1046,20 @@ def forward(self, Rs_or_train_idxs=None, return_E=True):
10611046
)
10621047
except RuntimeError as e:
10631048
if 'out of memory' in str(e):
1064-
torch.cuda.empty_cache()
1049+
if torch.cuda.is_available():
1050+
torch.cuda.empty_cache()
10651051

10661052
if _batch_size > 1:
1067-
_batch_size -= 1
10681053

1069-
self._log.debug('Trying batch size of {}.'.format(_batch_size))
1054+
self._log.debug('Setting batch size to {}/{} points.'.format(_batch_size, self.n_train))
1055+
_batch_size -= 1
10701056

10711057
elif _n_perm_batches < self.n_perms:
10721058
n_perm_batches = _next_batch_size(
10731059
self.n_perms, _n_perm_batches
10741060
)
10751061
self.set_n_perm_batches(n_perm_batches)
10761062

1077-
self._log.debug(
1078-
'Trying {} permutation batches.'.format(n_perm_batches)
1079-
)
1080-
10811063
else:
10821064
self._log.critical(
10831065
'Could not allocate enough (GPU) memory to evaluate model, despite reducing batch size.'

0 commit comments

Comments
 (0)