Diff of /inference.py [000000] .. [7d53f6]

Switch to unified view

a b/inference.py
1
import os
2
import sys
3
import time
4
import random
5
import pickle
6
import argparse
7
import os.path as osp
8
9
import torch
10
import torch.utils.data
11
from torch_geometric.loader import DataLoader
12
13
import pandas as pd
14
from tqdm import tqdm
15
16
from rdkit import RDLogger, Chem
17
from rdkit.Chem import QED, RDConfig
18
19
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
20
import sascorer
21
22
from src.util.utils import *
23
from src.model.models import Generator
24
from src.data.dataset import DruggenDataset
25
from src.data.utils import get_encoders_decoders, load_molecules
26
from src.model.loss import generator_loss
27
from src.util.smiles_cor import smi_correct
28
29
30
class Inference(object):
31
    """Inference class for DrugGEN."""
32
33
    def __init__(self, config):
34
        if config.set_seed:
35
            np.random.seed(config.seed)
36
            random.seed(config.seed)
37
            torch.manual_seed(config.seed)
38
            torch.cuda.manual_seed_all(config.seed)
39
40
            torch.backends.cudnn.deterministic = True
41
            torch.backends.cudnn.benchmark = False
42
43
            os.environ["PYTHONHASHSEED"] = str(config.seed)
44
45
            print(f'Using seed {config.seed}')
46
47
        self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
48
49
        # Initialize configurations
50
        self.submodel = config.submodel
51
        self.inference_model = config.inference_model
52
        self.sample_num = config.sample_num
53
        self.disable_correction = config.disable_correction
54
55
        # Data loader.
56
        self.inf_smiles = config.inf_smiles  # SMILES containing text file for first dataset. 
57
                                         # Write the full path to file.
58
        
59
        inf_smiles_basename = osp.basename(self.inf_smiles)
60
        
61
        # Get the base name without extension and add max_atom to it
62
        self.max_atom = config.max_atom  # Model is based on one-shot generation.
63
        inf_smiles_base = os.path.splitext(inf_smiles_basename)[0]
64
        
65
        # Change extension from .smi to .pt and add max_atom to the filename
66
        self.inf_dataset_file = f"{inf_smiles_base}{self.max_atom}.pt"
67
68
        self.inf_batch_size = config.inf_batch_size
69
        self.train_smiles = config.train_smiles
70
        self.train_drug_smiles = config.train_drug_smiles
71
        self.mol_data_dir = config.mol_data_dir  # Directory where the dataset files are stored.
72
        self.dataset_name = self.inf_dataset_file.split(".")[0]
73
        self.features = config.features  # Small model uses atom types as node features. (Boolean, False uses atom types only.)
74
                                         # Additional node features can be added. Please check new_dataloarder.py Line 102.
75
76
        # Get atom and bond encoders/decoders
77
        self.atom_encoder, self.atom_decoder, self.bond_encoder, self.bond_decoder = get_encoders_decoders(
78
            self.train_smiles,
79
            self.train_drug_smiles,
80
            self.max_atom
81
        )
82
83
        self.inf_dataset = DruggenDataset(self.mol_data_dir,
84
                                      self.inf_dataset_file,
85
                                      self.inf_smiles,
86
                                      self.max_atom,
87
                                      self.features,
88
                                      atom_encoder=self.atom_encoder,
89
                                      atom_decoder=self.atom_decoder,
90
                                      bond_encoder=self.bond_encoder,
91
                                      bond_decoder=self.bond_decoder)
92
93
        self.inf_loader = DataLoader(self.inf_dataset,
94
                                 shuffle=True,
95
                                 batch_size=self.inf_batch_size,
96
                                 drop_last=True)  # PyG dataloader for the first GAN.
97
98
        self.m_dim = len(self.atom_decoder) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
99
        self.b_dim = len(self.bond_decoder) # Bond type dimension.
100
        self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
101
102
        # Model configurations.
103
        self.act = config.act
104
        self.dim = config.dim
105
        self.depth = config.depth
106
        self.heads = config.heads
107
        self.mlp_ratio = config.mlp_ratio
108
        self.dropout = config.dropout
109
110
        self.build_model()
111
112
    def build_model(self):
113
        """Create generators and discriminators."""
114
        self.G = Generator(self.act,
115
                           self.vertexes,
116
                           self.b_dim,
117
                           self.m_dim,
118
                           self.dropout,
119
                           dim=self.dim,
120
                           depth=self.depth,
121
                           heads=self.heads,
122
                           mlp_ratio=self.mlp_ratio)
123
        self.G.to(self.device)
124
        self.print_network(self.G, 'G')
125
126
    def print_network(self, model, name):
127
        """Print out the network information."""
128
        num_params = 0
129
        for p in model.parameters():
130
            num_params += p.numel() 
131
        print(model)
132
        print(name)
133
        print("The number of parameters: {}".format(num_params))
134
135
    def restore_model(self, submodel, model_directory):
136
        """Restore the trained generator and discriminator."""
137
        print('Loading the model...')
138
        G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
139
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
140
141
    def inference(self):
142
        # Load the trained generator.
143
        self.restore_model(self.submodel, self.inference_model)
144
145
        # smiles data for metrics calculation.
146
        chembl_smiles = [line for line in open(self.train_smiles, 'r').read().splitlines()]
147
        chembl_test = [line for line in open(self.inf_smiles, 'r').read().splitlines()]
148
        drug_smiles = [line for line in open(self.train_drug_smiles, 'r').read().splitlines()]
149
        drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
150
        drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
151
152
153
        # Make directories if not exist.
154
        if not os.path.exists("experiments/inference/{}".format(self.submodel)):
155
            os.makedirs("experiments/inference/{}".format(self.submodel))
156
157
        if not self.disable_correction:
158
            correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
159
160
        search_res = pd.DataFrame(columns=["submodel", "validity",
161
                                           "uniqueness", "novelty",
162
                                           "novelty_test", "drug_novelty",
163
                                           "max_len", "mean_atom_type",
164
                                           "snn_chembl", "snn_drug", "IntDiv", "qed", "sa"])
165
166
        self.G.eval()
167
168
        start_time = time.time()
169
        metric_calc_dr = []
170
        uniqueness_calc = []
171
        real_smiles_snn = []
172
        nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
173
        f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w")
174
        f.write("SMILES")
175
        f.write("\n")
176
        val_counter = 0
177
        none_counter = 0
178
179
        # Inference mode
180
        with torch.inference_mode():
181
            pbar = tqdm(range(self.sample_num))
182
            pbar.set_description('Inference mode for {} model started'.format(self.submodel))
183
            for i, data in enumerate(self.inf_loader):
184
185
                val_counter += 1
186
                # Preprocess dataset 
187
                _, a_tensor, x_tensor = load_molecules(
188
                    data=data, 
189
                    batch_size=self.inf_batch_size,
190
                    device=self.device,
191
                    b_dim=self.b_dim,
192
                    m_dim=self.m_dim,
193
                )
194
195
                _, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
196
197
                g_edges_hat_sample = torch.max(edge_sample, -1)[1]
198
                g_nodes_hat_sample = torch.max(node_sample, -1)[1]
199
200
                fake_mol_g = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name) 
201
                        for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
202
203
                a_tensor_sample = torch.max(a_tensor, -1)[1]
204
                x_tensor_sample = torch.max(x_tensor, -1)[1]
205
                real_mols = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name) 
206
                        for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
207
208
                inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
209
                inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
210
211
                for molecules in inference_drugs:
212
                            if molecules is None:
213
                                none_counter += 1
214
215
                for molecules in inference_drugs:
216
                    if molecules is not None:
217
                        molecules = molecules.replace("*", "C") 
218
                        f.write(molecules)
219
                        f.write("\n")
220
                        uniqueness_calc.append(molecules)
221
                        nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
222
                        pbar.update(1)
223
                    metric_calc_dr.append(molecules)
224
225
                real_smiles_snn.append(real_mols[0])
226
                generation_number = len([x for x in metric_calc_dr if x is not None])
227
                if generation_number == self.sample_num or none_counter == self.sample_num:
228
                    break
229
230
        f.close()
231
        print("Inference completed, starting metrics calculation.")
232
        if not self.disable_correction:
233
            corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
234
            gen_smi = corrected["SMILES"].tolist()
235
            
236
        else:
237
            gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist()
238
            
239
            
240
        et = time.time() - start_time
241
        
242
        gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
243
        real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
244
        print("Inference mode is lasted for {:.2f} seconds".format(et))
245
246
        print("Metrics calculation started using MOSES.")
247
        
248
        if not self.disable_correction:
249
            val = round(len(gen_smi)/self.sample_num, 3)
250
            print("Validity: ", val, "\n")
251
        else: 
252
            val = round(fraction_valid(gen_smi), 3)
253
            print("Validity: ", val, "\n")
254
255
        uniq = round(fraction_unique(gen_smi), 3)
256
        nov = round(novelty(gen_smi, chembl_smiles), 3)
257
        nov_test = round(novelty(gen_smi, chembl_test), 3)
258
        drug_nov = round(novelty(gen_smi, drug_smiles), 3)
259
        max_len = round(Metrics.max_component(gen_smi, self.vertexes), 3)
260
        mean_atom = round(Metrics.mean_atom_type(nodes_sample), 3)
261
        snn_chembl = round(average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)), 3)
262
        snn_drug = round(average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)), 3)
263
        int_div = round((internal_diversity(np.array(gen_vecs)))[0], 3)
264
        qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
265
        sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
266
267
        print("Uniqueness: ", uniq, "\n")
268
        print("Novelty (Train): ", nov, "\n")
269
        print("Novelty (Inference): ", nov_test, "\n")
270
        print("Novelty (Real Inhibitors): ", drug_nov, "\n")
271
        print("Average Length: ", max_len, "\n")
272
        print("Mean Atom Type: ", mean_atom, "\n")
273
        print("SNN (ChEMBL): ", snn_chembl, "\n")
274
        print("SNN (Real Inhibitors): ", snn_drug, "\n")
275
        print("Internal Diversity: ", int_div, "\n")
276
        print("QED: ", qed, "\n")
277
        print("SA: ", sa, "\n")
278
279
        print("Metrics are calculated.")
280
        model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
281
                        "uniqueness": [uniq], "novelty": [nov],
282
                        "novelty_inference": [nov_test], "novelty_real_inhibitor": [drug_nov],
283
                        "ave_len": [max_len], "mean_atom_type": [mean_atom],
284
                        "snn_chembl": [snn_chembl], "snn_real_inhibitor": [snn_drug], 
285
                        "IntDiv": [int_div], "qed": [qed], "sa": [sa]})
286
        search_res = pd.concat([search_res, model_res], axis=0)
287
        os.remove("experiments/inference/{}/inference_drugs.txt".format(self.submodel))
288
        search_res.to_csv("experiments/inference/{}/inference_results.csv".format(self.submodel), index=False)
289
        generatedsmiles = pd.DataFrame({"SMILES": gen_smi})
290
        generatedsmiles.to_csv("experiments/inference/{}/inference_drugs.csv".format(self.submodel), index=False)
291
292
293
if __name__=="__main__":
294
    parser = argparse.ArgumentParser()
295
296
    # Inference configuration.
297
    parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
298
    parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
299
    parser.add_argument('--sample_num', type=int, default=100, help='inference samples')
300
    parser.add_argument('--disable_correction', action='store_true', help='Disable SMILES correction')
301
   
302
    # Data configuration.
303
    parser.add_argument('--inf_smiles', type=str, required=True)
304
    parser.add_argument('--train_smiles', type=str, required=True)
305
    parser.add_argument('--train_drug_smiles', type=str, required=True)
306
    parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
307
    parser.add_argument('--mol_data_dir', type=str, default='data')
308
    parser.add_argument('--features', action='store_true', help='features dimension for nodes')
309
310
    # Model configuration.
311
    parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
312
    parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
313
    parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
314
    parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
315
    parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
316
    parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
317
    parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
318
319
    # Seed configuration.
320
    parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
321
    parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
322
323
    config = parser.parse_args()
324
    inference = Inference(config)
325
    inference.inference()