Diff of /main_ACL.py [000000] .. [3ed61c]

Switch to unified view

a b/main_ACL.py
1
"""
2
Created on December 11, 2021.
3
main_ACL.py
4
5
@author: Soroosh Tayebi Arasteh <soroosh.arasteh@fau.de>
6
https://github.com/tayebiarasteh/
7
"""
8
9
import pdb
10
from torch.nn import CrossEntropyLoss
11
import torch
12
import os
13
14
from models.ACL_model import ACL_net
15
from config.serde import open_experiment, create_experiment, delete_experiment
16
from Train_Valid_ACL import Training
17
18
import warnings
19
warnings.filterwarnings('ignore')
20
21
22
23
def main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml", valid=False,
24
                  resume=False, augment=False, experiment_name='name'):
25
    """Main function for training + validation for directly 3d-wise
26
27
        Parameters
28
        ----------
29
        global_config_path: str
30
            always global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml"
31
32
        valid: bool
33
            if we want to do validation
34
35
        resume: bool
36
            if we are resuming training on a model
37
38
        experiment_name: str
39
            name of the experiment, in case of resuming training.
40
            name of new experiment, in case of new training.
41
    """
42
    if resume == True:
43
        params = open_experiment(experiment_name, global_config_path)
44
    else:
45
        params = create_experiment(experiment_name, global_config_path)
46
    cfg_path = params["cfg_path"]
47
48
    # Changeable network parameters
49
    model = ACL_net()
50
    loss_function = CrossEntropyLoss
51
    optimizer = torch.optim.Adam(model.parameters(), lr=float(params['Network']['lr']),
52
                                 weight_decay=float(params['Network']['weight_decay']), amsgrad=params['Network']['amsgrad'])
53
54
    trainer = Training(cfg_path, num_iterations=params['num_iterations'], resume=resume)
55
    if resume == True:
56
        trainer.load_checkpoint(model=model, optimiser=optimizer, loss_function=loss_function)
57
    else:
58
        trainer.setup_model(model=model, optimiser=optimizer,
59
                        loss_function=loss_function)
60
61
    # loading the data
62
    train_loader = torch.ones((1, 2, 110, 281, 285))
63
    valid_loader = torch.ones((1, 2, 110, 281, 285))
64
65
    trainer.execute_training(train_loader=train_loader, valid_loader=valid_loader, augmentation=augment)
66
67
68
69
70
71
if __name__ == '__main__':
72
    delete_experiment(experiment_name='testtest', global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml")
73
    main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml",
74
                  valid=False, resume=False, augment=False, experiment_name='testtest')