a b/tests/test_selfies_utils.py
1
import pytest
2
3
import selfies as sf
4
5
6
class Entry:
7
8
    def __init__(self, selfies, symbols, label, one_hot):
9
        self.selfies = selfies
10
        self.symbols = symbols
11
        self.label = label
12
        self.one_hot = one_hot
13
14
15
@pytest.fixture()
16
def dataset():
17
    stoi = {"[nop]": 0, "[O]": 1, ".": 2, "[C]": 3, "[F]": 4}
18
    itos = {i: c for c, i in stoi.items()}
19
    pad_to_len = 4
20
21
    entries = [
22
        Entry(selfies="",
23
              symbols=[],
24
              label=[0, 0, 0, 0],
25
              one_hot=[[1, 0, 0, 0, 0],
26
                       [1, 0, 0, 0, 0],
27
                       [1, 0, 0, 0, 0],
28
                       [1, 0, 0, 0, 0]]),
29
        Entry(selfies="[C][C][C]",
30
              symbols=["[C]", "[C]", "[C]"],
31
              label=[3, 3, 3, 0],
32
              one_hot=[[0, 0, 0, 1, 0],
33
                       [0, 0, 0, 1, 0],
34
                       [0, 0, 0, 1, 0],
35
                       [1, 0, 0, 0, 0]]),
36
        Entry(selfies="[C].[C]",
37
              symbols=["[C]", ".", "[C]"],
38
              label=[3, 2, 3, 0],
39
              one_hot=[[0, 0, 0, 1, 0],
40
                       [0, 0, 1, 0, 0],
41
                       [0, 0, 0, 1, 0],
42
                       [1, 0, 0, 0, 0]]),
43
        Entry(selfies="[C][O][C][F]",
44
              symbols=["[C]", "[O]", "[C]", "[F]"],
45
              label=[3, 1, 3, 4],
46
              one_hot=[[0, 0, 0, 1, 0],
47
                       [0, 1, 0, 0, 0],
48
                       [0, 0, 0, 1, 0],
49
                       [0, 0, 0, 0, 1]]),
50
        Entry(selfies="[C][O][C]",
51
              symbols=["[C]", "[O]", "[C]"],
52
              label=[3, 1, 3, 0],
53
              one_hot=[[0, 0, 0, 1, 0],
54
                       [0, 1, 0, 0, 0],
55
                       [0, 0, 0, 1, 0],
56
                       [1, 0, 0, 0, 0]])
57
    ]
58
59
    return entries, (stoi, itos, pad_to_len)
60
61
62
@pytest.fixture()
63
def dataset_flat_hots(dataset):
64
    flat_hots = []
65
    for entry in dataset[0]:
66
        hot = [elm for vec in entry.one_hot for elm in vec]
67
        flat_hots.append(hot)
68
    return flat_hots
69
70
71
def test_len_selfies(dataset):
72
    for entry in dataset[0]:
73
        assert sf.len_selfies(entry.selfies) == len(entry.symbols)
74
75
76
def test_split_selfies(dataset):
77
    for entry in dataset[0]:
78
        assert list(sf.split_selfies(entry.selfies)) == entry.symbols
79
80
81
def test_get_alphabet_from_selfies(dataset):
82
    entries, (vocab_stoi, _, _) = dataset
83
84
    selfies = [entry.selfies for entry in entries]
85
    alphabet = sf.get_alphabet_from_selfies(selfies)
86
    alphabet.add("[nop]")
87
    alphabet.add(".")
88
89
    assert alphabet == set(vocab_stoi.keys())
90
91
92
def test_selfies_to_encoding(dataset):
93
    entries, (vocab_stoi, vocab_itos, pad_to_len) = dataset
94
95
    for entry in entries:
96
        label, one_hot = sf.selfies_to_encoding(
97
            entry.selfies, vocab_stoi, pad_to_len, "both"
98
        )
99
100
        assert label == entry.label
101
        assert one_hot == entry.one_hot
102
103
        # recover original selfies
104
        selfies = sf.encoding_to_selfies(label, vocab_itos, "label")
105
        selfies = selfies.replace("[nop]", "")
106
        assert selfies == entry.selfies
107
108
        selfies = sf.encoding_to_selfies(one_hot, vocab_itos, "one_hot")
109
        selfies = selfies.replace("[nop]", "")
110
        assert selfies == entry.selfies
111
112
113
def test_selfies_to_flat_hot(dataset, dataset_flat_hots):
114
    entries, (vocab_stoi, vocab_itos, pad_to_len) = dataset
115
116
    batch = [entry.selfies for entry in entries]
117
    flat_hots = sf.batch_selfies_to_flat_hot(batch, vocab_stoi, pad_to_len)
118
119
    assert flat_hots == dataset_flat_hots
120
121
    # recover original selfies
122
    recovered = sf.batch_flat_hot_to_selfies(flat_hots, vocab_itos)
123
    assert batch == [s.replace("[nop]", "") for s in recovered]