a b/tests/test_download.py
1
import os
2
import unittest
3
from unittest.mock import patch, MagicMock, mock_open, call
4
import xml.etree.ElementTree as ET
5
from download import download_study_info, get_cancer_trials_list
6
7
8
class GetCancerTrialsListTestCase(unittest.TestCase):
9
    @patch('requests.get')
10
    def test_get_cancer_trials_list(self, mock_get):
11
        # Mock the response from the requests.get function
12
        mock_response = mock_get.return_value
13
        mock_response.status_code = 200
14
        mock_response.json.return_value = {
15
            "FullStudiesResponse": {
16
                "FullStudies": [
17
                    {
18
                        "Study": {
19
                            "ProtocolSection": {
20
                                "IdentificationModule": {
21
                                    "NCTId": "NCT12345678"
22
                                }
23
                            }
24
                        }
25
                    },
26
                    {
27
                        "Study": {
28
                            "ProtocolSection": {
29
                                "IdentificationModule": {
30
                                    "NCTId": "NCT87654321"
31
                                }
32
                            }
33
                        }
34
                    }
35
                ]
36
            }
37
        }
38
39
        # Call the function and get the result
40
        result = get_cancer_trials_list(max_trials=2)
41
42
        # Assert that the requests.get function was called with the correct URL and parameters
43
        mock_get.assert_called_once_with(
44
            "https://clinicaltrials.gov/api/query/full_studies",
45
            params={
46
                "expr": "((cancer) OR (neoplasm)) AND ((interventional) OR (treatment)) AND ((mutation) OR (variant))",
47
                "min_rnk": 1,
48
                "max_rnk": 100,
49
                "fmt": "json",
50
                "fields": "NCTId"
51
            }
52
        )
53
54
        # Assert that the function returns the correct list of NCT IDs
55
        self.assertEqual(sorted(result), sorted(["NCT12345678", "NCT87654321"]))
56
57
58
59
class DownloadStudyInfoTestCase(unittest.TestCase):
60
    @patch('os.path.exists')
61
    @patch('requests.get')
62
    def test_download_study_info(self, mock_get, mock_exists):
63
        # Mock the response from the requests.get function
64
        mock_response = mock_get.return_value
65
        mock_response.status_code = 200
66
        mock_response.text = "<root><eligibility>Age less than 18</eligibility></root>"
67
68
        # Mock the os.path.exists function to return False
69
        mock_exists.return_value = False
70
71
        # Create a mock file object that can track what was written to it
72
        mock_file = mock_open()
73
        with patch('builtins.open', mock_file, create=True):
74
            # Call the function and get the result
75
            result = download_study_info("NCT000000000")
76
77
        # Assert that the requests.get function was called with the correct URL
78
        mock_get.assert_called_once_with("https://clinicaltrials.gov/ct2/show/NCT000000000?displayxml=true")
79
80
        # Assert that the function returns an empty list
81
        self.assertEqual(result, [])
82
83
        # Assert that the file was written with the new text
84
        mock_file().write.assert_called_once_with("<root><eligibility>Age less than 18</eligibility></root>")
85
        
86
    @patch('os.path.exists')
87
    @patch('requests.get')
88
    def test_download_study_info_updates_file(self, mock_get, mock_exists):
89
        # Mock the response from the requests.get function
90
        mock_response = mock_get.return_value
91
        mock_response.status_code = 200
92
        mock_response.text = "<root><eligibility>Age more than 18</eligibility></root>"
93
94
        # Mock the os.path.exists function to return True
95
        mock_exists.return_value = True
96
97
        # Create a dictionary to store mock file objects for each file path
98
        mock_files = {}
99
100
        def side_effect(file_path, mode):
101
            # If a mock file object for this file path does not exist, create one
102
            if file_path not in mock_files:
103
                mock_files[file_path] = mock_open(read_data="<root><eligibility>Age less than 18</eligibility></root>")()
104
            return mock_files[file_path]
105
106
        with patch('builtins.open', side_effect=side_effect):
107
            # Call the function
108
            download_study_info("NCT000000000")
109
110
        # Assert that the file was written with the new text
111
        mock_files[f"../data/trials_xmls/NCT000000000.xml"].write.assert_called_once_with("<root><eligibility>Age more than 18</eligibility></root>")
112
        
113
if __name__ == '__main__':
114
    unittest.main()