--- a +++ b/utils.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision.models as models +import torchvision.transforms as transforms +import torch.nn.functional as F +import numpy as np + + +def model_prediction(chpnt): + model = my_model() + ch = torch.load(chpnt) + model.load_state_dict(ch['state_dict']) + return model + +def get_model(): + model = my_model() + model = model.cuda() + return model + +# Function that returns the model +class my_model(nn.Module): + def __init__(self): + super(my_model, self).__init__() + # CNN + model = models.resnet34(weights='DEFAULT') + conv1 = model._modules['conv1'].weight.detach().clone().mean(dim=1, keepdim=True) + model._modules['conv1'] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) + model._modules['conv1'].weight.data = conv1 + model.fc = nn.Linear(model.fc.in_features, 2) + self.features = nn.Sequential(*list(model.children())[0:-1]) + self.fc = list(model.children())[-1] + self.flat = True + + def forward(self, x): + f = self.features(x) + if self.flat: + f = f.view(x.size(0), -1) + o = self.fc(f) + if not self.flat: + o = o.view(x.size(0), -1) + return o + + def forward_feat(self, x): + f = self.features(x) + if self.flat: + f = f.view(x.size(0), -1) + o = self.fc(f) + if not self.flat: + o = o.view(x.size(0), -1) + return f,o + +# Function that creates the optimizer +def create_optimizer(model, mode, lr, momentum, wd): + if mode == 'sgd': + optimizer = optim.SGD(model.parameters(), lr, + momentum=momentum, dampening=0, + weight_decay=wd, nesterov=True) + elif mode == 'adam': + optimizer = optim.Adam(model.parameters(), lr=lr, + weight_decay=wd) + return optimizer + +# Function to anneal learning rate +def adjust_learning_rate(optimizer, epoch, period, start_lr): + """Sets the learning rate to the initial LR decayed by 10 every period epochs""" + lr = start_lr * (0.1 ** (epoch // period)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +class EarlyStopping(): + """ + Early stopping to stop the training when the loss does not improve after + certain epochs. + """ + def __init__(self, patience=5, min_delta=0.001): + """ + :param patience: how many epochs to wait before stopping when loss is + not improving + :param min_delta: minimum difference between new loss and old loss for + new loss to be considered as an improvement + """ + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = None + self.early_stop = False + + def __call__(self, val_loss): + if self.best_loss == None: + self.best_loss = val_loss + elif self.best_loss - val_loss > self.min_delta: + self.best_loss = val_loss + elif self.best_loss - val_loss < self.min_delta: + self.counter += 1 + print(f"INFO: Early stopping counter {self.counter} of {self.patience}") + if self.counter >= self.patience: + print('INFO: Early stopping') + self.early_stop = True + +class LRScheduler(): + """ + Learning rate scheduler. If the validation loss does not decrease for the + given number of `patience` epochs, then the learning rate will decrease by + by given `factor`. + """ + def __init__(self, optimizer, patience=1, min_lr=1e-5, factor=0.1, cooldown=1, threshold=0.001): + """ + new_lr = old_lr * factor + :param optimizer: the optimizer we are using + :param patience: how many epochs to wait before updating the lr + :param min_lr: least lr value to reduce to while updating + :param factor: factor by which the lr should be updated + """ + self.optimizer = optimizer + self.patience = patience + self.min_lr = min_lr + self.factor = factor + self.cooldown = cooldown + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + mode='min', + patience=self.patience, + factor=self.factor, + min_lr=self.min_lr, + cooldown=self.cooldown, + verbose=True + ) + def __call__(self, val_loss): + self.lr_scheduler.step(val_loss) \ No newline at end of file