a b/templates/examples/__init__.py
1
"""
2
authors: Richard Osuala, Zuzanna Szafranowska
3
BCN-AIM 2021
4
"""
5
6
import logging
7
import os
8
from pathlib import Path
9
10
import cv2
11
import numpy as np
12
import torch
13
import torch.nn as nn
14
import torch.nn.parallel
15
16
17
class BaseGenerator(nn.Module):
18
    def __init__(
19
        self,
20
        nz: int,
21
        ngf: int,
22
        nc: int,
23
        ngpu: int,
24
        leakiness: float = 0.2,
25
        bias: bool = False,
26
    ):
27
        super(BaseGenerator, self).__init__()
28
        self.nz = nz
29
        self.ngf = ngf
30
        self.nc = nc
31
        self.ngpu = ngpu
32
        self.leakiness = leakiness
33
        self.bias = bias
34
        self.main = None
35
36
    def forward(self, input):
37
        raise NotImplementedError
38
39
40
class Generator(BaseGenerator):
41
    def __init__(
42
        self,
43
        nz: int,
44
        ngf: int,
45
        nc: int,
46
        ngpu: int,
47
        image_size: int,
48
        conditional: bool,
49
        leakiness: float,
50
        bias: bool = False,
51
        n_cond: int = 10,
52
        is_condition_categorical: bool = False,
53
        num_embedding_dimensions: int = 50,
54
    ):
55
        super(Generator, self).__init__(
56
            nz=nz,
57
            ngf=ngf,
58
            nc=nc,
59
            ngpu=ngpu,
60
            leakiness=leakiness,
61
            bias=bias,
62
        )
63
        # if is_condition_categorical is False, we model the condition as continous input to the network
64
        self.is_condition_categorical = is_condition_categorical
65
66
        # n_cond is only used if is_condition_categorical is True.
67
        self.num_embedding_input = n_cond
68
69
        # num_embedding_dimensions is only used if is_condition_categorical is True.
70
        # num_embedding_dimensions standard would be dim(z), but atm we have a nn.Linear after
71
        # nn.Embedding that upscales the dimension to self.nz. Using same value of num_embedding_dims in D and G.
72
        self.num_embedding_dimensions = num_embedding_dimensions
73
74
        # whether the is a conditional input into the GAN i.e. cGAN
75
        self.conditional: bool = conditional
76
77
        # The image size (supported params should be 128 or 64)
78
        self.image_size = image_size
79
80
        if self.image_size == 128:
81
            self.first_layers = nn.Sequential(
82
                # input is Z, going into a convolution
83
                nn.ConvTranspose2d(
84
                    self.nz * self.nc, self.ngf * 16, 4, 1, 0, bias=self.bias
85
                ),
86
                nn.BatchNorm2d(self.ngf * 16),
87
                nn.ReLU(True),
88
                # state size. (ngf*16) x 4 x 4
89
                nn.ConvTranspose2d(
90
                    self.ngf * 16, self.ngf * 8, 4, 2, 1, bias=self.bias
91
                ),
92
                nn.BatchNorm2d(self.ngf * 8),
93
                nn.ReLU(True),
94
            )
95
        elif self.image_size == 64:
96
            self.first_layers = nn.Sequential(
97
                # input is Z, going into a convolution
98
                nn.ConvTranspose2d(
99
                    self.nz * self.nc, self.ngf * 8, 4, 1, 0, bias=self.bias
100
                ),
101
                nn.BatchNorm2d(self.ngf * 8),
102
                nn.ReLU(True),
103
            )
104
        else:
105
            raise ValueError(
106
                f"Allowed image sizes are 128 and 64. You provided {self.image_size}. Please adjust."
107
            )
108
109
        self.main = nn.Sequential(
110
            *self.first_layers.children(),
111
            # state size. (ngf*8) x 8 x 8
112
            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=self.bias),
113
            nn.BatchNorm2d(self.ngf * 4),
114
            nn.ReLU(True),
115
            # state size. (ngf*4) x 16 x 16
116
            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=self.bias),
117
            nn.BatchNorm2d(self.ngf * 2),
118
            nn.ReLU(True),
119
            # state size. (ngf*2) x 32 x 32
120
            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=self.bias),
121
            nn.BatchNorm2d(self.ngf),
122
            nn.ReLU(True),
123
            # state size. (ngf) x 64 x 64
124
            # Note that out_channels=1 instead of out_channels=self.nc.
125
            # This is due to conditional input channel of our grayscale images
126
            nn.ConvTranspose2d(
127
                in_channels=self.ngf,
128
                out_channels=1,
129
                kernel_size=4,
130
                stride=2,
131
                padding=1,
132
                bias=self.bias,
133
            ),
134
            nn.Tanh(),
135
            # state size. (nc) x 128 x 128
136
        )
137
138
        if self.is_condition_categorical:
139
            self.embed_nn = nn.Sequential(
140
                # e.g. condition -> int -> embedding -> fcl -> feature map -> concat with image -> conv layers..
141
                # embedding layer
142
                nn.Embedding(
143
                    num_embeddings=self.num_embedding_input,
144
                    embedding_dim=self.num_embedding_dimensions,
145
                ),
146
                # target output dim of dense layer is batch_size x self.nz x 1 x 1
147
                # input is dimension of the embedding layer output
148
                nn.Linear(
149
                    in_features=self.num_embedding_dimensions, out_features=self.nz
150
                ),
151
                # nn.BatchNorm1d(self.nz),
152
                nn.LeakyReLU(self.leakiness, inplace=True),
153
            )
154
        else:
155
            self.embed_nn = nn.Sequential(
156
                # target output dim of dense layer is: nz x 1 x 1
157
                # input is dimension of the numbers of embedding
158
                nn.Linear(in_features=1, out_features=self.nz),
159
                # TODO Ablation: How does BatchNorm1d affect the conditional model performance?
160
                nn.BatchNorm1d(self.nz),
161
                nn.LeakyReLU(self.leakiness, inplace=True),
162
            )
163
164
    def forward(self, x, conditions=None):
165
        if self.conditional:
166
            # combining condition labels and input images via a new image channel
167
            if not self.is_condition_categorical:
168
                # If labels are continuous (not modelled as categorical), use floats instead of integers for labels.
169
                # Also adjust dimensions to (batch_size x 1) as needed for input into linear layer
170
                # labels should already be of type float, no change expected in .float() conversion (it is only a safety check)
171
172
                # Just for testing:
173
                conditions *= 0
174
                conditions += 1
175
176
                conditions = conditions.view(conditions.size(0), -1).float()
177
            embedded_conditions = self.embed_nn(conditions)
178
            embedded_conditions_with_random_noise_dim = embedded_conditions.view(
179
                -1, self.nz, 1, 1
180
            )
181
            x = torch.cat([x, embedded_conditions_with_random_noise_dim], 1)
182
        return self.main(x)
183
184
185
def interval_mapping(image, from_min, from_max, to_min, to_max):
186
    # map values from [from_min, from_max] to [to_min, to_max]
187
    # image: input array
188
    from_range = from_max - from_min
189
    to_range = to_max - to_min
190
    # scale to interval [0,1]
191
    scaled = np.array((image - from_min) / float(from_range), dtype=float)
192
    # multiply by range and add minimum to get interval [min,range+min]
193
    return to_min + (scaled * to_range)
194
195
196
def image_generator(model_path, device, nz, ngf, nc, ngpu, num_samples):
197
    # instantiate the model
198
    logging.debug("Instantiating model...")
199
    netG = Generator(
200
        nz=nz,
201
        ngf=ngf,
202
        nc=nc,
203
        ngpu=ngpu,
204
        image_size=128,
205
        leakiness=0.1,
206
        conditional=False,
207
    )
208
    if device.type == "cuda":
209
        netG.cuda()
210
211
    # load the model's weights from state_dict *'.pt file
212
    logging.debug(f"Loading model weights from {model_path} ...")
213
214
    checkpoint = torch.load(model_path, map_location=device)
215
    try:
216
        netG.load_state_dict(state_dict=checkpoint["generator"])
217
    except KeyError:
218
        raise KeyError(
219
            f"checkpoint['generator_state_dict'] was not found."
220
        )  # checkpoint={checkpoint}")
221
    logging.debug(f"Using retrieved model from generator_state_dict checkpoint")
222
    netG.eval()
223
224
    # generate the images
225
    logging.debug(f"Generating {num_samples} images using {device}...")
226
    z = torch.randn(num_samples, nz, 1, 1, device=device)
227
    images = netG(z).detach().cpu().numpy()
228
    image_list = []
229
    for j, img_ in enumerate(images):
230
        image_list.append(img_)
231
    return image_list
232
233
234
def save_generated_images(image_list, path):
235
    logging.debug(f"Saving generated images now in {path}")
236
    for i, img_ in enumerate(image_list):
237
        Path(path).mkdir(parents=True, exist_ok=True)
238
        img_path = f"{path}/{i}.png"
239
        img_ = interval_mapping(img_.transpose(1, 2, 0), -1.0, 0.0, 0, 255)
240
        img_ = img_.astype("uint8")
241
        cv2.imwrite(img_path, img_)
242
    logging.debug(f"Saved generated images to {path}")
243
244
245
def return_images(image_list):
246
    # logging.debug(f"Returning generated images as {type(image_list)}.")
247
    processed_image_list = []
248
    for i, img_ in enumerate(image_list):
249
        img_ = interval_mapping(img_.transpose(1, 2, 0), -1.0, 0.0, 0, 255)
250
        img_ = img_.astype("uint8")
251
        processed_image_list.append(img_)
252
    return processed_image_list
253
254
255
def generate(model_file, num_samples, output_path, save_images: bool):
256
    """This function generates synthetic images of mammography regions of interest"""
257
    try:
258
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259
        ngpu = 0
260
        if device == "cuda":
261
            ngpu = 1
262
        image_list = image_generator(model_file, device, 100, 64, 1, ngpu, num_samples)
263
        if save_images:
264
            save_generated_images(image_list, output_path)
265
        else:
266
            return return_images(image_list)
267
    except Exception as e:
268
        logging.error(
269
            f"Error while trying to generate {num_samples} images with model {model_file}: {e}"
270
        )
271
        raise e