Diff of /HINT/utils.py [000000] .. [bc9e98]

Switch to unified view

a b/HINT/utils.py
1
###### import ######
2
3
import pickle
4
import numpy as np 
5
from rdkit import Chem 
6
from rdkit.Chem import AllChem
7
from rdkit import RDLogger  
8
RDLogger.DisableLog('rdApp.info')
9
RDLogger.DisableLog('rdApp.*')
10
###### import ######
11
12
13
14
def plot_hist(prefix_name, prediction, label):
15
    import seaborn as sns
16
    import matplotlib.pyplot as plt
17
    figure_name = prefix_name + "_histogram.png"
18
    positive_prediction = [prediction[i] for i in range(len(label)) if label[i]==1]
19
    negative_prediction = [prediction[i] for i in range(len(label)) if label[i]==0]
20
    save_file_name = "results/" + prefix_name.split('/')[-1] + "_positive_negative.pkl"
21
    pickle.dump((positive_prediction, negative_prediction), open(save_file_name, 'wb'))
22
    sns.distplot(positive_prediction, hist=True,  kde=False, bins=20, color = 'blue', label = 'success')  #### bins = 50 -> 20 
23
    sns.distplot(negative_prediction, hist=True,  kde=False, bins=20, color = 'red', label = 'fail')
24
    plt.xlabel("predicted success probability", fontsize=24)
25
    plt.ylabel("frequencies", fontsize = 25)
26
    plt.legend(fontsize = 21)
27
    plt.tight_layout()
28
    # plt.show()
29
    plt.savefig(figure_name)
30
    return 
31
32
def replace_strange_symbol(text):
33
    for i in "[]'\n/":
34
        text = text.replace(i,'_')
35
    return text
36
37
#  xml read blog:  https://blog.csdn.net/yiluochenwu/article/details/23515923 
38
def walkData(root_node, prefix, result_list):
39
    temp_list =[prefix + '/' + root_node.tag, root_node.text]
40
    result_list.append(temp_list)
41
    children_node = root_node.getchildren()
42
    if len(children_node) == 0:
43
        return
44
    for child in children_node:
45
        walkData(child, prefix = prefix + '/' + root_node.tag, result_list = result_list)
46
47
48
def dynamic_programming(s1, s2):
49
    arr2d = [[0 for i in s2] for j in s1]
50
    if s1[0] == s2[0]:
51
        arr2d[0][0] = 1
52
    for i in range(1, len(s1)):
53
        if s1[i]==s2[0]:
54
            arr2d[i][0] = 1
55
        else:
56
            arr2d[i][0] = arr2d[i-1][0] 
57
    for i in range(1,len(s2)):
58
        if s2[i]==s1[0]:
59
            arr2d[0][i] = 1 
60
        else:
61
            arr2d[0][i] = arr2d[0][i-1]
62
    for i in range(1,len(s1)):
63
        for j in range(1,len(s2)):
64
            if s1[i] == s2[j]:
65
                arr2d[i][j] = arr2d[i-1][j-1] + 1 
66
            else:
67
                arr2d[i][j] = max(arr2d[i-1][j], arr2d[i][j-1])
68
    return arr2d[len(s1)-1][len(s2)-1]
69
70
71
def get_path_of_all_xml_file():
72
    input_file = "./data/all_xml"
73
    with open(input_file, 'r') as fin:
74
        lines = fin.readlines()
75
    input_file_lst = [i.strip() for i in lines]
76
    return input_file_lst 
77
78
79
def remove_multiple_space(text):
80
    text = ' '.join(text.split())
81
    return text 
82
83
def nctid_2_xml_file_path(nctid):
84
    assert len(nctid)==11
85
    prefix = nctid[:7] + "xxxx"
86
    datafolder = os.path.join("./ClinicalTrialGov/", prefix, nctid+".xml")
87
    return datafolder 
88
89
90
def fingerprints_from_mol(mol):
91
    fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True)
92
    size = 2048
93
    nfp = np.zeros((1, size), np.int32)
94
    for idx,v in fp.GetNonzeroElements().items():
95
        nidx = idx%size
96
        nfp[0, nidx] += int(v)
97
    return nfp
98
99
def smiles2fp(smiles):
100
    try:
101
        mol = Chem.MolFromSmiles(smile)
102
        fp = fingerprints_from_mol(mol)
103
        return fp 
104
    except:
105
        return np.zeros((1, 2048), np.int32)
106
107
def smiles_lst2fp(smiles_lst):
108
    fp_lst = [smiles2fp(smiles) for smiles in smiles_lst]
109
    fp_mat = np.concatenate(fp_lst, 0)
110
    fp = np.mean(fp_mat,0)
111
    return fp   
112
113
114
115
116
117
if __name__ == "__main__":
118
    text = "interpret_result/NCT00329602__completed____1__1.7650960683822632__phase 4__['restless legs syndrome']__['placebo', 'ropinirole'].png"
119
    print(replace_strange_symbol(text))
120
121
122
123
124
125
126
# if __name__ == "__main__":
127
#   input_file_lst = get_path_of_all_xml_file() 
128
#   print(input_file_lst[:5])
129
# '''
130
# input_file_lst = [ 
131
#   'ClinicalTrialGov/NCT0000xxxx/NCT00000102.xml', 
132
#   'ClinicalTrialGov/NCT0000xxxx/NCT00000104.xml', 
133
#   'ClinicalTrialGov/NCT0000xxxx/NCT00000105.xml', 
134
#     ... ]
135
# '''
136
137
138
139
# if __name__ == "__main__":
140
#   s1 = "328943"
141
#   s2 = "13785"
142
#   assert dynamic_programming(s1, s2)==2 
143
144
145
146
147