--- a +++ b/semseg_train/callbacks.py @@ -0,0 +1,65 @@ +import numpy as np +import torch + +class EarlyStopping: + """Early stops the training if validation loss doesn't improve after a given patience.""" + def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + path (str): Path for the checkpoint to be saved to. + Default: 'checkpoint.pt' + trace_func (function): trace print function. + Default: print + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.path = path + self.trace_func = trace_func + def __call__(self, val_loss#): + , model, ep, step, model_path): + + score = -val_loss + + if self.best_score is None: + self.best_score = score + #self.save_checkpoint(val_loss, model) + #self.save_checkpoint_all(val_loss, model, ep, step, model_path) + elif score < self.best_score + self.delta: + self.counter += 1 + self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + #self.save_checkpoint(val_loss, model) + self.save_checkpoint_all(val_loss, model, ep, step, model_path) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + '''Saves model when validation loss decrease.''' + if self.verbose: + self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + torch.save(model.state_dict(), self.path) + self.val_loss_min = val_loss + + def save_checkpoint_all(self, val_loss, model, ep, step, model_path): + '''Saves model when validation loss decrease.''' + if self.verbose: + self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + torch.save({ + 'model': model.state_dict(), + 'epoch': ep, + 'step': step, + }, str(model_path)) + self.val_loss_min = val_loss \ No newline at end of file