Diff of /train_classifier.py [000000] .. [a49583]

Switch to side-by-side view

--- a
+++ b/train_classifier.py
@@ -0,0 +1,199 @@
+# COVID-CT-Mask-Net
+# I re-implemented Torchvision's detection library (Faster and Mask R-CNN) as a classifier
+# Alex Ter-Sarkisov @ City, University of London
+# alex.ter-sarkisov@city.ac.uk
+#
+import os
+import pickle
+import sys
+import sys
+import time
+
+import config_classifier
+import cv2
+import datasets.dataset_classifier as dataset
+# IMPORT LOCAL IMPLEMENTATION OF TORCHVISION'S DETECTION LIBRARY
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import utils
+from PIL import Image as PILImage
+# IMPORT LOCAL IMPLEMENTATION OF TORCHVISION'S DETECTION LIBRARY
+# Faster R-CNN interface
+import models.mask_net as mask_net
+from models.mask_net.faster_rcnn import FastRCNNPredictor, TwoMLPHead
+from models.mask_net.rpn import AnchorGenerator
+from torch.utils import data
+from torchvision import transforms
+
+
+# main method
+def main(config, main_step):
+    torch.manual_seed(time.time())
+    start_time = time.time()
+    devices = ['cpu', 'cuda']
+    backbones = ['resnet50', 'resnet34', 'resnet18']
+    truncation_levels = ['0','1','2']
+    assert config.device in devices
+    assert config.backbone_name in backbones
+    assert config.truncation in truncation_levels
+
+    start_epoch, pretrained_classifier, pretrained_segment, model_name, num_epochs, save_dir, train_data_dir, val_data_dir, \
+    batch_size, device, save_every, lrate, rpn_nms, roi_nms, backbone_name, truncation, roi_batch_size, n_c, s_features = \
+                                            config.start_epoch, config.pretrained_classification_model, \
+                                            config.pretrained_segmentation_model, \
+                                            config.model_name, config.num_epochs, config.save_dir, \
+                                            config.train_data_dir, config.val_data_dir, \
+                                            config.batch_size, config.device, config.save_every, \
+                                            config.lrate, config.rpn_nms_th, config.roi_nms_th, \
+                                            config.backbone_name, config.truncation, \
+                                            config.roi_batch_size, config.num_classes, config.s_features
+
+    if pretrained_classifier is not None and pretrained_segment is not None:
+        print("Not clear which model to use, switching to the classifier")
+        pretrained_model = pretrained_classifier
+    elif pretrained_classifier is not None and pretrained_segment is None:
+        pretrained_model = pretrained_classifier
+    else:
+        pretrained_model = pretrained_segment
+
+    if device == 'cuda' and torch.cuda.is_available():
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+    ##############################################################################################
+    # DATASETS+DATALOADERS
+    # Alex: could be added in the config file in the future
+    # parameters for the dataset
+    # 512x512 is the recommended image size input
+    dataset_covid_pars_train_cl = {'stage': 'train', 'data': train_data_dir, 'img_size': (512,512)}
+    datapoint_covid_train_cl = dataset.COVID_CT_DATA(**dataset_covid_pars_train_cl)
+    #
+    dataset_covid_pars_eval_cl = {'stage': 'eval', 'data': val_data_dir, 'img_size': (512,512)}
+    datapoint_covid_eval_cl = dataset.COVID_CT_DATA(**dataset_covid_pars_eval_cl)
+    #
+    dataloader_covid_pars_train_cl = {'shuffle': True, 'batch_size': batch_size}
+    dataloader_covid_train_cl = data.DataLoader(datapoint_covid_train_cl, **dataloader_covid_pars_train_cl)
+    #
+    dataloader_covid_pars_eval_cl = {'shuffle': True, 'batch_size': batch_size}
+    dataloader_covid_eval_cl = data.DataLoader(datapoint_covid_eval_cl, **dataloader_covid_pars_eval_cl)
+    #
+    ##### LOAD PRETRAINED WEIGHTS FROM MASK R-CNN MODEL
+    # This must be the full path to the checkpoint with the anchor generator and model weights
+    # Assumed that the keys in the checkpoint are model_weights and anchor_generator
+    ckpt = torch.load(pretrained_model, map_location=device)
+    # keyword arguments
+    # box_score_threshold:negative!
+    # set both NMS thresholds to 0.75 to get adjacent RoIs
+    # Box detections/image: batch size for the classifier
+    #
+    covid_mask_net_args = {'num_classes': None, 'min_size': 512, 'max_size': 1024, 'box_detections_per_img': roi_batch_size,
+                           'box_nms_thresh': roi_nms, 'box_score_thresh': -0.01, 'rpn_nms_thresh': rpn_nms}
+
+    # copy the anchor generator parameters, create a new one to avoid implementations' clash
+    sizes = ckpt['anchor_generator'].sizes
+    aspect_ratios = ckpt['anchor_generator'].aspect_ratios
+    anchor_generator = AnchorGenerator(sizes, aspect_ratios)
+    # out_channels:256, FPN
+    # num_classes:3 (1+2)
+    box_head = TwoMLPHead(in_channels=256*7*7, representation_size=128)
+    box_predictor = FastRCNNPredictor(in_channels=128, num_classes=n_c)
+    
+    covid_mask_net_args['rpn_anchor_generator'] = anchor_generator
+    covid_mask_net_args['box_predictor'] = box_predictor
+    covid_mask_net_args['box_head'] = box_head
+    covid_mask_net_args['s_representation_size'] = s_features
+    # Instantiate the model
+    covid_mask_net_model = mask_net.fasterrcnn_resnet_fpn(backbone_name, truncation, **covid_mask_net_args)
+    # which parameters to train?
+    trained_pars = []
+    # if the weights are loaded from the segmentation model:
+    if pretrained_classifier is None:
+        for _n, _par in covid_mask_net_model.state_dict().items():
+            if _n in ckpt['model_weights']:
+                print('Loading parameter', _n)
+                _par.copy_(ckpt['model_weights'][_n])
+    # if the weights are loaded from the classification model
+    else:
+        covid_mask_net_model.load_state_dict(ckpt['model_weights'])
+        if 'epoch' in ckpt.keys():
+            start_epoch = int(ckpt['epoch']) + 1
+        if 'model_name' in ckpt.keys():
+            model_name = str(ckpt['model_name'])
+
+    # Evaluation mode, no labels!
+    covid_mask_net_model.eval()
+    # set the model to training mode without triggering the 'training' mode of Mask R-CNN
+    # set up the optimizer
+    utils.switch_model_on(covid_mask_net_model, ckpt, trained_pars)
+    utils.set_to_train_mode(covid_mask_net_model)
+    print(covid_mask_net_model)
+    covid_mask_net_model = covid_mask_net_model.to(device)
+    total_trained_pars = sum([x.numel() for x in trained_pars])
+    print("Total trained pars {0:d}".format(total_trained_pars))
+    optimizer_pars = {'lr': lrate, 'weight_decay': 1e-3}
+    optimizer = torch.optim.Adam(trained_pars, **optimizer_pars)
+    if pretrained_classifier is not None and 'optimizer_state' in ckpt.keys():
+        optimizer.load_state_dict(ckpt['optimizer_state'])
+
+    if start_epoch>0:
+       num_epochs += start_epoch
+    print("Start training, epoch = {:d}".format(start_epoch))
+    for e in range(start_epoch, num_epochs):
+        train_loss_epoch = main_step("train", e, dataloader_covid_train_cl, optimizer, device, covid_mask_net_model,
+                                     save_every, lrate, model_name, None, None)
+        eval_loss_epoch = main_step("eval", e, dataloader_covid_eval_cl, optimizer, device, covid_mask_net_model,
+                                    save_every, lrate, model_name, anchor_generator, save_dir)
+        print(
+            "Epoch {0:d}: train loss = {1:.3f}, validation loss = {2:.3f}".format(e, train_loss_epoch, eval_loss_epoch))
+    end_time = time.time()
+    print("Training took {0:.1f} seconds".format(end_time - start_time))
+
+
+def step(stage, e, dataloader, optimizer, device, model, save_every, lrate, model_name, anchors, save_dir):
+    epoch_loss = 0
+    for id, b in enumerate(dataloader):
+        optimizer.zero_grad()
+        X, y = b
+        if device == torch.device('cuda'):
+            X, y = X.to(device), y.to(device)
+        # some batches are less than batch_size
+        batch_s = X.size()[0]
+        batch_scores = []
+        # input all images in the batch into COVID-Mask-Net to get B scores
+        for id in range(batch_s):
+            image = [X[id]]  # remove the batch dimension
+            predict_scores = model(image)
+            batch_scores.append(predict_scores[0]['final_scores'])
+        # batchify scores/image and compute binary cross-entropy loss
+        batch_scores = torch.stack(batch_scores)
+        batch_loss = F.binary_cross_entropy_with_logits(batch_scores, y)
+        if stage == "train":
+            batch_loss.backward()
+            optimizer.step()
+        else:
+            pass
+        epoch_loss += batch_loss.clone().detach().cpu().numpy()
+    epoch_loss = epoch_loss / len(dataloader)
+    if not (e+1) % save_every and stage == "eval":
+        model.eval()
+        state = {'epoch': str(e+1), 'model_weights': model.state_dict(),
+                 'optimizer_state': optimizer.state_dict(), 'lrate': lrate, 'anchor_generator': anchors,
+                 'model_name': model_name}
+        if model_name is None:
+            torch.save(state, os.path.join(save_dir, "covid_ct_mask_net_ckpt_" + str(e+1) + ".pth"))
+        else:
+            torch.save(state, os.path.join(save_dir, model_name + "_ckpt_" + str(e+1) + ".pth"))
+        utils.set_to_train_mode(model)
+    return epoch_loss
+
+
+# run the training
+if __name__ == '__main__':
+    config_train = config_classifier.get_config_pars_classifier("trainval")
+    if config_train.pretrained_classification_model is None and config_train.pretrained_segmentation_model is None:
+        print("You must have at least one pretrained model!")
+        sys.exit(0)
+    else:
+        main(config_train, step)