|
a |
|
b/src/util/utils.py |
|
|
1 |
import os |
|
|
2 |
import time |
|
|
3 |
import math |
|
|
4 |
import datetime |
|
|
5 |
import warnings |
|
|
6 |
import itertools |
|
|
7 |
from copy import deepcopy |
|
|
8 |
from functools import partial |
|
|
9 |
from collections import Counter |
|
|
10 |
from multiprocessing import Pool |
|
|
11 |
from statistics import mean |
|
|
12 |
|
|
|
13 |
import numpy as np |
|
|
14 |
import matplotlib.pyplot as plt |
|
|
15 |
from matplotlib.lines import Line2D |
|
|
16 |
from scipy.spatial.distance import cosine as cos_distance |
|
|
17 |
|
|
|
18 |
import torch |
|
|
19 |
import wandb |
|
|
20 |
|
|
|
21 |
from rdkit import Chem, DataStructs, RDLogger |
|
|
22 |
from rdkit.Chem import ( |
|
|
23 |
AllChem, |
|
|
24 |
Draw, |
|
|
25 |
Descriptors, |
|
|
26 |
Lipinski, |
|
|
27 |
Crippen, |
|
|
28 |
rdMolDescriptors, |
|
|
29 |
FilterCatalog, |
|
|
30 |
) |
|
|
31 |
from rdkit.Chem.Scaffolds import MurckoScaffold |
|
|
32 |
|
|
|
33 |
# Disable RDKit warnings |
|
|
34 |
RDLogger.DisableLog("rdApp.*") |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
class Metrics(object): |
|
|
38 |
""" |
|
|
39 |
Collection of static methods to compute various metrics for molecules. |
|
|
40 |
""" |
|
|
41 |
|
|
|
42 |
@staticmethod |
|
|
43 |
def valid(x): |
|
|
44 |
""" |
|
|
45 |
Checks whether the molecule is valid. |
|
|
46 |
|
|
|
47 |
Args: |
|
|
48 |
x: RDKit molecule object. |
|
|
49 |
|
|
|
50 |
Returns: |
|
|
51 |
bool: True if molecule is valid and has a non-empty SMILES representation. |
|
|
52 |
""" |
|
|
53 |
return x is not None and Chem.MolToSmiles(x) != '' |
|
|
54 |
|
|
|
55 |
@staticmethod |
|
|
56 |
def tanimoto_sim_1v2(data1, data2): |
|
|
57 |
""" |
|
|
58 |
Computes the average Tanimoto similarity for paired fingerprints. |
|
|
59 |
|
|
|
60 |
Args: |
|
|
61 |
data1: Fingerprint data for first set. |
|
|
62 |
data2: Fingerprint data for second set. |
|
|
63 |
|
|
|
64 |
Returns: |
|
|
65 |
float: The average Tanimoto similarity between corresponding fingerprints. |
|
|
66 |
""" |
|
|
67 |
# Determine the minimum size between two arrays for pairing |
|
|
68 |
min_len = data1.size if data1.size > data2.size else data2 |
|
|
69 |
sims = [] |
|
|
70 |
for i in range(min_len): |
|
|
71 |
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i]) |
|
|
72 |
sims.append(sim) |
|
|
73 |
# Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list. |
|
|
74 |
mean_sim = mean(sims) |
|
|
75 |
return mean_sim |
|
|
76 |
|
|
|
77 |
@staticmethod |
|
|
78 |
def mol_length(x): |
|
|
79 |
""" |
|
|
80 |
Computes the length of the largest fragment (by character count) in a SMILES string. |
|
|
81 |
|
|
|
82 |
Args: |
|
|
83 |
x (str): SMILES string. |
|
|
84 |
|
|
|
85 |
Returns: |
|
|
86 |
int: Number of alphabetic characters in the longest fragment of the SMILES. |
|
|
87 |
""" |
|
|
88 |
if x is not None: |
|
|
89 |
# Split at dots (.) and take the fragment with maximum length, then count alphabetic characters. |
|
|
90 |
return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()]) |
|
|
91 |
else: |
|
|
92 |
return 0 |
|
|
93 |
|
|
|
94 |
@staticmethod |
|
|
95 |
def max_component(data, max_len): |
|
|
96 |
""" |
|
|
97 |
Returns the average normalized length of molecules in the dataset. |
|
|
98 |
|
|
|
99 |
Each molecule's length is computed and divided by max_len, then averaged. |
|
|
100 |
|
|
|
101 |
Args: |
|
|
102 |
data (iterable): Collection of SMILES strings. |
|
|
103 |
max_len (int): Maximum possible length for normalization. |
|
|
104 |
|
|
|
105 |
Returns: |
|
|
106 |
float: Normalized average length. |
|
|
107 |
""" |
|
|
108 |
lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32) |
|
|
109 |
return (lengths / max_len).mean() |
|
|
110 |
|
|
|
111 |
@staticmethod |
|
|
112 |
def mean_atom_type(data): |
|
|
113 |
""" |
|
|
114 |
Computes the average number of unique atom types in the provided node data. |
|
|
115 |
|
|
|
116 |
Args: |
|
|
117 |
data (iterable): Iterable containing node data with unique atom types. |
|
|
118 |
|
|
|
119 |
Returns: |
|
|
120 |
float: The average count of unique atom types, subtracting one. |
|
|
121 |
""" |
|
|
122 |
atom_types_used = [] |
|
|
123 |
for i in data: |
|
|
124 |
# Assuming each element i has a .unique() method that returns unique atom types. |
|
|
125 |
atom_types_used.append(len(i.unique().tolist())) |
|
|
126 |
av_type = np.mean(atom_types_used) - 1 |
|
|
127 |
return av_type |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
def mols2grid_image(mols, path): |
|
|
131 |
""" |
|
|
132 |
Saves grid images for a list of molecules. |
|
|
133 |
|
|
|
134 |
For each molecule in the list, computes 2D coordinates and saves an image file. |
|
|
135 |
|
|
|
136 |
Args: |
|
|
137 |
mols (list): List of RDKit molecule objects. |
|
|
138 |
path (str): Directory where images will be saved. |
|
|
139 |
""" |
|
|
140 |
# Replace None molecules with an empty molecule |
|
|
141 |
mols = [e if e is not None else Chem.RWMol() for e in mols] |
|
|
142 |
|
|
|
143 |
for i in range(len(mols)): |
|
|
144 |
if Metrics.valid(mols[i]): |
|
|
145 |
AllChem.Compute2DCoords(mols[i]) |
|
|
146 |
file_path = os.path.join(path, "{}.png".format(i + 1)) |
|
|
147 |
Draw.MolToFile(mols[i], file_path, size=(1200, 1200)) |
|
|
148 |
# wandb.save(file_path) # Optionally save to Weights & Biases |
|
|
149 |
else: |
|
|
150 |
continue |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None): |
|
|
154 |
""" |
|
|
155 |
Saves the edge and node matrices along with SMILES strings to text files. |
|
|
156 |
|
|
|
157 |
Each file contains the edge matrix, node matrix, and SMILES representation for a molecule. |
|
|
158 |
|
|
|
159 |
Args: |
|
|
160 |
mols (list): List of RDKit molecule objects. |
|
|
161 |
edges_hard (torch.Tensor): Tensor of edge features. |
|
|
162 |
nodes_hard (torch.Tensor): Tensor of node features. |
|
|
163 |
path (str): Directory where files will be saved. |
|
|
164 |
data_source: Optional data source information (not used in function). |
|
|
165 |
""" |
|
|
166 |
mols = [e if e is not None else Chem.RWMol() for e in mols] |
|
|
167 |
|
|
|
168 |
for i in range(len(mols)): |
|
|
169 |
if Metrics.valid(mols[i]): |
|
|
170 |
save_path = os.path.join(path, "{}.txt".format(i + 1)) |
|
|
171 |
with open(save_path, "a") as f: |
|
|
172 |
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f') |
|
|
173 |
f.write("\n") |
|
|
174 |
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f') |
|
|
175 |
f.write("\n") |
|
|
176 |
# Append the SMILES representation to the file |
|
|
177 |
with open(save_path, "a") as f: |
|
|
178 |
print(Chem.MolToSmiles(mols[i]), file=f) |
|
|
179 |
# wandb.save(save_path) # Optionally save to Weights & Biases |
|
|
180 |
else: |
|
|
181 |
continue |
|
|
182 |
|
|
|
183 |
def dense_to_sparse_with_attr(adj): |
|
|
184 |
""" |
|
|
185 |
Converts a dense adjacency matrix to a sparse representation. |
|
|
186 |
|
|
|
187 |
Args: |
|
|
188 |
adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions. |
|
|
189 |
|
|
|
190 |
Returns: |
|
|
191 |
tuple: A tuple containing indices and corresponding edge attributes. |
|
|
192 |
""" |
|
|
193 |
assert adj.dim() >= 2 and adj.dim() <= 3 |
|
|
194 |
assert adj.size(-1) == adj.size(-2) |
|
|
195 |
|
|
|
196 |
index = adj.nonzero(as_tuple=True) |
|
|
197 |
edge_attr = adj[index] |
|
|
198 |
|
|
|
199 |
if len(index) == 3: |
|
|
200 |
batch = index[0] * adj.size(-1) |
|
|
201 |
index = (batch + index[1], batch + index[2]) |
|
|
202 |
return index, edge_attr |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name): |
|
|
206 |
""" |
|
|
207 |
Samples molecules from edge and node predictions, then saves grid images and text files. |
|
|
208 |
|
|
|
209 |
Args: |
|
|
210 |
sample_directory (str): Directory to save the samples. |
|
|
211 |
edges (torch.Tensor): Edge predictions tensor. |
|
|
212 |
nodes (torch.Tensor): Node predictions tensor. |
|
|
213 |
idx (int): Current index for naming the sample. |
|
|
214 |
i (int): Epoch/iteration index. |
|
|
215 |
matrices2mol (callable): Function to convert matrices to RDKit molecule. |
|
|
216 |
dataset_name (str): Name of the dataset for file naming. |
|
|
217 |
""" |
|
|
218 |
sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1)) |
|
|
219 |
# Get the index of the maximum predicted feature along the last dimension |
|
|
220 |
g_edges_hat_sample = torch.max(edges, -1)[1] |
|
|
221 |
g_nodes_hat_sample = torch.max(nodes, -1)[1] |
|
|
222 |
# Convert matrices to molecule objects |
|
|
223 |
mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), |
|
|
224 |
strict=True, file_name=dataset_name) |
|
|
225 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)] |
|
|
226 |
|
|
|
227 |
if not os.path.exists(sample_path): |
|
|
228 |
os.makedirs(sample_path) |
|
|
229 |
|
|
|
230 |
mols2grid_image(mol, sample_path) |
|
|
231 |
save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path) |
|
|
232 |
|
|
|
233 |
# Remove the directory if no files were saved |
|
|
234 |
if len(os.listdir(sample_path)) == 0: |
|
|
235 |
os.rmdir(sample_path) |
|
|
236 |
|
|
|
237 |
print("Valid molecules are saved.") |
|
|
238 |
print("Valid matrices and smiles are saved") |
|
|
239 |
|
|
|
240 |
|
|
|
241 |
def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node, |
|
|
242 |
matrices2mol, dataset_name, real_adj, real_annot, drug_vecs): |
|
|
243 |
""" |
|
|
244 |
Logs training statistics and evaluation metrics. |
|
|
245 |
|
|
|
246 |
The function generates molecules from predictions, computes various metrics such as |
|
|
247 |
validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file. |
|
|
248 |
|
|
|
249 |
Args: |
|
|
250 |
log_path (str): Path to save the log file. |
|
|
251 |
start_time (float): Start time to compute elapsed time. |
|
|
252 |
i (int): Current iteration index. |
|
|
253 |
idx (int): Current epoch index. |
|
|
254 |
loss (dict): Dictionary to update with loss and metric values. |
|
|
255 |
save_path (str): Directory path to save sample outputs. |
|
|
256 |
drug_smiles (list): List of reference drug SMILES. |
|
|
257 |
edge (torch.Tensor): Edge prediction tensor. |
|
|
258 |
node (torch.Tensor): Node prediction tensor. |
|
|
259 |
matrices2mol (callable): Function to convert matrices to molecules. |
|
|
260 |
dataset_name (str): Dataset name. |
|
|
261 |
real_adj (torch.Tensor): Ground truth adjacency matrix tensor. |
|
|
262 |
real_annot (torch.Tensor): Ground truth annotation tensor. |
|
|
263 |
drug_vecs (list): List of drug vectors for similarity calculation. |
|
|
264 |
""" |
|
|
265 |
g_edges_hat_sample = torch.max(edge, -1)[1] |
|
|
266 |
g_nodes_hat_sample = torch.max(node, -1)[1] |
|
|
267 |
|
|
|
268 |
a_tensor_sample = torch.max(real_adj, -1)[1].float() |
|
|
269 |
x_tensor_sample = torch.max(real_annot, -1)[1].float() |
|
|
270 |
|
|
|
271 |
# Generate molecules from predictions and real data |
|
|
272 |
mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), |
|
|
273 |
strict=True, file_name=dataset_name) |
|
|
274 |
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)] |
|
|
275 |
real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), |
|
|
276 |
strict=True, file_name=dataset_name) |
|
|
277 |
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)] |
|
|
278 |
|
|
|
279 |
# Compute average number of atom types |
|
|
280 |
atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample) |
|
|
281 |
real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None] |
|
|
282 |
gen_smiles = [] |
|
|
283 |
uniq_smiles = [] |
|
|
284 |
for line in mols: |
|
|
285 |
if line is not None: |
|
|
286 |
gen_smiles.append(Chem.MolToSmiles(line)) |
|
|
287 |
uniq_smiles.append(Chem.MolToSmiles(line)) |
|
|
288 |
elif line is None: |
|
|
289 |
gen_smiles.append(None) |
|
|
290 |
|
|
|
291 |
# Process SMILES to take the longest fragment if multiple are present |
|
|
292 |
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles] |
|
|
293 |
uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles] |
|
|
294 |
|
|
|
295 |
# Save the generated SMILES to a text file |
|
|
296 |
sample_save_dir = os.path.join(save_path, "samples.txt") |
|
|
297 |
with open(sample_save_dir, "a") as f: |
|
|
298 |
for s in gen_smiles_saves: |
|
|
299 |
if s is not None: |
|
|
300 |
f.write(s + "\n") |
|
|
301 |
|
|
|
302 |
k = len(set(uniq_smiles_saves) - {None}) |
|
|
303 |
et = time.time() - start_time |
|
|
304 |
et = str(datetime.timedelta(seconds=et))[:-7] |
|
|
305 |
log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1) |
|
|
306 |
|
|
|
307 |
# Generate molecular fingerprints for similarity computations |
|
|
308 |
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None] |
|
|
309 |
chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None] |
|
|
310 |
|
|
|
311 |
# Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length. |
|
|
312 |
valid = fraction_valid(gen_smiles_saves) |
|
|
313 |
unique = fraction_unique(uniq_smiles_saves, k) |
|
|
314 |
novel_starting_mol = novelty(gen_smiles_saves, real_smiles) |
|
|
315 |
novel_akt = novelty(gen_smiles_saves, drug_smiles) |
|
|
316 |
if len(uniq_smiles_saves) == 0: |
|
|
317 |
snn_chembl = 0 |
|
|
318 |
snn_akt = 0 |
|
|
319 |
maxlen = 0 |
|
|
320 |
else: |
|
|
321 |
snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs)) |
|
|
322 |
snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)) |
|
|
323 |
maxlen = Metrics.max_component(uniq_smiles_saves, 45) |
|
|
324 |
|
|
|
325 |
# Update loss dictionary with computed metrics |
|
|
326 |
loss.update({ |
|
|
327 |
'Validity': valid, |
|
|
328 |
'Uniqueness': unique, |
|
|
329 |
'Novelty': novel_starting_mol, |
|
|
330 |
'Novelty_akt': novel_akt, |
|
|
331 |
'SNN_chembl': snn_chembl, |
|
|
332 |
'SNN_akt': snn_akt, |
|
|
333 |
'MaxLen': maxlen, |
|
|
334 |
'Atom_types': atom_types_average |
|
|
335 |
}) |
|
|
336 |
|
|
|
337 |
# Log metrics using wandb |
|
|
338 |
wandb.log({ |
|
|
339 |
"Validity": valid, |
|
|
340 |
"Uniqueness": unique, |
|
|
341 |
"Novelty": novel_starting_mol, |
|
|
342 |
"Novelty_akt": novel_akt, |
|
|
343 |
"SNN_chembl": snn_chembl, |
|
|
344 |
"SNN_akt": snn_akt, |
|
|
345 |
"MaxLen": maxlen, |
|
|
346 |
"Atom_types": atom_types_average |
|
|
347 |
}) |
|
|
348 |
|
|
|
349 |
# Append each metric to the log string and write to the log file |
|
|
350 |
for tag, value in loss.items(): |
|
|
351 |
log_str += ", {}: {:.4f}".format(tag, value) |
|
|
352 |
with open(log_path, "a") as f: |
|
|
353 |
f.write(log_str + "\n") |
|
|
354 |
print(log_str) |
|
|
355 |
print("\n") |
|
|
356 |
|
|
|
357 |
|
|
|
358 |
def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory): |
|
|
359 |
""" |
|
|
360 |
Plots the gradients flowing through different layers during training. |
|
|
361 |
|
|
|
362 |
This is useful to check for possible gradient vanishing or exploding problems. |
|
|
363 |
|
|
|
364 |
Args: |
|
|
365 |
named_parameters (iterable): Iterable of (name, parameter) tuples from the model. |
|
|
366 |
model (str): Name of the model (used for saving the plot). |
|
|
367 |
itera (int): Iteration index. |
|
|
368 |
epoch (int): Current epoch. |
|
|
369 |
grad_flow_directory (str): Directory to save the gradient flow plot. |
|
|
370 |
""" |
|
|
371 |
ave_grads = [] |
|
|
372 |
max_grads = [] |
|
|
373 |
layers = [] |
|
|
374 |
for n, p in named_parameters: |
|
|
375 |
if p.requires_grad and ("bias" not in n): |
|
|
376 |
layers.append(n) |
|
|
377 |
ave_grads.append(p.grad.abs().mean().cpu()) |
|
|
378 |
max_grads.append(p.grad.abs().max().cpu()) |
|
|
379 |
# Plot maximum gradients and average gradients for each layer |
|
|
380 |
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") |
|
|
381 |
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") |
|
|
382 |
plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") |
|
|
383 |
plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") |
|
|
384 |
plt.xlim(left=0, right=len(ave_grads)) |
|
|
385 |
plt.ylim(bottom=-0.001, top=1) # Zoom in on lower gradient regions |
|
|
386 |
plt.xlabel("Layers") |
|
|
387 |
plt.ylabel("Average Gradient") |
|
|
388 |
plt.title("Gradient Flow") |
|
|
389 |
plt.grid(True) |
|
|
390 |
plt.legend([ |
|
|
391 |
Line2D([0], [0], color="c", lw=4), |
|
|
392 |
Line2D([0], [0], color="b", lw=4), |
|
|
393 |
Line2D([0], [0], color="k", lw=4) |
|
|
394 |
], ['max-gradient', 'mean-gradient', 'zero-gradient']) |
|
|
395 |
# Save the plot to the specified directory |
|
|
396 |
plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight') |
|
|
397 |
|
|
|
398 |
|
|
|
399 |
def get_mol(smiles_or_mol): |
|
|
400 |
""" |
|
|
401 |
Loads a SMILES string or molecule into an RDKit molecule object. |
|
|
402 |
|
|
|
403 |
Args: |
|
|
404 |
smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule. |
|
|
405 |
|
|
|
406 |
Returns: |
|
|
407 |
RDKit Mol or None: Sanitized molecule object, or None if invalid. |
|
|
408 |
""" |
|
|
409 |
if isinstance(smiles_or_mol, str): |
|
|
410 |
if len(smiles_or_mol) == 0: |
|
|
411 |
return None |
|
|
412 |
mol = Chem.MolFromSmiles(smiles_or_mol) |
|
|
413 |
if mol is None: |
|
|
414 |
return None |
|
|
415 |
try: |
|
|
416 |
Chem.SanitizeMol(mol) |
|
|
417 |
except ValueError: |
|
|
418 |
return None |
|
|
419 |
return mol |
|
|
420 |
return smiles_or_mol |
|
|
421 |
|
|
|
422 |
|
|
|
423 |
def mapper(n_jobs): |
|
|
424 |
""" |
|
|
425 |
Returns a mapping function for parallel or serial processing. |
|
|
426 |
|
|
|
427 |
If n_jobs == 1, returns the built-in map function. |
|
|
428 |
If n_jobs > 1, returns a function that uses a multiprocessing pool. |
|
|
429 |
|
|
|
430 |
Args: |
|
|
431 |
n_jobs (int or pool object): Number of jobs or a Pool instance. |
|
|
432 |
|
|
|
433 |
Returns: |
|
|
434 |
callable: A function that acts like map. |
|
|
435 |
""" |
|
|
436 |
if n_jobs == 1: |
|
|
437 |
def _mapper(*args, **kwargs): |
|
|
438 |
return list(map(*args, **kwargs)) |
|
|
439 |
return _mapper |
|
|
440 |
if isinstance(n_jobs, int): |
|
|
441 |
pool = Pool(n_jobs) |
|
|
442 |
def _mapper(*args, **kwargs): |
|
|
443 |
try: |
|
|
444 |
result = pool.map(*args, **kwargs) |
|
|
445 |
finally: |
|
|
446 |
pool.terminate() |
|
|
447 |
return result |
|
|
448 |
return _mapper |
|
|
449 |
return n_jobs.map |
|
|
450 |
|
|
|
451 |
|
|
|
452 |
def remove_invalid(gen, canonize=True, n_jobs=1): |
|
|
453 |
""" |
|
|
454 |
Removes invalid molecules from the provided dataset. |
|
|
455 |
|
|
|
456 |
Optionally canonizes the SMILES strings. |
|
|
457 |
|
|
|
458 |
Args: |
|
|
459 |
gen (list): List of SMILES strings. |
|
|
460 |
canonize (bool): Whether to convert to canonical SMILES. |
|
|
461 |
n_jobs (int): Number of parallel jobs. |
|
|
462 |
|
|
|
463 |
Returns: |
|
|
464 |
list: Filtered list of valid molecules. |
|
|
465 |
""" |
|
|
466 |
if not canonize: |
|
|
467 |
mols = mapper(n_jobs)(get_mol, gen) |
|
|
468 |
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None] |
|
|
469 |
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None] |
|
|
470 |
|
|
|
471 |
|
|
|
472 |
def fraction_valid(gen, n_jobs=1): |
|
|
473 |
""" |
|
|
474 |
Computes the fraction of valid molecules in the dataset. |
|
|
475 |
|
|
|
476 |
Args: |
|
|
477 |
gen (list): List of SMILES strings. |
|
|
478 |
n_jobs (int): Number of parallel jobs. |
|
|
479 |
|
|
|
480 |
Returns: |
|
|
481 |
float: Fraction of molecules that are valid. |
|
|
482 |
""" |
|
|
483 |
gen = mapper(n_jobs)(get_mol, gen) |
|
|
484 |
return 1 - gen.count(None) / len(gen) |
|
|
485 |
|
|
|
486 |
|
|
|
487 |
def canonic_smiles(smiles_or_mol): |
|
|
488 |
""" |
|
|
489 |
Converts a SMILES string or molecule to its canonical SMILES. |
|
|
490 |
|
|
|
491 |
Args: |
|
|
492 |
smiles_or_mol (str or RDKit Mol): Input molecule. |
|
|
493 |
|
|
|
494 |
Returns: |
|
|
495 |
str or None: Canonical SMILES string or None if invalid. |
|
|
496 |
""" |
|
|
497 |
mol = get_mol(smiles_or_mol) |
|
|
498 |
if mol is None: |
|
|
499 |
return None |
|
|
500 |
return Chem.MolToSmiles(mol) |
|
|
501 |
|
|
|
502 |
|
|
|
503 |
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True): |
|
|
504 |
""" |
|
|
505 |
Computes the fraction of unique molecules. |
|
|
506 |
|
|
|
507 |
Optionally computes unique@k, where only the first k molecules are considered. |
|
|
508 |
|
|
|
509 |
Args: |
|
|
510 |
gen (list): List of SMILES strings. |
|
|
511 |
k (int): Optional cutoff for unique@k computation. |
|
|
512 |
n_jobs (int): Number of parallel jobs. |
|
|
513 |
check_validity (bool): Whether to check for validity of molecules. |
|
|
514 |
|
|
|
515 |
Returns: |
|
|
516 |
float: Fraction of unique molecules. |
|
|
517 |
""" |
|
|
518 |
if k is not None: |
|
|
519 |
if len(gen) < k: |
|
|
520 |
warnings.warn("Can't compute unique@{}.".format(k) + |
|
|
521 |
" gen contains only {} molecules".format(len(gen))) |
|
|
522 |
gen = gen[:k] |
|
|
523 |
if check_validity: |
|
|
524 |
canonic = list(mapper(n_jobs)(canonic_smiles, gen)) |
|
|
525 |
canonic = [i for i in canonic if i is not None] |
|
|
526 |
set_cannonic = set(canonic) |
|
|
527 |
return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic) |
|
|
528 |
|
|
|
529 |
|
|
|
530 |
def novelty(gen, train, n_jobs=1): |
|
|
531 |
""" |
|
|
532 |
Computes the novelty score of generated molecules. |
|
|
533 |
|
|
|
534 |
Novelty is defined as the fraction of generated molecules that do not appear in the training set. |
|
|
535 |
|
|
|
536 |
Args: |
|
|
537 |
gen (list): List of generated SMILES strings. |
|
|
538 |
train (list): List of training SMILES strings. |
|
|
539 |
n_jobs (int): Number of parallel jobs. |
|
|
540 |
|
|
|
541 |
Returns: |
|
|
542 |
float: Novelty score. |
|
|
543 |
""" |
|
|
544 |
gen_smiles = mapper(n_jobs)(canonic_smiles, gen) |
|
|
545 |
gen_smiles_set = set(gen_smiles) - {None} |
|
|
546 |
train_set = set(train) |
|
|
547 |
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set) |
|
|
548 |
|
|
|
549 |
|
|
|
550 |
def internal_diversity(gen): |
|
|
551 |
""" |
|
|
552 |
Computes the internal diversity of a set of molecules. |
|
|
553 |
|
|
|
554 |
Internal diversity is defined as one minus the average Tanimoto similarity between all pairs. |
|
|
555 |
|
|
|
556 |
Args: |
|
|
557 |
gen: Array-like representation of molecules. |
|
|
558 |
|
|
|
559 |
Returns: |
|
|
560 |
tuple: Mean and standard deviation of internal diversity. |
|
|
561 |
""" |
|
|
562 |
diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)] |
|
|
563 |
return np.mean(diversity), np.std(diversity) |
|
|
564 |
|
|
|
565 |
|
|
|
566 |
def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False): |
|
|
567 |
""" |
|
|
568 |
Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints. |
|
|
569 |
|
|
|
570 |
For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs. |
|
|
571 |
|
|
|
572 |
Args: |
|
|
573 |
stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set. |
|
|
574 |
gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set. |
|
|
575 |
batch_size (int): Batch size for processing fingerprints. |
|
|
576 |
agg (str): Aggregation method, either 'max' or 'mean'. |
|
|
577 |
device (str): Device to perform computations on. |
|
|
578 |
p (int): Power for averaging. |
|
|
579 |
intdiv (bool): Whether to return individual similarities or the average. |
|
|
580 |
|
|
|
581 |
Returns: |
|
|
582 |
float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores. |
|
|
583 |
""" |
|
|
584 |
assert agg in ['max', 'mean'], "Can aggregate only max or mean" |
|
|
585 |
agg_tanimoto = np.zeros(len(gen_vecs)) |
|
|
586 |
total = np.zeros(len(gen_vecs)) |
|
|
587 |
for j in range(0, stock_vecs.shape[0], batch_size): |
|
|
588 |
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() |
|
|
589 |
for i in range(0, gen_vecs.shape[0], batch_size): |
|
|
590 |
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() |
|
|
591 |
y_gen = y_gen.transpose(0, 1) |
|
|
592 |
tp = torch.mm(x_stock, y_gen) |
|
|
593 |
# Compute Jaccard/Tanimoto similarity |
|
|
594 |
jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() |
|
|
595 |
jac[np.isnan(jac)] = 1 |
|
|
596 |
if p != 1: |
|
|
597 |
jac = jac ** p |
|
|
598 |
if agg == 'max': |
|
|
599 |
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum( |
|
|
600 |
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) |
|
|
601 |
elif agg == 'mean': |
|
|
602 |
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) |
|
|
603 |
total[i:i + y_gen.shape[1]] += jac.shape[0] |
|
|
604 |
if agg == 'mean': |
|
|
605 |
agg_tanimoto /= total |
|
|
606 |
if p != 1: |
|
|
607 |
agg_tanimoto = (agg_tanimoto) ** (1 / p) |
|
|
608 |
if intdiv: |
|
|
609 |
return agg_tanimoto |
|
|
610 |
else: |
|
|
611 |
return np.mean(agg_tanimoto) |
|
|
612 |
|
|
|
613 |
|
|
|
614 |
def str2bool(v): |
|
|
615 |
""" |
|
|
616 |
Converts a string to a boolean. |
|
|
617 |
|
|
|
618 |
Args: |
|
|
619 |
v (str): Input string. |
|
|
620 |
|
|
|
621 |
Returns: |
|
|
622 |
bool: True if the string is 'true' (case insensitive), else False. |
|
|
623 |
""" |
|
|
624 |
return v.lower() in ('true') |
|
|
625 |
|
|
|
626 |
|
|
|
627 |
def obey_lipinski(mol): |
|
|
628 |
""" |
|
|
629 |
Checks if a molecule obeys Lipinski's Rule of Five. |
|
|
630 |
|
|
|
631 |
The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds. |
|
|
632 |
|
|
|
633 |
Args: |
|
|
634 |
mol (RDKit Mol): Molecule object. |
|
|
635 |
|
|
|
636 |
Returns: |
|
|
637 |
int: Number of Lipinski rules satisfied. |
|
|
638 |
""" |
|
|
639 |
mol = deepcopy(mol) |
|
|
640 |
Chem.SanitizeMol(mol) |
|
|
641 |
rule_1 = Descriptors.ExactMolWt(mol) < 500 |
|
|
642 |
rule_2 = Lipinski.NumHDonors(mol) <= 5 |
|
|
643 |
rule_3 = Lipinski.NumHAcceptors(mol) <= 10 |
|
|
644 |
rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5) |
|
|
645 |
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10 |
|
|
646 |
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
|
|
647 |
|
|
|
648 |
|
|
|
649 |
def obey_veber(mol): |
|
|
650 |
""" |
|
|
651 |
Checks if a molecule obeys Veber's rules. |
|
|
652 |
|
|
|
653 |
Veber's rules focus on the number of rotatable bonds and topological polar surface area. |
|
|
654 |
|
|
|
655 |
Args: |
|
|
656 |
mol (RDKit Mol): Molecule object. |
|
|
657 |
|
|
|
658 |
Returns: |
|
|
659 |
int: Number of Veber's rules satisfied. |
|
|
660 |
""" |
|
|
661 |
mol = deepcopy(mol) |
|
|
662 |
Chem.SanitizeMol(mol) |
|
|
663 |
rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10 |
|
|
664 |
rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140 |
|
|
665 |
return np.sum([int(a) for a in [rule_1, rule_2]]) |
|
|
666 |
|
|
|
667 |
|
|
|
668 |
def load_pains_filters(): |
|
|
669 |
""" |
|
|
670 |
Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C. |
|
|
671 |
|
|
|
672 |
Returns: |
|
|
673 |
FilterCatalog: An RDKit FilterCatalog object containing PAINS filters. |
|
|
674 |
""" |
|
|
675 |
params = FilterCatalog.FilterCatalogParams() |
|
|
676 |
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A) |
|
|
677 |
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B) |
|
|
678 |
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C) |
|
|
679 |
catalog = FilterCatalog.FilterCatalog(params) |
|
|
680 |
return catalog |
|
|
681 |
|
|
|
682 |
|
|
|
683 |
def is_pains(mol, catalog): |
|
|
684 |
""" |
|
|
685 |
Checks if the given molecule is a PAINS compound. |
|
|
686 |
|
|
|
687 |
Args: |
|
|
688 |
mol (RDKit Mol): Molecule object. |
|
|
689 |
catalog (FilterCatalog): A catalog of PAINS filters. |
|
|
690 |
|
|
|
691 |
Returns: |
|
|
692 |
bool: True if the molecule matches a PAINS filter, else False. |
|
|
693 |
""" |
|
|
694 |
entry = catalog.GetFirstMatch(mol) |
|
|
695 |
return entry is not None |
|
|
696 |
|
|
|
697 |
|
|
|
698 |
def mapper(n_jobs): |
|
|
699 |
""" |
|
|
700 |
Returns a mapping function for parallel or serial processing. |
|
|
701 |
|
|
|
702 |
If n_jobs == 1, returns the built-in map function. |
|
|
703 |
If n_jobs > 1, returns a function that uses a multiprocessing pool. |
|
|
704 |
|
|
|
705 |
Args: |
|
|
706 |
n_jobs (int or pool object): Number of jobs or a Pool instance. |
|
|
707 |
|
|
|
708 |
Returns: |
|
|
709 |
callable: A function that acts like map. |
|
|
710 |
""" |
|
|
711 |
if n_jobs == 1: |
|
|
712 |
def _mapper(*args, **kwargs): |
|
|
713 |
return list(map(*args, **kwargs)) |
|
|
714 |
return _mapper |
|
|
715 |
if isinstance(n_jobs, int): |
|
|
716 |
pool = Pool(n_jobs) |
|
|
717 |
def _mapper(*args, **kwargs): |
|
|
718 |
try: |
|
|
719 |
result = pool.map(*args, **kwargs) |
|
|
720 |
finally: |
|
|
721 |
pool.terminate() |
|
|
722 |
return result |
|
|
723 |
return _mapper |
|
|
724 |
return n_jobs.map |
|
|
725 |
|
|
|
726 |
|
|
|
727 |
def fragmenter(mol): |
|
|
728 |
""" |
|
|
729 |
Fragments a molecule using BRICS and returns a list of fragment SMILES. |
|
|
730 |
|
|
|
731 |
Args: |
|
|
732 |
mol (str or RDKit Mol): Input molecule. |
|
|
733 |
|
|
|
734 |
Returns: |
|
|
735 |
list: List of fragment SMILES strings. |
|
|
736 |
""" |
|
|
737 |
fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol)) |
|
|
738 |
fgs_smi = Chem.MolToSmiles(fgs).split(".") |
|
|
739 |
return fgs_smi |
|
|
740 |
|
|
|
741 |
|
|
|
742 |
def get_mol(smiles_or_mol): |
|
|
743 |
""" |
|
|
744 |
Loads a SMILES string or molecule into an RDKit molecule object. |
|
|
745 |
|
|
|
746 |
Args: |
|
|
747 |
smiles_or_mol (str or RDKit Mol): SMILES string or molecule. |
|
|
748 |
|
|
|
749 |
Returns: |
|
|
750 |
RDKit Mol or None: Sanitized molecule object or None if invalid. |
|
|
751 |
""" |
|
|
752 |
if isinstance(smiles_or_mol, str): |
|
|
753 |
if len(smiles_or_mol) == 0: |
|
|
754 |
return None |
|
|
755 |
mol = Chem.MolFromSmiles(smiles_or_mol) |
|
|
756 |
if mol is None: |
|
|
757 |
return None |
|
|
758 |
try: |
|
|
759 |
Chem.SanitizeMol(mol) |
|
|
760 |
except ValueError: |
|
|
761 |
return None |
|
|
762 |
return mol |
|
|
763 |
return smiles_or_mol |
|
|
764 |
|
|
|
765 |
|
|
|
766 |
def compute_fragments(mol_list, n_jobs=1): |
|
|
767 |
""" |
|
|
768 |
Fragments a list of molecules using BRICS and returns a counter of fragment occurrences. |
|
|
769 |
|
|
|
770 |
Args: |
|
|
771 |
mol_list (list): List of molecules (SMILES or RDKit Mol). |
|
|
772 |
n_jobs (int): Number of parallel jobs. |
|
|
773 |
|
|
|
774 |
Returns: |
|
|
775 |
Counter: A Counter dictionary mapping fragment SMILES to counts. |
|
|
776 |
""" |
|
|
777 |
fragments = Counter() |
|
|
778 |
for mol_frag in mapper(n_jobs)(fragmenter, mol_list): |
|
|
779 |
fragments.update(mol_frag) |
|
|
780 |
return fragments |
|
|
781 |
|
|
|
782 |
|
|
|
783 |
def compute_scaffolds(mol_list, n_jobs=1, min_rings=2): |
|
|
784 |
""" |
|
|
785 |
Extracts scaffolds from a list of molecules as canonical SMILES. |
|
|
786 |
|
|
|
787 |
Only scaffolds with at least min_rings rings are considered. |
|
|
788 |
|
|
|
789 |
Args: |
|
|
790 |
mol_list (list): List of molecules. |
|
|
791 |
n_jobs (int): Number of parallel jobs. |
|
|
792 |
min_rings (int): Minimum number of rings required in a scaffold. |
|
|
793 |
|
|
|
794 |
Returns: |
|
|
795 |
Counter: A Counter mapping scaffold SMILES to counts. |
|
|
796 |
""" |
|
|
797 |
scaffolds = Counter() |
|
|
798 |
map_ = mapper(n_jobs) |
|
|
799 |
scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list)) |
|
|
800 |
if None in scaffolds: |
|
|
801 |
scaffolds.pop(None) |
|
|
802 |
return scaffolds |
|
|
803 |
|
|
|
804 |
|
|
|
805 |
def get_n_rings(mol): |
|
|
806 |
""" |
|
|
807 |
Computes the number of rings in a molecule. |
|
|
808 |
|
|
|
809 |
Args: |
|
|
810 |
mol (RDKit Mol): Molecule object. |
|
|
811 |
|
|
|
812 |
Returns: |
|
|
813 |
int: Number of rings. |
|
|
814 |
""" |
|
|
815 |
return mol.GetRingInfo().NumRings() |
|
|
816 |
|
|
|
817 |
|
|
|
818 |
def compute_scaffold(mol, min_rings=2): |
|
|
819 |
""" |
|
|
820 |
Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings. |
|
|
821 |
|
|
|
822 |
Args: |
|
|
823 |
mol (str or RDKit Mol): Input molecule. |
|
|
824 |
min_rings (int): Minimum number of rings required. |
|
|
825 |
|
|
|
826 |
Returns: |
|
|
827 |
str or None: Canonical SMILES of the scaffold if valid, else None. |
|
|
828 |
""" |
|
|
829 |
mol = get_mol(mol) |
|
|
830 |
try: |
|
|
831 |
scaffold = MurckoScaffold.GetScaffoldForMol(mol) |
|
|
832 |
except (ValueError, RuntimeError): |
|
|
833 |
return None |
|
|
834 |
n_rings = get_n_rings(scaffold) |
|
|
835 |
scaffold_smiles = Chem.MolToSmiles(scaffold) |
|
|
836 |
if scaffold_smiles == '' or n_rings < min_rings: |
|
|
837 |
return None |
|
|
838 |
return scaffold_smiles |
|
|
839 |
|
|
|
840 |
|
|
|
841 |
class Metric: |
|
|
842 |
""" |
|
|
843 |
Abstract base class for chemical metrics. |
|
|
844 |
|
|
|
845 |
Derived classes should implement the precalc and metric methods. |
|
|
846 |
""" |
|
|
847 |
def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs): |
|
|
848 |
self.n_jobs = n_jobs |
|
|
849 |
self.device = device |
|
|
850 |
self.batch_size = batch_size |
|
|
851 |
for k, v in kwargs.items(): |
|
|
852 |
setattr(self, k, v) |
|
|
853 |
|
|
|
854 |
def __call__(self, ref=None, gen=None, pref=None, pgen=None): |
|
|
855 |
""" |
|
|
856 |
Computes the metric between reference and generated molecules. |
|
|
857 |
|
|
|
858 |
Exactly one of ref or pref, and gen or pgen should be provided. |
|
|
859 |
|
|
|
860 |
Args: |
|
|
861 |
ref: Reference molecule list. |
|
|
862 |
gen: Generated molecule list. |
|
|
863 |
pref: Precalculated reference metric. |
|
|
864 |
pgen: Precalculated generated metric. |
|
|
865 |
|
|
|
866 |
Returns: |
|
|
867 |
Metric value computed by the metric method. |
|
|
868 |
""" |
|
|
869 |
assert (ref is None) != (pref is None), "specify ref xor pref" |
|
|
870 |
assert (gen is None) != (pgen is None), "specify gen xor pgen" |
|
|
871 |
if pref is None: |
|
|
872 |
pref = self.precalc(ref) |
|
|
873 |
if pgen is None: |
|
|
874 |
pgen = self.precalc(gen) |
|
|
875 |
return self.metric(pref, pgen) |
|
|
876 |
|
|
|
877 |
def precalc(self, molecules): |
|
|
878 |
""" |
|
|
879 |
Pre-calculates necessary representations from a list of molecules. |
|
|
880 |
Should be implemented by derived classes. |
|
|
881 |
""" |
|
|
882 |
raise NotImplementedError |
|
|
883 |
|
|
|
884 |
def metric(self, pref, pgen): |
|
|
885 |
""" |
|
|
886 |
Computes the metric given precalculated representations. |
|
|
887 |
Should be implemented by derived classes. |
|
|
888 |
""" |
|
|
889 |
raise NotImplementedError |
|
|
890 |
|
|
|
891 |
|
|
|
892 |
class FragMetric(Metric): |
|
|
893 |
""" |
|
|
894 |
Metrics based on molecular fragments. |
|
|
895 |
""" |
|
|
896 |
def precalc(self, mols): |
|
|
897 |
return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)} |
|
|
898 |
|
|
|
899 |
def metric(self, pref, pgen): |
|
|
900 |
return cos_similarity(pref['frag'], pgen['frag']) |
|
|
901 |
|
|
|
902 |
|
|
|
903 |
class ScafMetric(Metric): |
|
|
904 |
""" |
|
|
905 |
Metrics based on molecular scaffolds. |
|
|
906 |
""" |
|
|
907 |
def precalc(self, mols): |
|
|
908 |
return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)} |
|
|
909 |
|
|
|
910 |
def metric(self, pref, pgen): |
|
|
911 |
return cos_similarity(pref['scaf'], pgen['scaf']) |
|
|
912 |
|
|
|
913 |
|
|
|
914 |
def cos_similarity(ref_counts, gen_counts): |
|
|
915 |
""" |
|
|
916 |
Computes cosine similarity between two molecular vectors. |
|
|
917 |
|
|
|
918 |
Args: |
|
|
919 |
ref_counts (dict): Reference molecular vectors. |
|
|
920 |
gen_counts (dict): Generated molecular vectors. |
|
|
921 |
|
|
|
922 |
Returns: |
|
|
923 |
float: Cosine similarity between the two molecular vectors. |
|
|
924 |
""" |
|
|
925 |
if len(ref_counts) == 0 or len(gen_counts) == 0: |
|
|
926 |
return np.nan |
|
|
927 |
keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys())) |
|
|
928 |
ref_vec = np.array([ref_counts.get(k, 0) for k in keys]) |
|
|
929 |
gen_vec = np.array([gen_counts.get(k, 0) for k in keys]) |
|
|
930 |
return 1 - cos_distance(ref_vec, gen_vec) |