--- a +++ b/Evidential_segmentation/TRAINING-RBF.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python +# coding: utf-8 + +# In[]: + +########################### IMPORTS ############################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F +import monai +from monai.networks.nets import UNet,UNet_RBF +from monai.networks.utils import one_hot +from monai.transforms import ( + AsDiscrete, + AddChanneld, + AsChannelFirstd, + Compose, + LoadNiftid, + RandCropByPosNegLabeld, + RandRotate90d, + ScaleIntensityd, + ToTensord, +) +from monai.visualize import plot_2d_or_3d_image +from monai.data.utils import list_data_collate, worker_init_fn +from monai.inferers import sliding_window_inference +from monai.metrics import DiceMetric +from monai.metrics import compute_meandice +from torch.autograd import Variable +import torch.optim as optim +from torch.optim import lr_scheduler +from torch.utils.data import DataLoader +from torch.utils.data import Dataset +import torch.backends.cudnn as cudnn +import torchvision +from torchvision import datasets, models, transforms +import csv +import time +import SimpleITK as sitk +from os.path import splitext,basename +import random +from glob import glob + +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +from copy import copy +import os +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +#from global_tools.tools import display_loading_bar +from class_modalities.transforms import LoadNifti, Roi2Mask, ResampleReshapeAlign, Sitk2Numpy, ConcatModality +from monai.utils import first, set_determinism + + + +################## +train_transforms = Compose( + [ # read img + meta info + LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), + Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), + # user can also add other random transforms + + # Prepare for neural network + ConcatModality(keys=['pet_img', 'ct_img']), + AddChanneld(keys=["mask_img"]), # Add channel to the first axis + ToTensord(keys=["image", "mask_img"]), + ]) +# without data augmentation for validation +val_transforms = Compose( + [ # read img + meta info + LoadNifti(keys=["pet_img", "ct_img", "mask_img"]), + Sitk2Numpy(keys=['pet_img', 'ct_img', 'mask_img']), + + # Prepare for neural network + ConcatModality(keys=['pet_img', 'ct_img']), + AddChanneld(keys=["mask_img"]), # Add channel to the first axis + ToTensord(keys=["image", "mask_img"]), + ]) + + +##################loading data############################### + +device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu") + + +base_path="/home/lab/hualing/2.5_SUV_dilation" #####the path you put the pre_processed data. +pet_path = base_path + '/' + 'pet_test' +ct_path = base_path + '/' + 'ct_test' +mask_path = base_path + '/' + 'pet_test_mask' +PET_ids = sorted(glob(os.path.join(pet_path, '*pet.nii'))) +CT_ids = sorted(glob(os.path.join(ct_path, '*ct.nii'))) +MASK_ids = sorted(glob(os.path.join(mask_path, '*mask.nii'))) + +data_dicts= zip(PET_ids, CT_ids, MASK_ids) +files=list(data_dicts) + + +train_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for PET, CT, MASK in files[:138]] +val_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for PET, CT, MASK in files[138:156]] +test_files = [{"pet_img": PET, "ct_img": CT, 'mask_img': MASK} for PET, CT, MASK in files[156:]] + + +train_ds = monai.data.Dataset(data=train_files,transform=train_transforms) +val_ds = monai.data.Dataset(data=val_files,transform=val_transforms) +test_ds = monai.data.Dataset(data=test_files,transform=val_transforms) + +train_loader = DataLoader( + train_ds, + batch_size=3, + shuffle=True, + num_workers=4, + collate_fn=list_data_collate, + pin_memory=torch.cuda.is_available(),) + +val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) +test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate) + +###################################### define model####################################### +trained_model_path="./pre-trained_model/best_metric_model_unet.pth" ####path to the pre-trained UNET model +model = UNet_RBF( + dimensions=3, # 3D + in_channels=2, + out_channels=2, + kernel_size=5, + channels=(8,16, 32, 64,128), + strides=(2, 2, 2, 2), + num_res_units=2,).to(device) +model_dict = model.state_dict() +pre_dict = torch.load(trained_model_path) +pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict} +model_dict.update(pre_dict) + +model.load_state_dict(model_dict) + +####code to make sure only the parameters from ENN are optimized +for name, param in model.named_parameters(): + if param.requires_grad==True: + print(name) +####code to make sure only the parameters from ENN are optimized +params = filter(lambda p: p.requires_grad, model.parameters()) +optimizer = torch.optim.Adam(params, 1e-2) +dice_metric = monai.metrics.DiceMetric( include_background=False,reduction="mean") +scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,'min',patience=10) + +loss_function = monai.losses.DiceLoss(include_background=False,softmax=False,squared_pred=True,to_onehot_y=True) + + +# TODO : generate a learning rate scheduler + +val_interval = 1 +best_metric = -1 +best_metric_epoch = -1 +epoch_loss_values = list() +metric_values = list() + +writer = SummaryWriter() +post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2) +post_label = AsDiscrete(to_onehot=True, n_classes=2) + + +############################################# training and validation############################################# + +for epoch in range(100): + print("-" * 10) + print(f"epoch {epoch + 1}/{100}") + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step += 1 + inputs, labels = batch_data["image"].to(device), batch_data["mask_img"].to(device) + optimizer.zero_grad() + pm,mass = model(inputs) + + loss=loss_function(pm, labels) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + epoch_len = len(train_ds) // train_loader.batch_size + print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") + writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + + scheduler.step(epoch_loss) + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + metric_sum = 0.0 + metric_count = 0 + + for val_data in val_loader: + val_images, val_labels = val_data["image"].to(device), val_data["mask_img"].to(device) + pm,mass = model(val_images) + val_outputs=pm + output=pm + value = dice_metric(y_pred=val_outputs, y=val_labels) + + metric_count += len(value) + metric_sum += value.item() * len(value) + + metric = metric_sum / metric_count + metric_values.append(metric) + + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + torch.save(model.state_dict(), "RBF_best_metric_model_segmentation3d_dict.pth") + print("saved new best metric model") + print( + "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( + epoch + 1, metric, best_metric, best_metric_epoch + ) + ) +print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") +writer.close() \ No newline at end of file