As you know PyG is one of the useful package for graph based neural network as same as DGL-lifesci. Fortunately recent version of PyG is easy to install because it supports conda. So to install PyG, user don't need to install related package such as pytorch_scatter, pytorch-cluster etc. etc. And PyG has lots of predefined models the list of them is listed in original document. AttentiveFP which is model for molecular representation learning is one of them. I wrote post about attentivefp with DGL before so I tried to use PyG attentiveFP today. An example of attentivefp is provided from original repo. However the example uses torch_geometric.datasets.MoleculeNet class for data preparation so the available dataset is limited for dataset from MoleculeNet. I would like to local dataset with the model. To do it, I modified original code and tried it. The code is below. Following example, I used Esol data which is downloaded from molecule net but downloaded file before running the code. I defined Molecule class which load local csv and process it. The difference of original MoleculeNet class is that the class don't download data from web but the data comes from local file. # most of code is came from original PyG repo # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/attentive_fp.py import os.path as osp from math import sqrt import torch import torch.nn.functional as F from rdkit import Chem from torch_geometric.loader import DataLoader from torch_geometric.datasets import MoleculeNet from torch_geometric.nn.models import AttentiveFP import torch class GenFeatures(object): def __init__(self): self.symbols = [ 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At', 'other' ] self.hybridizations = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, 'other', ] self.stereos = [ Chem.rdchem.BondStereo.STEREONONE, Chem.rdchem.BondStereo.STEREOANY, Chem.rdchem.BondStereo.STEREOZ, Chem.rdchem.BondStereo.STEREOE, ] def __call__(self, data): # Generate AttentiveFP features according to Table 1. mol = Chem.MolFromSmiles(data.smiles) xs = [] for atom in mol.GetAtoms(): symbol = [0.] * len(self.symbols) symbol[self.symbols.index(atom.GetSymbol())] = 1. degree = [0.] * 6 degree[atom.GetDegree()] = 1. formal_charge = atom.GetFormalCharge() radical_electrons = atom.GetNumRadicalElectrons() hybridization = [0.] * len(self.hybridizations) hybridization[self.hybridizations.index( atom.GetHybridization())] = 1. aromaticity = 1. if atom.GetIsAromatic() else 0. hydrogens = [0.] * 5 hydrogens[atom.GetTotalNumHs()] = 1. chirality = 1. if atom.HasProp('_ChiralityPossible') else 0. chirality_type = [0.] * 2 if atom.HasProp('_CIPCode'): chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1. x = torch.tensor(symbol + degree + [formal_charge] + [radical_electrons] + hybridization + [aromaticity] + hydrogens + [chirality] + chirality_type) xs.append(x) data.x = torch.stack(xs, dim=0) edge_indices = [] edge_attrs = [] for bond in mol.GetBonds(): edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]] edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]] bond_type = bond.GetBondType() single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0. double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0. triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0. aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0. conjugation = 1. if bond.GetIsConjugated() else 0. ring = 1. if bond.IsInRing() else 0. stereo = [0.] * 4 stereo[self.stereos.index(bond.GetStereo())] = 1. edge_attr = torch.tensor( [single, double, triple, aromatic, conjugation, ring] + stereo) edge_attrs += [edge_attr, edge_attr] if len(edge_attrs) == 0: data.edge_index = torch.zeros((2, 0), dtype=torch.long) data.edge_attr = torch.zeros((0, 10), dtype=torch.float) else: data.edge_index = torch.tensor(edge_indices).t().contiguous() data.edge_attr = torch.stack(edge_attrs, dim=0) return data x_map = { 'atomic_num': list(range(0, 119)), 'chirality': [ 'CHI_UNSPECIFIED', 'CHI_TETRAHEDRAL_CW', 'CHI_TETRAHEDRAL_CCW', 'CHI_OTHER', ], 'degree': list(range(0, 11)), 'formal_charge': list(range(-5, 7)), 'num_hs': list(range(0, 9)), 'num_radical_electrons': list(range(0, 5)), 'hybridization': [ 'UNSPECIFIED', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'OTHER', ], 'is_aromatic': [False, True], 'is_in_ring': [False, True], } e_map = { 'bond_type': [ 'misc', 'SINGLE', 'DOUBLE', 'TRIPLE', 'AROMATIC', ], 'stereo': [ 'STEREONONE', 'STEREOZ', 'STEREOE', 'STEREOCIS', 'STEREOTRANS', 'STEREOANY', ], 'is_conjugated': [False, True], } import torch from torch_geometric.data import (InMemoryDataset, Data) import re class Molecule(InMemoryDataset): r"""The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark collection from the `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets from physical chemistry, biophysics and physiology. All datasets come with the additional node and edge features introduced by the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_. Args: root_dir (string): Root directory. name (string): The name of dataset (csv format) smi_idx (integer): index of smiles column target_idx (integer): index of target column transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """ def __init__(self, root_dir, name, smi_idx, target_idx, transform=None, pre_transform=None, pre_filter=None): self.root_dir = root_dir self.name = name self.smi_idx = smi_idx self.target_idx = target_idx #skip calling data super(Molecule, self).__init__(None, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self): return osp.join(self.root_dir, 'raw') @property def processed_dir(self): return osp.join(self.root_dir,'processed') @property def raw_file_names(self): return f'{self.name}' @property def processed_file_names(self): return 'data.pt' def process(self): from rdkit import Chem with open(self.raw_file_names, 'r') as f: dataset = f.read().split('\n')[1:-1] dataset = [x for x in dataset if len(x) > 0] # Filter empty lines. data_list = [] for line in dataset: line = re.sub(r'\".*\"', '', line) # Replace ".*" strings. line = line.split(',') smiles = line[self.smi_idx] ys = line[self.target_idx] ys = ys if isinstance(ys, list) else [ys] ys = [float(y) if len(y) > 0 else float('NaN') for y in ys] y = torch.tensor(ys, dtype=torch.float).view(1, -1) mol = Chem.MolFromSmiles(smiles) if mol is None: continue xs = [] for atom in mol.GetAtoms(): x = [] x.append(x_map['atomic_num'].index(atom.GetAtomicNum())) x.append(x_map['chirality'].index(str(atom.GetChiralTag()))) x.append(x_map['degree'].index(atom.GetTotalDegree())) x.append(x_map['formal_charge'].index(atom.GetFormalCharge())) x.append(x_map['num_hs'].index(atom.GetTotalNumHs())) x.append(x_map['num_radical_electrons'].index( atom.GetNumRadicalElectrons())) x.append(x_map['hybridization'].index( str(atom.GetHybridization()))) x.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) x.append(x_map['is_in_ring'].index(atom.IsInRing())) xs.append(x) x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() e = [] e.append(e_map['bond_type'].index(str(bond.GetBondType()))) e.append(e_map['stereo'].index(str(bond.GetStereo()))) e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) edge_indices += [[i, j], [j, i]] edge_attrs += [e, e] edge_index = torch.tensor(edge_indices) edge_index = edge_index.t().to(torch.long).view(2, -1) edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3) # Sort indices. if edge_index.numel() > 0: perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, smiles=smiles) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) torch.save(self.collate(data_list), self.processed_paths[0]) def __repr__(self): return '{}({})'.format(self.names[self.name][0], len(self)) The definition of AttentiveFP is not required because PyG has already defined model, it's easy to use the model just import it device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1, edge_dim=10, num_layers=2, num_timesteps=2, dropout=0.2).to(device) print(model) > AttentiveFP( (lin1): Linear(in_features=39, out_features=200, bias=True) (atom_convs): ModuleList( (0): GATEConv( (lin1): Linear(in_features=210, out_features=200, bias=False) (lin2): Linear(in_features=200, out_features=200, bias=False) ) (1): GATConv(200, 200, heads=1) ) (atom_grus): ModuleList( (0): GRUCell(200, 200) (1): GRUCell(200, 200) ) (mol_conv): GATConv(200, 200, heads=1) (mol_gru): GRUCell(200, 200) (lin2): Linear(in_features=200, out_features=1, bias=True) ) Then load data and split it to train/test/val. dataset = Molecule(root_dir='/home/iwatobipen/dev/data/AFP_Mol/esol/testf', name='/home/iwatobipen/dev/data/AFP_Mol/esol/testf/delaney-processed.csv', smi_idx=-1, target_idx=-2, pre_transform=GenFeatures()).shuffle() N = len(dataset) // 10 val_dataset = dataset[:N] test_dataset = dataset[N:2 * N] train_dataset = dataset[2 * N:] train_loader = DataLoader(train_dataset, batch_size=200, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=200) test_loader = DataLoader(test_dataset, batch_size=200) Now almost there, let's train the model. optimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5, weight_decay=10**-5) def train(): total_loss = total_examples = 0 for data in train_loader: data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.edge_attr, data.batch) loss = F.mse_loss(out, data.y) loss.backward() optimizer.step() total_loss += float(loss) * data.num_graphs total_examples += data.num_graphs return sqrt(total_loss / total_examples) @torch.no_grad() def test(loader): mse = [] for data in loader: data = data.to(device) out = model(data.x, data.edge_index, data.edge_attr, data.batch) mse.append(F.mse_loss(out, data.y, reduction='none').cpu()) return float(torch.cat(mse, dim=0).mean().sqrt()) for epoch in range(1, 20): train_rmse = train() val_rmse = test(val_loader) test_rmse = test(test_loader) print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} ' f'Test: {test_rmse:.4f}') > Epoch: 001, Loss: 3.4306 Val: 2.7534 Test: 2.4873 Epoch: 002, Loss: 2.4330 Val: 2.2135 Test: 1.9825 Epoch: 003, Loss: 1.7889 Val: 2.0010 Test: 2.0716 Epoch: 004, Loss: 1.7881 Val: 1.8208 Test: 1.8639 Epoch: 005, Loss: 1.7114 Val: 1.7611 Test: 1.7465 Epoch: 006, Loss: 1.6452 Val: 1.7461 Test: 1.7135 Epoch: 007, Loss: 1.5740 Val: 1.6480 Test: 1.6811 Epoch: 008, Loss: 1.5113 Val: 1.4189 Test: 1.4966 Epoch: 009, Loss: 1.3412 Val: 1.2268 Test: 1.3886 Epoch: 010, Loss: 1.2381 Val: 1.1057 Test: 1.2172 Epoch: 011, Loss: 1.1652 Val: 1.0242 Test: 1.0864 Epoch: 012, Loss: 1.1103 Val: 1.0130 Test: 1.0764 Epoch: 013, Loss: 1.0821 Val: 0.9757 Test: 0.9694 Epoch: 014, Loss: 1.0460 Val: 1.0927 Test: 1.0020 Epoch: 015, Loss: 1.0466 Val: 1.0505 Test: 0.9732 Epoch: 016, Loss: 1.0152 Val: 0.9447 Test: 0.9786 Epoch: 017, Loss: 0.9574 Val: 0.8512 Test: 0.9221 Epoch: 018, Loss: 0.9336 Val: 0.7625 Test: 0.8277 Epoch: 019, Loss: 0.9327 Val: 0.8616 Test: 0.9142 Now training is finished so try to use model for prediction. To do it, I defined process_mol function which process smiles to graph data. def process_mol(smiles_list, pre_transform, pre_filter=None): from rdkit import Chem data_list = [] for smi in smiles_list: smi = re.sub(r'\".*\"', '', smi) # Replace ".*" strings. smiles = smi.split(',')[0] mol = Chem.MolFromSmiles(smi) if mol is None: continue xs = [] for atom in mol.GetAtoms(): x = [] x.append(x_map['atomic_num'].index(atom.GetAtomicNum())) x.append(x_map['chirality'].index(str(atom.GetChiralTag()))) x.append(x_map['degree'].index(atom.GetTotalDegree())) x.append(x_map['formal_charge'].index(atom.GetFormalCharge())) x.append(x_map['num_hs'].index(atom.GetTotalNumHs())) x.append(x_map['num_radical_electrons'].index( atom.GetNumRadicalElectrons())) x.append(x_map['hybridization'].index( str(atom.GetHybridization()))) x.append(x_map['is_aromatic'].index(atom.GetIsAromatic())) x.append(x_map['is_in_ring'].index(atom.IsInRing())) xs.append(x) x = torch.tensor(xs, dtype=torch.long).view(-1, 9) edge_indices, edge_attrs = [], [] for bond in mol.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() e = [] e.append(e_map['bond_type'].index(str(bond.GetBondType()))) e.append(e_map['stereo'].index(str(bond.GetStereo()))) e.append(e_map['is_conjugated'].index(bond.GetIsConjugated())) edge_indices += [[i, j], [j, i]] edge_attrs += [e, e] edge_index = torch.tensor(edge_indices) edge_index = edge_index.t().to(torch.long).view(2, -1) edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3) # Sort indices. if edge_index.numel() > 0: perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() edge_index, edge_attr = edge_index[:, perm], edge_attr[perm] data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None, smiles=smiles) if pre_filter is not None and not pre_filter(data): continue if pre_transform is not None: data = pre_transform(data) data_list.append(data) return data_list Let's make dataset and predict their properties. dataset = process_mol(['CCCC', 'OCCO', 'c1ccccc1'], pre_transform=GenFeatures()) print(model(dataset[0].x, dataset[0].edge_index, dataset[0].edge_attr, torch.tensor([0]))) print(model(dataset[1].x, dataset[1].edge_index, dataset[1].edge_attr, torch.tensor([0]))) print(model(dataset[2].x, dataset[2].edge_index, dataset[2].edge_attr, torch.tensor([0]))) > tensor([[-1.6434]], grad_fn=<AddmmBackward>) tensor([[0.0233]], grad_fn=<AddmmBackward>) tensor([[-1.5886]], grad_fn=<AddmmBackward> The model predicts ethylenglycole is soluble and benzene isn't soluble. It seems reasonable. In summary recent version of PyG seems more chemoinformatics friendly but there are no native function to read molecules and convert them to graph object. I would like to wrote helper function of PyG for chemoinformatics. |
|
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.