Diff of /tests/fid.py [000000] .. [3ee609]

Switch to unified view

a b/tests/fid.py
1
"""
2
Calculates the Frechet Inception Distance between two distributions, using chosen feature extractor model.
3
4
RadImageNet Model source: https://github.com/BMEII-AI/RadImageNet
5
RadImageNet InceptionV3 weights (original, broken since 11.07.2023): https://drive.google.com/file/d/1p0q9AhG3rufIaaUE1jc2okpS8sdwN6PU
6
RadImageNet InceptionV3 weights (for medigan, updated link 11.07.2023): https://drive.google.com/drive/folders/1lGFiS8_a5y28l4f8zpc7fklwzPJC-gZv
7
8
Usage:
9
    python fid.py dir1 dir2 
10
"""
11
12
import argparse
13
import os
14
15
import cv2
16
import numpy as np
17
import tensorflow as tf
18
import tensorflow_gan as tfgan
19
import wget
20
from tensorflow.keras.applications import InceptionV3
21
from tensorflow.keras.applications.inception_v3 import preprocess_input
22
23
img_size = 299
24
batch_size = 64
25
num_batches = 1
26
RADIMAGENET_URL = "https://drive.google.com/uc?id=1uvJHLG1K71Qzl7Km4JMpNOwE7iTjN8g9"
27
RADIMAGENET_WEIGHTS = "RadImageNet-InceptionV3_notop.h5"
28
IMAGENET_TFHUB_URL = "https://tfhub.dev/tensorflow/tfgan/eval/inception/1"
29
30
31
def parse_args() -> argparse.Namespace:
32
    parser = argparse.ArgumentParser(
33
        description="Calculates the Frechet Inception Distance between two distributions using RadImageNet model."
34
    )
35
    parser.add_argument(
36
        "dataset_path_1",
37
        type=str,
38
        help="Path to images from first dataset",
39
    )
40
    parser.add_argument(
41
        "dataset_path_2",
42
        type=str,
43
        help="Path to images from second dataset",
44
    )
45
    parser.add_argument(
46
        "--model",
47
        type=str,
48
        default="imagenet",
49
        help="Use RadImageNet feature extractor for FID calculation",
50
    )
51
    parser.add_argument(
52
        "--lower_bound",
53
        action="store_true",
54
        help="Calculate lower bound of FID using the 50/50 split of images from dataset_path_1",
55
    )
56
    parser.add_argument(
57
        "--normalize_images",
58
        action="store_true",
59
        help="Normalize images from both datasources using min and max of each sample",
60
    )
61
    args = parser.parse_args()
62
    return args
63
64
65
def load_images(directory, normalize=False, split=False, limit=None):
66
    """
67
    Loads images from the given directory.
68
    If split is True, then half of the images is loaded to one array and the other half to another.
69
    """
70
    if split:
71
        subset_1 = []
72
        subset_2 = []
73
    else:
74
        images = []
75
76
    for count, filename in enumerate(os.listdir(directory)):
77
        if filename.lower().endswith((".png", ".jpg", ".jpeg")):
78
            img = cv2.imread(os.path.join(directory, filename))
79
            img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_LINEAR)
80
            if normalize:
81
                img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
82
            if len(img.shape) > 2 and img.shape[2] == 4:
83
                img = img[:, :, :3]
84
            if len(img.shape) == 2:
85
                img = np.stack([img] * 3, axis=2)
86
87
            if split:
88
                if count % 2 == 0:
89
                    subset_1.append(img)
90
                else:
91
                    subset_2.append(img)
92
            else:
93
                images.append(img)
94
        if count == limit:
95
            break
96
    if split:
97
        subset_1 = preprocess_input(np.array(subset_1))
98
        subset_2 = preprocess_input(np.array(subset_2))
99
        return subset_1, subset_2
100
    else:
101
        images = preprocess_input(np.array(images))
102
        return images
103
104
105
def check_model_weights(model_name):
106
    """
107
    Checks if the model weights are available and download them if not.
108
    """
109
    model_weights_path = None
110
    if model_name == "radimagenet":
111
        model_weights_path = RADIMAGENET_WEIGHTS
112
        if not os.path.exists(RADIMAGENET_WEIGHTS):
113
            print("Downloading RadImageNet InceptionV3 model:")
114
            wget.download(
115
                RADIMAGENET_URL,
116
                model_weights_path,
117
            )
118
            print("\n")
119
        return model_weights_path
120
121
122
def _radimagenet_fn(images):
123
    """
124
    Get RadImageNet inception v3 model
125
    """
126
    model = InceptionV3(
127
        weights=RADIMAGENET_WEIGHTS,
128
        input_shape=(img_size, img_size, 3),
129
        include_top=False,
130
        pooling="avg",
131
    )
132
    output = model(images)
133
    output = tf.nest.map_structure(tf.keras.layers.Flatten(), output)
134
    return output
135
136
137
def get_classifier_fn(model_name="imagenet"):
138
    """
139
    Get model as TF function for optimized inference.
140
    """
141
    check_model_weights(model_name)
142
143
    if model_name == "radimagenet":
144
        return _radimagenet_fn
145
    elif model_name == "imagenet":
146
        return tfgan.eval.classifier_fn_from_tfhub(IMAGENET_TFHUB_URL, "pool_3", True)
147
    else:
148
        raise ValueError("Model {} not recognized".format(model_name))
149
150
151
def calculate_fid(
152
    directory_1,
153
    directory_2,
154
    model_name,
155
    lower_bound=False,
156
    normalize_images=False,
157
):
158
    """
159
    Calculates the Frechet Inception Distance between two distributions using chosen feature extractor model.
160
    """
161
    limit = min(len(os.listdir(directory_1)), len(os.listdir(directory_2)))
162
    if lower_bound:
163
        images_1, images_2 = load_images(directory_1, split=True, limit=limit)
164
    else:
165
        images_1 = load_images(directory_1, limit=limit, normalize=normalize_images)
166
        images_2 = load_images(directory_2, limit=limit, normalize=normalize_images)
167
168
    fid = tfgan.eval.frechet_classifier_distance(
169
        images_1, images_2, get_classifier_fn(model_name)
170
    )
171
172
    return fid
173
174
175
if __name__ == "__main__":
176
    args = parse_args()
177
178
    directory_1 = args.dataset_path_1
179
    directory_2 = args.dataset_path_2
180
    lower_bound = args.lower_bound
181
    normalize_images = args.normalize_images
182
    model_name = args.model
183
184
    fid = calculate_fid(
185
        directory_1=directory_1,
186
        directory_2=directory_2,
187
        model_name=model_name,
188
        lower_bound=lower_bound,
189
        normalize_images=normalize_images,
190
    )
191
192
    if lower_bound:
193
        print("Lower bound FID {}: {}".format(model_name, fid))
194
    else:
195
        print("FID {}: {}".format(model_name, fid))