|
a |
|
b/process_bindingmoad.py |
|
|
1 |
from pathlib import Path |
|
|
2 |
from time import time |
|
|
3 |
import random |
|
|
4 |
from collections import defaultdict |
|
|
5 |
import argparse |
|
|
6 |
import warnings |
|
|
7 |
|
|
|
8 |
from tqdm import tqdm |
|
|
9 |
import numpy as np |
|
|
10 |
import torch |
|
|
11 |
from Bio.PDB import PDBParser |
|
|
12 |
from Bio.PDB.Polypeptide import three_to_one, is_aa |
|
|
13 |
from Bio.PDB import PDBIO, Select |
|
|
14 |
from openbabel import openbabel |
|
|
15 |
from rdkit import Chem |
|
|
16 |
from rdkit.Chem import QED |
|
|
17 |
from scipy.ndimage import gaussian_filter |
|
|
18 |
|
|
|
19 |
from geometry_utils import get_bb_transform |
|
|
20 |
from analysis.molecule_builder import build_molecule |
|
|
21 |
from analysis.metrics import rdmol_to_smiles |
|
|
22 |
import constants |
|
|
23 |
from constants import covalent_radii, dataset_params |
|
|
24 |
import utils |
|
|
25 |
|
|
|
26 |
dataset_info = dataset_params['bindingmoad'] |
|
|
27 |
amino_acid_dict = dataset_info['aa_encoder'] |
|
|
28 |
atom_dict = dataset_info['atom_encoder'] |
|
|
29 |
atom_decoder = dataset_info['atom_decoder'] |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
class Model0(Select): |
|
|
33 |
def accept_model(self, model): |
|
|
34 |
return model.id == 0 |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def read_label_file(csv_path): |
|
|
38 |
""" |
|
|
39 |
Read BindingMOAD's label file |
|
|
40 |
Args: |
|
|
41 |
csv_path: path to 'every.csv' |
|
|
42 |
Returns: |
|
|
43 |
Nested dictionary with all ligands. First level: EC number, |
|
|
44 |
Second level: PDB ID, Third level: list of ligands. Each ligand is |
|
|
45 |
represented as a tuple (ligand name, validity, SMILES string) |
|
|
46 |
""" |
|
|
47 |
ligand_dict = {} |
|
|
48 |
|
|
|
49 |
with open(csv_path, 'r') as f: |
|
|
50 |
for line in f.readlines(): |
|
|
51 |
row = line.split(',') |
|
|
52 |
|
|
|
53 |
# new protein class |
|
|
54 |
if len(row[0]) > 0: |
|
|
55 |
curr_class = row[0] |
|
|
56 |
ligand_dict[curr_class] = {} |
|
|
57 |
continue |
|
|
58 |
|
|
|
59 |
# new protein |
|
|
60 |
if len(row[2]) > 0: |
|
|
61 |
curr_prot = row[2] |
|
|
62 |
ligand_dict[curr_class][curr_prot] = [] |
|
|
63 |
continue |
|
|
64 |
|
|
|
65 |
# new small molecule |
|
|
66 |
if len(row[3]) > 0: |
|
|
67 |
ligand_dict[curr_class][curr_prot].append( |
|
|
68 |
# (ligand name, validity, SMILES string) |
|
|
69 |
[row[3], row[4], row[9]] |
|
|
70 |
) |
|
|
71 |
|
|
|
72 |
return ligand_dict |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def compute_druglikeness(ligand_dict): |
|
|
76 |
""" |
|
|
77 |
Computes RDKit's QED value and adds it to the dictionary |
|
|
78 |
Args: |
|
|
79 |
ligand_dict: nested ligand dictionary |
|
|
80 |
Returns: |
|
|
81 |
the same ligand dictionary with additional QED values |
|
|
82 |
""" |
|
|
83 |
print("Computing QED values...") |
|
|
84 |
for p, m in tqdm([(p, m) for c in ligand_dict for p in ligand_dict[c] |
|
|
85 |
for m in ligand_dict[c][p]]): |
|
|
86 |
mol = Chem.MolFromSmiles(m[2]) |
|
|
87 |
if mol is None: |
|
|
88 |
mol_id = f'{p}_{m}' |
|
|
89 |
warnings.warn(f"Could not construct molecule {mol_id} from SMILES " |
|
|
90 |
f"string '{m[2]}'") |
|
|
91 |
continue |
|
|
92 |
m.append(QED.qed(mol)) |
|
|
93 |
return ligand_dict |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
def filter_and_flatten(ligand_dict, qed_thresh, max_occurences, seed): |
|
|
97 |
|
|
|
98 |
filtered_examples = [] |
|
|
99 |
all_examples = [(c, p, m) for c in ligand_dict for p in ligand_dict[c] |
|
|
100 |
for m in ligand_dict[c][p]] |
|
|
101 |
|
|
|
102 |
# shuffle to select random examples of ligands that occur more than |
|
|
103 |
# max_occurences times |
|
|
104 |
random.seed(seed) |
|
|
105 |
random.shuffle(all_examples) |
|
|
106 |
|
|
|
107 |
ligand_name_counter = defaultdict(int) |
|
|
108 |
print("Filtering examples...") |
|
|
109 |
for c, p, m in tqdm(all_examples): |
|
|
110 |
|
|
|
111 |
ligand_name, ligand_chain, ligand_resi = m[0].split(':') |
|
|
112 |
if m[1] == 'valid' and len(m) > 3 and m[3] > qed_thresh: |
|
|
113 |
if ligand_name_counter[ligand_name] < max_occurences: |
|
|
114 |
filtered_examples.append( |
|
|
115 |
(c, p, m) |
|
|
116 |
) |
|
|
117 |
ligand_name_counter[ligand_name] += 1 |
|
|
118 |
|
|
|
119 |
return filtered_examples |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def split_by_ec_number(data_list, n_val, n_test, ec_level=1): |
|
|
123 |
""" |
|
|
124 |
Split dataset into training, validation and test sets based on EC numbers |
|
|
125 |
https://en.wikipedia.org/wiki/Enzyme_Commission_number |
|
|
126 |
Args: |
|
|
127 |
data_list: list of ligands |
|
|
128 |
n_val: number of validation examples |
|
|
129 |
n_test: number of test examples |
|
|
130 |
ec_level: level in the EC numbering hierarchy at which the split is |
|
|
131 |
made, i.e. items with matching EC numbers at this level are put in |
|
|
132 |
the same set |
|
|
133 |
Returns: |
|
|
134 |
dictionary with keys 'train', 'val', and 'test' |
|
|
135 |
""" |
|
|
136 |
|
|
|
137 |
examples_per_class = defaultdict(int) |
|
|
138 |
for c, p, m in data_list: |
|
|
139 |
c_sub = '.'.join(c.split('.')[:ec_level]) |
|
|
140 |
examples_per_class[c_sub] += 1 |
|
|
141 |
|
|
|
142 |
assert sum(examples_per_class.values()) == len(data_list) |
|
|
143 |
|
|
|
144 |
# split ec numbers |
|
|
145 |
val_classes = set() |
|
|
146 |
for c, num in sorted(examples_per_class.items(), key=lambda x: x[1], |
|
|
147 |
reverse=True): |
|
|
148 |
if sum([examples_per_class[x] for x in val_classes]) + num <= n_val: |
|
|
149 |
val_classes.add(c) |
|
|
150 |
|
|
|
151 |
test_classes = set() |
|
|
152 |
for c, num in sorted(examples_per_class.items(), key=lambda x: x[1], |
|
|
153 |
reverse=True): |
|
|
154 |
# skip classes already used in the validation set |
|
|
155 |
if c in val_classes: |
|
|
156 |
continue |
|
|
157 |
if sum([examples_per_class[x] for x in test_classes]) + num <= n_test: |
|
|
158 |
test_classes.add(c) |
|
|
159 |
|
|
|
160 |
# remaining classes belong to test set |
|
|
161 |
train_classes = {x for x in examples_per_class if |
|
|
162 |
x not in val_classes and x not in test_classes} |
|
|
163 |
|
|
|
164 |
# create separate lists of examples |
|
|
165 |
data_split = {} |
|
|
166 |
data_split['train'] = [x for x in data_list if '.'.join( |
|
|
167 |
x[0].split('.')[:ec_level]) in train_classes] |
|
|
168 |
data_split['val'] = [x for x in data_list if '.'.join( |
|
|
169 |
x[0].split('.')[:ec_level]) in val_classes] |
|
|
170 |
data_split['test'] = [x for x in data_list if '.'.join( |
|
|
171 |
x[0].split('.')[:ec_level]) in test_classes] |
|
|
172 |
|
|
|
173 |
assert len(data_split['train']) + len(data_split['val']) + \ |
|
|
174 |
len(data_split['test']) == len(data_list) |
|
|
175 |
|
|
|
176 |
return data_split |
|
|
177 |
|
|
|
178 |
|
|
|
179 |
def ligand_list_to_dict(ligand_list): |
|
|
180 |
out_dict = defaultdict(list) |
|
|
181 |
for _, p, m in ligand_list: |
|
|
182 |
out_dict[p].append(m) |
|
|
183 |
return out_dict |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
def process_ligand_and_pocket(pdb_struct, ligand_name, ligand_chain, |
|
|
187 |
ligand_resi, dist_cutoff, ca_only, |
|
|
188 |
compute_quaternion=False): |
|
|
189 |
try: |
|
|
190 |
residues = {obj.id[1]: obj for obj in |
|
|
191 |
pdb_struct[0][ligand_chain].get_residues()} |
|
|
192 |
except KeyError as e: |
|
|
193 |
raise KeyError(f'Chain {e} not found ({pdbfile}, ' |
|
|
194 |
f'{ligand_name}:{ligand_chain}:{ligand_resi})') |
|
|
195 |
ligand = residues[ligand_resi] |
|
|
196 |
assert ligand.get_resname() == ligand_name, \ |
|
|
197 |
f"{ligand.get_resname()} != {ligand_name}" |
|
|
198 |
|
|
|
199 |
# remove H atoms if not in atom_dict, other atom types that aren't allowed |
|
|
200 |
# should stay so that the entire ligand can be removed from the dataset |
|
|
201 |
lig_atoms = [a for a in ligand.get_atoms() |
|
|
202 |
if (a.element.capitalize() in atom_dict or a.element != 'H')] |
|
|
203 |
lig_coords = np.array([a.get_coord() for a in lig_atoms]) |
|
|
204 |
|
|
|
205 |
try: |
|
|
206 |
lig_one_hot = np.stack([ |
|
|
207 |
np.eye(1, len(atom_dict), atom_dict[a.element.capitalize()]).squeeze() |
|
|
208 |
for a in lig_atoms |
|
|
209 |
]) |
|
|
210 |
except KeyError as e: |
|
|
211 |
raise KeyError( |
|
|
212 |
f'Ligand atom {e} not in atom dict ({pdbfile}, ' |
|
|
213 |
f'{ligand_name}:{ligand_chain}:{ligand_resi})') |
|
|
214 |
|
|
|
215 |
# Find interacting pocket residues based on distance cutoff |
|
|
216 |
pocket_residues = [] |
|
|
217 |
for residue in pdb_struct[0].get_residues(): |
|
|
218 |
res_coords = np.array([a.get_coord() for a in residue.get_atoms()]) |
|
|
219 |
if is_aa(residue.get_resname(), standard=True) and \ |
|
|
220 |
(((res_coords[:, None, :] - lig_coords[None, :, :]) ** 2).sum(-1) ** 0.5).min() < dist_cutoff: |
|
|
221 |
pocket_residues.append(residue) |
|
|
222 |
|
|
|
223 |
# Compute transform of the canonical reference frame |
|
|
224 |
n_xyz = np.array([res['N'].get_coord() for res in pocket_residues]) |
|
|
225 |
ca_xyz = np.array([res['CA'].get_coord() for res in pocket_residues]) |
|
|
226 |
c_xyz = np.array([res['C'].get_coord() for res in pocket_residues]) |
|
|
227 |
|
|
|
228 |
if compute_quaternion: |
|
|
229 |
quaternion, c_alpha = get_bb_transform(n_xyz, ca_xyz, c_xyz) |
|
|
230 |
if np.any(np.isnan(quaternion)): |
|
|
231 |
raise ValueError( |
|
|
232 |
f'Invalid value in quaternion ({pdbfile}, ' |
|
|
233 |
f'{ligand_name}:{ligand_chain}:{ligand_resi})') |
|
|
234 |
else: |
|
|
235 |
c_alpha = ca_xyz |
|
|
236 |
|
|
|
237 |
if ca_only: |
|
|
238 |
pocket_coords = c_alpha |
|
|
239 |
try: |
|
|
240 |
pocket_one_hot = np.stack([ |
|
|
241 |
np.eye(1, len(amino_acid_dict), |
|
|
242 |
amino_acid_dict[three_to_one(res.get_resname())]).squeeze() |
|
|
243 |
for res in pocket_residues]) |
|
|
244 |
except KeyError as e: |
|
|
245 |
raise KeyError( |
|
|
246 |
f'{e} not in amino acid dict ({pdbfile}, ' |
|
|
247 |
f'{ligand_name}:{ligand_chain}:{ligand_resi})') |
|
|
248 |
else: |
|
|
249 |
pocket_atoms = [a for res in pocket_residues for a in res.get_atoms() |
|
|
250 |
if (a.element.capitalize() in atom_dict or a.element != 'H')] |
|
|
251 |
pocket_coords = np.array([a.get_coord() for a in pocket_atoms]) |
|
|
252 |
try: |
|
|
253 |
pocket_one_hot = np.stack([ |
|
|
254 |
np.eye(1, len(atom_dict), atom_dict[a.element.capitalize()]).squeeze() |
|
|
255 |
for a in pocket_atoms |
|
|
256 |
]) |
|
|
257 |
except KeyError as e: |
|
|
258 |
raise KeyError( |
|
|
259 |
f'Pocket atom {e} not in atom dict ({pdbfile}, ' |
|
|
260 |
f'{ligand_name}:{ligand_chain}:{ligand_resi})') |
|
|
261 |
|
|
|
262 |
pocket_ids = [f'{res.parent.id}:{res.id[1]}' for res in pocket_residues] |
|
|
263 |
|
|
|
264 |
ligand_data = { |
|
|
265 |
'lig_coords': lig_coords, |
|
|
266 |
'lig_one_hot': lig_one_hot, |
|
|
267 |
} |
|
|
268 |
pocket_data = { |
|
|
269 |
'pocket_coords': pocket_coords, |
|
|
270 |
'pocket_one_hot': pocket_one_hot, |
|
|
271 |
'pocket_ids': pocket_ids, |
|
|
272 |
} |
|
|
273 |
if compute_quaternion: |
|
|
274 |
pocket_data['pocket_quaternion'] = quaternion |
|
|
275 |
return ligand_data, pocket_data |
|
|
276 |
|
|
|
277 |
|
|
|
278 |
def compute_smiles(positions, one_hot, mask): |
|
|
279 |
print("Computing SMILES ...") |
|
|
280 |
|
|
|
281 |
atom_types = np.argmax(one_hot, axis=-1) |
|
|
282 |
|
|
|
283 |
sections = np.where(np.diff(mask))[0] + 1 |
|
|
284 |
positions = [torch.from_numpy(x) for x in np.split(positions, sections)] |
|
|
285 |
atom_types = [torch.from_numpy(x) for x in np.split(atom_types, sections)] |
|
|
286 |
|
|
|
287 |
mols_smiles = [] |
|
|
288 |
|
|
|
289 |
pbar = tqdm(enumerate(zip(positions, atom_types)), |
|
|
290 |
total=len(np.unique(mask))) |
|
|
291 |
for i, (pos, atom_type) in pbar: |
|
|
292 |
mol = build_molecule(pos, atom_type, dataset_info) |
|
|
293 |
|
|
|
294 |
# BasicMolecularMetrics() computes SMILES after sanitization |
|
|
295 |
try: |
|
|
296 |
Chem.SanitizeMol(mol) |
|
|
297 |
except ValueError: |
|
|
298 |
continue |
|
|
299 |
|
|
|
300 |
mol = rdmol_to_smiles(mol) |
|
|
301 |
if mol is not None: |
|
|
302 |
mols_smiles.append(mol) |
|
|
303 |
pbar.set_description(f'{len(mols_smiles)}/{i + 1} successful') |
|
|
304 |
|
|
|
305 |
return mols_smiles |
|
|
306 |
|
|
|
307 |
|
|
|
308 |
def get_n_nodes(lig_mask, pocket_mask, smooth_sigma=None): |
|
|
309 |
# Joint distribution of ligand's and pocket's number of nodes |
|
|
310 |
idx_lig, n_nodes_lig = np.unique(lig_mask, return_counts=True) |
|
|
311 |
idx_pocket, n_nodes_pocket = np.unique(pocket_mask, return_counts=True) |
|
|
312 |
assert np.all(idx_lig == idx_pocket) |
|
|
313 |
|
|
|
314 |
joint_histogram = np.zeros((np.max(n_nodes_lig) + 1, |
|
|
315 |
np.max(n_nodes_pocket) + 1)) |
|
|
316 |
|
|
|
317 |
for nlig, npocket in zip(n_nodes_lig, n_nodes_pocket): |
|
|
318 |
joint_histogram[nlig, npocket] += 1 |
|
|
319 |
|
|
|
320 |
print(f'Original histogram: {np.count_nonzero(joint_histogram)}/' |
|
|
321 |
f'{joint_histogram.shape[0] * joint_histogram.shape[1]} bins filled') |
|
|
322 |
|
|
|
323 |
# Smooth the histogram |
|
|
324 |
if smooth_sigma is not None: |
|
|
325 |
filtered_histogram = gaussian_filter( |
|
|
326 |
joint_histogram, sigma=smooth_sigma, order=0, mode='constant', |
|
|
327 |
cval=0.0, truncate=4.0) |
|
|
328 |
|
|
|
329 |
print(f'Smoothed histogram: {np.count_nonzero(filtered_histogram)}/' |
|
|
330 |
f'{filtered_histogram.shape[0] * filtered_histogram.shape[1]} bins filled') |
|
|
331 |
|
|
|
332 |
joint_histogram = filtered_histogram |
|
|
333 |
|
|
|
334 |
return joint_histogram |
|
|
335 |
|
|
|
336 |
|
|
|
337 |
def get_bond_length_arrays(atom_mapping): |
|
|
338 |
bond_arrays = [] |
|
|
339 |
for i in range(3): |
|
|
340 |
bond_dict = getattr(constants, f'bonds{i + 1}') |
|
|
341 |
bond_array = np.zeros((len(atom_mapping), len(atom_mapping))) |
|
|
342 |
for a1 in atom_mapping.keys(): |
|
|
343 |
for a2 in atom_mapping.keys(): |
|
|
344 |
if a1 in bond_dict and a2 in bond_dict[a1]: |
|
|
345 |
bond_len = bond_dict[a1][a2] |
|
|
346 |
else: |
|
|
347 |
bond_len = 0 |
|
|
348 |
bond_array[atom_mapping[a1], atom_mapping[a2]] = bond_len |
|
|
349 |
|
|
|
350 |
assert np.all(bond_array == bond_array.T) |
|
|
351 |
bond_arrays.append(bond_array) |
|
|
352 |
|
|
|
353 |
return bond_arrays |
|
|
354 |
|
|
|
355 |
|
|
|
356 |
def get_lennard_jones_rm(atom_mapping): |
|
|
357 |
# Bond radii for the Lennard-Jones potential |
|
|
358 |
LJ_rm = np.zeros((len(atom_mapping), len(atom_mapping))) |
|
|
359 |
|
|
|
360 |
for a1 in atom_mapping.keys(): |
|
|
361 |
for a2 in atom_mapping.keys(): |
|
|
362 |
all_bond_lengths = [] |
|
|
363 |
for btype in ['bonds1', 'bonds2', 'bonds3']: |
|
|
364 |
bond_dict = getattr(constants, btype) |
|
|
365 |
if a1 in bond_dict and a2 in bond_dict[a1]: |
|
|
366 |
all_bond_lengths.append(bond_dict[a1][a2]) |
|
|
367 |
|
|
|
368 |
if len(all_bond_lengths) > 0: |
|
|
369 |
# take the shortest possible bond length because slightly larger |
|
|
370 |
# values aren't penalized as much |
|
|
371 |
bond_len = min(all_bond_lengths) |
|
|
372 |
else: |
|
|
373 |
# Replace missing values with sum of average covalent radii |
|
|
374 |
bond_len = covalent_radii[a1] + covalent_radii[a2] |
|
|
375 |
|
|
|
376 |
LJ_rm[atom_mapping[a1], atom_mapping[a2]] = bond_len |
|
|
377 |
|
|
|
378 |
assert np.all(LJ_rm == LJ_rm.T) |
|
|
379 |
return LJ_rm |
|
|
380 |
|
|
|
381 |
|
|
|
382 |
def get_type_histograms(lig_one_hot, pocket_one_hot, atom_encoder, aa_encoder): |
|
|
383 |
|
|
|
384 |
atom_decoder = list(atom_encoder.keys()) |
|
|
385 |
atom_counts = {k: 0 for k in atom_encoder.keys()} |
|
|
386 |
for a in [atom_decoder[x] for x in lig_one_hot.argmax(1)]: |
|
|
387 |
atom_counts[a] += 1 |
|
|
388 |
|
|
|
389 |
aa_decoder = list(aa_encoder.keys()) |
|
|
390 |
aa_counts = {k: 0 for k in aa_encoder.keys()} |
|
|
391 |
for r in [aa_decoder[x] for x in pocket_one_hot.argmax(1)]: |
|
|
392 |
aa_counts[r] += 1 |
|
|
393 |
|
|
|
394 |
return atom_counts, aa_counts |
|
|
395 |
|
|
|
396 |
|
|
|
397 |
def saveall(filename, pdb_and_mol_ids, lig_coords, lig_one_hot, lig_mask, |
|
|
398 |
pocket_coords, pocket_quaternion, pocket_one_hot, pocket_mask): |
|
|
399 |
|
|
|
400 |
np.savez(filename, |
|
|
401 |
names=pdb_and_mol_ids, |
|
|
402 |
lig_coords=lig_coords, |
|
|
403 |
lig_one_hot=lig_one_hot, |
|
|
404 |
lig_mask=lig_mask, |
|
|
405 |
pocket_coords=pocket_coords, |
|
|
406 |
pocket_quaternion=pocket_quaternion, |
|
|
407 |
pocket_one_hot=pocket_one_hot, |
|
|
408 |
pocket_mask=pocket_mask |
|
|
409 |
) |
|
|
410 |
return True |
|
|
411 |
|
|
|
412 |
|
|
|
413 |
if __name__ == '__main__': |
|
|
414 |
parser = argparse.ArgumentParser() |
|
|
415 |
parser.add_argument('basedir', type=Path) |
|
|
416 |
parser.add_argument('--outdir', type=Path, default=None) |
|
|
417 |
parser.add_argument('--qed_thresh', type=float, default=0.3) |
|
|
418 |
parser.add_argument('--max_occurences', type=int, default=50) |
|
|
419 |
parser.add_argument('--num_val', type=int, default=300) |
|
|
420 |
parser.add_argument('--num_test', type=int, default=300) |
|
|
421 |
parser.add_argument('--dist_cutoff', type=float, default=8.0) |
|
|
422 |
parser.add_argument('--ca_only', action='store_true') |
|
|
423 |
parser.add_argument('--random_seed', type=int, default=42) |
|
|
424 |
parser.add_argument('--make_split', action='store_true') |
|
|
425 |
args = parser.parse_args() |
|
|
426 |
|
|
|
427 |
pdbdir = args.basedir / 'BindingMOAD_2020/' |
|
|
428 |
|
|
|
429 |
# Make output directory |
|
|
430 |
if args.outdir is None: |
|
|
431 |
suffix = '' if 'H' in atom_dict else '_noH' |
|
|
432 |
suffix += '_ca_only' if args.ca_only else '_full' |
|
|
433 |
processed_dir = Path(args.basedir, f'processed{suffix}') |
|
|
434 |
else: |
|
|
435 |
processed_dir = args.outdir |
|
|
436 |
|
|
|
437 |
processed_dir.mkdir(exist_ok=True, parents=True) |
|
|
438 |
|
|
|
439 |
if args.make_split: |
|
|
440 |
# Process the label file |
|
|
441 |
csv_path = args.basedir / 'every.csv' |
|
|
442 |
ligand_dict = read_label_file(csv_path) |
|
|
443 |
ligand_dict = compute_druglikeness(ligand_dict) |
|
|
444 |
filtered_examples = filter_and_flatten( |
|
|
445 |
ligand_dict, args.qed_thresh, args.max_occurences, args.random_seed) |
|
|
446 |
print(f'{len(filtered_examples)} examples after filtering') |
|
|
447 |
|
|
|
448 |
# Make data split |
|
|
449 |
data_split = split_by_ec_number(filtered_examples, args.num_val, |
|
|
450 |
args.num_test) |
|
|
451 |
|
|
|
452 |
else: |
|
|
453 |
# Use precomputed data split |
|
|
454 |
data_split = {} |
|
|
455 |
for split in ['test', 'val', 'train']: |
|
|
456 |
with open(f'data/moad_{split}.txt', 'r') as f: |
|
|
457 |
pocket_ids = f.read().split(',') |
|
|
458 |
# (ec-number, protein, molecule tuple) |
|
|
459 |
data_split[split] = [(None, x.split('_')[0][:4], (x.split('_')[1],)) |
|
|
460 |
for x in pocket_ids] |
|
|
461 |
|
|
|
462 |
n_train_before = len(data_split['train']) |
|
|
463 |
n_val_before = len(data_split['val']) |
|
|
464 |
n_test_before = len(data_split['test']) |
|
|
465 |
|
|
|
466 |
# Read and process PDB files |
|
|
467 |
n_samples_after = {} |
|
|
468 |
for split in data_split.keys(): |
|
|
469 |
lig_coords = [] |
|
|
470 |
lig_one_hot = [] |
|
|
471 |
lig_mask = [] |
|
|
472 |
pocket_coords = [] |
|
|
473 |
pocket_one_hot = [] |
|
|
474 |
pocket_mask = [] |
|
|
475 |
pdb_and_mol_ids = [] |
|
|
476 |
receptors = [] |
|
|
477 |
count = 0 |
|
|
478 |
|
|
|
479 |
pdb_sdf_dir = processed_dir / split |
|
|
480 |
pdb_sdf_dir.mkdir(exist_ok=True) |
|
|
481 |
|
|
|
482 |
n_tot = len(data_split[split]) |
|
|
483 |
pair_dict = ligand_list_to_dict(data_split[split]) |
|
|
484 |
|
|
|
485 |
tic = time() |
|
|
486 |
num_failed = 0 |
|
|
487 |
with tqdm(total=n_tot) as pbar: |
|
|
488 |
for p in pair_dict: |
|
|
489 |
|
|
|
490 |
pdb_successful = set() |
|
|
491 |
|
|
|
492 |
# try all available .bio files |
|
|
493 |
for pdbfile in sorted(pdbdir.glob(f"{p.lower()}.bio*")): |
|
|
494 |
|
|
|
495 |
# Skip if all ligands have been processed already |
|
|
496 |
if len(pair_dict[p]) == len(pdb_successful): |
|
|
497 |
continue |
|
|
498 |
|
|
|
499 |
pdb_struct = PDBParser(QUIET=True).get_structure('', pdbfile) |
|
|
500 |
struct_copy = pdb_struct.copy() |
|
|
501 |
|
|
|
502 |
n_bio_successful = 0 |
|
|
503 |
for m in pair_dict[p]: |
|
|
504 |
|
|
|
505 |
# Skip already processed ligand |
|
|
506 |
if m[0] in pdb_successful: |
|
|
507 |
continue |
|
|
508 |
|
|
|
509 |
ligand_name, ligand_chain, ligand_resi = m[0].split(':') |
|
|
510 |
ligand_resi = int(ligand_resi) |
|
|
511 |
|
|
|
512 |
try: |
|
|
513 |
ligand_data, pocket_data = process_ligand_and_pocket( |
|
|
514 |
pdb_struct, ligand_name, ligand_chain, ligand_resi, |
|
|
515 |
dist_cutoff=args.dist_cutoff, ca_only=args.ca_only) |
|
|
516 |
except (KeyError, AssertionError, FileNotFoundError, |
|
|
517 |
IndexError, ValueError) as e: |
|
|
518 |
# print(type(e).__name__, e) |
|
|
519 |
continue |
|
|
520 |
|
|
|
521 |
pdb_and_mol_ids.append(f"{p}_{m[0]}") |
|
|
522 |
receptors.append(pdbfile.name) |
|
|
523 |
lig_coords.append(ligand_data['lig_coords']) |
|
|
524 |
lig_one_hot.append(ligand_data['lig_one_hot']) |
|
|
525 |
lig_mask.append( |
|
|
526 |
count * np.ones(len(ligand_data['lig_coords']))) |
|
|
527 |
pocket_coords.append(pocket_data['pocket_coords']) |
|
|
528 |
# pocket_quaternion.append( |
|
|
529 |
# pocket_data['pocket_quaternion']) |
|
|
530 |
pocket_one_hot.append(pocket_data['pocket_one_hot']) |
|
|
531 |
pocket_mask.append( |
|
|
532 |
count * np.ones(len(pocket_data['pocket_coords']))) |
|
|
533 |
count += 1 |
|
|
534 |
|
|
|
535 |
pdb_successful.add(m[0]) |
|
|
536 |
n_bio_successful += 1 |
|
|
537 |
|
|
|
538 |
# Save additional files for affinity analysis |
|
|
539 |
if split in {'val', 'test'}: |
|
|
540 |
# if split in {'val', 'test', 'train'}: |
|
|
541 |
# remove ligand from receptor |
|
|
542 |
try: |
|
|
543 |
struct_copy[0][ligand_chain].detach_child((f'H_{ligand_name}', ligand_resi, ' ')) |
|
|
544 |
except KeyError: |
|
|
545 |
warnings.warn(f"Could not find ligand {(f'H_{ligand_name}', ligand_resi, ' ')} in {pdbfile}") |
|
|
546 |
continue |
|
|
547 |
|
|
|
548 |
# Create SDF file |
|
|
549 |
atom_types = [atom_decoder[np.argmax(i)] for i in ligand_data['lig_one_hot']] |
|
|
550 |
xyz_file = Path(pdb_sdf_dir, 'tmp.xyz') |
|
|
551 |
utils.write_xyz_file(ligand_data['lig_coords'], atom_types, xyz_file) |
|
|
552 |
|
|
|
553 |
obConversion = openbabel.OBConversion() |
|
|
554 |
obConversion.SetInAndOutFormats("xyz", "sdf") |
|
|
555 |
mol = openbabel.OBMol() |
|
|
556 |
obConversion.ReadFile(mol, str(xyz_file)) |
|
|
557 |
xyz_file.unlink() |
|
|
558 |
|
|
|
559 |
name = f"{p}-{pdbfile.suffix[1:]}_{m[0]}" |
|
|
560 |
sdf_file = Path(pdb_sdf_dir, f'{name}.sdf') |
|
|
561 |
obConversion.WriteFile(mol, str(sdf_file)) |
|
|
562 |
|
|
|
563 |
# specify pocket residues |
|
|
564 |
with open(Path(pdb_sdf_dir, f'{name}.txt'), 'w') as f: |
|
|
565 |
f.write(' '.join(pocket_data['pocket_ids'])) |
|
|
566 |
|
|
|
567 |
if split in {'val', 'test'} and n_bio_successful > 0: |
|
|
568 |
# if split in {'val', 'test', 'train'} and n_bio_successful > 0: |
|
|
569 |
# create receptor PDB file |
|
|
570 |
pdb_file_out = Path(pdb_sdf_dir, f'{p}-{pdbfile.suffix[1:]}.pdb') |
|
|
571 |
io = PDBIO() |
|
|
572 |
io.set_structure(struct_copy) |
|
|
573 |
io.save(str(pdb_file_out), select=Model0()) |
|
|
574 |
|
|
|
575 |
pbar.update(len(pair_dict[p])) |
|
|
576 |
num_failed += (len(pair_dict[p]) - len(pdb_successful)) |
|
|
577 |
pbar.set_description(f'#failed: {num_failed}') |
|
|
578 |
|
|
|
579 |
|
|
|
580 |
lig_coords = np.concatenate(lig_coords, axis=0) |
|
|
581 |
lig_one_hot = np.concatenate(lig_one_hot, axis=0) |
|
|
582 |
lig_mask = np.concatenate(lig_mask, axis=0) |
|
|
583 |
pocket_coords = np.concatenate(pocket_coords, axis=0) |
|
|
584 |
pocket_one_hot = np.concatenate(pocket_one_hot, axis=0) |
|
|
585 |
pocket_mask = np.concatenate(pocket_mask, axis=0) |
|
|
586 |
|
|
|
587 |
np.savez(processed_dir / f'{split}.npz', names=pdb_and_mol_ids, |
|
|
588 |
receptors=receptors, lig_coords=lig_coords, |
|
|
589 |
lig_one_hot=lig_one_hot, lig_mask=lig_mask, |
|
|
590 |
pocket_coords=pocket_coords, pocket_one_hot=pocket_one_hot, |
|
|
591 |
pocket_mask=pocket_mask) |
|
|
592 |
|
|
|
593 |
n_samples_after[split] = len(pdb_and_mol_ids) |
|
|
594 |
print(f"Processing {split} set took {(time() - tic)/60.0:.2f} minutes") |
|
|
595 |
|
|
|
596 |
# -------------------------------------------------------------------------- |
|
|
597 |
# Compute statistics & additional information |
|
|
598 |
# -------------------------------------------------------------------------- |
|
|
599 |
with np.load(processed_dir / 'train.npz', allow_pickle=True) as data: |
|
|
600 |
lig_mask = data['lig_mask'] |
|
|
601 |
pocket_mask = data['pocket_mask'] |
|
|
602 |
lig_coords = data['lig_coords'] |
|
|
603 |
lig_one_hot = data['lig_one_hot'] |
|
|
604 |
pocket_one_hot = data['pocket_one_hot'] |
|
|
605 |
|
|
|
606 |
# Compute SMILES for all training examples |
|
|
607 |
train_smiles = compute_smiles(lig_coords, lig_one_hot, lig_mask) |
|
|
608 |
np.save(processed_dir / 'train_smiles.npy', train_smiles) |
|
|
609 |
|
|
|
610 |
# Joint histogram of number of ligand and pocket nodes |
|
|
611 |
n_nodes = get_n_nodes(lig_mask, pocket_mask, smooth_sigma=1.0) |
|
|
612 |
np.save(Path(processed_dir, 'size_distribution.npy'), n_nodes) |
|
|
613 |
|
|
|
614 |
# Convert bond length dictionaries to arrays for batch processing |
|
|
615 |
bonds1, bonds2, bonds3 = get_bond_length_arrays(atom_dict) |
|
|
616 |
|
|
|
617 |
# Get bond length definitions for Lennard-Jones potential |
|
|
618 |
rm_LJ = get_lennard_jones_rm(atom_dict) |
|
|
619 |
|
|
|
620 |
# Get histograms of ligand and pocket node types |
|
|
621 |
atom_hist, aa_hist = get_type_histograms(lig_one_hot, pocket_one_hot, |
|
|
622 |
atom_dict, amino_acid_dict) |
|
|
623 |
|
|
|
624 |
# Create summary string |
|
|
625 |
summary_string = '# SUMMARY\n\n' |
|
|
626 |
summary_string += '# Before processing\n' |
|
|
627 |
summary_string += f'num_samples train: {n_train_before}\n' |
|
|
628 |
summary_string += f'num_samples val: {n_val_before}\n' |
|
|
629 |
summary_string += f'num_samples test: {n_test_before}\n\n' |
|
|
630 |
summary_string += '# After processing\n' |
|
|
631 |
summary_string += f"num_samples train: {n_samples_after['train']}\n" |
|
|
632 |
summary_string += f"num_samples val: {n_samples_after['val']}\n" |
|
|
633 |
summary_string += f"num_samples test: {n_samples_after['test']}\n\n" |
|
|
634 |
summary_string += '# Info\n' |
|
|
635 |
summary_string += f"'atom_encoder': {atom_dict}\n" |
|
|
636 |
summary_string += f"'atom_decoder': {list(atom_dict.keys())}\n" |
|
|
637 |
summary_string += f"'aa_encoder': {amino_acid_dict}\n" |
|
|
638 |
summary_string += f"'aa_decoder': {list(amino_acid_dict.keys())}\n" |
|
|
639 |
summary_string += f"'bonds1': {bonds1.tolist()}\n" |
|
|
640 |
summary_string += f"'bonds2': {bonds2.tolist()}\n" |
|
|
641 |
summary_string += f"'bonds3': {bonds3.tolist()}\n" |
|
|
642 |
summary_string += f"'lennard_jones_rm': {rm_LJ.tolist()}\n" |
|
|
643 |
summary_string += f"'atom_hist': {atom_hist}\n" |
|
|
644 |
summary_string += f"'aa_hist': {aa_hist}\n" |
|
|
645 |
summary_string += f"'n_nodes': {n_nodes.tolist()}\n" |
|
|
646 |
|
|
|
647 |
# Write summary to text file |
|
|
648 |
with open(processed_dir / 'summary.txt', 'w') as f: |
|
|
649 |
f.write(summary_string) |
|
|
650 |
|
|
|
651 |
# Print summary |
|
|
652 |
print(summary_string) |