|
a |
|
b/utils.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.optim as optim |
|
|
4 |
import torchvision.models as models |
|
|
5 |
import torchvision.transforms as transforms |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
import numpy as np |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def model_prediction(chpnt): |
|
|
11 |
model = my_model() |
|
|
12 |
ch = torch.load(chpnt) |
|
|
13 |
model.load_state_dict(ch['state_dict']) |
|
|
14 |
return model |
|
|
15 |
|
|
|
16 |
def get_model(): |
|
|
17 |
model = my_model() |
|
|
18 |
model = model.cuda() |
|
|
19 |
return model |
|
|
20 |
|
|
|
21 |
# Function that returns the model |
|
|
22 |
class my_model(nn.Module): |
|
|
23 |
def __init__(self): |
|
|
24 |
super(my_model, self).__init__() |
|
|
25 |
# CNN |
|
|
26 |
model = models.resnet34(weights='DEFAULT') |
|
|
27 |
conv1 = model._modules['conv1'].weight.detach().clone().mean(dim=1, keepdim=True) |
|
|
28 |
model._modules['conv1'] = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
29 |
model._modules['conv1'].weight.data = conv1 |
|
|
30 |
model.fc = nn.Linear(model.fc.in_features, 2) |
|
|
31 |
self.features = nn.Sequential(*list(model.children())[0:-1]) |
|
|
32 |
self.fc = list(model.children())[-1] |
|
|
33 |
self.flat = True |
|
|
34 |
|
|
|
35 |
def forward(self, x): |
|
|
36 |
f = self.features(x) |
|
|
37 |
if self.flat: |
|
|
38 |
f = f.view(x.size(0), -1) |
|
|
39 |
o = self.fc(f) |
|
|
40 |
if not self.flat: |
|
|
41 |
o = o.view(x.size(0), -1) |
|
|
42 |
return o |
|
|
43 |
|
|
|
44 |
def forward_feat(self, x): |
|
|
45 |
f = self.features(x) |
|
|
46 |
if self.flat: |
|
|
47 |
f = f.view(x.size(0), -1) |
|
|
48 |
o = self.fc(f) |
|
|
49 |
if not self.flat: |
|
|
50 |
o = o.view(x.size(0), -1) |
|
|
51 |
return f,o |
|
|
52 |
|
|
|
53 |
# Function that creates the optimizer |
|
|
54 |
def create_optimizer(model, mode, lr, momentum, wd): |
|
|
55 |
if mode == 'sgd': |
|
|
56 |
optimizer = optim.SGD(model.parameters(), lr, |
|
|
57 |
momentum=momentum, dampening=0, |
|
|
58 |
weight_decay=wd, nesterov=True) |
|
|
59 |
elif mode == 'adam': |
|
|
60 |
optimizer = optim.Adam(model.parameters(), lr=lr, |
|
|
61 |
weight_decay=wd) |
|
|
62 |
return optimizer |
|
|
63 |
|
|
|
64 |
# Function to anneal learning rate |
|
|
65 |
def adjust_learning_rate(optimizer, epoch, period, start_lr): |
|
|
66 |
"""Sets the learning rate to the initial LR decayed by 10 every period epochs""" |
|
|
67 |
lr = start_lr * (0.1 ** (epoch // period)) |
|
|
68 |
for param_group in optimizer.param_groups: |
|
|
69 |
param_group['lr'] = lr |
|
|
70 |
|
|
|
71 |
class EarlyStopping(): |
|
|
72 |
""" |
|
|
73 |
Early stopping to stop the training when the loss does not improve after |
|
|
74 |
certain epochs. |
|
|
75 |
""" |
|
|
76 |
def __init__(self, patience=5, min_delta=0.001): |
|
|
77 |
""" |
|
|
78 |
:param patience: how many epochs to wait before stopping when loss is |
|
|
79 |
not improving |
|
|
80 |
:param min_delta: minimum difference between new loss and old loss for |
|
|
81 |
new loss to be considered as an improvement |
|
|
82 |
""" |
|
|
83 |
self.patience = patience |
|
|
84 |
self.min_delta = min_delta |
|
|
85 |
self.counter = 0 |
|
|
86 |
self.best_loss = None |
|
|
87 |
self.early_stop = False |
|
|
88 |
|
|
|
89 |
def __call__(self, val_loss): |
|
|
90 |
if self.best_loss == None: |
|
|
91 |
self.best_loss = val_loss |
|
|
92 |
elif self.best_loss - val_loss > self.min_delta: |
|
|
93 |
self.best_loss = val_loss |
|
|
94 |
elif self.best_loss - val_loss < self.min_delta: |
|
|
95 |
self.counter += 1 |
|
|
96 |
print(f"INFO: Early stopping counter {self.counter} of {self.patience}") |
|
|
97 |
if self.counter >= self.patience: |
|
|
98 |
print('INFO: Early stopping') |
|
|
99 |
self.early_stop = True |
|
|
100 |
|
|
|
101 |
class LRScheduler(): |
|
|
102 |
""" |
|
|
103 |
Learning rate scheduler. If the validation loss does not decrease for the |
|
|
104 |
given number of `patience` epochs, then the learning rate will decrease by |
|
|
105 |
by given `factor`. |
|
|
106 |
""" |
|
|
107 |
def __init__(self, optimizer, patience=1, min_lr=1e-5, factor=0.1, cooldown=1, threshold=0.001): |
|
|
108 |
""" |
|
|
109 |
new_lr = old_lr * factor |
|
|
110 |
:param optimizer: the optimizer we are using |
|
|
111 |
:param patience: how many epochs to wait before updating the lr |
|
|
112 |
:param min_lr: least lr value to reduce to while updating |
|
|
113 |
:param factor: factor by which the lr should be updated |
|
|
114 |
""" |
|
|
115 |
self.optimizer = optimizer |
|
|
116 |
self.patience = patience |
|
|
117 |
self.min_lr = min_lr |
|
|
118 |
self.factor = factor |
|
|
119 |
self.cooldown = cooldown |
|
|
120 |
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
|
121 |
self.optimizer, |
|
|
122 |
mode='min', |
|
|
123 |
patience=self.patience, |
|
|
124 |
factor=self.factor, |
|
|
125 |
min_lr=self.min_lr, |
|
|
126 |
cooldown=self.cooldown, |
|
|
127 |
verbose=True |
|
|
128 |
) |
|
|
129 |
def __call__(self, val_loss): |
|
|
130 |
self.lr_scheduler.step(val_loss) |