[6d4aaa]: / medseg_dl / model / training.py

Download this file

99 lines (77 with data), 4.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import tensorflow as tf
import os
import logging
from medseg_dl.utils import utils_misc
# from tqdm import tqdm
def sess_train(spec_pipeline, spec_model, params):
# Add an op to initialize the variables
init_op_vars = tf.global_variables_initializer()
# Fetch global step of default graph
global_step = tf.train.get_global_step()
# Add ops to save and restore all variables
saver_recent = tf.train.Saver(max_to_keep=10) # keeps the last 10 ckpts
# generate summary writer
writer = tf.summary.FileWriter(params.dict['dir_logs_train'])
logging.info(f'saving log to {params.dict["dir_logs_train"]}')
# Define fetched variables
fetched_train = {'loss_value': spec_model['loss'],
'train_op': spec_model['train_op'],
'update_metrics_op_train': spec_model['update_op_metrics'],
'summary_train': spec_model['summary_op_train'],
'gstep': global_step}
fetched_metrics_train = {'metrics': spec_model['metrics_values'],
'summary_metrics': spec_model['summary_op_metrics']}
if params.dict['b_viewer_train']:
fetched_train.update({'images': spec_pipeline['images'],
'labels': spec_pipeline['labels'],
'positions': spec_pipeline['positions'],
'probs': spec_model['probs']})
# set growth option
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
_ = tf.summary.FileWriter(params.dict['dir_graphs_train'], sess.graph)
logging.info(f'Graph saved in {params.dict["dir_graphs_train"]}')
sess.run(init_op_vars) # init global variables
first_epoch = 0
total_steps = 0
# Reload weights from directory if previous training is restored
if params.dict['b_restore']:
if os.path.isdir(params.dict['dir_ckpts']):
file_ckpt = tf.train.latest_checkpoint(params.dict['dir_ckpts'])
first_epoch = int(os.path.basename(file_ckpt).split('-')[1]) + 1
logging.info(f'Restoring parameters from {file_ckpt}')
saver_recent.restore(sess, file_ckpt)
# Epochs
for epoch in range(first_epoch, params.dict['num_epochs']):
# training
logging.info(f'Epoch {epoch + 1}/{params.dict["num_epochs"]}: training')
sess.run(spec_pipeline['init_op_iter']) # initialize dataset
sess.run(spec_model['init_op_metrics']) # reset metrics
# training step
#pbar = tqdm(total=total_steps)
while True:
try:
results = sess.run(fetched_train) # perform mini-batch update
gstep = results['gstep']
loss_value = results['loss_value']
logging.info(f'Step {gstep}, loss: {loss_value}')
writer.add_summary(results['summary_train'], global_step=gstep) # write loss per batch
# allow viewing of data
if params.dict['b_viewer_train']:
shown_index = 2
print(f'Visualized patch has a position of {results["positions"][shown_index,...]}')
utils_misc.show_results(results['images'][shown_index, ...], results['labels'][shown_index, ...], results['probs'][shown_index, ...])
#pbar.update(1)
except tf.errors.OutOfRangeError:
#pbar.close()
break
# fetch aggregated metrics values
results = sess.run(fetched_metrics_train)
writer.add_summary(results['summary_metrics'], global_step=epoch) # write metrics per epoch
# save model every epoch:
# if epoch % 5 == 0:
save_path = saver_recent.save(sess,
os.path.join(params.dict['dir_ckpts'], 'model.ckpt'),
global_step=epoch)
logging.info(f'Model saved in {save_path}')