a b/unimol/utils/decode_utils.py
1
from unimol.data.dictionary import DecoderDictionary
2
import selfies as sf
3
from rdkit import Chem
4
from rdkit.Chem import AllChem
5
from rdkit.Chem.Crippen import MolLogP
6
from rdkit.Chem import MolFromSmiles
7
8
9
def one_hot_to_selfies(hot, dict1:DecoderDictionary):
10
    '''> 3 means to get rid of special tokens in the molecule representation.'''
11
    selfies_list = []
12
    # print(hot.transpose(0, 1).argmax(1))
13
    for idx in hot.transpose(0, 1).argmax(1):
14
        if idx.item() == dict1.index('[SEP]') or idx.item() == dict1.index('[PAD]'):
15
            break
16
        elif idx.item() == dict1.index('[UNK]') or idx.item() == dict1.index('[CLS]'):
17
            selfies_list.append('[nop]')            
18
        else:
19
            selfies_list.append(dict1.index2symbol(idx.item()))
20
    # print("selfies_list: {}".format(selfies_list))
21
    # return ''.join([dict.index2symbol(idx.item()) if idx.item() > 3 else '' for idx in hot.transpose(0, 1).argmax(1)]).replace(' ', '')
22
    # return ''.join([dict.index2symbol(idx.item()) if idx.item() > 3 else '[nop]' for idx in hot.transpose(0, 1).argmax(1)]).replace(' ', '')
23
    return ''.join(selfies_list).replace(' ', '')
24
25
26
def one_hot_to_smiles(hot, dict_):
27
    '''Return both the smile repre. and the selfies rep.'''
28
    selfies = one_hot_to_selfies(hot, dict_)
29
    # selfies_list = list(sf.split_selfies(selfies))
30
    # return sf.decoder(selfies), selfies_list
31
    return sf.decoder(selfies)
32
33
34
def label_smiles(smiles:list):
35
    """Label a batch of smiles to in the form of Unimol compatible dataset"""
36
37
    selfies = [list(sf.split_selfies(sf.encoder(smile))) for smile in smiles]
38
    new_data_list = []
39
    
40
    for idx, smile in enumerate(smiles):
41
        data_dict = dict()
42
        try:
43
            m = Chem.MolFromSmiles(smile)
44
            m3d = Chem.AddHs(m)
45
        except:
46
            # invalid smile generated
47
            continue
48
49
        atom_list = []
50
        for atom in m3d.GetAtoms():
51
            atom_list.append(atom.GetSymbol())
52
        
53
        selfie = selfies[idx]
54
55
        #print(selfie)
56
57
        #selfie_idx = [dict1.index(item) for item in selfie]
58
        #print(selfie_idx)
59
60
61
62
63
        data_dict['atoms'] = atom_list
64
        
65
        # coord_list = []
66
        # cids = AllChem.EmbedMultipleConfs(m3d, numConfs=10, numThreads=0)
67
        # for id in cids:
68
        #     conf = m3d.GetConformer(id=id)
69
        #     coord_list.append(conf.GetPositions())
70
        # data_dict['coordinates'] = coord_list
71
        data_dict['coordinates'] = [] # No need to add coordinates
72
        
73
        data_dict['smi'] = smile
74
        data_dict['scaffold'] = ''
75
        data_dict['ori_index'] = -1
76
        data_dict['selfies'] = selfies[idx]
77
        data_dict['target'] = MolLogP(MolFromSmiles(smile))
78
79
        new_data_list.append(data_dict)
80
    
81
    return new_data_list