Diff of /results/evaluate.py [000000] .. [7d53f6]

Switch to side-by-side view

--- a
+++ b/results/evaluate.py
@@ -0,0 +1,261 @@
+import json
+import pandas as pd
+import numpy as np
+from rdkit import Chem
+from rdkit.Chem import QED, AllChem
+from rdkit.Chem import Descriptors
+from rdkit.Chem import RDConfig
+import os
+import sys
+sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
+import sascorer
+from fcd_torch import FCD  # You'll need to install fcd_torch
+sys.path.append(os.path.dirname(os.path.dirname(__file__)))  # Add parent directory to path
+from src.util.utils import (fraction_valid, fraction_unique, novelty, 
+                  internal_diversity,obey_lipinski, obey_veber, load_pains_filters, 
+                     is_pains, FragMetric, ScafMetric)
+import torch
+
+class MoleculeEvaluator:
+    def __init__(self, gen_smiles, ref_smiles_1, ref_smiles_2=None, n_jobs=8):
+        """
+        Initialize evaluator with generated and reference SMILES
+        ref_smiles_2 is optional
+        """
+        self.gen_smiles = gen_smiles
+        self.ref_smiles_1 = ref_smiles_1
+        self.ref_smiles_2 = ref_smiles_2
+        self.n_jobs = n_jobs
+        
+        # Convert SMILES to RDKit molecules and filter out invalid ones
+        self.gen_mols = [mol for s in gen_smiles if s and (mol := Chem.MolFromSmiles(s)) is not None]
+        self.ref_mols_1 = [mol for s in ref_smiles_1 if s and (mol := Chem.MolFromSmiles(s)) is not None]
+        self.ref_mols_2 = [mol for s in ref_smiles_2 if s and (mol := Chem.MolFromSmiles(s)) is not None] if ref_smiles_2 else None
+        
+        # Initialize metrics that need setup
+        self.fcd = FCD(device='cuda' if torch.cuda.is_available() else 'cpu')
+        self.pains_catalog = load_pains_filters()
+        self.frag_metric = FragMetric(n_jobs=1)
+        self.scaf_metric = ScafMetric(n_jobs=1)
+
+    def calculate_basic_metrics(self):
+        """Calculate validity, uniqueness, novelty, and internal diversity"""
+        # Generate Morgan fingerprints for internal diversity calculation
+        fps = np.array([AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024) for mol in self.gen_mols if mol is not None])
+        internal_diversity_mean, internal_diversity_std = internal_diversity(fps)
+        results = {
+            'validity': fraction_valid(self.gen_smiles, n_jobs=self.n_jobs),
+            'uniqueness': fraction_unique(self.gen_smiles, n_jobs=self.n_jobs),
+            'novelty_ref1': novelty(self.gen_smiles, self.ref_smiles_1, n_jobs=self.n_jobs),
+            'internal_diversity': internal_diversity_mean, 
+            'internal_diversity_std': internal_diversity_std
+        }
+        if self.ref_smiles_2:
+            results['novelty_ref2'] = novelty(self.gen_smiles, self.ref_smiles_2, n_jobs=self.n_jobs)
+        return results
+
+    def calculate_property_metrics(self):
+        """Calculate QED and SA scores"""
+        qed_scores = [QED.qed(mol) for mol in self.gen_mols if mol is not None]
+        sa_scores = [sascorer.calculateScore(mol) for mol in self.gen_mols if mol is not None]
+        
+        return {
+            'qed_mean': np.mean(qed_scores),
+            'qed_std': np.std(qed_scores),
+            'sa_mean': np.mean(sa_scores),
+            'sa_std': np.std(sa_scores)
+        }
+
+    def calculate_fcd(self):
+        """Calculate FCD score against both reference sets"""
+        # Filter out None values and convert mols back to SMILES
+        gen_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.gen_mols if mol is not None]
+        ref1_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_1 if mol is not None]
+        
+        results = {
+            'fcd_ref1': self.fcd(gen_valid_smiles, ref1_valid_smiles)
+        }
+        
+        if self.ref_mols_2:
+            ref2_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_2 if mol is not None]
+            results['fcd_ref2'] = self.fcd(gen_valid_smiles, ref2_valid_smiles)
+            
+        return results
+
+    def calculate_similarity_metrics(self):
+        """Calculate fragment and scaffold similarity"""
+
+        results = {
+            'frag_sim_ref1': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_1),
+            'scaf_sim_ref1': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_1)
+        }
+        if self.ref_mols_2:
+            results.update({
+                'frag_sim_ref2': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_2),
+                'scaf_sim_ref2': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_2)
+            })
+        return results
+
+    def calculate_drug_likeness(self):
+        """Calculate Lipinski, Veber and PAINS filtering results"""
+        lipinski_scores = [obey_lipinski(mol) for mol in self.gen_mols if mol is not None]
+        veber_scores = [obey_veber(mol) for mol in self.gen_mols if mol is not None]
+        pains_results = [not is_pains(mol, self.pains_catalog) for mol in self.gen_mols if mol is not None]
+        
+        return {
+            'lipinski_mean': np.mean(lipinski_scores),
+            'lipinski_std': np.std(lipinski_scores),
+            'veber_mean': np.mean(veber_scores),
+            'veber_std': np.std(veber_scores),
+            'pains_pass_rate': np.mean(pains_results)
+        }
+
+    def evaluate_all(self):
+        """Run all evaluations and combine results"""
+        results = {}
+        
+        print("\nCalculating basic metrics...")
+        basic_metrics = self.calculate_basic_metrics()
+        print("Basic metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in basic_metrics.items()})
+        results.update(basic_metrics)
+        
+        print("\nCalculating property metrics...")
+        property_metrics = self.calculate_property_metrics()
+        print("Property metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in property_metrics.items()})
+        results.update(property_metrics)
+        
+        print("\nCalculating FCD scores...")
+        fcd_metrics = self.calculate_fcd()
+        print("FCD metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in fcd_metrics.items()})
+        results.update(fcd_metrics)
+        
+        print("\nCalculating similarity metrics...")
+        similarity_metrics = self.calculate_similarity_metrics()
+        print("Similarity metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in similarity_metrics.items()})
+        results.update(similarity_metrics)
+        
+        print("\nCalculating drug-likeness metrics...")
+        drug_likeness = self.calculate_drug_likeness()
+        print("Drug-likeness metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in drug_likeness.items()})
+        results.update(drug_likeness)
+        
+        return results
+
+def read_smi_file(file_path):
+    """
+    Read SMILES strings from a .smi file
+    
+    Args:
+        file_path (str): Path to .smi file
+        
+    Returns:
+        list: List of SMILES strings
+    """
+    smiles_list = []
+    try:
+        with open(file_path, 'r') as f:
+            for line in f:
+                # Handle different .smi formats:
+                # Some .smi files have just SMILES, others have SMILES followed by an identifier
+                parts = line.strip().split()
+                if parts:  # Skip empty lines
+                    smiles = parts[0]  # First part is always the SMILES
+                    smiles_list.append(smiles)
+    except FileNotFoundError:
+        raise FileNotFoundError(f"Could not find SMI file: {file_path}")
+    except Exception as e:
+        raise ValueError(f"Error reading SMI file {file_path}: {str(e)}")
+    
+    return smiles_list
+
+def evaluate_molecules_from_files(gen_path, ref_path_1, ref_path_2=None, smiles_col='SMILES', output_prefix="results", n_jobs=8):
+    """
+    Main function to evaluate generated molecules against reference sets
+    
+    Args:
+        gen_path (str): Path to CSV file containing generated SMILES
+        ref_path_1 (str): Path to .smi file containing first reference set SMILES
+        ref_path_2 (str, optional): Path to .smi file containing second reference set SMILES
+        smiles_col (str): Name of column containing SMILES strings in the CSV file
+        output_prefix (str): Prefix for output files
+        n_jobs (int): Number of parallel jobs
+    """
+    # Read generated SMILES from CSV file
+    try:
+        gen_df = pd.read_csv(gen_path)
+        if smiles_col not in gen_df.columns:
+            raise ValueError(f"SMILES column '{smiles_col}' not found in generated dataset")
+        gen_smiles = gen_df[smiles_col].dropna().tolist()
+    except FileNotFoundError:
+        raise FileNotFoundError(f"Could not find generated CSV file: {gen_path}")
+    except pd.errors.EmptyDataError:
+        raise ValueError("Generated CSV file is empty")
+    
+    # Read reference SMILES from .smi files
+    ref_smiles_1 = read_smi_file(ref_path_1)
+    ref_smiles_2 = read_smi_file(ref_path_2) if ref_path_2 else None
+    
+    # Validate that we have SMILES to process
+    if not gen_smiles:
+        raise ValueError("No valid SMILES found in generated set")
+    if not ref_smiles_1:
+        raise ValueError("No valid SMILES found in reference set 1")
+    if ref_path_2 and not ref_smiles_2:
+        raise ValueError("No valid SMILES found in reference set 2")
+    
+    print(f"\nProcessing datasets:")
+    print(f"Generated molecules: {len(gen_smiles)}")
+    print(f"Reference set 1: {len(ref_smiles_1)}")
+    if ref_smiles_2:
+        print(f"Reference set 2: {len(ref_smiles_2)}")
+    
+    # Run evaluation
+    evaluator = MoleculeEvaluator(gen_smiles, ref_smiles_1, ref_smiles_2, n_jobs=n_jobs)
+    results = evaluator.evaluate_all()
+    
+    print("\nSaving results...")
+    # Add dataset sizes to results
+    results.update({
+        'n_generated': len(gen_smiles),
+        'n_reference_1': len(ref_smiles_1),
+        'n_reference_2': len(ref_smiles_2) if ref_smiles_2 is not None else 0
+    })
+    
+    # Save results
+    # Format float values to 3 decimal places for JSON
+    formatted_results = {k: round(v, 3) if isinstance(v, float) else v for k, v in results.items()}
+    with open(f"{output_prefix}.json", 'w') as f:
+        json.dump(formatted_results, f, indent=4)
+    
+    # Create DataFrame with formatted results
+    df = pd.DataFrame([formatted_results])
+    df.to_csv(f"{output_prefix}.csv", index=False)
+    
+    return results
+
+if __name__ == "__main__":
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='Evaluate generated molecules against reference sets')
+    parser.add_argument('--gen', required=True, help='Path to CSV file with generated SMILES')
+    parser.add_argument('--ref1', required=True, help='Path to .smi file with first reference set SMILES')
+    parser.add_argument('--ref2', help='Path to .smi file with second reference set SMILES (optional)')
+    parser.add_argument('--smiles-col', default='SMILES', help='Name of SMILES column in generated CSV file')
+    parser.add_argument('--output', default='results', help='Prefix for output files')
+    parser.add_argument('--n-jobs', type=int, default=8, help='Number of parallel jobs')
+    
+    args = parser.parse_args()
+    
+    try:
+        results = evaluate_molecules_from_files(
+            args.gen, 
+            args.ref1, 
+            args.ref2, 
+            smiles_col=args.smiles_col,
+            output_prefix=args.output,
+            n_jobs=args.n_jobs
+        )
+        print(f"Evaluation complete. Results saved to {args.output}.json and {args.output}.csv")
+        
+    except Exception as e:
+        print(f"Error during evaluation: {str(e)}")
\ No newline at end of file