--- a +++ b/tests/test_selfies_utils.py @@ -0,0 +1,123 @@ +import pytest + +import selfies as sf + + +class Entry: + + def __init__(self, selfies, symbols, label, one_hot): + self.selfies = selfies + self.symbols = symbols + self.label = label + self.one_hot = one_hot + + +@pytest.fixture() +def dataset(): + stoi = {"[nop]": 0, "[O]": 1, ".": 2, "[C]": 3, "[F]": 4} + itos = {i: c for c, i in stoi.items()} + pad_to_len = 4 + + entries = [ + Entry(selfies="", + symbols=[], + label=[0, 0, 0, 0], + one_hot=[[1, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 0, 0, 0, 0]]), + Entry(selfies="[C][C][C]", + symbols=["[C]", "[C]", "[C]"], + label=[3, 3, 3, 0], + one_hot=[[0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 1, 0], + [1, 0, 0, 0, 0]]), + Entry(selfies="[C].[C]", + symbols=["[C]", ".", "[C]"], + label=[3, 2, 3, 0], + one_hot=[[0, 0, 0, 1, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [1, 0, 0, 0, 0]]), + Entry(selfies="[C][O][C][F]", + symbols=["[C]", "[O]", "[C]", "[F]"], + label=[3, 1, 3, 4], + one_hot=[[0, 0, 0, 1, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1]]), + Entry(selfies="[C][O][C]", + symbols=["[C]", "[O]", "[C]"], + label=[3, 1, 3, 0], + one_hot=[[0, 0, 0, 1, 0], + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + [1, 0, 0, 0, 0]]) + ] + + return entries, (stoi, itos, pad_to_len) + + +@pytest.fixture() +def dataset_flat_hots(dataset): + flat_hots = [] + for entry in dataset[0]: + hot = [elm for vec in entry.one_hot for elm in vec] + flat_hots.append(hot) + return flat_hots + + +def test_len_selfies(dataset): + for entry in dataset[0]: + assert sf.len_selfies(entry.selfies) == len(entry.symbols) + + +def test_split_selfies(dataset): + for entry in dataset[0]: + assert list(sf.split_selfies(entry.selfies)) == entry.symbols + + +def test_get_alphabet_from_selfies(dataset): + entries, (vocab_stoi, _, _) = dataset + + selfies = [entry.selfies for entry in entries] + alphabet = sf.get_alphabet_from_selfies(selfies) + alphabet.add("[nop]") + alphabet.add(".") + + assert alphabet == set(vocab_stoi.keys()) + + +def test_selfies_to_encoding(dataset): + entries, (vocab_stoi, vocab_itos, pad_to_len) = dataset + + for entry in entries: + label, one_hot = sf.selfies_to_encoding( + entry.selfies, vocab_stoi, pad_to_len, "both" + ) + + assert label == entry.label + assert one_hot == entry.one_hot + + # recover original selfies + selfies = sf.encoding_to_selfies(label, vocab_itos, "label") + selfies = selfies.replace("[nop]", "") + assert selfies == entry.selfies + + selfies = sf.encoding_to_selfies(one_hot, vocab_itos, "one_hot") + selfies = selfies.replace("[nop]", "") + assert selfies == entry.selfies + + +def test_selfies_to_flat_hot(dataset, dataset_flat_hots): + entries, (vocab_stoi, vocab_itos, pad_to_len) = dataset + + batch = [entry.selfies for entry in entries] + flat_hots = sf.batch_selfies_to_flat_hot(batch, vocab_stoi, pad_to_len) + + assert flat_hots == dataset_flat_hots + + # recover original selfies + recovered = sf.batch_flat_hot_to_selfies(flat_hots, vocab_itos) + assert batch == [s.replace("[nop]", "") for s in recovered]