--- a
+++ b/utils.py
@@ -0,0 +1,275 @@
+import numpy as np
+import pandas as pd
+import os
+import torch
+from torchvision import transforms
+from torch.utils.data import DataLoader
+import sys
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+from sklearn.metrics import roc_auc_score, roc_curve, log_loss, auc
+from scipy import interp
+from itertools import cycle
+sys.path.append("/home/anjum/PycharmProjects/kaggle")
+# sys.path.append("/home/anjum/rsna_code")  # GCP
+from rsna_intracranial_hemorrhage_detection.datasets import ICHDataset
+
+INPUT_DIR = "/mnt/storage_dimm2/kaggle_data/rsna-intracranial-hemorrhage-detection/"
+
+
+def build_tta_loaders(img_size, dataset, phase=1, image_filter=None, batch_size=32, num_workers=1,
+                      image_folder=None, png=True):
+    if type(img_size) == int:
+        img_size = (img_size, img_size)
+
+    def null_transform(image):
+        image = transforms.functional.to_pil_image(image)
+        image = transforms.functional.resize(image, img_size)
+        tensor = transforms.functional.to_tensor(image)
+        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        return tensor
+
+    def hflip(image):
+        image = transforms.functional.to_pil_image(image)
+        image = transforms.functional.hflip(image)
+        image = transforms.functional.resize(image, img_size)
+        tensor = transforms.functional.to_tensor(image)
+        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        return tensor
+
+    def rotate_pos(image):
+        image = transforms.functional.to_pil_image(image)
+        image = transforms.functional.affine(image, angle=10, translate=(0, 0), scale=1.0, shear=0)
+        image = transforms.functional.resize(image, img_size)
+        tensor = transforms.functional.to_tensor(image)
+        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        return tensor
+
+    def rotate_neg(image):
+        image = transforms.functional.to_pil_image(image)
+        image = transforms.functional.affine(image, angle=-10, translate=(0, 0), scale=1.0, shear=0)
+        image = transforms.functional.resize(image, img_size)
+        tensor = transforms.functional.to_tensor(image)
+        # tensor = transforms.functional.normalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        return tensor
+
+    # tta = [null_transform, hflip]
+    tta = [null_transform, hflip, rotate_pos, rotate_neg]
+    loaders = []
+
+    for augmentation in tta:
+        tta_dataset = ICHDataset(dataset, phase=phase, image_filter=image_filter, transforms=augmentation,
+                                 image_folder=image_folder, png=png)
+        loaders.append(DataLoader(tta_dataset, batch_size=batch_size, num_workers=num_workers))
+    return loaders
+
+
+def infer(model, loader, device, desc):
+    model.eval()
+    with torch.no_grad():
+        predictions, targets = [], []
+        for image, target in tqdm(loader, desc=desc):
+            image = image.to(device)
+            y_hat = model(image)
+            predictions.append(y_hat.cpu())
+            targets.append(target)
+
+        predictions = torch.cat(predictions)
+        targets = torch.cat(targets)
+    return predictions, targets
+
+
+def test_time_augmentation(model, loaders, device, desc):
+    tta_predictions = []
+    for i, loader in enumerate(loaders):
+        predictions, targets = infer(model, loader, device, f"{desc} TTA {i}")
+        tta_predictions.append(torch.unsqueeze(predictions, -1))
+
+    tta_predictions = torch.cat(tta_predictions, -1)
+    return torch.mean(tta_predictions, dim=-1), targets
+
+
+class EarlyStopping:
+    """
+    Early stops the training if validation loss doesn't improve after a given patience.
+    https://github.com/Bjarten/early-stopping-pytorch
+    """
+    def __init__(self, patience=7, verbose=False, delta=0, file_path='checkpoint.pt', parallel=False):
+        """
+        :param patience: How long to wait after last time validation loss improved. Default: 7
+        :param verbose: If True, prints a message for each validation loss improvement. Default: False
+        :param delta: Minimum change in the monitored quantity to qualify as an improvement. Default: 0
+        :param file_path: Path to save checkpoint file
+        :param parallel: If True, the multi-GPU model is saves as a single GPU model
+        """
+
+        self.patience = patience
+        self.verbose = verbose
+        self.counter = 0
+        self.best_score = None
+        self.early_stop = False
+        self.val_loss_min = np.Inf
+        self.delta = delta
+        self.file_path = file_path
+        self.parallel = parallel
+        self.parallel_model = None
+
+    def __call__(self, val_loss, model, **kwargs):
+
+        score = -val_loss
+
+        if self.best_score is None:
+            self.best_score = score
+            self.save_checkpoint(val_loss, model, **kwargs)
+        elif score < self.best_score - self.delta:
+            self.counter += 1
+            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
+            if self.counter >= self.patience:
+                self.early_stop = True
+        else:
+            self.best_score = score
+            self.save_checkpoint(val_loss, model, **kwargs)
+            self.counter = 0
+
+    def save_checkpoint(self, val_loss, model, **save_items):
+        """Saves model and addition items when validation loss decrease."""
+        if self.verbose:
+            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
+
+        if save_items == {}:
+            torch.save(model.state_dict(), self.file_path)
+        else:
+            if self.parallel:
+                self.parallel_model = model.state_dict()
+                save_items['model'] = model.module.state_dict()
+            else:
+                save_items['model'] = model.state_dict()
+            save_items["stopping_params"] = self.state_dict()
+            torch.save(save_items, self.file_path)
+        self.val_loss_min = val_loss
+
+    def state_dict(self):
+        state = {
+            # "counter": self.counter,
+            "best_score": self.best_score,
+            "val_loss_min": self.val_loss_min,
+        }
+        return state
+
+    def load_state_dict(self, state):
+        # self.counter = state["counter"]
+        self.best_score = state["best_score"]
+        self.val_loss_min = state["val_loss_min"]
+
+
+def plot_roc_curve(target, predictions, file_path, metric=None):
+
+    if type(target) == torch.Tensor:
+        target = target.numpy()
+    if type(predictions) == torch.Tensor:
+        predictions = predictions.numpy()
+
+    plt.figure()
+    fpr, tpr, _ = roc_curve(target, predictions)
+    score = roc_auc_score(target, predictions)
+    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.4f)' % score)
+    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.xlabel('False Positive Rate')
+    plt.ylabel('True Positive Rate')
+    if metric is not None:
+        plt.title(f'Receiver operating characteristic. LogLoss: {metric:.4f}')
+    else:
+        plt.title(f'Receiver operating characteristic')
+    plt.legend(loc="lower right")
+    plt.savefig(file_path)
+
+
+def plot_multiclass_roc_curve(target, predictions, file_path, metric=None):
+    # https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
+
+    if type(target) == torch.Tensor:
+        target = target.numpy()
+    if type(predictions) == torch.Tensor:
+        predictions = predictions.numpy()
+
+    n_classes = target.shape[1]
+
+    # Compute ROC curve and ROC area for each class
+    fpr = dict()
+    tpr = dict()
+    roc_auc = dict()
+    log_loss_vals = dict()
+    for i in range(n_classes):
+        fpr[i], tpr[i], _ = roc_curve(target[:, i], predictions[:, i])
+        roc_auc[i] = auc(fpr[i], tpr[i])
+        log_loss_vals[i] = log_loss(target[:, i], predictions[:, i])
+
+    # Compute micro-average ROC curve and ROC area
+    fpr["micro"], tpr["micro"], _ = roc_curve(target.ravel(), predictions.ravel())
+    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
+
+    # First aggregate all false positive rates
+    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
+
+    # Then interpolate all ROC curves at this points
+    mean_tpr = np.zeros_like(all_fpr)
+    for i in range(n_classes):
+        mean_tpr += interp(all_fpr, fpr[i], tpr[i])
+
+    # Finally average it and compute AUC
+    mean_tpr /= n_classes
+
+    fpr["macro"] = all_fpr
+    tpr["macro"] = mean_tpr
+    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
+
+    # Plot all ROC curves
+    plt.figure()
+    plt.plot(fpr["micro"], tpr["micro"],
+             label='micro-average ROC curve (area = {0:0.2f})'
+                   ''.format(roc_auc["micro"]),
+             color='deeppink', linestyle=':', linewidth=4)
+
+    plt.plot(fpr["macro"], tpr["macro"],
+             label='macro-average ROC curve (area = {0:0.2f})'
+                   ''.format(roc_auc["macro"]),
+             color='navy', linestyle=':', linewidth=4)
+
+    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'magenta', 'lawngreen', 'gold'])
+    for i, color in zip(range(n_classes), colors):
+        plt.plot(fpr[i], tpr[i], color=color, lw=2, label='ROC curve of class {0} (area={1:0.2f}, log_loss={2:0.3f})'
+                                                          ''.format(i, roc_auc[i], log_loss_vals[i]))
+
+    plt.plot([0, 1], [0, 1], 'k--', lw=2)
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.xlabel('False Positive Rate')
+    plt.ylabel('True Positive Rate')
+
+    if metric is not None:
+        plt.title(f'Receiver operating characteristic. LogLoss: {metric:.4f}')
+    else:
+        plt.title(f'Receiver operating characteristic')
+
+    plt.legend(loc="lower right")
+    plt.savefig(file_path)
+
+
+def reindex_submission(df, stage="test1"):
+    if stage not in ["test1", "test2"]:
+        return df
+    elif stage == "test1":
+        sub = pd.read_csv(os.path.join(INPUT_DIR, "stage_1_sample_submission.csv"))
+    else:
+        sub = pd.read_csv(os.path.join(INPUT_DIR, "stage_2_sample_submission.csv"))
+    return df.sort_values(by="ID").set_index(sub.sort_values(by="ID").index).sort_index()
+
+
+def wide_to_long(df, id_var="ImageID"):
+    categories = ["any", "epidural", "intraparenchymal", "intraventricular", "subarachnoid", "subdural"]
+    df_long = df.melt(id_vars=[id_var], value_vars=categories, value_name="Label")
+    df_long["ID"] = df_long[id_var] + "_" + df_long["variable"]
+    df_long = df_long.sort_values(by="ID")
+    return df_long[["ID", "Label"]]