[1aa732]: / tests / test_selfies_utils.py

Download this file

124 lines (94 with data), 3.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pytest
import selfies as sf
class Entry:
def __init__(self, selfies, symbols, label, one_hot):
self.selfies = selfies
self.symbols = symbols
self.label = label
self.one_hot = one_hot
@pytest.fixture()
def dataset():
stoi = {"[nop]": 0, "[O]": 1, ".": 2, "[C]": 3, "[F]": 4}
itos = {i: c for c, i in stoi.items()}
pad_to_len = 4
entries = [
Entry(selfies="",
symbols=[],
label=[0, 0, 0, 0],
one_hot=[[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 0, 0, 0]]),
Entry(selfies="[C][C][C]",
symbols=["[C]", "[C]", "[C]"],
label=[3, 3, 3, 0],
one_hot=[[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 1, 0],
[1, 0, 0, 0, 0]]),
Entry(selfies="[C].[C]",
symbols=["[C]", ".", "[C]"],
label=[3, 2, 3, 0],
one_hot=[[0, 0, 0, 1, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[1, 0, 0, 0, 0]]),
Entry(selfies="[C][O][C][F]",
symbols=["[C]", "[O]", "[C]", "[F]"],
label=[3, 1, 3, 4],
one_hot=[[0, 0, 0, 1, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]]),
Entry(selfies="[C][O][C]",
symbols=["[C]", "[O]", "[C]"],
label=[3, 1, 3, 0],
one_hot=[[0, 0, 0, 1, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 1, 0],
[1, 0, 0, 0, 0]])
]
return entries, (stoi, itos, pad_to_len)
@pytest.fixture()
def dataset_flat_hots(dataset):
flat_hots = []
for entry in dataset[0]:
hot = [elm for vec in entry.one_hot for elm in vec]
flat_hots.append(hot)
return flat_hots
def test_len_selfies(dataset):
for entry in dataset[0]:
assert sf.len_selfies(entry.selfies) == len(entry.symbols)
def test_split_selfies(dataset):
for entry in dataset[0]:
assert list(sf.split_selfies(entry.selfies)) == entry.symbols
def test_get_alphabet_from_selfies(dataset):
entries, (vocab_stoi, _, _) = dataset
selfies = [entry.selfies for entry in entries]
alphabet = sf.get_alphabet_from_selfies(selfies)
alphabet.add("[nop]")
alphabet.add(".")
assert alphabet == set(vocab_stoi.keys())
def test_selfies_to_encoding(dataset):
entries, (vocab_stoi, vocab_itos, pad_to_len) = dataset
for entry in entries:
label, one_hot = sf.selfies_to_encoding(
entry.selfies, vocab_stoi, pad_to_len, "both"
)
assert label == entry.label
assert one_hot == entry.one_hot
# recover original selfies
selfies = sf.encoding_to_selfies(label, vocab_itos, "label")
selfies = selfies.replace("[nop]", "")
assert selfies == entry.selfies
selfies = sf.encoding_to_selfies(one_hot, vocab_itos, "one_hot")
selfies = selfies.replace("[nop]", "")
assert selfies == entry.selfies
def test_selfies_to_flat_hot(dataset, dataset_flat_hots):
entries, (vocab_stoi, vocab_itos, pad_to_len) = dataset
batch = [entry.selfies for entry in entries]
flat_hots = sf.batch_selfies_to_flat_hot(batch, vocab_stoi, pad_to_len)
assert flat_hots == dataset_flat_hots
# recover original selfies
recovered = sf.batch_flat_hot_to_selfies(flat_hots, vocab_itos)
assert batch == [s.replace("[nop]", "") for s in recovered]