import json
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem import QED, AllChem
from rdkit.Chem import Descriptors
from rdkit.Chem import RDConfig
import os
import sys
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
from fcd_torch import FCD # You'll need to install fcd_torch
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # Add parent directory to path
from src.util.utils import (fraction_valid, fraction_unique, novelty,
internal_diversity,obey_lipinski, obey_veber, load_pains_filters,
is_pains, FragMetric, ScafMetric)
import torch
class MoleculeEvaluator:
def __init__(self, gen_smiles, ref_smiles_1, ref_smiles_2=None, n_jobs=8):
"""
Initialize evaluator with generated and reference SMILES
ref_smiles_2 is optional
"""
self.gen_smiles = gen_smiles
self.ref_smiles_1 = ref_smiles_1
self.ref_smiles_2 = ref_smiles_2
self.n_jobs = n_jobs
# Convert SMILES to RDKit molecules and filter out invalid ones
self.gen_mols = [mol for s in gen_smiles if s and (mol := Chem.MolFromSmiles(s)) is not None]
self.ref_mols_1 = [mol for s in ref_smiles_1 if s and (mol := Chem.MolFromSmiles(s)) is not None]
self.ref_mols_2 = [mol for s in ref_smiles_2 if s and (mol := Chem.MolFromSmiles(s)) is not None] if ref_smiles_2 else None
# Initialize metrics that need setup
self.fcd = FCD(device='cuda' if torch.cuda.is_available() else 'cpu')
self.pains_catalog = load_pains_filters()
self.frag_metric = FragMetric(n_jobs=1)
self.scaf_metric = ScafMetric(n_jobs=1)
def calculate_basic_metrics(self):
"""Calculate validity, uniqueness, novelty, and internal diversity"""
# Generate Morgan fingerprints for internal diversity calculation
fps = np.array([AllChem.GetMorganFingerprintAsBitVect(mol, 2, 1024) for mol in self.gen_mols if mol is not None])
internal_diversity_mean, internal_diversity_std = internal_diversity(fps)
results = {
'validity': fraction_valid(self.gen_smiles, n_jobs=self.n_jobs),
'uniqueness': fraction_unique(self.gen_smiles, n_jobs=self.n_jobs),
'novelty_ref1': novelty(self.gen_smiles, self.ref_smiles_1, n_jobs=self.n_jobs),
'internal_diversity': internal_diversity_mean,
'internal_diversity_std': internal_diversity_std
}
if self.ref_smiles_2:
results['novelty_ref2'] = novelty(self.gen_smiles, self.ref_smiles_2, n_jobs=self.n_jobs)
return results
def calculate_property_metrics(self):
"""Calculate QED and SA scores"""
qed_scores = [QED.qed(mol) for mol in self.gen_mols if mol is not None]
sa_scores = [sascorer.calculateScore(mol) for mol in self.gen_mols if mol is not None]
return {
'qed_mean': np.mean(qed_scores),
'qed_std': np.std(qed_scores),
'sa_mean': np.mean(sa_scores),
'sa_std': np.std(sa_scores)
}
def calculate_fcd(self):
"""Calculate FCD score against both reference sets"""
# Filter out None values and convert mols back to SMILES
gen_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.gen_mols if mol is not None]
ref1_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_1 if mol is not None]
results = {
'fcd_ref1': self.fcd(gen_valid_smiles, ref1_valid_smiles)
}
if self.ref_mols_2:
ref2_valid_smiles = [Chem.MolToSmiles(mol) for mol in self.ref_mols_2 if mol is not None]
results['fcd_ref2'] = self.fcd(gen_valid_smiles, ref2_valid_smiles)
return results
def calculate_similarity_metrics(self):
"""Calculate fragment and scaffold similarity"""
results = {
'frag_sim_ref1': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_1),
'scaf_sim_ref1': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_1)
}
if self.ref_mols_2:
results.update({
'frag_sim_ref2': self.frag_metric(gen=self.gen_mols, ref=self.ref_mols_2),
'scaf_sim_ref2': self.scaf_metric(gen=self.gen_mols, ref=self.ref_mols_2)
})
return results
def calculate_drug_likeness(self):
"""Calculate Lipinski, Veber and PAINS filtering results"""
lipinski_scores = [obey_lipinski(mol) for mol in self.gen_mols if mol is not None]
veber_scores = [obey_veber(mol) for mol in self.gen_mols if mol is not None]
pains_results = [not is_pains(mol, self.pains_catalog) for mol in self.gen_mols if mol is not None]
return {
'lipinski_mean': np.mean(lipinski_scores),
'lipinski_std': np.std(lipinski_scores),
'veber_mean': np.mean(veber_scores),
'veber_std': np.std(veber_scores),
'pains_pass_rate': np.mean(pains_results)
}
def evaluate_all(self):
"""Run all evaluations and combine results"""
results = {}
print("\nCalculating basic metrics...")
basic_metrics = self.calculate_basic_metrics()
print("Basic metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in basic_metrics.items()})
results.update(basic_metrics)
print("\nCalculating property metrics...")
property_metrics = self.calculate_property_metrics()
print("Property metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in property_metrics.items()})
results.update(property_metrics)
print("\nCalculating FCD scores...")
fcd_metrics = self.calculate_fcd()
print("FCD metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in fcd_metrics.items()})
results.update(fcd_metrics)
print("\nCalculating similarity metrics...")
similarity_metrics = self.calculate_similarity_metrics()
print("Similarity metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in similarity_metrics.items()})
results.update(similarity_metrics)
print("\nCalculating drug-likeness metrics...")
drug_likeness = self.calculate_drug_likeness()
print("Drug-likeness metrics:", {k: f"{v:.3f}" if isinstance(v, float) else v for k, v in drug_likeness.items()})
results.update(drug_likeness)
return results
def read_smi_file(file_path):
"""
Read SMILES strings from a .smi file
Args:
file_path (str): Path to .smi file
Returns:
list: List of SMILES strings
"""
smiles_list = []
try:
with open(file_path, 'r') as f:
for line in f:
# Handle different .smi formats:
# Some .smi files have just SMILES, others have SMILES followed by an identifier
parts = line.strip().split()
if parts: # Skip empty lines
smiles = parts[0] # First part is always the SMILES
smiles_list.append(smiles)
except FileNotFoundError:
raise FileNotFoundError(f"Could not find SMI file: {file_path}")
except Exception as e:
raise ValueError(f"Error reading SMI file {file_path}: {str(e)}")
return smiles_list
def evaluate_molecules_from_files(gen_path, ref_path_1, ref_path_2=None, smiles_col='SMILES', output_prefix="results", n_jobs=8):
"""
Main function to evaluate generated molecules against reference sets
Args:
gen_path (str): Path to CSV file containing generated SMILES
ref_path_1 (str): Path to .smi file containing first reference set SMILES
ref_path_2 (str, optional): Path to .smi file containing second reference set SMILES
smiles_col (str): Name of column containing SMILES strings in the CSV file
output_prefix (str): Prefix for output files
n_jobs (int): Number of parallel jobs
"""
# Read generated SMILES from CSV file
try:
gen_df = pd.read_csv(gen_path)
if smiles_col not in gen_df.columns:
raise ValueError(f"SMILES column '{smiles_col}' not found in generated dataset")
gen_smiles = gen_df[smiles_col].dropna().tolist()
except FileNotFoundError:
raise FileNotFoundError(f"Could not find generated CSV file: {gen_path}")
except pd.errors.EmptyDataError:
raise ValueError("Generated CSV file is empty")
# Read reference SMILES from .smi files
ref_smiles_1 = read_smi_file(ref_path_1)
ref_smiles_2 = read_smi_file(ref_path_2) if ref_path_2 else None
# Validate that we have SMILES to process
if not gen_smiles:
raise ValueError("No valid SMILES found in generated set")
if not ref_smiles_1:
raise ValueError("No valid SMILES found in reference set 1")
if ref_path_2 and not ref_smiles_2:
raise ValueError("No valid SMILES found in reference set 2")
print(f"\nProcessing datasets:")
print(f"Generated molecules: {len(gen_smiles)}")
print(f"Reference set 1: {len(ref_smiles_1)}")
if ref_smiles_2:
print(f"Reference set 2: {len(ref_smiles_2)}")
# Run evaluation
evaluator = MoleculeEvaluator(gen_smiles, ref_smiles_1, ref_smiles_2, n_jobs=n_jobs)
results = evaluator.evaluate_all()
print("\nSaving results...")
# Add dataset sizes to results
results.update({
'n_generated': len(gen_smiles),
'n_reference_1': len(ref_smiles_1),
'n_reference_2': len(ref_smiles_2) if ref_smiles_2 is not None else 0
})
# Save results
# Format float values to 3 decimal places for JSON
formatted_results = {k: round(v, 3) if isinstance(v, float) else v for k, v in results.items()}
with open(f"{output_prefix}.json", 'w') as f:
json.dump(formatted_results, f, indent=4)
# Create DataFrame with formatted results
df = pd.DataFrame([formatted_results])
df.to_csv(f"{output_prefix}.csv", index=False)
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Evaluate generated molecules against reference sets')
parser.add_argument('--gen', required=True, help='Path to CSV file with generated SMILES')
parser.add_argument('--ref1', required=True, help='Path to .smi file with first reference set SMILES')
parser.add_argument('--ref2', help='Path to .smi file with second reference set SMILES (optional)')
parser.add_argument('--smiles-col', default='SMILES', help='Name of SMILES column in generated CSV file')
parser.add_argument('--output', default='results', help='Prefix for output files')
parser.add_argument('--n-jobs', type=int, default=8, help='Number of parallel jobs')
args = parser.parse_args()
try:
results = evaluate_molecules_from_files(
args.gen,
args.ref1,
args.ref2,
smiles_col=args.smiles_col,
output_prefix=args.output,
n_jobs=args.n_jobs
)
print(f"Evaluation complete. Results saved to {args.output}.json and {args.output}.csv")
except Exception as e:
print(f"Error during evaluation: {str(e)}")