|
a |
|
b/solver.py |
|
|
1 |
import glob |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
from nn_common_modules import losses as additional_losses |
|
|
7 |
from torch.optim import lr_scheduler |
|
|
8 |
|
|
|
9 |
import utils.common_utils as common_utils |
|
|
10 |
from utils.log_utils import LogWriter |
|
|
11 |
|
|
|
12 |
CHECKPOINT_DIR = 'checkpoints' |
|
|
13 |
CHECKPOINT_EXTENSION = 'pth.tar' |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class Solver(object): |
|
|
17 |
|
|
|
18 |
def __init__(self, |
|
|
19 |
model, |
|
|
20 |
exp_name, |
|
|
21 |
device, |
|
|
22 |
num_class, |
|
|
23 |
optim=torch.optim.Adam, |
|
|
24 |
optim_args={}, |
|
|
25 |
loss_func=additional_losses.CombinedLoss(), |
|
|
26 |
model_name='quicknat', |
|
|
27 |
labels=None, |
|
|
28 |
num_epochs=10, |
|
|
29 |
log_nth=5, |
|
|
30 |
lr_scheduler_step_size=5, |
|
|
31 |
lr_scheduler_gamma=0.5, |
|
|
32 |
use_last_checkpoint=True, |
|
|
33 |
exp_dir='experiments', |
|
|
34 |
log_dir='logs'): |
|
|
35 |
|
|
|
36 |
self.device = device |
|
|
37 |
self.model = model |
|
|
38 |
|
|
|
39 |
self.model_name = model_name |
|
|
40 |
self.labels = labels |
|
|
41 |
self.num_epochs = num_epochs |
|
|
42 |
if torch.cuda.is_available(): |
|
|
43 |
self.loss_func = loss_func.cuda(device) |
|
|
44 |
else: |
|
|
45 |
self.loss_func = loss_func |
|
|
46 |
self.optim = optim(model.parameters(), **optim_args) |
|
|
47 |
self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size, |
|
|
48 |
gamma=lr_scheduler_gamma) |
|
|
49 |
|
|
|
50 |
exp_dir_path = os.path.join(exp_dir, exp_name) |
|
|
51 |
common_utils.create_if_not(exp_dir_path) |
|
|
52 |
common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) |
|
|
53 |
self.exp_dir_path = exp_dir_path |
|
|
54 |
|
|
|
55 |
self.log_nth = log_nth |
|
|
56 |
self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) |
|
|
57 |
|
|
|
58 |
self.use_last_checkpoint = use_last_checkpoint |
|
|
59 |
|
|
|
60 |
self.start_epoch = 1 |
|
|
61 |
self.start_iteration = 1 |
|
|
62 |
|
|
|
63 |
self.best_ds_mean = 0 |
|
|
64 |
self.best_ds_mean_epoch = 0 |
|
|
65 |
|
|
|
66 |
if use_last_checkpoint: |
|
|
67 |
self.load_checkpoint() |
|
|
68 |
|
|
|
69 |
# TODO:Need to correct the CM and dice score calculation. |
|
|
70 |
def train(self, train_loader, val_loader): |
|
|
71 |
""" |
|
|
72 |
Train a given model with the provided data. |
|
|
73 |
|
|
|
74 |
Inputs: |
|
|
75 |
- train_loader: train data in torch.utils.data.DataLoader |
|
|
76 |
- val_loader: val data in torch.utils.data.DataLoader |
|
|
77 |
""" |
|
|
78 |
model, optim, scheduler = self.model, self.optim, self.scheduler |
|
|
79 |
dataloaders = { |
|
|
80 |
'train': train_loader, |
|
|
81 |
'val': val_loader |
|
|
82 |
} |
|
|
83 |
|
|
|
84 |
if torch.cuda.is_available(): |
|
|
85 |
torch.cuda.empty_cache() |
|
|
86 |
model.cuda(self.device) |
|
|
87 |
|
|
|
88 |
print('START TRAINING. : model name = %s, device = %s' % ( |
|
|
89 |
self.model_name, torch.cuda.get_device_name(self.device))) |
|
|
90 |
current_iteration = self.start_iteration |
|
|
91 |
for epoch in range(self.start_epoch, self.num_epochs + 1): |
|
|
92 |
print("\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs)) |
|
|
93 |
for phase in ['train', 'val']: |
|
|
94 |
print("<<<= Phase: %s =>>>" % phase) |
|
|
95 |
loss_arr = [] |
|
|
96 |
out_list = [] |
|
|
97 |
y_list = [] |
|
|
98 |
if phase == 'train': |
|
|
99 |
model.train() |
|
|
100 |
scheduler.step() |
|
|
101 |
else: |
|
|
102 |
model.eval() |
|
|
103 |
for i_batch, sample_batched in enumerate(dataloaders[phase]): |
|
|
104 |
X = sample_batched[0].type(torch.FloatTensor) |
|
|
105 |
y = sample_batched[1].type(torch.LongTensor) |
|
|
106 |
w = sample_batched[2].type(torch.FloatTensor) |
|
|
107 |
|
|
|
108 |
if model.is_cuda: |
|
|
109 |
X, y, w = X.cuda(self.device, non_blocking=True), y.cuda(self.device, |
|
|
110 |
non_blocking=True), w.cuda(self.device, |
|
|
111 |
non_blocking=True) |
|
|
112 |
|
|
|
113 |
output = model(X) |
|
|
114 |
loss = self.loss_func(output, y, w) |
|
|
115 |
if phase == 'train': |
|
|
116 |
optim.zero_grad() |
|
|
117 |
loss.backward() |
|
|
118 |
optim.step() |
|
|
119 |
if i_batch % self.log_nth == 0: |
|
|
120 |
self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration) |
|
|
121 |
current_iteration += 1 |
|
|
122 |
|
|
|
123 |
loss_arr.append(loss.item()) |
|
|
124 |
|
|
|
125 |
_, batch_output = torch.max(output, dim=1) |
|
|
126 |
out_list.append(batch_output.cpu()) |
|
|
127 |
y_list.append(y.cpu()) |
|
|
128 |
|
|
|
129 |
del X, y, w, output, batch_output, loss |
|
|
130 |
torch.cuda.empty_cache() |
|
|
131 |
if phase == 'val': |
|
|
132 |
if i_batch != len(dataloaders[phase]) - 1: |
|
|
133 |
print("#", end='', flush=True) |
|
|
134 |
else: |
|
|
135 |
print("100%", flush=True) |
|
|
136 |
|
|
|
137 |
with torch.no_grad(): |
|
|
138 |
out_arr, y_arr = torch.cat(out_list), torch.cat(y_list) |
|
|
139 |
self.logWriter.loss_per_epoch(loss_arr, phase, epoch) |
|
|
140 |
index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False) |
|
|
141 |
self.logWriter.image_per_epoch(model.predict(dataloaders[phase].dataset.X[index], self.device), |
|
|
142 |
dataloaders[phase].dataset.y[index], phase, epoch) |
|
|
143 |
self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch) |
|
|
144 |
ds_mean = self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch) |
|
|
145 |
if phase == 'val': |
|
|
146 |
if ds_mean > self.best_ds_mean: |
|
|
147 |
self.best_ds_mean = ds_mean |
|
|
148 |
self.best_ds_mean_epoch = epoch |
|
|
149 |
|
|
|
150 |
print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====") |
|
|
151 |
self.save_checkpoint({ |
|
|
152 |
'epoch': epoch + 1, |
|
|
153 |
'start_iteration': current_iteration + 1, |
|
|
154 |
'arch': self.model_name, |
|
|
155 |
'state_dict': model.state_dict(), |
|
|
156 |
'optimizer': optim.state_dict(), |
|
|
157 |
'scheduler': scheduler.state_dict(), |
|
|
158 |
'best_ds_mean': self.best_ds_mean, |
|
|
159 |
'best_ds_mean_epoch': self.best_ds_mean_epoch |
|
|
160 |
}, os.path.join(self.exp_dir_path, CHECKPOINT_DIR, |
|
|
161 |
'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) |
|
|
162 |
|
|
|
163 |
print('FINISH.') |
|
|
164 |
self.logWriter.close() |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
def save_best_model(self, path): |
|
|
168 |
""" |
|
|
169 |
Save model with its parameters to the given path. Conventionally the |
|
|
170 |
path should end with "*.model". |
|
|
171 |
Inputs: |
|
|
172 |
- path: path string |
|
|
173 |
""" |
|
|
174 |
print('Saving model... %s' % path) |
|
|
175 |
print('Best Model at Epoch: ' + str(self.best_ds_mean_epoch)) |
|
|
176 |
self.load_checkpoint(self.best_ds_mean_epoch) |
|
|
177 |
|
|
|
178 |
torch.save(self.model, path) |
|
|
179 |
|
|
|
180 |
def save_checkpoint(self, state, filename): |
|
|
181 |
torch.save(state, filename) |
|
|
182 |
|
|
|
183 |
def load_checkpoint(self, epoch=None): |
|
|
184 |
if epoch is not None: |
|
|
185 |
checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, |
|
|
186 |
'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION) |
|
|
187 |
self._load_checkpoint_file(checkpoint_path) |
|
|
188 |
else: |
|
|
189 |
all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION) |
|
|
190 |
list_of_files = glob.glob(all_files_path) |
|
|
191 |
if len(list_of_files) > 0: |
|
|
192 |
checkpoint_path = max(list_of_files, key=os.path.getctime) |
|
|
193 |
self._load_checkpoint_file(checkpoint_path) |
|
|
194 |
else: |
|
|
195 |
self.logWriter.log( |
|
|
196 |
"=> no checkpoint found at '{}' folder".format(os.path.join(self.exp_dir_path, CHECKPOINT_DIR))) |
|
|
197 |
|
|
|
198 |
def _load_checkpoint_file(self, file_path): |
|
|
199 |
self.logWriter.log("=> loading checkpoint '{}'".format(file_path)) |
|
|
200 |
checkpoint = torch.load(file_path) |
|
|
201 |
self.start_epoch = checkpoint['epoch'] |
|
|
202 |
self.start_iteration = checkpoint['start_iteration'] |
|
|
203 |
self.model.load_state_dict(checkpoint['state_dict']) |
|
|
204 |
self.optim.load_state_dict(checkpoint['optimizer']) |
|
|
205 |
if 'best_ds_mean' in checkpoint.keys(): |
|
|
206 |
self.best_ds_mean = checkpoint['best_ds_mean'] |
|
|
207 |
self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch'] |
|
|
208 |
|
|
|
209 |
for state in self.optim.state.values(): |
|
|
210 |
for k, v in state.items(): |
|
|
211 |
if torch.is_tensor(v): |
|
|
212 |
state[k] = v.to(self.device) |
|
|
213 |
|
|
|
214 |
self.scheduler.load_state_dict(checkpoint['scheduler']) |
|
|
215 |
self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format(file_path, checkpoint['epoch'])) |