Diff of /utils.py [000000] .. [4b8af8]

Switch to unified view

a b/utils.py
1
import os
2
import io
3
import base64
4
5
import numpy as np
6
import pandas as pd
7
import cv2
8
9
import matplotlib.pyplot as plt
10
11
import torch
12
import torch.nn as nn
13
from albumentations import Normalize
14
15
import time
16
from IPython.display import clear_output
17
from IPython.display import HTML
18
19
from loss_metric import dice_coef_metric_per_classes, jaccard_coef_metric_per_classes
20
21
def get_one_slice_data(img_name: str,
22
                       mask_name: str,
23
                       root_imgs_path: str = "images/",
24
                       root_masks_path: str = "masks/",) -> np.ndarray:
25
26
    img_path = os.path.join('images/', img_name)
27
    mask_path = os.path.join('masks/', mask_name)
28
    one_slice_img = cv2.imread(img_path)#[:,:,0] uncomment for grayscale
29
    one_slice_mask = cv2.imread(mask_path)
30
    one_slice_mask[one_slice_mask < 240] = 0  # remove artifacts
31
    one_slice_mask[one_slice_mask >= 240] = 255
32
33
    return one_slice_img, one_slice_mask
34
35
36
def get_id_predictions(net: nn.Module,
37
                       ct_scan_id_df: pd.DataFrame,
38
                       root_imgs_dir: str,
39
                       treshold: float = 0.3) -> list:
40
41
    """
42
    Factory for getting predictions and storing them and images in lists as uint8 images.
43
    Params:
44
        net: model for prediction.
45
        ct_scan_id_df: df with unique patient id.
46
        root_imgs_dir: root path for images.
47
        treshold: threshold for probabilities.
48
    """
49
    sigmoid = lambda x: 1 / (1 + np.exp(-x))
50
    images = []
51
    predictions = []
52
    net.eval()
53
    device = "cuda" if torch.cuda.is_available() else "cpu"
54
    print("device:", device)
55
    with torch.no_grad():
56
        for idx in range(len(ct_scan_id_df)):
57
            img_name = ct_scan_id_df.loc[idx, "ImageId"]
58
            path = os.path.join(root_imgs_dir, img_name)
59
60
            img_ = cv2.imread(path)
61
    
62
            img = Normalize().apply(img_)
63
            tensor = torch.FloatTensor(img).permute(2, 0, 1).unsqueeze(0)
64
            prediction = net.forward(tensor.to(device))
65
            prediction = prediction.cpu().detach().numpy()
66
            prediction = prediction.squeeze(0).transpose(1, 2, 0)
67
            prediction = sigmoid(prediction)
68
            prediction = (prediction >= treshold).astype(np.float32)
69
70
            predictions.append((prediction * 255).astype("uint8"))
71
            images.append(img_)
72
73
    return images, predictions
74
75
76
# Save image in original resolution
77
# helpful link - https://stackoverflow.com/questions/34768717/matplotlib-unable-to-save-image-in-same-resolution-as-original-image
78
79
def get_overlaid_masks_on_image(
80
                one_slice_image: np.ndarray,
81
                one_slice_mask: np.ndarray, 
82
                w: float = 512,
83
                h: float = 512, 
84
                dpi: float = 100,
85
                write: bool = False,
86
                path_to_save: str = '/content/',
87
                name_to_save: str = 'img_name'):
88
    """overlap masks on image and save this as a new image."""
89
90
    path_to_save_ = os.path.join(path_to_save, name_to_save)
91
    lung, heart, trachea = [one_slice_mask[:, :, i] for i in range(3)]
92
    figsize = (w / dpi), (h / dpi)
93
    fig = plt.figure(figsize=(figsize))
94
    fig.add_axes([0, 0, 1, 1])
95
96
    # image
97
    plt.imshow(one_slice_image, cmap="bone")
98
99
    # overlaying segmentation masks
100
    plt.imshow(np.ma.masked_where(lung == False, lung),
101
            cmap='cool', alpha=0.3)
102
    plt.imshow(np.ma.masked_where(heart == False, heart),
103
            cmap='autumn', alpha=0.3)
104
    plt.imshow(np.ma.masked_where(trachea == False, trachea),
105
               cmap='autumn_r', alpha=0.3) 
106
107
    plt.axis('off')
108
    fig.savefig(f"{path_to_save_}.png",bbox_inches='tight', 
109
                pad_inches=0.0, dpi=dpi,  format="png")
110
    if write:
111
        plt.close()
112
    else:
113
        plt.show()
114
        
115
        
116
def get_overlaid_masks_on_full_ctscan(ct_scan_id_df: pd.DataFrame, path_to_save: str):
117
    """
118
    Creating images with overlaid masks on each slice of CT scan.
119
    Params:
120
         ct_scan_id_df: df with unique patient id.
121
         path_to_save: path to save images.
122
    """
123
    num_slice = len(ct_scan_id_df)
124
    for slice_ in range(num_slice):
125
        img_name = ct_scan_id_df.loc[slice_, "ImageId"]
126
        mask_name = ct_scan_id_df.loc[slice_, "MaskId"]
127
        one_slice_img, one_slice_mask = get_one_slice_data(img_name, mask_name)
128
        get_overlaid_masks_on_image(one_slice_img,
129
                                one_slice_mask,
130
                                write=True, 
131
                                path_to_save=path_to_save,
132
                                name_to_save=str(slice_)
133
                                )
134
135
def create_video(path_to_imgs: str, video_name: str, framerate: int):
136
    """
137
    Create video from images.
138
    Params:
139
        path_to_imgs: path to dir with images.
140
        video_name: name for saving video.
141
        framerate: num frames per sec in video.
142
    """
143
    img_names = sorted(os.listdir(path_to_imgs), key=lambda x: int(x[:-4]))  # img_name must be numbers
144
    img_path = os.path.join(path_to_imgs, img_names[0])
145
    frame_width, frame_height, _ = cv2.imread(img_path).shape
146
    fourc = cv2.VideoWriter_fourcc(*'MP4V')
147
    video = cv2.VideoWriter(video_name + ".mp4", 
148
                            fourc, 
149
                            framerate, 
150
                            (frame_width, frame_height))
151
152
    for img_name in img_names:
153
        img_path = os.path.join(path_to_imgs, img_name)
154
        image = cv2.imread(img_path)
155
        video.write(image)
156
            
157
    cv2.destroyAllWindows()
158
    video.release()
159
160
    
161
def compute_scores_per_classes(model,
162
                               dataloader,
163
                               classes):
164
    """
165
    Compute Dice and Jaccard coefficients for each class.
166
    Params:
167
        model: neural net for make predictions.
168
        dataloader: dataset object to load data from.
169
        classes: list with classes.
170
        Returns: dictionaries with dice and jaccard coefficients for each class for each slice.
171
    """
172
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
173
    dice_scores_per_classes = {key: list() for key in classes}
174
    iou_scores_per_classes = {key: list() for key in classes}
175
176
    with torch.no_grad():
177
        for i, (imgs, targets) in enumerate(dataloader):
178
            imgs, targets = imgs.to(device), targets.to(device)
179
            logits = model(imgs)
180
            logits = logits.detach().cpu().numpy()
181
            targets = targets.detach().cpu().numpy()
182
            
183
            dice_scores = dice_coef_metric_per_classes(logits, targets)
184
            iou_scores = jaccard_coef_metric_per_classes(logits, targets)
185
186
            for key in dice_scores.keys():
187
                dice_scores_per_classes[key].extend(dice_scores[key])
188
189
            for key in iou_scores.keys():
190
                iou_scores_per_classes[key].extend(iou_scores[key])
191
192
    return dice_scores_per_classes, iou_scores_per_classes