--- a +++ b/unet.py @@ -0,0 +1,526 @@ +import numpy as np +from matplotlib import pyplot as plt +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import data +from torchvision import transforms +import os +import random +from PIL import Image +import sklearn +from bme1312 import lab2 as lab +from torchvision import transforms as T +from torch.utils.data import TensorDataset, DataLoader +from torch.utils.tensorboard import SummaryWriter +from bme1312.evaluation import get_DC,get_accuracy +from torchvision import transforms + + +plt.rcParams['image.interpolation'] = 'nearest' +plt.rcParams['image.cmap'] = 'gray' +device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') +print("Using device: ", device) + +def process_data(): + path="./cine_seg.npz" # input the path of cine_seg.npz in your environment + dataset=np.load(path,allow_pickle=True) + files=dataset.files + inputs=[] + labels=[] + for file in files: + inputs.append(dataset[file][0]) + labels.append(dataset[file][1]) + inputs = torch.Tensor(inputs) + labels = torch.Tensor(labels) + return inputs, labels + + + +inputs, labels = process_data() +inputs = inputs.unsqueeze(1) # Add channel dimension +labels = labels.unsqueeze(1) # Add channel dimension +print("inputs shape: ", inputs.shape) +print("labels shape: ", labels.shape) + + +def convert_to_multi_labels(label): + device = label.device + B, C, H, W = label.shape + new_tensor = torch.zeros((B, 3, H, W), device=device) + mask1 = (label >= 255).squeeze(1) + mask2 = ((label >= 170) & (label < 255-1)).squeeze(1) + mask3 = ((label >= 85) & (label < 170-1)).squeeze(1) + one = torch.ones(size= (B, H, W), device=device) + zero = torch.zeros(size=(B, H, W), device=device) + new_tensor[:, 0, :, :] = torch.where(mask1, one, zero) + new_tensor[:, 1, :, :] = torch.where(mask2, one, zero) + new_tensor[:, 2, :, :] = torch.where(mask3, one, zero) + return new_tensor + + +dataset = TensorDataset(inputs, labels) +batch_size = 32 +train_size = int(4/7 * len(dataset)) +val_size = int(1/7 * len(dataset)) +test_size = len(dataset) - train_size - val_size +train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) + +dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True) +dataloader_val = DataLoader(val_set, batch_size=batch_size, shuffle=True) +dataloader_test = DataLoader(test_set, batch_size=batch_size, shuffle=True) + + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + +class Down(nn.Module): + def __init__(self, in_channels, out_channels): + super(Down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.mpconv(x) + + +class Up(nn.Module): + def __init__(self, in_channels, out_channels, bilinear=True): + super(Up, self).__init__() + + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2) + + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True, C_base=64): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.C_base = C_base + + self.inc = DoubleConv(n_channels, C_base) + self.down1 = Down(C_base, C_base*2) + self.down2 = Down(C_base*2, C_base*4) + self.down3 = Down(C_base*4, C_base*8) + self.down4 = Down(C_base*8, C_base*8) + self.up1 = Up(C_base*16, C_base*4, bilinear) + self.up2 = Up(C_base*8, C_base*2, bilinear) + self.up3 = Up(C_base*4, C_base, bilinear) + self.up4 = Up(C_base*2, C_base, bilinear) + self.outc = nn.Conv2d(C_base, n_classes, 1) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.outc(x) + return x + +class MyBinaryCrossEntropy(object): + def __init__(self): + self.sigmoid = nn.Sigmoid() + self.bce = nn.BCELoss(reduction='mean') + + def __call__(self, pred_seg, seg_gt): + pred_seg_probs = self.sigmoid(pred_seg) + seg_gt_probs = convert_to_multi_labels(seg_gt) # convert to multi labels + loss = self.bce(pred_seg_probs, seg_gt_probs) + return loss + +import bme1312.lab2 as lab + +net = UNet(n_channels=1, n_classes=3, C_base=32) +optimizer = torch.optim.Adam(net.parameters(), lr=0.01) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) + +solver = lab.Solver( + model=net, + optimizer=optimizer, + criterion=MyBinaryCrossEntropy(), + lr_scheduler=lr_scheduler, +) + +solver.train( + epochs=50, + data_loader=dataloader_train, + val_loader=dataloader_val, + save_path='./model_original_unet_bce.pth', + img_name='original_unet_bce' +) + + +# calculate the dice coefficient +dice_scores = [] +net.to(device) +for images, labels in dataloader_test: + images = images.to(device) + labels = labels.to(device) + preds = net(images) + preds = torch.sigmoid(preds) > 0.5 # Applying threshold to get binary output + labels = convert_to_multi_labels(labels) # Convert labels to multi-label format + + dice_rv = get_DC(preds[:, 0, :, :], labels[:, 0, :, :]) + dice_myo = get_DC(preds[:, 1, :, :], labels[:, 1, :, :]) + dice_lv = get_DC(preds[:, 2, :, :], labels[:, 2, :, :]) + + dice_scores.append((dice_rv, dice_myo, dice_lv)) + +mean_dice_rv = np.mean([score[0] for score in dice_scores]) +std_dice_rv = np.std([score[0] for score in dice_scores]) +mean_dice_myo = np.mean([score[1] for score in dice_scores]) +std_dice_myo = np.std([score[1] for score in dice_scores]) +mean_dice_lv = np.mean([score[2] for score in dice_scores]) +std_dice_lv = np.std([score[2] for score in dice_scores]) + +print(f'RV Dice Coefficient: Mean={mean_dice_rv}, SD={std_dice_rv}') +print(f'MYO Dice Coefficient: Mean={mean_dice_myo}, SD={std_dice_myo}') +print(f'LV Dice Coefficient: Mean={mean_dice_lv}, SD={std_dice_lv}') + + + +class Up_NoShortcut(nn.Module): + def __init__(self, in_channels, out_channels, bilinear=True): + super(Up_NoShortcut, self).__init__() + + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels) + else: + self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2) + self.conv = DoubleConv(in_channels // 2, out_channels) + + + def forward(self, x1): + x1 = self.up(x1) + return self.conv(x1) + +class UNet_NoShortcut(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True, C_base=64): + super(UNet_NoShortcut, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.C_base = C_base + + self.inc = DoubleConv(n_channels, C_base) + self.down1 = Down(C_base, C_base*2) + self.down2 = Down(C_base*2, C_base*4) + self.down3 = Down(C_base*4, C_base*8) + self.down4 = Down(C_base*8, C_base*8) + self.up1 = Up_NoShortcut(C_base*8, C_base*4, bilinear) + self.up2 = Up_NoShortcut(C_base*4, C_base*2, bilinear) + self.up3 = Up_NoShortcut(C_base*2, C_base, bilinear) + self.up4 = Up_NoShortcut(C_base, C_base, bilinear) + self.outc = nn.Conv2d(C_base, n_classes, 1) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5) + x = self.up2(x) + x = self.up3(x) + x = self.up4(x) + x = self.outc(x) + return x + +net_no_shortcut = UNet_NoShortcut(n_channels=1, n_classes=3, C_base=32) +optimizer = torch.optim.Adam(net_no_shortcut.parameters(), lr=0.01) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) +solver_no_shortcut = lab.Solver( + model=net_no_shortcut, + optimizer=torch.optim.Adam(net_no_shortcut.parameters(), lr=0.01), + criterion=MyBinaryCrossEntropy(), + lr_scheduler=lr_scheduler +) + +solver_no_shortcut.train( + epochs=50, + data_loader=dataloader_train, + val_loader=dataloader_val, + save_path='./model_no_shortcut_unet.pth', + img_name='no_shortcut_unet' +) + +dice_scores_no_shortcut = [] + +net_no_shortcut.to(device) + +for images, labels in dataloader_test: + images = images.to(device) + labels = labels.to(device) + preds = net_no_shortcut(images) + preds = torch.sigmoid(preds) > 0.5 # Applying threshold to get binary output + labels = convert_to_multi_labels(labels) # Convert labels to multi-label format + + dice_rv_no_shortcut = get_DC(preds[:, 0, :, :], labels[:, 0, :, :]) + dice_myo_no_shortcut = get_DC(preds[:, 1, :, :], labels[:, 1, :, :]) + dice_lv_no_shortcut = get_DC(preds[:, 2, :, :], labels[:, 2, :, :]) + + dice_scores_no_shortcut.append((dice_rv_no_shortcut, dice_myo_no_shortcut, dice_lv_no_shortcut)) + +mean_dice_rv_no_shortcut = np.mean([score[0] for score in dice_scores_no_shortcut]) +std_dice_rv_no_shortcut = np.std([score[0] for score in dice_scores_no_shortcut]) +mean_dice_myo_no_shortcut = np.mean([score[1] for score in dice_scores_no_shortcut]) +std_dice_myo_no_shortcut = np.std([score[1] for score in dice_scores_no_shortcut]) +mean_dice_lv_no_shortcut = np.mean([score[2] for score in dice_scores_no_shortcut]) +std_dice_lv_no_shortcut = np.std([score[2] for score in dice_scores_no_shortcut]) + +print(f'RV Dice Coefficient Without Shortcut: Mean={mean_dice_rv_no_shortcut}, SD={std_dice_rv_no_shortcut}') +print(f'MYO Dice Coefficient Without Shortcut: Mean={mean_dice_myo_no_shortcut}, SD={std_dice_myo_no_shortcut}') +print(f'LV Dice Coefficient Without Shortcut: Mean={mean_dice_lv_no_shortcut}, SD={std_dice_lv_no_shortcut}') + + +transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(10) + ]) + +class SegmentationDataset(data.Dataset): + def __init__(self, inputs, labels, transform=None): + self.inputs = inputs + self.labels = labels + self.transform = transform + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + image = self.inputs[idx] + label = self.labels[idx] + + if self.transform: + all = torch.stack((image, label), dim = 0) + all = self.transform(all) + image = all[0] + label = all[1] + + return image, label + + +def extract_inputs_labels(dataset): + inputs = [] + labels = [] + for data in dataset: + input, label = data + inputs.append(input) + labels.append(label) + inputs = torch.stack(inputs) + labels = torch.stack(labels) + return inputs, labels + +inputs_train, labels_train = extract_inputs_labels(train_set) +inputs_val, labels_val = extract_inputs_labels(val_set) +inputs_test, labels_test = extract_inputs_labels(test_set) + +train_dataset = SegmentationDataset(inputs_train, labels_train, transform=transform) +val_dataset = SegmentationDataset(inputs_val, labels_val, transform=None) + + +dataloader_train_aug = data.DataLoader(train_dataset, batch_size=32, shuffle=True) +dataloader_val_aug = data.DataLoader(val_dataset, batch_size=32, shuffle=False) + +net_data_aug = UNet(n_channels=1, n_classes=3, C_base=32) +optimizer = torch.optim.Adam(net_data_aug.parameters(), lr=0.01) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) + +solver = lab.Solver( + model=net_data_aug, + optimizer=optimizer, + criterion=MyBinaryCrossEntropy(), + lr_scheduler=lr_scheduler, +) + +solver.train( + epochs=50, + data_loader=dataloader_train_aug, + val_loader=dataloader_val_aug, + save_path='./model_original_unet_data_aug_acc.pth', + img_name='origianl_unet_data_aug_dice_acc' +) + +accuracy_scores = [] + +net_data_aug.to(device) + +for images, labels in dataloader_test: + images = images.to(device) + labels = labels.to(device) + preds = net_data_aug(images) + preds = torch.sigmoid(preds) > 0.5 # Applying threshold to get binary output + labels = convert_to_multi_labels(labels) # Convert labels to multi-label format + + accuracy_rv = get_accuracy(preds[:, 0, :, :], labels[:, 0, :, :]) + accuracy_myo = get_accuracy(preds[:, 1, :, :], labels[:, 1, :, :]) + accuracy_lv = get_accuracy(preds[:, 2, :, :], labels[:, 2, :, :]) + + accuracy_scores.append((accuracy_rv, accuracy_myo, accuracy_lv)) + +mean_accuracy_rv = np.mean([score[0] for score in accuracy_scores]) +std_accuracy_rv = np.std([score[0] for score in accuracy_scores]) +mean_accuracy_myo = np.mean([score[1] for score in accuracy_scores]) +std_accuracy_myo = np.std([score[1] for score in accuracy_scores]) +mean_accuracy_lv = np.mean([score[2] for score in accuracy_scores]) +std_accuracy_lv = np.std([score[2] for score in accuracy_scores]) + +print(f'RV Accuracy: Mean={mean_accuracy_rv}, SD={std_accuracy_rv}') +print(f'MYO Accuracy: Mean={mean_accuracy_myo}, SD={std_accuracy_myo}') +print(f'LV Accuracy: Mean={mean_accuracy_lv}, SD={std_accuracy_lv}') + + +dice_scores_data_aug = [] + +net_data_aug.to(device) + +for images, labels in dataloader_test: + images = images.to(device) + labels = labels.to(device) + preds = net_data_aug(images) + preds = torch.sigmoid(preds) > 0.5 # Applying threshold to get binary output + labels = convert_to_multi_labels(labels) # Convert labels to multi-label format + + dice_rv_data_aug = get_DC(preds[:, 0, :, :], labels[:, 0, :, :]) + dice_myo_data_aug = get_DC(preds[:, 1, :, :], labels[:, 1, :, :]) + dice_lv_data_aug = get_DC(preds[:, 2, :, :], labels[:, 2, :, :]) + + dice_scores_data_aug.append((dice_rv_data_aug, dice_myo_data_aug, dice_lv_data_aug)) + +mean_dice_rv_data_aug = np.mean([score[0] for score in dice_scores_data_aug]) +std_dice_rv_data_aug= np.std([score[0] for score in dice_scores_data_aug]) +mean_dice_myo_data_aug = np.mean([score[1] for score in dice_scores_data_aug]) +std_dice_myo_data_aug = np.std([score[1] for score in dice_scores_data_aug]) +mean_dice_lv_data_aug = np.mean([score[2] for score in dice_scores_data_aug]) +std_dice_lv_data_aug = np.std([score[2] for score in dice_scores_data_aug]) + +print(f'RV Dice Coefficient With Data Augmentation: Mean={mean_dice_rv_data_aug}, SD={std_dice_rv_data_aug}') +print(f'MYO Dice Coefficient With Data Augmentation: Mean={mean_dice_myo_data_aug}, SD={std_dice_myo_data_aug}') +print(f'LV Dice Coefficient With Data Augmentation: Mean={mean_dice_lv_data_aug}, SD={std_dice_lv_data_aug}') + + + +class SoftDiceLoss(nn.Module): + def __init__(self, smooth=1.): + super(SoftDiceLoss, self).__init__() + self.smooth = smooth + + def forward(self, inputs, targets): + inputs = torch.sigmoid(inputs) # input is logits + targets = convert_to_multi_labels(targets) # Convert labels to multi-label format + + intersection = (inputs * targets).sum(dim=(2, 3)) + union = inputs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) + + dice_coefficient = (2. * intersection + self.smooth) / (union + self.smooth) + + dice_loss = 1 - dice_coefficient.mean() + + return dice_loss + +net_soft_dice = UNet(n_channels=1, n_classes=3, C_base=32) +optimizer = torch.optim.Adam(net_soft_dice.parameters(), lr=0.001) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) + + +dice_loss = SoftDiceLoss() + +solver = lab.Solver( + model=net_soft_dice, + optimizer=optimizer, + criterion=dice_loss, + lr_scheduler=lr_scheduler, +) + + +solver.train( + epochs=50, + data_loader=dataloader_train_aug, + val_loader=dataloader_val_aug, + save_path='./model_soft_dice_loss.pth', + img_name='soft_dice_loss' +) + +accuracy_scores_soft_dice = [] + +net_soft_dice.to(device) + +for images, labels in dataloader_test: + images = images.to(device) + labels = labels.to(device) + preds = net_soft_dice(images) + preds = torch.sigmoid(preds) > 0.5 # Applying threshold to get binary output + labels = convert_to_multi_labels(labels) # Convert labels to multi-label format + + accuracy_rv_soft_dice = get_accuracy(preds[:, 0, :, :], labels[:, 0, :, :]) + accuracy_myo_soft_dice = get_accuracy(preds[:, 1, :, :], labels[:, 1, :, :]) + accuracy_lv_soft_dice = get_accuracy(preds[:, 2, :, :], labels[:, 2, :, :]) + + accuracy_scores_soft_dice.append((accuracy_rv_soft_dice, accuracy_myo_soft_dice, accuracy_lv_soft_dice)) + +mean_accuracy_rv_soft_dice = np.mean([score[0] for score in accuracy_scores_soft_dice]) +std_accuracy_rv_soft_dice = np.std([score[0] for score in accuracy_scores_soft_dice]) +mean_accuracy_myo_soft_dice = np.mean([score[1] for score in accuracy_scores_soft_dice]) +std_accuracy_myo_soft_dice = np.std([score[1] for score in accuracy_scores_soft_dice]) +mean_accuracy_lv_soft_dice = np.mean([score[2] for score in accuracy_scores_soft_dice]) +std_accuracy_lv_soft_dice = np.std([score[2] for score in accuracy_scores_soft_dice]) + +print(f'RV Accuracy With Soft Dice Loss: Mean={mean_accuracy_rv_soft_dice}, SD={std_accuracy_rv_soft_dice}') +print(f'MYO Accuracy With Soft Dice Loss: Mean={mean_accuracy_myo_soft_dice}, SD={std_accuracy_myo_soft_dice}') +print(f'LV Accuracy With Soft Dice Loss: Mean={mean_accuracy_lv_soft_dice}, SD={std_accuracy_lv_soft_dice}') + +with open('output_results.txt', 'w') as file: + file.write(f'RV Dice Coefficient: Mean={mean_dice_rv}, SD={std_dice_rv}\n') + file.write(f'MYO Dice Coefficient: Mean={mean_dice_myo}, SD={std_dice_myo}\n') + file.write(f'LV Dice Coefficient: Mean={mean_dice_lv}, SD={std_dice_lv}\n') + + file.write(f'RV Dice Coefficient Without Shortcut: Mean={mean_dice_rv_no_shortcut}, SD={std_dice_rv_no_shortcut}\n') + file.write(f'MYO Dice Coefficient Without Shortcut: Mean={mean_dice_myo_no_shortcut}, SD={std_dice_myo_no_shortcut}\n') + file.write(f'LV Dice Coefficient Without Shortcut: Mean={mean_dice_lv_no_shortcut}, SD={std_dice_lv_no_shortcut}\n') + + file.write(f'RV Dice Coefficient With Data Augmentation: Mean={mean_dice_rv_data_aug}, SD={std_dice_rv_data_aug}\n') + file.write(f'MYO Dice Coefficient With Data Augmentation: Mean={mean_dice_myo_data_aug}, SD={std_dice_myo_data_aug}\n') + file.write(f'LV Dice Coefficient With Data Augmentation: Mean={mean_dice_lv_data_aug}, SD={std_dice_lv_data_aug}\n') + + file.write(f'RV Accuracy: Mean={mean_accuracy_rv}, SD={std_accuracy_rv}\n') + file.write(f'MYO Accuracy: Mean={mean_accuracy_myo}, SD={std_accuracy_myo}\n') + file.write(f'LV Accuracy: Mean={mean_accuracy_lv}, SD={std_accuracy_lv}\n') + + file.write(f'RV Accuracy With Soft Dice Loss: Mean={mean_accuracy_rv_soft_dice}, SD={std_accuracy_rv_soft_dice}\n') + file.write(f'MYO Accuracy With Soft Dice Loss: Mean={mean_accuracy_myo_soft_dice}, SD={std_accuracy_myo_soft_dice}\n') + file.write(f'LV Accuracy With Soft Dice Loss: Mean={mean_accuracy_lv_soft_dice}, SD={std_accuracy_lv_soft_dice}\n') \ No newline at end of file