|
a |
|
b/main_ACL.py |
|
|
1 |
""" |
|
|
2 |
Created on December 11, 2021. |
|
|
3 |
main_ACL.py |
|
|
4 |
|
|
|
5 |
@author: Soroosh Tayebi Arasteh <soroosh.arasteh@fau.de> |
|
|
6 |
https://github.com/tayebiarasteh/ |
|
|
7 |
""" |
|
|
8 |
|
|
|
9 |
import pdb |
|
|
10 |
from torch.nn import CrossEntropyLoss |
|
|
11 |
import torch |
|
|
12 |
import os |
|
|
13 |
|
|
|
14 |
from models.ACL_model import ACL_net |
|
|
15 |
from config.serde import open_experiment, create_experiment, delete_experiment |
|
|
16 |
from Train_Valid_ACL import Training |
|
|
17 |
|
|
|
18 |
import warnings |
|
|
19 |
warnings.filterwarnings('ignore') |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml", valid=False, |
|
|
24 |
resume=False, augment=False, experiment_name='name'): |
|
|
25 |
"""Main function for training + validation for directly 3d-wise |
|
|
26 |
|
|
|
27 |
Parameters |
|
|
28 |
---------- |
|
|
29 |
global_config_path: str |
|
|
30 |
always global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml" |
|
|
31 |
|
|
|
32 |
valid: bool |
|
|
33 |
if we want to do validation |
|
|
34 |
|
|
|
35 |
resume: bool |
|
|
36 |
if we are resuming training on a model |
|
|
37 |
|
|
|
38 |
experiment_name: str |
|
|
39 |
name of the experiment, in case of resuming training. |
|
|
40 |
name of new experiment, in case of new training. |
|
|
41 |
""" |
|
|
42 |
if resume == True: |
|
|
43 |
params = open_experiment(experiment_name, global_config_path) |
|
|
44 |
else: |
|
|
45 |
params = create_experiment(experiment_name, global_config_path) |
|
|
46 |
cfg_path = params["cfg_path"] |
|
|
47 |
|
|
|
48 |
# Changeable network parameters |
|
|
49 |
model = ACL_net() |
|
|
50 |
loss_function = CrossEntropyLoss |
|
|
51 |
optimizer = torch.optim.Adam(model.parameters(), lr=float(params['Network']['lr']), |
|
|
52 |
weight_decay=float(params['Network']['weight_decay']), amsgrad=params['Network']['amsgrad']) |
|
|
53 |
|
|
|
54 |
trainer = Training(cfg_path, num_iterations=params['num_iterations'], resume=resume) |
|
|
55 |
if resume == True: |
|
|
56 |
trainer.load_checkpoint(model=model, optimiser=optimizer, loss_function=loss_function) |
|
|
57 |
else: |
|
|
58 |
trainer.setup_model(model=model, optimiser=optimizer, |
|
|
59 |
loss_function=loss_function) |
|
|
60 |
|
|
|
61 |
# loading the data |
|
|
62 |
train_loader = torch.ones((1, 2, 110, 281, 285)) |
|
|
63 |
valid_loader = torch.ones((1, 2, 110, 281, 285)) |
|
|
64 |
|
|
|
65 |
trainer.execute_training(train_loader=train_loader, valid_loader=valid_loader, augmentation=augment) |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
|
|
|
69 |
|
|
|
70 |
|
|
|
71 |
if __name__ == '__main__': |
|
|
72 |
delete_experiment(experiment_name='testtest', global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml") |
|
|
73 |
main_train_3D(global_config_path="/home/soroosh/Documents/Repositories/ACL_tear/config/config.yaml", |
|
|
74 |
valid=False, resume=False, augment=False, experiment_name='testtest') |