Switch to unified view

a b/tests/test_model_selector.py
1
# -*- coding: utf-8 -*-
2
# ! /usr/bin/env python
3
""" main test script to test the primary functions/classes/methods. """
4
# run with python -m tests.test_model_selector
5
6
import logging
7
import sys
8
9
import pytest
10
11
# import unittest
12
13
14
# Set the logging level depending on the level of detail you would like to have in the logs while running the tests.
15
LOGGING_LEVEL = logging.INFO  # WARNING  # logging.INFO
16
17
models = [
18
    (
19
        "00001_DCGAN_MMG_CALC_ROI",
20
        {},
21
        100,
22
    ),
23
    ("00002_DCGAN_MMG_MASS_ROI", {}, 3),
24
    ("00003_CYCLEGAN_MMG_DENSITY_FULL", {"translate_all_images": False}, 2),
25
    ("00005_DCGAN_MMG_MASS_ROI", {}, 3),
26
    # Further models can be added here if/when needed.
27
]
28
29
30
# class TestMediganSelectorMethods(unittest.TestCase):
31
class TestMediganSelectorMethods:
32
    def setup_method(self):
33
        ## unittest logger config
34
        # This logger on root level initialized via logging.getLogger() will also log all log events
35
        # from the medigan library. Pass a logger name (e.g. __name__) instead if you only want logs from tests.py
36
        self.logger = logging.getLogger()  # (__name__)
37
        self.logger.setLevel(LOGGING_LEVEL)
38
        stream_handler = logging.StreamHandler(sys.stdout)
39
        stream_handler.setLevel(LOGGING_LEVEL)
40
        formatter = logging.Formatter(
41
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
42
        )
43
        stream_handler.setFormatter(formatter)
44
        self.logger.addHandler(stream_handler)
45
        self.test_init_generators()
46
47
    def test_init_generators(self):
48
        from src.medigan.generators import Generators
49
50
        self.generators = Generators()
51
52
    @pytest.mark.parametrize(
53
        "values_list",
54
        [
55
            (["dcgan", "mMg", "ClF", "modality"]),
56
            (["DCGAN", "Mammography"]),
57
        ],
58
    )
59
    def test_search_for_models_method(self, values_list):
60
        found_models = self.generators.find_matching_models_by_values(
61
            values=values_list,
62
            target_values_operator="AND",
63
            are_keys_also_matched=True,
64
            is_case_sensitive=False,
65
        )
66
        self.logger.debug(
67
            f"For value {values_list}, these models were found: {found_models}"
68
        )
69
        assert len(found_models) > 0
70
71
    @pytest.mark.parametrize(
72
        "models, values_list, metric",
73
        [
74
            (
75
                models,
76
                ["dcgan", "MMG"],
77
                "CLF.trained_on_real_and_fake.f1",
78
            ),
79
            (models, ["dcgan", "MMG"], "turing_test.AUC"),
80
        ],
81
    )
82
    def test_find_and_rank_models_by_performance(self, models, values_list, metric):
83
        # These values would need to find at least two models. See metrics and values in the config/global.json file.
84
        found_ranked_models = self.generators.find_models_and_rank(
85
            values=values_list,
86
            target_values_operator="AND",
87
            are_keys_also_matched=True,
88
            is_case_sensitive=False,
89
            metric=metric,
90
            order="desc",
91
        )
92
        assert (
93
            len(found_ranked_models) > 0  # some models were found as is expected
94
            and found_ranked_models[0]["model_id"] is not None  # has a model id
95
            and (
96
                len(found_ranked_models) < 2
97
                or found_ranked_models[0][metric] > found_ranked_models[1][metric]
98
            )  # descending order (the higher a model's value, the lower its index in the list) is working
99
        )
100
101
    @pytest.mark.parametrize(
102
        "models, metric, order",
103
        [
104
            (
105
                models,
106
                "FID",
107
                "asc",
108
            ),  # Note: normally a lower FID is better, therefore asc (model with lowest FID has lowest result list index).
109
            (
110
                models,
111
                "FID_RADIMAGENET_ratio",
112
                "desc",  # descending, as the higher the FID ratio the better.
113
            ),
114
            # Note: normally a lower FID is better, therefore asc (model with lowest FID has lowest result list index).
115
            (models, "CLF.trained_on_real_and_fake.f1", "desc"),
116
            (models, "turing_test.AUC", "desc"),
117
        ],
118
    )
119
    def test_rank_models_by_performance(self, models, metric, order):
120
        """Ranking according to metrics in the config/global.json file."""
121
        ranked_models = self.generators.rank_models_by_performance(
122
            model_ids=None,
123
            metric=metric,
124
            order=order,
125
        )
126
        assert (
127
            len(ranked_models) > 0  # at least one model was found
128
            and (
129
                len(ranked_models) >= 21 or metric != "FID"
130
            )  # we should find at least 21 models with FID in medigan
131
            and ranked_models[0]["model_id"]
132
            is not None  # found model has a model id (i.e. correctly formatted results)
133
            and (
134
                len(ranked_models) == 1
135
                or (
136
                    ranked_models[0][metric] > ranked_models[1][metric]
137
                    or metric == "FID"
138
                )
139
            )  # descending order (the higher a model's value, the lower its index in the list) is working. In case of FID it is the other way around (ascending order is better).
140
        )
141
142
    @pytest.mark.parametrize(
143
        "models, metric, order",
144
        [
145
            (models, "CLF.trained_on_real_and_fake.f1", "desc"),
146
            (models, "turing_test.AUC", "desc"),
147
        ],
148
    )
149
    def test_rank_models_by_performance_with_given_ids(self, models, metric, order):
150
        """Ranking a specified set of models according to metrics in the config/global.json file."""
151
        ranked_models = self.generators.rank_models_by_performance(
152
            model_ids=[models[1][0], models[2][0]],
153
            metric=metric,
154
            order=order,
155
        )
156
        assert 0 < len(ranked_models) <= 2 and (
157
            len(ranked_models) < 2
158
            or (ranked_models[0][metric] > ranked_models[1][metric])
159
        )  # checking if descending order (the higher a model's value, the lower its index in the list) is working.
160
161
    @pytest.mark.parametrize(
162
        "key1, value1, expected",
163
        [
164
            ("modality", "Full-Field Mammography", 2),
165
            ("license", "BSD", 2),
166
            ("performance.CLF.trained_on_real_and_fake.f1", "0.96", 0),
167
            ("performance.turing_test.AUC", "0.56", 0),
168
        ],
169
    )
170
    def test_get_models_by_key_value_pair(self, key1, value1, expected):
171
        found_models = self.generators.get_models_by_key_value_pair(
172
            key1=key1, value1=value1, is_case_sensitive=False
173
        )
174
        assert len(found_models) >= expected