Diff of /tests/test_selfies.py [000000] .. [1aa732]

Switch to unified view

a b/tests/test_selfies.py
1
import faulthandler
2
import random
3
4
import pytest
5
from rdkit.Chem import MolFromSmiles
6
7
import selfies as sf
8
9
faulthandler.enable()
10
11
12
@pytest.fixture()
13
def max_selfies_len():
14
    return 1000
15
16
17
@pytest.fixture()
18
def large_alphabet():
19
    alphabet = sf.get_semantic_robust_alphabet()
20
    alphabet.update([
21
        "[#Br]", "[#Branch1]", "[#Branch2]", "[#Branch3]", "[#C@@H1]",
22
        "[#C@@]", "[#C@H1]", "[#C@]", "[#C]", "[#Cl]", "[#F]", "[#H]", "[#I]",
23
        "[#NH1]", "[#N]", "[#O]", "[#P]", "[#Ring1]", "[#Ring2]", "[#Ring3]",
24
        "[#S]", "[/Br]", "[/C@@H1]", "[/C@@]", "[/C@H1]", "[/C@]", "[/C]",
25
        "[/Cl]", "[/F]", "[/H]", "[/I]", "[/NH1]", "[/N]", "[/O]", "[/P]",
26
        "[/S]", "[=Br]", "[=Branch1]", "[=Branch2]", "[=Branch3]", "[=C@@H1]",
27
        "[=C@@]", "[=C@H1]", "[=C@]", "[=C]", "[=Cl]", "[=F]", "[=H]", "[=I]",
28
        "[=NH1]", "[=N]", "[=O]", "[=P]", "[=Ring1]", "[=Ring2]", "[=Ring3]",
29
        "[=S]", "[Br]", "[Branch1]", "[Branch2]", "[Branch3]", "[C@@H1]",
30
        "[C@@]", "[C@H1]", "[C@]", "[C]", "[Cl]", "[F]", "[H]", "[I]", "[NH1]",
31
        "[N]", "[O]", "[P]", "[Ring1]", "[Ring2]", "[Ring3]", "[S]", "[\\Br]",
32
        "[\\C@@H1]", "[\\C@@]", "[\\C@H1]", "[\\C@]", "[\\C]", "[\\Cl]",
33
        "[\\F]", "[\\H]", "[\\I]", "[\\NH1]", "[\\N]", "[\\O]", "[\\P]",
34
        "[\\S]", "[nop]"
35
    ])
36
    return list(alphabet)
37
38
39
def test_random_selfies_decoder(trials, max_selfies_len, large_alphabet):
40
    """Tests that SELFIES that are generated by randomly stringing together
41
    symbols from the SELFIES alphabet are decoded into valid SMILES.
42
    """
43
44
    alphabet = tuple(large_alphabet)
45
46
    for _ in range(trials):
47
48
        # create random SELFIES and decode
49
        rand_len = random.randint(1, max_selfies_len)
50
        rand_selfies = "".join(random_choices(alphabet, k=rand_len))
51
        smiles = sf.decoder(rand_selfies)
52
53
        # check if SMILES is valid
54
        try:
55
            is_valid = MolFromSmiles(smiles, sanitize=True) is not None
56
        except Exception:
57
            is_valid = False
58
59
        err_msg = "SMILES: {}\n\t SELFIES: {}".format(smiles, rand_selfies)
60
        assert is_valid, err_msg
61
62
63
def test_nop_symbol_decoder(max_selfies_len, large_alphabet):
64
    """Tests that the '[nop]' symbol is always skipped over.
65
    """
66
67
    alphabet = list(large_alphabet)
68
    alphabet.remove("[nop]")
69
70
    for _ in range(100):
71
72
        # create random SELFIES with and without [nop]
73
        rand_len = random.randint(1, max_selfies_len)
74
        rand_mol = random_choices(alphabet, k=rand_len)
75
        rand_mol.extend(["[nop]"] * (max_selfies_len - rand_len))
76
        random.shuffle(rand_mol)
77
78
        with_nops = "".join(rand_mol)
79
        without_nops = with_nops.replace("[nop]", "")
80
81
        assert sf.decoder(with_nops) == sf.decoder(without_nops)
82
83
84
def test_get_semantic_constraints():
85
    constraints = sf.get_semantic_constraints()
86
    assert constraints is not sf.get_semantic_constraints()  # not alias
87
    assert "?" in constraints
88
89
90
def test_change_constraints_cache_clear():
91
    alphabet = sf.get_semantic_robust_alphabet()
92
    assert alphabet == sf.get_semantic_robust_alphabet()
93
    assert sf.decoder("[C][#C]") == "C#C"
94
95
    new_constraints = sf.get_semantic_constraints()
96
    new_constraints["C"] = 1
97
    sf.set_semantic_constraints(new_constraints)
98
99
    new_alphabet = sf.get_semantic_robust_alphabet()
100
    assert new_alphabet != alphabet
101
    assert sf.decoder("[C][#C]") == "CC"
102
103
    sf.set_semantic_constraints()  # re-set alphabet
104
105
106
def test_invalid_or_unsupported_smiles_encoder():
107
    malformed_smiles = [
108
        "",
109
        "(",
110
        "C(Cl)(Cl)CC[13C",
111
        "C(CCCOC",
112
        "C=(CCOC",
113
        "CCCC)",
114
        "C1CCCCC",
115
        "C(F)(F)(F)(F)(F)F",  # violates bond constraints
116
        "C=C1=CCCCCC1",  # violates bond constraints
117
        "CC*CC",  # uses wildcard
118
        "C$C",  # uses $ bond
119
        "S[As@TB1](F)(Cl)(Br)N",  # unrecognized chirality,
120
        "SOMETHINGWRONGHERE",
121
        "1243124124",
122
    ]
123
124
    for smiles in malformed_smiles:
125
        with pytest.raises(sf.EncoderError):
126
            sf.encoder(smiles)
127
128
129
def test_malformed_selfies_decoder():
130
    with pytest.raises(sf.DecoderError):
131
        sf.decoder("[O][=C][O][C][C][C][C][O][N][Branch2_3")
132
133
134
def random_choices(population, k):  # random.choices was new in Python v3.6
135
    return [random.choice(population) for _ in range(k)]
136
137
138
def test_decoder_attribution():
139
    sm, am = sf.decoder(
140
        "[C][N][C][Branch1][C][P][C][C][Ring1][=Branch1]", attribute=True)
141
    # check that P lined up
142
    for ta in am:
143
        if ta.token == 'P':
144
            for a in ta.attribution:
145
                if a.token == '[P]':
146
                    return
147
    raise ValueError('Failed to find P in attribution map')
148
149
150
def test_encoder_attribution():
151
    smiles = "C1([O-])C=CC=C1Cl"
152
    indices = [0, 3, 3, 3, 5, 7, 8, 10, None, None, 12]
153
    s, am = sf.encoder(smiles, attribute=True)
154
    for i, ta in enumerate(am):
155
        if ta.attribution:
156
            assert indices[i] == ta.attribution[0].index, \
157
                f'found {ta[1]}; should be {indices[i]}'
158
        if ta.token == '[Cl]':
159
            assert 'Cl' in [
160
                a.token for a in ta.attribution],\
161
                'Failed to find Cl in attribution map'