Diff of /train.py [000000] .. [6d4adb]

Switch to unified view

a b/train.py
1
import torch
2
from tqdm import tqdm
3
import torch.nn  as nn
4
import torch.optim as optim
5
from Unet.model import UNET
6
from imutils import paths
7
from Dataset.datasetloader import MRIDataset
8
from Utils import utils
9
from torch.utils.data import DataLoader
10
from Preprocessing.preprocessing import Preprocessor
11
import argparse
12
# import wandb
13
14
15
# Hyperparam etc.
16
LEARNING_RATE = 1e-3
17
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
BATCH_SIZE = 8
19
NUM_EPOCHS = 100
20
LOAD_MODEL = False
21
NUM_WORKERS = 0
22
23
24
# wandb.init(project='Unet', entity='grubi')
25
# config = wandb.config
26
# config.learning_rate = LEARNING_RATE
27
28
29
def train_fn(loader, model, optimizer, loss_fn, scaler):
30
    loop = tqdm(loader)
31
32
33
    for batch_idx, (data, targets) in enumerate(loop):
34
        data = data.to(DEVICE)
35
        targets = targets.to(DEVICE)
36
37
        # forward
38
        with torch.cuda.amp.autocast():
39
            predictions = model(data)
40
            loss = loss_fn(predictions, targets)
41
42
43
        # backward
44
        optimizer.zero_grad()
45
        scaler.scale(loss).backward()
46
        scaler.step(optimizer)
47
        scaler.update()
48
49
        # update tqdm loop
50
        loop.set_postfix(loss=loss.item())
51
        wandb.log({"loss": loss})
52
53
54
55
56
ap = argparse.ArgumentParser()
57
ap.add_argument("-i", "--images", required=True, help="path to images directory")
58
ap.add_argument("-l", "--labels", required=True, help="path to labels directory")
59
ap.add_argument("-vi", "--val_images", required=True, help="path to images directory")
60
ap.add_argument("-vl", "--val_labels", required=True, help="path to labels directory")
61
args = vars(ap.parse_args())
62
63
print("[INFO] loading images and labels...")
64
imagePath = list(paths.list_images(args["images"]))
65
labelPath = list(paths.list_images(args["labels"]))
66
val_imagePath = list(paths.list_images(args["val_images"]))
67
val_labelPath = list(paths.list_images(args["val_labels"]))
68
69
prep = Preprocessor(128, 256)
70
dl = DataLoader(MRIDataset(imgpath=imagePath, labelpath=labelPath, preprocessors=[prep], verbose=200),
71
                batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
72
dl_val = DataLoader(MRIDataset(imgpath=val_imagePath, labelpath=val_labelPath, preprocessors=[prep], verbose=200),
73
                    batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
74
75
model = UNET(in_channels=1, out_channels=1).to(DEVICE).float()
76
loss_fnc = nn.BCEWithLogitsLoss()
77
optimizer = optim.RMSprop(model.parameters(), lr=LEARNING_RATE)
78
79
if LOAD_MODEL:
80
    utils.load_checkpoint(torch.load('tmp/checkpoint.pth.tar'), model)
81
82
scaler = torch.cuda.amp.GradScaler()
83
84
# wandb.watch(model)
85
for epoch in range(NUM_EPOCHS):
86
    print(epoch+1)
87
88
    train_fn(dl, model, optimizer, loss_fnc, scaler)
89
90
    checkpoint = {
91
        'state_dict': model.state_dict(),
92
        'optimizer': optimizer.state_dict()
93
    }
94
95
    utils.save_checkpoint(checkpoint,filename='tmp/checkpoint.pth.tar')
96
    utils.check_accuracy(dl_val, model, DEVICE)
97
    utils.save_predictions_as_imgs(dl_val, model, DEVICE)