[72db80]: / main.py

Download this file

73 lines (51 with data), 1.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
from trainer import train_model
import utils as ut
from loss.diceloss import diceloss
from metrics import m
from models.UnetAttention import UnnetAttention
def run_nn():
"""
Version requirements:
PyTorch Version: >1.2.0
Torchvision Version: >0.4.0a0+6b959ee
"""
"""
Parameters to execute the method
"""
root_dir = r'./data'
epochs = 100
batch_size = 8
# Filename of the final model weigths
weight_filename = "weights_final.pt"
data_aug = 'online'
log_path = './weights/'
"""
Main
"""
ut.create_nested_dir(log_path)
# Loads the distribution of the cases between train and val
cases = ut.load_dataset_dist()
# Create the dataloader
dataloaders = ut.get_data_loaders(
data_aug, cases, root_dir, batch_size)
model = UnnetAttention()
model.train()
# Load the loss object by name
criterion = diceloss()
# Specify the optimizer with a lower learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=7, gamma=0.9
)
# Specify the evalutation metrics
metrics = {'dice': m.mean_dice_coef,
'dice_target': m.mean_dice_coef_remove_empty}
train_model(model, criterion, dataloaders,
optimizer, exp_lr_scheduler, bpath=log_path, metrics=metrics, num_epochs=epochs)
# Save the trained model
torch.save(model, os.path.join(log_path, weight_filename))
print('\n\n ### ===> Training finished sucessfully!\n\n')
if __name__ == '__main__':
run_nn()