Diff of /src/Preporcessor/utils.py [000000] .. [f87529]

Switch to unified view

a b/src/Preporcessor/utils.py
1
import requests
2
import xml.etree.ElementTree as ET
3
import os
4
import time
5
import json
6
import re
7
import gzip, tarfile
8
9
def normalize_whitespace(s):
10
    return ' '.join(s.split())
11
12
def download_study_info(nct_id, runs=2):
13
    local_file_path = f"../data/trials_xmls/{nct_id}.xml"
14
    updated_cts = []
15
    for _ in range(runs):
16
        if os.path.exists(local_file_path):
17
            # Read the content of the existing local XML file
18
            with open(local_file_path, "r") as f:
19
                local_xml_content = f.read()
20
            try:
21
                local_root = ET.fromstring(local_xml_content)
22
            except ET.ParseError as e:
23
                print(f"Error parsing XML for trial {nct_id}: {e}")
24
                os.remove(local_file_path)
25
                continue
26
            
27
            # Download the online version of the XML
28
            url = f"https://clinicaltrials.gov/ct2/show/{nct_id}?displayxml=true"
29
            response = requests.get(url)
30
            
31
            if response.status_code == 200:
32
                online_xml_content = response.text
33
                # Parse the XML content
34
                online_root = ET.fromstring(online_xml_content)
35
                to_check = ["eligibility", "brief_title", "overall_status", "location"]
36
                
37
                local_version = []
38
                online_version = []
39
                
40
                for s in to_check:
41
                    local_elem = local_root.find(".//%s" % s)
42
                    online_elem = online_root.find(".//%s" % s)
43
                    
44
                    # Check if the element exists in both versions
45
                    if local_elem is not None and online_elem is not None:
46
                        local_version.append(local_elem)
47
                        online_version.append(online_elem)
48
                    else:
49
                        continue
50
                
51
                is_updated = any([normalize_whitespace(ET.tostring(a, encoding='unicode').strip()) !=
52
                                normalize_whitespace(ET.tostring(b, encoding='unicode').strip())
53
                                for a, b in zip(local_version, online_version)])
54
                
55
                if is_updated:
56
                    updated_cts.append(nct_id)
57
                    # Update the local XML with the online version
58
                    with open(local_file_path, "w") as f:
59
                        f.write(ET.tostring(online_root, encoding='unicode'))
60
                    print(f"Updated eligibility criteria for {nct_id}")
61
                else:
62
                    print(f"No changes in eligibility criteria for {nct_id}.")
63
            else:
64
                print(f"Error downloading study information for {nct_id}")
65
        else:
66
            downloaded = False
67
            while not downloaded:
68
                url = f"https://clinicaltrials.gov/ct2/show/{nct_id}?displayxml=true"
69
                response = requests.get(url)
70
                if response.status_code == 200:
71
                    root = ET.fromstring(response.text)
72
                    with open(local_file_path, "w") as f:
73
                        f.write(ET.tostring(root, encoding='unicode'))
74
                    downloaded = True
75
                    print(f"Study information downloaded for {nct_id}")
76
                else:
77
                    print(f"Error downloading study information for {nct_id}")
78
                
79
                if not downloaded:
80
                    print(f'Download of {nct_id}.xml failed. Retrying in 2 seconds...')
81
                    time.sleep(2)
82
    return updated_cts
83
84
85
def extract_study_info(nct_id):
86
    """
87
    Extract various study information from a clinical trial text with the given NCT identifier.
88
89
    This function attempts to extract various study information for a clinical trial specified by its unique
90
    NCT identifier (NCT ID). The function checks if a file named '{nct_id}_info.txt' already exists
91
    in the 'trials_texts' directory. If the file exists, the function returns 0, indicating that the
92
    extraction is not required, and the information is already available locally.
93
94
    If the file '{nct_id}_info.txt' does not exist, the function parses the XML file with the name '{nct_id}.xml'
95
    located in the 'trials_texts' directory. The XML content is parsed using the `xml.etree.ElementTree`
96
    module. The function then extracts various study information from the XML content and saves it in a text file
97
    with the name '{nct_id}_info.txt' in the 'trials_texts' directory.
98
99
    The extracted study information includes:
100
    - Long title
101
    - Short title
102
    - Cancer sites
103
    - Start date
104
    - End date
105
    - Primary end date
106
    - Overall status
107
    - Study phase
108
    - Study type
109
    - Brief summary
110
    - Detailed description
111
    - Number of arms
112
    - Arms information
113
    - Eligibility criteria
114
    - Gender
115
    - Minimum age
116
    - Maximum age
117
    - Intervention details
118
    - Location details
119
120
    Parameters:
121
        nct_id (str): The unique identifier (NCT ID) of the clinical trial for which study information
122
                    needs to be extracted.
123
124
    Returns:
125
        int: Returns 0 if the study information file already exists locally and doesn't require extraction.
126
            Otherwise, the function doesn't return anything directly (implicit return).
127
            Note: The extracted study information is saved in the 'trials_texts' directory.
128
129
    """
130
    if os.path.exists(f"../data/trials_xmls/{nct_id}_info.txt"):
131
        return 0
132
        # print(f"{nct_id}_info.txt already exists. Skipping extraction.")
133
    else:
134
        tree = ET.parse(f"../data/trials_xmls/{nct_id}.xml")
135
        root = tree.getroot()
136
        with open(f"../data/trials_xmls/{nct_id}_info.txt", "w") as f:
137
            
138
            # Extract Long title
139
            official_title = root.find(".//official_title")
140
            if official_title is not None:
141
                title_text = official_title.text.strip()
142
                f.write(f"Long Title:\n{title_text}\n\n")
143
                
144
            # Extract short title
145
            brief_title = root.find(".//brief_title")
146
            if brief_title is not None:
147
                title_text = brief_title.text.strip()
148
                f.write(f"Short Title:\n{title_text}\n\n")
149
            
150
            # Extract cancer sites
151
            conditions = root.findall(".//condition")
152
            if conditions is not None:
153
                f.write("Cancer Site(s):\n")
154
                for condition in conditions:
155
                    condition_text = condition.text.strip()
156
                    f.write(f"- {condition_text}\n")
157
                f.write("\n")
158
159
            # Extract start date
160
            start_date = root.find(".//start_date")
161
            if start_date is not None:
162
                start_date_text = start_date.text.strip()
163
                f.write(f"Start Date:\n{start_date_text}\n\n")
164
165
            # Extract end date
166
            end_date = root.find(".//completion_date")
167
            if end_date is not None:
168
                end_date_text = end_date.text.strip()
169
                f.write(f"End Date:\n{end_date_text}\n\n")
170
                
171
            # Extract primary end date
172
            primary_end_date = root.find(".//primary_completion_date")
173
            if end_date is not None:
174
                end_date_text = end_date.text.strip()
175
                f.write(f"Primary End Date:\n{end_date_text}\n\n")
176
            
177
            
178
            # Extract overall status
179
            overall_status = root.find(".//overall_status")
180
            if overall_status is not None:
181
                overall_status_text = overall_status.text.strip()
182
                f.write(f"Overall Status:\n{overall_status_text}\n\n")
183
                
184
            # Extract study phase
185
            study_phase = root.find(".//phase")
186
            if study_phase is not None:
187
                f.write(f"Study Phase: \n{study_phase.text.strip()}\n\n")
188
189
            # Extract study type
190
            study_type = root.find(".//study_type")
191
            if study_type is not None:
192
                study_type_text = study_type.text.strip()
193
                f.write(f"Study Type:\n{study_type_text}\n\n")
194
                
195
            # Extract brief summary
196
            brief_summary = root.find(".//brief_summary")
197
            if brief_summary is not None:
198
                brief_summary_text = brief_summary.find(".//textblock").text.strip()
199
                f.write(f"Brief Summary:\n{brief_summary_text}\n\n")
200
                
201
            # Extract detailed description
202
            detailed_description = root.find(".//detailed_description")
203
            if detailed_description is not None:
204
                detailed_description_text = detailed_description.find(".//textblock").text.strip()
205
                f.write(f"Detailed Description:\n{detailed_description_text}\n\n")
206
                
207
            # Extract number of arms
208
            number_of_arms = root.find(".//number_of_arms")
209
            if number_of_arms is not None:
210
                f.write(f"Number of Arms: {number_of_arms.text.strip()}\n\n")
211
212
            arms = root.findall(".//arm_group")
213
            if arms is not None:
214
                f.write("Arms:\n")
215
                for arm in arms:
216
                    arm_group_label = arm.find(".//arm_group_label").text.strip()
217
                    arm_group_description = arm.find(".//arm_group_description")
218
                    if arm_group_description is not None:
219
                        arm_group_description_text = arm_group_description.text.strip()
220
                        f.write(f"- {arm_group_label}: {arm_group_description_text}\n")
221
                    else:
222
                        f.write(f"- {arm_group_label}\n")
223
                f.write("\n")
224
            
225
            # Extract eligibility criteria
226
            eligibility_criteria = root.find(".//eligibility/criteria")
227
            if eligibility_criteria is not None:
228
                eligibility_criteria_text = eligibility_criteria.find(".//textblock").text.strip()
229
                f.write(f"Eligibility Criteria:\n{eligibility_criteria_text}\n\n")
230
231
            # Extract gender
232
            gender = root.find(".//gender")
233
            if gender is not None:
234
                gender_text = gender.text.strip()
235
                f.write(f"Gender:\n{gender_text}\n\n")
236
237
            # Extract minimum age
238
            min_age = root.find(".//eligibility/minimum_age")
239
            if min_age is not None:
240
                min_age_text = min_age.text.strip()
241
                f.write(f"Minimum Age:\n{min_age_text}\n\n")
242
            
243
            # Extract maximum age
244
            max_age = root.find(".//eligibility/maximum_age")
245
            if max_age is not None:
246
                max_age_text = max_age.text.strip()
247
                f.write(f"Maximum Age:\n{max_age_text}\n\n")
248
249
            # Extract intervention
250
            intervention = root.findall(".//intervention")
251
            if intervention is not None:
252
                f.write("Interventions:\n")
253
                for i in intervention:
254
                    intervention_name = i.find(".//intervention_name").text.strip()
255
                    f.write(f"- {intervention_name}\n")
256
                f.write("\n")
257
                
258
            # Extract locations
259
            locations = root.findall(".//location")
260
            if locations is not None:
261
                f.write("Locations:\n")
262
                for location in locations:
263
                    city = location.find(".//city")
264
                    country = location.find(".//country")
265
                    if city is not None and country is not None:
266
                        location_text = f"{city.text.strip()}, {country.text.strip()}"
267
                        f.write(f"- {location_text}\n")
268
                f.write("\n")
269
270
    print(f"{nct_id} info extracted and saved to {nct_id}_info.txt")
271
272
def add_spaces_around_punctuation(text):
273
274
    """
275
    Add spaces around punctuation
276
277
    Parameters
278
    ----------
279
    text : str
280
        The text to be preprocessed
281
282
    Returns
283
    -------
284
    str
285
        The preprocessed text
286
    """
287
    text = re.sub(r'([.,!?()])', r' \1 ', text)
288
    return text
289
290
291
def remove_special_characters(text):
292
    """
293
    Remove special characters
294
295
    Parameters
296
    ----------
297
    text : str
298
        The text to be preprocessed
299
300
    Returns
301
    -------
302
    str
303
        The preprocessed text
304
    """
305
    text = re.sub(r'[^a-zA-Z0-9]', ' ', text)
306
    return text
307
308
def remove_dashes_at_the_start_of_sentences(text):
309
    """
310
    Remove dashes at the start of sentences
311
312
    Parameters
313
    ----------
314
    text : str
315
        The text to be preprocessed
316
317
    Returns
318
    -------
319
    str
320
        The preprocessed text
321
    """
322
    text = re.sub(r'^- ', '', text)
323
    return text
324
325
326
def post_process_entities(entities):
327
    """
328
    Merge consecutive entities and post-process the results.
329
330
    This function takes a list of entities generated from a named entity recognition (NER) model's output
331
    and performs post-processing to merge consecutive entities of the same type. The input entities list
332
    contains dictionaries representing each detected entity with the following keys:
333
    - "entity" (str): The entity type represented as a prefixed tag (e.g., "B-ORG", "I-LOC").
334
    - "score" (float): The confidence score assigned to the entity by the NER model.
335
    - "word" (str): The text of the entity in the input text.
336
    - "start" (int): The starting index of the entity in the input text.
337
    - "end" (int): The ending index (exclusive) of the entity in the input text.
338
339
    The function iterates through the entities and merges consecutive entities with the same type into a single
340
    entity. It also handles entities that span multiple words, indicated by the presence of "I-" prefixes.
341
    The merged entity is represented by a dictionary containing the merged information:
342
    - "entity" (str): The entity type without the prefix (e.g., "ORG", "LOC").
343
    - "score" (float): The maximum confidence score among the merged entities.
344
    - "word" (str): The combined text of the merged entities.
345
    - "start" (int): The starting index of the first entity in the merged sequence.
346
    - "end" (int): The ending index (exclusive) of the last entity in the merged sequence.
347
348
    Parameters:
349
        entities (list): A list of dictionaries representing detected entities.
350
351
    Returns:
352
        list: A list of dictionaries representing merged entities after post-processing.
353
            Each dictionary contains the keys "entity", "score", "word", "start", and "end"
354
            representing the entity type, confidence score, text, start index, and end index respectively.
355
    """
356
    merged_entities = []
357
    current_entity = None
358
359
    for entity in entities:
360
        if entity["entity"].startswith("B-"):
361
            if current_entity is not None:
362
                merged_entities.append(current_entity)
363
            current_entity = {
364
                "entity": entity["entity"][2:],
365
                "score": entity["score"],
366
                "word": entity["word"].replace("##", " "),
367
                "start": entity["start"],
368
                "end": entity["end"]
369
            }
370
        elif entity["entity"].startswith("I-"):
371
            if (current_entity is not None) and entity["word"].startswith("##"):
372
                current_entity["word"] += entity["word"].replace("##", "")
373
                current_entity["end"] = entity["end"]
374
                current_entity["score"] = max(current_entity["score"], entity["score"])
375
            else:
376
                current_entity["word"] += " " + entity["word"].lstrip()
377
                current_entity["end"] = entity["end"]
378
                current_entity["score"] = max(current_entity["score"], entity["score"])
379
        else:
380
            if current_entity is not None:
381
                merged_entities.append(current_entity)
382
                current_entity = None
383
384
    if current_entity is not None:
385
        merged_entities.append(current_entity)
386
387
    return merged_entities
388
389
390
def get_dictionaries_with_values(list_of_dicts, key, values):
391
    """
392
    Filter a list of dictionaries based on the presence of specific values in a specified key.
393
394
    This function takes a list of dictionaries and filters them based on the presence of specific values in a specified key.
395
    The function checks each dictionary in the input list and includes only those dictionaries where any of the given values
396
    are present in the specified key. The filtering is performed using list comprehensions.
397
398
    Parameters:
399
        list_of_dicts (list): A list of dictionaries to be filtered.
400
        key (str): The key in the dictionaries where the filtering is applied.
401
        values (list): A list of values. The function will filter dictionaries where any of these values are present in the specified key.
402
403
    Returns:
404
        list: A list of dictionaries that meet the filtering criteria.
405
406
    Example:
407
        list_of_dicts = [
408
            {"name": "Alice", "age": 30},
409
            {"name": "Bob", "age": 25},
410
            {"name": "Charlie", "age": 35},
411
            {"name": "David", "age": 30},
412
        ]
413
414
        get_dictionaries_with_values(list_of_dicts, "age", [30, 35])
415
        # Output: [
416
        #   {"name": "Alice", "age": 30},
417
        #   {"name": "Charlie", "age": 35},
418
        #   {"name": "David", "age": 30}
419
        # ]
420
    """
421
    return [d for d in list_of_dicts if any(val in d.get(key, []) for val in values)]
422
423
def resolve_ner_overlaps(ner1_results, ner2_results):
424
    """
425
    Resolve overlaps between entities detected by two named entity recognition (NER) models.
426
427
    This function takes the results of two NER models (ner1_results and ner2_results) and resolves overlaps
428
    between the entities detected by these models. An overlap occurs when the span of an entity detected by one
429
    model partially or fully overlaps with the span of an entity detected by the other model.
430
431
    The function iterates through the entities detected by the first NER model (ner1_results). For each entity,
432
    it checks if it overlaps with any entity from the second model (ner2_results). If there are no overlaps,
433
    the entity from the first model is added to the resolved results.
434
435
    After processing the entities from the first model, the function then adds entities from the second model
436
    that do not overlap with any entities from the first model.
437
438
    Parameters:
439
        ner1_results (list): A list of dictionaries representing entities detected by the first NER model.
440
        ner2_results (list): A list of dictionaries representing entities detected by the second NER model.
441
442
    Returns:
443
        list: A list of dictionaries representing the resolved entities with overlaps removed.
444
445
    Example:
446
        ner1_results = [
447
            {"start": 5, "end": 10, "entity_group": "PERSON"},
448
            {"start": 20, "end": 25, "entity_group": "LOCATION"}
449
        ]
450
451
        ner2_results = [
452
            {"start": 8, "end": 15, "entity_group": "PERSON"},
453
            {"start": 18, "end": 30, "entity_group": "ORGANIZATION"}
454
        ]
455
456
        resolve_ner_overlaps(ner1_results, ner2_results)
457
        # Output: [
458
        #   {"start": 5, "end": 10, "entity_group": "PERSON"},
459
        #   {"start": 18, "end": 30, "entity_group": "ORGANIZATION"},
460
        #   {"start": 20, "end": 25, "entity_group": "LOCATION"}
461
        # ]
462
    """
463
    resolved_results = []
464
    # Iterate over the entities detected by the first NER model
465
    for entity1 in ner1_results:
466
        entity1_start = entity1['start']
467
        entity1_end = entity1['end']
468
        entity1_label = entity1['entity_group']
469
470
        # Check if the entity from the first model overlaps with any entity from the second model
471
        overlaps = False
472
        for entity2 in ner2_results:
473
            entity2_start = entity2['start']
474
            entity2_end = entity2['end']
475
            entity2_label = entity2['entity_group']
476
477
            if entity1_start < entity2_end and entity1_end > entity2_start:
478
                overlaps = True
479
                break
480
481
        # If there were no overlaps, add the entity from the first model to the resolved results
482
        if not overlaps:
483
            resolved_results.append(entity1)
484
485
    # Add entities from the second model that don't overlap with any entities from the first model
486
    for entity2 in ner2_results:
487
        entity2_start = entity2['start']
488
        entity2_end = entity2['end']
489
        entity2_label = entity2['entity_group']
490
491
        overlaps = False
492
        for entity1 in resolved_results:
493
            entity1_start = entity1['start']
494
            entity1_end = entity1['end']
495
            entity1_label = entity1['entity_group']
496
497
            if entity2_start < entity1_end and entity2_end > entity1_start:
498
                overlaps = True
499
                break
500
501
        if not overlaps:
502
            resolved_results.append(entity2)
503
504
    return resolved_results
505
506
def extract_eligibility_criteria(trial_id):
507
    """
508
    Extract the eligibility criteria text for a clinical trial with the given trial ID.
509
510
    This function attempts to locate and extract the eligibility criteria text for a clinical trial
511
    specified by its trial ID. The function reads an XML file named '{trial_id}.xml' which is expected
512
    to contain information for the clinical trial. It searches for the eligibility criteria textblock within
513
    the XML and extracts the corresponding text.
514
515
    Parameters:
516
        trial_id (str): The unique identifier of the clinical trial.
517
518
    Returns:
519
        str or None: The extracted eligibility criteria text for the specified trial if found,
520
                    otherwise None.
521
    """
522
    xml_file_path = f'../data/trials_xmls/{trial_id}.xml'
523
524
    if os.path.exists(xml_file_path):
525
        with open(xml_file_path, 'r') as xml_file:
526
            xml_content = xml_file.read()
527
        try:
528
            tree = ET.ElementTree(ET.fromstring(xml_content))
529
            root = tree.getroot()
530
        except ET.ParseError as e:
531
            print(f"Error parsing XML for trial {trial_id}: {e}")
532
            return None
533
        # Find the Eligibility Criteria TextBlock section within the XML
534
        eligibility_criteria_textblock = root.find(".//eligibility/criteria/textblock")
535
536
        if eligibility_criteria_textblock is not None:
537
            # Extract the text from the Eligibility Criteria TextBlock section
538
            eligibility_criteria_text = eligibility_criteria_textblock.text
539
            return eligibility_criteria_text.strip()
540
541
    # If the trial ID is not found or the eligibility criteria textblock is missing, return None
542
    return None
543
544
545
546
def replace_parentheses_with_braces(text):
547
    """
548
    Replace parentheses with curly braces in the given text.
549
550
    This function takes a text as input and replaces all occurrences of opening parentheses '('
551
    with an opening curly brace '{', and closing parentheses ')' with a closing curly brace '}'.
552
    The function maintains a stack to ensure proper matching of parentheses. If a closing parenthesis
553
    is encountered without a corresponding opening parenthesis in the stack, it is left unchanged.
554
555
    Parameters:
556
        text (str): The input text containing parentheses that need to be replaced.
557
558
    Returns:
559
        str: The modified text with parentheses replaced by curly braces.
560
    """
561
    stack = []
562
    result = ""
563
    for char in text:
564
        if char == '(' or char == '[':
565
            stack.append(char)
566
            result += "{"
567
        elif char == ')' or char == "]":
568
            if stack:
569
                stack.pop()
570
                result += "}"
571
            else:
572
                result += char
573
        else:
574
            result += char
575
    return result
576
577
578
579
580
def line_starts_with_capitalized_alphanumeric(line):
581
    """
582
    Check if the given line starts with a capitalized alphanumeric word.
583
584
    Parameters:
585
        line (str): The input string representing a line.
586
587
    Returns:
588
        bool: True if the line starts with a capitalized alphanumeric word, False otherwise.
589
    """
590
    words = line.split()
591
    if len(words) > 0:
592
        first_word = words[0]
593
        if first_word[0].isalpha() and first_word[0].isupper():
594
            return True
595
    return False
596