--- a
+++ b/U-Net/test_blood.py
@@ -0,0 +1,116 @@
+import argparse
+import logging
+import os
+import random
+import sys
+import copy
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as TF
+from pathlib import Path
+from torch import optim
+from torch.utils.data import DataLoader, random_split
+from tqdm import tqdm
+
+from evaluate import evaluate
+from unet.unet_model import UNet
+from utils.data_loading import BasicDataset, CarvanaDataset
+from utils.dice_score import dice_loss
+
+device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
+
+PRED_MODEL = './epoch_26_acc_0.90_best_val_acc.pth'
+
+dir_img = Path('./data/test/imgs/')
+dir_mask = Path('./data/test/masks/')
+#dir_checkpoint = Path('./out_checkpoints/')
+
+def test_model(
+        model, device, 
+        epochs: int = 1,
+        batch_size: int = 1,
+        learning_rate: float=0.001,
+        img_scale: float = 0.5,
+        amp: bool = False,
+        weight_decay: float = 1e-8,
+    ):
+    
+    data_transform = transforms.Compose([
+        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+    ])
+
+    try:
+        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
+    except (AssertionError, RuntimeError, IndexError):
+        dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform)
+
+    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
+    test_loader = DataLoader(dataset, shuffle=True, **loader_args)
+    optimizer = optim.Adam(model.parameters(),lr=learning_rate, weight_decay=weight_decay)
+    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
+
+    test_score = evaluate(model, test_loader, device, amp)
+    scheduler.step(test_score)
+
+    logging.info('Test Dice score: {}'.format(test_score))
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
+    parser.add_argument('--model', '-m', default= PRED_MODEL, metavar='FILE',help="Specify the file in which the model is stored")
+    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', default = dir_img)
+    parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
+    parser.add_argument('--viz', '-v', action='store_true',
+                        help='Visualize the images as they are processed')
+    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
+    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
+                        help='Minimum probability value to consider a mask pixel white')
+    parser.add_argument('--scale', '-s', type=float, default=0.5,
+                        help='Scale factor for the input images')
+    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
+    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
+    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
+
+    return parser.parse_args()
+
+
+if __name__ == '__main__':
+    args = get_args()
+
+    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
+    #device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
+    logging.info(f'Using device {device}')
+
+    """
+    Change here to adapt to your data
+    n_channels=3 for RGB images
+    n_classes is the number of probabilities you want to get per pixel
+    """
+    model = UNet(n_channels=1, n_classes=5, bilinear=True)
+
+    #Load pre-trained model
+    model.load_state_dict(torch.load(PRED_MODEL, map_location=device))
+
+    model = model.to(memory_format=torch.channels_last)
+
+    logging.info(f'Network:\n'
+                 f'\t{model.n_channels} input channels\n'
+                 f'\t{model.n_classes} output channels (classes)\n'
+                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
+
+    # if args.load:
+    #     state_dict = torch.load(args.load, map_location=device)
+    #     del state_dict['mask_values']
+    #     model.load_state_dict(state_dict)
+    #     logging.info(f'Model loaded from {args.load}')
+
+    model.to(device=device)
+
+    test_model(
+        model=model,
+        device=device,
+        img_scale=args.scale,
+        amp=args.amp
+    )