a b/benchmark/collect_all.py
1
# -*- coding: utf-8 -*- 
2
import os, csv, pickle   
3
from xml.dom import minidom
4
from xml.etree import ElementTree as ET
5
from collections import defaultdict
6
from time import time 
7
import re 
8
from tqdm import tqdm 
9
10
from utils import dynamic_programming
11
12
13
def get_all_file():
14
    input_file = "all_xml"
15
    with open(input_file, 'r') as fin:
16
        lines = fin.readlines()
17
    input_file_lst = [i.strip() for i in lines]
18
    return input_file_lst 
19
20
'''
21
input_file_lst = [ 
22
    'ClinicalTrialGov/NCT0000xxxx/NCT00000102.xml', 
23
    'ClinicalTrialGov/NCT0000xxxx/NCT00000104.xml', 
24
    'ClinicalTrialGov/NCT0000xxxx/NCT00000105.xml', 
25
      ... ]
26
'''
27
28
def remove_multiple_space(text):
29
    text = ' '.join(text.split())
30
    return text 
31
32
def generate_complete_path(nctid):
33
    assert len(nctid)==11
34
    prefix = nctid[:7] + "xxxx"
35
    datafolder = os.path.join("./ClinicalTrialGov/", prefix, nctid+".xml")
36
    return datafolder 
37
38
#  xml read blog:  https://blog.csdn.net/yiluochenwu/article/details/23515923 
39
def walkData(root_node, prefix, result_list):
40
    temp_list =[prefix + '/' + root_node.tag, root_node.text]
41
    result_list.append(temp_list)
42
    children_node = root_node.getchildren()
43
    if len(children_node) == 0:
44
        return
45
    for child in children_node:
46
        walkData(child, prefix = prefix + '/' + root_node.tag, result_list = result_list)
47
48
def root2outcome(root):
49
    result_list = []
50
    walkData(root, prefix = '', result_list = result_list) 
51
    filter_func = lambda x:'p_value' in x[0] 
52
    outcome_list = list(filter(filter_func, result_list))
53
    if len(outcome_list)==0:
54
        return None 
55
    outcome = outcome_list[0][1]
56
    if outcome[0]=='<':
57
        return 1
58
    if outcome[0]=='>':
59
        return 0 
60
    if outcome[0]=='=':
61
        outcome = outcome[1:]
62
    try:
63
        label = float(outcome)
64
        if label < 0.05:
65
            return 1
66
        else:
67
            return 0
68
    except:
69
        return None 
70
71
def file2dict(xml_file):
72
    tree = ET.parse(xml_file)
73
    root = tree.getroot()
74
    nctid = root.find('id_info').find('nct_id').text    ### nctid: 'NCT00000102'
75
    title = root.find('brief_title').text
76
    study_type = root.find('study_type').text 
77
    if study_type != 'Interventional':
78
        return (None,)
79
    label = root2outcome(root)
80
    if label is None:
81
        return (None,)
82
    conditions = [i.text for i in root.findall('condition')]
83
    interventions = [i for i in root.findall('intervention')]
84
    drug_interventions = [i.find('intervention_name').text for i in interventions \
85
                                                        if i.find('intervention_type').text=='Drug']
86
                                                        # or i.find('intervention_type').text=='Biological']
87
    #print(len(interventions), "drug intervention", drug_interventions)
88
    try:
89
        status = root.find('overall_status').text 
90
    except:
91
        status = ''
92
    try:
93
        criteria = root.find('eligibility').find('criteria').find('textblock').text 
94
        print("criteria\n\t\t", criteria)
95
    except:
96
        criteria = ''
97
    #if criteria != '':
98
    #   assert "Inclusion Criteria:" in criteria 
99
    #   assert "Exclusion Criteria:" in criteria 
100
    try: 
101
        summary = root.find('brief_summary').text 
102
        print("summary\n\t\t", summary)
103
    except:
104
        summary = '' 
105
    try:
106
        phase = root.find('phase').text 
107
        print("phase\n\t\t", phase)
108
    except:
109
        phase = ''
110
    return nctid, status, label, phase, conditions, drug_interventions, title, criteria, summary 
111
112
113
114
def getXmlData(file_name):
115
    result_list = []
116
    root = ET.parse(file_name).getroot()
117
    walkData(root, prefix = '', result_list = result_list) 
118
    return result_list
119
120
121
def Get_Iqvia_data():
122
    nct2outcome_file = "data/trial_outcomes_v1.csv"
123
    outcome2label_file = "data/outcome2label.txt"
124
    outcome2label = dict()
125
    nct2label = dict() 
126
    with open(outcome2label_file, 'r') as fin:
127
        lines = fin.readlines() 
128
    for line in lines:
129
        outcome = line.split('\t')[0]
130
        label = int(line.split('\t')[1])
131
        outcome2label[outcome] = label 
132
    with open(nct2outcome_file, 'r') as csvfile:
133
        reader = list(csv.reader(csvfile, delimiter=','))[1:]
134
        for row in reader:
135
            nctid, outcome = row[0], row[1]
136
            label = outcome2label[outcome]
137
            if nctid in nct2label:
138
                if label > nct2label[nctid]:
139
                    nct2label[nctid] = label 
140
            else:
141
                nct2label[nctid] = label 
142
    ### remove the key whole value is -1
143
    for nctid in list(nct2label.keys()):
144
        label = nct2label[nctid]
145
        if label == -1:
146
            nct2label.pop(nctid)
147
    return nct2label 
148
149
def load_drug2smiles_pkl():
150
    pkl_file = "data/drug2smiles.pkl"
151
    drug2smiles = pickle.load(open(pkl_file, 'rb'))
152
    return drug2smiles 
153
154
def load_disease2icd_pkl():
155
    iqvia_pkl_file = "data/disease2icd.pkl"
156
    public_pkl_file = "icdcode/description2icd.pkl"
157
    iqvia_disease2icd = pickle.load(open(iqvia_pkl_file, 'rb'))
158
    public_disease2icd = pickle.load(open(public_pkl_file, 'rb'))
159
    return iqvia_disease2icd, public_disease2icd 
160
161
162
163
def drug_hit_smiles(drug, drug2smiles):
164
    """
165
        heuristics
166
    """
167
    if drug in drug2smiles:
168
        return drug2smiles[drug]
169
    for word in drug.split():
170
        if len(word)>=7 and word in drug2smiles:
171
            #print("drug hit: ", drug, '&', word)
172
            return drug2smiles[word]
173
    # max_length = 0
174
    # for drug0 in drug2smiles:
175
    #   length = dynamic_programming(drug, drug0)
176
    #   if length > max_length:
177
    #       best_drug = drug0 
178
    #       max_length = length 
179
    # if max_length > 9: 
180
    #   print("DP drug hit: ", drug, '&', best_drug)
181
    #   return drug2smiles[best_drug]
182
    return None         
183
184
185
def disease_hit_icd(disease, disease2icd, disease2diseaseset):
186
    """
187
        heuristics
188
    """
189
    #### match 0
190
    if disease in disease2icd:
191
        return disease2icd[disease]
192
    #### match 1
193
    for word in disease.split():
194
        if len(word)>=7 and word in disease2icd:
195
            # print("I disease hit:", disease, '&', word)
196
            return disease2icd[word]
197
    #### match 2
198
    max_length = 0
199
    diseaseset = set(re.split(r"[\', /-]",disease))
200
    for disease0, disease0set in disease2diseaseset.items():
201
        intersection_set = disease0set.intersection(diseaseset)
202
        length = len(intersection_set)
203
        wordlength = len(''.join(list(intersection_set)))
204
        if length > max_length and wordlength > 8:
205
            max_length = length
206
            best_disease = disease0
207
    if max_length > 1:
208
        #print("II disease hit:", disease, '&', best_disease)       
209
        return disease2icd[best_disease]
210
211
    # max_length = 0
212
    # for disease0 in disease2icd:
213
    #   length = dynamic_programming(disease, disease0)
214
    #   if length > max_length:
215
    #       best_disease = disease0 
216
    #       max_length = length 
217
    # if max_length > 20: 
218
    #   print("III DP disease hit: ", disease, '&', best_disease)
219
    #   return disease2icd[best_disease]    
220
    return None
221
222
223
def disease_dict_reorganize(disease2icd):
224
    return {disease:set(re.split(r"[\', /-]",disease)) for disease in disease2icd}
225
226
227
228
def write_csv_file():
229
    cook_csv_file = 'data/cooked_trial.csv'
230
    drug2smiles = load_drug2smiles_pkl()
231
    iqvia_disease2icd, public_disease2icd  = load_disease2icd_pkl() 
232
    iqvia_disease2diseaseset = disease_dict_reorganize(iqvia_disease2icd)
233
    disease2icd = public_disease2icd 
234
    disease2diseaseset = disease_dict_reorganize(public_disease2icd)
235
    t1 = time()
236
    input_file_lst = get_all_file()
237
    fieldname = ['nctid', 'status', 'label', 'phase', 'diseases', 'icdcodes', 'drugs', 'smiless', 'title', 'criteria', 'summary']
238
    disease_hit, disease_all, drug_hit, drug_all = 0,0,0,0 ### disease hit icd && drug hit smiles
239
    with open(cook_csv_file, 'w') as csvfile:
240
        writer = csv.DictWriter(csvfile, fieldnames=fieldname)
241
        writer.writeheader()
242
        for file in tqdm(input_file_lst[:]):
243
            result = file2dict(file)
244
            if len(result)==1:
245
                continue 
246
            nctid, status, label, phase, diseases, drugs, title, criteria, summary = result
247
            icdcode_lst, smiles_lst = [], []
248
            for disease in diseases:
249
                disease = disease.lower()
250
                disease_all += 1
251
                icdcode = disease_hit_icd(disease, disease2icd, disease2diseaseset)
252
                if icdcode is not None:
253
                    disease_hit += 1
254
                    icdcode_lst.append(icdcode)
255
                else:
256
                    print("unfounded:", disease)
257
            if len(icdcode_lst)==0:
258
                continue  
259
            for drug in drugs:
260
                drug = drug.lower()
261
                drug_all += 1
262
                smiles = drug_hit_smiles(drug, drug2smiles)
263
                if smiles is not None: 
264
                    drug_hit += 1
265
                    smiles_lst.append(smiles)
266
            if len(smiles_lst)==0:
267
                continue
268
            icdcodes = '\t'.join(icdcode_lst)
269
            smiless = '\t'.join(smiles_lst)
270
            drugs = '\t'.join(smiles_lst)
271
            diseases = '\t'.join(diseases)
272
            writer.writerow({'nctid':nctid, \
273
                             'label':label, \
274
                             'phase':phase, \
275
                             'diseases':diseases.encode('utf-8'), \
276
                             'icdcodes': icdcodes, \
277
                             'drugs':drugs.encode('utf-8'), \
278
                             'smiless': smiless, \
279
                             'title':title.encode('utf-8'), \
280
                             'criteria':criteria.encode('utf-8'), \
281
                             'summary':summary.encode('utf-8')})
282
    print("disease hit icdcode", disease_hit, "disease all", disease_all, "\n drug hit smiles", drug_hit, "drug all", drug_all)
283
    t2 = time()
284
    print(str(int((t2-t1)/60)) + " minutes")
285
    return 
286
287
288
## dynamic programming
289
# if __name__ == "__main__":
290
#   a = 'dynamdddwic'
291
#   b = 'mfewweic'
292
#   print(dynamic_programming(a,b))
293
294
## write csv file 
295
if __name__ == "__main__":
296
    write_csv_file() 
297
298
# #### check csvfile
299
# if __name__ == "__main__":
300
#   cook_csv_file = 'data/cooked_trial.csv'
301
#   positive_sample_cnt, negative_sample_cnt = 0, 0
302
#   wrong_nct_list = []
303
#   correct_cnt, total_cnt = 0, 0 
304
#   iqvia_nct2label = Get_Iqvia_data() 
305
#   with open(cook_csv_file, 'r') as csvfile:
306
#       reader = list(csv.reader(csvfile, delimiter = ','))[1:]
307
#       for row in reader:
308
#           nctid = row[0]
309
#           label = int(row[1])
310
#           if nctid in iqvia_nct2label:
311
#               total_cnt += 1
312
#               iqvia_label = iqvia_nct2label[nctid]
313
#               if iqvia_label == label:
314
#                   correct_cnt += 1
315
#               else:
316
#                   wrong_nct_list.append(nctid)
317
#           if label == 1:
318
#               positive_sample_cnt += 1
319
#           elif label==0:
320
#               negative_sample_cnt += 1 
321
#   print("positive_sample_cnt", positive_sample_cnt, "negative_sample_cnt", negative_sample_cnt)
322
#   print("correct_cnt", correct_cnt, "total_cnt", total_cnt)
323
#   with open("wrong_nct.txt", 'w') as fout:
324
#       for nctid in wrong_nct_list:
325
#           fout.write(nctid + '\n')
326
327
328
329
##### p_value 
330
# if __name__ == "__main__":
331
#   ##### server
332
#   nctid = "NCT00001723"
333
#   file = generate_complete_path(nctid)
334
#   ### local 
335
#   file = "NCT00001723.xml" 
336
337
#   input_file_lst = get_all_file() 
338
#   for file in input_file_lst[:100000]:
339
#       result_list = getXmlData(file)
340
#       filter_func = lambda x:'p_value' in x[0] 
341
#       outcome_list = list(filter(filter_func, result_list))
342
#       if len(outcome_list) > 0:
343
#           print('='*50)
344
#           print(file.split('/')[-1].split('.')[0])
345
#           for i in outcome_list:
346
#               print(i)
347
348
349
350