Switch to unified view

a b/tests/test_model_executor.py
1
# -*- coding: utf-8 -*-
2
# ! /usr/bin/env python
3
""" main test script to test the primary functions/classes/methods. """
4
# run with python -m tests.test_generator
5
6
import glob
7
import logging
8
import os
9
import shutil
10
import sys
11
12
import pytest
13
import torch
14
15
# import unittest
16
17
18
# Set the logging level depending on the level of detail you would like to have in the logs while running the tests.
19
LOGGING_LEVEL = logging.INFO  # WARNING  # logging.INFO
20
21
models_with_args = [
22
    (
23
        "00001_DCGAN_MMG_CALC_ROI",
24
        {},
25
        100,
26
    ),  # 100 samples to test automatic batch-wise image generation in model_executor
27
    (
28
        "00002",
29
        {},
30
        3,
31
    ),  # "00002" instead of "00002_DCGAN_MMG_MASS_ROI" to test shortcut model_ids
32
    (
33
        "03",
34
        {"translate_all_images": False},
35
        2,
36
    ),  # "03" instead of "00003_CYCLEGAN_MMG_DENSITY_FULL" to test shortcut model_ids
37
    (
38
        4,  # 4 instead of "00004_PIX2PIX_MMG_MASSES_W_MASKS" to test shortcut model_ids
39
        {
40
            "shapes": ["oval"],
41
            "ssim_threshold": 0.18,
42
            "image_size": [128, 128],
43
            "patch_size": [30, 30],
44
        },
45
        3,
46
    ),
47
    ("00005_DCGAN_MMG_MASS_ROI", {}, 3),
48
    ("00006_WGANGP_MMG_MASS_ROI", {}, 3),
49
    (
50
        "00007_INPAINT_BRAIN_MRI",
51
        {
52
            "image_size": (256, 256),
53
            "num_inpaints_per_sample": 2,
54
            "randomize_input_image_order": False,
55
            "add_variations_to_mask": False,
56
            "x_center": 120,
57
            "y_center": 140,
58
            "radius_1": 8,
59
            "radius_2": 12,
60
            "radius_3": 24,
61
        },
62
        3,
63
    ),
64
    (
65
        "00008_C-DCGAN_MMG_MASSES",
66
        {"condition": 0, "is_cbisddsm_training_data": False},
67
        3,
68
    ),
69
    ("00009_PGGAN_POLYP_PATCHES_W_MASKS", {"save_option": "image_only"}, 3),
70
    ("00010_FASTGAN_POLYP_PATCHES_W_MASKS", {"save_option": "image_only"}, 3),
71
    # ("00011_SINGAN_POLYP_PATCHES_W_MASKS", {"checkpoint_ids": [999]}, 3), # removed after successful testing due to limited CI pipeline capacity
72
    # ("00012_C-DCGAN_MMG_MASSES", {"condition": 0}, 3), # removed after successful testing due to limited CI pipeline capacity
73
    # ("00013_CYCLEGAN_MMG_DENSITY_OPTIMAM_MLO", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity
74
    # ("00014_CYCLEGAN_MMG_DENSITY_OPTIMAM_CC", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity
75
    # ("00015_CYCLEGAN_MMG_DENSITY_CSAW_MLO", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity
76
    # ("00016_CYCLEGAN_MMG_DENSITY_CSAW_CC", {"translate_all_images": False}, 2), # removed after successful testing due to limited CI pipeline capacity
77
    ("00017_DCGAN_XRAY_LUNG_NODULES", {}, 3),
78
    ("00018_WGANGP_XRAY_LUNG_NODULES", {}, 3),
79
    ("00019_PGGAN_CHEST_XRAY", {}, 3),
80
    ("00020_PGGAN_CHEST_XRAY", {"resize_pixel_dim": 512, "image_size": 256}, 3),
81
    (
82
        "00021_CYCLEGAN_BRAIN_MRI_T1_T2",
83
        {
84
            "input_path": "models/00021_CYCLEGAN_Brain_MRI_T1_T2/inputs/T2",
85
            "gpu_id": 0,
86
            "T1_to_T2": False,
87
        },
88
        3,
89
    ),
90
    ("00022_WGAN_CARDIAC_AGING", {}, 3),
91
    (
92
        "00023_PIX2PIXHD_BREAST_DCEMRI",
93
        {
94
            "input_path": "input",
95
            "gpu_id": 0,
96
            "image_size": 448,
97
        },
98
        3,
99
    ),
100
]
101
102
103
# class TestMediganExecutorMethods(unittest.TestCase):
104
class TestMediganExecutorMethods:
105
    def setup_class(self):
106
        ## unittest logger config
107
        # This logger on root level initialized via logging.getLogger() will also log all log events
108
        # from the medigan library. Pass a logger name (e.g. __name__) instead if you only want logs from tests.py
109
        self.logger = logging.getLogger()  # (__name__)
110
        self.logger.setLevel(LOGGING_LEVEL)
111
        stream_handler = logging.StreamHandler(sys.stdout)
112
        stream_handler.setLevel(LOGGING_LEVEL)
113
        formatter = logging.Formatter(
114
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
115
        )
116
        stream_handler.setFormatter(formatter)
117
        self.logger.addHandler(stream_handler)
118
119
        self.test_output_path = "test_output_path"
120
        self.num_samples = 2
121
        self.test_imports_and_init_generators(self)
122
        self._remove_dir_and_contents(self)  # in case something is left there.
123
        self.model_ids = self.generators.config_manager.model_ids
124
125
    def test_imports_and_init_generators(self):
126
        from src.medigan.constants import (
127
            CONFIG_FILE_KEY_EXECUTION,
128
            CONFIG_FILE_KEY_GENERATE,
129
            CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE,
130
        )
131
        from src.medigan.generators import Generators
132
133
        self.generators = Generators()
134
        self.CONFIG_FILE_KEY_EXECUTION = CONFIG_FILE_KEY_EXECUTION
135
        self.CONFIG_FILE_KEY_GENERATE = CONFIG_FILE_KEY_GENERATE
136
        self.CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE = (
137
            CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE
138
        )
139
140
    @pytest.mark.parametrize("models_with_args", [models_with_args])
141
    def test_sample_generation_methods(self, models_with_args: list):
142
        self.logger.debug(f"models: {models_with_args}")
143
        for i, model_id in enumerate(self.model_ids):
144
            # if (
145
            #    model_id != "00011_SINGAN_POLYP_PATCHES_W_MASKS"
146
            # ):
147
            ## avoiding full memory on Windows ci test server
148
            # continue
149
            self.logger.debug(f"Now testing model {model_id}")
150
            self._remove_dir_and_contents()  # Already done in each test independently, but to be sure, here again.
151
            self.test_generate_method(model_id=model_id)
152
153
            # Check if args available fo model_id. Note: The models list may not include the latest medigan models
154
            for model in models_with_args:
155
                if model_id == model[0]:
156
                    self.test_generate_method_with_additional_args(
157
                        model_id=model[0], args=model[1], expected_num_samples=model[2]
158
                    )
159
            self.test_get_generate_method(model_id=model_id)
160
            self.test_get_dataloader_method(model_id=model_id)
161
162
            # if i == 16:  # just for local testing
163
            # self._remove_model_dir_and_zip(
164
            #    model_ids=[model_id], are_all_models_deleted=False
165
            # )
166
167
    @pytest.mark.parametrize(
168
        "values_list, should_sample_be_generated",
169
        [
170
            (["dcgan", "mMg", "ClF", "modality", "inbreast"], True),
171
            (["dcgan", "mMg", "ClF", "modality", "optimam"], True),
172
            (["dcgan", "mMg", "ClF", "modalities"], False),
173
        ],
174
    )
175
    def test_find_model_and_generate_method(
176
        self, values_list, should_sample_be_generated
177
    ):
178
        self._remove_dir_and_contents()
179
180
        self.generators.find_model_and_generate(
181
            values=values_list,
182
            target_values_operator="AND",
183
            are_keys_also_matched=True,
184
            is_case_sensitive=False,
185
            num_samples=self.num_samples,
186
            output_path=self.test_output_path,
187
        )
188
189
        self._check_if_samples_were_generated(
190
            should_sample_be_generated=should_sample_be_generated
191
        )
192
193
    @pytest.mark.parametrize(
194
        "values_list, metric",
195
        [
196
            (["dcgan", "MMG"], "CLF.trained_on_real_and_fake.f1"),
197
            (["dcgan", "MMG"], "turing_test.AUC"),
198
        ],
199
    )
200
    def test_find_and_rank_models_then_generate_method(self, values_list, metric):
201
        self._remove_dir_and_contents()
202
        # TODO This test needs the respective metrics for any of these models to be available in config/global.json.
203
        # These values would need to find at least two models.
204
        self.generators.find_models_rank_and_generate(
205
            values=values_list,
206
            target_values_operator="AND",
207
            are_keys_also_matched=True,
208
            is_case_sensitive=False,
209
            metric=metric,
210
            order="asc",
211
            num_samples=self.num_samples,
212
            output_path=self.test_output_path,
213
        )
214
        self._check_if_samples_were_generated()
215
216
    # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args])
217
    @pytest.mark.skip
218
    def test_generate_method(self, model_id):
219
        self._remove_dir_and_contents()
220
        self.generators.generate(
221
            model_id=model_id,
222
            num_samples=self.num_samples,
223
            output_path=self.test_output_path,
224
            install_dependencies=True,
225
        )
226
        self._check_if_samples_were_generated(model_id=model_id)
227
228
    # @pytest.mark.parametrize("model_id, args, expected_num_samples", models_with_args)
229
    @pytest.mark.skip
230
    def test_generate_method_with_additional_args(
231
        self, model_id, args, expected_num_samples
232
    ):
233
        self._remove_dir_and_contents()
234
        self.generators.generate(
235
            model_id=model_id,
236
            num_samples=expected_num_samples,
237
            output_path=self.test_output_path,
238
            **args,
239
        )
240
        self._check_if_samples_were_generated(
241
            model_id=model_id, num_samples=expected_num_samples
242
        )
243
244
    # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args])
245
    @pytest.mark.skip
246
    def test_get_generate_method(self, model_id):
247
        self._remove_dir_and_contents()
248
        gen_function = self.generators.get_generate_function(
249
            model_id=model_id,
250
            num_samples=self.num_samples,
251
            output_path=self.test_output_path,
252
        )
253
        gen_function()
254
        self._check_if_samples_were_generated(model_id=model_id)
255
        del gen_function
256
257
    # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args])
258
    @pytest.mark.skip
259
    def test_get_dataloader_method(self, model_id):
260
        self._remove_dir_and_contents()
261
        data_loader = self.generators.get_as_torch_dataloader(
262
            model_id=model_id, num_samples=self.num_samples
263
        )
264
        self.logger.debug(f"{model_id}: len(data_loader): {len(data_loader)}")
265
266
        if len(data_loader) != self.num_samples:
267
            logging.warning(
268
                f"{model_id}: The number of samples in the dataloader (={len(data_loader)}) is not equal the number of samples requested (={self.num_samples}). "
269
                f"Hint: Revise if the model's internal generate() function returned tuples as required in get_as_torch_dataloader()."
270
            )
271
272
        #### Get the object at index 0 from the dataloader
273
        data_dict = next(iter(data_loader))
274
275
        # Test if the items at index [0] of the aforementioned object is of type torch tensor (e.g. torch.uint8) and not None, as expected by data structure design decision.
276
        assert torch.is_tensor(data_dict.get("sample"))
277
278
        # Test if the items at index [1], [2] of the aforementioned object are None and, if not, whether they are of type torch tensor, as expected
279
        assert data_dict.get("mask") is None or torch.is_tensor(data_dict.get("mask"))
280
        assert data_dict.get("other_imaging_output") is None or torch.is_tensor(
281
            data_dict.get("other_imaging_output")
282
        )
283
284
        # Test if the items at index [3] of the aforementioned object is None and, if not, whether it is of type list of strings, as expected.
285
        assert data_dict.get("label") is None or (
286
            isinstance(data_dict.get("label"), list)
287
            and isinstance(data_dict.get("label")[0], str)
288
        )
289
        del data_dict
290
        del data_loader
291
292
    # @pytest.mark.parametrize("model_id", [model[0] for model in models_with_args])
293
    @pytest.mark.skip
294
    def test_visualize_method(self, model_id):
295
        if (
296
            self.CONFIG_FILE_KEY_GENERATE_ARGS_INPUT_LATENT_VECTOR_SIZE
297
            in self.generators.config_manager.config_dict[model_id][
298
                self.CONFIG_FILE_KEY_EXECUTION
299
            ][self.CONFIG_FILE_KEY_GENERATE]
300
        ):
301
            self.generators.visualize(model_id, auto_close=True)
302
303
        else:
304
            with pytest.raises(Exception) as e:
305
                self.generators.visualize(model_id, auto_close=True)
306
307
                assert e.type == ValueError
308
309
    @pytest.mark.skip
310
    def _check_if_samples_were_generated(
311
        self, model_id=None, num_samples=None, should_sample_be_generated: bool = True
312
    ):
313
        # check if the number of generated samples of model_id_1 is as expected.
314
        file_list = glob.glob(self.test_output_path + "/*")
315
        self.logger.debug(f"{model_id}: {len(file_list)} == {self.num_samples} ?")
316
        if num_samples is None:
317
            num_samples = self.num_samples
318
319
        if should_sample_be_generated:
320
            assert (
321
                len(file_list) == num_samples
322
                or len(file_list)
323
                == num_samples
324
                * 2
325
                * 6  # 00007_INPAINT_BRAIN_MRI: 2 inpaints per sample, 6 outputs per sample
326
                or len(file_list)
327
                == num_samples * 2  # Temporary fix for different outputs per model.
328
                or len(file_list) == num_samples + 1
329
            ), f"Model {model_id} generated {len(file_list)} samples instead of the expected {num_samples}, {num_samples*2*6}, or {num_samples + 1}."
330
            # Some models are balanced per label by default: If num_samples is odd, then len(file_list)==num_samples +1
331
        else:
332
            assert len(file_list) == 0
333
334
    # @pytest.mark.skip
335
    def _remove_dir_and_contents(self):
336
        """After each test, empty the created folders and files to avoid corrupting a new test."""
337
338
        try:
339
            shutil.rmtree(self.test_output_path)
340
        except OSError as e:
341
            # This may give an error if the folders are not created.
342
            self.logger.debug(
343
                f"Exception while trying to delete folder. Likely it simply had not yet been created: {e}"
344
            )
345
        except Exception as e2:
346
            self.logger.error(f"Error while trying to delete folder: {e2}")
347
348
    @pytest.mark.skip
349
    def _remove_model_dir_and_zip(
350
        self, model_ids=[], are_all_models_deleted: bool = False
351
    ):
352
        """After a specific model folders, model_executor, and model zip file to avoid running out-of-disk space."""
353
354
        try:
355
            for i, model_executor in enumerate(self.generators.model_executors):
356
                if are_all_models_deleted or (
357
                    model_ids is not None and model_executor.model_id in model_ids
358
                ):
359
                    try:
360
                        # Delete the folder containing the model
361
                        model_path = os.path.dirname(
362
                            model_executor.deserialized_model_as_lib.__file__
363
                        )
364
                        shutil.rmtree(model_path)
365
                        self.logger.info(
366
                            f"Deleted directory of model {model_executor.model_id}. ({model_path})"
367
                        )
368
369
                    except OSError as e:
370
                        # This may give an error if the FOLDER is not present
371
                        self.logger.warning(
372
                            f"Exception while trying to delete the model folder of model {model_executor.model_id}: {e}"
373
                        )
374
                    try:
375
                        # If the downloaded zip package of the model was not deleted inside the model_path, we explicitely delete it now.
376
                        if model_executor.package_path.is_file():
377
                            os.remove(model_executor.package_path)
378
                            self.logger.info(
379
                                f"Deleted zip file of model {model_executor.model_id}. ({model_executor.package_path})"
380
                            )
381
                    except Exception as e:
382
                        self.logger.warning(
383
                            f"Exception while trying to delete the ZIP file ({model_executor.package_path}) of model {model_executor.model_id}: {e}"
384
                        )
385
            # Deleting the stateful model_executors instantiated by the generators module, after deleting folders and zips
386
            if are_all_models_deleted:
387
                self.generators.model_executors.clear()
388
            else:
389
                if model_ids is not None:
390
                    for model_id in model_ids:
391
                        model_executor = self.generators.find_model_executor_by_id(
392
                            model_id
393
                        )
394
                        if model_executor is not None:
395
                            self.generators.model_executors.remove(model_executor)
396
                        del model_executor
397
        except Exception as e2:
398
            self.logger.error(
399
                f"Error while trying to delete model folders and zips: {e2}"
400
            )
401
402
    # @pytest.fixture(scope="session", autouse=True)
403
    def teardown_class(self):
404
        """After all tests, empty the large model folders, model_executors, and zip files to avoid running out-of-disk space."""
405
406
        # yield is at test-time, signaling that things after yield are run after the execution of the last test has terminated
407
        # https://docs.pytest.org/en/7.1.x/reference/reference.html?highlight=fixture#pytest.fixture
408
        # yield None
409
410
        # Remove all test outputs in test_output_path
411
        self._remove_dir_and_contents(self)
412
413
        # Remove all model folders, zip files and model executors
414
        # self._remove_model_dir_and_zip(
415
        #    self, model_ids=["00006_WGANGP_MMG_MASS_ROI"], are_all_models_deleted=False
416
        # )  # just for local testing
417
        # self._remove_model_dir_and_zip(
418
        #    self, model_ids=None, are_all_models_deleted=True
419
        # )