Diff of /unet.py [000000] .. [2147a4]

Switch to side-by-side view

--- a
+++ b/unet.py
@@ -0,0 +1,526 @@
+import numpy as np
+from matplotlib import pyplot as plt
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.utils import data
+from torchvision import transforms
+import os
+import random
+from PIL import Image
+import sklearn
+from bme1312 import lab2 as lab
+from torchvision import transforms as T
+from torch.utils.data import TensorDataset, DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from bme1312.evaluation import get_DC,get_accuracy
+from torchvision import transforms
+
+
+plt.rcParams['image.interpolation'] = 'nearest'
+plt.rcParams['image.cmap'] = 'gray'
+device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
+print("Using device: ", device)
+
+def process_data():
+    path="./cine_seg.npz" # input the path of cine_seg.npz in your environment
+    dataset=np.load(path,allow_pickle=True)
+    files=dataset.files
+    inputs=[]
+    labels=[]
+    for file in files:
+        inputs.append(dataset[file][0])
+        labels.append(dataset[file][1])
+    inputs = torch.Tensor(inputs)
+    labels = torch.Tensor(labels)
+    return inputs, labels
+
+
+
+inputs, labels = process_data()
+inputs = inputs.unsqueeze(1)  # Add channel dimension
+labels = labels.unsqueeze(1)  # Add channel dimension
+print("inputs shape: ", inputs.shape)
+print("labels shape: ", labels.shape)
+
+
+def convert_to_multi_labels(label): 
+    device = label.device
+    B, C, H, W = label.shape
+    new_tensor = torch.zeros((B, 3, H, W), device=device)
+    mask1 = (label >= 255).squeeze(1)
+    mask2 = ((label >= 170) & (label < 255-1)).squeeze(1)
+    mask3 = ((label >= 85) & (label < 170-1)).squeeze(1)
+    one = torch.ones(size= (B, H, W), device=device)
+    zero = torch.zeros(size=(B, H, W), device=device)
+    new_tensor[:, 0, :, :] = torch.where(mask1, one, zero)
+    new_tensor[:, 1, :, :] = torch.where(mask2, one, zero)
+    new_tensor[:, 2, :, :] = torch.where(mask3, one, zero)
+    return new_tensor
+    
+    
+dataset = TensorDataset(inputs, labels) 
+batch_size = 32
+train_size = int(4/7 * len(dataset))
+val_size = int(1/7 * len(dataset))
+test_size = len(dataset) - train_size - val_size
+train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
+
+dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True)
+dataloader_val = DataLoader(val_set, batch_size=batch_size, shuffle=True)
+dataloader_test =  DataLoader(test_set, batch_size=batch_size, shuffle=True)
+
+
+class DoubleConv(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(DoubleConv, self).__init__()
+        self.conv = nn.Sequential(
+            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
+            nn.BatchNorm2d(out_channels),
+            nn.ReLU(inplace=True)
+        )
+
+    def forward(self, x):
+        return self.conv(x)
+
+class Down(nn.Module):
+    def __init__(self, in_channels, out_channels):
+        super(Down, self).__init__()
+        self.mpconv = nn.Sequential(
+            nn.MaxPool2d(2),
+            DoubleConv(in_channels, out_channels)
+        )
+
+    def forward(self, x):
+        return self.mpconv(x)
+
+
+class Up(nn.Module):
+    def __init__(self, in_channels, out_channels, bilinear=True):
+        super(Up, self).__init__()
+
+        if bilinear:
+            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+        else:
+            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2)
+
+        self.conv = DoubleConv(in_channels, out_channels)
+
+    def forward(self, x1, x2):
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+
+        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+                        diffY // 2, diffY - diffY // 2])
+        x = torch.cat([x2, x1], dim=1)
+        return self.conv(x)
+
+class UNet(nn.Module):
+    def __init__(self, n_channels, n_classes, bilinear=True, C_base=64):
+        super(UNet, self).__init__()
+        self.n_channels = n_channels
+        self.n_classes = n_classes
+        self.bilinear = bilinear
+        self.C_base = C_base
+
+        self.inc = DoubleConv(n_channels, C_base)
+        self.down1 = Down(C_base, C_base*2)
+        self.down2 = Down(C_base*2, C_base*4)
+        self.down3 = Down(C_base*4, C_base*8)
+        self.down4 = Down(C_base*8, C_base*8)
+        self.up1 = Up(C_base*16, C_base*4, bilinear)
+        self.up2 = Up(C_base*8, C_base*2, bilinear)
+        self.up3 = Up(C_base*4, C_base, bilinear)
+        self.up4 = Up(C_base*2, C_base, bilinear)
+        self.outc = nn.Conv2d(C_base, n_classes, 1)
+
+    def forward(self, x):
+        x1 = self.inc(x)
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        x5 = self.down4(x4)
+        x = self.up1(x5, x4)
+        x = self.up2(x, x3)
+        x = self.up3(x, x2)
+        x = self.up4(x, x1)
+        x = self.outc(x)
+        return x
+
+class MyBinaryCrossEntropy(object):
+    def __init__(self):
+        self.sigmoid = nn.Sigmoid()
+        self.bce = nn.BCELoss(reduction='mean')
+
+    def __call__(self, pred_seg, seg_gt):
+        pred_seg_probs = self.sigmoid(pred_seg)
+        seg_gt_probs = convert_to_multi_labels(seg_gt) # convert to multi labels
+        loss = self.bce(pred_seg_probs, seg_gt_probs)
+        return loss
+    
+import bme1312.lab2 as lab
+
+net = UNet(n_channels=1, n_classes=3, C_base=32)
+optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
+lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)
+
+solver = lab.Solver(
+    model=net,
+    optimizer=optimizer,
+    criterion=MyBinaryCrossEntropy(),
+    lr_scheduler=lr_scheduler,
+)
+
+solver.train(
+    epochs=50,
+    data_loader=dataloader_train,
+    val_loader=dataloader_val,
+    save_path='./model_original_unet_bce.pth',
+    img_name='original_unet_bce'
+)
+
+
+# calculate the dice coefficient
+dice_scores = []
+net.to(device)
+for images, labels in dataloader_test:
+    images = images.to(device)  
+    labels = labels.to(device) 
+    preds = net(images)
+    preds = torch.sigmoid(preds) > 0.5  # Applying threshold to get binary output
+    labels = convert_to_multi_labels(labels)  # Convert labels to multi-label format
+
+    dice_rv = get_DC(preds[:, 0, :, :], labels[:, 0, :, :])
+    dice_myo = get_DC(preds[:, 1, :, :], labels[:, 1, :, :])
+    dice_lv = get_DC(preds[:, 2, :, :], labels[:, 2, :, :])
+
+    dice_scores.append((dice_rv, dice_myo, dice_lv))
+
+mean_dice_rv = np.mean([score[0] for score in dice_scores])
+std_dice_rv = np.std([score[0] for score in dice_scores])
+mean_dice_myo = np.mean([score[1] for score in dice_scores])
+std_dice_myo = np.std([score[1] for score in dice_scores])
+mean_dice_lv = np.mean([score[2] for score in dice_scores])
+std_dice_lv = np.std([score[2] for score in dice_scores])
+
+print(f'RV Dice Coefficient: Mean={mean_dice_rv}, SD={std_dice_rv}')
+print(f'MYO Dice Coefficient: Mean={mean_dice_myo}, SD={std_dice_myo}')
+print(f'LV Dice Coefficient: Mean={mean_dice_lv}, SD={std_dice_lv}')
+
+
+
+class Up_NoShortcut(nn.Module):
+    def __init__(self, in_channels, out_channels, bilinear=True):
+        super(Up_NoShortcut, self).__init__()
+
+        if bilinear:
+            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
+            self.conv = DoubleConv(in_channels, out_channels)
+        else:
+            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, 2, stride=2)
+            self.conv = DoubleConv(in_channels // 2, out_channels)
+
+
+    def forward(self, x1):
+        x1 = self.up(x1)
+        return self.conv(x1)
+
+class UNet_NoShortcut(nn.Module):
+    def __init__(self, n_channels, n_classes, bilinear=True, C_base=64):
+        super(UNet_NoShortcut, self).__init__()
+        self.n_channels = n_channels
+        self.n_classes = n_classes
+        self.bilinear = bilinear
+        self.C_base = C_base
+
+        self.inc = DoubleConv(n_channels, C_base)
+        self.down1 = Down(C_base, C_base*2)
+        self.down2 = Down(C_base*2, C_base*4)
+        self.down3 = Down(C_base*4, C_base*8)
+        self.down4 = Down(C_base*8, C_base*8)
+        self.up1 = Up_NoShortcut(C_base*8, C_base*4, bilinear)
+        self.up2 = Up_NoShortcut(C_base*4, C_base*2, bilinear)
+        self.up3 = Up_NoShortcut(C_base*2, C_base, bilinear)
+        self.up4 = Up_NoShortcut(C_base, C_base, bilinear)
+        self.outc = nn.Conv2d(C_base, n_classes, 1)
+
+    def forward(self, x):
+        x1 = self.inc(x)
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        x5 = self.down4(x4)
+        x = self.up1(x5)
+        x = self.up2(x)
+        x = self.up3(x)
+        x = self.up4(x)
+        x = self.outc(x)
+        return x
+    
+net_no_shortcut = UNet_NoShortcut(n_channels=1, n_classes=3, C_base=32)
+optimizer = torch.optim.Adam(net_no_shortcut.parameters(), lr=0.01)
+lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)
+solver_no_shortcut = lab.Solver(
+    model=net_no_shortcut,
+    optimizer=torch.optim.Adam(net_no_shortcut.parameters(), lr=0.01),
+    criterion=MyBinaryCrossEntropy(),
+    lr_scheduler=lr_scheduler
+)
+
+solver_no_shortcut.train(
+    epochs=50,
+    data_loader=dataloader_train,
+    val_loader=dataloader_val,
+    save_path='./model_no_shortcut_unet.pth',
+    img_name='no_shortcut_unet'
+)
+
+dice_scores_no_shortcut = []
+
+net_no_shortcut.to(device)
+
+for images, labels in dataloader_test:
+    images = images.to(device)  
+    labels = labels.to(device) 
+    preds = net_no_shortcut(images)
+    preds = torch.sigmoid(preds) > 0.5  # Applying threshold to get binary output
+    labels = convert_to_multi_labels(labels)  # Convert labels to multi-label format
+
+    dice_rv_no_shortcut = get_DC(preds[:, 0, :, :], labels[:, 0, :, :])
+    dice_myo_no_shortcut = get_DC(preds[:, 1, :, :], labels[:, 1, :, :])
+    dice_lv_no_shortcut = get_DC(preds[:, 2, :, :], labels[:, 2, :, :])
+
+    dice_scores_no_shortcut.append((dice_rv_no_shortcut, dice_myo_no_shortcut, dice_lv_no_shortcut))
+
+mean_dice_rv_no_shortcut = np.mean([score[0] for score in dice_scores_no_shortcut])
+std_dice_rv_no_shortcut = np.std([score[0] for score in dice_scores_no_shortcut])
+mean_dice_myo_no_shortcut = np.mean([score[1] for score in dice_scores_no_shortcut])
+std_dice_myo_no_shortcut = np.std([score[1] for score in dice_scores_no_shortcut])
+mean_dice_lv_no_shortcut = np.mean([score[2] for score in dice_scores_no_shortcut])
+std_dice_lv_no_shortcut = np.std([score[2] for score in dice_scores_no_shortcut])
+
+print(f'RV Dice Coefficient Without Shortcut: Mean={mean_dice_rv_no_shortcut}, SD={std_dice_rv_no_shortcut}')
+print(f'MYO Dice Coefficient Without Shortcut: Mean={mean_dice_myo_no_shortcut}, SD={std_dice_myo_no_shortcut}')
+print(f'LV Dice Coefficient Without Shortcut: Mean={mean_dice_lv_no_shortcut}, SD={std_dice_lv_no_shortcut}')
+
+
+transform = transforms.Compose([
+        transforms.RandomHorizontalFlip(),
+        transforms.RandomRotation(10)
+    ])
+
+class SegmentationDataset(data.Dataset):
+    def __init__(self, inputs, labels, transform=None):
+        self.inputs = inputs
+        self.labels = labels
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.inputs)
+
+    def __getitem__(self, idx):
+        image = self.inputs[idx]
+        label = self.labels[idx]
+        
+        if self.transform: 
+            all = torch.stack((image, label), dim = 0)
+            all = self.transform(all)
+            image = all[0]
+            label = all[1]
+            
+        return image, label
+
+
+def extract_inputs_labels(dataset):
+    inputs = []
+    labels = []
+    for data in dataset:
+        input, label = data
+        inputs.append(input)
+        labels.append(label)
+    inputs = torch.stack(inputs)
+    labels = torch.stack(labels)
+    return inputs, labels
+
+inputs_train, labels_train = extract_inputs_labels(train_set)
+inputs_val, labels_val = extract_inputs_labels(val_set)
+inputs_test, labels_test = extract_inputs_labels(test_set)
+
+train_dataset = SegmentationDataset(inputs_train, labels_train, transform=transform)
+val_dataset = SegmentationDataset(inputs_val, labels_val, transform=None)
+
+
+dataloader_train_aug = data.DataLoader(train_dataset, batch_size=32, shuffle=True)
+dataloader_val_aug = data.DataLoader(val_dataset, batch_size=32, shuffle=False)
+
+net_data_aug = UNet(n_channels=1, n_classes=3, C_base=32)
+optimizer = torch.optim.Adam(net_data_aug.parameters(), lr=0.01)
+lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)
+
+solver = lab.Solver(
+    model=net_data_aug,
+    optimizer=optimizer,
+    criterion=MyBinaryCrossEntropy(),
+    lr_scheduler=lr_scheduler,
+)
+
+solver.train(
+    epochs=50,
+    data_loader=dataloader_train_aug,
+    val_loader=dataloader_val_aug,
+    save_path='./model_original_unet_data_aug_acc.pth',
+    img_name='origianl_unet_data_aug_dice_acc'
+)
+
+accuracy_scores = []
+
+net_data_aug.to(device)
+
+for images, labels in dataloader_test:
+    images = images.to(device)  
+    labels = labels.to(device) 
+    preds = net_data_aug(images)
+    preds = torch.sigmoid(preds) > 0.5  # Applying threshold to get binary output
+    labels = convert_to_multi_labels(labels)  # Convert labels to multi-label format
+
+    accuracy_rv = get_accuracy(preds[:, 0, :, :], labels[:, 0, :, :])
+    accuracy_myo = get_accuracy(preds[:, 1, :, :], labels[:, 1, :, :])
+    accuracy_lv = get_accuracy(preds[:, 2, :, :], labels[:, 2, :, :])
+
+    accuracy_scores.append((accuracy_rv, accuracy_myo, accuracy_lv))
+
+mean_accuracy_rv = np.mean([score[0] for score in accuracy_scores])
+std_accuracy_rv = np.std([score[0] for score in accuracy_scores])
+mean_accuracy_myo = np.mean([score[1] for score in accuracy_scores])
+std_accuracy_myo = np.std([score[1] for score in accuracy_scores])
+mean_accuracy_lv = np.mean([score[2] for score in accuracy_scores])
+std_accuracy_lv = np.std([score[2] for score in accuracy_scores])
+
+print(f'RV Accuracy: Mean={mean_accuracy_rv}, SD={std_accuracy_rv}')
+print(f'MYO Accuracy: Mean={mean_accuracy_myo}, SD={std_accuracy_myo}')
+print(f'LV Accuracy: Mean={mean_accuracy_lv}, SD={std_accuracy_lv}')
+
+
+dice_scores_data_aug = []
+
+net_data_aug.to(device)
+
+for images, labels in dataloader_test:
+    images = images.to(device)  
+    labels = labels.to(device) 
+    preds = net_data_aug(images)
+    preds = torch.sigmoid(preds) > 0.5  # Applying threshold to get binary output
+    labels = convert_to_multi_labels(labels)  # Convert labels to multi-label format
+
+    dice_rv_data_aug = get_DC(preds[:, 0, :, :], labels[:, 0, :, :])
+    dice_myo_data_aug = get_DC(preds[:, 1, :, :], labels[:, 1, :, :])
+    dice_lv_data_aug = get_DC(preds[:, 2, :, :], labels[:, 2, :, :])
+
+    dice_scores_data_aug.append((dice_rv_data_aug, dice_myo_data_aug, dice_lv_data_aug))
+
+mean_dice_rv_data_aug = np.mean([score[0] for score in dice_scores_data_aug])
+std_dice_rv_data_aug= np.std([score[0] for score in dice_scores_data_aug])
+mean_dice_myo_data_aug = np.mean([score[1] for score in dice_scores_data_aug])
+std_dice_myo_data_aug = np.std([score[1] for score in dice_scores_data_aug])
+mean_dice_lv_data_aug = np.mean([score[2] for score in dice_scores_data_aug])
+std_dice_lv_data_aug = np.std([score[2] for score in dice_scores_data_aug])
+
+print(f'RV Dice Coefficient With Data Augmentation: Mean={mean_dice_rv_data_aug}, SD={std_dice_rv_data_aug}')
+print(f'MYO Dice Coefficient With Data Augmentation: Mean={mean_dice_myo_data_aug}, SD={std_dice_myo_data_aug}')
+print(f'LV Dice Coefficient With Data Augmentation: Mean={mean_dice_lv_data_aug}, SD={std_dice_lv_data_aug}')
+
+
+
+class SoftDiceLoss(nn.Module):
+    def __init__(self, smooth=1.):
+        super(SoftDiceLoss, self).__init__()
+        self.smooth = smooth
+
+    def forward(self, inputs, targets):
+        inputs = torch.sigmoid(inputs)  # input is logits
+        targets = convert_to_multi_labels(targets)  # Convert labels to multi-label format
+
+        intersection = (inputs * targets).sum(dim=(2, 3))
+        union = inputs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))
+
+        dice_coefficient = (2. * intersection + self.smooth) / (union + self.smooth)
+
+        dice_loss = 1 - dice_coefficient.mean()
+
+        return dice_loss
+
+net_soft_dice = UNet(n_channels=1, n_classes=3, C_base=32)
+optimizer = torch.optim.Adam(net_soft_dice.parameters(), lr=0.001)
+lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.95)
+
+
+dice_loss = SoftDiceLoss()
+
+solver = lab.Solver(
+    model=net_soft_dice,
+    optimizer=optimizer,
+    criterion=dice_loss,
+    lr_scheduler=lr_scheduler,
+)
+
+
+solver.train(
+    epochs=50,
+    data_loader=dataloader_train_aug,
+    val_loader=dataloader_val_aug,
+    save_path='./model_soft_dice_loss.pth',
+    img_name='soft_dice_loss'
+)
+
+accuracy_scores_soft_dice = []
+
+net_soft_dice.to(device)
+
+for images, labels in dataloader_test:
+    images = images.to(device)  
+    labels = labels.to(device) 
+    preds = net_soft_dice(images)
+    preds = torch.sigmoid(preds) > 0.5  # Applying threshold to get binary output
+    labels = convert_to_multi_labels(labels)  # Convert labels to multi-label format
+
+    accuracy_rv_soft_dice = get_accuracy(preds[:, 0, :, :], labels[:, 0, :, :])
+    accuracy_myo_soft_dice = get_accuracy(preds[:, 1, :, :], labels[:, 1, :, :])
+    accuracy_lv_soft_dice = get_accuracy(preds[:, 2, :, :], labels[:, 2, :, :])
+
+    accuracy_scores_soft_dice.append((accuracy_rv_soft_dice, accuracy_myo_soft_dice, accuracy_lv_soft_dice))
+
+mean_accuracy_rv_soft_dice = np.mean([score[0] for score in accuracy_scores_soft_dice])
+std_accuracy_rv_soft_dice = np.std([score[0] for score in accuracy_scores_soft_dice])
+mean_accuracy_myo_soft_dice = np.mean([score[1] for score in accuracy_scores_soft_dice])
+std_accuracy_myo_soft_dice = np.std([score[1] for score in accuracy_scores_soft_dice])
+mean_accuracy_lv_soft_dice  = np.mean([score[2] for score in accuracy_scores_soft_dice])
+std_accuracy_lv_soft_dice  = np.std([score[2] for score in accuracy_scores_soft_dice])
+
+print(f'RV Accuracy With Soft Dice Loss: Mean={mean_accuracy_rv_soft_dice}, SD={std_accuracy_rv_soft_dice}')
+print(f'MYO Accuracy With Soft Dice Loss: Mean={mean_accuracy_myo_soft_dice}, SD={std_accuracy_myo_soft_dice}')
+print(f'LV Accuracy With Soft Dice Loss: Mean={mean_accuracy_lv_soft_dice}, SD={std_accuracy_lv_soft_dice}')
+
+with open('output_results.txt', 'w') as file:
+    file.write(f'RV Dice Coefficient: Mean={mean_dice_rv}, SD={std_dice_rv}\n')
+    file.write(f'MYO Dice Coefficient: Mean={mean_dice_myo}, SD={std_dice_myo}\n')
+    file.write(f'LV Dice Coefficient: Mean={mean_dice_lv}, SD={std_dice_lv}\n')
+
+    file.write(f'RV Dice Coefficient Without Shortcut: Mean={mean_dice_rv_no_shortcut}, SD={std_dice_rv_no_shortcut}\n')
+    file.write(f'MYO Dice Coefficient Without Shortcut: Mean={mean_dice_myo_no_shortcut}, SD={std_dice_myo_no_shortcut}\n')
+    file.write(f'LV Dice Coefficient Without Shortcut: Mean={mean_dice_lv_no_shortcut}, SD={std_dice_lv_no_shortcut}\n')
+
+    file.write(f'RV Dice Coefficient With Data Augmentation: Mean={mean_dice_rv_data_aug}, SD={std_dice_rv_data_aug}\n')
+    file.write(f'MYO Dice Coefficient With Data Augmentation: Mean={mean_dice_myo_data_aug}, SD={std_dice_myo_data_aug}\n')
+    file.write(f'LV Dice Coefficient With Data Augmentation: Mean={mean_dice_lv_data_aug}, SD={std_dice_lv_data_aug}\n')
+    
+    file.write(f'RV Accuracy: Mean={mean_accuracy_rv}, SD={std_accuracy_rv}\n')
+    file.write(f'MYO Accuracy: Mean={mean_accuracy_myo}, SD={std_accuracy_myo}\n')
+    file.write(f'LV Accuracy: Mean={mean_accuracy_lv}, SD={std_accuracy_lv}\n')
+    
+    file.write(f'RV Accuracy With Soft Dice Loss: Mean={mean_accuracy_rv_soft_dice}, SD={std_accuracy_rv_soft_dice}\n')
+    file.write(f'MYO Accuracy With Soft Dice Loss: Mean={mean_accuracy_myo_soft_dice}, SD={std_accuracy_myo_soft_dice}\n')
+    file.write(f'LV Accuracy With Soft Dice Loss: Mean={mean_accuracy_lv_soft_dice}, SD={std_accuracy_lv_soft_dice}\n')
\ No newline at end of file