Diff of /utils.py [000000] .. [1928b6]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torchvision.models as models
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+import numpy as np
+
+
+def model_prediction(chpnt):
+    model = my_model()
+    ch = torch.load(chpnt)
+    model.load_state_dict(ch['state_dict'])
+    return model
+
+def get_model():
+    model = my_model()    
+    model = model.cuda()
+    return model
+
+# Function that returns the model
+class my_model(nn.Module):
+    def __init__(self):
+        super(my_model, self).__init__()
+        # CNN
+        model = models.resnet34(weights='DEFAULT')
+        conv1 = model._modules['conv1'].weight.detach().clone().mean(dim=1, keepdim=True)
+        model._modules['conv1'] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        model._modules['conv1'].weight.data = conv1
+        model.fc = nn.Linear(model.fc.in_features, 2)
+        self.features = nn.Sequential(*list(model.children())[0:-1])
+        self.fc = list(model.children())[-1]
+        self.flat = True
+
+    def forward(self, x):
+        f = self.features(x)
+        if self.flat:
+            f = f.view(x.size(0), -1)
+        o = self.fc(f)
+        if not self.flat:
+            o = o.view(x.size(0), -1)
+        return o
+    
+    def forward_feat(self, x):
+        f = self.features(x)
+        if self.flat:
+            f = f.view(x.size(0), -1)
+        o = self.fc(f)
+        if not self.flat:
+            o = o.view(x.size(0), -1)
+        return f,o
+
+# Function that creates the optimizer
+def create_optimizer(model, mode, lr, momentum, wd):
+    if mode == 'sgd':
+        optimizer = optim.SGD(model.parameters(), lr,
+                              momentum=momentum, dampening=0,
+                              weight_decay=wd, nesterov=True)
+    elif mode == 'adam':
+        optimizer = optim.Adam(model.parameters(), lr=lr,
+                               weight_decay=wd)
+    return optimizer
+
+# Function to anneal learning rate
+def adjust_learning_rate(optimizer, epoch, period, start_lr):
+    """Sets the learning rate to the initial LR decayed by 10 every period epochs"""
+    lr = start_lr * (0.1 ** (epoch // period))
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+
+class EarlyStopping():
+    """
+    Early stopping to stop the training when the loss does not improve after
+    certain epochs.
+    """
+    def __init__(self, patience=5, min_delta=0.001):
+        """
+        :param patience: how many epochs to wait before stopping when loss is
+               not improving
+        :param min_delta: minimum difference between new loss and old loss for
+               new loss to be considered as an improvement
+        """
+        self.patience = patience
+        self.min_delta = min_delta
+        self.counter = 0
+        self.best_loss = None
+        self.early_stop = False
+
+    def __call__(self, val_loss):
+        if self.best_loss == None:
+            self.best_loss = val_loss
+        elif self.best_loss - val_loss > self.min_delta:
+            self.best_loss = val_loss
+        elif self.best_loss - val_loss < self.min_delta:
+            self.counter += 1
+            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
+            if self.counter >= self.patience:
+                print('INFO: Early stopping')
+                self.early_stop = True
+
+class LRScheduler():
+    """
+    Learning rate scheduler. If the validation loss does not decrease for the 
+    given number of `patience` epochs, then the learning rate will decrease by
+    by given `factor`.
+    """
+    def __init__(self, optimizer, patience=1, min_lr=1e-5, factor=0.1, cooldown=1, threshold=0.001):
+        """
+        new_lr = old_lr * factor
+        :param optimizer: the optimizer we are using
+        :param patience: how many epochs to wait before updating the lr
+        :param min_lr: least lr value to reduce to while updating
+        :param factor: factor by which the lr should be updated
+        """
+        self.optimizer = optimizer
+        self.patience = patience
+        self.min_lr = min_lr
+        self.factor = factor
+        self.cooldown = cooldown
+        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 
+                self.optimizer,
+                mode='min',
+                patience=self.patience,
+                factor=self.factor,
+                min_lr=self.min_lr,
+                cooldown=self.cooldown,
+                verbose=True
+            )
+    def __call__(self, val_loss):
+        self.lr_scheduler.step(val_loss)
\ No newline at end of file