Diff of /tests/test_on_datasets.py [000000] .. [1aa732]

Switch to unified view

a b/tests/test_on_datasets.py
1
import faulthandler
2
import pathlib
3
import random
4
5
import pandas as pd
6
import pytest
7
from rdkit import Chem
8
9
import selfies as sf
10
11
faulthandler.enable()
12
13
TEST_SET_DIR = pathlib.Path(__file__).parent / "test_sets"
14
ERROR_LOG_DIR = pathlib.Path(__file__).parent / "error_logs"
15
ERROR_LOG_DIR.mkdir(exist_ok=True, parents=True)
16
17
datasets = list(TEST_SET_DIR.glob("**/*.csv"))
18
19
20
@pytest.mark.parametrize("test_path", datasets)
21
def test_roundtrip_translation(test_path, dataset_samples):
22
    """Tests SMILES -> SELFIES -> SMILES translation on various datasets.
23
    """
24
25
    # very relaxed constraints
26
    constraints = sf.get_preset_constraints("hypervalent")
27
    constraints.update({"P": 7, "P-1": 8, "P+1": 6, "?": 12})
28
    sf.set_semantic_constraints(constraints)
29
30
    error_path = ERROR_LOG_DIR / "{}.csv".format(test_path.stem)
31
    with open(error_path, "w+") as error_log:
32
        error_log.write("In, Out\n")
33
34
    error_data = []
35
    error_found = False
36
37
    n_lines = sum(1 for _ in open(test_path)) - 1
38
    n_keep = dataset_samples if (0 < dataset_samples <= n_lines) else n_lines
39
    skip = random.sample(range(1, n_lines + 1), n_lines - n_keep)
40
    reader = pd.read_csv(test_path, chunksize=10000, header=0, skiprows=skip)
41
42
    for chunk in reader:
43
44
        for in_smiles in chunk["smiles"]:
45
            in_smiles = in_smiles.strip()
46
47
            mol = Chem.MolFromSmiles(in_smiles, sanitize=True)
48
            if (mol is None) or ("*" in in_smiles):
49
                continue
50
51
            try:
52
                selfies = sf.encoder(in_smiles, strict=True)
53
                out_smiles = sf.decoder(selfies)
54
            except (sf.EncoderError, sf.DecoderError):
55
                error_data.append((in_smiles, ""))
56
                continue
57
58
            if not is_same_mol(in_smiles, out_smiles):
59
                error_data.append((in_smiles, out_smiles))
60
61
        with open(error_path, "a") as error_log:
62
            for entry in error_data:
63
                error_log.write(",".join(entry) + "\n")
64
65
        error_found = error_found or error_data
66
        error_data = []
67
68
    sf.set_semantic_constraints()  # restore constraints
69
70
    assert not error_found
71
72
73
def is_same_mol(smiles1, smiles2):
74
    try:
75
        can_smiles1 = Chem.CanonSmiles(smiles1)
76
        can_smiles2 = Chem.CanonSmiles(smiles2)
77
        return can_smiles1 == can_smiles2
78
    except Exception:
79
        return False