Diff of /utils.py [000000] .. [1928b6]

Switch to unified view

a b/utils.py
1
import torch
2
import torch.nn as nn
3
import torch.optim as optim
4
import torchvision.models as models
5
import torchvision.transforms as transforms
6
import torch.nn.functional as F
7
import numpy as np
8
9
10
def model_prediction(chpnt):
11
    model = my_model()
12
    ch = torch.load(chpnt)
13
    model.load_state_dict(ch['state_dict'])
14
    return model
15
16
def get_model():
17
    model = my_model()    
18
    model = model.cuda()
19
    return model
20
21
# Function that returns the model
22
class my_model(nn.Module):
23
    def __init__(self):
24
        super(my_model, self).__init__()
25
        # CNN
26
        model = models.resnet34(weights='DEFAULT')
27
        conv1 = model._modules['conv1'].weight.detach().clone().mean(dim=1, keepdim=True)
28
        model._modules['conv1'] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
29
        model._modules['conv1'].weight.data = conv1
30
        model.fc = nn.Linear(model.fc.in_features, 2)
31
        self.features = nn.Sequential(*list(model.children())[0:-1])
32
        self.fc = list(model.children())[-1]
33
        self.flat = True
34
35
    def forward(self, x):
36
        f = self.features(x)
37
        if self.flat:
38
            f = f.view(x.size(0), -1)
39
        o = self.fc(f)
40
        if not self.flat:
41
            o = o.view(x.size(0), -1)
42
        return o
43
    
44
    def forward_feat(self, x):
45
        f = self.features(x)
46
        if self.flat:
47
            f = f.view(x.size(0), -1)
48
        o = self.fc(f)
49
        if not self.flat:
50
            o = o.view(x.size(0), -1)
51
        return f,o
52
53
# Function that creates the optimizer
54
def create_optimizer(model, mode, lr, momentum, wd):
55
    if mode == 'sgd':
56
        optimizer = optim.SGD(model.parameters(), lr,
57
                              momentum=momentum, dampening=0,
58
                              weight_decay=wd, nesterov=True)
59
    elif mode == 'adam':
60
        optimizer = optim.Adam(model.parameters(), lr=lr,
61
                               weight_decay=wd)
62
    return optimizer
63
64
# Function to anneal learning rate
65
def adjust_learning_rate(optimizer, epoch, period, start_lr):
66
    """Sets the learning rate to the initial LR decayed by 10 every period epochs"""
67
    lr = start_lr * (0.1 ** (epoch // period))
68
    for param_group in optimizer.param_groups:
69
        param_group['lr'] = lr
70
71
class EarlyStopping():
72
    """
73
    Early stopping to stop the training when the loss does not improve after
74
    certain epochs.
75
    """
76
    def __init__(self, patience=5, min_delta=0.001):
77
        """
78
        :param patience: how many epochs to wait before stopping when loss is
79
               not improving
80
        :param min_delta: minimum difference between new loss and old loss for
81
               new loss to be considered as an improvement
82
        """
83
        self.patience = patience
84
        self.min_delta = min_delta
85
        self.counter = 0
86
        self.best_loss = None
87
        self.early_stop = False
88
89
    def __call__(self, val_loss):
90
        if self.best_loss == None:
91
            self.best_loss = val_loss
92
        elif self.best_loss - val_loss > self.min_delta:
93
            self.best_loss = val_loss
94
        elif self.best_loss - val_loss < self.min_delta:
95
            self.counter += 1
96
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
97
            if self.counter >= self.patience:
98
                print('INFO: Early stopping')
99
                self.early_stop = True
100
101
class LRScheduler():
102
    """
103
    Learning rate scheduler. If the validation loss does not decrease for the 
104
    given number of `patience` epochs, then the learning rate will decrease by
105
    by given `factor`.
106
    """
107
    def __init__(self, optimizer, patience=1, min_lr=1e-5, factor=0.1, cooldown=1, threshold=0.001):
108
        """
109
        new_lr = old_lr * factor
110
        :param optimizer: the optimizer we are using
111
        :param patience: how many epochs to wait before updating the lr
112
        :param min_lr: least lr value to reduce to while updating
113
        :param factor: factor by which the lr should be updated
114
        """
115
        self.optimizer = optimizer
116
        self.patience = patience
117
        self.min_lr = min_lr
118
        self.factor = factor
119
        self.cooldown = cooldown
120
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 
121
                self.optimizer,
122
                mode='min',
123
                patience=self.patience,
124
                factor=self.factor,
125
                min_lr=self.min_lr,
126
                cooldown=self.cooldown,
127
                verbose=True
128
            )
129
    def __call__(self, val_loss):
130
        self.lr_scheduler.step(val_loss)