# Mask R-CNN model for lesion segmentation in chest CT scans
# Torchvision detection package is locally re-implemented
# by Alex Ter-Sarkisov@City, University of London
# alex.ter-sarkisov@city.ac.uk
# 2020
import argparse
import time
import pickle
import torch
import torchvision
import numpy as np
import os, sys
import cv2
import models.mask_net as mask_net
from models.mask_net.rpn_segmentation import AnchorGenerator
from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from models.mask_net.covid_mask_net import MaskRCNNHeads, MaskRCNNPredictor
from torch.utils import data
import torch.utils as utils
import datasets.dataset_segmentation as dataset
from PIL import Image as PILImage
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import utils
import config_segmentation as config
# main method
def main(config, main_step):
devices = ['cpu', 'cuda']
mask_classes = ['both', 'ggo', 'merge']
truncation_levels = ['0','1','2']
backbones = ['resnet50', 'resnet34', 'resnet18']
assert config.backbone_name in backbones
assert config.mask_type in mask_classes
assert config.truncation in truncation_levels
# import arguments from the config file
start_epoch, model_name, use_pretrained_resnet_backbone, num_epochs, save_dir, train_data_dir, val_data_dir, imgs_dir, gt_dir, batch_size, device, save_every, lrate, rpn_nms, mask_type, backbone_name, truncation = \
config.start_epoch, config.model_name, config.use_pretrained_resnet_backbone, config.num_epochs, config.save_dir, \
config.train_data_dir, config.val_data_dir, config.imgs_dir, config.gt_dir, config.batch_size, config.device, config.save_every, config.lrate, config.rpn_nms_th, config.mask_type, config.backbone_name, config.truncation
assert device in devices
if not save_dir in os.listdir('.'):
os.mkdir(save_dir)
if batch_size > 1:
print("The model was implemented for batch size of one")
if device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(device)
# Load the weights if provided
if config.pretrained_model is not None:
pretrained_model = torch.load(config.pretrained_model, map_location = device)
use_pretrained_resnet_backbone = False
else:
pretrained_model=None
torch.manual_seed(time.time())
##############################################################################################
# DATASETS + DATALOADERS
# Alex: could be added in the config file in the future
# parameters for the dataset
dataset_covid_pars_train = {'stage': 'train', 'gt': os.path.join(train_data_dir, gt_dir),
'data': os.path.join(train_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True}
datapoint_covid_train = dataset.CovidCTData(**dataset_covid_pars_train)
dataset_covid_pars_eval = {'stage': 'eval', 'gt': os.path.join(val_data_dir, gt_dir),
'data': os.path.join(val_data_dir, imgs_dir), 'mask_type':mask_type, 'ignore_small':True}
datapoint_covid_eval = dataset.CovidCTData(**dataset_covid_pars_eval)
###############################################################################################
dataloader_covid_pars_train = {'shuffle': True, 'batch_size': batch_size}
dataloader_covid_train = data.DataLoader(datapoint_covid_train, **dataloader_covid_pars_train)
#
dataloader_covid_pars_eval = {'shuffle': False, 'batch_size': batch_size}
dataloader_covid_eval = data.DataLoader(datapoint_covid_eval, **dataloader_covid_pars_eval)
###############################################################################################
# MASK R-CNN model
# Alex: these settings could also be added to the config
if mask_type == "both":
n_c = 3
else:
n_c = 2
maskrcnn_args = {'min_size': 512, 'max_size': 1024, 'rpn_batch_size_per_image': 256, 'rpn_positive_fraction': 0.75,
'box_positive_fraction': 0.75, 'box_fg_iou_thresh': 0.75, 'box_bg_iou_thresh': 0.5,
'num_classes': None, 'box_batch_size_per_image': 256, 'rpn_nms_thresh': rpn_nms}
# Alex: for Ground glass opacity and consolidatin segmentation
# many small anchors
# use all outputs of FPN
# IMPORTANT!! For the pretrained weights, this determines the size of the anchor layer in RPN!!!!
# pretrained model must have anchors
if pretrained_model is None:
anchor_generator = AnchorGenerator(
sizes=tuple([(2, 4, 8, 16, 32) for r in range(5)]),
aspect_ratios=tuple([(0.1, 0.25, 0.5, 1, 1.5, 2) for rh in range(5)]))
else:
print("Loading the anchor generator")
sizes = pretrained_model['anchor_generator'].sizes
aspect_ratios = pretrained_model['anchor_generator'].aspect_ratios
anchor_generator = AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios)
print(anchor_generator, anchor_generator.num_anchors_per_location())
# num_classes:3 (1+2)
# in_channels
# 256: number if channels from FPN
# For the ResNet50+FPN: keep the torchvision architecture, but with 128 features
# For lightweights models: re-implement MaskRCNNHeads with a single layer
box_head = TwoMLPHead(in_channels=256*7*7,representation_size=128)
if backbone_name == 'resnet50':
maskrcnn_heads = None
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
mask_predictor = MaskRCNNPredictor(in_channels=256, dim_reduced=256, num_classes=n_c)
else:
#Backbone->FPN->boxhead->boxpredictor
box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
maskrcnn_heads = MaskRCNNHeads(in_channels=256, layers=(128,), dilation=1)
mask_predictor = MaskRCNNPredictor(in_channels=128, dim_reduced=128, num_classes=n_c)
maskrcnn_args['box_head'] = box_head
maskrcnn_args['rpn_anchor_generator'] = anchor_generator
maskrcnn_args['mask_head'] = maskrcnn_heads
maskrcnn_args['mask_predictor'] = mask_predictor
maskrcnn_args['box_predictor'] = box_predictor
# Instantiate the segmentation model
maskrcnn_model = mask_net.maskrcnn_resnet_fpn(backbone_name, truncation, pretrained_backbone=use_pretrained_resnet_backbone, **maskrcnn_args)
# pretrained?
print(maskrcnn_model.backbone.out_channels)
if pretrained_model is not None:
print("Loading pretrained weights")
maskrcnn_model.load_state_dict(pretrained_model['model_weights'])
if pretrained_model['epoch']:
start_epoch = int(pretrained_model['epoch'])+1
if 'model_name' in pretrained_model.keys():
model_name = str(pretrained_model['model_name'])
# Set to training mode
print(maskrcnn_model)
maskrcnn_model.train().to(device)
optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
optimizer = torch.optim.Adam(list(maskrcnn_model.parameters()), **optimizer_pars)
if pretrained_model is not None and 'optimizer_state' in pretrained_model.keys():
optimizer.load_state_dict(pretrained_model['optimizer_state'])
start_time = time.time()
if start_epoch>0:
num_epochs += start_epoch
print("Start training, epoch = {:d}".format(start_epoch))
for e in range(start_epoch, num_epochs):
train_loss_epoch = main_step("train", e, dataloader_covid_train, optimizer, device, maskrcnn_model, save_every,
lrate, model_name, None, None)
eval_loss_epoch = main_step("eval", e, dataloader_covid_eval, optimizer, device, maskrcnn_model, save_every, lrate, model_name, anchor_generator, save_dir)
print(
"Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch))
end_time = time.time()
print("Training took {0:.1f} seconds".format(end_time - start_time))
def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir):
epoch_loss = 0
for b in dataloader:
optimizer.zero_grad()
X, y = b
if device == torch.device('cuda'):
X, y['labels'], y['boxes'], y['masks'] = X.to(device), y['labels'].to(device), y['boxes'].to(device), y[
'masks'].to(device)
images = [im for im in X]
targets = []
lab = {}
# THIS IS IMPORTANT!!!!!
# get rid of the first dimension (batch)
# IF you have >1 images, make another loop
# REPEAT: DO NOT USE BATCH DIMENSION
lab['boxes'] = y['boxes'].squeeze_(0)
lab['labels'] = y['labels'].squeeze_(0)
lab['masks'] = y['masks'].squeeze_(0)
if len(lab['boxes']) > 0 and len(lab['labels']) > 0 and len(lab['masks']) > 0:
targets.append(lab)
else:
pass
# avoid empty objects
if len(targets) > 0:
loss = model(images, targets)
total_loss = 0
for k in loss.keys():
total_loss += loss[k]
if stage == "train":
total_loss.backward()
optimizer.step()
else:
pass
epoch_loss += total_loss.clone().detach().cpu().numpy()
epoch_loss = epoch_loss / len(dataloader)
if not (e+1) % save_every and stage == "eval":
model.eval()
state = {'epoch': str(e+1), 'model_name':model_name, 'model_weights': model.state_dict(),
'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator':anchors}
if model_name is None:
print(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth")
torch.save(state, os.path.join(save_dir, "mrcnn_covid_segmentation_model_ckpt_" + str(e+1) + ".pth"))
else:
torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth"))
model.train()
return epoch_loss
# run the training of the segmentation algoithm
if __name__ == '__main__':
config_train = config.get_config_pars("trainval")
main(config_train, step)