Switch to unified view

a b/5-Training with Ignite and Optuna/tuningfunctions.py
1
import torch
2
import torch.nn as nn
3
from torch.utils.data import TensorDataset
4
import torch.optim as optim
5
from torch.optim import lr_scheduler
6
import numpy as np
7
import torchvision
8
import torch.nn.functional as F
9
from torch.utils.data.sampler import SubsetRandomSampler
10
from torch.utils.data import DataLoader
11
from torchvision import datasets, models, transforms
12
from torchvision.transforms import Resize, ToTensor, Normalize
13
import matplotlib.pyplot as plt
14
# from imblearn.under_sampling import RandomUnderSampler
15
import cv2
16
from scipy import stats
17
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \
18
    average_precision_score
19
from sklearn.model_selection import train_test_split
20
import time
21
import os
22
from pathlib import Path
23
from skimage import io
24
import copy
25
from torch import optim, cuda
26
import pandas as pd
27
import glob
28
from collections import Counter
29
# Useful for examining network
30
from functools import reduce
31
from operator import __add__
32
# from torchsummary import summary
33
import seaborn as sns
34
import warnings
35
# warnings.filterwarnings('ignore', category=FutureWarning)
36
from PIL import Image
37
from timeit import default_timer as timer
38
import matplotlib.pyplot as plt
39
40
# Useful for examining network
41
from functools import reduce
42
from operator import __add__
43
from torchsummary import summary
44
45
# from IPython.core.interactiveshell import InteractiveShell
46
import seaborn as sns
47
48
import warnings
49
# warnings.filterwarnings('ignore', category=FutureWarning)
50
51
# Image manipulations
52
from PIL import Image
53
54
# Timing utility
55
from timeit import default_timer as timer
56
57
# Visualizations
58
import matplotlib.pyplot as plt
59
60
61
62
63
import optuna
64
from ignite.engine import Engine
65
from ignite.engine import create_supervised_evaluator
66
from ignite.engine import create_supervised_trainer
67
from ignite.engine import Events
68
from ignite.metrics import Accuracy, Loss, Precision, Recall, Fbeta
69
from ignite.contrib.metrics.roc_auc import ROC_AUC
70
from ignite.handlers import ModelCheckpoint, global_step_from_engine, Checkpoint, DiskSaver
71
from ignite.handlers.early_stopping import EarlyStopping
72
from ignite.contrib.handlers import TensorboardLogger
73
74
import models
75
76
77
78
def get_data_loaders(X_train, X_test, y_train, y_test):
79
  
80
    batch_size = 10
81
    dlen = X_train.shape[0]
82
83
84
    y_test = torch.FloatTensor(y_test).unsqueeze(1)
85
    X_test = TensorDataset(torch.FloatTensor(X_test), y_test)
86
    test_loader = DataLoader(X_test, batch_size=batch_size, pin_memory=True, shuffle=True)
87
88
    y_train = torch.FloatTensor(y_train).unsqueeze(1)
89
    X_train = TensorDataset(torch.FloatTensor(X_train), y_train)
90
    train_loader = DataLoader(X_train, batch_size=batch_size, pin_memory=True, shuffle=True)
91
92
    return train_loader, test_loader
93
94
95
def get_criterion(y_train):
96
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
97
    print(f'Train on: {device}')
98
99
100
    LABEL_WEIGHTS = []
101
102
   
103
    class_counts = np.bincount(y_train).tolist() #y_train.value_counts().tolist()
104
    weights = torch.tensor(np.array(class_counts) / sum(class_counts))
105
    # assert weights[0] > weights[1]
106
    print("CLASS 0: {}, CLASS 1: {}".format(weights[0], weights[1]))
107
    weights = weights[0] / weights
108
    print("WEIGHT 0: {}, WEIGHT 1: {}".format(weights[0], weights[1]))
109
    LABEL_WEIGHTS.append(weights[1])
110
111
    print("Label Weights: ", LABEL_WEIGHTS)
112
    cuda_idx = 0
113
    LABEL_WEIGHTS = torch.stack(LABEL_WEIGHTS)
114
    LABEL_WEIGHTS = LABEL_WEIGHTS.to(device)
115
    criterion = nn.BCEWithLogitsLoss(pos_weight=LABEL_WEIGHTS)
116
    criterion.to(device)
117
    
118
    return criterion
119
120
def thresholded_output_transform(output):
121
            y_pred, y = output
122
            y_pred = torch.round(torch.sigmoid(y_pred))
123
            return y_pred, y
124
def class0_thresholded_output_transform(output):
125
            y_pred, y = output
126
            y_pred = torch.round(torch.sigmoid(y_pred))
127
            y=1-y
128
            y_pred=1-y_pred
129
            return y_pred, y
130
        
131
132
class Objective(object):
133
    def __init__(self, model_name, criterion, train_loader, test_loader, optimizers, lr_lower, lr_upper, metric, max_epochs, early_stopping_patience=None, lr_scheduler=False, step_size=None, gamma=None):
134
        # Hold this implementation specific arguments as the fields of the class.
135
        self.model_name=model_name
136
        self.train_loader=train_loader
137
        self.test_loader=test_loader
138
        self.optimizers = optimizers
139
        self.criterion=criterion
140
        self.metric = metric
141
        self.max_epochs=max_epochs
142
        self.lr_lower=lr_lower
143
        self.lr_upper=lr_upper
144
        self.early_stopping_patience=early_stopping_patience
145
        self.lr_scheduler=lr_scheduler
146
        self.step_size=step_size
147
        self.gamma=gamma
148
149
    def __call__(self, trial):
150
        # Calculate an objective value by using the extra arguments.
151
        model = getattr(models, self.model_name)(trial)
152
153
        device = "cpu"
154
        if torch.cuda.is_available():
155
            device = "cuda:0"
156
            model.cuda(device)
157
158
        val_metrics = {
159
        "accuracy": Accuracy(output_transform=thresholded_output_transform),
160
        "loss": Loss(self.criterion),
161
        "roc_auc": ROC_AUC(output_transform=thresholded_output_transform),
162
        "precision": Precision(output_transform=thresholded_output_transform),
163
        "precision_0": Precision(output_transform=class0_thresholded_output_transform),
164
        "recall": Recall(output_transform=thresholded_output_transform),
165
        "recall_0": Recall(output_transform=class0_thresholded_output_transform),
166
        }
167
        val_metrics["f1"]=Fbeta(beta=1.0, average=False, precision=val_metrics['precision'], recall=val_metrics['recall'])
168
        val_metrics["f1_0"]=Fbeta(beta=1.0, average=False, precision=val_metrics['precision_0'], recall=val_metrics['recall_0'])
169
170
171
172
173
        optimizer_name = trial.suggest_categorical("optimizer", self.optimizers)
174
        learnrate = trial.suggest_loguniform("lr", self.lr_lower, self.lr_upper)
175
        optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=learnrate)
176
177
        trainer = create_supervised_trainer(model, optimizer, self.criterion, device=device)
178
        train_evaluator = create_supervised_evaluator(model, metrics= val_metrics, device=device)
179
        evaluator = create_supervised_evaluator(model, metrics= val_metrics, device=device)
180
181
        # Register a pruning handler to the evaluator.
182
        pruning_handler = optuna.integration.PyTorchIgnitePruningHandler(trial, self.metric, trainer)
183
        evaluator.add_event_handler(Events.COMPLETED, pruning_handler)
184
185
        def score_fn(engine):
186
            score = engine.state.metrics[self.metric]
187
            return score if self.metric!='loss' else -score
188
189
        #early stopping
190
        if self.early_stopping_patience is not None:
191
            es_handler = EarlyStopping(patience=self.early_stopping_patience, score_function=score_fn, trainer=trainer)
192
            evaluator.add_event_handler(Events.COMPLETED, es_handler)
193
194
        #checkpointing
195
        to_save = {'model': model}
196
197
        checkpointname='checkpoint'
198
        for key, value in trial.params.items():
199
          checkpointname+=key+': '+str(value)+', '
200
        checkpoint_handler = Checkpoint(to_save, DiskSaver(checkpointname, create_dir=True),
201
                         filename_prefix='best', score_function=score_fn, score_name="val_metric",
202
                         global_step_transform=global_step_from_engine(trainer))
203
204
        evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
205
206
        #  Add lr scheduler
207
        if self.lr_scheduler is True:
208
            scheduler = lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)
209
            trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: scheduler.step())
210
211
        
212
        
213
        #print metrics on each epoch completed
214
        @trainer.on(Events.EPOCH_COMPLETED)
215
        def log_training_results(engine):
216
          train_evaluator.run(self.train_loader)
217
          metrics = train_evaluator.state.metrics
218
          print("Training Results - Epoch: {}  Avg accuracy: {:.4f} Avg loss: {:.4f} roc_auc: {:.4f} \n"
219
              .format(engine.state.epoch, metrics["accuracy"], metrics["loss"], metrics['roc_auc']))
220
221
        @trainer.on(Events.EPOCH_COMPLETED)
222
        def log_validation_results(engine):
223
            evaluator.run(self.test_loader)
224
            metrics = evaluator.state.metrics
225
            print("Validation Results - Epoch: {}  Avg accuracy: {:.4f} Avg loss: {:.4f} ROC_AUC: {:.4f}"
226
            "\nClass 1 Precision: {:.4f} Class 1 Recall: {:.4f} Class 1 F1: {:.4f}"
227
            "\nClass 0 Precision: {:.4f} Class 0 Recall: {:.4f} Class 0 F1: {:4f} \n"
228
              .format(engine.state.epoch, metrics["accuracy"], metrics["loss"], metrics['roc_auc'], 
229
                      metrics['precision'], metrics['recall'], metrics['f1'], 
230
                      metrics['precision_0'], metrics['recall_0'], metrics["f1_0"]))
231
232
        #Tensorboard logs
233
        logname=''
234
        for key, value in trial.params.items():
235
          logname+=key+': '+str(value)+','
236
        tb_logger = TensorboardLogger(log_dir=logname)
237
238
        for tag, evaluator in [("training", train_evaluator), ("validation", evaluator)]:
239
          tb_logger.attach_output_handler(
240
            evaluator,
241
            event_name=Events.EPOCH_COMPLETED,
242
            tag=tag,
243
            metric_names="all",
244
            global_step_transform=global_step_from_engine(trainer),)
245
246
        #run the trainer
247
        trainer.run(self.train_loader, max_epochs=self.max_epochs)
248
249
        #load the checkpoint with the best validation metric in the trial
250
        to_load = to_save
251
        checkpoint = torch.load(checkpointname+'/'+checkpoint_handler.last_checkpoint)
252
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
253
254
        evaluator.run(self.test_loader)
255
256
        tb_logger.close()
257
        return evaluator.state.metrics[self.metric]
258
    
259
260
def run_trials(objective, pruner, num_trials, direction): 
261
    pruner = pruner
262
    study = optuna.create_study(direction=direction, pruner=pruner)
263
    study.optimize(objective, n_trials=num_trials, gc_after_trial=True)
264
265
    print("Number of finished trials: ", len(study.trials))
266
267
    print("Best trial:")
268
    trial = study.best_trial
269
270
    print("  Value: ", trial.value)
271
272
    print("  Params: ")
273
    for key, value in trial.params.items():
274
          print("    {}: {}".format(key, value))