Skip to content

Commit dca6679

Browse files
authored
Remove dtype parameter, use previously existing "precision" instead (#208)
* Remove dtype parameter, use previously existing "precision" instead * Do not store dtype in args when creating the model * Wrap the dataset in the DataLoader to cast data to the requested precision * Inherit every member from the wrapped datset when casting to other float precision * blacken * Add tests for double precision training * Remove unnecessary default * Add precision to a test * Fix a test
1 parent 4645fa4 commit dca6679

File tree

9 files changed

+60
-26
lines changed

9 files changed

+60
-26
lines changed

examples/ET-QM9.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ train_size: 110000
5555
trainable_rbf: false
5656
val_size: 10000
5757
weight_decay: 0.0
58-
dtype: float
58+
precision: 32

tests/test_model.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
from torchmdnet import models
88
from torchmdnet.models.model import create_model
99
from torchmdnet.models import output_modules
10+
from torchmdnet.models.utils import dtype_mapping
1011

1112
from utils import load_example_args, create_example_batch
1213

1314

1415
@mark.parametrize("model_name", models.__all__)
1516
@mark.parametrize("use_batch", [True, False])
1617
@mark.parametrize("explicit_q_s", [True, False])
17-
@mark.parametrize("dtype", [torch.float32, torch.float64])
18-
def test_forward(model_name, use_batch, explicit_q_s, dtype):
18+
@mark.parametrize("precision", [32, 64])
19+
def test_forward(model_name, use_batch, explicit_q_s, precision):
1920
z, pos, batch = create_example_batch()
20-
pos = pos.to(dtype=dtype)
21-
model = create_model(load_example_args(model_name, prior_model=None, dtype=dtype))
21+
pos = pos.to(dtype=dtype_mapping[precision])
22+
model = create_model(load_example_args(model_name, prior_model=None, precision=precision))
2223
batch = batch if use_batch else None
2324
if explicit_q_s:
2425
model(z, pos, batch=batch, q=None, s=None)
@@ -28,10 +29,10 @@ def test_forward(model_name, use_batch, explicit_q_s, dtype):
2829

2930
@mark.parametrize("model_name", models.__all__)
3031
@mark.parametrize("output_model", output_modules.__all__)
31-
@mark.parametrize("dtype", [torch.float32, torch.float64])
32-
def test_forward_output_modules(model_name, output_model, dtype):
32+
@mark.parametrize("precision", [32,64])
33+
def test_forward_output_modules(model_name, output_model, precision):
3334
z, pos, batch = create_example_batch()
34-
args = load_example_args(model_name, remove_prior=True, output_model=output_model, dtype=dtype)
35+
args = load_example_args(model_name, remove_prior=True, output_model=output_model, precision=precision)
3536
model = create_model(args)
3637
model(z, pos, batch=batch)
3738

@@ -146,7 +147,7 @@ def test_forward_output(model_name, output_model, overwrite_reference=False):
146147
@mark.parametrize("model_name", models.__all__)
147148
def test_gradients(model_name):
148149
pl.seed_everything(1234)
149-
dtype = torch.float64
150+
precision = 64
150151
output_model = "Scalar"
151152
# create model and sample batch
152153
derivative = output_model in ["Scalar", "EquivariantScalar"]
@@ -155,12 +156,12 @@ def test_gradients(model_name):
155156
remove_prior=True,
156157
output_model=output_model,
157158
derivative=derivative,
158-
dtype=dtype,
159+
precision=precision
159160
)
160161
model = create_model(args)
161162
z, pos, batch = create_example_batch(n_atoms=5)
162163
pos.requires_grad_(True)
163-
pos = pos.to(dtype)
164+
pos = pos.to(torch.float64)
164165
torch.autograd.gradcheck(
165166
model, (z, pos, batch), eps=1e-4, atol=1e-3, rtol=1e-2, nondet_tol=1e-3
166167
)

tests/test_module.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def test_load_model():
2424

2525
@mark.parametrize("model_name", models.__all__)
2626
@mark.parametrize("use_atomref", [True, False])
27-
def test_train(model_name, use_atomref, tmpdir):
27+
@mark.parametrize("precision", [32, 64])
28+
def test_train(model_name, use_atomref, precision, tmpdir):
2829
args = load_example_args(
2930
model_name,
3031
remove_prior=not use_atomref,
@@ -37,6 +38,7 @@ def test_train(model_name, use_atomref, tmpdir):
3738
num_layers=2,
3839
num_rbf=16,
3940
batch_size=8,
41+
precision=precision,
4042
)
4143
datamodule = DataModule(args, DummyDataset(has_atomref=use_atomref))
4244

@@ -47,6 +49,6 @@ def test_train(model_name, use_atomref, tmpdir):
4749

4850
module = LNNP(args, prior_model=prior)
4951

50-
trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir)
52+
trainer = pl.Trainer(max_steps=10, default_root_dir=tmpdir, precision=args["precision"])
5153
trainer.fit(module, datamodule)
5254
trainer.test(module, datamodule)

tests/test_optimize.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch as pt
44
from torchmdnet.models.model import create_model
55
from torchmdnet.optimize import optimize
6-
6+
from torchmdnet.models.utils import dtype_mapping
77

88
@mark.parametrize("device", ["cpu", "cuda"])
99
@mark.parametrize("num_atoms", [10, 100])
@@ -39,6 +39,7 @@ def test_gn(device, num_atoms):
3939
"prior_model": None,
4040
"output_model": "Scalar",
4141
"reduce_op": "add",
42+
"precision": 32,
4243
}
4344
ref_model = create_model(args).to(device)
4445

@@ -47,7 +48,7 @@ def test_gn(device, num_atoms):
4748

4849
# Optimize the model
4950
model = optimize(ref_model).to(device)
50-
51+
positions.to(dtype_mapping[args["precision"]])
5152
# Execute the optimize model
5253
energy, gradient = model(elements, positions)
5354

tests/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs
1212
config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml")
1313
with open(config_file, "r") as f:
1414
args = yaml.load(f, Loader=yaml.FullLoader)
15-
if "dtype" not in args:
16-
args["dtype"] = "float"
15+
if "precision" not in args:
16+
args["precision"] = 32
1717
args["model"] = model_name
1818
args["seed"] = 1234
1919
if remove_prior:

torchmdnet/data.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,37 @@
66
from pytorch_lightning import LightningDataModule
77
from pytorch_lightning.utilities import rank_zero_warn
88
from torchmdnet import datasets
9+
from torch_geometric.data import Dataset
910
from torchmdnet.utils import make_splits, MissingEnergyException
1011
from torch_scatter import scatter
12+
from torchmdnet.models.utils import dtype_mapping
13+
14+
15+
class FloatCastDatasetWrapper(Dataset):
16+
def __init__(self, dataset, dtype=torch.float64):
17+
super(FloatCastDatasetWrapper, self).__init__(
18+
dataset.root, dataset.transform, dataset.pre_transform, dataset.pre_filter
19+
)
20+
self.dataset = dataset
21+
self.dtype = dtype
22+
23+
def len(self):
24+
return len(self.dataset)
25+
26+
def get(self, idx):
27+
data = self.dataset.get(idx)
28+
for key, value in data:
29+
if torch.is_tensor(value) and torch.is_floating_point(value):
30+
setattr(data, key, value.to(self.dtype))
31+
return data
32+
33+
def __getattr__(self, name):
34+
# Check if the attribute exists in the underlying dataset
35+
if hasattr(self.dataset, name):
36+
return getattr(self.dataset, name)
37+
raise AttributeError(
38+
f"'{type(self).__name__}' and its underlying dataset have no attribute '{name}'"
39+
)
1140

1241

1342
class DataModule(LightningDataModule):
@@ -34,6 +63,9 @@ def setup(self, stage):
3463
self.dataset = getattr(datasets, self.hparams["dataset"])(
3564
self.hparams["dataset_root"], **dataset_arg
3665
)
66+
self.dataset = FloatCastDatasetWrapper(
67+
self.dataset, dtype_mapping[self.hparams["precision"]]
68+
)
3769

3870
self.idx_train, self.idx_val, self.idx_test = make_splits(
3971
len(self.dataset),
@@ -62,7 +94,7 @@ def val_dataloader(self):
6294
loaders = [self._get_dataloader(self.val_dataset, "val")]
6395
if (
6496
len(self.test_dataset) > 0
65-
and (self.trainer.current_epoch+1) % self.hparams["test_interval"] == 0
97+
and (self.trainer.current_epoch + 1) % self.hparams["test_interval"] == 0
6698
):
6799
loaders.append(self._get_dataloader(self.test_dataset, "test"))
68100
return loaders

torchmdnet/models/model.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
2525
-------
2626
nn.Module: An instance of the TorchMD_Net model.
2727
"""
28-
args["dtype"] = "float32" if "dtype" not in args else args["dtype"]
29-
args["dtype"] = dtype_mapping[args["dtype"]] if isinstance(args["dtype"], str) else args["dtype"]
28+
dtype = dtype_mapping[args["precision"]]
3029
shared_args = dict(
3130
hidden_channels=args["embedding_dimension"],
3231
num_layers=args["num_layers"],
@@ -38,7 +37,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
3837
cutoff_upper=args["cutoff_upper"],
3938
max_z=args["max_z"],
4039
max_num_neighbors=args["max_num_neighbors"],
41-
dtype=args["dtype"]
40+
dtype=dtype
4241
)
4342

4443
# representation network
@@ -102,7 +101,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
102101
args["embedding_dimension"],
103102
activation=args["activation"],
104103
reduce_op=args["reduce_op"],
105-
dtype=args["dtype"],
104+
dtype=dtype,
106105
)
107106

108107
# combine representation and output network
@@ -113,7 +112,7 @@ def create_model(args, prior_model=None, mean=None, std=None):
113112
mean=mean,
114113
std=std,
115114
derivative=args["derivative"],
116-
dtype=args["dtype"],
115+
dtype=dtype,
117116
)
118117
return model
119118

torchmdnet/models/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -526,4 +526,4 @@ def forward(self, x, v):
526526
"sigmoid": nn.Sigmoid,
527527
}
528528

529-
dtype_mapping = {"float": torch.float, "double": torch.float64, "float32": torch.float32, "float64": torch.float64}
529+
dtype_mapping = {16: torch.float16, 32: torch.float, 64: torch.float64}

torchmdnet/scripts/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_args():
3737
parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy')
3838
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
3939
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
40-
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
40+
parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision')
4141
parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file')
4242
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
4343
parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)')
@@ -67,7 +67,6 @@ def get_args():
6767
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')
6868

6969
# architectural args
70-
parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64')
7170
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge')
7271
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state')
7372
parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')

0 commit comments

Comments
 (0)