Switch to unified view

a b/src/DataLoader/download.py
1
import requests
2
import sys
3
import xml.etree.ElementTree as ET
4
import os
5
import time
6
import joblib 
7
from tqdm.auto import tqdm
8
import numpy as np
9
from tenacity import retry, wait_random_exponential, stop_after_attempt
10
11
# # Open the log file
12
# log_file = open('../logs/download.log', 'w')
13
# # Redirect standard output to the log file
14
# sys.stdout = log_file
15
16
def normalize_whitespace(s):
17
    return ' '.join(s.split())
18
19
20
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(10))
21
def get_cancer_trials_list(max_trials=15000):
22
    base_url = "https://clinicaltrials.gov/api/query/full_studies"
23
    trials_set = set()
24
    page_size = 100  # Number of trials per page
25
    current_rank = 1
26
    trials_fetched = 0
27
    while trials_fetched < max_trials:
28
        search_params = {
29
            "expr": "((cancer) OR (neoplasm)) AND ((interventional) OR (treatment)) AND ((mutation) OR (variant))",
30
            "min_rnk": current_rank,
31
            "max_rnk": current_rank + page_size - 1,
32
            "fmt": "json",
33
            "fields": "NCTId"
34
        }
35
        response = requests.get(base_url, params=search_params)
36
        if response.status_code == 200:
37
            trials_data = response.json()
38
            if "FullStudiesResponse" in trials_data:
39
                studies = trials_data["FullStudiesResponse"]["FullStudies"]
40
                if not studies:
41
                    break  # No more studies found, exit the loop
42
                for study in studies:
43
                    trials_set.add(study["Study"]["ProtocolSection"]["IdentificationModule"]["NCTId"])
44
                    trials_fetched += 1
45
                    if trials_fetched == max_trials:
46
                        break
47
                current_rank += page_size
48
            else:
49
                print("No trials found matching the criteria.")
50
                break
51
        else:
52
            print("Failed to retrieve data. Status code:", response.status_code)
53
            break
54
55
    return list(trials_set)  # Convert set to list for output
56
    
57
58
def download_study_info(nct_id):
59
    local_file_path = f"../data/trials_xmls/{nct_id}.xml"
60
61
    if os.path.exists(local_file_path):
62
        # Read the content of the existing local XML file
63
        with open(local_file_path, "r") as f:
64
            local_xml_content = f.read()
65
        try:
66
            local_root = ET.fromstring(local_xml_content)
67
        except ET.ParseError as e:
68
            print(f"Error parsing XML for trial {nct_id}: {e}")
69
            os.remove(local_file_path)
70
        # Download the online version of the XML
71
        url = f"https://clinicaltrials.gov/ct2/show/{nct_id}?displayxml=true"
72
        try:
73
            response = requests.get(url)
74
        except requests.exceptions.RequestException as e:
75
            print(f"Error fetching XML for trial {nct_id}: {e}")
76
77
        if response.status_code == 200:
78
            online_xml_content = response.text
79
            # Parse the XML content
80
            try:
81
                online_root = ET.fromstring(online_xml_content)
82
            except ET.ParseError as e:
83
                print(f"Error parsing online XML for trial {nct_id}: {e}")
84
85
        else:
86
            print(f"Error: received status code {response.status_code} when fetching XML for trial {nct_id}")
87
88
        to_check = ["eligibility", "brief_title", "overall_status", "location"]
89
                
90
        local_version = []
91
        online_version = []
92
        
93
        for s in to_check:
94
            local_elem = local_root.find(".//%s" % s)
95
            online_elem = online_root.find(".//%s" % s)
96
            
97
            # Check if the element exists in both versions
98
            if local_elem is not None and online_elem is not None:
99
                local_version.append(local_elem)
100
                online_version.append(online_elem)
101
            else:
102
                continue
103
        
104
        is_updated = any([normalize_whitespace(ET.tostring(a, encoding='unicode').strip()) !=
105
                        normalize_whitespace(ET.tostring(b, encoding='unicode').strip())
106
                        for a, b in zip(local_version, online_version)])
107
108
        if is_updated:
109
            # Update the local XML with the online version
110
            with open(local_file_path, "w") as f:
111
                f.write(ET.tostring(online_root, encoding='unicode'))
112
            print(f"Updated eligibility criteria for {nct_id}")
113
        else:
114
            print(f"No changes in eligibility criteria for {nct_id}.")
115
    else:
116
        # If the local file doesn't exist, download the online version
117
        url = f"https://clinicaltrials.gov/ct2/show/{nct_id}?displayxml=true"
118
        try:
119
            response = requests.get(url)
120
        except requests.exceptions.RequestException as e:
121
            print(f"Error fetching XML for trial {nct_id}: {e}")
122
123
        if response.status_code == 200:
124
            try:
125
                root = ET.fromstring(response.text)
126
                with open(local_file_path, "w") as f:
127
                    f.write(ET.tostring(root, encoding='unicode'))
128
                print(f"Study information downloaded for {nct_id}")
129
            except ET.ParseError as e:
130
                print(f"Error parsing online XML for trial {nct_id}: {e}")
131
        else:
132
            print(f"Error: received status code {response.status_code} when fetching XML for trial {nct_id}")
133
    return []
134
135
    
136
137
memory = joblib.Memory(".")
138
def ParallelExecutor(use_bar="tqdm", **joblib_args):
139
    """Utility for tqdm progress bar in joblib.Parallel"""
140
    all_bar_funcs = {
141
        "tqdm": lambda args: lambda x: tqdm(x, **args),
142
        "False": lambda args: iter,
143
        "None": lambda args: iter,
144
    }
145
    def aprun(bar=use_bar, **tq_args):
146
        def tmp(op_iter):
147
            if str(bar) in all_bar_funcs.keys():
148
                bar_func = all_bar_funcs[str(bar)](tq_args)
149
            else:
150
                raise ValueError("Value %s not supported as bar type" % bar)
151
            
152
            # Pass n_jobs from joblib_args
153
            return joblib.Parallel(n_jobs=joblib_args.get("n_jobs", 10))(bar_func(op_iter))
154
155
        return tmp
156
    return aprun
157
158
def parallel_downloader(
159
    n_jobs,
160
    nct_ids,
161
):
162
    parallel_runner = ParallelExecutor(n_jobs=n_jobs)(total=len(nct_ids))
163
    X = parallel_runner(
164
        joblib.delayed(download_study_info)(
165
        nct_id, 
166
        )
167
        for nct_id in nct_ids
168
    )     
169
    updated_cts = np.vstack(X).flatten()
170
    return updated_cts 
171
172
173
class Downloader:
174
    def __init__(self, id_list, n_jobs):
175
        self.id_list = id_list
176
        self.n_jobs = n_jobs
177
178
    def download_and_update_trials(self):
179
        start_time = time.time()
180
        updated_cts = parallel_downloader(self.n_jobs, self.id_list)
181
        end_time = time.time()
182
        elapsed_time = end_time - start_time
183
        print(f"Elapsed time: {elapsed_time} seconds")
184
        return updated_cts
185
186
187
if __name__ == "__main__":
188
    id_list = [...]  # Replace [...] with your list of IDs
189
    n_jobs = ...  # Replace ... with the number of parallel jobs
190
    downloader = Downloader(id_list, n_jobs)
191
    downloader.download_and_update_trials()
192