Diff of /solver.py [000000] .. [6f9c00]

Switch to unified view

a b/solver.py
1
import glob
2
import os
3
4
import numpy as np
5
import torch
6
from nn_common_modules import losses as additional_losses
7
from torch.optim import lr_scheduler
8
9
import utils.common_utils as common_utils
10
from utils.log_utils import LogWriter
11
12
CHECKPOINT_DIR = 'checkpoints'
13
CHECKPOINT_EXTENSION = 'pth.tar'
14
15
16
class Solver(object):
17
18
    def __init__(self,
19
                 model,
20
                 exp_name,
21
                 device,
22
                 num_class,
23
                 optim=torch.optim.Adam,
24
                 optim_args={},
25
                 loss_func=additional_losses.CombinedLoss(),
26
                 model_name='quicknat',
27
                 labels=None,
28
                 num_epochs=10,
29
                 log_nth=5,
30
                 lr_scheduler_step_size=5,
31
                 lr_scheduler_gamma=0.5,
32
                 use_last_checkpoint=True,
33
                 exp_dir='experiments',
34
                 log_dir='logs'):
35
36
        self.device = device
37
        self.model = model
38
39
        self.model_name = model_name
40
        self.labels = labels
41
        self.num_epochs = num_epochs
42
        if torch.cuda.is_available():
43
            self.loss_func = loss_func.cuda(device)
44
        else:
45
            self.loss_func = loss_func
46
        self.optim = optim(model.parameters(), **optim_args)
47
        self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size,
48
                                             gamma=lr_scheduler_gamma)
49
50
        exp_dir_path = os.path.join(exp_dir, exp_name)
51
        common_utils.create_if_not(exp_dir_path)
52
        common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR))
53
        self.exp_dir_path = exp_dir_path
54
55
        self.log_nth = log_nth
56
        self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels)
57
58
        self.use_last_checkpoint = use_last_checkpoint
59
60
        self.start_epoch = 1
61
        self.start_iteration = 1
62
63
        self.best_ds_mean = 0
64
        self.best_ds_mean_epoch = 0
65
66
        if use_last_checkpoint:
67
            self.load_checkpoint()
68
69
    # TODO:Need to correct the CM and dice score calculation.
70
    def train(self, train_loader, val_loader):
71
        """
72
        Train a given model with the provided data.
73
74
        Inputs:
75
        - train_loader: train data in torch.utils.data.DataLoader
76
        - val_loader: val data in torch.utils.data.DataLoader
77
        """
78
        model, optim, scheduler = self.model, self.optim, self.scheduler
79
        dataloaders = {
80
            'train': train_loader,
81
            'val': val_loader
82
        }
83
84
        if torch.cuda.is_available():
85
            torch.cuda.empty_cache()
86
            model.cuda(self.device)
87
88
        print('START TRAINING. : model name = %s, device = %s' % (
89
            self.model_name, torch.cuda.get_device_name(self.device)))
90
        current_iteration = self.start_iteration
91
        for epoch in range(self.start_epoch, self.num_epochs + 1):
92
            print("\n==== Epoch [ %d  /  %d ] START ====" % (epoch, self.num_epochs))
93
            for phase in ['train', 'val']:
94
                print("<<<= Phase: %s =>>>" % phase)
95
                loss_arr = []
96
                out_list = []
97
                y_list = []
98
                if phase == 'train':
99
                    model.train()
100
                    scheduler.step()
101
                else:
102
                    model.eval()
103
                for i_batch, sample_batched in enumerate(dataloaders[phase]):
104
                    X = sample_batched[0].type(torch.FloatTensor)
105
                    y = sample_batched[1].type(torch.LongTensor)
106
                    w = sample_batched[2].type(torch.FloatTensor)
107
108
                    if model.is_cuda:
109
                        X, y, w = X.cuda(self.device, non_blocking=True), y.cuda(self.device,
110
                                                                                 non_blocking=True), w.cuda(self.device,
111
                                                                                                            non_blocking=True)
112
113
                    output = model(X)
114
                    loss = self.loss_func(output, y, w)
115
                    if phase == 'train':
116
                        optim.zero_grad()
117
                        loss.backward()
118
                        optim.step()
119
                        if i_batch % self.log_nth == 0:
120
                            self.logWriter.loss_per_iter(loss.item(), i_batch, current_iteration)
121
                        current_iteration += 1
122
123
                    loss_arr.append(loss.item())
124
125
                    _, batch_output = torch.max(output, dim=1)
126
                    out_list.append(batch_output.cpu())
127
                    y_list.append(y.cpu())
128
129
                    del X, y, w, output, batch_output, loss
130
                    torch.cuda.empty_cache()
131
                    if phase == 'val':
132
                        if i_batch != len(dataloaders[phase]) - 1:
133
                            print("#", end='', flush=True)
134
                        else:
135
                            print("100%", flush=True)
136
137
                with torch.no_grad():
138
                    out_arr, y_arr = torch.cat(out_list), torch.cat(y_list)
139
                    self.logWriter.loss_per_epoch(loss_arr, phase, epoch)
140
                    index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False)
141
                    self.logWriter.image_per_epoch(model.predict(dataloaders[phase].dataset.X[index], self.device),
142
                                                   dataloaders[phase].dataset.y[index], phase, epoch)
143
                    self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch)
144
                    ds_mean = self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch)
145
                    if phase == 'val':
146
                        if ds_mean > self.best_ds_mean:
147
                            self.best_ds_mean = ds_mean
148
                            self.best_ds_mean_epoch = epoch
149
150
            print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====")
151
            self.save_checkpoint({
152
                'epoch': epoch + 1,
153
                'start_iteration': current_iteration + 1,
154
                'arch': self.model_name,
155
                'state_dict': model.state_dict(),
156
                'optimizer': optim.state_dict(),
157
                'scheduler': scheduler.state_dict(),
158
                'best_ds_mean': self.best_ds_mean,
159
                'best_ds_mean_epoch': self.best_ds_mean_epoch
160
            }, os.path.join(self.exp_dir_path, CHECKPOINT_DIR,
161
                            'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)) 
162
163
        print('FINISH.')
164
        self.logWriter.close()
165
166
167
    def save_best_model(self, path):
168
        """
169
        Save model with its parameters to the given path. Conventionally the
170
        path should end with "*.model".
171
        Inputs:
172
        - path: path string
173
        """
174
        print('Saving model... %s' % path)
175
        print('Best Model at Epoch: ' + str(self.best_ds_mean_epoch))
176
        self.load_checkpoint(self.best_ds_mean_epoch)
177
178
        torch.save(self.model, path)
179
180
    def save_checkpoint(self, state, filename):
181
        torch.save(state, filename)
182
183
    def load_checkpoint(self, epoch=None):
184
        if epoch is not None:
185
            checkpoint_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR,
186
                                           'checkpoint_epoch_' + str(epoch) + '.' + CHECKPOINT_EXTENSION)
187
            self._load_checkpoint_file(checkpoint_path)
188
        else:
189
            all_files_path = os.path.join(self.exp_dir_path, CHECKPOINT_DIR, '*.' + CHECKPOINT_EXTENSION)
190
            list_of_files = glob.glob(all_files_path)
191
            if len(list_of_files) > 0:
192
                checkpoint_path = max(list_of_files, key=os.path.getctime)
193
                self._load_checkpoint_file(checkpoint_path)
194
            else:
195
                self.logWriter.log(
196
                    "=> no checkpoint found at '{}' folder".format(os.path.join(self.exp_dir_path, CHECKPOINT_DIR)))
197
198
    def _load_checkpoint_file(self, file_path):
199
        self.logWriter.log("=> loading checkpoint '{}'".format(file_path))
200
        checkpoint = torch.load(file_path)
201
        self.start_epoch = checkpoint['epoch']
202
        self.start_iteration = checkpoint['start_iteration']
203
        self.model.load_state_dict(checkpoint['state_dict'])
204
        self.optim.load_state_dict(checkpoint['optimizer'])
205
        if 'best_ds_mean' in checkpoint.keys():
206
            self.best_ds_mean = checkpoint['best_ds_mean']
207
            self.best_ds_mean_epoch = checkpoint['best_ds_mean_epoch']
208
209
        for state in self.optim.state.values():
210
            for k, v in state.items():
211
                if torch.is_tensor(v):
212
                    state[k] = v.to(self.device)
213
214
        self.scheduler.load_state_dict(checkpoint['scheduler'])
215
        self.logWriter.log("=> loaded checkpoint '{}' (epoch {})".format(file_path, checkpoint['epoch']))