--- a +++ b/tests/test_selfies.py @@ -0,0 +1,161 @@ +import faulthandler +import random + +import pytest +from rdkit.Chem import MolFromSmiles + +import selfies as sf + +faulthandler.enable() + + +@pytest.fixture() +def max_selfies_len(): + return 1000 + + +@pytest.fixture() +def large_alphabet(): + alphabet = sf.get_semantic_robust_alphabet() + alphabet.update([ + "[#Br]", "[#Branch1]", "[#Branch2]", "[#Branch3]", "[#C@@H1]", + "[#C@@]", "[#C@H1]", "[#C@]", "[#C]", "[#Cl]", "[#F]", "[#H]", "[#I]", + "[#NH1]", "[#N]", "[#O]", "[#P]", "[#Ring1]", "[#Ring2]", "[#Ring3]", + "[#S]", "[/Br]", "[/C@@H1]", "[/C@@]", "[/C@H1]", "[/C@]", "[/C]", + "[/Cl]", "[/F]", "[/H]", "[/I]", "[/NH1]", "[/N]", "[/O]", "[/P]", + "[/S]", "[=Br]", "[=Branch1]", "[=Branch2]", "[=Branch3]", "[=C@@H1]", + "[=C@@]", "[=C@H1]", "[=C@]", "[=C]", "[=Cl]", "[=F]", "[=H]", "[=I]", + "[=NH1]", "[=N]", "[=O]", "[=P]", "[=Ring1]", "[=Ring2]", "[=Ring3]", + "[=S]", "[Br]", "[Branch1]", "[Branch2]", "[Branch3]", "[C@@H1]", + "[C@@]", "[C@H1]", "[C@]", "[C]", "[Cl]", "[F]", "[H]", "[I]", "[NH1]", + "[N]", "[O]", "[P]", "[Ring1]", "[Ring2]", "[Ring3]", "[S]", "[\\Br]", + "[\\C@@H1]", "[\\C@@]", "[\\C@H1]", "[\\C@]", "[\\C]", "[\\Cl]", + "[\\F]", "[\\H]", "[\\I]", "[\\NH1]", "[\\N]", "[\\O]", "[\\P]", + "[\\S]", "[nop]" + ]) + return list(alphabet) + + +def test_random_selfies_decoder(trials, max_selfies_len, large_alphabet): + """Tests that SELFIES that are generated by randomly stringing together + symbols from the SELFIES alphabet are decoded into valid SMILES. + """ + + alphabet = tuple(large_alphabet) + + for _ in range(trials): + + # create random SELFIES and decode + rand_len = random.randint(1, max_selfies_len) + rand_selfies = "".join(random_choices(alphabet, k=rand_len)) + smiles = sf.decoder(rand_selfies) + + # check if SMILES is valid + try: + is_valid = MolFromSmiles(smiles, sanitize=True) is not None + except Exception: + is_valid = False + + err_msg = "SMILES: {}\n\t SELFIES: {}".format(smiles, rand_selfies) + assert is_valid, err_msg + + +def test_nop_symbol_decoder(max_selfies_len, large_alphabet): + """Tests that the '[nop]' symbol is always skipped over. + """ + + alphabet = list(large_alphabet) + alphabet.remove("[nop]") + + for _ in range(100): + + # create random SELFIES with and without [nop] + rand_len = random.randint(1, max_selfies_len) + rand_mol = random_choices(alphabet, k=rand_len) + rand_mol.extend(["[nop]"] * (max_selfies_len - rand_len)) + random.shuffle(rand_mol) + + with_nops = "".join(rand_mol) + without_nops = with_nops.replace("[nop]", "") + + assert sf.decoder(with_nops) == sf.decoder(without_nops) + + +def test_get_semantic_constraints(): + constraints = sf.get_semantic_constraints() + assert constraints is not sf.get_semantic_constraints() # not alias + assert "?" in constraints + + +def test_change_constraints_cache_clear(): + alphabet = sf.get_semantic_robust_alphabet() + assert alphabet == sf.get_semantic_robust_alphabet() + assert sf.decoder("[C][#C]") == "C#C" + + new_constraints = sf.get_semantic_constraints() + new_constraints["C"] = 1 + sf.set_semantic_constraints(new_constraints) + + new_alphabet = sf.get_semantic_robust_alphabet() + assert new_alphabet != alphabet + assert sf.decoder("[C][#C]") == "CC" + + sf.set_semantic_constraints() # re-set alphabet + + +def test_invalid_or_unsupported_smiles_encoder(): + malformed_smiles = [ + "", + "(", + "C(Cl)(Cl)CC[13C", + "C(CCCOC", + "C=(CCOC", + "CCCC)", + "C1CCCCC", + "C(F)(F)(F)(F)(F)F", # violates bond constraints + "C=C1=CCCCCC1", # violates bond constraints + "CC*CC", # uses wildcard + "C$C", # uses $ bond + "S[As@TB1](F)(Cl)(Br)N", # unrecognized chirality, + "SOMETHINGWRONGHERE", + "1243124124", + ] + + for smiles in malformed_smiles: + with pytest.raises(sf.EncoderError): + sf.encoder(smiles) + + +def test_malformed_selfies_decoder(): + with pytest.raises(sf.DecoderError): + sf.decoder("[O][=C][O][C][C][C][C][O][N][Branch2_3") + + +def random_choices(population, k): # random.choices was new in Python v3.6 + return [random.choice(population) for _ in range(k)] + + +def test_decoder_attribution(): + sm, am = sf.decoder( + "[C][N][C][Branch1][C][P][C][C][Ring1][=Branch1]", attribute=True) + # check that P lined up + for ta in am: + if ta.token == 'P': + for a in ta.attribution: + if a.token == '[P]': + return + raise ValueError('Failed to find P in attribution map') + + +def test_encoder_attribution(): + smiles = "C1([O-])C=CC=C1Cl" + indices = [0, 3, 3, 3, 5, 7, 8, 10, None, None, 12] + s, am = sf.encoder(smiles, attribute=True) + for i, ta in enumerate(am): + if ta.attribution: + assert indices[i] == ta.attribution[0].index, \ + f'found {ta[1]}; should be {indices[i]}' + if ta.token == '[Cl]': + assert 'Cl' in [ + a.token for a in ta.attribution],\ + 'Failed to find Cl in attribution map'