Diff of /train.py [000000] .. [c621c3]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+
+from tqdm import tqdm
+from UNET import UNET
+from utils import (
+    load_checkpoint,
+    save_checkpoint,
+    get_loaders,
+    check_accuracy,
+    save_predictions_as_images
+)
+
+learning_rate = 1e-04
+batch_size = 16
+num_epochs = 3
+num_workers = 2
+image_height = 160
+image_width = 240
+pin_memory = True
+load_model = False
+train_img_dir = 'dataset/image/'
+train_mask_dir = 'dataset/mask/'
+test_img_dir = 'dataset/test_image/'
+test_mask_dir = 'dataset/test_mask/'
+
+def train_fn(loader, model, optimizer, loss_fn, scaler):
+    loop = tqdm(loader)
+
+    for batch_idx, (data, targets) in enumerate(loop):
+        targets = targets.float().unsqueeze(1)
+
+        # Forward
+        with torch.cuda.amp.autocast():
+            predictions = model(data)
+            loss = loss_fn(predictions, targets)
+
+        # Backward
+        optimizer.zero_grad()
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()
+
+        #update tqdm_loop
+        loop.set_postfix(loss=loss.item())
+
+def main():
+    train_transform = A.Compose(
+        [
+            A.Resize(height=image_height, width=image_width),
+            A.Rotate(limit=35, p=1.0),
+            A.HorizontalFlip(p=0.5),
+            A.VerticalFlip(p=0.1),
+            A.Normalize(
+                mean=[0.0,0.0,0.0],
+                std=[1.0,1.0,1.0],
+                max_pixel_value=255.0
+            ),
+            ToTensorV2()
+        ]
+    )
+
+    test_transforms = A.Compose(
+        [
+            A.Resize(height=image_height, width=image_width),
+            A.Normalize(
+                        mean=[0.0,0.0,0.0],
+                        std=[1.0,1.0,1.0],
+                        max_pixel_value=255.0
+            ),
+            ToTensorV2()
+        ]
+    )
+
+    model = UNET(in_channels=3, out_channels=1)
+    loss_fn = nn.BCEWithLogitsLoss() #Since we are not doing Sigmoid on the output of the model.
+    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+
+    train_loader, test_loader = get_loaders(
+        train_img_dir,
+        train_mask_dir,
+        test_img_dir,
+        test_mask_dir,
+        batch_size,
+        train_transform,
+        test_transforms,
+        num_workers,
+        pin_memory
+    )
+
+    if load_model:
+        load_checkpoint(torch.load('my_checkpoint.pth.tar'), model)
+
+    check_accuracy(test_loader, model)
+    scaler = torch.cuda.amp.GradScaler()
+    for epoch in range(num_epochs):
+        train_fn(train_loader, model, optimizer, loss_fn, scaler)
+        checkpoint = {
+            'state_dict': model.state_dict(),
+            'optimizer': optimizer.state_dict()
+        }
+        save_checkpoint(checkpoint)
+        check_accuracy(test_loader, model)
+        save_predictions_as_images(test_loader, model, folder='saved_images/')
+
+
+if __name__ == "__main__":
+    main()