Diff of /main.py [000000] .. [f804b3]

Switch to unified view

a b/main.py
1
import argparse
2
import json
3
import os
4
5
import numpy as np
6
import torch
7
import torch.optim as optim
8
from torch.utils.data import DataLoader
9
from torchvision import transforms
10
from tqdm import tqdm
11
import albumentations as A
12
from albumentations.pytorch import ToTensor
13
14
15
from common. dataset import MedicalImageDataset as Dataset
16
from common.logger import Logger
17
from common.loss import bce_dice_loss, dice_coef_metric,_fast_hist, jaccard_index
18
from model.Att_Unet import Att_Unet
19
from common.utils import log_images
20
21
22
def main(config):
23
    makedirs(config)
24
    snapshotargs(config)
25
    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
26
27
    loader_train, loader_valid = data_loaders(config)
28
    loaders = {"train": loader_train, "valid": loader_valid}
29
30
    unet =Att_Unet()
31
    unet.to(device)
32
33
34
    best_validation_dsc = 0.0
35
36
    optimizer = optim.Adam(unet.parameters(), lr=config.lr,weight_decay=1e-5)
37
    lr_scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=15, verbose=False)
38
39
40
    logger = Logger(config.logs)
41
    loss_train = []
42
    loss_valid = []
43
44
    step = 0
45
46
    for epoch in tqdm(range(config.epochs), total=config.epochs):
47
        for phase in ["train", "valid"]:
48
                if phase == "train":
49
                    unet.train()
50
                else:
51
                    unet.eval()
52
53
                validation_pred = []
54
                validation_true = []
55
                running_loss = 0.0
56
57
                for i, data in enumerate(loaders[phase]):
58
                    if phase == "train":
59
                        step += 1
60
61
                    x, y_true = data
62
                    x, y_true = x.to(device), y_true.to(device)
63
64
                    optimizer.zero_grad()
65
66
                    with torch.set_grad_enabled(phase == "train"):
67
                            y_pred = unet(x)
68
69
                            loss = bce_dice_loss(y_pred, y_true)
70
71
                            if phase == "valid":
72
                                loss_valid.append(loss.item())
73
                                y_pred_np = y_pred.detach().cpu().numpy()
74
                                validation_pred.extend(
75
                                    [y_pred_np[s] for s in range(y_pred_np.shape[0])]
76
                                )
77
                                y_true_np = y_true.detach().cpu().numpy()
78
                                validation_true.extend(
79
                                    [y_true_np[s] for s in range(y_true_np.shape[0])]
80
                                )
81
                                if (epoch % config.vis_freq == 0) or (epoch == config.epochs - 1):
82
                                    if i * config.batch_size < config.vis_images:
83
                                        tag = "image/{}".format(i)
84
                                        num_images = config.vis_images - i * config.batch_size
85
                                        logger.image_list_summary(
86
                                            tag,
87
                                            log_images(x, y_true, y_pred)[:num_images],
88
                                            step,
89
                                        )
90
91
                            if phase == "train":
92
                                loss_train.append(loss.item())
93
                                loss.backward()
94
                                optimizer.step()
95
                            running_loss += loss.detach() * x.size(0)
96
97
                    if i % 50 == 0:
98
                        for param_group in optimizer.param_groups:
99
                            print("Current learning rate is: {}".format(param_group['lr']))
100
101
102
103
                    if phase == "train" and (step + 1) % 10 == 0:
104
                        log_loss_summary(logger, loss_train, step)
105
                        loss_train = []
106
107
                print('Epoch [%d/%d], Loss: %.4f, ' %(epoch+1, config.epochs, running_loss/len(loaders[phase].dataset)))
108
109
                if phase == "valid":
110
                    log_loss_summary(logger, loss_valid, step, prefix="val_")
111
                    mean_dsc,mean_iou = compute_metric(unet,loaders[phase])
112
                    logger.scalar_summary("val_dsc", mean_dsc, step)
113
                    logger.scalar_summary("val_iou", mean_iou, step)
114
                    lr_scheduler.step(mean_dsc)
115
                    print("\nMean DICE on validation:", mean_dsc)
116
                    print("Mean IOU on validation:", mean_iou)
117
                    print("..........................................")
118
119
                    if mean_dsc > best_validation_dsc:
120
                        best_validation_dsc = mean_dsc
121
                        torch.save(unet.state_dict(), os.path.join(config.weights, "unet.pt"))
122
                    loss_valid = []
123
124
125
126
    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
127
128
129
def data_loaders(config):
130
    dataset_train, dataset_valid = datasets(config)
131
132
133
134
    loader_train = DataLoader(
135
        dataset_train,
136
        batch_size=config.batch_size,
137
        num_workers=config.workers
138
    )
139
    loader_valid = DataLoader(
140
        dataset_valid,
141
        batch_size=config.batch_size,
142
        num_workers=config.workers
143
144
    )
145
146
    return loader_train, loader_valid
147
148
data_transforms = A.Compose ([
149
    A.Resize(width = 256, height = 256, p=1.0),
150
    A.HorizontalFlip(p=0.5),
151
    A.VerticalFlip(p=0.5),
152
    A.Rotate((-5,5),p=0.5),
153
    A.RandomSunFlare(flare_roi=(0, 0, 1, 0.5), angle_lower=0, angle_upper=1,
154
                                   num_flare_circles_lower=1, num_flare_circles_upper=2,
155
                                   src_radius=160, src_color=(255, 255, 255),  always_apply=False, p=0.2),
156
     A.RGBShift (r_shift_limit=10, g_shift_limit=10,
157
                 b_shift_limit=10, always_apply=False, p=0.2),
158
    A. ElasticTransform (alpha=2, sigma=15, alpha_affine=25, interpolation=1,
159
                                      border_mode=4, value=None, mask_value=None,
160
                                      always_apply=False, approximate=False, p=0.2) ,
161
    A.Normalize( p=1.0),
162
    ToTensor(),
163
])
164
165
166
def datasets(config):
167
    train = Dataset('train', config.root,
168
                    transform=data_transforms)
169
170
171
    valid = Dataset('val', config.root,
172
                    transform=data_transforms)
173
174
    return train, valid
175
176
177
def compute_metric(model, loader, threshold=0.3):
178
    """
179
    Computes accuracy on the dataset wrapped in a loader
180
181
    Returns: accuracy as a float value between 0 and 1
182
    """
183
    device = torch.device("cpu" if not torch.cuda.is_available() else config.device)
184
    #model.eval()
185
    valloss_one = 0
186
    valloss_two = 0
187
188
    with torch.no_grad():
189
190
        for i_step, (data, target) in enumerate(loader):
191
192
            data = data.to(device)
193
            target = target.to(device)
194
195
196
            #prediction = model(x_gpu)
197
198
            outputs = model(data)
199
           # print("val_output:", outputs.shape)
200
201
            out_cut = np.copy(outputs.data.cpu().numpy())
202
            out_cut[np.nonzero(out_cut < threshold)] = 0.0
203
            out_cut[np.nonzero(out_cut >= threshold)] = 1.0
204
            hist=_fast_hist(target.data.cpu().numpy(),out_cut,num_classes=2)
205
206
            picloss = dice_coef_metric(hist)
207
            iouloss,_=jaccard_index(hist)
208
            valloss_one += picloss
209
            valloss_two +=iouloss
210
211
212
    return valloss_one / i_step,valloss_two/i_step
213
214
215
def log_loss_summary(logger, loss, step, prefix=""):
216
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)
217
218
219
def makedirs(config):
220
    os.makedirs(config.weights, exist_ok=True)
221
    os.makedirs(config.logs, exist_ok=True)
222
223
224
def snapshotargs(config):
225
    config_file = os.path.join(config.logs, "config.json")
226
    with open(config_file, "w") as fp:
227
        json.dump(vars(config), fp)
228
229
230
if __name__ == "__main__":
231
    parser = argparse.ArgumentParser(
232
        description="Finetuning pretrained Unet"
233
    )
234
    parser.add_argument(
235
        "--batch-size",
236
        type=int,
237
        default=16,
238
        help="input batch size for training (default: 16)",
239
    )
240
    parser.add_argument(
241
        "--epochs",
242
        type=int,
243
        default=100,
244
        help="number of epochs to train (default: 100)",
245
    )
246
    parser.add_argument(
247
        "--lr",
248
        type=float,
249
        default=0.001,
250
        help="initial learning rate (default: 0.001)",
251
    )
252
    parser.add_argument(
253
        "--device",
254
        type=str,
255
        default="cuda:0",
256
        help="device for training (default: cuda:0)",
257
    )
258
    parser.add_argument(
259
        "--workers",
260
        type=int,
261
        default=4,
262
        help="number of workers for data loading (default: 4)",
263
    )
264
    parser.add_argument(
265
        "--vis-images",
266
        type=int,
267
        default=100,
268
        help="number of visualization images to save in log file (default: 200)",
269
    )
270
    parser.add_argument(
271
        "--vis-freq",
272
        type=int,
273
        default=10,
274
        help="frequency of saving images to log file (default: 10)",
275
    )
276
    parser.add_argument(
277
        "--weights", type=str, default="./weights", help="folder to save weights"
278
    )
279
    parser.add_argument(
280
        "--logs", type=str, default="./logs", help="folder to save logs"
281
    )
282
    parser.add_argument(
283
        "--root", type=str, default="./medico2020", help="root folder with images"
284
    )
285
    parser.add_argument(
286
        "--image-size",
287
        type=int,
288
        default=256,
289
        help="target input image size (default: 256)",
290
    )
291
    parser.add_argument(
292
        "--aug-scale",
293
        type=int,
294
        default=0.05,
295
        help="scale factor range for augmentation (default: 0.05)",
296
    )
297
    parser.add_argument(
298
        "--aug-angle",
299
        type=int,
300
        default=6,
301
        help="rotation angle range in degrees for augmentation (default: 15)",
302
    )
303
    config = parser.parse_args()
304
    main(config)