Diff of /dl/utils/train.py [000000] .. [4807fa]

Switch to unified view

a b/dl/utils/train.py
1
import sys
2
import os
3
import copy
4
lib_path = 'I:/code'
5
if not os.path.exists(lib_path):
6
  lib_path = '/media/6T/.tianle/.lib'
7
if os.path.exists(lib_path) and lib_path not in sys.path:
8
  sys.path.append(lib_path)
9
10
import numpy as np
11
import sklearn
12
13
import torch
14
import torch.nn as nn
15
16
from dl.utils.visualization.visualization import plot_scatter
17
18
19
def cosine_similarity(x, y=None, eps=1e-8):
20
  """Calculate cosine similarity between two matrices; 
21
22
  Args:
23
    x: N*p tensor
24
    y: M*p tensor or None; if None, set y = x
25
    This function do not broadcast
26
  
27
  Returns:
28
    N*M tensor
29
30
  """
31
  w1 = torch.norm(x, p=2, dim=1, keepdim=True)
32
  if y is None:
33
    w2 = w1.squeeze(dim=1)
34
    y = x
35
  else:
36
    w2 = torch.norm(y, p=2, dim=1)
37
  w12 = torch.mm(x, y.t())
38
  return w12 / (w1*w2).clamp(min=eps)
39
40
41
def adjust_learning_rate(optimizer, lr, epoch, reduce_every=2):
42
  """Reduce learning rate by 10% every reduce_every iterations
43
  """
44
  lr = lr * (0.1 ** (epoch//reduce_every))
45
  for param_group in optimizer.param_groups:
46
    param_group['lr'] = lr
47
48
49
def predict(model, x, batch_size=None, train=True, num_heads=1):
50
  """Calculate model(x)
51
52
  Args:
53
    batch_size: default None; predict in one batch for small model and data.
54
    train: default True; if False, call torch.set_grad_enabled(False) first. 
55
    num_heads: default 1; if num_heads > 1, then calculate multi-head output
56
  """
57
  if batch_size is None:
58
    batch_size = x.size(0)
59
  y_pred = []
60
  if num_heads > 1:
61
    y_pred = [[] for i in range(num_heads)] # store decoder output
62
  prev = torch.is_grad_enabled()
63
  if train:
64
    model.train()
65
    torch.set_grad_enabled(True)
66
  else:
67
    model.eval()
68
    torch.set_grad_enabled(False)
69
  for i in range(0, len(x), batch_size):
70
    y_ = model(x[i:i+batch_size])
71
    if num_heads > 1:
72
      for i in range(num_heads):
73
        y_pred[i].append(y_[i])
74
    else:
75
      y_pred.append(y_)
76
  torch.set_grad_enabled(prev)
77
  
78
  if num_heads > 1:
79
    return [torch.cat(y, 0) for y in y_pred]
80
  else:
81
    return torch.cat(y_pred, 0)
82
83
84
def plot_data(model, x, y, title='', num_heads=1, batch_size=None):
85
  """Scatter plot for input layer and output with colors corresponding to labels
86
  """
87
  if isinstance(y, torch.Tensor):
88
    y = y.cpu().detach().numpy()
89
  plot_scatter(x, labels=y, title=f'Input {title}')
90
  y_pred = predict(model, x, batch_size, train=False, num_heads=num_heads)
91
  if num_heads > 1:
92
    for i in range(num_heads):
93
      plot_scatter(y_pred[i], labels=y, title=f'Head {i}')
94
  else:
95
    plot_scatter(y_pred, labels=y, title='Output')
96
97
98
def plot_data_multi_splits(model, xs, ys, num_heads=1, titles=['Training', 'Validation', 'Test'], batch_size=None):
99
  """Call plot_data on multiple data splits, typically x_train, x_val, x_test
100
101
  Args:
102
    Most arguments are passed to plot_data
103
    xs: a list of model input
104
    ys: a list of target labels
105
    titles: a list of titles for each data split
106
107
  """
108
  if len(xs) != len(titles): # Make sure titles are of the same length as xs
109
    titles = [f'Data split {i}' for i in range(len(xs))]
110
  for i, (x, y) in enumerate(zip(xs, ys)):
111
    if len(x) > 0 and len(x)==len(y):
112
      plot_data(model, x, y, title=titles[i], num_heads=num_heads, batch_size=batch_size)
113
    else:
114
      print(f'x for {titles[i]} is empty or len(x) != len(y)')
115
116
117
def get_label_prob(labels, verbose=True):
118
  """Get label distribution
119
  """
120
  if isinstance(labels, torch.Tensor):
121
    unique_labels = torch.unique(labels).sort()[0]
122
    label_prob = torch.stack([labels==i for i in unique_labels], dim=0).sum(dim=1)
123
    label_prob = label_prob.float()/len(labels)
124
  else:
125
    labels = np.array(labels) # if labels is a list then change it to np.array
126
    unique_labels = sorted(np.unique(labels))
127
    label_prob = np.stack([labels==i for i in unique_labels], axis=0).sum(axis=1)
128
    label_prob = label_prob / len(labels)
129
  if verbose:
130
    msg = '\n'.join(map(lambda x: f'{x[0]}: {x[1].item():.2f}', 
131
                        zip(unique_labels, label_prob)))
132
    print(f'label distribution:\n{msg}')
133
  return label_prob
134
135
136
def eval_classification(y_true, y_pred=None, model=None, x=None, batch_size=None, multi_heads=False, 
137
  cls_head=0, average='weighted', predict_func=None, pred_kwargs=None, verbose=True):
138
  """Evaluate classification results
139
140
  Args:
141
    y_true: true labels; numpy array or torch.Tensor
142
    y_pred: if None, then y_pred = model(x)
143
    model: torch.nn.Module type
144
    x: input tensor
145
    batch_size: used for predict(model, x, batch_size)
146
    multi_heads: If true, the model output a list; Assume the classification head is the first one
147
    cls_head: only used when multi_heads is True; specify which head is used for classification; default 0
148
    average: used for sklearn.metrics to calculate precision, recall, f1, auc and ap; default: 'weighted'
149
    predict_func: if not None, use predict_func(model, x, **pred_kwargs) instead of predict()
150
    pred_kwargs: dictionary arguments for predict_func
151
152
  """
153
  if isinstance(y_true, torch.Tensor): 
154
    y_true = y_true.cpu().detach().numpy().reshape(-1)
155
  num_cls = len(np.unique(y_true))
156
  auc = -1 # dummy variable for multi-class classification
157
  average_precision = -1 # dummy variable for multi-class classification
158
  y_score = None # only used to calculate auc and average_precision for binary classification; will be set later
159
  if y_pred is None: # Calculate y_pred = model(x) in batches
160
    if predict_func is None:
161
      # use predict() defined in this file
162
      num_heads = 2 if multi_heads else 1 # num_heads >= 2 is to make predict() to process the model as multi-output
163
      y_ = predict(model, x, batch_size, train=False, num_heads=num_heads)
164
      y_pred = y_[cls_head] if multi_heads else y_
165
    else:
166
      # use customized predict_func with variable keyworded arguments
167
      y_pred = predict_func(model, x, **pred_kwargs)
168
  if isinstance(y_pred, torch.Tensor):
169
    # either input argument is a torch.Tensor or calculate it from model(x) in the last chunk
170
    y_pred = y_pred.cpu().detach().numpy()
171
  if isinstance(y_pred, np.ndarray) and y_pred.ndim == 2 and y_pred.shape[1] > 1:
172
    # y_pred is the class score matrix: n_samples * n_classes
173
    if y_pred.shape[1] == 2: # for binary classification
174
      y_score = y_pred[:,1] - y_pred[:,0] # y_score is only useful for calculating auc and average precison 
175
    y_pred = y_pred.argmax(axis=-1) # only consider top 1 prediction
176
  if num_cls==2 and y_pred.dtype == np.dtype('float'): # last chunk had not been executed
177
    # For binary classification, argument y_pred can be the scores for belonging to class 1.
178
    y_score =  y_pred # Used for calculate auc and average_precision
179
    y_pred = (y_score > 0).astype('int')
180
  acc = sklearn.metrics.accuracy_score(y_true, y_pred)
181
  precision = sklearn.metrics.precision_score(y_true, y_pred, average=average)
182
  recall = sklearn.metrics.recall_score(y_true, y_pred, average=average)
183
  f1_score = sklearn.metrics.f1_score(y_true=y_true, y_pred=y_pred, average=average)
184
  adjusted_mutual_info = sklearn.metrics.adjusted_mutual_info_score(labels_true=y_true, labels_pred=y_pred)
185
  confusion_mat = sklearn.metrics.confusion_matrix(y_true, y_pred)
186
  msg = f'acc={acc:.3f}, precision={precision:.3f}, recall={recall:.3f}, fl={f1_score:.3f}, adj_MI={adjusted_mutual_info:.3f}'
187
  if num_cls == 2:
188
    # When y_pred is given as an int np.array or tensor, model(x) is not called; 
189
    # set y_score = y_pred to calculate auc and average precision approximately; 
190
    # it may not be 100% accurate because I assign y_pred (binary labels) to y_score (which should be probabilities)
191
    if y_score is None: 
192
      y_score = y_pred
193
    auc = sklearn.metrics.roc_auc_score(y_true=y_true, y_score=y_score, average=average)
194
    average_precision = sklearn.metrics.average_precision_score(y_true=y_true, y_score=y_score, average=average)
195
    msg = msg + f', auc={auc:.3f}, ap={average_precision:.3f}'
196
  msg = msg + f', confusion_mat=\n{confusion_mat}'
197
  if verbose:
198
    print(msg)
199
    print('report', sklearn.metrics.classification_report(y_true=y_true, y_pred=y_pred))
200
201
  return np.array([acc, precision, recall, f1_score, adjusted_mutual_info, auc, average_precision]), confusion_mat
202
203
204
def eval_classification_multi_splits(model, xs, ys, batch_size=None, multi_heads=False, cls_head=0, 
205
  average='weighted', return_result=True, split_names=['Train', 'Validataion', 'Test'],
206
  predict_func=None, pred_kwargs=None, verbose=True):
207
  """Call eval_classification on multiple data splits, e.g., x_train, x_val, x_test with given model
208
209
  Args:
210
    Most arguments are passed to eval_classification
211
    xs: a list of model input, e.g., [x_train, x_val, x_test]
212
    ys: a list of targets, e.g., [y_train, y_val, y_test]
213
    return_results: if True return results on non-empty data splits
214
    split_names: for print purpose; default: ['train', 'val', 'test']
215
216
  """
217
  res = []
218
  if len(xs) != len(split_names):
219
    split_names = [f'Data split {i}' for i in range(len(xs))]
220
  for i, (x, y) in enumerate(zip(xs, ys)):
221
    if len(x) > 0:
222
      print(split_names[i])
223
      metric = eval_classification(y_true=y, model=model, x=x, batch_size=batch_size, 
224
                          multi_heads=multi_heads, cls_head=cls_head, average=average,
225
                          predict_func=predict_func, pred_kwargs=pred_kwargs, verbose=verbose)
226
      res.append(metric)
227
  if return_result:
228
    return res
229
230
231
def run_one_epoch_single_loss(model, x, y_true, loss_fn=nn.CrossEntropyLoss(), train=True, optimizer=None, 
232
  batch_size=None, return_loss=True, epoch=0, print_every=10, verbose=True):
233
  """Run one epoch, i.e., model(x), but split into batches
234
  
235
  Args:
236
    model: torch.nn.Module
237
    x: torch.Tensor
238
    y_true: target torch.Tensor
239
    loss_fn: loss function
240
    train: if False, call model.eval() and torch.set_grad_enabled(False) to save time
241
    optimizer: needed when train is True
242
    batch_size: if None, batch_size = len(x)
243
    return_loss: if True, return epoch loss
244
    epoch: for print 
245
    print_every: print epoch_loss if print_every % epoch == 0
246
    verbose: if True, print batch_loss
247
  """
248
249
  is_grad_enabled = torch.is_grad_enabled()
250
  if train:
251
    model.train()
252
    torch.set_grad_enabled(True)
253
  else:
254
    model.eval()
255
    torch.set_grad_enabled(False)
256
  loss_history = []
257
  is_classification = isinstance(y_true.cpu(), torch.LongTensor)
258
  if is_classification:
259
    acc_history = []
260
  if batch_size is None:
261
    batch_size = len(x)
262
  for i in range(0, len(x), batch_size):
263
    y_pred = model(x[i:i+batch_size])
264
    loss = loss_fn(y_pred, y_true[i:i+batch_size])
265
    loss_history.append(loss.item())
266
    if is_classification:
267
      labels_pred = y_pred.topk(1, -1)[1].squeeze() # only calculate top 1 accuracy
268
      acc = (labels_pred == y_true[i:i+batch_size]).float().mean().item()
269
      acc_history.append(acc)
270
    if verbose:
271
      msg = 'Epoch{} {}/{}: loss={:.2e}'.format(
272
        epoch, i//batch_size, (len(x)+batch_size-1)//batch_size, loss.item())
273
      if is_classification:
274
        msg = msg + f', acc={acc:.2f}'
275
      print(msg)
276
    if train:
277
      optimizer.zero_grad()
278
      loss.backward()
279
      optimizer.step()
280
  torch.set_grad_enabled(is_grad_enabled)
281
282
  loss_epoch = np.mean(loss_history)
283
  if is_classification:
284
    acc_epoch = np.mean(acc_history)
285
  if epoch % print_every == 0:  
286
    msg = 'Epoch{} {}: loss={:.2e}'.format(epoch, 'Train' if train else 'Test', np.mean(loss_history))
287
    if is_classification:
288
      msg = msg + f', acc={np.mean(acc_history):.2f}'
289
    print(msg)
290
  if return_loss:
291
    if is_classification:
292
      return loss_epoch, acc_epoch, loss_history, acc_history
293
    else:
294
      return loss_epoch, loss_history
295
  
296
297
def train_single_loss(model, x_train, y_train, x_val=[], y_val=[], x_test=[], y_test=[], 
298
    loss_fn=nn.CrossEntropyLoss(), lr=1e-2, weight_decay=1e-4, amsgrad=True, batch_size=None, num_epochs=1, 
299
    reduce_every=200, eval_every=1, print_every=1, verbose=False, 
300
    loss_train_his=[], loss_val_his=[], loss_test_his=[], 
301
    acc_train_his=[], acc_val_his=[], acc_test_his=[], return_best_val=True):
302
  """Run a number of epochs to backpropagate
303
304
  Args:
305
    Most arguments are passed to run_one_epoch_single_loss
306
    lr, weight_decay, amsgrad are passed to torch.optim.Adam
307
    reduce_every: call adjust_learning_rate if cur_epoch % reduce_every == 0
308
    eval_every: call run_one_epoch_single_loss on validation and test sets if cur_epoch % eval_every == 0
309
    print_every: print epoch loss if cur_epoch % print_every == 0
310
    verbose: if True, print batch loss
311
    return_best_val: if True, return the best model on validation set for classification task 
312
  """
313
314
  def eval_one_epoch(x, targets, loss_his, acc_his, epoch, train=False):
315
    """Function within function; reuse parameters within proper scope
316
    """
317
    results = run_one_epoch_single_loss(model, x, targets, loss_fn=loss_fn, train=train, optimizer=optimizer, 
318
      batch_size=batch_size, return_loss=True, epoch=epoch, print_every=print_every, verbose=verbose)
319
    if is_classification:
320
      loss_epoch, acc_epoch, loss_history, acc_history = results
321
    else:
322
      loss_epoch, loss_history = results
323
    loss_his.append(loss_epoch)
324
    if is_classification:
325
      acc_his.append(acc_epoch)
326
327
  is_classification = isinstance(y_train.cpu(), torch.LongTensor)
328
  best_val_acc = -1 # best_val_acc >=0 after the first epoch for classification task
329
  for i in range(num_epochs):   
330
    if i == 0:
331
      optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 
332
        lr=lr, weight_decay=weight_decay, amsgrad=amsgrad)
333
    # Should I create a new torch.optim.Adam instance every time I adjust learning rate? 
334
    adjust_learning_rate(optimizer, lr, i, reduce_every=reduce_every)
335
336
    eval_one_epoch(x_train, y_train, loss_train_his, acc_train_his, i, train=True)
337
    if i % eval_every == 0:
338
      if len(x_val)>0 and len(y_val)>0:
339
        eval_one_epoch(x_val, y_val, loss_val_his, acc_val_his, i, train=False) # Set train to be False is crucial!
340
        if is_classification:
341
          if acc_val_his[-1] > best_val_acc:
342
            best_val_acc = acc_val_his[-1]
343
            best_model = copy.deepcopy(model)
344
            best_epoch = i
345
            print('epoch {}, best_val_acc={:.2f}, train_acc={:.2f}'.format(
346
              best_epoch, best_val_acc, acc_train_his[-1]))
347
      if len(x_test)>0 and len(y_test)>0:
348
        eval_one_epoch(x_test, y_test, loss_test_his, acc_test_his, i, train=False) # Set train to be False
349
350
  if is_classification:
351
    if return_best_val and len(x_val)>0 and len(y_val)>0:
352
      return best_model, best_val_acc, best_epoch
353
    else:
354
      return model, acc_train_his[-1], i
355
356
357
def run_one_epoch_multiloss(model, x, targets, heads=[0,1], loss_fns=[nn.CrossEntropyLoss(), nn.MSELoss()], 
358
  loss_weights=[1,0], other_loss_fns=[], other_loss_weights=[], return_loss=True, batch_size=None, 
359
  train=True, optimizer=None, epoch=0, print_every=10, verbose=True):
360
  """Calculate a multi-head model with multiple losses including losses from the outputs and targets (head losses) 
361
  and regularizers on model parameters (non-head losses).
362
  
363
  Args:
364
    model: A model with multihead; for example, an AutoEncoder classifier, returns classification scores 
365
      (or regression target) and decoder output (reconstruction of input)
366
    x: input
367
    targets: a list of targets associated with multi-head output specified by argument heads; 
368
      e.g., for an autoencoder with two heads, targets = [y_labels, x]
369
      targets are not needed to pair with all heads output one-to-one; 
370
      use arguments heads to specify which heads are paired with targets;
371
      The elements of targets can be None, too; 
372
      the length of targets must be compatible with that of loss_weights, loss_fns, and heads
373
    heads: the index for the heads paired with targets for calculating losses; 
374
      if None, set heads = list(range(len(targets)))
375
    loss_fns: a list of loss functions for the corresponding head
376
    loss_weights: the (non-negative) weights for the above head-losses;
377
      heads, loss_fns, and loss_weights are closely related to each other; need to handle it carefully
378
    other_loss_fns: a list of loss functions as regularizers on model parameters
379
    other_loss_weights: the corresponding weights for other_loss_fns
380
    return_loss: default True, return all losses
381
    batch_size: default None; split data into batches
382
    train: default True; if False, call model.eval() and torch.set_grad_enabled(False) to save time
383
    optimizer: when train is True, optimizer must be given; default None, do not use for evaluation
384
    epoch: for print only
385
    print_every: print epoch losses if epoch % print_every == 0
386
    verbose: if True, print losses for each batch
387
  """
388
389
  is_grad_enabled = torch.is_grad_enabled()
390
  if train:
391
    model.train()
392
    torch.set_grad_enabled(True)
393
  else:
394
    model.eval()
395
    torch.set_grad_enabled(False)
396
  if batch_size is None:
397
    batch_size = len(x)
398
  
399
  if len(targets) < len(loss_weights):
400
    # Some losses do not require targets (using 'implicit' targets in the objective)
401
    # Add None so that targets for later use
402
    targets = targets + [None]*(len(loss_weights) - len(targets))
403
  is_classification = [] # record the indices of targets that is for classification
404
  has_unequal_size = [] # record the indices of targets that has a different size with input
405
  is_none = [] # record the indices of the targets that is None
406
  for j, y_true in enumerate(targets):
407
    if y_true is not None:
408
      if len(y_true) == len(x):
409
        if isinstance(y_true.cpu(), torch.LongTensor):
410
          # if targets[j] is LongTensor, treat it as classification task
411
          is_classification.append(j)
412
      else:
413
        has_unequal_size.append(j)
414
    else:
415
      is_none.append(j)
416
  loss_history = []
417
  if len(is_classification) > 0:
418
    acc_history = []
419
420
  if heads is None: # If head is not given, then assume the targets is paired with model output in order
421
    heads = list(range(len(targets)))
422
  for i in range(0, len(x), batch_size):
423
    y_pred = model(x[i:i+batch_size])
424
    loss_batch = []
425
    for j, w in enumerate(loss_weights):
426
      if w>0: # only execute when w>0
427
        if j in is_none:
428
          loss_j = loss_fns[j](y_pred[heads[j]]) * w
429
        elif j in has_unequal_size:
430
          loss_j = loss_fns[j](y_pred[heads[j]], targets[j]) * w # targets[j] is the same for all batches
431
        else:
432
          loss_j = loss_fns[j](y_pred[heads[j]], targets[j][i:i+batch_size]) * w
433
        loss_batch.append(loss_j)
434
    for j, w in enumerate(other_loss_weights):
435
      if w>0:
436
        # The implicit 'target' is encoded in the loss function itself
437
        # todo: in addition to argument model, make loss_fns handle other 'dynamic' arguments as well
438
        loss_j = other_loss_fns[j](model) * w 
439
        loss_batch.append(loss_j)
440
    loss = sum(loss_batch)
441
    loss_batch = [v.item() for v in loss_batch]
442
    loss_history.append(loss_batch)
443
    # Calculate accuracy
444
    if len(is_classification) > 0:
445
      acc_batch = []
446
      for k, j in enumerate(is_classification):
447
        labels_pred = y_pred[heads[j]].topk(1, -1)[1].squeeze()
448
        acc = (labels_pred == targets[j][i:i+batch_size]).float().mean().item()
449
        acc_batch.append(acc)
450
      acc_history.append(acc_batch)
451
    if verbose:
452
      msg = 'Epoch{} {}/{}: loss:{}'.format(epoch, i//batch_size, (len(x)+batch_size-1)//batch_size, 
453
        ', '.join(map(lambda x: f'{x:.2e}', loss_batch)))
454
      if len(is_classification) > 0:
455
        msg = msg + ', acc={}'.format(', '.join(map(lambda x: f'{x:.2f}', acc_batch)))
456
      print(msg)
457
    if train:
458
      optimizer.zero_grad()
459
      loss.backward()
460
      optimizer.step()
461
  torch.set_grad_enabled(is_grad_enabled)
462
463
  loss_epoch = np.mean(loss_history, axis=0)
464
  if len(is_classification) > 0:
465
    acc_epoch = np.mean(acc_history, axis=0)
466
  if epoch % print_every == 0:
467
    msg = 'Epoch{} {}: loss:{}'.format(epoch, 'Train' if train else 'Test', 
468
      ', '.join(map(lambda x: f'{x:.2e}', loss_epoch)))
469
    if len(is_classification) > 0:
470
      msg = msg + ', acc={}'.format(', '.join(map(lambda x: f'{x:.2f}', acc_epoch)))
471
    print(msg)
472
    
473
  if return_loss:
474
    if len(is_classification) > 0:
475
      return loss_epoch, acc_epoch, loss_history, acc_history
476
    else:
477
      return loss_epoch, loss_history
478
479
480
def train_multiloss(model, x_train, y_train, x_val=[], y_val=[], x_test=[], y_test=[], heads=[0, 1], 
481
  loss_fns=[nn.CrossEntropyLoss(), nn.MSELoss()], loss_weights=[1,0], other_loss_fns=[], other_loss_weights=[], 
482
  lr=1e-2, weight_decay=1e-4, batch_size=None, num_epochs=1, reduce_every=100, eval_every=1, print_every=1,
483
  loss_train_his=[], loss_val_his=[], loss_test_his=[], acc_train_his=[], acc_val_his=[], acc_test_his=[], 
484
  return_best_val=True, amsgrad=True, verbose=False):
485
  """Train a number of epochs
486
  Most of the parameters are passed to run_one_epoch_multiloss
487
488
  Args:
489
    lr, weight_decay, amsgrad are passed to torch.optim.Adam
490
    reduce_every: call adjust_learning_rate if i % reduce_every == 0; i is the current epoch
491
    eval_every: run_one_multiloss on validation and test set if i % eval_every == 0
492
    return_best_val: for classification task, if validation set is available, return the best model on validation set
493
    print_every: print epoch losses if i % print_every == 0
494
    verbose: if True, print batch losses
495
  """
496
  def eval_one_epoch(x, targets, loss_his, acc_his, epoch, train=False):
497
    """This is a function within a function; reuse some parameters in the scope of the "outer" function
498
    """
499
    results = run_one_epoch_multiloss(model, x, targets=targets, heads=heads, loss_fns=loss_fns, 
500
                loss_weights=loss_weights, other_loss_fns=other_loss_fns, other_loss_weights=other_loss_weights, 
501
                return_loss=True, batch_size=batch_size, train=train, optimizer=optimizer, epoch=epoch, 
502
                print_every=print_every, verbose=verbose)
503
    if is_classification:
504
      loss_epoch, acc_epoch, loss_history, acc_history = results
505
    else:
506
      loss_epoch, loss_history = results
507
    # loss_train_his += loss_history
508
    # acc_train_his += acc_history
509
    loss_his.append(loss_epoch)
510
    if is_classification:
511
      acc_his.append(acc_epoch)
512
513
  cls_targets = []
514
  for i, y_true in enumerate(y_train):
515
    if isinstance(y_true.cpu(), torch.LongTensor):
516
      cls_targets.append(i)
517
  is_classification = len(cls_targets) > 0
518
  best_val_acc = -1 # After the first iteration, best_val_acc >= 0 
519
520
  for i in range(num_epochs):   
521
    if i == 0: # I did not clear the caches after adjusting the learning rate later; this works, but is it better?
522
      optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, 
523
        weight_decay=weight_decay, amsgrad=amsgrad)
524
    adjust_learning_rate(optimizer, lr, i, reduce_every=reduce_every)
525
526
    eval_one_epoch(x_train, y_train, loss_train_his, acc_train_his, i, train=True)
527
    if i % eval_every == 0:
528
      if len(x_val)>0 and len(y_val)>0:
529
        # Must set train=False, otherwise leak data
530
        eval_one_epoch(x_val, y_val, loss_val_his, acc_val_his, i, train=False) 
531
        if is_classification:
532
          cur_val_acc = np.mean(acc_val_his[-1])
533
          if cur_val_acc > best_val_acc: # Use the mean accuracy for all classification tasks (in most case just one)
534
            best_val_acc = cur_val_acc
535
            best_model = copy.deepcopy(model)
536
            best_epoch = i
537
            print('epoch {}, best_val_acc={:.2f}, train_acc={:.2f}'.format(
538
              best_epoch, best_val_acc, np.mean(acc_train_his[-1])))
539
      if len(x_test)>0 and len(y_test)>0:
540
        eval_one_epoch(x_test, y_test, loss_test_his, acc_test_his, i, train=False)
541
542
  if is_classification:
543
    if return_best_val and len(x_val)>0 and len(y_val)>0:
544
      return best_model, best_val_acc, best_epoch
545
    else:
546
      return model, np.mean(acc_train_his[-1]), i