Diff of /main.py [000000] .. [72db80]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,72 @@
+import os
+import torch
+
+from trainer import train_model
+import utils as ut
+
+from loss.diceloss import diceloss
+from metrics import m
+
+from models.UnetAttention import UnnetAttention
+
+def run_nn():
+    """
+        Version requirements:
+            PyTorch Version:  >1.2.0
+            Torchvision Version:  >0.4.0a0+6b959ee
+    """
+
+    """
+        Parameters to execute the method
+    """
+    root_dir = r'./data'
+
+    epochs = 100
+    batch_size = 8
+    # Filename of the final model weigths
+    weight_filename = "weights_final.pt"
+   
+    data_aug = 'online'
+
+    log_path = './weights/'
+ 
+    """
+        Main 
+    """
+
+    ut.create_nested_dir(log_path)
+
+    # Loads the distribution of the cases between train and val
+    cases = ut.load_dataset_dist()
+
+    # Create the dataloader
+    dataloaders = ut.get_data_loaders(
+        data_aug, cases, root_dir, batch_size)
+
+    model = UnnetAttention()
+
+    model.train()
+
+    # Load the loss object by name
+    criterion = diceloss()
+    # Specify the optimizer with a lower learning rate
+    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
+
+    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
+        optimizer, step_size=7, gamma=0.9
+    )
+
+    # Specify the evalutation metrics
+    metrics = {'dice': m.mean_dice_coef,
+               'dice_target': m.mean_dice_coef_remove_empty}
+
+    train_model(model, criterion, dataloaders,
+                optimizer, exp_lr_scheduler, bpath=log_path, metrics=metrics, num_epochs=epochs)
+
+    # Save the trained model
+    torch.save(model, os.path.join(log_path, weight_filename))
+    print('\n\n ### ===> Training finished sucessfully!\n\n')
+
+
+if __name__ == '__main__':
+    run_nn()