|
a |
|
b/tests/test_selfies.py |
|
|
1 |
import faulthandler |
|
|
2 |
import random |
|
|
3 |
|
|
|
4 |
import pytest |
|
|
5 |
from rdkit.Chem import MolFromSmiles |
|
|
6 |
|
|
|
7 |
import selfies as sf |
|
|
8 |
|
|
|
9 |
faulthandler.enable() |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
@pytest.fixture() |
|
|
13 |
def max_selfies_len(): |
|
|
14 |
return 1000 |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
@pytest.fixture() |
|
|
18 |
def large_alphabet(): |
|
|
19 |
alphabet = sf.get_semantic_robust_alphabet() |
|
|
20 |
alphabet.update([ |
|
|
21 |
"[#Br]", "[#Branch1]", "[#Branch2]", "[#Branch3]", "[#C@@H1]", |
|
|
22 |
"[#C@@]", "[#C@H1]", "[#C@]", "[#C]", "[#Cl]", "[#F]", "[#H]", "[#I]", |
|
|
23 |
"[#NH1]", "[#N]", "[#O]", "[#P]", "[#Ring1]", "[#Ring2]", "[#Ring3]", |
|
|
24 |
"[#S]", "[/Br]", "[/C@@H1]", "[/C@@]", "[/C@H1]", "[/C@]", "[/C]", |
|
|
25 |
"[/Cl]", "[/F]", "[/H]", "[/I]", "[/NH1]", "[/N]", "[/O]", "[/P]", |
|
|
26 |
"[/S]", "[=Br]", "[=Branch1]", "[=Branch2]", "[=Branch3]", "[=C@@H1]", |
|
|
27 |
"[=C@@]", "[=C@H1]", "[=C@]", "[=C]", "[=Cl]", "[=F]", "[=H]", "[=I]", |
|
|
28 |
"[=NH1]", "[=N]", "[=O]", "[=P]", "[=Ring1]", "[=Ring2]", "[=Ring3]", |
|
|
29 |
"[=S]", "[Br]", "[Branch1]", "[Branch2]", "[Branch3]", "[C@@H1]", |
|
|
30 |
"[C@@]", "[C@H1]", "[C@]", "[C]", "[Cl]", "[F]", "[H]", "[I]", "[NH1]", |
|
|
31 |
"[N]", "[O]", "[P]", "[Ring1]", "[Ring2]", "[Ring3]", "[S]", "[\\Br]", |
|
|
32 |
"[\\C@@H1]", "[\\C@@]", "[\\C@H1]", "[\\C@]", "[\\C]", "[\\Cl]", |
|
|
33 |
"[\\F]", "[\\H]", "[\\I]", "[\\NH1]", "[\\N]", "[\\O]", "[\\P]", |
|
|
34 |
"[\\S]", "[nop]" |
|
|
35 |
]) |
|
|
36 |
return list(alphabet) |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def test_random_selfies_decoder(trials, max_selfies_len, large_alphabet): |
|
|
40 |
"""Tests that SELFIES that are generated by randomly stringing together |
|
|
41 |
symbols from the SELFIES alphabet are decoded into valid SMILES. |
|
|
42 |
""" |
|
|
43 |
|
|
|
44 |
alphabet = tuple(large_alphabet) |
|
|
45 |
|
|
|
46 |
for _ in range(trials): |
|
|
47 |
|
|
|
48 |
# create random SELFIES and decode |
|
|
49 |
rand_len = random.randint(1, max_selfies_len) |
|
|
50 |
rand_selfies = "".join(random_choices(alphabet, k=rand_len)) |
|
|
51 |
smiles = sf.decoder(rand_selfies) |
|
|
52 |
|
|
|
53 |
# check if SMILES is valid |
|
|
54 |
try: |
|
|
55 |
is_valid = MolFromSmiles(smiles, sanitize=True) is not None |
|
|
56 |
except Exception: |
|
|
57 |
is_valid = False |
|
|
58 |
|
|
|
59 |
err_msg = "SMILES: {}\n\t SELFIES: {}".format(smiles, rand_selfies) |
|
|
60 |
assert is_valid, err_msg |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def test_nop_symbol_decoder(max_selfies_len, large_alphabet): |
|
|
64 |
"""Tests that the '[nop]' symbol is always skipped over. |
|
|
65 |
""" |
|
|
66 |
|
|
|
67 |
alphabet = list(large_alphabet) |
|
|
68 |
alphabet.remove("[nop]") |
|
|
69 |
|
|
|
70 |
for _ in range(100): |
|
|
71 |
|
|
|
72 |
# create random SELFIES with and without [nop] |
|
|
73 |
rand_len = random.randint(1, max_selfies_len) |
|
|
74 |
rand_mol = random_choices(alphabet, k=rand_len) |
|
|
75 |
rand_mol.extend(["[nop]"] * (max_selfies_len - rand_len)) |
|
|
76 |
random.shuffle(rand_mol) |
|
|
77 |
|
|
|
78 |
with_nops = "".join(rand_mol) |
|
|
79 |
without_nops = with_nops.replace("[nop]", "") |
|
|
80 |
|
|
|
81 |
assert sf.decoder(with_nops) == sf.decoder(without_nops) |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
def test_get_semantic_constraints(): |
|
|
85 |
constraints = sf.get_semantic_constraints() |
|
|
86 |
assert constraints is not sf.get_semantic_constraints() # not alias |
|
|
87 |
assert "?" in constraints |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
def test_change_constraints_cache_clear(): |
|
|
91 |
alphabet = sf.get_semantic_robust_alphabet() |
|
|
92 |
assert alphabet == sf.get_semantic_robust_alphabet() |
|
|
93 |
assert sf.decoder("[C][#C]") == "C#C" |
|
|
94 |
|
|
|
95 |
new_constraints = sf.get_semantic_constraints() |
|
|
96 |
new_constraints["C"] = 1 |
|
|
97 |
sf.set_semantic_constraints(new_constraints) |
|
|
98 |
|
|
|
99 |
new_alphabet = sf.get_semantic_robust_alphabet() |
|
|
100 |
assert new_alphabet != alphabet |
|
|
101 |
assert sf.decoder("[C][#C]") == "CC" |
|
|
102 |
|
|
|
103 |
sf.set_semantic_constraints() # re-set alphabet |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def test_invalid_or_unsupported_smiles_encoder(): |
|
|
107 |
malformed_smiles = [ |
|
|
108 |
"", |
|
|
109 |
"(", |
|
|
110 |
"C(Cl)(Cl)CC[13C", |
|
|
111 |
"C(CCCOC", |
|
|
112 |
"C=(CCOC", |
|
|
113 |
"CCCC)", |
|
|
114 |
"C1CCCCC", |
|
|
115 |
"C(F)(F)(F)(F)(F)F", # violates bond constraints |
|
|
116 |
"C=C1=CCCCCC1", # violates bond constraints |
|
|
117 |
"CC*CC", # uses wildcard |
|
|
118 |
"C$C", # uses $ bond |
|
|
119 |
"S[As@TB1](F)(Cl)(Br)N", # unrecognized chirality, |
|
|
120 |
"SOMETHINGWRONGHERE", |
|
|
121 |
"1243124124", |
|
|
122 |
] |
|
|
123 |
|
|
|
124 |
for smiles in malformed_smiles: |
|
|
125 |
with pytest.raises(sf.EncoderError): |
|
|
126 |
sf.encoder(smiles) |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def test_malformed_selfies_decoder(): |
|
|
130 |
with pytest.raises(sf.DecoderError): |
|
|
131 |
sf.decoder("[O][=C][O][C][C][C][C][O][N][Branch2_3") |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
def random_choices(population, k): # random.choices was new in Python v3.6 |
|
|
135 |
return [random.choice(population) for _ in range(k)] |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
def test_decoder_attribution(): |
|
|
139 |
sm, am = sf.decoder( |
|
|
140 |
"[C][N][C][Branch1][C][P][C][C][Ring1][=Branch1]", attribute=True) |
|
|
141 |
# check that P lined up |
|
|
142 |
for ta in am: |
|
|
143 |
if ta.token == 'P': |
|
|
144 |
for a in ta.attribution: |
|
|
145 |
if a.token == '[P]': |
|
|
146 |
return |
|
|
147 |
raise ValueError('Failed to find P in attribution map') |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
def test_encoder_attribution(): |
|
|
151 |
smiles = "C1([O-])C=CC=C1Cl" |
|
|
152 |
indices = [0, 3, 3, 3, 5, 7, 8, 10, None, None, 12] |
|
|
153 |
s, am = sf.encoder(smiles, attribute=True) |
|
|
154 |
for i, ta in enumerate(am): |
|
|
155 |
if ta.attribution: |
|
|
156 |
assert indices[i] == ta.attribution[0].index, \ |
|
|
157 |
f'found {ta[1]}; should be {indices[i]}' |
|
|
158 |
if ta.token == '[Cl]': |
|
|
159 |
assert 'Cl' in [ |
|
|
160 |
a.token for a in ta.attribution],\ |
|
|
161 |
'Failed to find Cl in attribution map' |