a b/main.py
1
import os
2
import torch
3
4
from trainer import train_model
5
import utils as ut
6
7
from loss.diceloss import diceloss
8
from metrics import m
9
10
from models.UnetAttention import UnnetAttention
11
12
def run_nn():
13
    """
14
        Version requirements:
15
            PyTorch Version:  >1.2.0
16
            Torchvision Version:  >0.4.0a0+6b959ee
17
    """
18
19
    """
20
        Parameters to execute the method
21
    """
22
    root_dir = r'./data'
23
24
    epochs = 100
25
    batch_size = 8
26
    # Filename of the final model weigths
27
    weight_filename = "weights_final.pt"
28
   
29
    data_aug = 'online'
30
31
    log_path = './weights/'
32
 
33
    """
34
        Main 
35
    """
36
37
    ut.create_nested_dir(log_path)
38
39
    # Loads the distribution of the cases between train and val
40
    cases = ut.load_dataset_dist()
41
42
    # Create the dataloader
43
    dataloaders = ut.get_data_loaders(
44
        data_aug, cases, root_dir, batch_size)
45
46
    model = UnnetAttention()
47
48
    model.train()
49
50
    # Load the loss object by name
51
    criterion = diceloss()
52
    # Specify the optimizer with a lower learning rate
53
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
54
55
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
56
        optimizer, step_size=7, gamma=0.9
57
    )
58
59
    # Specify the evalutation metrics
60
    metrics = {'dice': m.mean_dice_coef,
61
               'dice_target': m.mean_dice_coef_remove_empty}
62
63
    train_model(model, criterion, dataloaders,
64
                optimizer, exp_lr_scheduler, bpath=log_path, metrics=metrics, num_epochs=epochs)
65
66
    # Save the trained model
67
    torch.save(model, os.path.join(log_path, weight_filename))
68
    print('\n\n ### ===> Training finished sucessfully!\n\n')
69
70
71
if __name__ == '__main__':
72
    run_nn()