Diff of /src/data/dataset.py [000000] .. [7d53f6]

Switch to unified view

a b/src/data/dataset.py
1
import os
2
import os.path as osp
3
import re
4
import pickle
5
6
import numpy as np
7
import pandas as pd
8
from tqdm import tqdm
9
10
import torch
11
from torch_geometric.data import Data, InMemoryDataset
12
13
from rdkit import Chem, RDLogger
14
15
from src.data.utils import label2onehot
16
17
RDLogger.DisableLog('rdApp.*') 
18
19
20
class DruggenDataset(InMemoryDataset):
21
    def __init__(self, root, dataset_file, raw_files, max_atom, features, 
22
                 atom_encoder, atom_decoder, bond_encoder, bond_decoder,
23
                 transform=None, pre_transform=None, pre_filter=None):
24
        """
25
        Initialize the DruggenDataset with pre-loaded encoder/decoder dictionaries.
26
27
        Parameters:
28
            root (str): Root directory.
29
            dataset_file (str): Name of the processed dataset file.
30
            raw_files (str): Path to the raw SMILES file.
31
            max_atom (int): Maximum number of atoms allowed in a molecule.
32
            features (bool): Whether to include additional node features.
33
            atom_encoder (dict): Pre-loaded atom encoder dictionary.
34
            atom_decoder (dict): Pre-loaded atom decoder dictionary.
35
            bond_encoder (dict): Pre-loaded bond encoder dictionary.
36
            bond_decoder (dict): Pre-loaded bond decoder dictionary.
37
            transform, pre_transform, pre_filter: See PyG InMemoryDataset.
38
        """
39
        self.dataset_name = dataset_file.split(".")[0]
40
        self.dataset_file = dataset_file
41
        self.raw_files = raw_files
42
        self.max_atom = max_atom
43
        self.features = features
44
45
        # Use the provided encoder/decoder mappings.
46
        self.atom_encoder_m = atom_encoder
47
        self.atom_decoder_m = atom_decoder
48
        self.bond_encoder_m = bond_encoder
49
        self.bond_decoder_m = bond_decoder
50
51
        self.atom_num_types = len(atom_encoder)
52
        self.bond_num_types = len(bond_encoder)
53
54
        super().__init__(root, transform, pre_transform, pre_filter)
55
        path = osp.join(self.processed_dir, dataset_file)
56
        self.data, self.slices = torch.load(path)
57
        self.root = root
58
59
    @property
60
    def processed_dir(self):
61
        """
62
        Returns the directory where processed dataset files are stored.
63
        """
64
        return self.root
65
    
66
    @property
67
    def raw_file_names(self):
68
        """
69
        Returns the raw SMILES file name.
70
        """
71
        return self.raw_files
72
73
    @property
74
    def processed_file_names(self):
75
        """
76
        Returns the name of the processed dataset file.
77
        """
78
        return self.dataset_file
79
80
    def _filter_smiles(self, smiles_list):
81
        """
82
        Filters the input list of SMILES strings to keep only valid molecules that:
83
         - Can be successfully parsed,
84
         - Have a number of atoms less than or equal to the maximum allowed (max_atom),
85
         - Contain only atoms present in the atom_encoder,
86
         - Contain only bonds present in the bond_encoder.
87
88
        Parameters:
89
            smiles_list (list): List of SMILES strings.
90
91
        Returns:
92
            max_length (int): Maximum number of atoms found in the filtered molecules.
93
            filtered_smiles (list): List of valid SMILES strings.
94
        """
95
        max_length = 0
96
        filtered_smiles = []
97
        for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
98
            mol = Chem.MolFromSmiles(smiles)
99
            if mol is None:
100
                continue
101
102
            # Check molecule size
103
            molecule_size = mol.GetNumAtoms()
104
            if molecule_size > self.max_atom:
105
                continue
106
107
            # Filter out molecules with atoms not in the atom_encoder
108
            if not all(atom.GetAtomicNum() in self.atom_encoder_m for atom in mol.GetAtoms()):
109
                continue
110
111
            # Filter out molecules with bonds not in the bond_encoder
112
            if not all(bond.GetBondType() in self.bond_encoder_m for bond in mol.GetBonds()):
113
                continue
114
115
            filtered_smiles.append(smiles)
116
            max_length = max(max_length, molecule_size)
117
        return max_length, filtered_smiles
118
119
    def _genA(self, mol, connected=True, max_length=None):
120
        """
121
        Generates the adjacency matrix for a molecule based on its bond structure.
122
123
        Parameters:
124
            mol (rdkit.Chem.Mol): The molecule.
125
            connected (bool): If True, ensures all atoms are connected.
126
            max_length (int, optional): The size of the matrix; if None, uses number of atoms in mol.
127
128
        Returns:
129
            np.array: Adjacency matrix with bond types as entries, or None if disconnected.
130
        """
131
        max_length = max_length if max_length is not None else mol.GetNumAtoms()
132
        A = np.zeros((max_length, max_length))
133
        begin = [b.GetBeginAtomIdx() for b in mol.GetBonds()]
134
        end = [b.GetEndAtomIdx() for b in mol.GetBonds()]
135
        bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
136
        A[begin, end] = bond_type
137
        A[end, begin] = bond_type
138
        degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
139
        return A if connected and (degree > 0).all() else None
140
141
    def _genX(self, mol, max_length=None):
142
        """
143
        Generates the feature vector for each atom in a molecule by encoding their atomic numbers.
144
145
        Parameters:
146
            mol (rdkit.Chem.Mol): The molecule.
147
            max_length (int, optional): Length of the feature vector; if None, uses number of atoms in mol.
148
149
        Returns:
150
            np.array: Array of atom feature indices, padded with zeros if necessary, or None on error.
151
        """
152
        max_length = max_length if max_length is not None else mol.GetNumAtoms()
153
        try:
154
            return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
155
                            [0] * (max_length - mol.GetNumAtoms()))
156
        except KeyError as e:
157
            print(f"Skipping molecule with unsupported atom: {e}")
158
            print(f"Skipped SMILES: {Chem.MolToSmiles(mol)}")
159
            return None
160
161
    def _genF(self, mol, max_length=None):
162
        """
163
        Generates additional node features for a molecule using various atomic properties.
164
165
        Parameters:
166
            mol (rdkit.Chem.Mol): The molecule.
167
            max_length (int, optional): Number of rows in the features matrix; if None, uses number of atoms.
168
169
        Returns:
170
            np.array: Array of additional features for each atom, padded with zeros if necessary.
171
        """
172
        max_length = max_length if max_length is not None else mol.GetNumAtoms()
173
        features = np.array([[*[a.GetDegree() == i for i in range(5)],
174
                               *[a.GetExplicitValence() == i for i in range(9)],
175
                               *[int(a.GetHybridization()) == i for i in range(1, 7)],
176
                               *[a.GetImplicitValence() == i for i in range(9)],
177
                               a.GetIsAromatic(),
178
                               a.GetNoImplicit(),
179
                               *[a.GetNumExplicitHs() == i for i in range(5)],
180
                               *[a.GetNumImplicitHs() == i for i in range(5)],
181
                               *[a.GetNumRadicalElectrons() == i for i in range(5)],
182
                               a.IsInRing(),
183
                               *[a.IsInRingSize(i) for i in range(2, 9)]]
184
                              for a in mol.GetAtoms()], dtype=np.int32)
185
        return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
186
187
    def decoder_load(self, dictionary_name, file):
188
        """
189
        Returns the pre-loaded decoder dictionary based on the dictionary name.
190
191
        Parameters:
192
            dictionary_name (str): Name of the dictionary ("atom" or "bond").
193
            file: Placeholder parameter for compatibility.
194
195
        Returns:
196
            dict: The corresponding decoder dictionary.
197
        """
198
        if dictionary_name == "atom":
199
            return self.atom_decoder_m
200
        elif dictionary_name == "bond":
201
            return self.bond_decoder_m
202
        else:
203
            raise ValueError("Unknown dictionary name.")
204
205
    def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
206
        """
207
        Converts graph representations (node labels and edge labels) back to an RDKit molecule.
208
209
        Parameters:
210
            node_labels (iterable): Encoded atom labels.
211
            edge_labels (np.array): Adjacency matrix with encoded bond types.
212
            strict (bool): If True, sanitizes the molecule and returns None on failure.
213
            file_name: Placeholder parameter for compatibility.
214
215
        Returns:
216
            rdkit.Chem.Mol: The resulting molecule, or None if sanitization fails.
217
        """
218
        mol = Chem.RWMol()
219
        for node_label in node_labels:
220
            mol.AddAtom(Chem.Atom(self.atom_decoder_m[node_label]))
221
        for start, end in zip(*np.nonzero(edge_labels)):
222
            if start > end:
223
                mol.AddBond(int(start), int(end), self.bond_decoder_m[edge_labels[start, end]])
224
        if strict:
225
            try:
226
                Chem.SanitizeMol(mol)
227
            except Exception:
228
                mol = None
229
        return mol
230
231
    def check_valency(self, mol):
232
        """
233
        Checks that no atom in the molecule has exceeded its allowed valency.
234
235
        Parameters:
236
            mol (rdkit.Chem.Mol): The molecule.
237
238
        Returns:
239
            tuple: (True, None) if valid; (False, atomid_valence) if there is a valency issue.
240
        """
241
        try:
242
            Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
243
            return True, None
244
        except ValueError as e:
245
            e = str(e)
246
            p = e.find('#')
247
            e_sub = e[p:]
248
            atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
249
            return False, atomid_valence
250
251
    def correct_mol(self, mol):
252
        """
253
        Corrects a molecule by removing bonds until all atoms satisfy their valency limits.
254
255
        Parameters:
256
            mol (rdkit.Chem.Mol): The molecule.
257
258
        Returns:
259
            rdkit.Chem.Mol: The corrected molecule.
260
        """
261
        while True:
262
            flag, atomid_valence = self.check_valency(mol)
263
            if flag:
264
                break
265
            else:
266
                # Expecting two numbers: atom index and its valence.
267
                assert len(atomid_valence) == 2
268
                idx = atomid_valence[0]
269
                queue = []
270
                for b in mol.GetAtomWithIdx(idx).GetBonds():
271
                    queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
272
                queue.sort(key=lambda tup: tup[1], reverse=True)
273
                if queue:
274
                    start = queue[0][2]
275
                    end = queue[0][3]
276
                    mol.RemoveBond(start, end)
277
        return mol
278
279
280
    def process(self, size=None):
281
        """
282
        Processes the raw SMILES file by filtering and converting each valid SMILES into a PyTorch Geometric Data object.
283
        The resulting dataset is saved to disk.
284
285
        Parameters:
286
            size (optional): Placeholder parameter for compatibility.
287
288
        Side Effects:
289
            Saves the processed dataset as a file in the processed directory.
290
        """
291
        # Read raw SMILES from file (assuming CSV with no header)
292
        smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
293
        max_length, filtered_smiles = self._filter_smiles(smiles_list)
294
        data_list = []
295
        self.m_dim = len(self.atom_decoder_m)
296
        for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
297
            mol = Chem.MolFromSmiles(smiles)
298
            A = self._genA(mol, connected=True, max_length=max_length)
299
            if A is not None:
300
                x_array = self._genX(mol, max_length=max_length)
301
                if x_array is None:
302
                    continue
303
                x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
304
                x = label2onehot(x, self.m_dim).squeeze()
305
                if self.features:
306
                    f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
307
                    x = torch.concat((x, f), dim=-1)
308
                adjacency = torch.from_numpy(A)
309
                edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
310
                edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
311
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
312
                if self.pre_filter is not None and not self.pre_filter(data):
313
                    continue
314
                if self.pre_transform is not None:
315
                    data = self.pre_transform(data)
316
                data_list.append(data)
317
        torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))