You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Remove dtype parameter, use previously existing "precision" instead (#208)
* Remove dtype parameter, use previously existing "precision" instead
* Do not store dtype in args when creating the model
* Wrap the dataset in the DataLoader to cast data to the requested precision
* Inherit every member from the wrapped datset when casting to other
float precision
* blacken
* Add tests for double precision training
* Remove unnecessary default
* Add precision to a test
* Fix a test
Copy file name to clipboardExpand all lines: torchmdnet/scripts/train.py
+1-2
Original file line number
Diff line number
Diff line change
@@ -37,7 +37,7 @@ def get_args():
37
37
parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy')
38
38
parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
39
39
parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
40
-
parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
40
+
parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision')
parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
43
43
parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)')
@@ -67,7 +67,6 @@ def get_args():
67
67
parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')
68
68
69
69
# architectural args
70
-
parser.add_argument('--dtype', type=str, default="float32", choices=list(dtype_mapping.keys()), help='Floating point precision. Can be float32 or float64')
71
70
parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge')
72
71
parser.add_argument('--spin', type=bool, default=False, help='Model needs a spin state')
0 commit comments