|
a |
|
b/benchmark/pseudolabel.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
|
|
|
3 |
''' |
|
|
4 |
|
|
|
5 |
input: 348k data |
|
|
6 |
1. ClinicalTrialGov/NCTxxxx/xxxxxx.xml & all_xml |
|
|
7 |
1. data/diseases.csv |
|
|
8 |
2. data/drug2smiles.pkl |
|
|
9 |
output: data/raw_data.csv |
|
|
10 |
|
|
|
11 |
processing: |
|
|
12 |
0.1 Interventional: 273k data (348k total, e.g., observatorial, surgery, ) |
|
|
13 |
0.2 intervention_type == Drug (drug not empty) |
|
|
14 |
0.3 drop_set 96k data (273k), (we don't use drop_set to filter out) |
|
|
15 |
0.4 -1 -> 0 based on "why_stop" |
|
|
16 |
0.5 filter out -1(invalid) |
|
|
17 |
|
|
|
18 |
1. disease -> icd |
|
|
19 |
2. drug -> smiles |
|
|
20 |
3. inclusive / exclusive criteria ---- to do |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
requires ~10 minutes. |
|
|
24 |
|
|
|
25 |
''' |
|
|
26 |
|
|
|
27 |
##### standard library |
|
|
28 |
import os, csv, pickle |
|
|
29 |
from xml.dom import minidom |
|
|
30 |
from xml.etree import ElementTree as ET |
|
|
31 |
from collections import defaultdict |
|
|
32 |
from time import time |
|
|
33 |
import re |
|
|
34 |
from tqdm import tqdm |
|
|
35 |
|
|
|
36 |
from utils import get_path_of_all_xml_file, walkData |
|
|
37 |
|
|
|
38 |
drop_set = ['Active, not recruiting', 'Enrolling by invitation', 'No longer available', |
|
|
39 |
'Not yet recruiting', 'Recruiting', 'Temporarily not available', 'Unknown status'] |
|
|
40 |
|
|
|
41 |
''' |
|
|
42 |
14 overall_status |
|
|
43 |
|
|
|
44 |
Active, not recruiting |
|
|
45 |
Approved for marketing |
|
|
46 |
Available |
|
|
47 |
Completed |
|
|
48 |
Enrolling by invitation |
|
|
49 |
No longer available |
|
|
50 |
Not yet recruiting |
|
|
51 |
Recruiting |
|
|
52 |
Suspended |
|
|
53 |
Temporarily not available |
|
|
54 |
Terminated |
|
|
55 |
Unknown status |
|
|
56 |
Withdrawn |
|
|
57 |
Withheld |
|
|
58 |
''' |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def root2outcome(root): |
|
|
63 |
result_list = [] |
|
|
64 |
walkData(root, prefix = '', result_list = result_list) |
|
|
65 |
filter_func = lambda x:'p_value' in x[0] |
|
|
66 |
outcome_list = list(filter(filter_func, result_list)) |
|
|
67 |
if len(outcome_list)==0: |
|
|
68 |
return None |
|
|
69 |
outcome = outcome_list[0][1] |
|
|
70 |
if outcome[0]=='<': |
|
|
71 |
return 1 |
|
|
72 |
if outcome[0]=='>': |
|
|
73 |
return 0 |
|
|
74 |
if outcome[0]=='=': |
|
|
75 |
outcome = outcome[1:] |
|
|
76 |
try: |
|
|
77 |
label = float(outcome) |
|
|
78 |
if label < 0.05: |
|
|
79 |
return 1 |
|
|
80 |
else: |
|
|
81 |
return 0 |
|
|
82 |
except: |
|
|
83 |
return None |
|
|
84 |
|
|
|
85 |
def xmlfile_2_label(xml_file): |
|
|
86 |
tree = ET.parse(xml_file) |
|
|
87 |
root = tree.getroot() |
|
|
88 |
nctid = root.find('id_info').find('nct_id').text ### nctid: 'NCT00000102' |
|
|
89 |
|
|
|
90 |
label = root2outcome(root) |
|
|
91 |
label = -1 if label is None else label |
|
|
92 |
|
|
|
93 |
return label |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
|
|
|
97 |
|
|
|
98 |
|
|
|
99 |
|