--- a
+++ b/predict_3d.py
@@ -0,0 +1,219 @@
+import matplotlib.pyplot as plt
+
+from data_3d import *
+from train_3d import *
+
+torch.manual_seed(42)
+import monai.transforms as mt
+
+"""-----------------------Arguments-----------------------"""
+parser = argparse.ArgumentParser(description="Prediction")
+parser.add_argument("--batch_size_test", type=str, default=1)
+
+args = parser.parse_args()
+test_batch_size = args.batch_size_test
+
+# Softmax
+soft = torch.nn.Softmax(dim=1).cuda()
+# IoU
+IOU_metric = IoU(num_classes=4, absent_score=-1., reduction="none").cuda()
+# F1 score
+f1_metric = F1(num_classes=4, mdmc_average="samplewise", average='none').cuda()
+
+"""---------Post Processing---------"""
+keep_largest = monai.transforms.KeepLargestConnectedComponent(applied_labels=[0, 1, 2, 3], independent=True)
+fill_holes = monai.transforms.FillHoles()
+
+"""---------Test Data---------"""
+test_transform = mt.Compose([
+    mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False)
+])
+# Test dataset
+test_data = DataLoader(test_loader_ACDC3(transform=test_transform, test_index=None), batch_size=1, shuffle=False)
+
+
+#  padding: just pass the image
+def Pad_images(image):
+    orig_shape = list(image.size())
+    original_x = orig_shape[2]
+    original_y = orig_shape[3]
+    original_z = orig_shape[4]
+    new_x = (16 - (original_x % 16)) + original_x
+    new_y = (16 - (original_y % 16)) + original_y
+    new_z = original_z
+    new_shape = [new_x, new_y, new_z]
+    b, c, h, w, d = image.shape
+    m = image.min()
+    x_max = new_shape[0]
+    y_max = new_shape[1]
+    z_max = new_shape[2]
+    result = torch.Tensor(b, c, x_max, y_max, z_max).fill_(m)
+    xx = (x_max - h) // 2
+    yy = (y_max - w) // 2
+    zz = (z_max - d) // 2
+    result[:, :, xx:xx + h, yy:yy + w, zz:zz + d] = image
+    return result, tuple([xx, yy, zz])  # result is a torch tensor in CPU --> have to move to GPU
+
+
+# pass the padded image, the indices and the original shape
+def UnPad_imges(image, indices, org_shape):
+    b, c, h, w, d = org_shape
+    xx = indices[0]
+    yy = indices[1]
+    zz = indices[2]
+    return image[:, :, xx:xx + h, yy:yy + w, zz:zz + d]  # image is a torch tensor --> have to move to GPU
+
+
+# def show_results(res):
+#     for slices in range(res.shape[2]):
+#         out_show_res = res[:, :, slices]
+#         plt.imshow(out_show_res)
+#         plt.show()
+
+
+# save the predictions and ground truths
+def save_pred(img, mask, pred, outpath, name_model, idx, aff):
+    # Folder to save the results
+    if not os.path.exists(os.path.join(outpath, name_model)):
+        os.makedirs(os.path.join(outpath, name_model))
+    out_save_path_image = os.path.join(outpath, name_model, f"{idx}_image" + '.nii.gz')
+    out_save_path_pred = os.path.join(outpath, name_model, f"{idx}_pred" + '.nii.gz')
+    out_save_path_mask = os.path.join(outpath, name_model, f"{idx}_gt" + '.nii.gz')
+    # affine = np.diag([-1.25, -1.25, 10.0, 1.0])
+    # print(aff.shape, aff)
+    aff = aff.squeeze().cpu()
+    affine = np.diag([torch.diagonal(aff)[0], torch.diagonal(aff)[1],
+                      torch.diagonal(aff)[2], torch.diagonal(aff)[3]])
+    print(affine)
+
+    # Save images
+    img = img.squeeze()
+    img = np.array(img.cpu())
+    # show_results(img)
+    # print(type(img))
+    for slices in range(img.shape[2]):
+        out_show_img = img[:, :, slices]
+        if not os.path.exists(os.path.join(outpath, name_model)):
+            os.makedirs(os.path.join(outpath, name_model))
+        out_save_path_image = os.path.join(outpath, name_model, f"{idx}_{slices}_image" + '.png')
+        image_file_name = f"{idx}_{slices}_image"
+        plt.title = image_file_name
+        plt.imsave(out_save_path_image, out_show_img, format='png', cmap='gray')
+        plt.close()
+    # saves the resampled images
+    img = nib.Nifti1Image(img, affine)
+    nib.save(img, out_save_path_image)
+
+    # Save ground truths
+    mask = mask.squeeze()
+    mask = np.array(mask.cpu())
+    # print(type(mask))
+    for slices in range(mask.shape[2]):
+        out_show_mask = mask[:, :, slices]
+        if not os.path.exists(os.path.join(outpath, name_model)):
+            os.makedirs(os.path.join(outpath, name_model))
+        out_save_path_mask = os.path.join(outpath, name_model, f"{idx}_{slices}_gt" + '.png')
+        image_file_name = f"{idx}_{slices}_gt"
+        plt.title = image_file_name
+        plt.imsave(out_save_path_mask, out_show_mask, format='png', cmap='gray')
+        plt.close()
+    # saves the resampled masks
+    mask = nib.Nifti1Image(mask, affine)
+    nib.save(mask, out_save_path_mask)
+
+    # Save predictions
+    # Post Processing
+    final_prediction = torch.argmax(pred, dim=1)
+    final_prediction = keep_largest(final_prediction)
+    # final_prediction = fill_holes(final_prediction)
+    # final_prediction = torch.argmax(final_prediction, dim=1)
+    final_pred = np.array(final_prediction.cpu().squeeze())
+    # print(type(final_pred))
+    for slices in range(final_pred.shape[2]):
+        out_show_pred = final_pred[:, :, slices]
+        if not os.path.exists(os.path.join(outpath, name_model)):
+            os.makedirs(os.path.join(outpath, name_model))
+        out_save_path_pred = os.path.join(outpath, name_model, f"{idx}_{slices}_pred" + '.png')
+        image_file_name = f"{idx}_{slices}_pred"
+        plt.title = image_file_name
+        plt.imsave(out_save_path_pred, out_show_pred, format='png', cmap='gray')
+        plt.close()
+    # saves the resampled predictions
+    final_pred = nib.Nifti1Image(final_pred, affine)
+    nib.save(final_pred, out_save_path_pred)
+
+
+def save_results(iou, dice, out_path, model_name):
+    # Folder to store the plots
+    if not os.path.exists(os.path.join(out_path, model_name)):
+        os.makedirs(os.path.join(out_path, model_name))
+    out_save_path = os.path.join(os.path.join(out_path, model_name))
+    # IoU
+    sorted_iou = sorted(iou)
+    print("Top IoU:", sorted_iou[-1])
+    # Dice Scores
+    sorted_dice = sorted(dice)
+    print("Top Dice Score:", sorted_dice[-1])
+    # save results
+    result_dict = {"IoU": sorted_iou[-1],
+                   "Dice Score": sorted_dice[-1]}
+
+    file_name = 'results.txt'
+    completeName = os.path.join(out_save_path, file_name)
+    with open(completeName, 'w') as file:
+        file.write(str(result_dict))
+
+
+def test_results(model, out_path, model_name):
+    all_iou = []
+    all_dice = []
+    indices = 0
+    for items in test_data:
+        image = items["image"].cuda()
+        image_shape = image.shape
+        mask = items["mask"].long().cuda().squeeze(dim=1)
+        # print(mask.shape, image_shape)
+        # pad the image
+        image, ind = Pad_images(image)
+        pred = model(image.float().cuda())
+        # unpad the images
+        pred = UnPad_imges(pred, ind, image_shape).cuda()
+        pred = soft(pred)
+        # pred = torch.argmax(pred, dim=1)
+        ###############################################
+        img_affine = items['image_meta_dict']['affine']
+        mask_affine = items['mask_meta_dict']['affine']
+        image_affine_original = items['image_meta_dict']['original_affine']
+        mask_affine_original = items['mask_meta_dict']['original_affine']
+        # print(img_affine, image_affine_original)
+        ###############################################
+        # Save results
+        save_pred(image, mask, pred, out_path, model_name, indices, image_affine_original)
+        # calculate iou
+        iou_all_class = IOU_metric(pred, mask)
+        iou_all_class = iou_all_class.cpu().numpy()
+        iou_all_class = iou_all_class[iou_all_class != -1.]
+        iou = iou_all_class.mean()
+        all_iou.append(iou)
+        # calculate dice score
+        dice_all_class = f1_metric(pred, mask)
+        dice_all_class = dice_all_class.cpu().numpy()
+        # dice_all_class = dice_all_class[~np.isnan(dice_all_class)]
+        dice_all_class = dice_all_class[dice_all_class != -1.]
+        dice = dice_all_class.mean()
+        all_dice.append(dice)
+        indices = indices + 1
+    # Save plots
+    save_results(all_iou, all_dice, out_path, model_name)
+
+
+if __name__ == "__main__":
+    name = str("UNet3D_Best_0.5_Fold_4")
+    model_path = str(Path(r"../unet/cluster_results/best_models", name + ".pt"))
+    if not os.path.exists(r'../unet/test_results_3d/'):
+        os.makedirs(r'../unet/test_results_3d/')
+    result_path = r'../unet/test_results_3d/'
+    model = torch.load(model_path)
+    with torch.no_grad():
+        model.eval().cuda()
+        test_results(model, result_path, name)