From 29eeea68f013beae1e05099b4f8d6598083c9be6 Mon Sep 17 00:00:00 2001 From: stefanch <41121473+stefanch@users.noreply.github.com> Date: Thu, 20 Oct 2022 16:28:13 +0200 Subject: [PATCH] bugfix --- README.md | 11 +- sgdml/__init__.py | 2 +- sgdml/cli.py | 92 +++++++++++-- sgdml/predict.py | 46 +++++-- sgdml/torchtools.py | 321 ++++++++++++++++++++++++-------------------- sgdml/train.py | 20 ++- 6 files changed, 319 insertions(+), 173 deletions(-) diff --git a/README.md b/README.md index f8248f1..8ba159e 100644 --- a/README.md +++ b/README.md @@ -115,12 +115,17 @@ We appreciate and welcome contributions and would like to thank the following pe Science Advances, 3(5), e1603015 (2017) [10.1126/sciadv.1603015](http://dx.doi.org/10.1126/sciadv.1603015) -* [2] Chmiela, S., Sauceda, H. E., Müller, K.-R., & Tkatchenko, A., +* [2] Chmiela, S., Sauceda, H. E., Müller, K.-R., Tkatchenko, A., *Towards Exact Molecular Dynamics Simulations with Machine-Learned Force Fields.* Nature Communications, 9(1), 3887 (2018) [10.1038/s41467-018-06169-2](https://doi.org/10.1038/s41467-018-06169-2) -* [3] Chmiela, S., Sauceda, H. E., Poltavsky, I., Müller, K.-R., & Tkatchenko, A., +* [3] Chmiela, S., Sauceda, H. E., Poltavsky, I., Müller, K.-R., Tkatchenko, A., *sGDML: Constructing Accurate and Data Efficient Molecular Force Fields Using Machine Learning.* Computer Physics Communications, 240, 38-45 (2019) -[10.1016/j.cpc.2019.02.007](https://doi.org/10.1016/j.cpc.2019.02.007) \ No newline at end of file +[10.1016/j.cpc.2019.02.007](https://doi.org/10.1016/j.cpc.2019.02.007) + +* [4] Chmiela, S., Vassilev-Galindo, V., Unke, O. T., Kabylda, A., Sauceda, H. E., Tkatchenko, A., Müller, K.-R., +*Accurate global machine learning force fields for molecules with hundreds of atoms* +Preprint (2022) +[arXiv:2209.14865](https://arxiv.org/abs/2209.14865) \ No newline at end of file diff --git a/sgdml/__init__.py b/sgdml/__init__.py index f553b48..4c2e3d1 100644 --- a/sgdml/__init__.py +++ b/sgdml/__init__.py @@ -22,7 +22,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -__version__ = '0.5.3.dev0' +__version__ = '0.5.3' MAX_PRINT_WIDTH = 100 LOG_LEVELNAME_WIDTH = 7 # do not modify diff --git a/sgdml/cli.py b/sgdml/cli.py index af920af..73af1d2 100644 --- a/sgdml/cli.py +++ b/sgdml/cli.py @@ -45,6 +45,17 @@ else: _has_torch = True +try: + _torch_mps_is_available = torch.backends.mps.is_available() +except AttributeError: + _torch_mps_is_available = False +_torch_mps_is_available = False + +try: + _torch_cuda_is_available = torch.cuda.is_available() +except AttributeError: + _torch_cuda_is_available = False + try: import ase except ImportError: @@ -94,10 +105,14 @@ def _print_splash(max_memory, max_processes, use_torch): max_processes_str = '{:d} CPU(s)'.format(max_processes) hardware_str = 'using {}, {}'.format(max_memory_str, max_processes_str) - if use_torch and _has_torch and torch.cuda.is_available(): - num_gpu = torch.cuda.device_count() - if num_gpu > 0: - hardware_str += ', {:d} GPU(s)'.format(num_gpu) + if use_torch and _has_torch: + + if _torch_cuda_is_available: + num_gpu = torch.cuda.device_count() + if num_gpu > 0: + hardware_str += ', {:d} GPU(s)'.format(num_gpu) + elif _torch_mps_is_available: + hardware_str += ', MPS enabled' logo_str_split = logo_str.splitlines() print('\n'.join(logo_str_split[:-1])) @@ -126,6 +141,9 @@ def _print_splash(max_memory, max_processes, use_torch): ) + _print_billboard() + + def _check_update(): try: @@ -133,8 +151,8 @@ def _check_update(): except ImportError: from urllib2 import urlopen - base_url = 'http://www.quantum-machine.org/gdml/' - url = '%supdate.php?v=%s' % (base_url, __version__) + base_url = 'http://api.sgdml.org/' + url = '{}update.php?v={}'.format(base_url, __version__) can_update, must_update = '0', '0' latest_version = '' @@ -148,6 +166,60 @@ def _check_update(): return can_update == '1', latest_version +def _print_billboard(): + + try: + from urllib.request import urlopen + except ImportError: + from urllib2 import urlopen + + base_url = 'http://api.sgdml.org/' + url = '{}billboard.php'.format(base_url) + + resp_str = '' + try: + response = urlopen(url, timeout=1) + resp_str = response.read().decode() + response.close() + except: + pass + + bbs = None + try: + import json + bbs = json.loads(resp_str) + except: + pass + + if bbs is not None: + + for bb in bbs: + + back_color = ui.WHITE + if bb['color'] == 'YELLOW': + back_color = ui.YELLOW + elif bb['color'] == 'GREEN': + back_color = ui.GREEN + elif bb['color'] == 'RED': + back_color = ui.RED + elif bb['color'] == 'CYAN': + back_color = ui.CYAN + + print( + '\n' + + ui.color_str( + ' {} '.format(bb['title']), + fore_color=ui.BLACK, + back_color=back_color, + bold=True, + ) + + '\n' + + '-' * MAX_PRINT_WIDTH + ) + + print(ui.wrap_str(bb['text'], width=MAX_PRINT_WIDTH - 2)) + + def _print_dataset_properties(dataset, title_str='Dataset properties'): print(ui.color_str(title_str, bold=True)) @@ -1692,11 +1764,11 @@ def test( if model['use_E']: model['e_err'] = { - 'mae': np.asscalar(e_mae), - 'rmse': np.asscalar(e_rmse), + 'mae': e_mae.item(), + 'rmse': e_rmse.item(), } - model['f_err'] = {'mae': np.asscalar(f_mae), 'rmse': np.asscalar(f_rmse)} + model['f_err'] = {'mae':f_mae.item(), 'rmse': f_rmse.item()} np.savez_compressed(model_path, **model) if is_test and model['n_test'] > 0: @@ -2191,7 +2263,7 @@ def _add_argument_dir_with_file_type(parser, type, or_file=False): # Check PyTorch GPU support. if ('use_torch' in args and args.use_torch) or 'use_torch' not in args: if _has_torch: - if not torch.cuda.is_available(): + if not (_torch_cuda_is_available or _torch_mps_is_available): print() # TODO: print only if log level includes warning log.warning( 'Your PyTorch installation does not see any GPU(s) on your system and will thus run all calculations on the CPU! If this is what you want, we recommend bypassing PyTorch using \'--cpu\' for improved performance.' diff --git a/sgdml/predict.py b/sgdml/predict.py index f85b213..4768226 100644 --- a/sgdml/predict.py +++ b/sgdml/predict.py @@ -45,6 +45,17 @@ else: _has_torch = True +try: + _torch_mps_is_available = torch.backends.mps.is_available() +except AttributeError: + _torch_mps_is_available = False +_torch_mps_is_available = False + +try: + _torch_cuda_is_available = torch.cuda.is_available() +except AttributeError: + _torch_cuda_is_available = False + import numpy as np from . import __version__ @@ -367,15 +378,22 @@ def __init__( self.torch_predict = torch.nn.DataParallel(self.torch_predict) # Send model to device - self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' + #self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' + if _torch_cuda_is_available: + self.torch_device = 'cuda' + elif _torch_mps_is_available: + self.torch_device = 'mps' + else: + self.torch_device = 'cpu' + while True: try: self.torch_predict.to(self.torch_device) except RuntimeError as e: if 'out of memory' in str(e): - torch.cuda.empty_cache() - print('sending to device -> fail (trying again)') + if _torch_cuda_is_available: + torch.cuda.empty_cache() model = self.torch_predict if isinstance(self.torch_predict, torch.nn.DataParallel): @@ -387,9 +405,9 @@ def __init__( model.set_n_perm_batches( model.get_n_perm_batches() + 1 ) # uncache - self.torch_predict.to( - self.torch_device - ) # try sending to device again + #self.torch_predict.to( # NOTE! + # self.torch_device + #) # try sending to device again pass else: self.log.critical( @@ -1120,7 +1138,7 @@ def get_GPU_batch(self): if self.use_torch: model = self.torch_predict - if isinstance(self.torch_predict, torch.nn.DataParallel): + if isinstance(model, torch.nn.DataParallel): model = model.module return model._batch_size() @@ -1176,11 +1194,14 @@ def predict(self, R=None, return_E=True): print() os._exit(1) else: - R_torch = torch.from_numpy(R.reshape(-1, self.n_atoms, 3)).to( + R_torch = torch.from_numpy(R.reshape(-1, self.n_atoms, 3)).type(torch.float32).to( self.torch_device ) - E_torch_F_torch = self.torch_predict.forward(R_torch, return_E=return_E) + model = self.torch_predict + if R_torch.shape[0] < torch.cuda.device_count() and isinstance(model, torch.nn.DataParallel): + model = self.torch_predict.module + E_torch_F_torch = model.forward(R_torch, return_E=return_E) if return_E: E_torch, F_torch = E_torch_F_torch @@ -1258,9 +1279,10 @@ def predict(self, R=None, return_E=True): ) ) - E_F *= self.std - F = E_F[:, 1:] - E = E_F[:, 0] + self.c + if R is not None: # Not in train mode. TODO: better set y_std to zero + E_F *= self.std + F = E_F[:, 1:] + E = E_F[:, 0] + self.c ret = (F,) if return_E: diff --git a/sgdml/torchtools.py b/sgdml/torchtools.py index d137325..fb2cdce 100644 --- a/sgdml/torchtools.py +++ b/sgdml/torchtools.py @@ -31,9 +31,23 @@ import torch.nn as nn from torch.utils.data import DataLoader +try: + _torch_mps_is_available = torch.backends.mps.is_available() +except AttributeError: + _torch_mps_is_available = False +_torch_mps_is_available = False + +try: + _torch_cuda_is_available = torch.cuda.is_available() +except AttributeError: + _torch_cuda_is_available = False + + from .utils.desc import Desc from .utils import ui +_dtype = torch.float64 + def _next_batch_size(n_total, batch_size): @@ -76,12 +90,12 @@ def __init__( self.sig = float(sig) self.tril_perms_lin = tril_perms_lin - self.n_perms = int(len(self.tril_perms_lin) / self.dim_d) + self.n_perms = len(self.tril_perms_lin) // self.dim_d self.use_E_cstr = use_E_cstr - self.R_desc_torch = nn.Parameter(R_desc_torch, requires_grad=False) - self.R_d_desc_torch = nn.Parameter(R_d_desc_torch, requires_grad=False) + self.R_desc_torch = nn.Parameter(R_desc_torch.type(_dtype), requires_grad=False) + self.R_d_desc_torch = nn.Parameter(R_d_desc_torch.type(_dtype), requires_grad=False) self._desc = Desc(self.n_atoms) @@ -121,10 +135,7 @@ def _forward( ) keep_idxs_3n = slice(None) # same as [:] - mat52_base_div = 3 * self.sig ** 4 - sqrt5 = np.sqrt(5.0) - sig_pow2 = self.sig ** 2 - + q = np.sqrt(5) / self.sig if ( j < self.n_train @@ -166,42 +177,40 @@ def _forward( for i_batch in np.array_split(np.arange(self.n_train), _n_batches): - diff_ab_perms_torch = ( + x_diffs = q * ( self.R_desc_torch[i_batch, None, :] - rj_desc_perms_torch[None, :, :] ) # N, n_perms, d - norm_ab_perms_torch = sqrt5 * diff_ab_perms_torch.norm( - dim=-1 - ) # N, n_perms - mat52_base_perms_torch = ( - torch.exp(-norm_ab_perms_torch / self.sig) - / mat52_base_div - * 5 - ) # N, n_perms + x_dists = x_diffs.norm(dim=-1) # N, n_perms + + exp_xs = torch.exp(-x_dists) * (q ** 2) / 3 # N, n_perms + exp_xs_1_x_dists = exp_xs * (1 + x_dists) # N, n_perms*N_train + + del x_dists # E_cstr diff_ab_outer_perms_torch = torch.einsum( '...ki,...kj->...ij', # (slow) - #'...ij,...ik->...jk', #(fast) - diff_ab_perms_torch - * mat52_base_perms_torch[:, :, None] - * 5, # N, n_perms, d + x_diffs * exp_xs[:, :, None], # N, n_perms, d torch.einsum( '...ki,jik -> ...kj', - diff_ab_perms_torch, + x_diffs, rj_d_desc_perms_torch, ), # N, n_perms, a*3 ) # N, n_perms, a*3 - # del diff_ab_perms_torch # E_cstr + del exp_xs + + if not self.use_E_cstr: + del x_diffs diff_ab_outer_perms_torch -= torch.einsum( 'ikj,...j->...ki', rj_d_desc_perms_torch, - (sig_pow2 + self.sig * norm_ab_perms_torch) - * mat52_base_perms_torch, + exp_xs_1_x_dists, ) - # del norm_ab_perms_torch # E_cstr - del mat52_base_perms_torch + + if not self.use_E_cstr: + del exp_xs_1_x_dists R_d_desc_decomp_torch = self._desc.d_desc_from_comp( self.R_d_desc_torch[i_batch, :, :] @@ -231,13 +240,9 @@ def _forward( # First derivative constraints if self.use_E_cstr: - K_fe = ( - 5 - * diff_ab_perms_torch - / (3 * self.sig ** 3) - * (norm_ab_perms_torch[:, :, None] + self.sig) - * torch.exp(-norm_ab_perms_torch / self.sig)[:, :, None] - ) + K_fe = (x_diffs / q) * exp_xs_1_x_dists[:, :, None] + del x_diffs + del exp_xs_1_x_dists K_fe = -torch.einsum( '...ik,jki -> ...j', K_fe, rj_d_desc_perms_torch @@ -245,7 +250,9 @@ def _forward( E_off_i = self.n_train * self.dim_i i_batch_off = i_batch + E_off_i - self.out[i_batch_off[0] : (i_batch_off[-1] + 1), blk_j] = K_fe.cpu().numpy() + self.out[ + i_batch_off[0] : (i_batch_off[-1] + 1), blk_j + ] = K_fe.cpu().numpy() del rj_desc_perms_torch del rj_d_desc_perms_torch @@ -267,9 +274,7 @@ def _forward( n_perms_batch = len(perm_batch) n_perms_done += n_perms_batch - for i_batch in np.array_split( - np.arange(self.n_train), _n_batches - ): + for i_batch in np.array_split(np.arange(self.n_train), _n_batches): ri_desc_perms_torch = torch.reshape( torch.tile( @@ -289,23 +294,19 @@ def _forward( ], (len(i_batch), self.dim_d, n_perms_batch, -1), ) - #del ri_d_desc_decomp_torch + # del ri_d_desc_decomp_torch - diff_ab_perms_torch = ( + x_diffs = q * ( self.R_desc_torch[j % self.n_train, None, :, None] - ri_desc_perms_torch ) - #norm_ab_perms_torch = sqrt5 * diff_ab_perms_torch.norm(dim=-1) - norm_ab_perms_torch = sqrt5 / self.sig * diff_ab_perms_torch.norm(dim=1) + x_dists = x_diffs.norm(dim=1) - K_fe = ( - 5 - * diff_ab_perms_torch - / (3 * self.sig ** 2) - * (norm_ab_perms_torch[:, None, :] + 1) - * torch.exp(-norm_ab_perms_torch)[:, None, :] - ) + exp_xs = torch.exp(-x_dists) * (q ** 2) / 3 + exp_xs_1_x_dists = exp_xs * (1 + x_dists) + + K_fe = x_diffs / q * exp_xs_1_x_dists[:, None, :] K_fe = -torch.einsum( '...ik,...ikj -> ...j', K_fe, ri_d_desc_perms_torch ).ravel() @@ -313,14 +314,14 @@ def _forward( k_ee = -torch.einsum( '...i,...i -> ...', - 1 - + (norm_ab_perms_torch) - * (1 + norm_ab_perms_torch / 3), - torch.exp(-norm_ab_perms_torch), + 1 + x_dists * (1 + x_dists / 3), + torch.exp(-x_dists), ) k_ee = k_ee.cpu().numpy() - E_off_i = (self.n_train * self.dim_i) # Account for 'alloc_extra_rows'!. + E_off_i = ( + self.n_train * self.dim_i + ) # Account for 'alloc_extra_rows'!. blk_i_full = slice( i_batch[0] * self.dim_i, (i_batch[-1] + 1) * self.dim_i ) @@ -334,7 +335,6 @@ def _forward( self.out[E_off_i + i_batch, K_j] = ( self.out[E_off_i + i_batch, K_j] + k_ee ) - # del k return blk_j.stop - blk_j.start @@ -348,13 +348,11 @@ def forward(self, J_indx): done = self._forward(self.J[i]) except RuntimeError as e: if 'out of memory' in str(e): - if torch.cuda.is_available(): + if _torch_cuda_is_available: torch.cuda.empty_cache() if _n_batches < self.n_train: - _n_batches = _next_batch_size( - self.n_train, _n_batches - ) + _n_batches = _next_batch_size(self.n_train, _n_batches) self._log.debug( 'Assembling each kernel column in {} batches, i.e. {} points/batch ({} points in total).'.format( @@ -479,7 +477,7 @@ def __init__( self.tril_indices = np.tril_indices(self.n_atoms, k=-1) - if torch.cuda.is_available(): # Ignore limits and take whatever the GPU has. + if _torch_cuda_is_available: # Ignore limits and take whatever the GPU has. max_memory = ( min( [ @@ -489,7 +487,7 @@ def __init__( ) // 2 ** 30 ) # bytes to GB - else: + else: # TODO: what about MPS? default_cpu_max_mem = 32 if max_memory is None: self._log.warning( @@ -512,7 +510,7 @@ def __init__( ) log_type( '{} memory report: max./avail. {}, min. req. (const./per-sample) ~{}/~{}'.format( - 'GPU' if torch.cuda.is_available() else 'CPU', + 'GPU' if (_torch_cuda_is_available or _torch_mps_is_available) else 'CPU', ui.gen_memory_str(max_memory), ui.gen_memory_str(min_const_mem), ui.gen_memory_str(min_per_sample_mem), @@ -522,29 +520,40 @@ def __init__( self.max_processes = max_processes self.R_d_desc = None - self._xs_train = nn.Parameter(torch.tensor(model['R_desc']).t(), requires_grad=False) - self._Jx_alphas = nn.Parameter(torch.tensor(np.array(model['R_d_desc_alpha'])), requires_grad=False) + self._xs_train = nn.Parameter( + torch.tensor(model['R_desc'], dtype=_dtype).t(), requires_grad=False + ) + self._Jx_alphas = nn.Parameter( + torch.tensor(np.array(model['R_d_desc_alpha']), dtype=_dtype), requires_grad=False + ) self._alphas_E = None if 'alphas_E' in model: self._alphas_E = nn.Parameter( - torch.from_numpy(model['alphas_E']), requires_grad=False + torch.from_numpy(model['alphas_E'], dtype=_dtype), requires_grad=False ) self.perm_idxs = ( - torch.tensor(model['tril_perms_lin']).view(-1, self.n_perms).t() + torch.tensor(model['tril_perms_lin'], dtype=torch.long).view(-1, self.n_perms).t() ) + i, j = self.tril_indices + self.register_buffer('agg_mat', torch.zeros((self.n_atoms, self.dim_d), dtype=torch.int8)) + self.agg_mat[i, range(self.dim_d)] = -1 + self.agg_mat[j, range(self.dim_d)] = 1 + # Try to cache all permutated variants of 'self._xs_train' and 'self._Jx_alphas' try: self.set_n_perm_batches(n_perm_batches) except RuntimeError as e: if 'out of memory' in str(e): - if torch.cuda.is_available(): + if _torch_cuda_is_available: torch.cuda.empty_cache() if n_perm_batches == 1: - self.set_n_perm_batches(2) # Set to 2 perm batches, because that's the first batch size (and fastest) that is not cached. + self.set_n_perm_batches( + 2 + ) # Set to 2 perm batches, because that's the first batch size (and fastest) that is not cached. pass else: self._log.critical( @@ -563,12 +572,14 @@ def __init__( ) max_batch_size = ( self.n_train // torch.cuda.device_count() - if torch.cuda.is_available() + if _torch_cuda_is_available else self.n_train ) _batch_size = min(_batch_size, max_batch_size) - self._log.debug('Setting batch size to {}/{} points.'.format(_batch_size, self.n_train)) + self._log.debug( + 'Setting batch size to {}/{} points.'.format(_batch_size, self.n_train) + ) self.desc = Desc(self.n_atoms, max_processes=max_processes) @@ -581,7 +592,11 @@ def set_n_perm_batches(self, n_perm_batches): global _n_perm_batches - self._log.debug('Setting permutation batch size to {}{}.'.format(n_perm_batches, ' (no caching)' if n_perm_batches > 1 else '')) + self._log.debug( + 'Setting permutation batch size to {}/{}{}.'.format( + self.n_perms // n_perm_batches, self.n_perms, ' (no caching)' if n_perm_batches > 1 else '' + ) + ) _n_perm_batches = n_perm_batches if n_perm_batches == 1 and self.n_perms > 1: @@ -626,13 +641,17 @@ def cache_perms(self): xs_train_n_perms = self._xs_train.numel() // (self.n_train * self.dim_d) if xs_train_n_perms == 1: # Cached already? - self._xs_train = nn.Parameter(self.apply_perms_to_obj(self._xs_train, perm_idxs=self.perm_idxs), requires_grad=False) + self._xs_train = nn.Parameter( + self.apply_perms_to_obj(self._xs_train, perm_idxs=self.perm_idxs), + requires_grad=False, + ) Jx_alphas_n_perms = self._Jx_alphas.numel() // (self.n_train * self.dim_d) if Jx_alphas_n_perms == 1: # Cached already? - self._Jx_alphas = nn.Parameter(self.apply_perms_to_obj( - self._Jx_alphas, perm_idxs=self.perm_idxs - ), requires_grad=False) + self._Jx_alphas = nn.Parameter( + self.apply_perms_to_obj(self._Jx_alphas, perm_idxs=self.perm_idxs), + requires_grad=False, + ) def est_mem_requirement(self, return_min=False): """ @@ -660,6 +679,7 @@ def est_mem_requirement(self, return_min=False): const_mem += ( n_perms_mem * self.n_train * self.dim_d * 2 ) # _xs_train and _Jx_alphas + const_mem += self.n_atoms * self.dim_d # agg_mat const_mem *= 8 const_mem = int(const_mem) @@ -708,15 +728,17 @@ def set_R_d_desc(self, R_d_desc): each training point. """ - self.R_d_desc = torch.from_numpy(R_d_desc) + self.R_d_desc = torch.from_numpy(R_d_desc).type(_dtype) # Try moving to GPU memory. - if torch.cuda.is_available(): + if _torch_cuda_is_available or _torch_mps_is_available: try: R_d_desc = self.R_d_desc.to(self._xs_train.device) except RuntimeError as e: if 'out of memory' in str(e): - torch.cuda.empty_cache() + + if _torch_cuda_is_available: + torch.cuda.empty_cache() self._log.debug('Failed to cache \'R_d_desc\' on GPU.') else: @@ -750,30 +772,36 @@ def set_alphas(self, alphas, alphas_E=None): if alphas_E is not None: self._alphas_E = nn.Parameter( - torch.from_numpy(alphas_E).to(self._xs_train.device), requires_grad=False + torch.from_numpy(alphas_E).to(self._xs_train.device).type(_dtype), + requires_grad=False, ) del self._Jx_alphas while True: try: - alphas_torch = torch.from_numpy(alphas).to(self.R_d_desc.device) # Send to whatever device 'R_d_desc' is on, first. + alphas_torch = torch.from_numpy(alphas).type(_dtype).to( + self.R_d_desc.device + ) # Send to whatever device 'R_d_desc' is on, first. xs = self.desc.d_desc_dot_vec( self.R_d_desc, alphas_torch.reshape(-1, self.dim_i) ) del alphas_torch - if torch.cuda.is_available() and not xs.is_cuda: - xs = xs.to(self._xs_train.device) # Only now send it to the GPU ('_xs_train' will be for sure, if GPUs are available) + if (_torch_cuda_is_available and not xs.is_cuda) or (_torch_mps_is_available and not xs.is_mps): + xs = xs.to( + self._xs_train.device + ) # Only now send it to the GPU ('_xs_train' will be for sure, if GPUs are available) except RuntimeError as e: if 'out of memory' in str(e): - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if _torch_cuda_is_available or _torch_mps_is_available: + + if _torch_cuda_is_available: + torch.cuda.empty_cache() self.R_d_desc = self.R_d_desc.cpu() - #torch.cuda.empty_cache() self._log.debug( 'Failed to \'set_alphas()\': \'R_d_desc\' was moved back from GPU to CPU' @@ -808,17 +836,19 @@ def set_alphas(self, alphas, alphas_E=None): if _n_perm_batches < self.n_perms: - #self._log.debug('Uncaching permutations (within \'set_alphas()\')') - - self._log.debug('Setting permutation batch size to {}{}.'.format(_n_perm_batches, ' (no caching)' if _n_perm_batches > 1 else '')) + self._log.debug( + 'Setting permutation batch size to {}/{}{}.'.format( + self.n_perms // n_perm_batches, self.n_perms, ' (no caching)' if n_perm_batches > 1 else '' + ) + ) - _n_perm_batches += 1 # Do NOT change me to use 'self.set_n_perm_batches(_n_perm_batches + 1)'! + _n_perm_batches += 1 # Do NOT change me to use 'self.set_n_perm_batches(_n_perm_batches + 1)'! self._xs_train = nn.Parameter( self.remove_perms_from_obj(self._xs_train), requires_grad=False - ) # Remove any permutations from 'self._xs_train'. + ) # Remove any permutations from 'self._xs_train'. self._Jx_alphas = nn.Parameter( self.apply_perms_to_obj(xs, perm_idxs=None), requires_grad=False - ) # Set 'self._Jx_alphas' without applying permutations. + ) # Set 'self._Jx_alphas' without applying permutations. else: self._log.critical( @@ -829,7 +859,6 @@ def set_alphas(self, alphas, alphas_E=None): else: raise e - def _forward(self, Rs_or_train_idxs, return_E=True): global _n_perm_batches @@ -841,7 +870,8 @@ def _forward(self, Rs_or_train_idxs, return_E=True): if not is_train_pred: # Rs Rs = Rs_or_train_idxs - diffs = Rs[:, :, None, :] - Rs[:, None, :, :] # N, a, a, 3 + diffs = Rs[:, :, None, :] - Rs[:, None, :, :] # N, a, a, 3 + diffs = diffs[:, i, j, :] # N, d, 3 if self._lat_and_inv is not None: @@ -862,7 +892,11 @@ def _forward(self, Rs_or_train_idxs, return_E=True): diffs = diffs.reshape(diffs_shape) - xs = 1 / diffs.norm(dim=-1)[:, i, j] # R_desc # N, d + xs = 1 / diffs.norm(dim=-1) # N, d + + diffs *= xs[:, :, None] ** 3 + Jxs = diffs + del diffs else: # xs_train @@ -876,9 +910,7 @@ def _forward(self, Rs_or_train_idxs, return_E=True): train_idxs, idx_id_perm, : ] # ignore permutations - Jxs = self.R_d_desc[train_idxs, :, :] - #if torch.cuda.is_available() and not self.R_d_desc.is_cuda: - Jxs = Jxs.to(xs.device) # 'R_d_desc' can live on the CPU, as well. + Jxs = self.R_d_desc[train_idxs, :, :].to(xs.device) # 'R_d_desc' might be living on the CPU... # current: # diffs: N, a, a, 3 @@ -915,19 +947,24 @@ def _forward(self, Rs_or_train_idxs, return_E=True): ) # N, n_perms*N_train, d x_dists = x_diffs.norm(dim=-1) # N, n_perms*N - exp_xs = ( - 5.0 / (3 * self._sig ** 2) * torch.exp(-x_dists) - ) # N, n_perms*N_train + exp_xs = torch.exp(-x_dists) * (q ** 2) / 3 # N, n_perms exp_xs_1_x_dists = exp_xs * (1 + x_dists) # N, n_perms*N_train + if self._alphas_E is None: + del x_dists + dot_x_diff_Jx_alphas = torch.einsum( 'ij...,j...->ij', x_diffs, Jx_alphas_perm_split ) # N, n_perms*N_train # Fs_x = ((exp_xs * dot_x_diff_Jx_alphas)[..., None] * x_diffs).sum(dim=1) - Fs_x = Fs_x + torch.einsum( + Fs_x += torch.einsum( # NOTE ! Fs_x = Fs_x + torch.einsum( '...j,...j,...jk', exp_xs, dot_x_diff_Jx_alphas, x_diffs ) # N, d + del exp_xs + + if self._alphas_E is None: + del x_diffs # current: # diffs: N, a, a, 3 @@ -939,8 +976,6 @@ def _forward(self, Rs_or_train_idxs, return_E=True): # exp_xs_1_x_dists: N, n_perms*N_train # Fs_x: N, d - del exp_xs - Fs_x -= exp_xs_1_x_dists.mm(Jx_alphas_perm_split) # N, d if return_E: @@ -949,17 +984,25 @@ def _forward(self, Rs_or_train_idxs, return_E=True): / q ) + del dot_x_diff_Jx_alphas + + if self._alphas_E is None: + del exp_xs_1_x_dists + # Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E) if self._alphas_E is not None: - #K_fe = (x_diffs / q) * exp_xs[:, :, None] * (x_dists[:, :, None] + 1) K_fe = (x_diffs / q) * exp_xs_1_x_dists[:, :, None] + del exp_xs_1_x_dists + del x_diffs + K_fe = K_fe.reshape(-1, self.n_train, len(perm_batch), self.dim_d) Fs_x += torch.einsum('j,...jkl->...l', self._alphas_E, K_fe) - del x_diffs del K_fe K_ee = (1 + x_dists * (1 + x_dists / 3)) * torch.exp(-x_dists) + del x_dists + K_ee = K_ee.reshape(-1, self.n_train, len(perm_batch)) Es += torch.einsum('j,...jk->...', self._alphas_E, K_ee) del K_ee @@ -972,24 +1015,10 @@ def _forward(self, Rs_or_train_idxs, return_E=True): # exp_xs_1_x_dists: N, n_perms*N # Fs_x: N, d - if not is_train_pred: # Rs - - Fs_x *= xs ** 3 - diffs[:, i, j, :] *= Fs_x[..., None] - diffs[:, j, i, :] *= Fs_x[..., None] - - else: # xs_train - - n = Jxs.shape[0] - diffs = torch.zeros( - (n, self.n_atoms, self.n_atoms, 3), device=xs.device, dtype=xs.dtype - ) - - diffs[:, i, j, :] = Jxs * Fs_x[..., None] - diffs[:, j, i, :] = -diffs[:, i, j, :] + Fs = torch.einsum('ji,...ik,...i->...jk', self.agg_mat, Jxs, Fs_x) - Fs = diffs.sum(dim=1) * self._std - del diffs + if not is_train_pred: # TODO: set std to zero in training mode? + Fs *= self._std if return_E: Es *= self._std @@ -997,16 +1026,16 @@ def _forward(self, Rs_or_train_idxs, return_E=True): return Es, Fs - def forward(self, Rs_or_train_idxs=None, return_E=True): + def forward(self, Rs_or_train_idxs, return_E=True): """ Predict energy and forces for a batch of geometries. Parameters ---------- - Rs : :obj:`torch.Tensor`, optional - (dims M x N x 3) Cartesian coordinates of M molecules composed of N atoms. - If this parameter is ommited, the training error is returned. Note that the training - geometries need to be set right after initialization using `set_R()` for this to work. + Rs_or_train_idxs : :obj:`torch.Tensor` + (dims M x N x 3) Cartesian coordinates of M molecules composed of N atoms or + (dims N) index list of training points to evaluate. Note that `self.R_d_desc` + needs to be set for the latter to work. return_E : boolean, optional If false (default: true), only the forces are returned. @@ -1020,21 +1049,21 @@ def forward(self, Rs_or_train_idxs=None, return_E=True): global _batch_size, _n_perm_batches - if Rs_or_train_idxs.dim() == 1: - # contains index list. return predictions for these training points - dtype = self.R_d_desc.dtype - elif Rs_or_train_idxs.dim() == 3: + #if Rs_or_train_idxs.dim() == 1: + # # contains index list. return predictions for these training points + # dtype = self.R_d_desc.dtype + #elif Rs_or_train_idxs.dim() == 3: # this is real data - assert Rs_or_train_idxs.shape[1:] == (self.n_atoms, 3) - Rs_or_train_idxs = Rs_or_train_idxs.double() - dtype = Rs_or_train_idxs.dtype + # assert Rs_or_train_idxs.shape[1:] == (self.n_atoms, 3) + # Rs_or_train_idxs = Rs_or_train_idxs.double() + # dtype = Rs_or_train_idxs.dtype - else: - # unknown input - self._log.critical('Invalid input for \'Rs_or_train_idxs\'.') - print() - os._exit(1) + #else: + # # unknown input + # self._log.critical('Invalid input for \'Rs_or_train_idxs\'.') + # print() + # os._exit(1) while True: try: @@ -1051,13 +1080,15 @@ def forward(self, Rs_or_train_idxs=None, return_E=True): if _batch_size > 1: - self._log.debug('Setting batch size to {}/{} points.'.format(_batch_size, self.n_train)) + self._log.debug( + 'Setting batch size to {}/{} points.'.format( + _batch_size, self.n_train + ) + ) _batch_size -= 1 elif _n_perm_batches < self.n_perms: - n_perm_batches = _next_batch_size( - self.n_perms, _n_perm_batches - ) + n_perm_batches = _next_batch_size(self.n_perms, _n_perm_batches) self.set_n_perm_batches(n_perm_batches) else: @@ -1071,8 +1102,8 @@ def forward(self, Rs_or_train_idxs=None, return_E=True): else: break - ret = (torch.cat(Fs).to(dtype),) + ret = (torch.cat(Fs),) if return_E: - ret = (torch.cat(Es).to(dtype),) + ret + ret = (torch.cat(Es),) + ret - return ret + return ret \ No newline at end of file diff --git a/sgdml/train.py b/sgdml/train.py index 929e7b5..7de91ef 100644 --- a/sgdml/train.py +++ b/sgdml/train.py @@ -47,6 +47,17 @@ else: _has_torch = True +try: + _torch_mps_is_available = torch.backends.mps.is_available() +except AttributeError: + _torch_mps_is_available = False +_torch_mps_is_available = False + +try: + _torch_cuda_is_available = torch.cuda.is_available() +except AttributeError: + _torch_cuda_is_available = False + from . import __version__, DONE, NOT_DONE from .solvers.analytic import Analytic @@ -959,7 +970,7 @@ def train( # noqa: C901 self.log.debug('Iterative solver not installed.') use_analytic_solver = True - # use_analytic_solver = False # remove me! + # use_analytic_solver = True # remove me! if use_analytic_solver: @@ -1415,7 +1426,12 @@ def progress_callback(done): start = timeit.default_timer() - torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' + if _torch_cuda_is_available: + torch_device = 'cuda' + elif _torch_mps_is_available: + torch_device = 'mps' + else: + torch_device = 'cpu' R_desc_torch = torch.from_numpy(R_desc).to(torch_device) # N, d R_d_desc_torch = torch.from_numpy(R_d_desc).to(torch_device)