Diff of /train_classifier.py [000000] .. [a49583]

Switch to unified view

a b/train_classifier.py
1
# COVID-CT-Mask-Net
2
# I re-implemented Torchvision's detection library (Faster and Mask R-CNN) as a classifier
3
# Alex Ter-Sarkisov @ City, University of London
4
# alex.ter-sarkisov@city.ac.uk
5
#
6
import os
7
import pickle
8
import sys
9
import sys
10
import time
11
12
import config_classifier
13
import cv2
14
import datasets.dataset_classifier as dataset
15
# IMPORT LOCAL IMPLEMENTATION OF TORCHVISION'S DETECTION LIBRARY
16
import numpy as np
17
import torch
18
import torch.nn.functional as F
19
import torchvision
20
import utils
21
from PIL import Image as PILImage
22
# IMPORT LOCAL IMPLEMENTATION OF TORCHVISION'S DETECTION LIBRARY
23
# Faster R-CNN interface
24
import models.mask_net as mask_net
25
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
26
from models.mask_net.rpn import AnchorGenerator
27
from torch.utils import data
28
from torchvision import transforms
29
30
31
# main method
32
def main(config, main_step):
33
    torch.manual_seed(time.time())
34
    start_time = time.time()
35
    devices = ['cpu', 'cuda']
36
    backbones = ['resnet50', 'resnet34', 'resnet18']
37
    truncation_levels = ['0','1','2']
38
    assert config.device in devices
39
    assert config.backbone_name in backbones
40
    assert config.truncation in truncation_levels
41
42
    start_epoch, pretrained_classifier, pretrained_segment, model_name, num_epochs, save_dir, train_data_dir, val_data_dir, \
43
    batch_size, device, save_every, lrate, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features = \
44
                                            config.start_epoch, config.pretrained_classification_model, \
45
                                            config.pretrained_segmentation_model, \
46
                                            config.model_name, config.num_epochs, config.save_dir, \
47
                                            config.train_data_dir, config.val_data_dir, \
48
                                            config.batch_size, config.device, config.save_every, \
49
                                            config.lrate, config.rpn_nms_th, config.roi_nms_th, \
50
                                            config.backbone_name, config.truncation, \
51
                                            config.roi_batch_size, config.num_classes, config.s_features
52
53
    if pretrained_classifier is not None and pretrained_segment is not None:
54
        print("Not clear which model to use, switching to the classifier")
55
        pretrained_model = pretrained_classifier
56
    elif pretrained_classifier is not None and pretrained_segment is None:
57
        pretrained_model = pretrained_classifier
58
    else:
59
        pretrained_model = pretrained_segment
60
61
    if device == 'cuda' and torch.cuda.is_available():
62
        device = torch.device('cuda')
63
    else:
64
        device = torch.device('cpu')
65
    ##############################################################################################
66
    # DATASETS+DATALOADERS
67
    # Alex: could be added in the config file in the future
68
    # parameters for the dataset
69
    # 512x512 is the recommended image size input
70
    dataset_covid_pars_train_cl = {'stage': 'train', 'data': train_data_dir, 'img_size': (512,512)}
71
    datapoint_covid_train_cl = dataset.COVID_CT_DATA(**dataset_covid_pars_train_cl)
72
    #
73
    dataset_covid_pars_eval_cl = {'stage': 'eval', 'data': val_data_dir, 'img_size': (512,512)}
74
    datapoint_covid_eval_cl = dataset.COVID_CT_DATA(**dataset_covid_pars_eval_cl)
75
    #
76
    dataloader_covid_pars_train_cl = {'shuffle': True, 'batch_size': batch_size}
77
    dataloader_covid_train_cl = data.DataLoader(datapoint_covid_train_cl, **dataloader_covid_pars_train_cl)
78
    #
79
    dataloader_covid_pars_eval_cl = {'shuffle': True, 'batch_size': batch_size}
80
    dataloader_covid_eval_cl = data.DataLoader(datapoint_covid_eval_cl, **dataloader_covid_pars_eval_cl)
81
    #
82
    ##### LOAD PRETRAINED WEIGHTS FROM MASK R-CNN MODEL
83
    # This must be the full path to the checkpoint with the anchor generator and model weights
84
    # Assumed that the keys in the checkpoint are model_weights and anchor_generator
85
    ckpt = torch.load(pretrained_model, map_location=device)
86
    # keyword arguments
87
    # box_score_threshold:negative!
88
    # set both NMS thresholds to 0.75 to get adjacent RoIs
89
    # Box detections/image: batch size for the classifier
90
    #
91
    covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
92
                           'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
93
94
    # copy the anchor generator parameters, create a new one to avoid implementations' clash
95
    sizes = ckpt['anchor_generator'].sizes
96
    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
97
    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
98
    # out_channels:256, FPN
99
    # num_classes:3 (1+2)
100
    box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
101
    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
102
    
103
    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
104
    covid_mask_net_args['box_predictor'] = box_predictor
105
    covid_mask_net_args['box_head'] = box_head
106
    covid_mask_net_args['s_representation_size'] = s_features
107
    # Instantiate the model
108
    covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
109
    # which parameters to train?
110
    trained_pars = []
111
    # if the weights are loaded from the segmentation model:
112
    if pretrained_classifier is None:
113
        for _n, _par in covid_mask_net_model.state_dict().items():
114
            if _n in ckpt['model_weights']:
115
                print('Loading parameter', _n)
116
                _par.copy_(ckpt['model_weights'][_n])
117
    # if the weights are loaded from the classification model
118
    else:
119
        covid_mask_net_model.load_state_dict(ckpt['model_weights'])
120
        if 'epoch' in ckpt.keys():
121
            start_epoch = int(ckpt['epoch']) + 1
122
        if 'model_name' in ckpt.keys():
123
            model_name = str(ckpt['model_name'])
124
125
    # Evaluation mode, no labels!
126
    covid_mask_net_model.eval()
127
    # set the model to training mode without triggering the 'training' mode of Mask R-CNN
128
    # set up the optimizer
129
    utils.switch_model_on(covid_mask_net_model, ckpt, trained_pars)
130
    utils.set_to_train_mode(covid_mask_net_model)
131
    print(covid_mask_net_model)
132
    covid_mask_net_model = covid_mask_net_model.to(device)
133
    total_trained_pars = sum([x.numel() for x in trained_pars])
134
    print("Total trained pars {0:d}".format(total_trained_pars))
135
    optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
136
    optimizer = torch.optim.Adam(trained_pars, **optimizer_pars)
137
    if pretrained_classifier is not None and 'optimizer_state' in ckpt.keys():
138
        optimizer.load_state_dict(ckpt['optimizer_state'])
139
140
    if start_epoch>0:
141
       num_epochs += start_epoch
142
    print("Start training, epoch = {:d}".format(start_epoch))
143
    for e in range(start_epoch, num_epochs):
144
        train_loss_epoch = main_step("train", e, dataloader_covid_train_cl, optimizer, device, covid_mask_net_model,
145
                                     save_every, lrate, model_name, None, None)
146
        eval_loss_epoch = main_step("eval", e, dataloader_covid_eval_cl, optimizer, device, covid_mask_net_model,
147
                                    save_every, lrate, model_name, anchor_generator, save_dir)
148
        print(
149
            "Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch))
150
    end_time = time.time()
151
    print("Training took {0:.1f} seconds".format(end_time - start_time))
152
153
154
def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir):
155
    epoch_loss = 0
156
    for id, b in enumerate(dataloader):
157
        optimizer.zero_grad()
158
        X, y = b
159
        if device == torch.device('cuda'):
160
            X, y = X.to(device), y.to(device)
161
        # some batches are less than batch_size
162
        batch_s = X.size()[0]
163
        batch_scores = []
164
        # input all images in the batch into COVID-Mask-Net to get B scores
165
        for id in range(batch_s):
166
            image = [X[id]]  # remove the batch dimension
167
            predict_scores = model(image)
168
            batch_scores.append(predict_scores[0]['final_scores'])
169
        # batchify scores/image and compute binary cross-entropy loss
170
        batch_scores = torch.stack(batch_scores)
171
        batch_loss = F.binary_cross_entropy_with_logits(batch_scores, y)
172
        if stage == "train":
173
            batch_loss.backward()
174
            optimizer.step()
175
        else:
176
            pass
177
        epoch_loss += batch_loss.clone().detach().cpu().numpy()
178
    epoch_loss = epoch_loss / len(dataloader)
179
    if not (e+1) % save_every and stage == "eval":
180
        model.eval()
181
        state = {'epoch': str(e+1), 'model_weights': model.state_dict(),
182
                 'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator': anchors,
183
                 'model_name': model_name}
184
        if model_name is None:
185
            torch.save(state, os.path.join(save_dir, "covid_ct_mask_net_ckpt_" + str(e+1) + ".pth"))
186
        else:
187
            torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth"))
188
        utils.set_to_train_mode(model)
189
    return epoch_loss
190
191
192
# run the training
193
if __name__ == '__main__':
194
    config_train = config_classifier.get_config_pars_classifier("trainval")
195
    if config_train.pretrained_classification_model is None and config_train.pretrained_segmentation_model is None:
196
        print("You must have at least one pretrained model!")
197
        sys.exit(0)
198
    else:
199
        main(config_train, step)