a b/tests/run_on_large_dataset.py
1
"""Script for testing selfies against large datasets.
2
"""
3
4
import argparse
5
import pathlib
6
7
import pandas as pd
8
from rdkit import Chem
9
from tqdm import tqdm
10
11
import selfies as sf
12
13
parser = argparse.ArgumentParser()
14
parser.add_argument("--data_path", type=str, default="version.smi.gz")
15
parser.add_argument("--col_name", type=str, default="isosmiles")
16
parser.add_argument("--sep", type=str, default=r"\s+")
17
parser.add_argument("--start_from", type=int, default=0)
18
args = parser.parse_args()
19
20
TEST_DIR = pathlib.Path(__file__).parent
21
TEST_SET_PATH = TEST_DIR / "test_sets" / args.data_path
22
ERROR_LOG_DIR = TEST_DIR / "error_logs"
23
ERROR_LOG_DIR.mkdir(exist_ok=True, parents=True)
24
25
26
def make_reader():
27
    return pd.read_csv(TEST_SET_PATH, sep=args.sep, chunksize=10000)
28
29
30
def roundtrip_translation():
31
    sf.set_semantic_constraints("hypervalent")
32
33
    n_entries = 0
34
    for chunk in make_reader():
35
        n_entries += len(chunk)
36
    pbar = tqdm(total=n_entries)
37
38
    reader = make_reader()
39
    error_log = open(ERROR_LOG_DIR / f"{TEST_SET_PATH.stem}.txt", "a+")
40
41
    curr_idx = 0
42
    for chunk_idx, chunk in enumerate(reader):
43
        for in_smiles in chunk[args.col_name]:
44
            pbar.update(1)
45
            curr_idx += 1
46
            if curr_idx < args.start_from:
47
                continue
48
49
            in_smiles = in_smiles.strip()
50
51
            mol = Chem.MolFromSmiles(in_smiles, sanitize=True)
52
            if (mol is None) or ("*" in in_smiles):
53
                continue
54
55
            try:
56
                selfies = sf.encoder(in_smiles, strict=True)
57
                out_smiles = sf.decoder(selfies)
58
            except (sf.EncoderError, sf.DecoderError):
59
                error_log.write(in_smiles + "\n")
60
                tqdm.write(in_smiles)
61
                continue
62
63
            if not is_same_mol(in_smiles, out_smiles):
64
                error_log.write(in_smiles + "\n")
65
                tqdm.write(in_smiles)
66
67
    error_log.close()
68
69
70
def is_same_mol(smiles1, smiles2):
71
    try:
72
        can_smiles1 = Chem.CanonSmiles(smiles1)
73
        can_smiles2 = Chem.CanonSmiles(smiles2)
74
        return can_smiles1 == can_smiles2
75
    except Exception:
76
        return False
77
78
79
if __name__ == "__main__":
80
    roundtrip_translation()