Diff of /trainer.py [000000] .. [70e190]

Switch to side-by-side view

--- a
+++ b/trainer.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+from tqdm import tqdm
+from utils import write_csv
+import sys
+
+def trainer(config, train_loader, optimizer, model, ce, dice, iou, hd):        
+    model.train()
+    steps = len(train_loader)
+    pbar = tqdm(total=steps)
+    
+    total_ce_loss, total_dice_score, total_dice_loss, \
+    total_iou_score, total_iou_loss, total_loss, total_hausdorff = 0.0,0.0,0.0,0.0,0.0,0.0,0.0
+    
+    for iter, (input, target) in tqdm(enumerate(train_loader)):
+        sys.stdout.write(f"\riter: {iter+1}/{steps}")
+        sys.stdout.flush()
+        
+        input = input.unsqueeze(1).cuda()
+        target = target.cuda()
+        logits, _, _, _ = model(input)
+        
+        ce_loss = ce(logits, target)
+        dice_score, dice_loss, class_dice_score, class_dice_loss = dice(logits, target)
+        iou_score, class_iou = iou(logits, target)
+        hausdorff = hd(logits, target)
+        loss = dice_loss*0.4 + (1 - iou_score)*0.6
+        
+        total_ce_loss += ce_loss.item()
+        total_dice_score += dice_score.item()
+        total_dice_loss += dice_loss.item()
+        total_iou_score += iou_score.item()
+        total_iou_loss += 1.0-iou_score.item()
+        total_hausdorff += hausdorff
+        total_loss += loss.item()
+            
+        write_csv(f'outputs/{config.name}/iter_log.csv', [
+            ce_loss.item(),
+            dice_score.item(),
+            dice_loss.item(),
+            iou_score.item(),
+            1.0-iou_score.item(),
+            hausdorff,
+            loss.item()
+        ])
+        
+        write_csv(f'outputs/{config.name}/ds_class_iter.csv', class_dice_score)
+        write_csv(f'outputs/{config.name}/dl_loss_iter.csv', class_dice_loss)
+        write_csv(f'outputs/{config.name}/iou_class_iter.csv', class_iou)
+                                        
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        pbar.update(1)
+    
+    pbar.close()           
+    return OrderedDict([
+        ('ce_loss', total_ce_loss / steps),
+        ('dice_score', total_dice_score / steps),
+        ('dice_loss', total_dice_loss / steps),
+        ('iou_score', total_iou_score / steps),
+        ('iou_loss', total_iou_loss / steps),
+        ('hausdorff', total_hausdorff / steps),
+        ('loss', total_loss / steps)
+    ])
+    
+    
+def validate(config, val_loader, model, ce, dice, iou, hd):
+    model.eval()
+    steps = len(val_loader)
+    
+    total_ce_loss, total_dice_score, total_dice_loss, \
+    total_iou_score, total_iou_loss, total_loss, total_hausdorff = 0.0,0.0,0.0,0.0,0.0,0.0,0.0
+    
+    with torch.no_grad():
+        for input, target in val_loader:
+            input = input.unsqueeze(1).cuda()
+            target = target.cuda()
+            logits, _, _, _ = model(input)
+        
+            ce_loss = ce(logits, target)
+            dice_score, dice_loss, _, _ = dice(logits, target)
+            iou_score, _ = iou(logits, target)
+            hausdorff = hd(logits, target)
+            loss = dice_loss*0.4 + (1 - iou_score)*0.6
+            
+            total_ce_loss += ce_loss.item()
+            total_dice_score += dice_score.item()
+            total_dice_loss += dice_loss.item()
+            total_iou_score += iou_score.item()
+            total_iou_loss += 1.0-iou_score.item()
+            total_hausdorff += hausdorff
+            total_loss += loss.item()
+            
+    return OrderedDict([
+        ('ce_loss', total_ce_loss / steps),
+        ('dice_score', total_dice_score / steps),
+        ('dice_loss', total_dice_loss / steps),
+        ('iou_score', total_iou_score / steps),
+        ('iou_loss', total_iou_loss / steps),
+        ('hausdorff', total_hausdorff / steps),
+        ('loss', total_loss / steps)
+    ])
\ No newline at end of file