|
a |
|
b/results/evaluate.py |
|
|
1 |
import json |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
from rdkit import Chem |
|
|
5 |
from rdkit.Chem import QED, AllChem |
|
|
6 |
from rdkit.Chem import Descriptors |
|
|
7 |
from rdkit.Chem import RDConfig |
|
|
8 |
import os |
|
|
9 |
import sys |
|
|
10 |
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score')) |
|
|
11 |
import sascorer |
|
|
12 |
from fcd_torch import FCD # You'll need to install fcd_torch |
|
|
13 |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # Add parent directory to path |
|
|
14 |
from src.util.utils import (fraction_valid, fraction_unique, novelty, |
|
|
15 |
internal_diversity,obey_lipinski, obey_veber, load_pains_filters, |
|
|
16 |
is_pains, FragMetric, ScafMetric) |
|
|
17 |
import torch |
|
|
18 |
|
|
|
19 |
class MoleculeEvaluator: |
|
|
20 |
def __init__(self, gen_smiles, ref_smiles_1, ref_smiles_2=None, n_jobs=8): |
|
|
21 |
""" |
|
|
22 |
Initialize evaluator with generated and reference SMILES |
|
|
23 |
ref_smiles_2 is optional |
|
|
24 |
""" |
|
|
25 |
self.gen_smiles = gen_smiles |
|
|
26 |
self.ref_smiles_1 = ref_smiles_1 |
|
|
27 |
self.ref_smiles_2 = ref_smiles_2 |
|
|
28 |
self.n_jobs = n_jobs |
|
|
29 |
|
|
|
30 |
# Convert SMILES to RDKit molecules and filter out invalid ones |
|
|
31 |
self.gen_mols = [mol for s in gen_smiles if s and (mol := Chem.MolFromSmiles(s)) is not None] |
|
|
32 |
self.ref_mols_1 = [mol for s in ref_smiles_1 if s and (mol := Chem.MolFromSmiles(s)) is not None] |
|
|
33 |
self.ref_mols_2 = [mol for s in ref_smiles_2 if s and (mol := Chem.MolFromSmiles(s)) is not None] if ref_smiles_2 else None |
|
|
34 |
|
|
|
35 |
# Initialize metrics that need setup |
|
|
36 |
self.fcd = FCD(device='cuda' if torch.cuda.is_available() else 'cpu') |
|
|
37 |
self.pains_catalog = load_pains_filters() |
|
|
38 |
self.frag_metric = FragMetric(n_jobs=1) |
|
|
39 |
self.scaf_metric = ScafMetric(n_jobs=1) |
|
|
40 |
|
|
|
41 |
def calculate_basic_metrics(self): |
|
|
42 |
"""Calculate validity, uniqueness, novelty, and internal diversity""" |
|
|
43 |
# Generate Morgan fingerprints for internal diversity calculation |
|
|
44 |
fps = np.array([AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024) for mol in self.gen_mols if mol is not None]) |
|
|
45 |
internal_diversity_mean, internal_diversity_std = internal_diversity(fps) |
|
|
46 |
results = { |
|
|
47 |
'validity': fraction_valid(self.gen_smiles, n_jobs=self.n_jobs), |
|
|
48 |
'uniqueness': fraction_unique(self.gen_smiles, n_jobs=self.n_jobs), |
|
|
49 |
'novelty_ref1': novelty(self.gen_smiles, self.ref_smiles_1, n_jobs=self.n_jobs), |
|
|
50 |
'internal_diversity': internal_diversity_mean, |
|
|
51 |
'internal_diversity_std': internal_diversity_std |
|
|
52 |
} |
|
|
53 |
if self.ref_smiles_2: |
|
|
54 |
results['novelty_ref2'] = novelty(self.gen_smiles, self.ref_smiles_2, n_jobs=self.n_jobs) |
|
|
55 |
return results |
|
|
56 |
|
|
|
57 |
def calculate_property_metrics(self): |
|
|
58 |
"""Calculate QED and SA scores""" |
|
|
59 |
qed_scores = [QED.qed(mol) for mol in self.gen_mols if mol is not None] |
|
|
60 |
sa_scores = [sascorer.calculateScore(mol) for mol in self.gen_mols if mol is not None] |
|
|
61 |
|
|
|
62 |
return { |
|
|
63 |
'qed_mean': np.mean(qed_scores), |
|
|
64 |
'qed_std': np.std(qed_scores), |
|
|
65 |
'sa_mean': np.mean(sa_scores), |
|
|
66 |
'sa_std': np.std(sa_scores) |
|
|
67 |
} |
|
|
68 |
|
|
|
69 |
def calculate_fcd(self): |
|
|
70 |
"""Calculate FCD score against both reference sets""" |
|
|
71 |
# Filter out None values and convert mols back to SMILES |
|
|
72 |
gen_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.gen_mols if mol is not None] |
|
|
73 |
ref1_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_1 if mol is not None] |
|
|
74 |
|
|
|
75 |
results = { |
|
|
76 |
'fcd_ref1': self.fcd(gen_valid_smiles, ref1_valid_smiles) |
|
|
77 |
} |
|
|
78 |
|
|
|
79 |
if self.ref_mols_2: |
|
|
80 |
ref2_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_2 if mol is not None] |
|
|
81 |
results['fcd_ref2'] = self.fcd(gen_valid_smiles, ref2_valid_smiles) |
|
|
82 |
|
|
|
83 |
return results |
|
|
84 |
|
|
|
85 |
def calculate_similarity_metrics(self): |
|
|
86 |
"""Calculate fragment and scaffold similarity""" |
|
|
87 |
|
|
|
88 |
results = { |
|
|
89 |
'frag_sim_ref1': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_1), |
|
|
90 |
'scaf_sim_ref1': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_1) |
|
|
91 |
} |
|
|
92 |
if self.ref_mols_2: |
|
|
93 |
results.update({ |
|
|
94 |
'frag_sim_ref2': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_2), |
|
|
95 |
'scaf_sim_ref2': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_2) |
|
|
96 |
}) |
|
|
97 |
return results |
|
|
98 |
|
|
|
99 |
def calculate_drug_likeness(self): |
|
|
100 |
"""Calculate Lipinski, Veber and PAINS filtering results""" |
|
|
101 |
lipinski_scores = [obey_lipinski(mol) for mol in self.gen_mols if mol is not None] |
|
|
102 |
veber_scores = [obey_veber(mol) for mol in self.gen_mols if mol is not None] |
|
|
103 |
pains_results = [not is_pains(mol, self.pains_catalog) for mol in self.gen_mols if mol is not None] |
|
|
104 |
|
|
|
105 |
return { |
|
|
106 |
'lipinski_mean': np.mean(lipinski_scores), |
|
|
107 |
'lipinski_std': np.std(lipinski_scores), |
|
|
108 |
'veber_mean': np.mean(veber_scores), |
|
|
109 |
'veber_std': np.std(veber_scores), |
|
|
110 |
'pains_pass_rate': np.mean(pains_results) |
|
|
111 |
} |
|
|
112 |
|
|
|
113 |
def evaluate_all(self): |
|
|
114 |
"""Run all evaluations and combine results""" |
|
|
115 |
results = {} |
|
|
116 |
|
|
|
117 |
print("\nCalculating basic metrics...") |
|
|
118 |
basic_metrics = self.calculate_basic_metrics() |
|
|
119 |
print("Basic metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in basic_metrics.items()}) |
|
|
120 |
results.update(basic_metrics) |
|
|
121 |
|
|
|
122 |
print("\nCalculating property metrics...") |
|
|
123 |
property_metrics = self.calculate_property_metrics() |
|
|
124 |
print("Property metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in property_metrics.items()}) |
|
|
125 |
results.update(property_metrics) |
|
|
126 |
|
|
|
127 |
print("\nCalculating FCD scores...") |
|
|
128 |
fcd_metrics = self.calculate_fcd() |
|
|
129 |
print("FCD metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in fcd_metrics.items()}) |
|
|
130 |
results.update(fcd_metrics) |
|
|
131 |
|
|
|
132 |
print("\nCalculating similarity metrics...") |
|
|
133 |
similarity_metrics = self.calculate_similarity_metrics() |
|
|
134 |
print("Similarity metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in similarity_metrics.items()}) |
|
|
135 |
results.update(similarity_metrics) |
|
|
136 |
|
|
|
137 |
print("\nCalculating drug-likeness metrics...") |
|
|
138 |
drug_likeness = self.calculate_drug_likeness() |
|
|
139 |
print("Drug-likeness metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in drug_likeness.items()}) |
|
|
140 |
results.update(drug_likeness) |
|
|
141 |
|
|
|
142 |
return results |
|
|
143 |
|
|
|
144 |
def read_smi_file(file_path): |
|
|
145 |
""" |
|
|
146 |
Read SMILES strings from a .smi file |
|
|
147 |
|
|
|
148 |
Args: |
|
|
149 |
file_path (str): Path to .smi file |
|
|
150 |
|
|
|
151 |
Returns: |
|
|
152 |
list: List of SMILES strings |
|
|
153 |
""" |
|
|
154 |
smiles_list = [] |
|
|
155 |
try: |
|
|
156 |
with open(file_path, 'r') as f: |
|
|
157 |
for line in f: |
|
|
158 |
# Handle different .smi formats: |
|
|
159 |
# Some .smi files have just SMILES, others have SMILES followed by an identifier |
|
|
160 |
parts = line.strip().split() |
|
|
161 |
if parts: # Skip empty lines |
|
|
162 |
smiles = parts[0] # First part is always the SMILES |
|
|
163 |
smiles_list.append(smiles) |
|
|
164 |
except FileNotFoundError: |
|
|
165 |
raise FileNotFoundError(f"Could not find SMI file: {file_path}") |
|
|
166 |
except Exception as e: |
|
|
167 |
raise ValueError(f"Error reading SMI file {file_path}: {str(e)}") |
|
|
168 |
|
|
|
169 |
return smiles_list |
|
|
170 |
|
|
|
171 |
def evaluate_molecules_from_files(gen_path, ref_path_1, ref_path_2=None, smiles_col='SMILES', output_prefix="results", n_jobs=8): |
|
|
172 |
""" |
|
|
173 |
Main function to evaluate generated molecules against reference sets |
|
|
174 |
|
|
|
175 |
Args: |
|
|
176 |
gen_path (str): Path to CSV file containing generated SMILES |
|
|
177 |
ref_path_1 (str): Path to .smi file containing first reference set SMILES |
|
|
178 |
ref_path_2 (str, optional): Path to .smi file containing second reference set SMILES |
|
|
179 |
smiles_col (str): Name of column containing SMILES strings in the CSV file |
|
|
180 |
output_prefix (str): Prefix for output files |
|
|
181 |
n_jobs (int): Number of parallel jobs |
|
|
182 |
""" |
|
|
183 |
# Read generated SMILES from CSV file |
|
|
184 |
try: |
|
|
185 |
gen_df = pd.read_csv(gen_path) |
|
|
186 |
if smiles_col not in gen_df.columns: |
|
|
187 |
raise ValueError(f"SMILES column '{smiles_col}' not found in generated dataset") |
|
|
188 |
gen_smiles = gen_df[smiles_col].dropna().tolist() |
|
|
189 |
except FileNotFoundError: |
|
|
190 |
raise FileNotFoundError(f"Could not find generated CSV file: {gen_path}") |
|
|
191 |
except pd.errors.EmptyDataError: |
|
|
192 |
raise ValueError("Generated CSV file is empty") |
|
|
193 |
|
|
|
194 |
# Read reference SMILES from .smi files |
|
|
195 |
ref_smiles_1 = read_smi_file(ref_path_1) |
|
|
196 |
ref_smiles_2 = read_smi_file(ref_path_2) if ref_path_2 else None |
|
|
197 |
|
|
|
198 |
# Validate that we have SMILES to process |
|
|
199 |
if not gen_smiles: |
|
|
200 |
raise ValueError("No valid SMILES found in generated set") |
|
|
201 |
if not ref_smiles_1: |
|
|
202 |
raise ValueError("No valid SMILES found in reference set 1") |
|
|
203 |
if ref_path_2 and not ref_smiles_2: |
|
|
204 |
raise ValueError("No valid SMILES found in reference set 2") |
|
|
205 |
|
|
|
206 |
print(f"\nProcessing datasets:") |
|
|
207 |
print(f"Generated molecules: {len(gen_smiles)}") |
|
|
208 |
print(f"Reference set 1: {len(ref_smiles_1)}") |
|
|
209 |
if ref_smiles_2: |
|
|
210 |
print(f"Reference set 2: {len(ref_smiles_2)}") |
|
|
211 |
|
|
|
212 |
# Run evaluation |
|
|
213 |
evaluator = MoleculeEvaluator(gen_smiles, ref_smiles_1, ref_smiles_2, n_jobs=n_jobs) |
|
|
214 |
results = evaluator.evaluate_all() |
|
|
215 |
|
|
|
216 |
print("\nSaving results...") |
|
|
217 |
# Add dataset sizes to results |
|
|
218 |
results.update({ |
|
|
219 |
'n_generated': len(gen_smiles), |
|
|
220 |
'n_reference_1': len(ref_smiles_1), |
|
|
221 |
'n_reference_2': len(ref_smiles_2) if ref_smiles_2 is not None else 0 |
|
|
222 |
}) |
|
|
223 |
|
|
|
224 |
# Save results |
|
|
225 |
# Format float values to 3 decimal places for JSON |
|
|
226 |
formatted_results = {k: round(v, 3) if isinstance(v, float) else v for k, v in results.items()} |
|
|
227 |
with open(f"{output_prefix}.json", 'w') as f: |
|
|
228 |
json.dump(formatted_results, f, indent=4) |
|
|
229 |
|
|
|
230 |
# Create DataFrame with formatted results |
|
|
231 |
df = pd.DataFrame([formatted_results]) |
|
|
232 |
df.to_csv(f"{output_prefix}.csv", index=False) |
|
|
233 |
|
|
|
234 |
return results |
|
|
235 |
|
|
|
236 |
if __name__ == "__main__": |
|
|
237 |
import argparse |
|
|
238 |
|
|
|
239 |
parser = argparse.ArgumentParser(description='Evaluate generated molecules against reference sets') |
|
|
240 |
parser.add_argument('--gen', required=True, help='Path to CSV file with generated SMILES') |
|
|
241 |
parser.add_argument('--ref1', required=True, help='Path to .smi file with first reference set SMILES') |
|
|
242 |
parser.add_argument('--ref2', help='Path to .smi file with second reference set SMILES (optional)') |
|
|
243 |
parser.add_argument('--smiles-col', default='SMILES', help='Name of SMILES column in generated CSV file') |
|
|
244 |
parser.add_argument('--output', default='results', help='Prefix for output files') |
|
|
245 |
parser.add_argument('--n-jobs', type=int, default=8, help='Number of parallel jobs') |
|
|
246 |
|
|
|
247 |
args = parser.parse_args() |
|
|
248 |
|
|
|
249 |
try: |
|
|
250 |
results = evaluate_molecules_from_files( |
|
|
251 |
args.gen, |
|
|
252 |
args.ref1, |
|
|
253 |
args.ref2, |
|
|
254 |
smiles_col=args.smiles_col, |
|
|
255 |
output_prefix=args.output, |
|
|
256 |
n_jobs=args.n_jobs |
|
|
257 |
) |
|
|
258 |
print(f"Evaluation complete. Results saved to {args.output}.json and {args.output}.csv") |
|
|
259 |
|
|
|
260 |
except Exception as e: |
|
|
261 |
print(f"Error during evaluation: {str(e)}") |