Switch to side-by-side view

--- a
+++ b/5-Training with Ignite and Optuna/tuningfunctions.py
@@ -0,0 +1,274 @@
+import torch
+import torch.nn as nn
+from torch.utils.data import TensorDataset
+import torch.optim as optim
+from torch.optim import lr_scheduler
+import numpy as np
+import torchvision
+import torch.nn.functional as F
+from torch.utils.data.sampler import SubsetRandomSampler
+from torch.utils.data import DataLoader
+from torchvision import datasets, models, transforms
+from torchvision.transforms import Resize, ToTensor, Normalize
+import matplotlib.pyplot as plt
+# from imblearn.under_sampling import RandomUnderSampler
+import cv2
+from scipy import stats
+from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \
+    average_precision_score
+from sklearn.model_selection import train_test_split
+import time
+import os
+from pathlib import Path
+from skimage import io
+import copy
+from torch import optim, cuda
+import pandas as pd
+import glob
+from collections import Counter
+# Useful for examining network
+from functools import reduce
+from operator import __add__
+# from torchsummary import summary
+import seaborn as sns
+import warnings
+# warnings.filterwarnings('ignore', category=FutureWarning)
+from PIL import Image
+from timeit import default_timer as timer
+import matplotlib.pyplot as plt
+
+# Useful for examining network
+from functools import reduce
+from operator import __add__
+from torchsummary import summary
+
+# from IPython.core.interactiveshell import InteractiveShell
+import seaborn as sns
+
+import warnings
+# warnings.filterwarnings('ignore', category=FutureWarning)
+
+# Image manipulations
+from PIL import Image
+
+# Timing utility
+from timeit import default_timer as timer
+
+# Visualizations
+import matplotlib.pyplot as plt
+
+
+
+
+import optuna
+from ignite.engine import Engine
+from ignite.engine import create_supervised_evaluator
+from ignite.engine import create_supervised_trainer
+from ignite.engine import Events
+from ignite.metrics import Accuracy, Loss, Precision, Recall, Fbeta
+from ignite.contrib.metrics.roc_auc import ROC_AUC
+from ignite.handlers import ModelCheckpoint, global_step_from_engine, Checkpoint, DiskSaver
+from ignite.handlers.early_stopping import EarlyStopping
+from ignite.contrib.handlers import TensorboardLogger
+
+import models
+
+
+
+def get_data_loaders(X_train, X_test, y_train, y_test):
+  
+    batch_size = 10
+    dlen = X_train.shape[0]
+
+
+    y_test = torch.FloatTensor(y_test).unsqueeze(1)
+    X_test = TensorDataset(torch.FloatTensor(X_test), y_test)
+    test_loader = DataLoader(X_test, batch_size=batch_size, pin_memory=True, shuffle=True)
+
+    y_train = torch.FloatTensor(y_train).unsqueeze(1)
+    X_train = TensorDataset(torch.FloatTensor(X_train), y_train)
+    train_loader = DataLoader(X_train, batch_size=batch_size, pin_memory=True, shuffle=True)
+
+    return train_loader, test_loader
+
+
+def get_criterion(y_train):
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+    print(f'Train on: {device}')
+
+
+    LABEL_WEIGHTS = []
+
+   
+    class_counts = np.bincount(y_train).tolist() #y_train.value_counts().tolist()
+    weights = torch.tensor(np.array(class_counts) / sum(class_counts))
+    # assert weights[0] > weights[1]
+    print("CLASS 0: {}, CLASS 1: {}".format(weights[0], weights[1]))
+    weights = weights[0] / weights
+    print("WEIGHT 0: {}, WEIGHT 1: {}".format(weights[0], weights[1]))
+    LABEL_WEIGHTS.append(weights[1])
+
+    print("Label Weights: ", LABEL_WEIGHTS)
+    cuda_idx = 0
+    LABEL_WEIGHTS = torch.stack(LABEL_WEIGHTS)
+    LABEL_WEIGHTS = LABEL_WEIGHTS.to(device)
+    criterion = nn.BCEWithLogitsLoss(pos_weight=LABEL_WEIGHTS)
+    criterion.to(device)
+    
+    return criterion
+
+def thresholded_output_transform(output):
+            y_pred, y = output
+            y_pred = torch.round(torch.sigmoid(y_pred))
+            return y_pred, y
+def class0_thresholded_output_transform(output):
+            y_pred, y = output
+            y_pred = torch.round(torch.sigmoid(y_pred))
+            y=1-y
+            y_pred=1-y_pred
+            return y_pred, y
+        
+
+class Objective(object):
+    def __init__(self, model_name, criterion, train_loader, test_loader, optimizers, lr_lower, lr_upper, metric, max_epochs, early_stopping_patience=None, lr_scheduler=False, step_size=None, gamma=None):
+        # Hold this implementation specific arguments as the fields of the class.
+        self.model_name=model_name
+        self.train_loader=train_loader
+        self.test_loader=test_loader
+        self.optimizers = optimizers
+        self.criterion=criterion
+        self.metric = metric
+        self.max_epochs=max_epochs
+        self.lr_lower=lr_lower
+        self.lr_upper=lr_upper
+        self.early_stopping_patience=early_stopping_patience
+        self.lr_scheduler=lr_scheduler
+        self.step_size=step_size
+        self.gamma=gamma
+
+    def __call__(self, trial):
+        # Calculate an objective value by using the extra arguments.
+        model = getattr(models, self.model_name)(trial)
+
+        device = "cpu"
+        if torch.cuda.is_available():
+            device = "cuda:0"
+            model.cuda(device)
+
+        val_metrics = {
+        "accuracy": Accuracy(output_transform=thresholded_output_transform),
+        "loss": Loss(self.criterion),
+        "roc_auc": ROC_AUC(output_transform=thresholded_output_transform),
+        "precision": Precision(output_transform=thresholded_output_transform),
+        "precision_0": Precision(output_transform=class0_thresholded_output_transform),
+        "recall": Recall(output_transform=thresholded_output_transform),
+        "recall_0": Recall(output_transform=class0_thresholded_output_transform),
+        }
+        val_metrics["f1"]=Fbeta(beta=1.0, average=False, precision=val_metrics['precision'], recall=val_metrics['recall'])
+        val_metrics["f1_0"]=Fbeta(beta=1.0, average=False, precision=val_metrics['precision_0'], recall=val_metrics['recall_0'])
+
+
+
+
+        optimizer_name = trial.suggest_categorical("optimizer", self.optimizers)
+        learnrate = trial.suggest_loguniform("lr", self.lr_lower, self.lr_upper)
+        optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=learnrate)
+
+        trainer = create_supervised_trainer(model, optimizer, self.criterion, device=device)
+        train_evaluator = create_supervised_evaluator(model, metrics= val_metrics, device=device)
+        evaluator = create_supervised_evaluator(model, metrics= val_metrics, device=device)
+
+        # Register a pruning handler to the evaluator.
+        pruning_handler = optuna.integration.PyTorchIgnitePruningHandler(trial, self.metric, trainer)
+        evaluator.add_event_handler(Events.COMPLETED, pruning_handler)
+
+        def score_fn(engine):
+            score = engine.state.metrics[self.metric]
+            return score if self.metric!='loss' else -score
+
+        #early stopping
+        if self.early_stopping_patience is not None:
+            es_handler = EarlyStopping(patience=self.early_stopping_patience, score_function=score_fn, trainer=trainer)
+            evaluator.add_event_handler(Events.COMPLETED, es_handler)
+
+        #checkpointing
+        to_save = {'model': model}
+
+        checkpointname='checkpoint'
+        for key, value in trial.params.items():
+          checkpointname+=key+': '+str(value)+', '
+        checkpoint_handler = Checkpoint(to_save, DiskSaver(checkpointname, create_dir=True),
+                         filename_prefix='best', score_function=score_fn, score_name="val_metric",
+                         global_step_transform=global_step_from_engine(trainer))
+
+        evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
+
+        #  Add lr scheduler
+        if self.lr_scheduler is True:
+            scheduler = lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)
+            trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: scheduler.step())
+
+        
+        
+        #print metrics on each epoch completed
+        @trainer.on(Events.EPOCH_COMPLETED)
+        def log_training_results(engine):
+          train_evaluator.run(self.train_loader)
+          metrics = train_evaluator.state.metrics
+          print("Training Results - Epoch: {}  Avg accuracy: {:.4f} Avg loss: {:.4f} roc_auc: {:.4f} \n"
+              .format(engine.state.epoch, metrics["accuracy"], metrics["loss"], metrics['roc_auc']))
+
+        @trainer.on(Events.EPOCH_COMPLETED)
+        def log_validation_results(engine):
+            evaluator.run(self.test_loader)
+            metrics = evaluator.state.metrics
+            print("Validation Results - Epoch: {}  Avg accuracy: {:.4f} Avg loss: {:.4f} ROC_AUC: {:.4f}"
+            "\nClass 1 Precision: {:.4f} Class 1 Recall: {:.4f} Class 1 F1: {:.4f}"
+            "\nClass 0 Precision: {:.4f} Class 0 Recall: {:.4f} Class 0 F1: {:4f} \n"
+              .format(engine.state.epoch, metrics["accuracy"], metrics["loss"], metrics['roc_auc'], 
+                      metrics['precision'], metrics['recall'], metrics['f1'], 
+                      metrics['precision_0'], metrics['recall_0'], metrics["f1_0"]))
+
+        #Tensorboard logs
+        logname=''
+        for key, value in trial.params.items():
+          logname+=key+': '+str(value)+','
+        tb_logger = TensorboardLogger(log_dir=logname)
+
+        for tag, evaluator in [("training", train_evaluator), ("validation", evaluator)]:
+          tb_logger.attach_output_handler(
+            evaluator,
+            event_name=Events.EPOCH_COMPLETED,
+            tag=tag,
+            metric_names="all",
+            global_step_transform=global_step_from_engine(trainer),)
+
+        #run the trainer
+        trainer.run(self.train_loader, max_epochs=self.max_epochs)
+
+        #load the checkpoint with the best validation metric in the trial
+        to_load = to_save
+        checkpoint = torch.load(checkpointname+'/'+checkpoint_handler.last_checkpoint)
+        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
+
+        evaluator.run(self.test_loader)
+
+        tb_logger.close()
+        return evaluator.state.metrics[self.metric]
+    
+
+def run_trials(objective, pruner, num_trials, direction): 
+    pruner = pruner
+    study = optuna.create_study(direction=direction, pruner=pruner)
+    study.optimize(objective, n_trials=num_trials, gc_after_trial=True)
+
+    print("Number of finished trials: ", len(study.trials))
+
+    print("Best trial:")
+    trial = study.best_trial
+
+    print("  Value: ", trial.value)
+
+    print("  Params: ")
+    for key, value in trial.params.items():
+          print("    {}: {}".format(key, value))
\ No newline at end of file