Skip to content

Remove dtype parameter, use previously existing "precision" instead #208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 8, 2023
2 changes: 1 addition & 1 deletion examples/ET-QM9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ train_size: 110000
trainable_rbf: false
val_size: 10000
weight_decay: 0.0
dtype: float
precision: 32
21 changes: 11 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
from torchmdnet import models
from torchmdnet.models.model import create_model
from torchmdnet.models import output_modules
from torchmdnet.models.utils import dtype_mapping

from utils import load_example_args, create_example_batch


@mark.parametrize("model_name", models.__all__)
@mark.parametrize("use_batch", [True, False])
@mark.parametrize("explicit_q_s", [True, False])
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward(model_name, use_batch, explicit_q_s, dtype):
@mark.parametrize("precision", [32, 64])
def test_forward(model_name, use_batch, explicit_q_s, precision):
z, pos, batch = create_example_batch()
pos = pos.to(dtype=dtype)
model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype))
pos = pos.to(dtype=dtype_mapping[precision])
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
batch = batch if use_batch else None
if explicit_q_s:
model(z, pos, batch=batch, q=None, s=None)
Expand All @@ -28,10 +29,10 @@ def test_forward(model_name, use_batch, explicit_q_s, dtype):

@mark.parametrize("model_name", models.__all__)
@mark.parametrize("output_model", output_modules.__all__)
@mark.parametrize("dtype", [torch.float32, torch.float64])
def test_forward_output_modules(model_name, output_model, dtype):
@mark.parametrize("precision", [32,64])
def test_forward_output_modules(model_name, output_model, precision):
z, pos, batch = create_example_batch()
args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype)
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
model = create_model(args)
model(z, pos, batch=batch)

Expand Down Expand Up @@ -146,7 +147,7 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
@mark.parametrize("model_name", models.__all__)
def test_gradients(model_name):
pl.seed_everything(1234)
dtype = torch.float64
precision = 64
output_model = "Scalar"
# create model and sample batch
derivative = output_model in ["Scalar", "EquivariantScalar"]
Expand All @@ -155,12 +156,12 @@ def test_gradients(model_name):
remove_prior=True,
output_model=output_model,
derivative=derivative,
dtype=dtype,
precision=precision
)
model = create_model(args)
z, pos, batch = create_example_batch(n_atoms=5)
pos.requires_grad_(True)
pos = pos.to(dtype)
pos = pos.to(torch.float64)
torch.autograd.gradcheck(
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
)
6 changes: 4 additions & 2 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def test_load_model():

@mark.parametrize("model_name", models.__all__)
@mark.parametrize("use_atomref", [True, False])
def test_train(model_name, use_atomref, tmpdir):
@mark.parametrize("precision", [32, 64])
def test_train(model_name, use_atomref, precision, tmpdir):
args = load_example_args(
model_name,
remove_prior=not use_atomref,
Expand All @@ -37,6 +38,7 @@ def test_train(model_name, use_atomref, tmpdir):
num_layers=2,
num_rbf=16,
batch_size=8,
precision=precision,
)
datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref))

Expand All @@ -47,6 +49,6 @@ def test_train(model_name, use_atomref, tmpdir):

module = LNNP(args, prior_model=prior)

trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir)
trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"])
trainer.fit(module, datamodule)
trainer.test(module, datamodule)
5 changes: 3 additions & 2 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch as pt
from torchmdnet.models.model import create_model
from torchmdnet.optimize import optimize

from torchmdnet.models.utils import dtype_mapping

@mark.parametrize("device", ["cpu", "cuda"])
@mark.parametrize("num_atoms", [10, 100])
Expand Down Expand Up @@ -39,6 +39,7 @@ def test_gn(device, num_atoms):
"prior_model": None,
"output_model": "Scalar",
"reduce_op": "add",
"precision": 32,
}
ref_model = create_model(args).to(device)

Expand All @@ -47,7 +48,7 @@ def test_gn(device, num_atoms):

# Optimize the model
model = optimize(ref_model).to(device)

positions.to(dtype_mapping[args["precision"]])
# Execute the optimize model
energy, gradient = model(elements, positions)

Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
with open(config_file, "r") as f:
args = yaml.load(f, Loader=yaml.FullLoader)
if "dtype" not in args:
args["dtype"] = "float"
if "precision" not in args:
args["precision"] = 32
args["model"] = model_name
args["seed"] = 1234
if remove_prior:
Expand Down
34 changes: 33 additions & 1 deletion torchmdnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,37 @@
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import rank_zero_warn
from torchmdnet import datasets
from torch_geometric.data import Dataset
from torchmdnet.utils import make_splits, MissingEnergyException
from torch_scatter import scatter
from torchmdnet.models.utils import dtype_mapping


class FloatCastDatasetWrapper(Dataset):
def __init__(self, dataset, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__(
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
)
self.dataset = dataset
self.dtype = dtype

def len(self):
return len(self.dataset)

def get(self, idx):
data = self.dataset.get(idx)
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self.dtype))
return data

def __getattr__(self, name):
# Check if the attribute exists in the underlying dataset
if hasattr(self.dataset, name):
return getattr(self.dataset, name)
raise AttributeError(
f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'"
)


class DataModule(LightningDataModule):
Expand All @@ -34,6 +63,9 @@ def setup(self, stage):
self.dataset = getattr(datasets, self.hparams["dataset"])(
self.hparams["dataset_root"], **dataset_arg
)
self.dataset = FloatCastDatasetWrapper(
self.dataset, dtype_mapping[self.hparams["precision"]]
)

self.idx_train, self.idx_val, self.idx_test = make_splits(
len(self.dataset),
Expand Down Expand Up @@ -62,7 +94,7 @@ def val_dataloader(self):
loaders = [self._get_dataloader(self.val_dataset, "val")]
if (
len(self.test_dataset) > 0
and (self.trainer.current_epoch+1) % self.hparams["test_interval"] == 0
and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0
):
loaders.append(self._get_dataloader(self.test_dataset, "test"))
return loaders
Expand Down
9 changes: 4 additions & 5 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
-------
nn.Module: An instance of the TorchMD_Net model.
"""
args["dtype"] = "float32" if "dtype" not in args else args["dtype"]
args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"]
dtype = dtype_mapping[args["precision"]]
shared_args = dict(
hidden_channels=args["embedding_dimension"],
num_layers=args["num_layers"],
Expand All @@ -38,7 +37,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
cutoff_upper=args["cutoff_upper"],
max_z=args["max_z"],
max_num_neighbors=args["max_num_neighbors"],
dtype=args["dtype"]
dtype=dtype
)

# representation network
Expand Down Expand Up @@ -102,7 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
args["embedding_dimension"],
activation=args["activation"],
reduce_op=args["reduce_op"],
dtype=args["dtype"],
dtype=dtype,
)

# combine representation and output network
Expand All @@ -113,7 +112,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
mean=mean,
std=std,
derivative=args["derivative"],
dtype=args["dtype"],
dtype=dtype,
)
return model

Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,4 +526,4 @@ def forward(self, x, v):
"sigmoid": nn.Sigmoid,
}

dtype_mapping = {"float": torch.float, "double": torch.float64, "float32": torch.float32, "float64": torch.float64}
dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}
3 changes: 1 addition & 2 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_args():
parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy')
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision')
parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file')
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)')
Expand Down Expand Up @@ -67,7 +67,6 @@ def get_args():
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')

# architectural args
parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64')
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge')
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state')
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
Expand Down