[de07e6]: / tests / test_download.py

Download this file

114 lines (94 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import unittest
from unittest.mock import patch, MagicMock, mock_open, call
import xml.etree.ElementTree as ET
from download import download_study_info, get_cancer_trials_list
class GetCancerTrialsListTestCase(unittest.TestCase):
@patch('requests.get')
def test_get_cancer_trials_list(self, mock_get):
# Mock the response from the requests.get function
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.json.return_value = {
"FullStudiesResponse": {
"FullStudies": [
{
"Study": {
"ProtocolSection": {
"IdentificationModule": {
"NCTId": "NCT12345678"
}
}
}
},
{
"Study": {
"ProtocolSection": {
"IdentificationModule": {
"NCTId": "NCT87654321"
}
}
}
}
]
}
}
# Call the function and get the result
result = get_cancer_trials_list(max_trials=2)
# Assert that the requests.get function was called with the correct URL and parameters
mock_get.assert_called_once_with(
"https://clinicaltrials.gov/api/query/full_studies",
params={
"expr": "((cancer) OR (neoplasm)) AND ((interventional) OR (treatment)) AND ((mutation) OR (variant))",
"min_rnk": 1,
"max_rnk": 100,
"fmt": "json",
"fields": "NCTId"
}
)
# Assert that the function returns the correct list of NCT IDs
self.assertEqual(sorted(result), sorted(["NCT12345678", "NCT87654321"]))
class DownloadStudyInfoTestCase(unittest.TestCase):
@patch('os.path.exists')
@patch('requests.get')
def test_download_study_info(self, mock_get, mock_exists):
# Mock the response from the requests.get function
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.text = "<root><eligibility>Age less than 18</eligibility></root>"
# Mock the os.path.exists function to return False
mock_exists.return_value = False
# Create a mock file object that can track what was written to it
mock_file = mock_open()
with patch('builtins.open', mock_file, create=True):
# Call the function and get the result
result = download_study_info("NCT000000000")
# Assert that the requests.get function was called with the correct URL
mock_get.assert_called_once_with("https://clinicaltrials.gov/ct2/show/NCT000000000?displayxml=true")
# Assert that the function returns an empty list
self.assertEqual(result, [])
# Assert that the file was written with the new text
mock_file().write.assert_called_once_with("<root><eligibility>Age less than 18</eligibility></root>")
@patch('os.path.exists')
@patch('requests.get')
def test_download_study_info_updates_file(self, mock_get, mock_exists):
# Mock the response from the requests.get function
mock_response = mock_get.return_value
mock_response.status_code = 200
mock_response.text = "<root><eligibility>Age more than 18</eligibility></root>"
# Mock the os.path.exists function to return True
mock_exists.return_value = True
# Create a dictionary to store mock file objects for each file path
mock_files = {}
def side_effect(file_path, mode):
# If a mock file object for this file path does not exist, create one
if file_path not in mock_files:
mock_files[file_path] = mock_open(read_data="<root><eligibility>Age less than 18</eligibility></root>")()
return mock_files[file_path]
with patch('builtins.open', side_effect=side_effect):
# Call the function
download_study_info("NCT000000000")
# Assert that the file was written with the new text
mock_files[f"../data/trials_xmls/NCT000000000.xml"].write.assert_called_once_with("<root><eligibility>Age more than 18</eligibility></root>")
if __name__ == '__main__':
unittest.main()