Diff of /train_segmentation.py [000000] .. [a49583]

Switch to unified view

a b/train_segmentation.py
1
# Mask R-CNN model for lesion segmentation in chest CT scans
2
# Torchvision detection package is locally re-implemented
3
# by Alex Ter-Sarkisov@City, University of London
4
# alex.ter-sarkisov@city.ac.uk
5
# 2020
6
import argparse
7
import time
8
import pickle
9
import torch
10
import torchvision
11
import numpy as np
12
import os, sys
13
import cv2
14
import models.mask_net as mask_net
15
from models.mask_net.rpn_segmentation import AnchorGenerator
16
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
17
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
18
from torch.utils import data
19
import torch.utils as utils
20
import datasets.dataset_segmentation as dataset
21
from PIL import Image as PILImage
22
import matplotlib.pyplot as plt
23
import matplotlib.patches as patches
24
from matplotlib.patches import Rectangle
25
import utils
26
import config_segmentation as config
27
28
29
# main method
30
def main(config, main_step):
31
    devices = ['cpu', 'cuda']
32
    mask_classes = ['both', 'ggo', 'merge']
33
    truncation_levels = ['0','1','2']
34
    backbones = ['resnet50', 'resnet34', 'resnet18']
35
    assert config.backbone_name in backbones
36
    assert config.mask_type in mask_classes
37
    assert config.truncation in truncation_levels
38
39
    # import arguments from the config file
40
    start_epoch, model_name, use_pretrained_resnet_backbone, num_epochs, save_dir, train_data_dir, val_data_dir, imgs_dir, gt_dir, batch_size, device, save_every, lrate, rpn_nms, mask_type, backbone_name, truncation = \
41
        config.start_epoch, config.model_name, config.use_pretrained_resnet_backbone, config.num_epochs, config.save_dir, \
42
        config.train_data_dir, config.val_data_dir, config.imgs_dir, config.gt_dir, config.batch_size, config.device, config.save_every, config.lrate, config.rpn_nms_th, config.mask_type, config.backbone_name, config.truncation
43
44
    assert device in devices
45
    if not save_dir in os.listdir('.'):
46
       os.mkdir(save_dir)
47
48
    if batch_size > 1:
49
        print("The model was implemented for batch size of one")
50
    if device == 'cuda' and torch.cuda.is_available():
51
        device = torch.device('cuda')
52
    else:
53
        device = torch.device('cpu')
54
55
    print(device)
56
57
    # Load the weights if provided
58
    if config.pretrained_model is not None:
59
       pretrained_model = torch.load(config.pretrained_model, map_location = device)
60
       use_pretrained_resnet_backbone = False
61
    else:
62
       pretrained_model=None
63
    torch.manual_seed(time.time())
64
    ##############################################################################################
65
    # DATASETS + DATALOADERS
66
    # Alex: could be added in the config file in the future
67
    # parameters for the dataset
68
    dataset_covid_pars_train = {'stage': 'train', 'gt': os.path.join(train_data_dir, gt_dir),
69
                                'data': os.path.join(train_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True}
70
    datapoint_covid_train = dataset.CovidCTData(**dataset_covid_pars_train)
71
72
    dataset_covid_pars_eval = {'stage': 'eval', 'gt': os.path.join(val_data_dir, gt_dir),
73
                               'data': os.path.join(val_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True}
74
    datapoint_covid_eval = dataset.CovidCTData(**dataset_covid_pars_eval)
75
    ###############################################################################################
76
    dataloader_covid_pars_train = {'shuffle': True, 'batch_size': batch_size}
77
    dataloader_covid_train = data.DataLoader(datapoint_covid_train, **dataloader_covid_pars_train)
78
    #
79
    dataloader_covid_pars_eval = {'shuffle': False, 'batch_size': batch_size}
80
    dataloader_covid_eval = data.DataLoader(datapoint_covid_eval, **dataloader_covid_pars_eval)
81
    ###############################################################################################
82
    # MASK R-CNN model
83
    # Alex: these settings could also be added to the config
84
    if mask_type == "both":
85
        n_c = 3
86
    else:
87
        n_c = 2
88
    maskrcnn_args = {'min_size': 512, 'max_size': 1024, 'rpn_batch_size_per_image': 256, 'rpn_positive_fraction': 0.75,
89
                     'box_positive_fraction': 0.75, 'box_fg_iou_thresh': 0.75, 'box_bg_iou_thresh': 0.5,
90
                     'num_classes': None, 'box_batch_size_per_image': 256, 'rpn_nms_thresh': rpn_nms}
91
92
    # Alex: for Ground glass opacity and consolidatin segmentation
93
    # many small anchors
94
    # use all outputs of FPN
95
    # IMPORTANT!! For the pretrained weights, this determines the size of the anchor layer in RPN!!!!
96
    # pretrained model must have anchors
97
    if pretrained_model is None:
98
       anchor_generator = AnchorGenerator(
99
           sizes=tuple([(2, 4, 8, 16, 32) for r in range(5)]),
100
           aspect_ratios=tuple([(0.1, 0.25, 0.5, 1, 1.5, 2) for rh in range(5)]))
101
    else:
102
       print("Loading the anchor generator")
103
       sizes = pretrained_model['anchor_generator'].sizes
104
       aspect_ratios = pretrained_model['anchor_generator'].aspect_ratios
105
       anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)
106
       print(anchor_generator, anchor_generator.num_anchors_per_location())
107
    # num_classes:3 (1+2)
108
    # in_channels
109
    # 256: number if channels from FPN
110
    # For the ResNet50+FPN: keep the torchvision architecture, but with 128 features
111
    # For lightweights models: re-implement MaskRCNNHeads with a single layer
112
    box_head = TwoMLPHead(in_channels=256*7*7,representation_size=128)
113
    if backbone_name == 'resnet50':
114
       maskrcnn_heads = None
115
       box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
116
       mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
117
    else:
118
       #Backbone->FPN->boxhead->boxpredictor
119
       box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
120
       maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
121
       mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
122
123
    maskrcnn_args['box_head'] = box_head
124
    maskrcnn_args['rpn_anchor_generator'] = anchor_generator
125
    maskrcnn_args['mask_head'] = maskrcnn_heads
126
    maskrcnn_args['mask_predictor'] = mask_predictor
127
    maskrcnn_args['box_predictor'] = box_predictor
128
    # Instantiate the segmentation model
129
    maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=use_pretrained_resnet_backbone, **maskrcnn_args)
130
    # pretrained?
131
    print(maskrcnn_model.backbone.out_channels)
132
    if pretrained_model is not None:
133
        print("Loading pretrained weights")
134
        maskrcnn_model.load_state_dict(pretrained_model['model_weights'])
135
        if pretrained_model['epoch']:
136
           start_epoch = int(pretrained_model['epoch'])+1
137
        if 'model_name' in pretrained_model.keys():
138
           model_name = str(pretrained_model['model_name'])
139
140
    # Set to training mode
141
    print(maskrcnn_model)
142
    maskrcnn_model.train().to(device)
143
144
    optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
145
    optimizer = torch.optim.Adam(list(maskrcnn_model.parameters()), **optimizer_pars)
146
    if pretrained_model is not None and 'optimizer_state' in pretrained_model.keys():
147
       optimizer.load_state_dict(pretrained_model['optimizer_state'])
148
149
    start_time = time.time()
150
    if start_epoch>0:
151
       num_epochs += start_epoch
152
    print("Start training, epoch = {:d}".format(start_epoch))
153
    for e in range(start_epoch, num_epochs):
154
        train_loss_epoch = main_step("train", e, dataloader_covid_train, optimizer, device, maskrcnn_model, save_every,
155
                                lrate, model_name, None, None)
156
        eval_loss_epoch = main_step("eval", e, dataloader_covid_eval, optimizer, device, maskrcnn_model, save_every, lrate, model_name, anchor_generator, save_dir)
157
        print(
158
            "Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch))
159
    end_time = time.time()
160
    print("Training took {0:.1f} seconds".format(end_time - start_time))
161
162
163
def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir):
164
    epoch_loss = 0
165
    for b in dataloader:
166
        optimizer.zero_grad()
167
        X, y = b
168
        if device == torch.device('cuda'):
169
            X, y['labels'], y['boxes'], y['masks'] = X.to(device), y['labels'].to(device), y['boxes'].to(device), y[
170
                'masks'].to(device)
171
        images = [im for im in X]
172
        targets = []
173
        lab = {}
174
        # THIS IS IMPORTANT!!!!!
175
        # get rid of the first dimension (batch)
176
        # IF you have >1 images, make another loop
177
        # REPEAT: DO NOT USE BATCH DIMENSION
178
        lab['boxes'] = y['boxes'].squeeze_(0)
179
        lab['labels'] = y['labels'].squeeze_(0)
180
        lab['masks'] = y['masks'].squeeze_(0)
181
        if len(lab['boxes']) > 0 and len(lab['labels']) > 0 and len(lab['masks']) > 0:
182
            targets.append(lab)
183
        else:
184
            pass
185
        # avoid empty objects
186
        if len(targets) > 0:
187
            loss = model(images, targets)
188
            total_loss = 0
189
            for k in loss.keys():
190
                total_loss += loss[k]
191
            if stage == "train":
192
                total_loss.backward()
193
                optimizer.step()
194
            else:
195
                pass
196
            epoch_loss += total_loss.clone().detach().cpu().numpy()
197
    epoch_loss = epoch_loss / len(dataloader)
198
    if not (e+1) % save_every and stage == "eval":
199
        model.eval()
200
        state = {'epoch': str(e+1), 'model_name':model_name, 'model_weights': model.state_dict(),
201
                 'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator':anchors}
202
        if model_name is None:
203
            print(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth")
204
            torch.save(state, os.path.join(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth"))
205
        else:
206
            torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth"))
207
208
        model.train()
209
    return epoch_loss
210
211
212
# run the training of the segmentation algoithm
213
if __name__ == '__main__':
214
    config_train = config.get_config_pars("trainval")
215
    main(config_train, step)
216
217
218