--- a +++ b/train_segmentation.py @@ -0,0 +1,218 @@ +# Mask R-CNN model for lesion segmentation in chest CT scans +# Torchvision detection package is locally re-implemented +# by Alex Ter-Sarkisov@City, University of London +# alex.ter-sarkisov@city.ac.uk +# 2020 +import argparse +import time +import pickle +import torch +import torchvision +import numpy as np +import os, sys +import cv2 +import models.mask_net as mask_net +from models.mask_net.rpn_segmentation import AnchorGenerator +from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead +from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor +from torch.utils import data +import torch.utils as utils +import datasets.dataset_segmentation as dataset +from PIL import Image as PILImage +import matplotlib.pyplot as plt +import matplotlib.patches as patches +from matplotlib.patches import Rectangle +import utils +import config_segmentation as config + + +# main method +def main(config, main_step): + devices = ['cpu', 'cuda'] + mask_classes = ['both', 'ggo', 'merge'] + truncation_levels = ['0','1','2'] + backbones = ['resnet50', 'resnet34', 'resnet18'] + assert config.backbone_name in backbones + assert config.mask_type in mask_classes + assert config.truncation in truncation_levels + + # import arguments from the config file + start_epoch, model_name, use_pretrained_resnet_backbone, num_epochs, save_dir, train_data_dir, val_data_dir, imgs_dir, gt_dir, batch_size, device, save_every, lrate, rpn_nms, mask_type, backbone_name, truncation = \ + config.start_epoch, config.model_name, config.use_pretrained_resnet_backbone, config.num_epochs, config.save_dir, \ + config.train_data_dir, config.val_data_dir, config.imgs_dir, config.gt_dir, config.batch_size, config.device, config.save_every, config.lrate, config.rpn_nms_th, config.mask_type, config.backbone_name, config.truncation + + assert device in devices + if not save_dir in os.listdir('.'): + os.mkdir(save_dir) + + if batch_size > 1: + print("The model was implemented for batch size of one") + if device == 'cuda' and torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + print(device) + + # Load the weights if provided + if config.pretrained_model is not None: + pretrained_model = torch.load(config.pretrained_model, map_location = device) + use_pretrained_resnet_backbone = False + else: + pretrained_model=None + torch.manual_seed(time.time()) + ############################################################################################## + # DATASETS + DATALOADERS + # Alex: could be added in the config file in the future + # parameters for the dataset + dataset_covid_pars_train = {'stage': 'train', 'gt': os.path.join(train_data_dir, gt_dir), + 'data': os.path.join(train_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True} + datapoint_covid_train = dataset.CovidCTData(**dataset_covid_pars_train) + + dataset_covid_pars_eval = {'stage': 'eval', 'gt': os.path.join(val_data_dir, gt_dir), + 'data': os.path.join(val_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True} + datapoint_covid_eval = dataset.CovidCTData(**dataset_covid_pars_eval) + ############################################################################################### + dataloader_covid_pars_train = {'shuffle': True, 'batch_size': batch_size} + dataloader_covid_train = data.DataLoader(datapoint_covid_train, **dataloader_covid_pars_train) + # + dataloader_covid_pars_eval = {'shuffle': False, 'batch_size': batch_size} + dataloader_covid_eval = data.DataLoader(datapoint_covid_eval, **dataloader_covid_pars_eval) + ############################################################################################### + # MASK R-CNN model + # Alex: these settings could also be added to the config + if mask_type == "both": + n_c = 3 + else: + n_c = 2 + maskrcnn_args = {'min_size': 512, 'max_size': 1024, 'rpn_batch_size_per_image': 256, 'rpn_positive_fraction': 0.75, + 'box_positive_fraction': 0.75, 'box_fg_iou_thresh': 0.75, 'box_bg_iou_thresh': 0.5, + 'num_classes': None, 'box_batch_size_per_image': 256, 'rpn_nms_thresh': rpn_nms} + + # Alex: for Ground glass opacity and consolidatin segmentation + # many small anchors + # use all outputs of FPN + # IMPORTANT!! For the pretrained weights, this determines the size of the anchor layer in RPN!!!! + # pretrained model must have anchors + if pretrained_model is None: + anchor_generator = AnchorGenerator( + sizes=tuple([(2, 4, 8, 16, 32) for r in range(5)]), + aspect_ratios=tuple([(0.1, 0.25, 0.5, 1, 1.5, 2) for rh in range(5)])) + else: + print("Loading the anchor generator") + sizes = pretrained_model['anchor_generator'].sizes + aspect_ratios = pretrained_model['anchor_generator'].aspect_ratios + anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios) + print(anchor_generator, anchor_generator.num_anchors_per_location()) + # num_classes:3 (1+2) + # in_channels + # 256: number if channels from FPN + # For the ResNet50+FPN: keep the torchvision architecture, but with 128 features + # For lightweights models: re-implement MaskRCNNHeads with a single layer + box_head = TwoMLPHead(in_channels=256*7*7,representation_size=128) + if backbone_name == 'resnet50': + maskrcnn_heads = None + box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) + mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c) + else: + #Backbone->FPN->boxhead->boxpredictor + box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c) + maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1) + mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c) + + maskrcnn_args['box_head'] = box_head + maskrcnn_args['rpn_anchor_generator'] = anchor_generator + maskrcnn_args['mask_head'] = maskrcnn_heads + maskrcnn_args['mask_predictor'] = mask_predictor + maskrcnn_args['box_predictor'] = box_predictor + # Instantiate the segmentation model + maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=use_pretrained_resnet_backbone, **maskrcnn_args) + # pretrained? + print(maskrcnn_model.backbone.out_channels) + if pretrained_model is not None: + print("Loading pretrained weights") + maskrcnn_model.load_state_dict(pretrained_model['model_weights']) + if pretrained_model['epoch']: + start_epoch = int(pretrained_model['epoch'])+1 + if 'model_name' in pretrained_model.keys(): + model_name = str(pretrained_model['model_name']) + + # Set to training mode + print(maskrcnn_model) + maskrcnn_model.train().to(device) + + optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3} + optimizer = torch.optim.Adam(list(maskrcnn_model.parameters()), **optimizer_pars) + if pretrained_model is not None and 'optimizer_state' in pretrained_model.keys(): + optimizer.load_state_dict(pretrained_model['optimizer_state']) + + start_time = time.time() + if start_epoch>0: + num_epochs += start_epoch + print("Start training, epoch = {:d}".format(start_epoch)) + for e in range(start_epoch, num_epochs): + train_loss_epoch = main_step("train", e, dataloader_covid_train, optimizer, device, maskrcnn_model, save_every, + lrate, model_name, None, None) + eval_loss_epoch = main_step("eval", e, dataloader_covid_eval, optimizer, device, maskrcnn_model, save_every, lrate, model_name, anchor_generator, save_dir) + print( + "Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch)) + end_time = time.time() + print("Training took {0:.1f} seconds".format(end_time - start_time)) + + +def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir): + epoch_loss = 0 + for b in dataloader: + optimizer.zero_grad() + X, y = b + if device == torch.device('cuda'): + X, y['labels'], y['boxes'], y['masks'] = X.to(device), y['labels'].to(device), y['boxes'].to(device), y[ + 'masks'].to(device) + images = [im for im in X] + targets = [] + lab = {} + # THIS IS IMPORTANT!!!!! + # get rid of the first dimension (batch) + # IF you have >1 images, make another loop + # REPEAT: DO NOT USE BATCH DIMENSION + lab['boxes'] = y['boxes'].squeeze_(0) + lab['labels'] = y['labels'].squeeze_(0) + lab['masks'] = y['masks'].squeeze_(0) + if len(lab['boxes']) > 0 and len(lab['labels']) > 0 and len(lab['masks']) > 0: + targets.append(lab) + else: + pass + # avoid empty objects + if len(targets) > 0: + loss = model(images, targets) + total_loss = 0 + for k in loss.keys(): + total_loss += loss[k] + if stage == "train": + total_loss.backward() + optimizer.step() + else: + pass + epoch_loss += total_loss.clone().detach().cpu().numpy() + epoch_loss = epoch_loss / len(dataloader) + if not (e+1) % save_every and stage == "eval": + model.eval() + state = {'epoch': str(e+1), 'model_name':model_name, 'model_weights': model.state_dict(), + 'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator':anchors} + if model_name is None: + print(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth") + torch.save(state, os.path.join(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth")) + else: + torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth")) + + model.train() + return epoch_loss + + +# run the training of the segmentation algoithm +if __name__ == '__main__': + config_train = config.get_config_pars("trainval") + main(config_train, step) + + +