Diff of /train.py [000000] .. [6d4adb]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,97 @@
+import torch
+from tqdm import tqdm
+import torch.nn  as nn
+import torch.optim as optim
+from Unet.model import UNET
+from imutils import paths
+from Dataset.datasetloader import MRIDataset
+from Utils import utils
+from torch.utils.data import DataLoader
+from Preprocessing.preprocessing import Preprocessor
+import argparse
+# import wandb
+
+
+# Hyperparam etc.
+LEARNING_RATE = 1e-3
+DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+BATCH_SIZE = 8
+NUM_EPOCHS = 100
+LOAD_MODEL = False
+NUM_WORKERS = 0
+
+
+# wandb.init(project='Unet', entity='grubi')
+# config = wandb.config
+# config.learning_rate = LEARNING_RATE
+
+
+def train_fn(loader, model, optimizer, loss_fn, scaler):
+    loop = tqdm(loader)
+
+
+    for batch_idx, (data, targets) in enumerate(loop):
+        data = data.to(DEVICE)
+        targets = targets.to(DEVICE)
+
+        # forward
+        with torch.cuda.amp.autocast():
+            predictions = model(data)
+            loss = loss_fn(predictions, targets)
+
+
+        # backward
+        optimizer.zero_grad()
+        scaler.scale(loss).backward()
+        scaler.step(optimizer)
+        scaler.update()
+
+        # update tqdm loop
+        loop.set_postfix(loss=loss.item())
+        wandb.log({"loss": loss})
+
+
+
+
+ap = argparse.ArgumentParser()
+ap.add_argument("-i", "--images", required=True, help="path to images directory")
+ap.add_argument("-l", "--labels", required=True, help="path to labels directory")
+ap.add_argument("-vi", "--val_images", required=True, help="path to images directory")
+ap.add_argument("-vl", "--val_labels", required=True, help="path to labels directory")
+args = vars(ap.parse_args())
+
+print("[INFO] loading images and labels...")
+imagePath = list(paths.list_images(args["images"]))
+labelPath = list(paths.list_images(args["labels"]))
+val_imagePath = list(paths.list_images(args["val_images"]))
+val_labelPath = list(paths.list_images(args["val_labels"]))
+
+prep = Preprocessor(128, 256)
+dl = DataLoader(MRIDataset(imgpath=imagePath, labelpath=labelPath, preprocessors=[prep], verbose=200),
+                batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
+dl_val = DataLoader(MRIDataset(imgpath=val_imagePath, labelpath=val_labelPath, preprocessors=[prep], verbose=200),
+                    batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
+
+model = UNET(in_channels=1, out_channels=1).to(DEVICE).float()
+loss_fnc = nn.BCEWithLogitsLoss()
+optimizer = optim.RMSprop(model.parameters(), lr=LEARNING_RATE)
+
+if LOAD_MODEL:
+    utils.load_checkpoint(torch.load('tmp/checkpoint.pth.tar'), model)
+
+scaler = torch.cuda.amp.GradScaler()
+
+# wandb.watch(model)
+for epoch in range(NUM_EPOCHS):
+    print(epoch+1)
+
+    train_fn(dl, model, optimizer, loss_fnc, scaler)
+
+    checkpoint = {
+        'state_dict': model.state_dict(),
+        'optimizer': optimizer.state_dict()
+    }
+
+    utils.save_checkpoint(checkpoint,filename='tmp/checkpoint.pth.tar')
+    utils.check_accuracy(dl_val, model, DEVICE)
+    utils.save_predictions_as_imgs(dl_val, model, DEVICE)