Diff of /src/hybrid/hybrid_fit.py [000000] .. [71ad2f]

Switch to side-by-side view

--- a
+++ b/src/hybrid/hybrid_fit.py
@@ -0,0 +1,84 @@
+
+import torch
+from src.utils import train_metric
+
+
+
+def hybrid_fit(epochs, model, hybrid_train_loader, hybrid_val_loader, icdtype, opt_fn,loss_fn, learning_rate, device):
+  optimizer = opt_fn(model.parameters(), lr=learning_rate)
+  print('-'*10 + icdtype + '-'*10)
+  for epoch in range(1,epochs+1):
+
+    model.train()
+
+    train_epoch_loss=0
+    train_epoch_accuracy=0
+    train_epoch_hammingloss=0
+    train_epoch_f1score=0
+
+    val_epoch_loss=0
+    val_epoch_accuracy=0
+    val_epoch_hammingloss=0
+    val_epoch_f1score=0
+    
+    for rnn_x, cnn_x, y_dict in hybrid_train_loader:
+
+      rnn_x = rnn_x.to(device)
+      cnn_x = cnn_x.to(device)
+
+      y = y_dict[icdtype]
+      y = y.to(device)
+      
+
+      
+      preds=model(rnn_x, cnn_x)
+
+      optimizer.zero_grad()
+      loss=loss_fn(preds,y)
+      loss.backward()
+      optimizer.step()
+      
+      accuracy, hammingloss, f1score  = train_metric(preds,y)
+
+      train_epoch_loss+=loss.item()
+      train_epoch_accuracy+=accuracy.item()
+      train_epoch_hammingloss+=hammingloss
+      train_epoch_f1score+=f1score
+    
+    model.eval()
+    with torch.no_grad():
+      for rnn_x, cnn_x, y_dict in hybrid_val_loader:
+        
+        rnn_x = rnn_x.to(device)
+        cnn_x = cnn_x.to(device)
+
+        y = y_dict[icdtype]
+        y = y.to(device)
+        
+        preds=model(rnn_x, cnn_x)
+
+        loss=loss_fn(preds,y)
+        accuracy, hammingloss, f1score  = train_metric(preds,y)
+        val_epoch_loss+=loss.item()
+        val_epoch_accuracy+=accuracy.item()
+        val_epoch_hammingloss+=hammingloss
+        val_epoch_f1score+=f1score
+
+    
+    
+    print("\n")
+    print('-'*100)
+    print('Epoch = {}/{}:\n train_loss = {:.4f}, train_accuracy = {:.4f}, train_hammingloss = {:.4f}, train_f1score = {:.4f}\n val_loss = {:.4f}, val_accuracy = {:.4f}, val_hammmingloss = {:.4f}, val_f1score = {:.4f}'.format(epoch
+                                                              ,epochs
+                                                              ,train_epoch_loss/len(hybrid_train_loader)
+                                                              ,train_epoch_accuracy/len(hybrid_train_loader)
+                                                              ,train_epoch_hammingloss/len(hybrid_train_loader)
+                                                              ,train_epoch_f1score/len(hybrid_train_loader)
+                                                              ,val_epoch_loss/len(hybrid_val_loader)
+                                                              ,val_epoch_accuracy/len(hybrid_val_loader)
+                                                              ,val_epoch_hammingloss/len(hybrid_val_loader)
+                                                              ,val_epoch_f1score/len(hybrid_val_loader)
+                                                              ))
+    print('-'*100)
+    print("\n")
+    
\ No newline at end of file