Diff of /U-Net/test_blood.py [000000] .. [6f3ba0]

Switch to unified view

a b/U-Net/test_blood.py
1
import argparse
2
import logging
3
import os
4
import random
5
import sys
6
import copy
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
import torchvision.transforms as transforms
11
import torchvision.transforms.functional as TF
12
from pathlib import Path
13
from torch import optim
14
from torch.utils.data import DataLoader, random_split
15
from tqdm import tqdm
16
17
from evaluate import evaluate
18
from unet.unet_model import UNet
19
from utils.data_loading import BasicDataset, CarvanaDataset
20
from utils.dice_score import dice_loss
21
22
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
23
24
PRED_MODEL = './epoch_26_acc_0.90_best_val_acc.pth'
25
26
dir_img = Path('./data/test/imgs/')
27
dir_mask = Path('./data/test/masks/')
28
#dir_checkpoint = Path('./out_checkpoints/')
29
30
def test_model(
31
        model, device, 
32
        epochs: int = 1,
33
        batch_size: int = 1,
34
        learning_rate: float=0.001,
35
        img_scale: float = 0.5,
36
        amp: bool = False,
37
        weight_decay: float = 1e-8,
38
    ):
39
    
40
    data_transform = transforms.Compose([
41
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42
    ])
43
44
    try:
45
        dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
46
    except (AssertionError, RuntimeError, IndexError):
47
        dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform)
48
49
    loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
50
    test_loader = DataLoader(dataset, shuffle=True, **loader_args)
51
    optimizer = optim.Adam(model.parameters(),lr=learning_rate, weight_decay=weight_decay)
52
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
53
54
    test_score = evaluate(model, test_loader, device, amp)
55
    scheduler.step(test_score)
56
57
    logging.info('Test Dice score: {}'.format(test_score))
58
59
60
def get_args():
61
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
62
    parser.add_argument('--model', '-m', default= PRED_MODEL, metavar='FILE',help="Specify the file in which the model is stored")
63
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', default = dir_img)
64
    parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
65
    parser.add_argument('--viz', '-v', action='store_true',
66
                        help='Visualize the images as they are processed')
67
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
68
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
69
                        help='Minimum probability value to consider a mask pixel white')
70
    parser.add_argument('--scale', '-s', type=float, default=0.5,
71
                        help='Scale factor for the input images')
72
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
73
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
74
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
75
76
    return parser.parse_args()
77
78
79
if __name__ == '__main__':
80
    args = get_args()
81
82
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
83
    #device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
84
    logging.info(f'Using device {device}')
85
86
    """
87
    Change here to adapt to your data
88
    n_channels=3 for RGB images
89
    n_classes is the number of probabilities you want to get per pixel
90
    """
91
    model = UNet(n_channels=1, n_classes=5, bilinear=True)
92
93
    #Load pre-trained model
94
    model.load_state_dict(torch.load(PRED_MODEL, map_location=device))
95
96
    model = model.to(memory_format=torch.channels_last)
97
98
    logging.info(f'Network:\n'
99
                 f'\t{model.n_channels} input channels\n'
100
                 f'\t{model.n_classes} output channels (classes)\n'
101
                 f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
102
103
    # if args.load:
104
    #     state_dict = torch.load(args.load, map_location=device)
105
    #     del state_dict['mask_values']
106
    #     model.load_state_dict(state_dict)
107
    #     logging.info(f'Model loaded from {args.load}')
108
109
    model.to(device=device)
110
111
    test_model(
112
        model=model,
113
        device=device,
114
        img_scale=args.scale,
115
        amp=args.amp
116
    )