Diff of /main.py [000000] .. [6d4aaa]

Switch to unified view

a b/main.py
1
import os
2
import logging
3
import tensorflow as tf
4
from medseg_dl import parameters
5
from medseg_dl.utils import utils_data, utils_misc
6
from medseg_dl.model import input_fn
7
from medseg_dl.model import model_fn
8
from medseg_dl.model import training
9
import sys
10
11
12
def main(dir_model, non_local='disable', attgate='disable', device=None, idx_dataset=0):
13
    # Since this is the training script all variables should be trainable
14
    b_training = True  # hardcoded since scripts are not identical
15
16
    # Load / generate parameters from model file, if available
17
    file_params = os.path.join(dir_model, 'params.yaml')
18
    if os.path.isfile(file_params):
19
        params = parameters.Params(path_yaml=file_params)
20
    else:
21
        params = parameters.Params(model_dir=dir_model, idx_dataset=idx_dataset)
22
23
    # Set logger
24
    utils_misc.set_logger(os.path.join(params.dict['dir_model'], 'train.log'), params.dict['log_level'])
25
26
    # Set random seed
27
    if params.dict['b_use_seed']:
28
        tf.set_random_seed(params.dict['random_seed'])
29
30
    # Set device for graph calc
31
    if device:
32
        os.environ['CUDA_VISIBLE_DEVICES'] = device
33
    else:
34
        os.environ['CUDA_VISIBLE_DEVICES'] = params.dict['device']
35
36
    """ Fetch data, generate pipeline and model """
37
    tf.reset_default_graph()
38
39
    # Fetch datasets, atm. saved as json
40
    logging.info('Fetching the datasets...')
41
    filenames_train, filenames_eval = utils_data.load_sets(params.dict['dir_data'],
42
                                                           params.dict['dir_model'],
43
                                                           path_parser_cfg=params.dict['path_parser_cfg'],
44
                                                           set_split=params.dict['set_split'],
45
                                                           b_recreate=True)
46
47
    # Create a tf.data pipeline
48
    logging.info('Creating the pipeline...')
49
    spec_pipeline = input_fn.gen_pipeline_train(filenames=filenames_train,
50
                                                shape_image=params.dict['shape_image'],
51
                                                shape_input=params.dict['shape_input'],
52
                                                shape_output=params.dict['shape_output'],
53
                                                channels_out=params.dict['channels_out'],
54
                                                size_batch=params.dict['size_batch'],
55
                                                size_buffer=params.dict['size_buffer'],
56
                                                num_parallel_calls=params.dict['num_parallel_calls'],
57
                                                repeat=params.dict['repeat'],
58
                                                b_shuffle=params.dict['b_shuffle'],
59
                                                patches_per_class=params.dict['patches_per_class'],
60
                                                sigma_offset=params.dict['sigma_offset'],
61
                                                sigma_noise=params.dict['sigma_noise'],
62
                                                sigma_pos=params.dict['sigma_pos'],
63
                                                b_mirror=params.dict['b_mirror'],
64
                                                b_rotate=params.dict['b_rotate'],
65
                                                b_scale=params.dict['b_scale'],
66
                                                b_warp=params.dict['b_warp'],
67
                                                b_permute_labels=params.dict['b_permute_labels'],
68
                                                angle_max=params.dict['angle_max'],
69
                                                scale_factor=params.dict['scale_factor'],
70
                                                delta_max=params.dict['delta_max'],
71
                                                b_verbose=False)
72
73
74
75
    # Create the model (incorporating losses, optimizer, metrics)
76
    logging.info('Creating the model...')
77
    spec_model = model_fn.model_fn(spec_pipeline,
78
                                   input_metrics=None,
79
                                   b_training=b_training,
80
                                   channels=params.dict['channels'],
81
                                   channels_out=params.dict['channels_out'],
82
                                   batch_size=params.dict['size_batch'],
83
                                   b_dynamic_pos_mid=params.dict['b_dynamic_pos_mid'],
84
                                   b_dynamic_pos_end=params.dict['b_dynamic_pos_end'],
85
                                   non_local=non_local,
86
                                   non_local_num=1,
87
                                   attgate=attgate,
88
                                   filters=params.dict['filters'],
89
                                   dense_layers=params.dict['dense_layers'],
90
                                   alpha=params.dict['alpha'],
91
                                   dropout_rate=params.dict['rate_dropout'],
92
                                   rate_learning=params.dict['rate_learning'],
93
                                   beta1=params.dict['beta1'],
94
                                   beta2=params.dict['beta2'],
95
                                   epsilon=params.dict['epsilon'])
96
97
98
    # Train the actual model
99
    logging.info('Starting training for %i epoch(s)', params.dict['num_epochs'])
100
    training.sess_train(spec_pipeline, spec_model, params)
101
102
103
if __name__ == '__main__':
104
    # model
105
    model_dir = '/home/stage13_realshuffle_attgate_batch24_l2_repeat10_shuffle_eval_reminder_timout6000_K'
106
107
    # crucial parameters
108
    logging.info('Received the following input(s): %s', str(sys.argv))
109
110
    non_local = 'disable'
111
    if len(sys.argv) > 1:
112
        non_local = str(sys.argv[1])
113
    logging.info('Adding nonlocal block after %s', non_local)
114
115
    attgate = 'active'
116
    if len(sys.argv) > 2:
117
        attgate = istr(sys.argv[2])
118
    logging.info('Adding attention gate after %s', attgate)
119
120
    device = '3'
121
    if len(sys.argv) > 3:
122
        device = str(sys.argv[3])
123
    logging.info('Calculating on device %s', device)
124
125
    idx_dataset = 0
126
    if len(sys.argv) > 4:
127
        idx_dataset = int(sys.argv[4])
128
    logging.info('Using dataset %i', idx_dataset)
129
130
    main(model_dir, non_local=non_local, attgate=attgate, device=device, idx_dataset=idx_dataset)
131
132
133
134
135