--- a
+++ b/attention_unet.py
@@ -0,0 +1,322 @@
+import numpy as np
+from matplotlib import pyplot as plt
+import torch
+from torch.nn import MSELoss
+from torch.utils import data
+from torchvision import transforms
+
+plt.rcParams['image.interpolation'] = 'nearest'
+plt.rcParams['image.cmap'] = 'gray'
+
+def convert_to_multi_labels(label):
+    device = label.device
+    B, C, H, W = label.shape
+    new_tensor = torch.zeros((B, 3, H, W), device=device)
+    mask1 = (label >= 255).squeeze(1)
+    mask2 = ((label >= 170) & (label < 255-1)).squeeze(1)
+    mask3 = ((label >= 85) & (label < 170-1)).squeeze(1)
+    one = torch.ones(size= (B, H, W), device=device)
+    zero = torch.zeros(size=(B, H, W), device=device)
+    # print(mask1.shape, one.shape, new_tensor.shape)
+    new_tensor[:, 0, :, :] = torch.where(mask1, one, zero)
+    new_tensor[:, 1, :, :] = torch.where(mask2, one, zero)
+    new_tensor[:, 2, :, :] = torch.where(mask3, one, zero)
+    return new_tensor
+
+
+def process_data():
+    path = "./cine_seg.npz"
+    dataset = np.load(path, allow_pickle=True)
+    files = dataset.files
+    inputs = []
+    labels = []
+    for file in files:
+        inputs.append(dataset[file][0])
+        labels.append(dataset[file][1])
+    inputs = torch.Tensor(inputs)
+    labels = torch.Tensor(labels)
+    return inputs, labels
+
+
+from torch.utils.data import TensorDataset, DataLoader
+class SegmentationDataset(data.Dataset):
+    def __init__(self, inputs, labels, transform=None):
+        self.inputs = inputs
+        self.labels = labels
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.inputs)
+
+    def __getitem__(self, idx):
+        image = self.inputs[idx]
+        label = self.labels[idx]
+
+        if self.transform:
+            all = torch.stack((image, label), dim=0)
+            all = self.transform(all)
+            image = all[0]
+            label = all[1]
+
+        return image, label
+
+def extract_inputs_labels(dataset):
+    inputs = []
+    labels = []
+    for data in dataset:
+        input, label = data
+        inputs.append(input)
+        labels.append(label)
+    inputs = torch.stack(inputs)
+    labels = torch.stack(labels)
+    return inputs, labels
+
+
+
+
+inputs, labels = process_data()
+inputs = inputs.unsqueeze(1)
+labels = labels.unsqueeze(1)
+dataset = TensorDataset(inputs, labels)
+train_size = int(4/7 * len(dataset))
+val_size = int(1/7 * len(dataset))
+test_size = len(dataset) - train_size - val_size
+
+train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
+
+inputs_train, labels_train = extract_inputs_labels(train_set)
+inputs_val, labels_val = extract_inputs_labels(val_set)
+inputs_test, labels_test = extract_inputs_labels(test_set)
+
+transform = transforms.Compose([
+        transforms.RandomHorizontalFlip(),
+        transforms.RandomRotation(10)
+    ])
+
+
+train_dataset = SegmentationDataset(inputs_train, labels_train, transform=transform)
+val_dataset = SegmentationDataset(inputs_val, labels_val, transform=None)
+
+dataloader_train_aug = data.DataLoader(train_dataset, batch_size=32, shuffle=True)
+dataloader_val_aug = data.DataLoader(val_dataset, batch_size=32, shuffle=False)
+
+print(inputs.shape)
+print(labels.shape)
+# dataset = TensorDataset(inputs, labels)
+#
+batch_size = 32
+# train_size = int(4 / 7 * len(dataset))
+# val_size = int(1 / 7 * len(dataset))
+# test_size = len(dataset) - train_size - val_size
+# train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
+# print(len(train_set), len(val_set), len(test_set))
+dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True)
+dataloader_val = DataLoader(val_set, batch_size=batch_size, shuffle=True)
+dataloader_test = DataLoader(test_set, batch_size=batch_size, shuffle=True)
+
+from torch import nn
+import torch.nn.functional as F
+
+
+class DoubleConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(DoubleConv, self).__init__()
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        return self.conv(x)
+
+
+class AttentionBlock(nn.Module):
+    def __init__(self, F_g, F_l, F_int):
+        super(AttentionBlock, self).__init__()
+        self.W_g = nn.Sequential(
+            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
+            nn.BatchNorm2d(F_int)
+        )
+
+        self.W_x = nn.Sequential(
+            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
+            nn.BatchNorm2d(F_int)
+        )
+
+        self.psi = nn.Sequential(
+            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
+            nn.BatchNorm2d(1),
+            nn.Sigmoid()
+        )
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, g, skip_connection):
+        g1 = self.W_g(g)
+        x1 = self.W_x(skip_connection)
+        psi = self.relu(g1 + x1)
+        psi = self.psi(psi)
+        return skip_connection * psi
+
+class Down(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(Down, self).__init__()
+        self.mpconv = nn.Sequential(
+            nn.MaxPool2d(2),
+            DoubleConv(in_channels, out_channels)
+        )
+
+    def forward(self, x):
+        return self.mpconv(x)
+
+
+class Up(nn.Module):
+    def __init__(self, in_channels, out_channels, bilinear=True):
+        super(Up, self).__init__()
+
+        if bilinear:
+            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+        else:
+            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2)
+        self.attention = AttentionBlock(F_g=in_channels // 2, F_l=in_channels // 2, F_int=in_channels // 4)
+        self.conv = DoubleConv(in_channels, out_channels)
+
+    def forward(self, x1, x2):
+        x1 = self.up(x1)
+        # diffY = x2.size()[2] - x1.size()[2]
+        # diffX = x2.size()[3] - x1.size()[3]
+        #
+        # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+        #                 diffY // 2, diffY - diffY // 2])
+        #
+        # x = torch.cat([x2, x1], dim=1)
+        x3 = self.attention(x1, x2)
+        x = torch.cat([x3, x1], dim=1)
+        return self.conv(x)
+
+
+class UNet(nn.Module):
+    def __init__(self, n_channels, n_classes, bilinear=True, C_base=64):
+        super(UNet, self).__init__()
+        self.n_channels = n_channels
+        self.n_classes = n_classes
+        self.bilinear = bilinear
+        self.C_base = C_base
+
+        self.inc = DoubleConv(n_channels, C_base)
+        self.down1 = Down(C_base, C_base * 2)
+        self.down2 = Down(C_base * 2, C_base * 4)
+        self.down3 = Down(C_base * 4, C_base * 8)
+        self.down4 = Down(C_base * 8, C_base * 8)
+        self.up1 = Up(C_base * 16, C_base * 4, bilinear)
+        self.up2 = Up(C_base * 8, C_base * 2, bilinear)
+        self.up3 = Up(C_base * 4, C_base, bilinear)
+        self.up4 = Up(C_base * 2, C_base, bilinear)
+        self.outc = nn.Conv2d(C_base, n_classes, 1)
+
+    def forward(self, x):
+        x1 = self.inc(x)
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        x5 = self.down4(x4)
+        x = self.up1(x5, x4)
+        x = self.up2(x, x3)
+        x = self.up3(x, x2)
+        x = self.up4(x, x1)
+        x = self.outc(x)
+        return x
+
+
+class MyBinaryCrossEntropy(object):
+    def __init__(self):
+        self.sigmoid = nn.Sigmoid()
+        self.bce = nn.BCELoss(reduction='mean')
+
+    def __call__(self, pred_seg, seg_gt):
+        pred_seg_probs = self.sigmoid(pred_seg)
+        seg_gt_probs = convert_to_multi_labels(seg_gt)
+        loss = self.bce(pred_seg_probs, seg_gt_probs)
+        return loss
+
+
+
+from bme1312.evaluation import get_DC
+
+
+
+# def dice_coef_loss(y_pred, y_true):
+#     # print(y_true.shape)
+#     y_true_multi = convert_to_multi_labels(y_true)
+#     # print(y_true_multi.shape)
+#     loss = torch.zeros(1, device=y_pred.device)
+#     for i in range(3):
+#         loss += get_DC(y_pred[:, i, :, :], y_true_multi[:, i, :, :], threshold=0.5)
+#     return (1. - loss/3).requires_grad_()
+
+# def dice_coef_loss(y_pred, y_true):
+#     # print(y_true.shape)
+#     y_true_multi = convert_to_multi_labels(y_true)
+#     y_true_multi = torch.argmax(y_true_multi, dim=1, keepdim=True)
+#     # print(y_true_multi.shape)
+#     # loss = torch.zeros(1, device=y_pred.device)
+#     # for i in range(3):
+#         # loss += get_DC(y_pred[:, i, :, :], y_true_multi[:, i, :, :], threshold=0.5)
+#     y_pred_probs = torch.argmax(y_pred, dim=1, keepdim=True)
+#     return (1. - torch.tensor(get_DC(y_pred_probs, y_true_multi), requires_grad=True))
+
+def dice_coef_loss(y_pred, y_true):
+    # print(y_true.shape)
+    # y_true_multi = convert_to_multi_labels(y_true)
+    # print(y_true_multi.shape)
+    # loss = torch.zeros(1, device=y_pred.device)
+    # for i in range(3):
+        # loss += get_DC(y_pred[:, i, :, :], y_true_multi[:, i, :, :])
+    y_pred_probs = torch.max(y_pred, dim=1).values
+    return (1. - torch.tensor(get_DC(y_pred_probs, y_true), requires_grad=True))
+
+def generalized_dice_coeff(y_pred, y_true):
+    Ncl = y_pred.shape[1]
+    y_true_prob = convert_to_multi_labels(y_true)
+    w = torch.zeros(size=(Ncl,))
+    w = torch.sum(y_true_prob, dim=(0,2,3))
+    w = 1/(w**2+0.000001)
+    # Compute gen dice coef:
+    numerator = y_true_prob*y_pred
+    numerator = w*torch.sum(numerator, dim=(0,1,2,3))
+    numerator = torch.sum(numerator)
+    denominator = y_true_prob+y_pred
+    denominator = w*torch.sum(denominator, dim=(0,1,2,3))
+    denominator = torch.sum(denominator)
+    gen_dice_coef = 2*numerator/denominator
+    return gen_dice_coef
+
+
+def generalized_dice_loss(y_true, y_pred):
+    return 1 - generalized_dice_coeff(y_true, y_pred)
+
+import bme1312.lab2 as lab
+
+net = UNet(n_channels=1, n_classes=3, C_base=32)
+optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
+lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)
+
+solver = lab.Solver(
+    model=net,
+    optimizer=optimizer,
+    criterion=MyBinaryCrossEntropy(),
+    lr_scheduler=lr_scheduler
+)
+solver.train(
+    epochs=30,
+    data_loader=dataloader_train_aug,
+    val_loader=dataloader_val_aug,
+    save_path='./model_attention_30.pth',
+    img_name='attention_aug_30',
+)
+
+solver.visualize(data_loader=dataloader_test, idx=10)