--- a +++ b/attention_unet.py @@ -0,0 +1,322 @@ +import numpy as np +from matplotlib import pyplot as plt +import torch +from torch.nn import MSELoss +from torch.utils import data +from torchvision import transforms + +plt.rcParams['image.interpolation'] = 'nearest' +plt.rcParams['image.cmap'] = 'gray' + +def convert_to_multi_labels(label): + device = label.device + B, C, H, W = label.shape + new_tensor = torch.zeros((B, 3, H, W), device=device) + mask1 = (label >= 255).squeeze(1) + mask2 = ((label >= 170) & (label < 255-1)).squeeze(1) + mask3 = ((label >= 85) & (label < 170-1)).squeeze(1) + one = torch.ones(size= (B, H, W), device=device) + zero = torch.zeros(size=(B, H, W), device=device) + # print(mask1.shape, one.shape, new_tensor.shape) + new_tensor[:, 0, :, :] = torch.where(mask1, one, zero) + new_tensor[:, 1, :, :] = torch.where(mask2, one, zero) + new_tensor[:, 2, :, :] = torch.where(mask3, one, zero) + return new_tensor + + +def process_data(): + path = "./cine_seg.npz" + dataset = np.load(path, allow_pickle=True) + files = dataset.files + inputs = [] + labels = [] + for file in files: + inputs.append(dataset[file][0]) + labels.append(dataset[file][1]) + inputs = torch.Tensor(inputs) + labels = torch.Tensor(labels) + return inputs, labels + + +from torch.utils.data import TensorDataset, DataLoader +class SegmentationDataset(data.Dataset): + def __init__(self, inputs, labels, transform=None): + self.inputs = inputs + self.labels = labels + self.transform = transform + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + image = self.inputs[idx] + label = self.labels[idx] + + if self.transform: + all = torch.stack((image, label), dim=0) + all = self.transform(all) + image = all[0] + label = all[1] + + return image, label + +def extract_inputs_labels(dataset): + inputs = [] + labels = [] + for data in dataset: + input, label = data + inputs.append(input) + labels.append(label) + inputs = torch.stack(inputs) + labels = torch.stack(labels) + return inputs, labels + + + + +inputs, labels = process_data() +inputs = inputs.unsqueeze(1) +labels = labels.unsqueeze(1) +dataset = TensorDataset(inputs, labels) +train_size = int(4/7 * len(dataset)) +val_size = int(1/7 * len(dataset)) +test_size = len(dataset) - train_size - val_size + +train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) + +inputs_train, labels_train = extract_inputs_labels(train_set) +inputs_val, labels_val = extract_inputs_labels(val_set) +inputs_test, labels_test = extract_inputs_labels(test_set) + +transform = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(10) + ]) + + +train_dataset = SegmentationDataset(inputs_train, labels_train, transform=transform) +val_dataset = SegmentationDataset(inputs_val, labels_val, transform=None) + +dataloader_train_aug = data.DataLoader(train_dataset, batch_size=32, shuffle=True) +dataloader_val_aug = data.DataLoader(val_dataset, batch_size=32, shuffle=False) + +print(inputs.shape) +print(labels.shape) +# dataset = TensorDataset(inputs, labels) +# +batch_size = 32 +# train_size = int(4 / 7 * len(dataset)) +# val_size = int(1 / 7 * len(dataset)) +# test_size = len(dataset) - train_size - val_size +# train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) +# print(len(train_set), len(val_set), len(test_set)) +dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True) +dataloader_val = DataLoader(val_set, batch_size=batch_size, shuffle=True) +dataloader_test = DataLoader(test_set, batch_size=batch_size, shuffle=True) + +from torch import nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(DoubleConv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.conv(x) + + +class AttentionBlock(nn.Module): + def __init__(self, F_g, F_l, F_int): + super(AttentionBlock, self).__init__() + self.W_g = nn.Sequential( + nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), + nn.BatchNorm2d(F_int) + ) + + self.W_x = nn.Sequential( + nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), + nn.BatchNorm2d(F_int) + ) + + self.psi = nn.Sequential( + nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), + nn.BatchNorm2d(1), + nn.Sigmoid() + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, g, skip_connection): + g1 = self.W_g(g) + x1 = self.W_x(skip_connection) + psi = self.relu(g1 + x1) + psi = self.psi(psi) + return skip_connection * psi + +class Down(nn.Module): + def __init__(self, in_channels, out_channels): + super(Down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.mpconv(x) + + +class Up(nn.Module): + def __init__(self, in_channels, out_channels, bilinear=True): + super(Up, self).__init__() + + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + else: + self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2) + self.attention = AttentionBlock(F_g=in_channels // 2, F_l=in_channels // 2, F_int=in_channels // 4) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # diffY = x2.size()[2] - x1.size()[2] + # diffX = x2.size()[3] - x1.size()[3] + # + # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + # diffY // 2, diffY - diffY // 2]) + # + # x = torch.cat([x2, x1], dim=1) + x3 = self.attention(x1, x2) + x = torch.cat([x3, x1], dim=1) + return self.conv(x) + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True, C_base=64): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.C_base = C_base + + self.inc = DoubleConv(n_channels, C_base) + self.down1 = Down(C_base, C_base * 2) + self.down2 = Down(C_base * 2, C_base * 4) + self.down3 = Down(C_base * 4, C_base * 8) + self.down4 = Down(C_base * 8, C_base * 8) + self.up1 = Up(C_base * 16, C_base * 4, bilinear) + self.up2 = Up(C_base * 8, C_base * 2, bilinear) + self.up3 = Up(C_base * 4, C_base, bilinear) + self.up4 = Up(C_base * 2, C_base, bilinear) + self.outc = nn.Conv2d(C_base, n_classes, 1) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.outc(x) + return x + + +class MyBinaryCrossEntropy(object): + def __init__(self): + self.sigmoid = nn.Sigmoid() + self.bce = nn.BCELoss(reduction='mean') + + def __call__(self, pred_seg, seg_gt): + pred_seg_probs = self.sigmoid(pred_seg) + seg_gt_probs = convert_to_multi_labels(seg_gt) + loss = self.bce(pred_seg_probs, seg_gt_probs) + return loss + + + +from bme1312.evaluation import get_DC + + + +# def dice_coef_loss(y_pred, y_true): +# # print(y_true.shape) +# y_true_multi = convert_to_multi_labels(y_true) +# # print(y_true_multi.shape) +# loss = torch.zeros(1, device=y_pred.device) +# for i in range(3): +# loss += get_DC(y_pred[:, i, :, :], y_true_multi[:, i, :, :], threshold=0.5) +# return (1. - loss/3).requires_grad_() + +# def dice_coef_loss(y_pred, y_true): +# # print(y_true.shape) +# y_true_multi = convert_to_multi_labels(y_true) +# y_true_multi = torch.argmax(y_true_multi, dim=1, keepdim=True) +# # print(y_true_multi.shape) +# # loss = torch.zeros(1, device=y_pred.device) +# # for i in range(3): +# # loss += get_DC(y_pred[:, i, :, :], y_true_multi[:, i, :, :], threshold=0.5) +# y_pred_probs = torch.argmax(y_pred, dim=1, keepdim=True) +# return (1. - torch.tensor(get_DC(y_pred_probs, y_true_multi), requires_grad=True)) + +def dice_coef_loss(y_pred, y_true): + # print(y_true.shape) + # y_true_multi = convert_to_multi_labels(y_true) + # print(y_true_multi.shape) + # loss = torch.zeros(1, device=y_pred.device) + # for i in range(3): + # loss += get_DC(y_pred[:, i, :, :], y_true_multi[:, i, :, :]) + y_pred_probs = torch.max(y_pred, dim=1).values + return (1. - torch.tensor(get_DC(y_pred_probs, y_true), requires_grad=True)) + +def generalized_dice_coeff(y_pred, y_true): + Ncl = y_pred.shape[1] + y_true_prob = convert_to_multi_labels(y_true) + w = torch.zeros(size=(Ncl,)) + w = torch.sum(y_true_prob, dim=(0,2,3)) + w = 1/(w**2+0.000001) + # Compute gen dice coef: + numerator = y_true_prob*y_pred + numerator = w*torch.sum(numerator, dim=(0,1,2,3)) + numerator = torch.sum(numerator) + denominator = y_true_prob+y_pred + denominator = w*torch.sum(denominator, dim=(0,1,2,3)) + denominator = torch.sum(denominator) + gen_dice_coef = 2*numerator/denominator + return gen_dice_coef + + +def generalized_dice_loss(y_true, y_pred): + return 1 - generalized_dice_coeff(y_true, y_pred) + +import bme1312.lab2 as lab + +net = UNet(n_channels=1, n_classes=3, C_base=32) +optimizer = torch.optim.Adam(net.parameters(), lr=0.001) +lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95) + +solver = lab.Solver( + model=net, + optimizer=optimizer, + criterion=MyBinaryCrossEntropy(), + lr_scheduler=lr_scheduler +) +solver.train( + epochs=30, + data_loader=dataloader_train_aug, + val_loader=dataloader_val_aug, + save_path='./model_attention_30.pth', + img_name='attention_aug_30', +) + +solver.visualize(data_loader=dataloader_test, idx=10)