|
a |
|
b/inference.py |
|
|
1 |
import os |
|
|
2 |
import sys |
|
|
3 |
import time |
|
|
4 |
import random |
|
|
5 |
import pickle |
|
|
6 |
import argparse |
|
|
7 |
import os.path as osp |
|
|
8 |
|
|
|
9 |
import torch |
|
|
10 |
import torch.utils.data |
|
|
11 |
from torch_geometric.loader import DataLoader |
|
|
12 |
|
|
|
13 |
import pandas as pd |
|
|
14 |
from tqdm import tqdm |
|
|
15 |
|
|
|
16 |
from rdkit import RDLogger, Chem |
|
|
17 |
from rdkit.Chem import QED, RDConfig |
|
|
18 |
|
|
|
19 |
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score')) |
|
|
20 |
import sascorer |
|
|
21 |
|
|
|
22 |
from src.util.utils import * |
|
|
23 |
from src.model.models import Generator |
|
|
24 |
from src.data.dataset import DruggenDataset |
|
|
25 |
from src.data.utils import get_encoders_decoders, load_molecules |
|
|
26 |
from src.model.loss import generator_loss |
|
|
27 |
from src.util.smiles_cor import smi_correct |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
class Inference(object): |
|
|
31 |
"""Inference class for DrugGEN.""" |
|
|
32 |
|
|
|
33 |
def __init__(self, config): |
|
|
34 |
if config.set_seed: |
|
|
35 |
np.random.seed(config.seed) |
|
|
36 |
random.seed(config.seed) |
|
|
37 |
torch.manual_seed(config.seed) |
|
|
38 |
torch.cuda.manual_seed_all(config.seed) |
|
|
39 |
|
|
|
40 |
torch.backends.cudnn.deterministic = True |
|
|
41 |
torch.backends.cudnn.benchmark = False |
|
|
42 |
|
|
|
43 |
os.environ["PYTHONHASHSEED"] = str(config.seed) |
|
|
44 |
|
|
|
45 |
print(f'Using seed {config.seed}') |
|
|
46 |
|
|
|
47 |
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') |
|
|
48 |
|
|
|
49 |
# Initialize configurations |
|
|
50 |
self.submodel = config.submodel |
|
|
51 |
self.inference_model = config.inference_model |
|
|
52 |
self.sample_num = config.sample_num |
|
|
53 |
self.disable_correction = config.disable_correction |
|
|
54 |
|
|
|
55 |
# Data loader. |
|
|
56 |
self.inf_smiles = config.inf_smiles # SMILES containing text file for first dataset. |
|
|
57 |
# Write the full path to file. |
|
|
58 |
|
|
|
59 |
inf_smiles_basename = osp.basename(self.inf_smiles) |
|
|
60 |
|
|
|
61 |
# Get the base name without extension and add max_atom to it |
|
|
62 |
self.max_atom = config.max_atom # Model is based on one-shot generation. |
|
|
63 |
inf_smiles_base = os.path.splitext(inf_smiles_basename)[0] |
|
|
64 |
|
|
|
65 |
# Change extension from .smi to .pt and add max_atom to the filename |
|
|
66 |
self.inf_dataset_file = f"{inf_smiles_base}{self.max_atom}.pt" |
|
|
67 |
|
|
|
68 |
self.inf_batch_size = config.inf_batch_size |
|
|
69 |
self.train_smiles = config.train_smiles |
|
|
70 |
self.train_drug_smiles = config.train_drug_smiles |
|
|
71 |
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored. |
|
|
72 |
self.dataset_name = self.inf_dataset_file.split(".")[0] |
|
|
73 |
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.) |
|
|
74 |
# Additional node features can be added. Please check new_dataloarder.py Line 102. |
|
|
75 |
|
|
|
76 |
# Get atom and bond encoders/decoders |
|
|
77 |
self.atom_encoder, self.atom_decoder, self.bond_encoder, self.bond_decoder = get_encoders_decoders( |
|
|
78 |
self.train_smiles, |
|
|
79 |
self.train_drug_smiles, |
|
|
80 |
self.max_atom |
|
|
81 |
) |
|
|
82 |
|
|
|
83 |
self.inf_dataset = DruggenDataset(self.mol_data_dir, |
|
|
84 |
self.inf_dataset_file, |
|
|
85 |
self.inf_smiles, |
|
|
86 |
self.max_atom, |
|
|
87 |
self.features, |
|
|
88 |
atom_encoder=self.atom_encoder, |
|
|
89 |
atom_decoder=self.atom_decoder, |
|
|
90 |
bond_encoder=self.bond_encoder, |
|
|
91 |
bond_decoder=self.bond_decoder) |
|
|
92 |
|
|
|
93 |
self.inf_loader = DataLoader(self.inf_dataset, |
|
|
94 |
shuffle=True, |
|
|
95 |
batch_size=self.inf_batch_size, |
|
|
96 |
drop_last=True) # PyG dataloader for the first GAN. |
|
|
97 |
|
|
|
98 |
self.m_dim = len(self.atom_decoder) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension. |
|
|
99 |
self.b_dim = len(self.bond_decoder) # Bond type dimension. |
|
|
100 |
self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph. |
|
|
101 |
|
|
|
102 |
# Model configurations. |
|
|
103 |
self.act = config.act |
|
|
104 |
self.dim = config.dim |
|
|
105 |
self.depth = config.depth |
|
|
106 |
self.heads = config.heads |
|
|
107 |
self.mlp_ratio = config.mlp_ratio |
|
|
108 |
self.dropout = config.dropout |
|
|
109 |
|
|
|
110 |
self.build_model() |
|
|
111 |
|
|
|
112 |
def build_model(self): |
|
|
113 |
"""Create generators and discriminators.""" |
|
|
114 |
self.G = Generator(self.act, |
|
|
115 |
self.vertexes, |
|
|
116 |
self.b_dim, |
|
|
117 |
self.m_dim, |
|
|
118 |
self.dropout, |
|
|
119 |
dim=self.dim, |
|
|
120 |
depth=self.depth, |
|
|
121 |
heads=self.heads, |
|
|
122 |
mlp_ratio=self.mlp_ratio) |
|
|
123 |
self.G.to(self.device) |
|
|
124 |
self.print_network(self.G, 'G') |
|
|
125 |
|
|
|
126 |
def print_network(self, model, name): |
|
|
127 |
"""Print out the network information.""" |
|
|
128 |
num_params = 0 |
|
|
129 |
for p in model.parameters(): |
|
|
130 |
num_params += p.numel() |
|
|
131 |
print(model) |
|
|
132 |
print(name) |
|
|
133 |
print("The number of parameters: {}".format(num_params)) |
|
|
134 |
|
|
|
135 |
def restore_model(self, submodel, model_directory): |
|
|
136 |
"""Restore the trained generator and discriminator.""" |
|
|
137 |
print('Loading the model...') |
|
|
138 |
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel)) |
|
|
139 |
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) |
|
|
140 |
|
|
|
141 |
def inference(self): |
|
|
142 |
# Load the trained generator. |
|
|
143 |
self.restore_model(self.submodel, self.inference_model) |
|
|
144 |
|
|
|
145 |
# smiles data for metrics calculation. |
|
|
146 |
chembl_smiles = [line for line in open(self.train_smiles, 'r').read().splitlines()] |
|
|
147 |
chembl_test = [line for line in open(self.inf_smiles, 'r').read().splitlines()] |
|
|
148 |
drug_smiles = [line for line in open(self.train_drug_smiles, 'r').read().splitlines()] |
|
|
149 |
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles] |
|
|
150 |
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None] |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
# Make directories if not exist. |
|
|
154 |
if not os.path.exists("experiments/inference/{}".format(self.submodel)): |
|
|
155 |
os.makedirs("experiments/inference/{}".format(self.submodel)) |
|
|
156 |
|
|
|
157 |
if not self.disable_correction: |
|
|
158 |
correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel)) |
|
|
159 |
|
|
|
160 |
search_res = pd.DataFrame(columns=["submodel", "validity", |
|
|
161 |
"uniqueness", "novelty", |
|
|
162 |
"novelty_test", "drug_novelty", |
|
|
163 |
"max_len", "mean_atom_type", |
|
|
164 |
"snn_chembl", "snn_drug", "IntDiv", "qed", "sa"]) |
|
|
165 |
|
|
|
166 |
self.G.eval() |
|
|
167 |
|
|
|
168 |
start_time = time.time() |
|
|
169 |
metric_calc_dr = [] |
|
|
170 |
uniqueness_calc = [] |
|
|
171 |
real_smiles_snn = [] |
|
|
172 |
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device) |
|
|
173 |
f = open("experiments/inference/{}/inference_drugs.txt".format(self.submodel), "w") |
|
|
174 |
f.write("SMILES") |
|
|
175 |
f.write("\n") |
|
|
176 |
val_counter = 0 |
|
|
177 |
none_counter = 0 |
|
|
178 |
|
|
|
179 |
# Inference mode |
|
|
180 |
with torch.inference_mode(): |
|
|
181 |
pbar = tqdm(range(self.sample_num)) |
|
|
182 |
pbar.set_description('Inference mode for {} model started'.format(self.submodel)) |
|
|
183 |
for i, data in enumerate(self.inf_loader): |
|
|
184 |
|
|
|
185 |
val_counter += 1 |
|
|
186 |
# Preprocess dataset |
|
|
187 |
_, a_tensor, x_tensor = load_molecules( |
|
|
188 |
data=data, |
|
|
189 |
batch_size=self.inf_batch_size, |
|
|
190 |
device=self.device, |
|
|
191 |
b_dim=self.b_dim, |
|
|
192 |
m_dim=self.m_dim, |
|
|
193 |
) |
|
|
194 |
|
|
|
195 |
_, _, node_sample, edge_sample = self.G(a_tensor, x_tensor) |
|
|
196 |
|
|
|
197 |
g_edges_hat_sample = torch.max(edge_sample, -1)[1] |
|
|
198 |
g_nodes_hat_sample = torch.max(node_sample, -1)[1] |
|
|
199 |
|
|
|
200 |
fake_mol_g = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name) |
|
|
201 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)] |
|
|
202 |
|
|
|
203 |
a_tensor_sample = torch.max(a_tensor, -1)[1] |
|
|
204 |
x_tensor_sample = torch.max(x_tensor, -1)[1] |
|
|
205 |
real_mols = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name) |
|
|
206 |
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)] |
|
|
207 |
|
|
|
208 |
inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g] |
|
|
209 |
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs] |
|
|
210 |
|
|
|
211 |
for molecules in inference_drugs: |
|
|
212 |
if molecules is None: |
|
|
213 |
none_counter += 1 |
|
|
214 |
|
|
|
215 |
for molecules in inference_drugs: |
|
|
216 |
if molecules is not None: |
|
|
217 |
molecules = molecules.replace("*", "C") |
|
|
218 |
f.write(molecules) |
|
|
219 |
f.write("\n") |
|
|
220 |
uniqueness_calc.append(molecules) |
|
|
221 |
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0) |
|
|
222 |
pbar.update(1) |
|
|
223 |
metric_calc_dr.append(molecules) |
|
|
224 |
|
|
|
225 |
real_smiles_snn.append(real_mols[0]) |
|
|
226 |
generation_number = len([x for x in metric_calc_dr if x is not None]) |
|
|
227 |
if generation_number == self.sample_num or none_counter == self.sample_num: |
|
|
228 |
break |
|
|
229 |
|
|
|
230 |
f.close() |
|
|
231 |
print("Inference completed, starting metrics calculation.") |
|
|
232 |
if not self.disable_correction: |
|
|
233 |
corrected = correct.correct("experiments/inference/{}/inference_drugs.txt".format(self.submodel)) |
|
|
234 |
gen_smi = corrected["SMILES"].tolist() |
|
|
235 |
|
|
|
236 |
else: |
|
|
237 |
gen_smi = pd.read_csv("experiments/inference/{}/inference_drugs.txt".format(self.submodel))["SMILES"].tolist() |
|
|
238 |
|
|
|
239 |
|
|
|
240 |
et = time.time() - start_time |
|
|
241 |
|
|
|
242 |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None] |
|
|
243 |
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None] |
|
|
244 |
print("Inference mode is lasted for {:.2f} seconds".format(et)) |
|
|
245 |
|
|
|
246 |
print("Metrics calculation started using MOSES.") |
|
|
247 |
|
|
|
248 |
if not self.disable_correction: |
|
|
249 |
val = round(len(gen_smi)/self.sample_num, 3) |
|
|
250 |
print("Validity: ", val, "\n") |
|
|
251 |
else: |
|
|
252 |
val = round(fraction_valid(gen_smi), 3) |
|
|
253 |
print("Validity: ", val, "\n") |
|
|
254 |
|
|
|
255 |
uniq = round(fraction_unique(gen_smi), 3) |
|
|
256 |
nov = round(novelty(gen_smi, chembl_smiles), 3) |
|
|
257 |
nov_test = round(novelty(gen_smi, chembl_test), 3) |
|
|
258 |
drug_nov = round(novelty(gen_smi, drug_smiles), 3) |
|
|
259 |
max_len = round(Metrics.max_component(gen_smi, self.vertexes), 3) |
|
|
260 |
mean_atom = round(Metrics.mean_atom_type(nodes_sample), 3) |
|
|
261 |
snn_chembl = round(average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)), 3) |
|
|
262 |
snn_drug = round(average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)), 3) |
|
|
263 |
int_div = round((internal_diversity(np.array(gen_vecs)))[0], 3) |
|
|
264 |
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3) |
|
|
265 |
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3) |
|
|
266 |
|
|
|
267 |
print("Uniqueness: ", uniq, "\n") |
|
|
268 |
print("Novelty (Train): ", nov, "\n") |
|
|
269 |
print("Novelty (Inference): ", nov_test, "\n") |
|
|
270 |
print("Novelty (Real Inhibitors): ", drug_nov, "\n") |
|
|
271 |
print("Average Length: ", max_len, "\n") |
|
|
272 |
print("Mean Atom Type: ", mean_atom, "\n") |
|
|
273 |
print("SNN (ChEMBL): ", snn_chembl, "\n") |
|
|
274 |
print("SNN (Real Inhibitors): ", snn_drug, "\n") |
|
|
275 |
print("Internal Diversity: ", int_div, "\n") |
|
|
276 |
print("QED: ", qed, "\n") |
|
|
277 |
print("SA: ", sa, "\n") |
|
|
278 |
|
|
|
279 |
print("Metrics are calculated.") |
|
|
280 |
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val], |
|
|
281 |
"uniqueness": [uniq], "novelty": [nov], |
|
|
282 |
"novelty_inference": [nov_test], "novelty_real_inhibitor": [drug_nov], |
|
|
283 |
"ave_len": [max_len], "mean_atom_type": [mean_atom], |
|
|
284 |
"snn_chembl": [snn_chembl], "snn_real_inhibitor": [snn_drug], |
|
|
285 |
"IntDiv": [int_div], "qed": [qed], "sa": [sa]}) |
|
|
286 |
search_res = pd.concat([search_res, model_res], axis=0) |
|
|
287 |
os.remove("experiments/inference/{}/inference_drugs.txt".format(self.submodel)) |
|
|
288 |
search_res.to_csv("experiments/inference/{}/inference_results.csv".format(self.submodel), index=False) |
|
|
289 |
generatedsmiles = pd.DataFrame({"SMILES": gen_smi}) |
|
|
290 |
generatedsmiles.to_csv("experiments/inference/{}/inference_drugs.csv".format(self.submodel), index=False) |
|
|
291 |
|
|
|
292 |
|
|
|
293 |
if __name__=="__main__": |
|
|
294 |
parser = argparse.ArgumentParser() |
|
|
295 |
|
|
|
296 |
# Inference configuration. |
|
|
297 |
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget']) |
|
|
298 |
parser.add_argument('--inference_model', type=str, help="Path to the model for inference") |
|
|
299 |
parser.add_argument('--sample_num', type=int, default=100, help='inference samples') |
|
|
300 |
parser.add_argument('--disable_correction', action='store_true', help='Disable SMILES correction') |
|
|
301 |
|
|
|
302 |
# Data configuration. |
|
|
303 |
parser.add_argument('--inf_smiles', type=str, required=True) |
|
|
304 |
parser.add_argument('--train_smiles', type=str, required=True) |
|
|
305 |
parser.add_argument('--train_drug_smiles', type=str, required=True) |
|
|
306 |
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference') |
|
|
307 |
parser.add_argument('--mol_data_dir', type=str, default='data') |
|
|
308 |
parser.add_argument('--features', action='store_true', help='features dimension for nodes') |
|
|
309 |
|
|
|
310 |
# Model configuration. |
|
|
311 |
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid']) |
|
|
312 |
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.') |
|
|
313 |
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.') |
|
|
314 |
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.') |
|
|
315 |
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.') |
|
|
316 |
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.') |
|
|
317 |
parser.add_argument('--dropout', type=float, default=0., help='dropout rate') |
|
|
318 |
|
|
|
319 |
# Seed configuration. |
|
|
320 |
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility') |
|
|
321 |
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility') |
|
|
322 |
|
|
|
323 |
config = parser.parse_args() |
|
|
324 |
inference = Inference(config) |
|
|
325 |
inference.inference() |