Diff of /utils.py [000000] .. [77dc1e]

Switch to unified view

a b/utils.py
1
import numpy as np
2
import pandas as pd
3
import os
4
import torch
5
from torchvision import transforms
6
from torch.utils.data import DataLoader
7
import sys
8
from tqdm import tqdm
9
import matplotlib.pyplot as plt
10
from sklearn.metrics import roc_auc_score, roc_curve, log_loss, auc
11
from scipy import interp
12
from itertools import cycle
13
sys.path.append("/home/anjum/PycharmProjects/kaggle")
14
# sys.path.append("/home/anjum/rsna_code")  # GCP
15
from rsna_intracranial_hemorrhage_detection.datasets import ICHDataset
16
17
INPUT_DIR = "/mnt/storage_dimm2/kaggle_data/rsna-intracranial-hemorrhage-detection/"
18
19
20
def build_tta_loaders(img_size, dataset, phase=1, image_filter=None, batch_size=32, num_workers=1,
21
                      image_folder=None, png=True):
22
    if type(img_size) == int:
23
        img_size = (img_size, img_size)
24
25
    def null_transform(image):
26
        image = transforms.functional.to_pil_image(image)
27
        image = transforms.functional.resize(image, img_size)
28
        tensor = transforms.functional.to_tensor(image)
29
        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
        return tensor
31
32
    def hflip(image):
33
        image = transforms.functional.to_pil_image(image)
34
        image = transforms.functional.hflip(image)
35
        image = transforms.functional.resize(image, img_size)
36
        tensor = transforms.functional.to_tensor(image)
37
        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
38
        return tensor
39
40
    def rotate_pos(image):
41
        image = transforms.functional.to_pil_image(image)
42
        image = transforms.functional.affine(image, angle=10, translate=(0, 0), scale=1.0, shear=0)
43
        image = transforms.functional.resize(image, img_size)
44
        tensor = transforms.functional.to_tensor(image)
45
        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
46
        return tensor
47
48
    def rotate_neg(image):
49
        image = transforms.functional.to_pil_image(image)
50
        image = transforms.functional.affine(image, angle=-10, translate=(0, 0), scale=1.0, shear=0)
51
        image = transforms.functional.resize(image, img_size)
52
        tensor = transforms.functional.to_tensor(image)
53
        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
        return tensor
55
56
    # tta = [null_transform, hflip]
57
    tta = [null_transform, hflip, rotate_pos, rotate_neg]
58
    loaders = []
59
60
    for augmentation in tta:
61
        tta_dataset = ICHDataset(dataset, phase=phase, image_filter=image_filter, transforms=augmentation,
62
                                 image_folder=image_folder, png=png)
63
        loaders.append(DataLoader(tta_dataset, batch_size=batch_size, num_workers=num_workers))
64
    return loaders
65
66
67
def infer(model, loader, device, desc):
68
    model.eval()
69
    with torch.no_grad():
70
        predictions, targets = [], []
71
        for image, target in tqdm(loader, desc=desc):
72
            image = image.to(device)
73
            y_hat = model(image)
74
            predictions.append(y_hat.cpu())
75
            targets.append(target)
76
77
        predictions = torch.cat(predictions)
78
        targets = torch.cat(targets)
79
    return predictions, targets
80
81
82
def test_time_augmentation(model, loaders, device, desc):
83
    tta_predictions = []
84
    for i, loader in enumerate(loaders):
85
        predictions, targets = infer(model, loader, device, f"{desc} TTA {i}")
86
        tta_predictions.append(torch.unsqueeze(predictions, -1))
87
88
    tta_predictions = torch.cat(tta_predictions, -1)
89
    return torch.mean(tta_predictions, dim=-1), targets
90
91
92
class EarlyStopping:
93
    """
94
    Early stops the training if validation loss doesn't improve after a given patience.
95
    https://github.com/Bjarten/early-stopping-pytorch
96
    """
97
    def __init__(self, patience=7, verbose=False, delta=0, file_path='checkpoint.pt', parallel=False):
98
        """
99
        :param patience: How long to wait after last time validation loss improved. Default: 7
100
        :param verbose: If True, prints a message for each validation loss improvement. Default: False
101
        :param delta: Minimum change in the monitored quantity to qualify as an improvement. Default: 0
102
        :param file_path: Path to save checkpoint file
103
        :param parallel: If True, the multi-GPU model is saves as a single GPU model
104
        """
105
106
        self.patience = patience
107
        self.verbose = verbose
108
        self.counter = 0
109
        self.best_score = None
110
        self.early_stop = False
111
        self.val_loss_min = np.Inf
112
        self.delta = delta
113
        self.file_path = file_path
114
        self.parallel = parallel
115
        self.parallel_model = None
116
117
    def __call__(self, val_loss, model, **kwargs):
118
119
        score = -val_loss
120
121
        if self.best_score is None:
122
            self.best_score = score
123
            self.save_checkpoint(val_loss, model, **kwargs)
124
        elif score < self.best_score - self.delta:
125
            self.counter += 1
126
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
127
            if self.counter >= self.patience:
128
                self.early_stop = True
129
        else:
130
            self.best_score = score
131
            self.save_checkpoint(val_loss, model, **kwargs)
132
            self.counter = 0
133
134
    def save_checkpoint(self, val_loss, model, **save_items):
135
        """Saves model and addition items when validation loss decrease."""
136
        if self.verbose:
137
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
138
139
        if save_items == {}:
140
            torch.save(model.state_dict(), self.file_path)
141
        else:
142
            if self.parallel:
143
                self.parallel_model = model.state_dict()
144
                save_items['model'] = model.module.state_dict()
145
            else:
146
                save_items['model'] = model.state_dict()
147
            save_items["stopping_params"] = self.state_dict()
148
            torch.save(save_items, self.file_path)
149
        self.val_loss_min = val_loss
150
151
    def state_dict(self):
152
        state = {
153
            # "counter": self.counter,
154
            "best_score": self.best_score,
155
            "val_loss_min": self.val_loss_min,
156
        }
157
        return state
158
159
    def load_state_dict(self, state):
160
        # self.counter = state["counter"]
161
        self.best_score = state["best_score"]
162
        self.val_loss_min = state["val_loss_min"]
163
164
165
def plot_roc_curve(target, predictions, file_path, metric=None):
166
167
    if type(target) == torch.Tensor:
168
        target = target.numpy()
169
    if type(predictions) == torch.Tensor:
170
        predictions = predictions.numpy()
171
172
    plt.figure()
173
    fpr, tpr, _ = roc_curve(target, predictions)
174
    score = roc_auc_score(target, predictions)
175
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.4f)' % score)
176
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
177
    plt.xlim([0.0, 1.0])
178
    plt.ylim([0.0, 1.05])
179
    plt.xlabel('False Positive Rate')
180
    plt.ylabel('True Positive Rate')
181
    if metric is not None:
182
        plt.title(f'Receiver operating characteristic. LogLoss: {metric:.4f}')
183
    else:
184
        plt.title(f'Receiver operating characteristic')
185
    plt.legend(loc="lower right")
186
    plt.savefig(file_path)
187
188
189
def plot_multiclass_roc_curve(target, predictions, file_path, metric=None):
190
    # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
191
192
    if type(target) == torch.Tensor:
193
        target = target.numpy()
194
    if type(predictions) == torch.Tensor:
195
        predictions = predictions.numpy()
196
197
    n_classes = target.shape[1]
198
199
    # Compute ROC curve and ROC area for each class
200
    fpr = dict()
201
    tpr = dict()
202
    roc_auc = dict()
203
    log_loss_vals = dict()
204
    for i in range(n_classes):
205
        fpr[i], tpr[i], _ = roc_curve(target[:, i], predictions[:, i])
206
        roc_auc[i] = auc(fpr[i], tpr[i])
207
        log_loss_vals[i] = log_loss(target[:, i], predictions[:, i])
208
209
    # Compute micro-average ROC curve and ROC area
210
    fpr["micro"], tpr["micro"], _ = roc_curve(target.ravel(), predictions.ravel())
211
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
212
213
    # First aggregate all false positive rates
214
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
215
216
    # Then interpolate all ROC curves at this points
217
    mean_tpr = np.zeros_like(all_fpr)
218
    for i in range(n_classes):
219
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])
220
221
    # Finally average it and compute AUC
222
    mean_tpr /= n_classes
223
224
    fpr["macro"] = all_fpr
225
    tpr["macro"] = mean_tpr
226
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
227
228
    # Plot all ROC curves
229
    plt.figure()
230
    plt.plot(fpr["micro"], tpr["micro"],
231
             label='micro-average ROC curve (area = {0:0.2f})'
232
                   ''.format(roc_auc["micro"]),
233
             color='deeppink', linestyle=':', linewidth=4)
234
235
    plt.plot(fpr["macro"], tpr["macro"],
236
             label='macro-average ROC curve (area = {0:0.2f})'
237
                   ''.format(roc_auc["macro"]),
238
             color='navy', linestyle=':', linewidth=4)
239
240
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'magenta', 'lawngreen', 'gold'])
241
    for i, color in zip(range(n_classes), colors):
242
        plt.plot(fpr[i], tpr[i], color=color, lw=2, label='ROC curve of class {0} (area={1:0.2f}, log_loss={2:0.3f})'
243
                                                          ''.format(i, roc_auc[i], log_loss_vals[i]))
244
245
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
246
    plt.xlim([0.0, 1.0])
247
    plt.ylim([0.0, 1.05])
248
    plt.xlabel('False Positive Rate')
249
    plt.ylabel('True Positive Rate')
250
251
    if metric is not None:
252
        plt.title(f'Receiver operating characteristic. LogLoss: {metric:.4f}')
253
    else:
254
        plt.title(f'Receiver operating characteristic')
255
256
    plt.legend(loc="lower right")
257
    plt.savefig(file_path)
258
259
260
def reindex_submission(df, stage="test1"):
261
    if stage not in ["test1", "test2"]:
262
        return df
263
    elif stage == "test1":
264
        sub = pd.read_csv(os.path.join(INPUT_DIR, "stage_1_sample_submission.csv"))
265
    else:
266
        sub = pd.read_csv(os.path.join(INPUT_DIR, "stage_2_sample_submission.csv"))
267
    return df.sort_values(by="ID").set_index(sub.sort_values(by="ID").index).sort_index()
268
269
270
def wide_to_long(df, id_var="ImageID"):
271
    categories = ["any", "epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]
272
    df_long = df.melt(id_vars=[id_var], value_vars=categories, value_name="Label")
273
    df_long["ID"] = df_long[id_var] + "_" + df_long["variable"]
274
    df_long = df_long.sort_values(by="ID")
275
    return df_long[["ID", "Label"]]