|
a |
|
b/tests/test_model_executor.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
# ! /usr/bin/env python |
|
|
3 |
""" main test script to test the primary functions/classes/methods. """ |
|
|
4 |
# run with python -m tests.test_generator |
|
|
5 |
|
|
|
6 |
import glob |
|
|
7 |
import logging |
|
|
8 |
import os |
|
|
9 |
import shutil |
|
|
10 |
import sys |
|
|
11 |
|
|
|
12 |
import pytest |
|
|
13 |
import torch |
|
|
14 |
|
|
|
15 |
# import unittest |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
# Set the logging level depending on the level of detail you would like to have in the logs while running the tests. |
|
|
19 |
LOGGING_LEVEL = logging.INFO # WARNING # logging.INFO |
|
|
20 |
|
|
|
21 |
models_with_args = [ |
|
|
22 |
( |
|
|
23 |
"00001_DCGAN_MMG_CALC_ROI", |
|
|
24 |
{}, |
|
|
25 |
100, |
|
|
26 |
), # 100 samples to test automatic batch-wise image generation in model_executor |
|
|
27 |
( |
|
|
28 |
"00002", |
|
|
29 |
{}, |
|
|
30 |
3, |
|
|
31 |
), # "00002" instead of "00002_DCGAN_MMG_MASS_ROI" to test shortcut model_ids |
|
|
32 |
( |
|
|
33 |
"03", |
|
|
34 |
{"translate_all_images": False}, |
|
|
35 |
2, |
|
|
36 |
), # "03" instead of "00003_CYCLEGAN_MMG_DENSITY_FULL" to test shortcut model_ids |
|
|
37 |
( |
|
|
38 |
4, # 4 instead of "00004_PIX2PIX_MMG_MASSES_W_MASKS" to test shortcut model_ids |
|
|
39 |
{ |
|
|
40 |
"shapes": ["oval"], |
|
|
41 |
"ssim_threshold": 0.18, |
|
|
42 |
"image_size": [128, 128], |
|
|
43 |
"patch_size": [30, 30], |
|
|
44 |
}, |
|
|
45 |
3, |
|
|
46 |
), |
|
|
47 |
("00005_DCGAN_MMG_MASS_ROI", {}, 3), |
|
|
48 |
("00006_WGANGP_MMG_MASS_ROI", {}, 3), |
|
|
49 |
( |
|
|
50 |
"00007_INPAINT_BRAIN_MRI", |
|
|
51 |
{ |
|
|
52 |
"image_size": (256, 256), |
|
|
53 |
"num_inpaints_per_sample": 2, |
|
|
54 |
"randomize_input_image_order": False, |
|
|
55 |
"add_variations_to_mask": False, |
|
|
56 |
"x_center": 120, |
|
|
57 |
"y_center": 140, |
|
|
58 |
"radius_1": 8, |
|
|
59 |
"radius_2": 12, |
|
|
60 |
"radius_3": 24, |
|
|
61 |
}, |
|
|
62 |
3, |
|
|
63 |
), |
|
|
64 |
( |
|
|
65 |
"00008_C-DCGAN_MMG_MASSES", |
|
|
66 |
{"condition": 0, "is_cbisddsm_training_data": False}, |
|
|
67 |
3, |
|
|
68 |
), |
|
|
69 |
("00009_PGGAN_POLYP_PATCHES_W_MASKS", {"save_option": "image_only"}, 3), |
|
|
70 |
("00010_FASTGAN_POLYP_PATCHES_W_MASKS", {"save_option": "image_only"}, 3), |
|
|
71 |
# ("00011_SINGAN_POLYP_PATCHES_W_MASKS", {"checkpoint_ids": [999]}, 3), # removed after successful testing due to limited CI pipeline capacity |
|
|
72 |
# ("00012_C-DCGAN_MMG_MASSES", {"condition": 0}, 3), # removed after successful testing due to limited CI pipeline capacity |
|
|
73 |
# ("00013_CYCLEGAN_MMG_DENSITY_OPTIMAM_MLO", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity |
|
|
74 |
# ("00014_CYCLEGAN_MMG_DENSITY_OPTIMAM_CC", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity |
|
|
75 |
# ("00015_CYCLEGAN_MMG_DENSITY_CSAW_MLO", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity |
|
|
76 |
# ("00016_CYCLEGAN_MMG_DENSITY_CSAW_CC", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity |
|
|
77 |
("00017_DCGAN_XRAY_LUNG_NODULES", {}, 3), |
|
|
78 |
("00018_WGANGP_XRAY_LUNG_NODULES", {}, 3), |
|
|
79 |
("00019_PGGAN_CHEST_XRAY", {}, 3), |
|
|
80 |
("00020_PGGAN_CHEST_XRAY", {"resize_pixel_dim": 512, "image_size": 256}, 3), |
|
|
81 |
( |
|
|
82 |
"00021_CYCLEGAN_BRAIN_MRI_T1_T2", |
|
|
83 |
{ |
|
|
84 |
"input_path": "models/00021_CYCLEGAN_Brain_MRI_T1_T2/inputs/T2", |
|
|
85 |
"gpu_id": 0, |
|
|
86 |
"T1_to_T2": False, |
|
|
87 |
}, |
|
|
88 |
3, |
|
|
89 |
), |
|
|
90 |
("00022_WGAN_CARDIAC_AGING", {}, 3), |
|
|
91 |
( |
|
|
92 |
"00023_PIX2PIXHD_BREAST_DCEMRI", |
|
|
93 |
{ |
|
|
94 |
"input_path": "input", |
|
|
95 |
"gpu_id": 0, |
|
|
96 |
"image_size": 448, |
|
|
97 |
}, |
|
|
98 |
3, |
|
|
99 |
), |
|
|
100 |
] |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
# class TestMediganExecutorMethods(unittest.TestCase): |
|
|
104 |
class TestMediganExecutorMethods: |
|
|
105 |
def setup_class(self): |
|
|
106 |
## unittest logger config |
|
|
107 |
# This logger on root level initialized via logging.getLogger() will also log all log events |
|
|
108 |
# from the medigan library. Pass a logger name (e.g. __name__) instead if you only want logs from tests.py |
|
|
109 |
self.logger = logging.getLogger() # (__name__) |
|
|
110 |
self.logger.setLevel(LOGGING_LEVEL) |
|
|
111 |
stream_handler = logging.StreamHandler(sys.stdout) |
|
|
112 |
stream_handler.setLevel(LOGGING_LEVEL) |
|
|
113 |
formatter = logging.Formatter( |
|
|
114 |
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
|
115 |
) |
|
|
116 |
stream_handler.setFormatter(formatter) |
|
|
117 |
self.logger.addHandler(stream_handler) |
|
|
118 |
|
|
|
119 |
self.test_output_path = "test_output_path" |
|
|
120 |
self.num_samples = 2 |
|
|
121 |
self.test_imports_and_init_generators(self) |
|
|
122 |
self._remove_dir_and_contents(self) # in case something is left there. |
|
|
123 |
self.model_ids = self.generators.config_manager.model_ids |
|
|
124 |
|
|
|
125 |
def test_imports_and_init_generators(self): |
|
|
126 |
from src.medigan.constants import ( |
|
|
127 |
CONFIG_FILE_KEY_EXECUTION, |
|
|
128 |
CONFIG_FILE_KEY_GENERATE, |
|
|
129 |
CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE, |
|
|
130 |
) |
|
|
131 |
from src.medigan.generators import Generators |
|
|
132 |
|
|
|
133 |
self.generators = Generators() |
|
|
134 |
self.CONFIG_FILE_KEY_EXECUTION = CONFIG_FILE_KEY_EXECUTION |
|
|
135 |
self.CONFIG_FILE_KEY_GENERATE = CONFIG_FILE_KEY_GENERATE |
|
|
136 |
self.CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE = ( |
|
|
137 |
CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE |
|
|
138 |
) |
|
|
139 |
|
|
|
140 |
@pytest.mark.parametrize("models_with_args", [models_with_args]) |
|
|
141 |
def test_sample_generation_methods(self, models_with_args: list): |
|
|
142 |
self.logger.debug(f"models: {models_with_args}") |
|
|
143 |
for i, model_id in enumerate(self.model_ids): |
|
|
144 |
# if ( |
|
|
145 |
# model_id != "00011_SINGAN_POLYP_PATCHES_W_MASKS" |
|
|
146 |
# ): |
|
|
147 |
## avoiding full memory on Windows ci test server |
|
|
148 |
# continue |
|
|
149 |
self.logger.debug(f"Now testing model {model_id}") |
|
|
150 |
self._remove_dir_and_contents() # Already done in each test independently, but to be sure, here again. |
|
|
151 |
self.test_generate_method(model_id=model_id) |
|
|
152 |
|
|
|
153 |
# Check if args available fo model_id. Note: The models list may not include the latest medigan models |
|
|
154 |
for model in models_with_args: |
|
|
155 |
if model_id == model[0]: |
|
|
156 |
self.test_generate_method_with_additional_args( |
|
|
157 |
model_id=model[0], args=model[1], expected_num_samples=model[2] |
|
|
158 |
) |
|
|
159 |
self.test_get_generate_method(model_id=model_id) |
|
|
160 |
self.test_get_dataloader_method(model_id=model_id) |
|
|
161 |
|
|
|
162 |
# if i == 16: # just for local testing |
|
|
163 |
# self._remove_model_dir_and_zip( |
|
|
164 |
# model_ids=[model_id], are_all_models_deleted=False |
|
|
165 |
# ) |
|
|
166 |
|
|
|
167 |
@pytest.mark.parametrize( |
|
|
168 |
"values_list, should_sample_be_generated", |
|
|
169 |
[ |
|
|
170 |
(["dcgan", "mMg", "ClF", "modality", "inbreast"], True), |
|
|
171 |
(["dcgan", "mMg", "ClF", "modality", "optimam"], True), |
|
|
172 |
(["dcgan", "mMg", "ClF", "modalities"], False), |
|
|
173 |
], |
|
|
174 |
) |
|
|
175 |
def test_find_model_and_generate_method( |
|
|
176 |
self, values_list, should_sample_be_generated |
|
|
177 |
): |
|
|
178 |
self._remove_dir_and_contents() |
|
|
179 |
|
|
|
180 |
self.generators.find_model_and_generate( |
|
|
181 |
values=values_list, |
|
|
182 |
target_values_operator="AND", |
|
|
183 |
are_keys_also_matched=True, |
|
|
184 |
is_case_sensitive=False, |
|
|
185 |
num_samples=self.num_samples, |
|
|
186 |
output_path=self.test_output_path, |
|
|
187 |
) |
|
|
188 |
|
|
|
189 |
self._check_if_samples_were_generated( |
|
|
190 |
should_sample_be_generated=should_sample_be_generated |
|
|
191 |
) |
|
|
192 |
|
|
|
193 |
@pytest.mark.parametrize( |
|
|
194 |
"values_list, metric", |
|
|
195 |
[ |
|
|
196 |
(["dcgan", "MMG"], "CLF.trained_on_real_and_fake.f1"), |
|
|
197 |
(["dcgan", "MMG"], "turing_test.AUC"), |
|
|
198 |
], |
|
|
199 |
) |
|
|
200 |
def test_find_and_rank_models_then_generate_method(self, values_list, metric): |
|
|
201 |
self._remove_dir_and_contents() |
|
|
202 |
# TODO This test needs the respective metrics for any of these models to be available in config/global.json. |
|
|
203 |
# These values would need to find at least two models. |
|
|
204 |
self.generators.find_models_rank_and_generate( |
|
|
205 |
values=values_list, |
|
|
206 |
target_values_operator="AND", |
|
|
207 |
are_keys_also_matched=True, |
|
|
208 |
is_case_sensitive=False, |
|
|
209 |
metric=metric, |
|
|
210 |
order="asc", |
|
|
211 |
num_samples=self.num_samples, |
|
|
212 |
output_path=self.test_output_path, |
|
|
213 |
) |
|
|
214 |
self._check_if_samples_were_generated() |
|
|
215 |
|
|
|
216 |
# @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) |
|
|
217 |
@pytest.mark.skip |
|
|
218 |
def test_generate_method(self, model_id): |
|
|
219 |
self._remove_dir_and_contents() |
|
|
220 |
self.generators.generate( |
|
|
221 |
model_id=model_id, |
|
|
222 |
num_samples=self.num_samples, |
|
|
223 |
output_path=self.test_output_path, |
|
|
224 |
install_dependencies=True, |
|
|
225 |
) |
|
|
226 |
self._check_if_samples_were_generated(model_id=model_id) |
|
|
227 |
|
|
|
228 |
# @pytest.mark.parametrize("model_id, args, expected_num_samples", models_with_args) |
|
|
229 |
@pytest.mark.skip |
|
|
230 |
def test_generate_method_with_additional_args( |
|
|
231 |
self, model_id, args, expected_num_samples |
|
|
232 |
): |
|
|
233 |
self._remove_dir_and_contents() |
|
|
234 |
self.generators.generate( |
|
|
235 |
model_id=model_id, |
|
|
236 |
num_samples=expected_num_samples, |
|
|
237 |
output_path=self.test_output_path, |
|
|
238 |
**args, |
|
|
239 |
) |
|
|
240 |
self._check_if_samples_were_generated( |
|
|
241 |
model_id=model_id, num_samples=expected_num_samples |
|
|
242 |
) |
|
|
243 |
|
|
|
244 |
# @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) |
|
|
245 |
@pytest.mark.skip |
|
|
246 |
def test_get_generate_method(self, model_id): |
|
|
247 |
self._remove_dir_and_contents() |
|
|
248 |
gen_function = self.generators.get_generate_function( |
|
|
249 |
model_id=model_id, |
|
|
250 |
num_samples=self.num_samples, |
|
|
251 |
output_path=self.test_output_path, |
|
|
252 |
) |
|
|
253 |
gen_function() |
|
|
254 |
self._check_if_samples_were_generated(model_id=model_id) |
|
|
255 |
del gen_function |
|
|
256 |
|
|
|
257 |
# @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) |
|
|
258 |
@pytest.mark.skip |
|
|
259 |
def test_get_dataloader_method(self, model_id): |
|
|
260 |
self._remove_dir_and_contents() |
|
|
261 |
data_loader = self.generators.get_as_torch_dataloader( |
|
|
262 |
model_id=model_id, num_samples=self.num_samples |
|
|
263 |
) |
|
|
264 |
self.logger.debug(f"{model_id}: len(data_loader): {len(data_loader)}") |
|
|
265 |
|
|
|
266 |
if len(data_loader) != self.num_samples: |
|
|
267 |
logging.warning( |
|
|
268 |
f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). " |
|
|
269 |
f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader()." |
|
|
270 |
) |
|
|
271 |
|
|
|
272 |
#### Get the object at index 0 from the dataloader |
|
|
273 |
data_dict = next(iter(data_loader)) |
|
|
274 |
|
|
|
275 |
# Test if the items at index [0] of the aforementioned object is of type torch tensor (e.g. torch.uint8) and not None, as expected by data structure design decision. |
|
|
276 |
assert torch.is_tensor(data_dict.get("sample")) |
|
|
277 |
|
|
|
278 |
# Test if the items at index [1], [2] of the aforementioned object are None and, if not, whether they are of type torch tensor, as expected |
|
|
279 |
assert data_dict.get("mask") is None or torch.is_tensor(data_dict.get("mask")) |
|
|
280 |
assert data_dict.get("other_imaging_output") is None or torch.is_tensor( |
|
|
281 |
data_dict.get("other_imaging_output") |
|
|
282 |
) |
|
|
283 |
|
|
|
284 |
# Test if the items at index [3] of the aforementioned object is None and, if not, whether it is of type list of strings, as expected. |
|
|
285 |
assert data_dict.get("label") is None or ( |
|
|
286 |
isinstance(data_dict.get("label"), list) |
|
|
287 |
and isinstance(data_dict.get("label")[0], str) |
|
|
288 |
) |
|
|
289 |
del data_dict |
|
|
290 |
del data_loader |
|
|
291 |
|
|
|
292 |
# @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args]) |
|
|
293 |
@pytest.mark.skip |
|
|
294 |
def test_visualize_method(self, model_id): |
|
|
295 |
if ( |
|
|
296 |
self.CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE |
|
|
297 |
in self.generators.config_manager.config_dict[model_id][ |
|
|
298 |
self.CONFIG_FILE_KEY_EXECUTION |
|
|
299 |
][self.CONFIG_FILE_KEY_GENERATE] |
|
|
300 |
): |
|
|
301 |
self.generators.visualize(model_id, auto_close=True) |
|
|
302 |
|
|
|
303 |
else: |
|
|
304 |
with pytest.raises(Exception) as e: |
|
|
305 |
self.generators.visualize(model_id, auto_close=True) |
|
|
306 |
|
|
|
307 |
assert e.type == ValueError |
|
|
308 |
|
|
|
309 |
@pytest.mark.skip |
|
|
310 |
def _check_if_samples_were_generated( |
|
|
311 |
self, model_id=None, num_samples=None, should_sample_be_generated: bool = True |
|
|
312 |
): |
|
|
313 |
# check if the number of generated samples of model_id_1 is as expected. |
|
|
314 |
file_list = glob.glob(self.test_output_path + "/*") |
|
|
315 |
self.logger.debug(f"{model_id}: {len(file_list)} == {self.num_samples} ?") |
|
|
316 |
if num_samples is None: |
|
|
317 |
num_samples = self.num_samples |
|
|
318 |
|
|
|
319 |
if should_sample_be_generated: |
|
|
320 |
assert ( |
|
|
321 |
len(file_list) == num_samples |
|
|
322 |
or len(file_list) |
|
|
323 |
== num_samples |
|
|
324 |
* 2 |
|
|
325 |
* 6 # 00007_INPAINT_BRAIN_MRI: 2 inpaints per sample, 6 outputs per sample |
|
|
326 |
or len(file_list) |
|
|
327 |
== num_samples * 2 # Temporary fix for different outputs per model. |
|
|
328 |
or len(file_list) == num_samples + 1 |
|
|
329 |
), f"Model {model_id} generated {len(file_list)} samples instead of the expected {num_samples}, {num_samples*2*6}, or {num_samples + 1}." |
|
|
330 |
# Some models are balanced per label by default: If num_samples is odd, then len(file_list)==num_samples +1 |
|
|
331 |
else: |
|
|
332 |
assert len(file_list) == 0 |
|
|
333 |
|
|
|
334 |
# @pytest.mark.skip |
|
|
335 |
def _remove_dir_and_contents(self): |
|
|
336 |
"""After each test, empty the created folders and files to avoid corrupting a new test.""" |
|
|
337 |
|
|
|
338 |
try: |
|
|
339 |
shutil.rmtree(self.test_output_path) |
|
|
340 |
except OSError as e: |
|
|
341 |
# This may give an error if the folders are not created. |
|
|
342 |
self.logger.debug( |
|
|
343 |
f"Exception while trying to delete folder. Likely it simply had not yet been created: {e}" |
|
|
344 |
) |
|
|
345 |
except Exception as e2: |
|
|
346 |
self.logger.error(f"Error while trying to delete folder: {e2}") |
|
|
347 |
|
|
|
348 |
@pytest.mark.skip |
|
|
349 |
def _remove_model_dir_and_zip( |
|
|
350 |
self, model_ids=[], are_all_models_deleted: bool = False |
|
|
351 |
): |
|
|
352 |
"""After a specific model folders, model_executor, and model zip file to avoid running out-of-disk space.""" |
|
|
353 |
|
|
|
354 |
try: |
|
|
355 |
for i, model_executor in enumerate(self.generators.model_executors): |
|
|
356 |
if are_all_models_deleted or ( |
|
|
357 |
model_ids is not None and model_executor.model_id in model_ids |
|
|
358 |
): |
|
|
359 |
try: |
|
|
360 |
# Delete the folder containing the model |
|
|
361 |
model_path = os.path.dirname( |
|
|
362 |
model_executor.deserialized_model_as_lib.__file__ |
|
|
363 |
) |
|
|
364 |
shutil.rmtree(model_path) |
|
|
365 |
self.logger.info( |
|
|
366 |
f"Deleted directory of model {model_executor.model_id}. ({model_path})" |
|
|
367 |
) |
|
|
368 |
|
|
|
369 |
except OSError as e: |
|
|
370 |
# This may give an error if the FOLDER is not present |
|
|
371 |
self.logger.warning( |
|
|
372 |
f"Exception while trying to delete the model folder of model {model_executor.model_id}: {e}" |
|
|
373 |
) |
|
|
374 |
try: |
|
|
375 |
# If the downloaded zip package of the model was not deleted inside the model_path, we explicitely delete it now. |
|
|
376 |
if model_executor.package_path.is_file(): |
|
|
377 |
os.remove(model_executor.package_path) |
|
|
378 |
self.logger.info( |
|
|
379 |
f"Deleted zip file of model {model_executor.model_id}. ({model_executor.package_path})" |
|
|
380 |
) |
|
|
381 |
except Exception as e: |
|
|
382 |
self.logger.warning( |
|
|
383 |
f"Exception while trying to delete the ZIP file ({model_executor.package_path}) of model {model_executor.model_id}: {e}" |
|
|
384 |
) |
|
|
385 |
# Deleting the stateful model_executors instantiated by the generators module, after deleting folders and zips |
|
|
386 |
if are_all_models_deleted: |
|
|
387 |
self.generators.model_executors.clear() |
|
|
388 |
else: |
|
|
389 |
if model_ids is not None: |
|
|
390 |
for model_id in model_ids: |
|
|
391 |
model_executor = self.generators.find_model_executor_by_id( |
|
|
392 |
model_id |
|
|
393 |
) |
|
|
394 |
if model_executor is not None: |
|
|
395 |
self.generators.model_executors.remove(model_executor) |
|
|
396 |
del model_executor |
|
|
397 |
except Exception as e2: |
|
|
398 |
self.logger.error( |
|
|
399 |
f"Error while trying to delete model folders and zips: {e2}" |
|
|
400 |
) |
|
|
401 |
|
|
|
402 |
# @pytest.fixture(scope="session", autouse=True) |
|
|
403 |
def teardown_class(self): |
|
|
404 |
"""After all tests, empty the large model folders, model_executors, and zip files to avoid running out-of-disk space.""" |
|
|
405 |
|
|
|
406 |
# yield is at test-time, signaling that things after yield are run after the execution of the last test has terminated |
|
|
407 |
# https://docs.pytest.org/en/7.1.x/reference/reference.html?highlight=fixture#pytest.fixture |
|
|
408 |
# yield None |
|
|
409 |
|
|
|
410 |
# Remove all test outputs in test_output_path |
|
|
411 |
self._remove_dir_and_contents(self) |
|
|
412 |
|
|
|
413 |
# Remove all model folders, zip files and model executors |
|
|
414 |
# self._remove_model_dir_and_zip( |
|
|
415 |
# self, model_ids=["00006_WGANGP_MMG_MASS_ROI"], are_all_models_deleted=False |
|
|
416 |
# ) # just for local testing |
|
|
417 |
# self._remove_model_dir_and_zip( |
|
|
418 |
# self, model_ids=None, are_all_models_deleted=True |
|
|
419 |
# ) |