|
a |
|
b/tests/test_on_datasets.py |
|
|
1 |
import faulthandler |
|
|
2 |
import pathlib |
|
|
3 |
import random |
|
|
4 |
|
|
|
5 |
import pandas as pd |
|
|
6 |
import pytest |
|
|
7 |
from rdkit import Chem |
|
|
8 |
|
|
|
9 |
import selfies as sf |
|
|
10 |
|
|
|
11 |
faulthandler.enable() |
|
|
12 |
|
|
|
13 |
TEST_SET_DIR = pathlib.Path(__file__).parent / "test_sets" |
|
|
14 |
ERROR_LOG_DIR = pathlib.Path(__file__).parent / "error_logs" |
|
|
15 |
ERROR_LOG_DIR.mkdir(exist_ok=True, parents=True) |
|
|
16 |
|
|
|
17 |
datasets = list(TEST_SET_DIR.glob("**/*.csv")) |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
@pytest.mark.parametrize("test_path", datasets) |
|
|
21 |
def test_roundtrip_translation(test_path, dataset_samples): |
|
|
22 |
"""Tests SMILES -> SELFIES -> SMILES translation on various datasets. |
|
|
23 |
""" |
|
|
24 |
|
|
|
25 |
# very relaxed constraints |
|
|
26 |
constraints = sf.get_preset_constraints("hypervalent") |
|
|
27 |
constraints.update({"P": 7, "P-1": 8, "P+1": 6, "?": 12}) |
|
|
28 |
sf.set_semantic_constraints(constraints) |
|
|
29 |
|
|
|
30 |
error_path = ERROR_LOG_DIR / "{}.csv".format(test_path.stem) |
|
|
31 |
with open(error_path, "w+") as error_log: |
|
|
32 |
error_log.write("In, Out\n") |
|
|
33 |
|
|
|
34 |
error_data = [] |
|
|
35 |
error_found = False |
|
|
36 |
|
|
|
37 |
n_lines = sum(1 for _ in open(test_path)) - 1 |
|
|
38 |
n_keep = dataset_samples if (0 < dataset_samples <= n_lines) else n_lines |
|
|
39 |
skip = random.sample(range(1, n_lines + 1), n_lines - n_keep) |
|
|
40 |
reader = pd.read_csv(test_path, chunksize=10000, header=0, skiprows=skip) |
|
|
41 |
|
|
|
42 |
for chunk in reader: |
|
|
43 |
|
|
|
44 |
for in_smiles in chunk["smiles"]: |
|
|
45 |
in_smiles = in_smiles.strip() |
|
|
46 |
|
|
|
47 |
mol = Chem.MolFromSmiles(in_smiles, sanitize=True) |
|
|
48 |
if (mol is None) or ("*" in in_smiles): |
|
|
49 |
continue |
|
|
50 |
|
|
|
51 |
try: |
|
|
52 |
selfies = sf.encoder(in_smiles, strict=True) |
|
|
53 |
out_smiles = sf.decoder(selfies) |
|
|
54 |
except (sf.EncoderError, sf.DecoderError): |
|
|
55 |
error_data.append((in_smiles, "")) |
|
|
56 |
continue |
|
|
57 |
|
|
|
58 |
if not is_same_mol(in_smiles, out_smiles): |
|
|
59 |
error_data.append((in_smiles, out_smiles)) |
|
|
60 |
|
|
|
61 |
with open(error_path, "a") as error_log: |
|
|
62 |
for entry in error_data: |
|
|
63 |
error_log.write(",".join(entry) + "\n") |
|
|
64 |
|
|
|
65 |
error_found = error_found or error_data |
|
|
66 |
error_data = [] |
|
|
67 |
|
|
|
68 |
sf.set_semantic_constraints() # restore constraints |
|
|
69 |
|
|
|
70 |
assert not error_found |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def is_same_mol(smiles1, smiles2): |
|
|
74 |
try: |
|
|
75 |
can_smiles1 = Chem.CanonSmiles(smiles1) |
|
|
76 |
can_smiles2 = Chem.CanonSmiles(smiles2) |
|
|
77 |
return can_smiles1 == can_smiles2 |
|
|
78 |
except Exception: |
|
|
79 |
return False |