Diff of /results/evaluate.py [000000] .. [7d53f6]

Switch to unified view

a b/results/evaluate.py
1
import json
2
import pandas as pd
3
import numpy as np
4
from rdkit import Chem
5
from rdkit.Chem import QED, AllChem
6
from rdkit.Chem import Descriptors
7
from rdkit.Chem import RDConfig
8
import os
9
import sys
10
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
11
import sascorer
12
from fcd_torch import FCD  # You'll need to install fcd_torch
13
sys.path.append(os.path.dirname(os.path.dirname(__file__)))  # Add parent directory to path
14
from src.util.utils import (fraction_valid, fraction_unique, novelty, 
15
                  internal_diversity,obey_lipinski, obey_veber, load_pains_filters, 
16
                     is_pains, FragMetric, ScafMetric)
17
import torch
18
19
class MoleculeEvaluator:
20
    def __init__(self, gen_smiles, ref_smiles_1, ref_smiles_2=None, n_jobs=8):
21
        """
22
        Initialize evaluator with generated and reference SMILES
23
        ref_smiles_2 is optional
24
        """
25
        self.gen_smiles = gen_smiles
26
        self.ref_smiles_1 = ref_smiles_1
27
        self.ref_smiles_2 = ref_smiles_2
28
        self.n_jobs = n_jobs
29
        
30
        # Convert SMILES to RDKit molecules and filter out invalid ones
31
        self.gen_mols = [mol for s in gen_smiles if s and (mol := Chem.MolFromSmiles(s)) is not None]
32
        self.ref_mols_1 = [mol for s in ref_smiles_1 if s and (mol := Chem.MolFromSmiles(s)) is not None]
33
        self.ref_mols_2 = [mol for s in ref_smiles_2 if s and (mol := Chem.MolFromSmiles(s)) is not None] if ref_smiles_2 else None
34
        
35
        # Initialize metrics that need setup
36
        self.fcd = FCD(device='cuda' if torch.cuda.is_available() else 'cpu')
37
        self.pains_catalog = load_pains_filters()
38
        self.frag_metric = FragMetric(n_jobs=1)
39
        self.scaf_metric = ScafMetric(n_jobs=1)
40
41
    def calculate_basic_metrics(self):
42
        """Calculate validity, uniqueness, novelty, and internal diversity"""
43
        # Generate Morgan fingerprints for internal diversity calculation
44
        fps = np.array([AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024) for mol in self.gen_mols if mol is not None])
45
        internal_diversity_mean, internal_diversity_std = internal_diversity(fps)
46
        results = {
47
            'validity': fraction_valid(self.gen_smiles, n_jobs=self.n_jobs),
48
            'uniqueness': fraction_unique(self.gen_smiles, n_jobs=self.n_jobs),
49
            'novelty_ref1': novelty(self.gen_smiles, self.ref_smiles_1, n_jobs=self.n_jobs),
50
            'internal_diversity': internal_diversity_mean, 
51
            'internal_diversity_std': internal_diversity_std
52
        }
53
        if self.ref_smiles_2:
54
            results['novelty_ref2'] = novelty(self.gen_smiles, self.ref_smiles_2, n_jobs=self.n_jobs)
55
        return results
56
57
    def calculate_property_metrics(self):
58
        """Calculate QED and SA scores"""
59
        qed_scores = [QED.qed(mol) for mol in self.gen_mols if mol is not None]
60
        sa_scores = [sascorer.calculateScore(mol) for mol in self.gen_mols if mol is not None]
61
        
62
        return {
63
            'qed_mean': np.mean(qed_scores),
64
            'qed_std': np.std(qed_scores),
65
            'sa_mean': np.mean(sa_scores),
66
            'sa_std': np.std(sa_scores)
67
        }
68
69
    def calculate_fcd(self):
70
        """Calculate FCD score against both reference sets"""
71
        # Filter out None values and convert mols back to SMILES
72
        gen_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.gen_mols if mol is not None]
73
        ref1_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_1 if mol is not None]
74
        
75
        results = {
76
            'fcd_ref1': self.fcd(gen_valid_smiles, ref1_valid_smiles)
77
        }
78
        
79
        if self.ref_mols_2:
80
            ref2_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_2 if mol is not None]
81
            results['fcd_ref2'] = self.fcd(gen_valid_smiles, ref2_valid_smiles)
82
            
83
        return results
84
85
    def calculate_similarity_metrics(self):
86
        """Calculate fragment and scaffold similarity"""
87
88
        results = {
89
            'frag_sim_ref1': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_1),
90
            'scaf_sim_ref1': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_1)
91
        }
92
        if self.ref_mols_2:
93
            results.update({
94
                'frag_sim_ref2': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_2),
95
                'scaf_sim_ref2': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_2)
96
            })
97
        return results
98
99
    def calculate_drug_likeness(self):
100
        """Calculate Lipinski, Veber and PAINS filtering results"""
101
        lipinski_scores = [obey_lipinski(mol) for mol in self.gen_mols if mol is not None]
102
        veber_scores = [obey_veber(mol) for mol in self.gen_mols if mol is not None]
103
        pains_results = [not is_pains(mol, self.pains_catalog) for mol in self.gen_mols if mol is not None]
104
        
105
        return {
106
            'lipinski_mean': np.mean(lipinski_scores),
107
            'lipinski_std': np.std(lipinski_scores),
108
            'veber_mean': np.mean(veber_scores),
109
            'veber_std': np.std(veber_scores),
110
            'pains_pass_rate': np.mean(pains_results)
111
        }
112
113
    def evaluate_all(self):
114
        """Run all evaluations and combine results"""
115
        results = {}
116
        
117
        print("\nCalculating basic metrics...")
118
        basic_metrics = self.calculate_basic_metrics()
119
        print("Basic metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in basic_metrics.items()})
120
        results.update(basic_metrics)
121
        
122
        print("\nCalculating property metrics...")
123
        property_metrics = self.calculate_property_metrics()
124
        print("Property metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in property_metrics.items()})
125
        results.update(property_metrics)
126
        
127
        print("\nCalculating FCD scores...")
128
        fcd_metrics = self.calculate_fcd()
129
        print("FCD metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in fcd_metrics.items()})
130
        results.update(fcd_metrics)
131
        
132
        print("\nCalculating similarity metrics...")
133
        similarity_metrics = self.calculate_similarity_metrics()
134
        print("Similarity metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in similarity_metrics.items()})
135
        results.update(similarity_metrics)
136
        
137
        print("\nCalculating drug-likeness metrics...")
138
        drug_likeness = self.calculate_drug_likeness()
139
        print("Drug-likeness metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in drug_likeness.items()})
140
        results.update(drug_likeness)
141
        
142
        return results
143
144
def read_smi_file(file_path):
145
    """
146
    Read SMILES strings from a .smi file
147
    
148
    Args:
149
        file_path (str): Path to .smi file
150
        
151
    Returns:
152
        list: List of SMILES strings
153
    """
154
    smiles_list = []
155
    try:
156
        with open(file_path, 'r') as f:
157
            for line in f:
158
                # Handle different .smi formats:
159
                # Some .smi files have just SMILES, others have SMILES followed by an identifier
160
                parts = line.strip().split()
161
                if parts:  # Skip empty lines
162
                    smiles = parts[0]  # First part is always the SMILES
163
                    smiles_list.append(smiles)
164
    except FileNotFoundError:
165
        raise FileNotFoundError(f"Could not find SMI file: {file_path}")
166
    except Exception as e:
167
        raise ValueError(f"Error reading SMI file {file_path}: {str(e)}")
168
    
169
    return smiles_list
170
171
def evaluate_molecules_from_files(gen_path, ref_path_1, ref_path_2=None, smiles_col='SMILES', output_prefix="results", n_jobs=8):
172
    """
173
    Main function to evaluate generated molecules against reference sets
174
    
175
    Args:
176
        gen_path (str): Path to CSV file containing generated SMILES
177
        ref_path_1 (str): Path to .smi file containing first reference set SMILES
178
        ref_path_2 (str, optional): Path to .smi file containing second reference set SMILES
179
        smiles_col (str): Name of column containing SMILES strings in the CSV file
180
        output_prefix (str): Prefix for output files
181
        n_jobs (int): Number of parallel jobs
182
    """
183
    # Read generated SMILES from CSV file
184
    try:
185
        gen_df = pd.read_csv(gen_path)
186
        if smiles_col not in gen_df.columns:
187
            raise ValueError(f"SMILES column '{smiles_col}' not found in generated dataset")
188
        gen_smiles = gen_df[smiles_col].dropna().tolist()
189
    except FileNotFoundError:
190
        raise FileNotFoundError(f"Could not find generated CSV file: {gen_path}")
191
    except pd.errors.EmptyDataError:
192
        raise ValueError("Generated CSV file is empty")
193
    
194
    # Read reference SMILES from .smi files
195
    ref_smiles_1 = read_smi_file(ref_path_1)
196
    ref_smiles_2 = read_smi_file(ref_path_2) if ref_path_2 else None
197
    
198
    # Validate that we have SMILES to process
199
    if not gen_smiles:
200
        raise ValueError("No valid SMILES found in generated set")
201
    if not ref_smiles_1:
202
        raise ValueError("No valid SMILES found in reference set 1")
203
    if ref_path_2 and not ref_smiles_2:
204
        raise ValueError("No valid SMILES found in reference set 2")
205
    
206
    print(f"\nProcessing datasets:")
207
    print(f"Generated molecules: {len(gen_smiles)}")
208
    print(f"Reference set 1: {len(ref_smiles_1)}")
209
    if ref_smiles_2:
210
        print(f"Reference set 2: {len(ref_smiles_2)}")
211
    
212
    # Run evaluation
213
    evaluator = MoleculeEvaluator(gen_smiles, ref_smiles_1, ref_smiles_2, n_jobs=n_jobs)
214
    results = evaluator.evaluate_all()
215
    
216
    print("\nSaving results...")
217
    # Add dataset sizes to results
218
    results.update({
219
        'n_generated': len(gen_smiles),
220
        'n_reference_1': len(ref_smiles_1),
221
        'n_reference_2': len(ref_smiles_2) if ref_smiles_2 is not None else 0
222
    })
223
    
224
    # Save results
225
    # Format float values to 3 decimal places for JSON
226
    formatted_results = {k: round(v, 3) if isinstance(v, float) else v for k, v in results.items()}
227
    with open(f"{output_prefix}.json", 'w') as f:
228
        json.dump(formatted_results, f, indent=4)
229
    
230
    # Create DataFrame with formatted results
231
    df = pd.DataFrame([formatted_results])
232
    df.to_csv(f"{output_prefix}.csv", index=False)
233
    
234
    return results
235
236
if __name__ == "__main__":
237
    import argparse
238
    
239
    parser = argparse.ArgumentParser(description='Evaluate generated molecules against reference sets')
240
    parser.add_argument('--gen', required=True, help='Path to CSV file with generated SMILES')
241
    parser.add_argument('--ref1', required=True, help='Path to .smi file with first reference set SMILES')
242
    parser.add_argument('--ref2', help='Path to .smi file with second reference set SMILES (optional)')
243
    parser.add_argument('--smiles-col', default='SMILES', help='Name of SMILES column in generated CSV file')
244
    parser.add_argument('--output', default='results', help='Prefix for output files')
245
    parser.add_argument('--n-jobs', type=int, default=8, help='Number of parallel jobs')
246
    
247
    args = parser.parse_args()
248
    
249
    try:
250
        results = evaluate_molecules_from_files(
251
            args.gen, 
252
            args.ref1, 
253
            args.ref2, 
254
            smiles_col=args.smiles_col,
255
            output_prefix=args.output,
256
            n_jobs=args.n_jobs
257
        )
258
        print(f"Evaluation complete. Results saved to {args.output}.json and {args.output}.csv")
259
        
260
    except Exception as e:
261
        print(f"Error during evaluation: {str(e)}")