Skip to content

Commit

Permalink
[Example] Multitask Molecular Property Prediction with GNN (#159)
Browse files Browse the repository at this point in the history
* Update

* Update

* Fix

* Fix

* Update

* Fix

* Fix

* Fix

* Fix

* Fix

* Update

* Update
  • Loading branch information
mufeili authored Nov 16, 2021
1 parent ef58e80 commit d7522cc
Show file tree
Hide file tree
Showing 13 changed files with 1,181 additions and 9 deletions.
41 changes: 41 additions & 0 deletions examples/property_prediction/MTL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Multitask Graph Neural Network for Molecular Property Prediction

## Usage

```
python -c CSV -m MODEL --mode MODE -p PATH -s SMILES -t TASKS
```

where:
- `CSV` specifies the path to a CSV file for the dataset
- `MODEL` specifies the model to use, which can be `GCN`, `GAT`, `MPNN`, or `AttentiveFP`
- `MODE` specifies the multitask architecture to use, which can be `parallel` or `bypass`
- `PATH` specifies the path to save training results
- `SMILES` specifies the SMIELS column header in the CSV file
- `TASKS` specifies the CSV column headers for the tasks to model. For multiple tasks, separate them by comma, e.g., task1,task2,task3. It not specified, all columns except for the SMILES column will be treated as properties/tasks.

## Example

For demonstration, you can generate a synthetic dataset as follows.

```python
import torch
import pandas as pd

data = {
'smiles': ['CCO' for _ in range(128)],
'logP': torch.randn(128).numpy().tolist(),
'logD': torch.randn(128).numpy().tolist()
}
df = pd.DataFrame(data)
df.to_csv('syn_data.csv', index=False)
```

After you run an experiment with

```
python main.py -c syn_data.csv -m GCN --mode parallel -p results -s smiles -t logP,logD
```

Once the experiment is completed, `results/model.pth` is the trained model checkpoint
and `results/results.txt` is the evaluation result.
108 changes: 108 additions & 0 deletions examples/property_prediction/MTL/configure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
GCN_parallel = {
'gnn_hidden_feats': 64,
'num_gnn_layers': 2,
'regressor_hidden_feats': 64,
'lr': 1e-3,
'weight_decay': 0.,
'dropout': 0.,
'patience': 50,
'batch_size': 128
}

GCN_bypass = {
'gnn_hidden_feats': 128,
'num_gnn_layers': 2,
'regressor_hidden_feats': 32,
'lr': 1e-3,
'weight_decay': 0.,
'dropout': 0.,
'patience': 30,
'batch_size': 128
}

GAT_parallel = {
'gnn_hidden_feats': 32,
'num_gnn_layers': 2,
'num_heads': 6,
'regressor_hidden_feats': 32,
'lr': 3e-3,
'weight_decay': 3e-5,
'dropout': 0.01,
'patience': 100,
'batch_size': 128
}

GAT_bypass = {
'gnn_hidden_feats': 32,
'num_gnn_layers': 3,
'num_heads': 8,
'regressor_hidden_feats': 32,
'lr': 1e-3,
'weight_decay': 3e-6,
'dropout': 0.1,
'patience': 30,
'batch_size': 128
}

MPNN_parallel = {
'node_hidden_dim': 64,
'edge_hidden_dim': 16,
'num_step_message_passing': 2,
'num_step_set2set': 3,
'num_layer_set2set': 2,
'regressor_hidden_feats': 32,
'lr': 1e-3,
'weight_decay': 0.,
'dropout': 0.,
'patience': 50,
'batch_size': 128
}

MPNN_bypass = {
'node_hidden_dim': 32,
'edge_hidden_dim': 64,
'num_step_message_passing': 2,
'num_step_set2set': 2,
'num_layer_set2set': 2,
'regressor_hidden_feats': 32,
'lr': 1e-3,
'weight_decay': 0.,
'dropout': 0.01,
'patience': 50,
'batch_size': 128
}

AttentiveFP_parallel = {
'num_gnn_layers': 3,
'gnn_out_feats': 64,
'num_timesteps': 3,
'regressor_hidden_feats': 32,
'lr': 1e-3,
'weight_decay': 0.,
'dropout': 0.,
'patience': 50,
'batch_size': 32
}

AttentiveFP_bypass = {
'num_gnn_layers': 2,
'gnn_out_feats': 32,
'num_timesteps': 2,
'regressor_hidden_feats': 32,
'lr': 1e-3,
'weight_decay': 0.,
'dropout': 0.,
'patience': 50,
'batch_size': 32
}

configs = {
'GCN_parallel': GCN_parallel,
'GCN_bypass': GCN_bypass,
'GAT_parallel': GAT_parallel,
'GAT_bypass': GAT_bypass,
'MPNN_parallel': MPNN_parallel,
'MPNN_bypass': MPNN_bypass,
'AttentiveFP_parallel': AttentiveFP_parallel,
'AttentiveFP_bypass': AttentiveFP_bypass
}
87 changes: 87 additions & 0 deletions examples/property_prediction/MTL/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from dgllife.utils import atom_type_one_hot, atom_degree_one_hot, \
atom_hybridization_one_hot, atom_is_aromatic_one_hot, \
atom_chiral_tag_one_hot, atom_formal_charge_one_hot, atom_mass, \
atom_implicit_valence_one_hot, BaseAtomFeaturizer, \
ConcatFeaturizer, CanonicalBondFeaturizer
from functools import partial
from rdkit import Chem

atom_featurizer = BaseAtomFeaturizer(
featurizer_funcs={
'hv': ConcatFeaturizer(
[partial(atom_degree_one_hot, allowable_set=[1, 2, 3, 4, 6]),
partial(atom_type_one_hot, allowable_set=[
'B', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'N', 'O', 'P', 'S', 'Se', 'Si']),
atom_chiral_tag_one_hot,
partial(atom_formal_charge_one_hot, allowable_set=[-1, 0, 1]),
partial(atom_hybridization_one_hot, allowable_set=[
Chem.rdchem.HybridizationType.S,
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D2
]),
partial(atom_implicit_valence_one_hot, allowable_set=list(range(4))),
atom_is_aromatic_one_hot, atom_mass,
])}
)

if __name__ == '__main__':
import pandas as pd

from argparse import ArgumentParser
from dgllife.data import MoleculeCSVDataset
from dgllife.utils import smiles_to_bigraph, RandomSplitter

from configure import configs
from run import main
from utils import mkdir_p, setup

parser = ArgumentParser('(Multitask) Molecular Property Prediction with GNNs for a user-specified .csv file.')
parser.add_argument('-c', '--csv-path', type=str, required=True,
help='Path to a csv file for loading a dataset.')
parser.add_argument('-m', '--model', type=str,
choices=['GCN', 'GAT', 'MPNN', 'AttentiveFP'],
help='Model to use')
parser.add_argument('--mode', type=str, choices=['parallel', 'bypass'],
help='Architecture to use for multitask learning')
parser.add_argument('-n', '--num-epochs', type=int, default=4000,
help='Maximum number of epochs allowed for training. '
'We set a large number by default as early stopping will be performed.')
parser.add_argument('-p', '--result-path', type=str, required=True,
help='Path to training results')
parser.add_argument('-s', '--smiles-column', type=str, default='smiles',
help='CSV column header for the SMIELS strings. (default: smiles)')
parser.add_argument('-t', '--tasks', default=None, type=str,
help='CSV column headers for the tasks to model. For multiple tasks, separate them by '
'comma, e.g., task1,task2,task3, ... If None, we will model '
'all the columns except for the smiles_column in the CSV file. '
'(default: None)')
args = parser.parse_args().__dict__

args['exp_name'] = '_'.join([args['model'], args['mode']])
if args['tasks'] is not None:
args['tasks'] = args['tasks'].split(',')
args.update(configs[args['exp_name']])

# Setup for experiments
mkdir_p(args['result_path'])

node_featurizer = atom_featurizer
edge_featurizer = CanonicalBondFeaturizer(bond_data_field='he', self_loop=True)
df = pd.read_csv(args['csv_path'])
dataset = MoleculeCSVDataset(
df, partial(smiles_to_bigraph, add_self_loop=True),
node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer,
smiles_column=args['smiles_column'],
cache_file_path=args['result_path'] + '/graph.bin',
task_names=args['tasks']
)
args['tasks'] = dataset.task_names
args = setup(args)
train_set, val_set, test_set = RandomSplitter.train_val_test_split(
dataset, frac_train=0.8, frac_val=0.1,
frac_test=0.1, random_state=0)

main(args, node_featurizer, edge_featurizer, train_set, val_set, test_set)
4 changes: 4 additions & 0 deletions examples/property_prediction/MTL/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .gcn import GCNRegressor, GCNRegressorBypass
from .gat import GATRegressor, GATRegressorBypass
from .mpnn import MPNNRegressor, MPNNRegressorBypass
from .attentivefp import AttentiveFPRegressor, AttentiveFPRegressorBypass
74 changes: 74 additions & 0 deletions examples/property_prediction/MTL/model/attentivefp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch.nn as nn

from dgllife.model import AttentiveFPGNN, AttentiveFPReadout

from .regressor import BaseGNNRegressor, BaseGNNRegressorBypass

class AttentiveFPRegressor(BaseGNNRegressor):
"""AttentiveFP-based model for multitask molecular property prediction.
We assume all tasks are regression problems.
Parameters
----------
in_node_feats : int
Number of input node features
in_edge_feats : int
Number of input edge features
gnn_out_feats : int
The GNN output size
num_layers : int
Number of GNN layers
num_timesteps : int
Number of timesteps for updating molecular representations with GRU during readout
n_tasks : int
Number of prediction tasks
regressor_hidden_feats : int
Hidden size in MLP regressor
dropout : float
The probability for dropout. Default to 0, i.e. no dropout is performed.
"""
def __init__(self, in_node_feats, in_edge_feats, gnn_out_feats, num_layers, num_timesteps,
n_tasks, regressor_hidden_feats=128, dropout=0.):
super(AttentiveFPRegressor, self).__init__(readout_feats=gnn_out_feats,
n_tasks=n_tasks,
regressor_hidden_feats=regressor_hidden_feats,
dropout=dropout)
self.gnn = AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers,
gnn_out_feats, dropout)
self.readout = AttentiveFPReadout(gnn_out_feats, num_timesteps, dropout)

class AttentiveFPRegressorBypass(BaseGNNRegressorBypass):
"""AttentiveFP-based model for bypass multitask molecular property prediction.
We assume all tasks are regression problems.
Parameters
----------
in_node_feats : int
Number of input node features
in_edge_feats : int
Number of input edge features
gnn_out_feats : int
The GNN output size
num_layers : int
Number of GNN layers
num_timesteps : int
Number of timesteps for updating molecular representations with GRU during readout
n_tasks : int
Number of prediction tasks
regressor_hidden_feats : int
Hidden size in MLP regressor
dropout : float
The probability for dropout. Default to 0, i.e. no dropout is performed.
"""
def __init__(self, in_node_feats, in_edge_feats, gnn_out_feats, num_layers, num_timesteps,
n_tasks, regressor_hidden_feats=128, dropout=0.):
super(AttentiveFPRegressorBypass, self).__init__(
readout_feats= 2 * gnn_out_feats, n_tasks=n_tasks,
regressor_hidden_feats=regressor_hidden_feats,
dropout=dropout)
self.shared_gnn = AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers,
gnn_out_feats, dropout)
for _ in range(n_tasks):
self.task_gnns.append(AttentiveFPGNN(in_node_feats, in_edge_feats, num_layers,
gnn_out_feats, dropout))
self.readouts.append(AttentiveFPReadout(2 * gnn_out_feats, num_timesteps, dropout))
Loading

0 comments on commit d7522cc

Please sign in to comment.