Diff of /src/fit.py [000000] .. [71ad2f]

Switch to unified view

a b/src/fit.py
1
import torch
2
from src.utils import train_metric
3
4
def fit(epochs,model,train_loader,val_loader, icdtype, opt_fn,loss_fn, learning_rate, device):
5
  
6
  optimizer = opt_fn(model.parameters(), lr=learning_rate)
7
  print('-'*10 + icdtype + '-'*10)
8
  for epoch in range(1,epochs+1):
9
10
    model.train()
11
12
    train_epoch_loss=0
13
    train_epoch_accuracy=0
14
    train_epoch_hammingloss=0
15
    train_epoch_f1score=0
16
17
    val_epoch_loss=0
18
    val_epoch_accuracy=0
19
    val_epoch_hammingloss=0
20
    val_epoch_f1score=0
21
22
    
23
    for x, y_dict in train_loader:
24
25
      x = x.to(device)
26
27
      y = y_dict[icdtype]
28
      y = y.to(device)
29
30
      
31
      preds=model(x)
32
33
      optimizer.zero_grad()
34
      loss=loss_fn(preds,y)
35
      loss.backward()
36
      optimizer.step()
37
      
38
      accuracy, hammingloss, f1score = train_metric(preds,y)
39
40
      train_epoch_loss+=loss.item()
41
      train_epoch_accuracy+=accuracy.item()
42
      train_epoch_hammingloss+=hammingloss
43
      train_epoch_f1score+=f1score
44
    
45
    model.eval()
46
    with torch.no_grad():
47
      for x,y_dict in val_loader:
48
        
49
        x=x.to(device)
50
51
        y = y_dict[icdtype]
52
        y = y.to(device)
53
54
        
55
        preds=model(x)
56
57
        loss=loss_fn(preds,y)
58
        accuracy, hammingloss, f1score  = train_metric(preds,y)
59
60
        val_epoch_loss+=loss.item()
61
        val_epoch_accuracy+=accuracy.item()
62
        val_epoch_hammingloss+=hammingloss
63
        val_epoch_f1score+=f1score
64
65
    
66
  
67
    print("\n")
68
    print('-'*100)
69
    print('Epoch = {}/{}:\n train_loss = {:.4f}, train_accuracy = {:.4f}, train_hammingloss = {:.4f}, train_f1score = {:.4f}\n val_loss = {:.4f}, val_accuracy = {:.4f}, val_hammmingloss = {:.4f}, val_f1score = {:.4f}'.format(epoch
70
                                                              ,epochs
71
                                                              ,train_epoch_loss/len(train_loader)
72
                                                              ,train_epoch_accuracy/len(train_loader)
73
                                                              ,train_epoch_hammingloss/len(train_loader)
74
                                                              ,train_epoch_f1score/len(train_loader)
75
                                                              ,val_epoch_loss/len(val_loader)
76
                                                              ,val_epoch_accuracy/len(val_loader)
77
                                                              ,val_epoch_hammingloss/len(val_loader)
78
                                                              ,val_epoch_f1score/len(val_loader)
79
                                                              ))
80
    print('-'*100)
81
    print("\n")
82