|
a |
|
b/5-Training with Ignite and Optuna/tuningfunctions.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from torch.utils.data import TensorDataset |
|
|
4 |
import torch.optim as optim |
|
|
5 |
from torch.optim import lr_scheduler |
|
|
6 |
import numpy as np |
|
|
7 |
import torchvision |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
from torch.utils.data.sampler import SubsetRandomSampler |
|
|
10 |
from torch.utils.data import DataLoader |
|
|
11 |
from torchvision import datasets, models, transforms |
|
|
12 |
from torchvision.transforms import Resize, ToTensor, Normalize |
|
|
13 |
import matplotlib.pyplot as plt |
|
|
14 |
# from imblearn.under_sampling import RandomUnderSampler |
|
|
15 |
import cv2 |
|
|
16 |
from scipy import stats |
|
|
17 |
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, roc_auc_score, \ |
|
|
18 |
average_precision_score |
|
|
19 |
from sklearn.model_selection import train_test_split |
|
|
20 |
import time |
|
|
21 |
import os |
|
|
22 |
from pathlib import Path |
|
|
23 |
from skimage import io |
|
|
24 |
import copy |
|
|
25 |
from torch import optim, cuda |
|
|
26 |
import pandas as pd |
|
|
27 |
import glob |
|
|
28 |
from collections import Counter |
|
|
29 |
# Useful for examining network |
|
|
30 |
from functools import reduce |
|
|
31 |
from operator import __add__ |
|
|
32 |
# from torchsummary import summary |
|
|
33 |
import seaborn as sns |
|
|
34 |
import warnings |
|
|
35 |
# warnings.filterwarnings('ignore', category=FutureWarning) |
|
|
36 |
from PIL import Image |
|
|
37 |
from timeit import default_timer as timer |
|
|
38 |
import matplotlib.pyplot as plt |
|
|
39 |
|
|
|
40 |
# Useful for examining network |
|
|
41 |
from functools import reduce |
|
|
42 |
from operator import __add__ |
|
|
43 |
from torchsummary import summary |
|
|
44 |
|
|
|
45 |
# from IPython.core.interactiveshell import InteractiveShell |
|
|
46 |
import seaborn as sns |
|
|
47 |
|
|
|
48 |
import warnings |
|
|
49 |
# warnings.filterwarnings('ignore', category=FutureWarning) |
|
|
50 |
|
|
|
51 |
# Image manipulations |
|
|
52 |
from PIL import Image |
|
|
53 |
|
|
|
54 |
# Timing utility |
|
|
55 |
from timeit import default_timer as timer |
|
|
56 |
|
|
|
57 |
# Visualizations |
|
|
58 |
import matplotlib.pyplot as plt |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
|
|
|
62 |
|
|
|
63 |
import optuna |
|
|
64 |
from ignite.engine import Engine |
|
|
65 |
from ignite.engine import create_supervised_evaluator |
|
|
66 |
from ignite.engine import create_supervised_trainer |
|
|
67 |
from ignite.engine import Events |
|
|
68 |
from ignite.metrics import Accuracy, Loss, Precision, Recall, Fbeta |
|
|
69 |
from ignite.contrib.metrics.roc_auc import ROC_AUC |
|
|
70 |
from ignite.handlers import ModelCheckpoint, global_step_from_engine, Checkpoint, DiskSaver |
|
|
71 |
from ignite.handlers.early_stopping import EarlyStopping |
|
|
72 |
from ignite.contrib.handlers import TensorboardLogger |
|
|
73 |
|
|
|
74 |
import models |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
|
|
|
78 |
def get_data_loaders(X_train, X_test, y_train, y_test): |
|
|
79 |
|
|
|
80 |
batch_size = 10 |
|
|
81 |
dlen = X_train.shape[0] |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
y_test = torch.FloatTensor(y_test).unsqueeze(1) |
|
|
85 |
X_test = TensorDataset(torch.FloatTensor(X_test), y_test) |
|
|
86 |
test_loader = DataLoader(X_test, batch_size=batch_size, pin_memory=True, shuffle=True) |
|
|
87 |
|
|
|
88 |
y_train = torch.FloatTensor(y_train).unsqueeze(1) |
|
|
89 |
X_train = TensorDataset(torch.FloatTensor(X_train), y_train) |
|
|
90 |
train_loader = DataLoader(X_train, batch_size=batch_size, pin_memory=True, shuffle=True) |
|
|
91 |
|
|
|
92 |
return train_loader, test_loader |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
def get_criterion(y_train): |
|
|
96 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
97 |
print(f'Train on: {device}') |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
LABEL_WEIGHTS = [] |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
class_counts = np.bincount(y_train).tolist() #y_train.value_counts().tolist() |
|
|
104 |
weights = torch.tensor(np.array(class_counts) / sum(class_counts)) |
|
|
105 |
# assert weights[0] > weights[1] |
|
|
106 |
print("CLASS 0: {}, CLASS 1: {}".format(weights[0], weights[1])) |
|
|
107 |
weights = weights[0] / weights |
|
|
108 |
print("WEIGHT 0: {}, WEIGHT 1: {}".format(weights[0], weights[1])) |
|
|
109 |
LABEL_WEIGHTS.append(weights[1]) |
|
|
110 |
|
|
|
111 |
print("Label Weights: ", LABEL_WEIGHTS) |
|
|
112 |
cuda_idx = 0 |
|
|
113 |
LABEL_WEIGHTS = torch.stack(LABEL_WEIGHTS) |
|
|
114 |
LABEL_WEIGHTS = LABEL_WEIGHTS.to(device) |
|
|
115 |
criterion = nn.BCEWithLogitsLoss(pos_weight=LABEL_WEIGHTS) |
|
|
116 |
criterion.to(device) |
|
|
117 |
|
|
|
118 |
return criterion |
|
|
119 |
|
|
|
120 |
def thresholded_output_transform(output): |
|
|
121 |
y_pred, y = output |
|
|
122 |
y_pred = torch.round(torch.sigmoid(y_pred)) |
|
|
123 |
return y_pred, y |
|
|
124 |
def class0_thresholded_output_transform(output): |
|
|
125 |
y_pred, y = output |
|
|
126 |
y_pred = torch.round(torch.sigmoid(y_pred)) |
|
|
127 |
y=1-y |
|
|
128 |
y_pred=1-y_pred |
|
|
129 |
return y_pred, y |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
class Objective(object): |
|
|
133 |
def __init__(self, model_name, criterion, train_loader, test_loader, optimizers, lr_lower, lr_upper, metric, max_epochs, early_stopping_patience=None, lr_scheduler=False, step_size=None, gamma=None): |
|
|
134 |
# Hold this implementation specific arguments as the fields of the class. |
|
|
135 |
self.model_name=model_name |
|
|
136 |
self.train_loader=train_loader |
|
|
137 |
self.test_loader=test_loader |
|
|
138 |
self.optimizers = optimizers |
|
|
139 |
self.criterion=criterion |
|
|
140 |
self.metric = metric |
|
|
141 |
self.max_epochs=max_epochs |
|
|
142 |
self.lr_lower=lr_lower |
|
|
143 |
self.lr_upper=lr_upper |
|
|
144 |
self.early_stopping_patience=early_stopping_patience |
|
|
145 |
self.lr_scheduler=lr_scheduler |
|
|
146 |
self.step_size=step_size |
|
|
147 |
self.gamma=gamma |
|
|
148 |
|
|
|
149 |
def __call__(self, trial): |
|
|
150 |
# Calculate an objective value by using the extra arguments. |
|
|
151 |
model = getattr(models, self.model_name)(trial) |
|
|
152 |
|
|
|
153 |
device = "cpu" |
|
|
154 |
if torch.cuda.is_available(): |
|
|
155 |
device = "cuda:0" |
|
|
156 |
model.cuda(device) |
|
|
157 |
|
|
|
158 |
val_metrics = { |
|
|
159 |
"accuracy": Accuracy(output_transform=thresholded_output_transform), |
|
|
160 |
"loss": Loss(self.criterion), |
|
|
161 |
"roc_auc": ROC_AUC(output_transform=thresholded_output_transform), |
|
|
162 |
"precision": Precision(output_transform=thresholded_output_transform), |
|
|
163 |
"precision_0": Precision(output_transform=class0_thresholded_output_transform), |
|
|
164 |
"recall": Recall(output_transform=thresholded_output_transform), |
|
|
165 |
"recall_0": Recall(output_transform=class0_thresholded_output_transform), |
|
|
166 |
} |
|
|
167 |
val_metrics["f1"]=Fbeta(beta=1.0, average=False, precision=val_metrics['precision'], recall=val_metrics['recall']) |
|
|
168 |
val_metrics["f1_0"]=Fbeta(beta=1.0, average=False, precision=val_metrics['precision_0'], recall=val_metrics['recall_0']) |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
|
|
|
172 |
|
|
|
173 |
optimizer_name = trial.suggest_categorical("optimizer", self.optimizers) |
|
|
174 |
learnrate = trial.suggest_loguniform("lr", self.lr_lower, self.lr_upper) |
|
|
175 |
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=learnrate) |
|
|
176 |
|
|
|
177 |
trainer = create_supervised_trainer(model, optimizer, self.criterion, device=device) |
|
|
178 |
train_evaluator = create_supervised_evaluator(model, metrics= val_metrics, device=device) |
|
|
179 |
evaluator = create_supervised_evaluator(model, metrics= val_metrics, device=device) |
|
|
180 |
|
|
|
181 |
# Register a pruning handler to the evaluator. |
|
|
182 |
pruning_handler = optuna.integration.PyTorchIgnitePruningHandler(trial, self.metric, trainer) |
|
|
183 |
evaluator.add_event_handler(Events.COMPLETED, pruning_handler) |
|
|
184 |
|
|
|
185 |
def score_fn(engine): |
|
|
186 |
score = engine.state.metrics[self.metric] |
|
|
187 |
return score if self.metric!='loss' else -score |
|
|
188 |
|
|
|
189 |
#early stopping |
|
|
190 |
if self.early_stopping_patience is not None: |
|
|
191 |
es_handler = EarlyStopping(patience=self.early_stopping_patience, score_function=score_fn, trainer=trainer) |
|
|
192 |
evaluator.add_event_handler(Events.COMPLETED, es_handler) |
|
|
193 |
|
|
|
194 |
#checkpointing |
|
|
195 |
to_save = {'model': model} |
|
|
196 |
|
|
|
197 |
checkpointname='checkpoint' |
|
|
198 |
for key, value in trial.params.items(): |
|
|
199 |
checkpointname+=key+': '+str(value)+', ' |
|
|
200 |
checkpoint_handler = Checkpoint(to_save, DiskSaver(checkpointname, create_dir=True), |
|
|
201 |
filename_prefix='best', score_function=score_fn, score_name="val_metric", |
|
|
202 |
global_step_transform=global_step_from_engine(trainer)) |
|
|
203 |
|
|
|
204 |
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler) |
|
|
205 |
|
|
|
206 |
# Add lr scheduler |
|
|
207 |
if self.lr_scheduler is True: |
|
|
208 |
scheduler = lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma) |
|
|
209 |
trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: scheduler.step()) |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
|
|
|
213 |
#print metrics on each epoch completed |
|
|
214 |
@trainer.on(Events.EPOCH_COMPLETED) |
|
|
215 |
def log_training_results(engine): |
|
|
216 |
train_evaluator.run(self.train_loader) |
|
|
217 |
metrics = train_evaluator.state.metrics |
|
|
218 |
print("Training Results - Epoch: {} Avg accuracy: {:.4f} Avg loss: {:.4f} roc_auc: {:.4f} \n" |
|
|
219 |
.format(engine.state.epoch, metrics["accuracy"], metrics["loss"], metrics['roc_auc'])) |
|
|
220 |
|
|
|
221 |
@trainer.on(Events.EPOCH_COMPLETED) |
|
|
222 |
def log_validation_results(engine): |
|
|
223 |
evaluator.run(self.test_loader) |
|
|
224 |
metrics = evaluator.state.metrics |
|
|
225 |
print("Validation Results - Epoch: {} Avg accuracy: {:.4f} Avg loss: {:.4f} ROC_AUC: {:.4f}" |
|
|
226 |
"\nClass 1 Precision: {:.4f} Class 1 Recall: {:.4f} Class 1 F1: {:.4f}" |
|
|
227 |
"\nClass 0 Precision: {:.4f} Class 0 Recall: {:.4f} Class 0 F1: {:4f} \n" |
|
|
228 |
.format(engine.state.epoch, metrics["accuracy"], metrics["loss"], metrics['roc_auc'], |
|
|
229 |
metrics['precision'], metrics['recall'], metrics['f1'], |
|
|
230 |
metrics['precision_0'], metrics['recall_0'], metrics["f1_0"])) |
|
|
231 |
|
|
|
232 |
#Tensorboard logs |
|
|
233 |
logname='' |
|
|
234 |
for key, value in trial.params.items(): |
|
|
235 |
logname+=key+': '+str(value)+',' |
|
|
236 |
tb_logger = TensorboardLogger(log_dir=logname) |
|
|
237 |
|
|
|
238 |
for tag, evaluator in [("training", train_evaluator), ("validation", evaluator)]: |
|
|
239 |
tb_logger.attach_output_handler( |
|
|
240 |
evaluator, |
|
|
241 |
event_name=Events.EPOCH_COMPLETED, |
|
|
242 |
tag=tag, |
|
|
243 |
metric_names="all", |
|
|
244 |
global_step_transform=global_step_from_engine(trainer),) |
|
|
245 |
|
|
|
246 |
#run the trainer |
|
|
247 |
trainer.run(self.train_loader, max_epochs=self.max_epochs) |
|
|
248 |
|
|
|
249 |
#load the checkpoint with the best validation metric in the trial |
|
|
250 |
to_load = to_save |
|
|
251 |
checkpoint = torch.load(checkpointname+'/'+checkpoint_handler.last_checkpoint) |
|
|
252 |
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) |
|
|
253 |
|
|
|
254 |
evaluator.run(self.test_loader) |
|
|
255 |
|
|
|
256 |
tb_logger.close() |
|
|
257 |
return evaluator.state.metrics[self.metric] |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
def run_trials(objective, pruner, num_trials, direction): |
|
|
261 |
pruner = pruner |
|
|
262 |
study = optuna.create_study(direction=direction, pruner=pruner) |
|
|
263 |
study.optimize(objective, n_trials=num_trials, gc_after_trial=True) |
|
|
264 |
|
|
|
265 |
print("Number of finished trials: ", len(study.trials)) |
|
|
266 |
|
|
|
267 |
print("Best trial:") |
|
|
268 |
trial = study.best_trial |
|
|
269 |
|
|
|
270 |
print(" Value: ", trial.value) |
|
|
271 |
|
|
|
272 |
print(" Params: ") |
|
|
273 |
for key, value in trial.params.items(): |
|
|
274 |
print(" {}: {}".format(key, value)) |