a b/train.py
1
import argparse
2
import json
3
import os
4
5
import numpy as np
6
import torch
7
import torch.optim as optim
8
from torch.utils.data import DataLoader
9
from tqdm import tqdm
10
11
from dataset import BrainSegmentationDataset as Dataset
12
from logger import Logger
13
from loss import DiceLoss
14
from transform import transforms
15
from unet import UNet
16
from utils import log_images, dsc
17
18
19
def main(args):
20
    makedirs(args)
21
    snapshotargs(args)
22
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)
23
24
    loader_train, loader_valid = data_loaders(args)
25
    loaders = {"train": loader_train, "valid": loader_valid}
26
27
    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
28
    unet.to(device)
29
30
    dsc_loss = DiceLoss()
31
    best_validation_dsc = 0.0
32
33
    optimizer = optim.Adam(unet.parameters(), lr=args.lr)
34
35
    logger = Logger(args.logs)
36
    loss_train = []
37
    loss_valid = []
38
39
    step = 0
40
41
    for epoch in tqdm(range(args.epochs), total=args.epochs):
42
        for phase in ["train", "valid"]:
43
            if phase == "train":
44
                unet.train()
45
            else:
46
                unet.eval()
47
48
            validation_pred = []
49
            validation_true = []
50
51
            for i, data in enumerate(loaders[phase]):
52
                if phase == "train":
53
                    step += 1
54
55
                x, y_true = data
56
                x, y_true = x.to(device), y_true.to(device)
57
58
                optimizer.zero_grad()
59
60
                with torch.set_grad_enabled(phase == "train"):
61
                    y_pred = unet(x)
62
63
                    loss = dsc_loss(y_pred, y_true)
64
65
                    if phase == "valid":
66
                        loss_valid.append(loss.item())
67
                        y_pred_np = y_pred.detach().cpu().numpy()
68
                        validation_pred.extend(
69
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
70
                        )
71
                        y_true_np = y_true.detach().cpu().numpy()
72
                        validation_true.extend(
73
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
74
                        )
75
                        if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
76
                            if i * args.batch_size < args.vis_images:
77
                                tag = "image/{}".format(i)
78
                                num_images = args.vis_images - i * args.batch_size
79
                                logger.image_list_summary(
80
                                    tag,
81
                                    log_images(x, y_true, y_pred)[:num_images],
82
                                    step,
83
                                )
84
85
                    if phase == "train":
86
                        loss_train.append(loss.item())
87
                        loss.backward()
88
                        optimizer.step()
89
90
                if phase == "train" and (step + 1) % 10 == 0:
91
                    log_loss_summary(logger, loss_train, step)
92
                    loss_train = []
93
94
            if phase == "valid":
95
                log_loss_summary(logger, loss_valid, step, prefix="val_")
96
                mean_dsc = np.mean(
97
                    dsc_per_volume(
98
                        validation_pred,
99
                        validation_true,
100
                        loader_valid.dataset.patient_slice_index,
101
                    )
102
                )
103
                logger.scalar_summary("val_dsc", mean_dsc, step)
104
                if mean_dsc > best_validation_dsc:
105
                    best_validation_dsc = mean_dsc
106
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
107
                loss_valid = []
108
109
    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
110
111
112
def data_loaders(args):
113
    dataset_train, dataset_valid = datasets(args)
114
115
    def worker_init(worker_id):
116
        np.random.seed(42 + worker_id)
117
118
    loader_train = DataLoader(
119
        dataset_train,
120
        batch_size=args.batch_size,
121
        shuffle=True,
122
        drop_last=True,
123
        num_workers=args.workers,
124
        worker_init_fn=worker_init,
125
    )
126
    loader_valid = DataLoader(
127
        dataset_valid,
128
        batch_size=args.batch_size,
129
        drop_last=False,
130
        num_workers=args.workers,
131
        worker_init_fn=worker_init,
132
    )
133
134
    return loader_train, loader_valid
135
136
137
def datasets(args):
138
    train = Dataset(
139
        images_dir=args.images,
140
        subset="train",
141
        image_size=args.image_size,
142
        transform=transforms(scale=args.aug_scale, angle=args.aug_angle, flip_prob=0.5),
143
    )
144
    valid = Dataset(
145
        images_dir=args.images,
146
        subset="validation",
147
        image_size=args.image_size,
148
        random_sampling=False,
149
    )
150
    return train, valid
151
152
153
def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
154
    dsc_list = []
155
    num_slices = np.bincount([p[0] for p in patient_slice_index])
156
    index = 0
157
    for p in range(len(num_slices)):
158
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
159
        y_true = np.array(validation_true[index : index + num_slices[p]])
160
        dsc_list.append(dsc(y_pred, y_true))
161
        index += num_slices[p]
162
    return dsc_list
163
164
165
def log_loss_summary(logger, loss, step, prefix=""):
166
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)
167
168
169
def makedirs(args):
170
    os.makedirs(args.weights, exist_ok=True)
171
    os.makedirs(args.logs, exist_ok=True)
172
173
174
def snapshotargs(args):
175
    args_file = os.path.join(args.logs, "args.json")
176
    with open(args_file, "w") as fp:
177
        json.dump(vars(args), fp)
178
179
180
if __name__ == "__main__":
181
    parser = argparse.ArgumentParser(
182
        description="Training U-Net model for segmentation of brain MRI"
183
    )
184
    parser.add_argument(
185
        "--batch-size",
186
        type=int,
187
        default=16,
188
        help="input batch size for training (default: 16)",
189
    )
190
    parser.add_argument(
191
        "--epochs",
192
        type=int,
193
        default=100,
194
        help="number of epochs to train (default: 100)",
195
    )
196
    parser.add_argument(
197
        "--lr",
198
        type=float,
199
        default=0.0001,
200
        help="initial learning rate (default: 0.001)",
201
    )
202
    parser.add_argument(
203
        "--device",
204
        type=str,
205
        default="cuda:0",
206
        help="device for training (default: cuda:0)",
207
    )
208
    parser.add_argument(
209
        "--workers",
210
        type=int,
211
        default=4,
212
        help="number of workers for data loading (default: 4)",
213
    )
214
    parser.add_argument(
215
        "--vis-images",
216
        type=int,
217
        default=200,
218
        help="number of visualization images to save in log file (default: 200)",
219
    )
220
    parser.add_argument(
221
        "--vis-freq",
222
        type=int,
223
        default=10,
224
        help="frequency of saving images to log file (default: 10)",
225
    )
226
    parser.add_argument(
227
        "--weights", type=str, default="./weights", help="folder to save weights"
228
    )
229
    parser.add_argument(
230
        "--logs", type=str, default="./logs", help="folder to save logs"
231
    )
232
    parser.add_argument(
233
        "--images", type=str, default="./kaggle_3m", help="root folder with images"
234
    )
235
    parser.add_argument(
236
        "--image-size",
237
        type=int,
238
        default=256,
239
        help="target input image size (default: 256)",
240
    )
241
    parser.add_argument(
242
        "--aug-scale",
243
        type=int,
244
        default=0.05,
245
        help="scale factor range for augmentation (default: 0.05)",
246
    )
247
    parser.add_argument(
248
        "--aug-angle",
249
        type=int,
250
        default=15,
251
        help="rotation angle range in degrees for augmentation (default: 15)",
252
    )
253
    args = parser.parse_args()
254
    main(args)