|
a |
|
b/dl/utils/train.py |
|
|
1 |
import sys |
|
|
2 |
import os |
|
|
3 |
import copy |
|
|
4 |
lib_path = 'I:/code' |
|
|
5 |
if not os.path.exists(lib_path): |
|
|
6 |
lib_path = '/media/6T/.tianle/.lib' |
|
|
7 |
if os.path.exists(lib_path) and lib_path not in sys.path: |
|
|
8 |
sys.path.append(lib_path) |
|
|
9 |
|
|
|
10 |
import numpy as np |
|
|
11 |
import sklearn |
|
|
12 |
|
|
|
13 |
import torch |
|
|
14 |
import torch.nn as nn |
|
|
15 |
|
|
|
16 |
from dl.utils.visualization.visualization import plot_scatter |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def cosine_similarity(x, y=None, eps=1e-8): |
|
|
20 |
"""Calculate cosine similarity between two matrices; |
|
|
21 |
|
|
|
22 |
Args: |
|
|
23 |
x: N*p tensor |
|
|
24 |
y: M*p tensor or None; if None, set y = x |
|
|
25 |
This function do not broadcast |
|
|
26 |
|
|
|
27 |
Returns: |
|
|
28 |
N*M tensor |
|
|
29 |
|
|
|
30 |
""" |
|
|
31 |
w1 = torch.norm(x, p=2, dim=1, keepdim=True) |
|
|
32 |
if y is None: |
|
|
33 |
w2 = w1.squeeze(dim=1) |
|
|
34 |
y = x |
|
|
35 |
else: |
|
|
36 |
w2 = torch.norm(y, p=2, dim=1) |
|
|
37 |
w12 = torch.mm(x, y.t()) |
|
|
38 |
return w12 / (w1*w2).clamp(min=eps) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def adjust_learning_rate(optimizer, lr, epoch, reduce_every=2): |
|
|
42 |
"""Reduce learning rate by 10% every reduce_every iterations |
|
|
43 |
""" |
|
|
44 |
lr = lr * (0.1 ** (epoch//reduce_every)) |
|
|
45 |
for param_group in optimizer.param_groups: |
|
|
46 |
param_group['lr'] = lr |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def predict(model, x, batch_size=None, train=True, num_heads=1): |
|
|
50 |
"""Calculate model(x) |
|
|
51 |
|
|
|
52 |
Args: |
|
|
53 |
batch_size: default None; predict in one batch for small model and data. |
|
|
54 |
train: default True; if False, call torch.set_grad_enabled(False) first. |
|
|
55 |
num_heads: default 1; if num_heads > 1, then calculate multi-head output |
|
|
56 |
""" |
|
|
57 |
if batch_size is None: |
|
|
58 |
batch_size = x.size(0) |
|
|
59 |
y_pred = [] |
|
|
60 |
if num_heads > 1: |
|
|
61 |
y_pred = [[] for i in range(num_heads)] # store decoder output |
|
|
62 |
prev = torch.is_grad_enabled() |
|
|
63 |
if train: |
|
|
64 |
model.train() |
|
|
65 |
torch.set_grad_enabled(True) |
|
|
66 |
else: |
|
|
67 |
model.eval() |
|
|
68 |
torch.set_grad_enabled(False) |
|
|
69 |
for i in range(0, len(x), batch_size): |
|
|
70 |
y_ = model(x[i:i+batch_size]) |
|
|
71 |
if num_heads > 1: |
|
|
72 |
for i in range(num_heads): |
|
|
73 |
y_pred[i].append(y_[i]) |
|
|
74 |
else: |
|
|
75 |
y_pred.append(y_) |
|
|
76 |
torch.set_grad_enabled(prev) |
|
|
77 |
|
|
|
78 |
if num_heads > 1: |
|
|
79 |
return [torch.cat(y, 0) for y in y_pred] |
|
|
80 |
else: |
|
|
81 |
return torch.cat(y_pred, 0) |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
def plot_data(model, x, y, title='', num_heads=1, batch_size=None): |
|
|
85 |
"""Scatter plot for input layer and output with colors corresponding to labels |
|
|
86 |
""" |
|
|
87 |
if isinstance(y, torch.Tensor): |
|
|
88 |
y = y.cpu().detach().numpy() |
|
|
89 |
plot_scatter(x, labels=y, title=f'Input {title}') |
|
|
90 |
y_pred = predict(model, x, batch_size, train=False, num_heads=num_heads) |
|
|
91 |
if num_heads > 1: |
|
|
92 |
for i in range(num_heads): |
|
|
93 |
plot_scatter(y_pred[i], labels=y, title=f'Head {i}') |
|
|
94 |
else: |
|
|
95 |
plot_scatter(y_pred, labels=y, title='Output') |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def plot_data_multi_splits(model, xs, ys, num_heads=1, titles=['Training', 'Validation', 'Test'], batch_size=None): |
|
|
99 |
"""Call plot_data on multiple data splits, typically x_train, x_val, x_test |
|
|
100 |
|
|
|
101 |
Args: |
|
|
102 |
Most arguments are passed to plot_data |
|
|
103 |
xs: a list of model input |
|
|
104 |
ys: a list of target labels |
|
|
105 |
titles: a list of titles for each data split |
|
|
106 |
|
|
|
107 |
""" |
|
|
108 |
if len(xs) != len(titles): # Make sure titles are of the same length as xs |
|
|
109 |
titles = [f'Data split {i}' for i in range(len(xs))] |
|
|
110 |
for i, (x, y) in enumerate(zip(xs, ys)): |
|
|
111 |
if len(x) > 0 and len(x)==len(y): |
|
|
112 |
plot_data(model, x, y, title=titles[i], num_heads=num_heads, batch_size=batch_size) |
|
|
113 |
else: |
|
|
114 |
print(f'x for {titles[i]} is empty or len(x) != len(y)') |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
def get_label_prob(labels, verbose=True): |
|
|
118 |
"""Get label distribution |
|
|
119 |
""" |
|
|
120 |
if isinstance(labels, torch.Tensor): |
|
|
121 |
unique_labels = torch.unique(labels).sort()[0] |
|
|
122 |
label_prob = torch.stack([labels==i for i in unique_labels], dim=0).sum(dim=1) |
|
|
123 |
label_prob = label_prob.float()/len(labels) |
|
|
124 |
else: |
|
|
125 |
labels = np.array(labels) # if labels is a list then change it to np.array |
|
|
126 |
unique_labels = sorted(np.unique(labels)) |
|
|
127 |
label_prob = np.stack([labels==i for i in unique_labels], axis=0).sum(axis=1) |
|
|
128 |
label_prob = label_prob / len(labels) |
|
|
129 |
if verbose: |
|
|
130 |
msg = '\n'.join(map(lambda x: f'{x[0]}: {x[1].item():.2f}', |
|
|
131 |
zip(unique_labels, label_prob))) |
|
|
132 |
print(f'label distribution:\n{msg}') |
|
|
133 |
return label_prob |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
def eval_classification(y_true, y_pred=None, model=None, x=None, batch_size=None, multi_heads=False, |
|
|
137 |
cls_head=0, average='weighted', predict_func=None, pred_kwargs=None, verbose=True): |
|
|
138 |
"""Evaluate classification results |
|
|
139 |
|
|
|
140 |
Args: |
|
|
141 |
y_true: true labels; numpy array or torch.Tensor |
|
|
142 |
y_pred: if None, then y_pred = model(x) |
|
|
143 |
model: torch.nn.Module type |
|
|
144 |
x: input tensor |
|
|
145 |
batch_size: used for predict(model, x, batch_size) |
|
|
146 |
multi_heads: If true, the model output a list; Assume the classification head is the first one |
|
|
147 |
cls_head: only used when multi_heads is True; specify which head is used for classification; default 0 |
|
|
148 |
average: used for sklearn.metrics to calculate precision, recall, f1, auc and ap; default: 'weighted' |
|
|
149 |
predict_func: if not None, use predict_func(model, x, **pred_kwargs) instead of predict() |
|
|
150 |
pred_kwargs: dictionary arguments for predict_func |
|
|
151 |
|
|
|
152 |
""" |
|
|
153 |
if isinstance(y_true, torch.Tensor): |
|
|
154 |
y_true = y_true.cpu().detach().numpy().reshape(-1) |
|
|
155 |
num_cls = len(np.unique(y_true)) |
|
|
156 |
auc = -1 # dummy variable for multi-class classification |
|
|
157 |
average_precision = -1 # dummy variable for multi-class classification |
|
|
158 |
y_score = None # only used to calculate auc and average_precision for binary classification; will be set later |
|
|
159 |
if y_pred is None: # Calculate y_pred = model(x) in batches |
|
|
160 |
if predict_func is None: |
|
|
161 |
# use predict() defined in this file |
|
|
162 |
num_heads = 2 if multi_heads else 1 # num_heads >= 2 is to make predict() to process the model as multi-output |
|
|
163 |
y_ = predict(model, x, batch_size, train=False, num_heads=num_heads) |
|
|
164 |
y_pred = y_[cls_head] if multi_heads else y_ |
|
|
165 |
else: |
|
|
166 |
# use customized predict_func with variable keyworded arguments |
|
|
167 |
y_pred = predict_func(model, x, **pred_kwargs) |
|
|
168 |
if isinstance(y_pred, torch.Tensor): |
|
|
169 |
# either input argument is a torch.Tensor or calculate it from model(x) in the last chunk |
|
|
170 |
y_pred = y_pred.cpu().detach().numpy() |
|
|
171 |
if isinstance(y_pred, np.ndarray) and y_pred.ndim == 2 and y_pred.shape[1] > 1: |
|
|
172 |
# y_pred is the class score matrix: n_samples * n_classes |
|
|
173 |
if y_pred.shape[1] == 2: # for binary classification |
|
|
174 |
y_score = y_pred[:,1] - y_pred[:,0] # y_score is only useful for calculating auc and average precison |
|
|
175 |
y_pred = y_pred.argmax(axis=-1) # only consider top 1 prediction |
|
|
176 |
if num_cls==2 and y_pred.dtype == np.dtype('float'): # last chunk had not been executed |
|
|
177 |
# For binary classification, argument y_pred can be the scores for belonging to class 1. |
|
|
178 |
y_score = y_pred # Used for calculate auc and average_precision |
|
|
179 |
y_pred = (y_score > 0).astype('int') |
|
|
180 |
acc = sklearn.metrics.accuracy_score(y_true, y_pred) |
|
|
181 |
precision = sklearn.metrics.precision_score(y_true, y_pred, average=average) |
|
|
182 |
recall = sklearn.metrics.recall_score(y_true, y_pred, average=average) |
|
|
183 |
f1_score = sklearn.metrics.f1_score(y_true=y_true, y_pred=y_pred, average=average) |
|
|
184 |
adjusted_mutual_info = sklearn.metrics.adjusted_mutual_info_score(labels_true=y_true, labels_pred=y_pred) |
|
|
185 |
confusion_mat = sklearn.metrics.confusion_matrix(y_true, y_pred) |
|
|
186 |
msg = f'acc={acc:.3f}, precision={precision:.3f}, recall={recall:.3f}, fl={f1_score:.3f}, adj_MI={adjusted_mutual_info:.3f}' |
|
|
187 |
if num_cls == 2: |
|
|
188 |
# When y_pred is given as an int np.array or tensor, model(x) is not called; |
|
|
189 |
# set y_score = y_pred to calculate auc and average precision approximately; |
|
|
190 |
# it may not be 100% accurate because I assign y_pred (binary labels) to y_score (which should be probabilities) |
|
|
191 |
if y_score is None: |
|
|
192 |
y_score = y_pred |
|
|
193 |
auc = sklearn.metrics.roc_auc_score(y_true=y_true, y_score=y_score, average=average) |
|
|
194 |
average_precision = sklearn.metrics.average_precision_score(y_true=y_true, y_score=y_score, average=average) |
|
|
195 |
msg = msg + f', auc={auc:.3f}, ap={average_precision:.3f}' |
|
|
196 |
msg = msg + f', confusion_mat=\n{confusion_mat}' |
|
|
197 |
if verbose: |
|
|
198 |
print(msg) |
|
|
199 |
print('report', sklearn.metrics.classification_report(y_true=y_true, y_pred=y_pred)) |
|
|
200 |
|
|
|
201 |
return np.array([acc, precision, recall, f1_score, adjusted_mutual_info, auc, average_precision]), confusion_mat |
|
|
202 |
|
|
|
203 |
|
|
|
204 |
def eval_classification_multi_splits(model, xs, ys, batch_size=None, multi_heads=False, cls_head=0, |
|
|
205 |
average='weighted', return_result=True, split_names=['Train', 'Validataion', 'Test'], |
|
|
206 |
predict_func=None, pred_kwargs=None, verbose=True): |
|
|
207 |
"""Call eval_classification on multiple data splits, e.g., x_train, x_val, x_test with given model |
|
|
208 |
|
|
|
209 |
Args: |
|
|
210 |
Most arguments are passed to eval_classification |
|
|
211 |
xs: a list of model input, e.g., [x_train, x_val, x_test] |
|
|
212 |
ys: a list of targets, e.g., [y_train, y_val, y_test] |
|
|
213 |
return_results: if True return results on non-empty data splits |
|
|
214 |
split_names: for print purpose; default: ['train', 'val', 'test'] |
|
|
215 |
|
|
|
216 |
""" |
|
|
217 |
res = [] |
|
|
218 |
if len(xs) != len(split_names): |
|
|
219 |
split_names = [f'Data split {i}' for i in range(len(xs))] |
|
|
220 |
for i, (x, y) in enumerate(zip(xs, ys)): |
|
|
221 |
if len(x) > 0: |
|
|
222 |
print(split_names[i]) |
|
|
223 |
metric = eval_classification(y_true=y, model=model, x=x, batch_size=batch_size, |
|
|
224 |
multi_heads=multi_heads, cls_head=cls_head, average=average, |
|
|
225 |
predict_func=predict_func, pred_kwargs=pred_kwargs, verbose=verbose) |
|
|
226 |
res.append(metric) |
|
|
227 |
if return_result: |
|
|
228 |
return res |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
def run_one_epoch_single_loss(model, x, y_true, loss_fn=nn.CrossEntropyLoss(), train=True, optimizer=None, |
|
|
232 |
batch_size=None, return_loss=True, epoch=0, print_every=10, verbose=True): |
|
|
233 |
"""Run one epoch, i.e., model(x), but split into batches |
|
|
234 |
|
|
|
235 |
Args: |
|
|
236 |
model: torch.nn.Module |
|
|
237 |
x: torch.Tensor |
|
|
238 |
y_true: target torch.Tensor |
|
|
239 |
loss_fn: loss function |
|
|
240 |
train: if False, call model.eval() and torch.set_grad_enabled(False) to save time |
|
|
241 |
optimizer: needed when train is True |
|
|
242 |
batch_size: if None, batch_size = len(x) |
|
|
243 |
return_loss: if True, return epoch loss |
|
|
244 |
epoch: for print |
|
|
245 |
print_every: print epoch_loss if print_every % epoch == 0 |
|
|
246 |
verbose: if True, print batch_loss |
|
|
247 |
""" |
|
|
248 |
|
|
|
249 |
is_grad_enabled = torch.is_grad_enabled() |
|
|
250 |
if train: |
|
|
251 |
model.train() |
|
|
252 |
torch.set_grad_enabled(True) |
|
|
253 |
else: |
|
|
254 |
model.eval() |
|
|
255 |
torch.set_grad_enabled(False) |
|
|
256 |
loss_history = [] |
|
|
257 |
is_classification = isinstance(y_true.cpu(), torch.LongTensor) |
|
|
258 |
if is_classification: |
|
|
259 |
acc_history = [] |
|
|
260 |
if batch_size is None: |
|
|
261 |
batch_size = len(x) |
|
|
262 |
for i in range(0, len(x), batch_size): |
|
|
263 |
y_pred = model(x[i:i+batch_size]) |
|
|
264 |
loss = loss_fn(y_pred, y_true[i:i+batch_size]) |
|
|
265 |
loss_history.append(loss.item()) |
|
|
266 |
if is_classification: |
|
|
267 |
labels_pred = y_pred.topk(1, -1)[1].squeeze() # only calculate top 1 accuracy |
|
|
268 |
acc = (labels_pred == y_true[i:i+batch_size]).float().mean().item() |
|
|
269 |
acc_history.append(acc) |
|
|
270 |
if verbose: |
|
|
271 |
msg = 'Epoch{} {}/{}: loss={:.2e}'.format( |
|
|
272 |
epoch, i//batch_size, (len(x)+batch_size-1)//batch_size, loss.item()) |
|
|
273 |
if is_classification: |
|
|
274 |
msg = msg + f', acc={acc:.2f}' |
|
|
275 |
print(msg) |
|
|
276 |
if train: |
|
|
277 |
optimizer.zero_grad() |
|
|
278 |
loss.backward() |
|
|
279 |
optimizer.step() |
|
|
280 |
torch.set_grad_enabled(is_grad_enabled) |
|
|
281 |
|
|
|
282 |
loss_epoch = np.mean(loss_history) |
|
|
283 |
if is_classification: |
|
|
284 |
acc_epoch = np.mean(acc_history) |
|
|
285 |
if epoch % print_every == 0: |
|
|
286 |
msg = 'Epoch{} {}: loss={:.2e}'.format(epoch, 'Train' if train else 'Test', np.mean(loss_history)) |
|
|
287 |
if is_classification: |
|
|
288 |
msg = msg + f', acc={np.mean(acc_history):.2f}' |
|
|
289 |
print(msg) |
|
|
290 |
if return_loss: |
|
|
291 |
if is_classification: |
|
|
292 |
return loss_epoch, acc_epoch, loss_history, acc_history |
|
|
293 |
else: |
|
|
294 |
return loss_epoch, loss_history |
|
|
295 |
|
|
|
296 |
|
|
|
297 |
def train_single_loss(model, x_train, y_train, x_val=[], y_val=[], x_test=[], y_test=[], |
|
|
298 |
loss_fn=nn.CrossEntropyLoss(), lr=1e-2, weight_decay=1e-4, amsgrad=True, batch_size=None, num_epochs=1, |
|
|
299 |
reduce_every=200, eval_every=1, print_every=1, verbose=False, |
|
|
300 |
loss_train_his=[], loss_val_his=[], loss_test_his=[], |
|
|
301 |
acc_train_his=[], acc_val_his=[], acc_test_his=[], return_best_val=True): |
|
|
302 |
"""Run a number of epochs to backpropagate |
|
|
303 |
|
|
|
304 |
Args: |
|
|
305 |
Most arguments are passed to run_one_epoch_single_loss |
|
|
306 |
lr, weight_decay, amsgrad are passed to torch.optim.Adam |
|
|
307 |
reduce_every: call adjust_learning_rate if cur_epoch % reduce_every == 0 |
|
|
308 |
eval_every: call run_one_epoch_single_loss on validation and test sets if cur_epoch % eval_every == 0 |
|
|
309 |
print_every: print epoch loss if cur_epoch % print_every == 0 |
|
|
310 |
verbose: if True, print batch loss |
|
|
311 |
return_best_val: if True, return the best model on validation set for classification task |
|
|
312 |
""" |
|
|
313 |
|
|
|
314 |
def eval_one_epoch(x, targets, loss_his, acc_his, epoch, train=False): |
|
|
315 |
"""Function within function; reuse parameters within proper scope |
|
|
316 |
""" |
|
|
317 |
results = run_one_epoch_single_loss(model, x, targets, loss_fn=loss_fn, train=train, optimizer=optimizer, |
|
|
318 |
batch_size=batch_size, return_loss=True, epoch=epoch, print_every=print_every, verbose=verbose) |
|
|
319 |
if is_classification: |
|
|
320 |
loss_epoch, acc_epoch, loss_history, acc_history = results |
|
|
321 |
else: |
|
|
322 |
loss_epoch, loss_history = results |
|
|
323 |
loss_his.append(loss_epoch) |
|
|
324 |
if is_classification: |
|
|
325 |
acc_his.append(acc_epoch) |
|
|
326 |
|
|
|
327 |
is_classification = isinstance(y_train.cpu(), torch.LongTensor) |
|
|
328 |
best_val_acc = -1 # best_val_acc >=0 after the first epoch for classification task |
|
|
329 |
for i in range(num_epochs): |
|
|
330 |
if i == 0: |
|
|
331 |
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), |
|
|
332 |
lr=lr, weight_decay=weight_decay, amsgrad=amsgrad) |
|
|
333 |
# Should I create a new torch.optim.Adam instance every time I adjust learning rate? |
|
|
334 |
adjust_learning_rate(optimizer, lr, i, reduce_every=reduce_every) |
|
|
335 |
|
|
|
336 |
eval_one_epoch(x_train, y_train, loss_train_his, acc_train_his, i, train=True) |
|
|
337 |
if i % eval_every == 0: |
|
|
338 |
if len(x_val)>0 and len(y_val)>0: |
|
|
339 |
eval_one_epoch(x_val, y_val, loss_val_his, acc_val_his, i, train=False) # Set train to be False is crucial! |
|
|
340 |
if is_classification: |
|
|
341 |
if acc_val_his[-1] > best_val_acc: |
|
|
342 |
best_val_acc = acc_val_his[-1] |
|
|
343 |
best_model = copy.deepcopy(model) |
|
|
344 |
best_epoch = i |
|
|
345 |
print('epoch {}, best_val_acc={:.2f}, train_acc={:.2f}'.format( |
|
|
346 |
best_epoch, best_val_acc, acc_train_his[-1])) |
|
|
347 |
if len(x_test)>0 and len(y_test)>0: |
|
|
348 |
eval_one_epoch(x_test, y_test, loss_test_his, acc_test_his, i, train=False) # Set train to be False |
|
|
349 |
|
|
|
350 |
if is_classification: |
|
|
351 |
if return_best_val and len(x_val)>0 and len(y_val)>0: |
|
|
352 |
return best_model, best_val_acc, best_epoch |
|
|
353 |
else: |
|
|
354 |
return model, acc_train_his[-1], i |
|
|
355 |
|
|
|
356 |
|
|
|
357 |
def run_one_epoch_multiloss(model, x, targets, heads=[0,1], loss_fns=[nn.CrossEntropyLoss(), nn.MSELoss()], |
|
|
358 |
loss_weights=[1,0], other_loss_fns=[], other_loss_weights=[], return_loss=True, batch_size=None, |
|
|
359 |
train=True, optimizer=None, epoch=0, print_every=10, verbose=True): |
|
|
360 |
"""Calculate a multi-head model with multiple losses including losses from the outputs and targets (head losses) |
|
|
361 |
and regularizers on model parameters (non-head losses). |
|
|
362 |
|
|
|
363 |
Args: |
|
|
364 |
model: A model with multihead; for example, an AutoEncoder classifier, returns classification scores |
|
|
365 |
(or regression target) and decoder output (reconstruction of input) |
|
|
366 |
x: input |
|
|
367 |
targets: a list of targets associated with multi-head output specified by argument heads; |
|
|
368 |
e.g., for an autoencoder with two heads, targets = [y_labels, x] |
|
|
369 |
targets are not needed to pair with all heads output one-to-one; |
|
|
370 |
use arguments heads to specify which heads are paired with targets; |
|
|
371 |
The elements of targets can be None, too; |
|
|
372 |
the length of targets must be compatible with that of loss_weights, loss_fns, and heads |
|
|
373 |
heads: the index for the heads paired with targets for calculating losses; |
|
|
374 |
if None, set heads = list(range(len(targets))) |
|
|
375 |
loss_fns: a list of loss functions for the corresponding head |
|
|
376 |
loss_weights: the (non-negative) weights for the above head-losses; |
|
|
377 |
heads, loss_fns, and loss_weights are closely related to each other; need to handle it carefully |
|
|
378 |
other_loss_fns: a list of loss functions as regularizers on model parameters |
|
|
379 |
other_loss_weights: the corresponding weights for other_loss_fns |
|
|
380 |
return_loss: default True, return all losses |
|
|
381 |
batch_size: default None; split data into batches |
|
|
382 |
train: default True; if False, call model.eval() and torch.set_grad_enabled(False) to save time |
|
|
383 |
optimizer: when train is True, optimizer must be given; default None, do not use for evaluation |
|
|
384 |
epoch: for print only |
|
|
385 |
print_every: print epoch losses if epoch % print_every == 0 |
|
|
386 |
verbose: if True, print losses for each batch |
|
|
387 |
""" |
|
|
388 |
|
|
|
389 |
is_grad_enabled = torch.is_grad_enabled() |
|
|
390 |
if train: |
|
|
391 |
model.train() |
|
|
392 |
torch.set_grad_enabled(True) |
|
|
393 |
else: |
|
|
394 |
model.eval() |
|
|
395 |
torch.set_grad_enabled(False) |
|
|
396 |
if batch_size is None: |
|
|
397 |
batch_size = len(x) |
|
|
398 |
|
|
|
399 |
if len(targets) < len(loss_weights): |
|
|
400 |
# Some losses do not require targets (using 'implicit' targets in the objective) |
|
|
401 |
# Add None so that targets for later use |
|
|
402 |
targets = targets + [None]*(len(loss_weights) - len(targets)) |
|
|
403 |
is_classification = [] # record the indices of targets that is for classification |
|
|
404 |
has_unequal_size = [] # record the indices of targets that has a different size with input |
|
|
405 |
is_none = [] # record the indices of the targets that is None |
|
|
406 |
for j, y_true in enumerate(targets): |
|
|
407 |
if y_true is not None: |
|
|
408 |
if len(y_true) == len(x): |
|
|
409 |
if isinstance(y_true.cpu(), torch.LongTensor): |
|
|
410 |
# if targets[j] is LongTensor, treat it as classification task |
|
|
411 |
is_classification.append(j) |
|
|
412 |
else: |
|
|
413 |
has_unequal_size.append(j) |
|
|
414 |
else: |
|
|
415 |
is_none.append(j) |
|
|
416 |
loss_history = [] |
|
|
417 |
if len(is_classification) > 0: |
|
|
418 |
acc_history = [] |
|
|
419 |
|
|
|
420 |
if heads is None: # If head is not given, then assume the targets is paired with model output in order |
|
|
421 |
heads = list(range(len(targets))) |
|
|
422 |
for i in range(0, len(x), batch_size): |
|
|
423 |
y_pred = model(x[i:i+batch_size]) |
|
|
424 |
loss_batch = [] |
|
|
425 |
for j, w in enumerate(loss_weights): |
|
|
426 |
if w>0: # only execute when w>0 |
|
|
427 |
if j in is_none: |
|
|
428 |
loss_j = loss_fns[j](y_pred[heads[j]]) * w |
|
|
429 |
elif j in has_unequal_size: |
|
|
430 |
loss_j = loss_fns[j](y_pred[heads[j]], targets[j]) * w # targets[j] is the same for all batches |
|
|
431 |
else: |
|
|
432 |
loss_j = loss_fns[j](y_pred[heads[j]], targets[j][i:i+batch_size]) * w |
|
|
433 |
loss_batch.append(loss_j) |
|
|
434 |
for j, w in enumerate(other_loss_weights): |
|
|
435 |
if w>0: |
|
|
436 |
# The implicit 'target' is encoded in the loss function itself |
|
|
437 |
# todo: in addition to argument model, make loss_fns handle other 'dynamic' arguments as well |
|
|
438 |
loss_j = other_loss_fns[j](model) * w |
|
|
439 |
loss_batch.append(loss_j) |
|
|
440 |
loss = sum(loss_batch) |
|
|
441 |
loss_batch = [v.item() for v in loss_batch] |
|
|
442 |
loss_history.append(loss_batch) |
|
|
443 |
# Calculate accuracy |
|
|
444 |
if len(is_classification) > 0: |
|
|
445 |
acc_batch = [] |
|
|
446 |
for k, j in enumerate(is_classification): |
|
|
447 |
labels_pred = y_pred[heads[j]].topk(1, -1)[1].squeeze() |
|
|
448 |
acc = (labels_pred == targets[j][i:i+batch_size]).float().mean().item() |
|
|
449 |
acc_batch.append(acc) |
|
|
450 |
acc_history.append(acc_batch) |
|
|
451 |
if verbose: |
|
|
452 |
msg = 'Epoch{} {}/{}: loss:{}'.format(epoch, i//batch_size, (len(x)+batch_size-1)//batch_size, |
|
|
453 |
', '.join(map(lambda x: f'{x:.2e}', loss_batch))) |
|
|
454 |
if len(is_classification) > 0: |
|
|
455 |
msg = msg + ', acc={}'.format(', '.join(map(lambda x: f'{x:.2f}', acc_batch))) |
|
|
456 |
print(msg) |
|
|
457 |
if train: |
|
|
458 |
optimizer.zero_grad() |
|
|
459 |
loss.backward() |
|
|
460 |
optimizer.step() |
|
|
461 |
torch.set_grad_enabled(is_grad_enabled) |
|
|
462 |
|
|
|
463 |
loss_epoch = np.mean(loss_history, axis=0) |
|
|
464 |
if len(is_classification) > 0: |
|
|
465 |
acc_epoch = np.mean(acc_history, axis=0) |
|
|
466 |
if epoch % print_every == 0: |
|
|
467 |
msg = 'Epoch{} {}: loss:{}'.format(epoch, 'Train' if train else 'Test', |
|
|
468 |
', '.join(map(lambda x: f'{x:.2e}', loss_epoch))) |
|
|
469 |
if len(is_classification) > 0: |
|
|
470 |
msg = msg + ', acc={}'.format(', '.join(map(lambda x: f'{x:.2f}', acc_epoch))) |
|
|
471 |
print(msg) |
|
|
472 |
|
|
|
473 |
if return_loss: |
|
|
474 |
if len(is_classification) > 0: |
|
|
475 |
return loss_epoch, acc_epoch, loss_history, acc_history |
|
|
476 |
else: |
|
|
477 |
return loss_epoch, loss_history |
|
|
478 |
|
|
|
479 |
|
|
|
480 |
def train_multiloss(model, x_train, y_train, x_val=[], y_val=[], x_test=[], y_test=[], heads=[0, 1], |
|
|
481 |
loss_fns=[nn.CrossEntropyLoss(), nn.MSELoss()], loss_weights=[1,0], other_loss_fns=[], other_loss_weights=[], |
|
|
482 |
lr=1e-2, weight_decay=1e-4, batch_size=None, num_epochs=1, reduce_every=100, eval_every=1, print_every=1, |
|
|
483 |
loss_train_his=[], loss_val_his=[], loss_test_his=[], acc_train_his=[], acc_val_his=[], acc_test_his=[], |
|
|
484 |
return_best_val=True, amsgrad=True, verbose=False): |
|
|
485 |
"""Train a number of epochs |
|
|
486 |
Most of the parameters are passed to run_one_epoch_multiloss |
|
|
487 |
|
|
|
488 |
Args: |
|
|
489 |
lr, weight_decay, amsgrad are passed to torch.optim.Adam |
|
|
490 |
reduce_every: call adjust_learning_rate if i % reduce_every == 0; i is the current epoch |
|
|
491 |
eval_every: run_one_multiloss on validation and test set if i % eval_every == 0 |
|
|
492 |
return_best_val: for classification task, if validation set is available, return the best model on validation set |
|
|
493 |
print_every: print epoch losses if i % print_every == 0 |
|
|
494 |
verbose: if True, print batch losses |
|
|
495 |
""" |
|
|
496 |
def eval_one_epoch(x, targets, loss_his, acc_his, epoch, train=False): |
|
|
497 |
"""This is a function within a function; reuse some parameters in the scope of the "outer" function |
|
|
498 |
""" |
|
|
499 |
results = run_one_epoch_multiloss(model, x, targets=targets, heads=heads, loss_fns=loss_fns, |
|
|
500 |
loss_weights=loss_weights, other_loss_fns=other_loss_fns, other_loss_weights=other_loss_weights, |
|
|
501 |
return_loss=True, batch_size=batch_size, train=train, optimizer=optimizer, epoch=epoch, |
|
|
502 |
print_every=print_every, verbose=verbose) |
|
|
503 |
if is_classification: |
|
|
504 |
loss_epoch, acc_epoch, loss_history, acc_history = results |
|
|
505 |
else: |
|
|
506 |
loss_epoch, loss_history = results |
|
|
507 |
# loss_train_his += loss_history |
|
|
508 |
# acc_train_his += acc_history |
|
|
509 |
loss_his.append(loss_epoch) |
|
|
510 |
if is_classification: |
|
|
511 |
acc_his.append(acc_epoch) |
|
|
512 |
|
|
|
513 |
cls_targets = [] |
|
|
514 |
for i, y_true in enumerate(y_train): |
|
|
515 |
if isinstance(y_true.cpu(), torch.LongTensor): |
|
|
516 |
cls_targets.append(i) |
|
|
517 |
is_classification = len(cls_targets) > 0 |
|
|
518 |
best_val_acc = -1 # After the first iteration, best_val_acc >= 0 |
|
|
519 |
|
|
|
520 |
for i in range(num_epochs): |
|
|
521 |
if i == 0: # I did not clear the caches after adjusting the learning rate later; this works, but is it better? |
|
|
522 |
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, |
|
|
523 |
weight_decay=weight_decay, amsgrad=amsgrad) |
|
|
524 |
adjust_learning_rate(optimizer, lr, i, reduce_every=reduce_every) |
|
|
525 |
|
|
|
526 |
eval_one_epoch(x_train, y_train, loss_train_his, acc_train_his, i, train=True) |
|
|
527 |
if i % eval_every == 0: |
|
|
528 |
if len(x_val)>0 and len(y_val)>0: |
|
|
529 |
# Must set train=False, otherwise leak data |
|
|
530 |
eval_one_epoch(x_val, y_val, loss_val_his, acc_val_his, i, train=False) |
|
|
531 |
if is_classification: |
|
|
532 |
cur_val_acc = np.mean(acc_val_his[-1]) |
|
|
533 |
if cur_val_acc > best_val_acc: # Use the mean accuracy for all classification tasks (in most case just one) |
|
|
534 |
best_val_acc = cur_val_acc |
|
|
535 |
best_model = copy.deepcopy(model) |
|
|
536 |
best_epoch = i |
|
|
537 |
print('epoch {}, best_val_acc={:.2f}, train_acc={:.2f}'.format( |
|
|
538 |
best_epoch, best_val_acc, np.mean(acc_train_his[-1]))) |
|
|
539 |
if len(x_test)>0 and len(y_test)>0: |
|
|
540 |
eval_one_epoch(x_test, y_test, loss_test_his, acc_test_his, i, train=False) |
|
|
541 |
|
|
|
542 |
if is_classification: |
|
|
543 |
if return_best_val and len(x_val)>0 and len(y_val)>0: |
|
|
544 |
return best_model, best_val_acc, best_epoch |
|
|
545 |
else: |
|
|
546 |
return model, np.mean(acc_train_his[-1]), i |