Diff of /train.py [000000] .. [ef4563]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,74 @@
+import math
+import torch
+from config import (
+    TRAINING_EPOCH, NUM_CLASSES, IN_CHANNELS, BCE_WEIGHTS, BACKGROUND_AS_CLASS, TRAIN_CUDA
+)
+from torch.nn import CrossEntropyLoss
+from dataset import get_train_val_test_Dataloaders
+from torch.optim import Adam
+from torch.utils.tensorboard import SummaryWriter
+from unet3d import UNet3D
+from transforms import (train_transform, train_transform_cuda,
+                        val_transform, val_transform_cuda)
+
+if BACKGROUND_AS_CLASS: NUM_CLASSES += 1
+
+writer = SummaryWriter("runs")
+
+model = UNet3D(in_channels=IN_CHANNELS , num_classes= NUM_CLASSES)
+train_transforms = train_transform
+val_transforms = val_transform
+
+if torch.cuda.is_available() and TRAIN_CUDA:
+    model = model.cuda()
+    train_transforms = train_transform_cuda
+    val_transforms = val_transform_cuda 
+elif not torch.cuda.is_available() and TRAIN_CUDA:
+    print('cuda not available! Training initialized on cpu ...')
+
+
+train_dataloader, val_dataloader, _ = get_train_val_test_Dataloaders(train_transforms= train_transforms, val_transforms=val_transforms, test_transforms= val_transforms)
+
+
+criterion = CrossEntropyLoss(weight=torch.Tensor(BCE_WEIGHTS))
+optimizer = Adam(params=model.parameters())
+
+min_valid_loss = math.inf
+
+for epoch in range(TRAINING_EPOCH):
+    
+    train_loss = 0.0
+    model.train()
+    for data in train_dataloader:
+        image, ground_truth = data['image'], data['label']
+        optimizer.zero_grad()
+        target = model(image)
+        loss = criterion(target, ground_truth)
+        loss.backward()
+        optimizer.step()
+
+        train_loss += loss.item()
+    
+    valid_loss = 0.0
+    model.eval()
+    for data in val_dataloader:
+        image, ground_truth = data['image'], data['label']
+        
+        target = model(image)
+        loss = criterion(target,ground_truth)
+        valid_loss = loss.item()
+        
+    writer.add_scalar("Loss/Train", train_loss / len(train_dataloader), epoch)
+    writer.add_scalar("Loss/Validation", valid_loss / len(val_dataloader), epoch)
+    
+    print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / len(train_dataloader)} \t\t Validation Loss: {valid_loss / len(val_dataloader)}')
+    
+    if min_valid_loss > valid_loss:
+        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
+        min_valid_loss = valid_loss
+        # Saving State Dict
+        torch.save(model.state_dict(), f'checkpoints/epoch{epoch}_valLoss{min_valid_loss}.pth')
+
+writer.flush()
+writer.close()
+