a b/src/hybrid/hybrid_fit.py
1
2
import torch
3
from src.utils import train_metric
4
5
6
7
def hybrid_fit(epochs, model, hybrid_train_loader, hybrid_val_loader, icdtype, opt_fn,loss_fn, learning_rate, device):
8
  optimizer = opt_fn(model.parameters(), lr=learning_rate)
9
  print('-'*10 + icdtype + '-'*10)
10
  for epoch in range(1,epochs+1):
11
12
    model.train()
13
14
    train_epoch_loss=0
15
    train_epoch_accuracy=0
16
    train_epoch_hammingloss=0
17
    train_epoch_f1score=0
18
19
    val_epoch_loss=0
20
    val_epoch_accuracy=0
21
    val_epoch_hammingloss=0
22
    val_epoch_f1score=0
23
    
24
    for rnn_x, cnn_x, y_dict in hybrid_train_loader:
25
26
      rnn_x = rnn_x.to(device)
27
      cnn_x = cnn_x.to(device)
28
29
      y = y_dict[icdtype]
30
      y = y.to(device)
31
      
32
33
      
34
      preds=model(rnn_x, cnn_x)
35
36
      optimizer.zero_grad()
37
      loss=loss_fn(preds,y)
38
      loss.backward()
39
      optimizer.step()
40
      
41
      accuracy, hammingloss, f1score  = train_metric(preds,y)
42
43
      train_epoch_loss+=loss.item()
44
      train_epoch_accuracy+=accuracy.item()
45
      train_epoch_hammingloss+=hammingloss
46
      train_epoch_f1score+=f1score
47
    
48
    model.eval()
49
    with torch.no_grad():
50
      for rnn_x, cnn_x, y_dict in hybrid_val_loader:
51
        
52
        rnn_x = rnn_x.to(device)
53
        cnn_x = cnn_x.to(device)
54
55
        y = y_dict[icdtype]
56
        y = y.to(device)
57
        
58
        preds=model(rnn_x, cnn_x)
59
60
        loss=loss_fn(preds,y)
61
        accuracy, hammingloss, f1score  = train_metric(preds,y)
62
        val_epoch_loss+=loss.item()
63
        val_epoch_accuracy+=accuracy.item()
64
        val_epoch_hammingloss+=hammingloss
65
        val_epoch_f1score+=f1score
66
67
    
68
    
69
    print("\n")
70
    print('-'*100)
71
    print('Epoch = {}/{}:\n train_loss = {:.4f}, train_accuracy = {:.4f}, train_hammingloss = {:.4f}, train_f1score = {:.4f}\n val_loss = {:.4f}, val_accuracy = {:.4f}, val_hammmingloss = {:.4f}, val_f1score = {:.4f}'.format(epoch
72
                                                              ,epochs
73
                                                              ,train_epoch_loss/len(hybrid_train_loader)
74
                                                              ,train_epoch_accuracy/len(hybrid_train_loader)
75
                                                              ,train_epoch_hammingloss/len(hybrid_train_loader)
76
                                                              ,train_epoch_f1score/len(hybrid_train_loader)
77
                                                              ,val_epoch_loss/len(hybrid_val_loader)
78
                                                              ,val_epoch_accuracy/len(hybrid_val_loader)
79
                                                              ,val_epoch_hammingloss/len(hybrid_val_loader)
80
                                                              ,val_epoch_f1score/len(hybrid_val_loader)
81
                                                              ))
82
    print('-'*100)
83
    print("\n")
84