Diff of /U-Net/black_blood.py [000000] .. [6f3ba0]

Switch to unified view

a b/U-Net/black_blood.py
1
import argparse
2
import logging
3
import os
4
5
import numpy as np
6
import torch
7
import torch.nn.functional as F
8
from PIL import Image
9
from torchvision import transforms
10
11
from utils.data_loading import BasicDataset
12
from unet.unet_model import UNet
13
from utils.utils import plot_img_and_mask
14
15
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
16
17
directory = r'/home/data/spleen_blood/data/test/imgs'
18
19
def predict_img(net,
20
                full_img,
21
                device,
22
                scale_factor=1,
23
                out_threshold=0.5):
24
    net.eval()
25
    img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
26
    img = img.unsqueeze(0)
27
    img = img.to(device=device, dtype=torch.float32)
28
29
    with torch.no_grad():
30
        output = net(img).cpu()
31
        output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
32
        if net.n_classes > 1:
33
            mask = output.argmax(dim=1)
34
        else:
35
            mask = torch.sigmoid(output) > out_threshold
36
37
    return mask[0].long().squeeze().numpy()
38
39
def get_output_filenames(in_files):
40
    return f'{os.path.splitext(in_files)[0]}_OUT.png'
41
42
43
def mask_to_image(mask: np.ndarray, mask_values):
44
    if isinstance(mask_values[0], list):
45
        out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
46
    elif mask_values == [0, 1]:
47
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
48
    else:
49
        out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
50
51
    if mask.ndim == 3:
52
        mask = np.argmax(mask, axis=0)
53
54
    for i, v in enumerate(mask_values):
55
        out[mask == i] = v
56
57
    return Image.fromarray(out)
58
59
60
if __name__ == '__main__':
61
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
62
    
63
    model = './epoch_26_acc_0.90_best_val_acc.pth'
64
65
    for fn in os.listdir(directory):
66
        in_files = os.path.join(directory, fn)
67
        out_files = get_output_filenames(in_files)
68
        #print(fn, in_files, out_files)
69
70
        net = UNet(n_channels=1, n_classes=5, bilinear=True)
71
        
72
        logging.info(f'Loading model {model}')
73
        logging.info(f'Using device {device}')
74
75
        net.to(device=device)
76
        state_dict = torch.load(model, map_location=device)
77
        mask_values = state_dict.pop('mask_values', [0, 1])
78
        net.load_state_dict(state_dict)
79
80
        logging.info('Model loaded!')
81
        
82
        filename = in_files
83
        logging.info(f'Predicting image {filename} ...')
84
        img = Image.open(filename)
85
86
        mask = predict_img(net=net,
87
                           full_img=img,
88
                           scale_factor=0.5,
89
                           out_threshold=0.5,
90
                           device=device)
91
                           
92
        out_filename = out_files
93
        result = mask_to_image(mask, mask_values)
94
        result.save(out_filename)
95
        logging.info(f'Mask saved to {out_filename}')
96
97
        #logging.info(f'Visualizing results for image {filename}, close to continue...')
98
        #plot_img_and_mask(img, mask)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125