[6f3ba0]: / U-Net / test_blood.py

Download this file

117 lines (93 with data), 4.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import logging
import os
import random
import sys
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from evaluate import evaluate
from unet.unet_model import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
PRED_MODEL = './epoch_26_acc_0.90_best_val_acc.pth'
dir_img = Path('./data/test/imgs/')
dir_mask = Path('./data/test/masks/')
#dir_checkpoint = Path('./out_checkpoints/')
def test_model(
model, device,
epochs: int = 1,
batch_size: int = 1,
learning_rate: float=0.001,
img_scale: float = 0.5,
amp: bool = False,
weight_decay: float = 1e-8,
):
data_transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError, IndexError):
dataset = BasicDataset(dir_img, dir_mask, img_scale, data_transform)
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
test_loader = DataLoader(dataset, shuffle=True, **loader_args)
optimizer = optim.Adam(model.parameters(),lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
test_score = evaluate(model, test_loader, device, amp)
scheduler.step(test_score)
logging.info('Test Dice score: {}'.format(test_score))
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--model', '-m', default= PRED_MODEL, metavar='FILE',help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', default = dir_img)
parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
parser.add_argument('--viz', '-v', action='store_true',
help='Visualize the images as they are processed')
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
help='Minimum probability value to consider a mask pixel white')
parser.add_argument('--scale', '-s', type=float, default=0.5,
help='Scale factor for the input images')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
#device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
"""
Change here to adapt to your data
n_channels=3 for RGB images
n_classes is the number of probabilities you want to get per pixel
"""
model = UNet(n_channels=1, n_classes=5, bilinear=True)
#Load pre-trained model
model.load_state_dict(torch.load(PRED_MODEL, map_location=device))
model = model.to(memory_format=torch.channels_last)
logging.info(f'Network:\n'
f'\t{model.n_channels} input channels\n'
f'\t{model.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
# if args.load:
# state_dict = torch.load(args.load, map_location=device)
# del state_dict['mask_values']
# model.load_state_dict(state_dict)
# logging.info(f'Model loaded from {args.load}')
model.to(device=device)
test_model(
model=model,
device=device,
img_scale=args.scale,
amp=args.amp
)