Diff of /train.py [000000] .. [c621c3]

Switch to unified view

a b/train.py
1
import torch
2
import torch.nn as nn
3
import torch.optim as optim
4
5
import albumentations as A
6
from albumentations.pytorch import ToTensorV2
7
8
from tqdm import tqdm
9
from UNET import UNET
10
from utils import (
11
    load_checkpoint,
12
    save_checkpoint,
13
    get_loaders,
14
    check_accuracy,
15
    save_predictions_as_images
16
)
17
18
learning_rate = 1e-04
19
batch_size = 16
20
num_epochs = 3
21
num_workers = 2
22
image_height = 160
23
image_width = 240
24
pin_memory = True
25
load_model = False
26
train_img_dir = 'dataset/image/'
27
train_mask_dir = 'dataset/mask/'
28
test_img_dir = 'dataset/test_image/'
29
test_mask_dir = 'dataset/test_mask/'
30
31
def train_fn(loader, model, optimizer, loss_fn, scaler):
32
    loop = tqdm(loader)
33
34
    for batch_idx, (data, targets) in enumerate(loop):
35
        targets = targets.float().unsqueeze(1)
36
37
        # Forward
38
        with torch.cuda.amp.autocast():
39
            predictions = model(data)
40
            loss = loss_fn(predictions, targets)
41
42
        # Backward
43
        optimizer.zero_grad()
44
        scaler.scale(loss).backward()
45
        scaler.step(optimizer)
46
        scaler.update()
47
48
        #update tqdm_loop
49
        loop.set_postfix(loss=loss.item())
50
51
def main():
52
    train_transform = A.Compose(
53
        [
54
            A.Resize(height=image_height, width=image_width),
55
            A.Rotate(limit=35, p=1.0),
56
            A.HorizontalFlip(p=0.5),
57
            A.VerticalFlip(p=0.1),
58
            A.Normalize(
59
                mean=[0.0,0.0,0.0],
60
                std=[1.0,1.0,1.0],
61
                max_pixel_value=255.0
62
            ),
63
            ToTensorV2()
64
        ]
65
    )
66
67
    test_transforms = A.Compose(
68
        [
69
            A.Resize(height=image_height, width=image_width),
70
            A.Normalize(
71
                        mean=[0.0,0.0,0.0],
72
                        std=[1.0,1.0,1.0],
73
                        max_pixel_value=255.0
74
            ),
75
            ToTensorV2()
76
        ]
77
    )
78
79
    model = UNET(in_channels=3, out_channels=1)
80
    loss_fn = nn.BCEWithLogitsLoss() #Since we are not doing Sigmoid on the output of the model.
81
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
82
83
    train_loader, test_loader = get_loaders(
84
        train_img_dir,
85
        train_mask_dir,
86
        test_img_dir,
87
        test_mask_dir,
88
        batch_size,
89
        train_transform,
90
        test_transforms,
91
        num_workers,
92
        pin_memory
93
    )
94
95
    if load_model:
96
        load_checkpoint(torch.load('my_checkpoint.pth.tar'), model)
97
98
    check_accuracy(test_loader, model)
99
    scaler = torch.cuda.amp.GradScaler()
100
    for epoch in range(num_epochs):
101
        train_fn(train_loader, model, optimizer, loss_fn, scaler)
102
        checkpoint = {
103
            'state_dict': model.state_dict(),
104
            'optimizer': optimizer.state_dict()
105
        }
106
        save_checkpoint(checkpoint)
107
        check_accuracy(test_loader, model)
108
        save_predictions_as_images(test_loader, model, folder='saved_images/')
109
110
111
if __name__ == "__main__":
112
    main()