-
Notifications
You must be signed in to change notification settings - Fork 161
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Multitask Molecular Property Prediction with GNN (#159)
* Update * Update * Fix * Fix * Update * Fix * Fix * Fix * Fix * Fix * Update * Update
- Loading branch information
Showing
13 changed files
with
1,181 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Multitask Graph Neural Network for Molecular Property Prediction | ||
|
||
## Usage | ||
|
||
``` | ||
python -c CSV -m MODEL --mode MODE -p PATH -s SMILES -t TASKS | ||
``` | ||
|
||
where: | ||
- `CSV` specifies the path to a CSV file for the dataset | ||
- `MODEL` specifies the model to use, which can be `GCN`, `GAT`, `MPNN`, or `AttentiveFP` | ||
- `MODE` specifies the multitask architecture to use, which can be `parallel` or `bypass` | ||
- `PATH` specifies the path to save training results | ||
- `SMILES` specifies the SMIELS column header in the CSV file | ||
- `TASKS` specifies the CSV column headers for the tasks to model. For multiple tasks, separate them by comma, e.g., task1,task2,task3. It not specified, all columns except for the SMILES column will be treated as properties/tasks. | ||
|
||
## Example | ||
|
||
For demonstration, you can generate a synthetic dataset as follows. | ||
|
||
```python | ||
import torch | ||
import pandas as pd | ||
|
||
data = { | ||
'smiles': ['CCO' for _ in range(128)], | ||
'logP': torch.randn(128).numpy().tolist(), | ||
'logD': torch.randn(128).numpy().tolist() | ||
} | ||
df = pd.DataFrame(data) | ||
df.to_csv('syn_data.csv', index=False) | ||
``` | ||
|
||
After you run an experiment with | ||
|
||
``` | ||
python main.py -c syn_data.csv -m GCN --mode parallel -p results -s smiles -t logP,logD | ||
``` | ||
|
||
Once the experiment is completed, `results/model.pth` is the trained model checkpoint | ||
and `results/results.txt` is the evaluation result. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
GCN_parallel = { | ||
'gnn_hidden_feats': 64, | ||
'num_gnn_layers': 2, | ||
'regressor_hidden_feats': 64, | ||
'lr': 1e-3, | ||
'weight_decay': 0., | ||
'dropout': 0., | ||
'patience': 50, | ||
'batch_size': 128 | ||
} | ||
|
||
GCN_bypass = { | ||
'gnn_hidden_feats': 128, | ||
'num_gnn_layers': 2, | ||
'regressor_hidden_feats': 32, | ||
'lr': 1e-3, | ||
'weight_decay': 0., | ||
'dropout': 0., | ||
'patience': 30, | ||
'batch_size': 128 | ||
} | ||
|
||
GAT_parallel = { | ||
'gnn_hidden_feats': 32, | ||
'num_gnn_layers': 2, | ||
'num_heads': 6, | ||
'regressor_hidden_feats': 32, | ||
'lr': 3e-3, | ||
'weight_decay': 3e-5, | ||
'dropout': 0.01, | ||
'patience': 100, | ||
'batch_size': 128 | ||
} | ||
|
||
GAT_bypass = { | ||
'gnn_hidden_feats': 32, | ||
'num_gnn_layers': 3, | ||
'num_heads': 8, | ||
'regressor_hidden_feats': 32, | ||
'lr': 1e-3, | ||
'weight_decay': 3e-6, | ||
'dropout': 0.1, | ||
'patience': 30, | ||
'batch_size': 128 | ||
} | ||
|
||
MPNN_parallel = { | ||
'node_hidden_dim': 64, | ||
'edge_hidden_dim': 16, | ||
'num_step_message_passing': 2, | ||
'num_step_set2set': 3, | ||
'num_layer_set2set': 2, | ||
'regressor_hidden_feats': 32, | ||
'lr': 1e-3, | ||
'weight_decay': 0., | ||
'dropout': 0., | ||
'patience': 50, | ||
'batch_size': 128 | ||
} | ||
|
||
MPNN_bypass = { | ||
'node_hidden_dim': 32, | ||
'edge_hidden_dim': 64, | ||
'num_step_message_passing': 2, | ||
'num_step_set2set': 2, | ||
'num_layer_set2set': 2, | ||
'regressor_hidden_feats': 32, | ||
'lr': 1e-3, | ||
'weight_decay': 0., | ||
'dropout': 0.01, | ||
'patience': 50, | ||
'batch_size': 128 | ||
} | ||
|
||
AttentiveFP_parallel = { | ||
'num_gnn_layers': 3, | ||
'gnn_out_feats': 64, | ||
'num_timesteps': 3, | ||
'regressor_hidden_feats': 32, | ||
'lr': 1e-3, | ||
'weight_decay': 0., | ||
'dropout': 0., | ||
'patience': 50, | ||
'batch_size': 32 | ||
} | ||
|
||
AttentiveFP_bypass = { | ||
'num_gnn_layers': 2, | ||
'gnn_out_feats': 32, | ||
'num_timesteps': 2, | ||
'regressor_hidden_feats': 32, | ||
'lr': 1e-3, | ||
'weight_decay': 0., | ||
'dropout': 0., | ||
'patience': 50, | ||
'batch_size': 32 | ||
} | ||
|
||
configs = { | ||
'GCN_parallel': GCN_parallel, | ||
'GCN_bypass': GCN_bypass, | ||
'GAT_parallel': GAT_parallel, | ||
'GAT_bypass': GAT_bypass, | ||
'MPNN_parallel': MPNN_parallel, | ||
'MPNN_bypass': MPNN_bypass, | ||
'AttentiveFP_parallel': AttentiveFP_parallel, | ||
'AttentiveFP_bypass': AttentiveFP_bypass | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from dgllife.utils import atom_type_one_hot, atom_degree_one_hot, \ | ||
atom_hybridization_one_hot, atom_is_aromatic_one_hot, \ | ||
atom_chiral_tag_one_hot, atom_formal_charge_one_hot, atom_mass, \ | ||
atom_implicit_valence_one_hot, BaseAtomFeaturizer, \ | ||
ConcatFeaturizer, CanonicalBondFeaturizer | ||
from functools import partial | ||
from rdkit import Chem | ||
|
||
atom_featurizer = BaseAtomFeaturizer( | ||
featurizer_funcs={ | ||
'hv': ConcatFeaturizer( | ||
[partial(atom_degree_one_hot, allowable_set=[1, 2, 3, 4, 6]), | ||
partial(atom_type_one_hot, allowable_set=[ | ||
'B', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'N', 'O', 'P', 'S', 'Se', 'Si']), | ||
atom_chiral_tag_one_hot, | ||
partial(atom_formal_charge_one_hot, allowable_set=[-1, 0, 1]), | ||
partial(atom_hybridization_one_hot, allowable_set=[ | ||
Chem.rdchem.HybridizationType.S, | ||
Chem.rdchem.HybridizationType.SP, | ||
Chem.rdchem.HybridizationType.SP2, | ||
Chem.rdchem.HybridizationType.SP3, | ||
Chem.rdchem.HybridizationType.SP3D2 | ||
]), | ||
partial(atom_implicit_valence_one_hot, allowable_set=list(range(4))), | ||
atom_is_aromatic_one_hot, atom_mass, | ||
])} | ||
) | ||
|
||
if __name__ == '__main__': | ||
import pandas as pd | ||
|
||
from argparse import ArgumentParser | ||
from dgllife.data import MoleculeCSVDataset | ||
from dgllife.utils import smiles_to_bigraph, RandomSplitter | ||
|
||
from configure import configs | ||
from run import main | ||
from utils import mkdir_p, setup | ||
|
||
parser = ArgumentParser('(Multitask) Molecular Property Prediction with GNNs for a user-specified .csv file.') | ||
parser.add_argument('-c', '--csv-path', type=str, required=True, | ||
help='Path to a csv file for loading a dataset.') | ||
parser.add_argument('-m', '--model', type=str, | ||
choices=['GCN', 'GAT', 'MPNN', 'AttentiveFP'], | ||
help='Model to use') | ||
parser.add_argument('--mode', type=str, choices=['parallel', 'bypass'], | ||
help='Architecture to use for multitask learning') | ||
parser.add_argument('-n', '--num-epochs', type=int, default=4000, | ||
help='Maximum number of epochs allowed for training. ' | ||
'We set a large number by default as early stopping will be performed.') | ||
parser.add_argument('-p', '--result-path', type=str, required=True, | ||
help='Path to training results') | ||
parser.add_argument('-s', '--smiles-column', type=str, default='smiles', | ||
help='CSV column header for the SMIELS strings. (default: smiles)') | ||
parser.add_argument('-t', '--tasks', default=None, type=str, | ||
help='CSV column headers for the tasks to model. For multiple tasks, separate them by ' | ||
'comma, e.g., task1,task2,task3, ... If None, we will model ' | ||
'all the columns except for the smiles_column in the CSV file. ' | ||
'(default: None)') | ||
args = parser.parse_args().__dict__ | ||
|
||
args['exp_name'] = '_'.join([args['model'], args['mode']]) | ||
if args['tasks'] is not None: | ||
args['tasks'] = args['tasks'].split(',') | ||
args.update(configs[args['exp_name']]) | ||
|
||
# Setup for experiments | ||
mkdir_p(args['result_path']) | ||
|
||
node_featurizer = atom_featurizer | ||
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True) | ||
df = pd.read_csv(args['csv_path']) | ||
dataset = MoleculeCSVDataset( | ||
df, partial(smiles_to_bigraph, add_self_loop=True), | ||
node_featurizer=node_featurizer, | ||
edge_featurizer=edge_featurizer, | ||
smiles_column=args['smiles_column'], | ||
cache_file_path=args['result_path'] + '/graph.bin', | ||
task_names=args['tasks'] | ||
) | ||
args['tasks'] = dataset.task_names | ||
args = setup(args) | ||
train_set, val_set, test_set = RandomSplitter.train_val_test_split( | ||
dataset, frac_train=0.8, frac_val=0.1, | ||
frac_test=0.1, random_state=0) | ||
|
||
main(args, node_featurizer, edge_featurizer, train_set, val_set, test_set) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .gcn import GCNRegressor, GCNRegressorBypass | ||
from .gat import GATRegressor, GATRegressorBypass | ||
from .mpnn import MPNNRegressor, MPNNRegressorBypass | ||
from .attentivefp import AttentiveFPRegressor, AttentiveFPRegressorBypass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch.nn as nn | ||
|
||
from dgllife.model import AttentiveFPGNN, AttentiveFPReadout | ||
|
||
from .regressor import BaseGNNRegressor, BaseGNNRegressorBypass | ||
|
||
class AttentiveFPRegressor(BaseGNNRegressor): | ||
"""AttentiveFP-based model for multitask molecular property prediction. | ||
We assume all tasks are regression problems. | ||
Parameters | ||
---------- | ||
in_node_feats : int | ||
Number of input node features | ||
in_edge_feats : int | ||
Number of input edge features | ||
gnn_out_feats : int | ||
The GNN output size | ||
num_layers : int | ||
Number of GNN layers | ||
num_timesteps : int | ||
Number of timesteps for updating molecular representations with GRU during readout | ||
n_tasks : int | ||
Number of prediction tasks | ||
regressor_hidden_feats : int | ||
Hidden size in MLP regressor | ||
dropout : float | ||
The probability for dropout. Default to 0, i.e. no dropout is performed. | ||
""" | ||
def __init__(self, in_node_feats, in_edge_feats, gnn_out_feats, num_layers, num_timesteps, | ||
n_tasks, regressor_hidden_feats=128, dropout=0.): | ||
super(AttentiveFPRegressor, self).__init__(readout_feats=gnn_out_feats, | ||
n_tasks=n_tasks, | ||
regressor_hidden_feats=regressor_hidden_feats, | ||
dropout=dropout) | ||
self.gnn = AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers, | ||
gnn_out_feats, dropout) | ||
self.readout = AttentiveFPReadout(gnn_out_feats, num_timesteps, dropout) | ||
|
||
class AttentiveFPRegressorBypass(BaseGNNRegressorBypass): | ||
"""AttentiveFP-based model for bypass multitask molecular property prediction. | ||
We assume all tasks are regression problems. | ||
Parameters | ||
---------- | ||
in_node_feats : int | ||
Number of input node features | ||
in_edge_feats : int | ||
Number of input edge features | ||
gnn_out_feats : int | ||
The GNN output size | ||
num_layers : int | ||
Number of GNN layers | ||
num_timesteps : int | ||
Number of timesteps for updating molecular representations with GRU during readout | ||
n_tasks : int | ||
Number of prediction tasks | ||
regressor_hidden_feats : int | ||
Hidden size in MLP regressor | ||
dropout : float | ||
The probability for dropout. Default to 0, i.e. no dropout is performed. | ||
""" | ||
def __init__(self, in_node_feats, in_edge_feats, gnn_out_feats, num_layers, num_timesteps, | ||
n_tasks, regressor_hidden_feats=128, dropout=0.): | ||
super(AttentiveFPRegressorBypass, self).__init__( | ||
readout_feats= 2 * gnn_out_feats, n_tasks=n_tasks, | ||
regressor_hidden_feats=regressor_hidden_feats, | ||
dropout=dropout) | ||
self.shared_gnn = AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers, | ||
gnn_out_feats, dropout) | ||
for _ in range(n_tasks): | ||
self.task_gnns.append(AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers, | ||
gnn_out_feats, dropout)) | ||
self.readouts.append(AttentiveFPReadout(2 * gnn_out_feats, num_timesteps, dropout)) |
Oops, something went wrong.