|
a |
|
b/src/run/trainer.py |
|
|
1 |
""" trainer python script |
|
|
2 |
|
|
|
3 |
This script allows training the proposed lfbnet model. |
|
|
4 |
This script requires to specify the directory path to the preprocessed PET MIP images. It could read the patient ids |
|
|
5 |
from |
|
|
6 |
the given directory path, or it could accept patient ids as .xls and .csv files. please provide the directory path to |
|
|
7 |
the |
|
|
8 |
csv or xls file. It assumes the csv/xls file have two columns with level 'train' and 'valid' indicating the training and |
|
|
9 |
validation patient ids respectively. |
|
|
10 |
|
|
|
11 |
Please see the _name__ == '__main__': as example which is equivalent to: |
|
|
12 |
|
|
|
13 |
e.g.train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/" |
|
|
14 |
train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/' |
|
|
15 |
train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv) |
|
|
16 |
|
|
|
17 |
trainer = NetworkTrainer( |
|
|
18 |
folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir, |
|
|
19 |
ids_to_read_train=train_ids, |
|
|
20 |
ids_to_read_valid=valid_ids |
|
|
21 |
) |
|
|
22 |
trainer.train() |
|
|
23 |
""" |
|
|
24 |
# Import libraries |
|
|
25 |
import os |
|
|
26 |
import glob |
|
|
27 |
import sys |
|
|
28 |
import time |
|
|
29 |
from datetime import datetime |
|
|
30 |
|
|
|
31 |
import numpy as np |
|
|
32 |
from numpy.random import seed |
|
|
33 |
from random import randint |
|
|
34 |
from tqdm import tqdm |
|
|
35 |
from typing import Tuple, List |
|
|
36 |
from numpy import ndarray |
|
|
37 |
from copy import deepcopy |
|
|
38 |
from medpy.metric import binary |
|
|
39 |
import matplotlib.pyplot as plt |
|
|
40 |
from keras import backend as K |
|
|
41 |
import re |
|
|
42 |
|
|
|
43 |
# make LFBNet as parent directory, for absolute import libraries. local application import. |
|
|
44 |
p = os.path.abspath('../..') |
|
|
45 |
if p not in sys.path: |
|
|
46 |
sys.path.append(p) |
|
|
47 |
|
|
|
48 |
# import LFBNet modules |
|
|
49 |
from src.LFBNet.data_loader import DataLoader |
|
|
50 |
from src.LFBNet.network_architecture import lfbnet |
|
|
51 |
from src.LFBNet.losses import losses |
|
|
52 |
from src.LFBNet.preprocessing import save_nii_images |
|
|
53 |
from src.LFBNet.utilities import train_valid_paths |
|
|
54 |
from src.LFBNet.postprocessing import remove_outliers_in_sagittal |
|
|
55 |
# choose cuda gpu |
|
|
56 |
CUDA_VISIBLE_DEVICES = 1 |
|
|
57 |
os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
|
|
58 |
|
|
|
59 |
# set randomness repetable across experiments. |
|
|
60 |
seed(1) |
|
|
61 |
|
|
|
62 |
# Define the parameters of the data to process |
|
|
63 |
K.set_image_data_format('channels_last') |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def default_training_parameters( |
|
|
67 |
num_epochs: int = 5000, batch_size: int = 16, early_stop: int = None, fold_number: int = None, |
|
|
68 |
model_name_save: List[str] = None, loss: str = None, metric: str = None |
|
|
69 |
) -> dict: |
|
|
70 |
""" Configure default parameters for training. |
|
|
71 |
Training parameters are setted here. For other options, the user should modifier these values. |
|
|
72 |
Parameters |
|
|
73 |
---------- |
|
|
74 |
num_epochs: int, maximum number of epochs to train the model. |
|
|
75 |
batch_size: int, number of images per batch |
|
|
76 |
early_stop: int, the number of training epochs the model should train while it is not improving the accuracy. |
|
|
77 |
fold_number: int, optional, fold number while applying cross-validation-based training. |
|
|
78 |
model_name_save: str, model name to save |
|
|
79 |
loss: str, loss funciton |
|
|
80 |
metric: str, specify the metric, such as dice |
|
|
81 |
|
|
|
82 |
Returns |
|
|
83 |
------- |
|
|
84 |
Returns configured dictionary for the training. |
|
|
85 |
|
|
|
86 |
""" |
|
|
87 |
if early_stop is None: |
|
|
88 |
# early stop 50 % of the maximum number of epochs |
|
|
89 |
early_stop = int(num_epochs * 0.5) |
|
|
90 |
|
|
|
91 |
if fold_number is None: |
|
|
92 |
fold_number = 'fold_run_at_' + str(time.time()) |
|
|
93 |
|
|
|
94 |
if model_name_save is None: |
|
|
95 |
model_name_save = ["forward_" + str(time.time()), "feedback_" + str(time.time())] |
|
|
96 |
|
|
|
97 |
if loss is None: |
|
|
98 |
loss = losses.LossMetric.dice_plus_binary_cross_entropy_loss |
|
|
99 |
|
|
|
100 |
if metric is None: |
|
|
101 |
metric = losses.LossMetric.dice_metric |
|
|
102 |
|
|
|
103 |
config_trainer = {'num_epochs': num_epochs, 'batch_size': batch_size, 'num_early_stop': early_stop, |
|
|
104 |
'fold_number': fold_number, 'model_name_save_forward': model_name_save[0], |
|
|
105 |
'model_name_save_feedback': model_name_save[1], "custom_loss": loss, "custom_dice": metric} |
|
|
106 |
|
|
|
107 |
return config_trainer |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def get_training_and_validation_ids_from_csv(path): |
|
|
111 |
""" Get training and validation ids from a given csv or xls file. Assuming the training ids are given with column |
|
|
112 |
name 'train' and validation ids in 'valid' |
|
|
113 |
|
|
|
114 |
Parameters |
|
|
115 |
---------- |
|
|
116 |
path: directory path to the csv or xls file. |
|
|
117 |
|
|
|
118 |
Returns |
|
|
119 |
------- |
|
|
120 |
Returns training and validation patient ids. |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
""" |
|
|
124 |
ids = train_valid_paths.read_csv_train_valid_index(path) |
|
|
125 |
train, valid = ids[0], ids[1] |
|
|
126 |
return train, valid |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def get_train_valid_ids_from_folder(path_train_valid, ratio_valid_data=0.25): |
|
|
130 |
""" Returns the randomly split training and validation patient ids. The percentage of validation is given by the |
|
|
131 |
ratio_valid_data. |
|
|
132 |
|
|
|
133 |
Parameters |
|
|
134 |
---------- |
|
|
135 |
path_train_valid |
|
|
136 |
ratio_valid_data |
|
|
137 |
|
|
|
138 |
Returns |
|
|
139 |
------- |
|
|
140 |
Returns training patient id and validation patient ids respectively as in two array. |
|
|
141 |
|
|
|
142 |
""" |
|
|
143 |
# given training and validation data on one folder, random splitting with .25% : train, valid |
|
|
144 |
if len(path_train_valid) == 1: |
|
|
145 |
all_cases_id = os.listdir(str(path_train_valid)) # all patients id |
|
|
146 |
|
|
|
147 |
# make permutation in the given list |
|
|
148 |
case_ids = np.array(all_cases_id) |
|
|
149 |
indices = np.random.permutation(len(case_ids)) |
|
|
150 |
num_valid_data = int(ratio_valid_data * len(all_cases_id)) |
|
|
151 |
|
|
|
152 |
train, valid = indices[num_valid_data:], indices[:num_valid_data] |
|
|
153 |
return [train, valid] |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
class NetworkTrainer: |
|
|
157 |
""" |
|
|
158 |
class to train the lfb net |
|
|
159 |
""" |
|
|
160 |
# keep the best loss and dice while training : Value shared across all instances, methods |
|
|
161 |
BEST_METRIC_VALIDATION = 0 # KEEP THE BEST VALIDATION METRIC SUCH AS THE DICE METRIC (BEST_DICE) |
|
|
162 |
BEST_LOSS_VALIDATION = 100 # KEEP THE BEST VALIDATION LOSS SUCH AS THE LOSS VALUES (BEST_LOSS) |
|
|
163 |
EARLY_STOP_COUNT = 0 # COUNTS THE NUMBER OF TRAINING ITERATIONS THE MODEL DID NOT INCREASE, TO COMPARE WITH THE |
|
|
164 |
now = datetime.now() # current time, date, month, |
|
|
165 |
TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime()) |
|
|
166 |
|
|
|
167 |
# EARLY STOP CRITERIA |
|
|
168 |
|
|
|
169 |
def __init__( |
|
|
170 |
self, config_trainer: dict = None, folder_preprocessed_train: str = '../data/train/', |
|
|
171 |
folder_preprocessed_valid: str = '../data/valid/', ids_to_read_train: ndarray = None, |
|
|
172 |
ids_to_read_valid: ndarray = None, task: str = 'valid', predicted_directory: str = '../data/predicted/', |
|
|
173 |
save_predicted: bool = False |
|
|
174 |
): |
|
|
175 |
""" |
|
|
176 |
|
|
|
177 |
:param config_trainer: |
|
|
178 |
:param folder_preprocessed_train: |
|
|
179 |
:param folder_preprocessed_valid: |
|
|
180 |
:param ids_to_read_train: |
|
|
181 |
:param ids_to_read_valid: |
|
|
182 |
:param task: |
|
|
183 |
:predicted_directory: |
|
|
184 |
:save_predicted |
|
|
185 |
""" |
|
|
186 |
|
|
|
187 |
if config_trainer is None: |
|
|
188 |
self.config_trainer = deepcopy(default_training_parameters()) |
|
|
189 |
|
|
|
190 |
# training data |
|
|
191 |
self.folder_preprocessed_train = folder_preprocessed_train |
|
|
192 |
if ids_to_read_train is None: |
|
|
193 |
ids_to_read_train = os.listdir(folder_preprocessed_train) |
|
|
194 |
|
|
|
195 |
self.ids_to_read_train = ids_to_read_train |
|
|
196 |
|
|
|
197 |
# validation data |
|
|
198 |
self.folder_preprocessed_valid = folder_preprocessed_valid |
|
|
199 |
if ids_to_read_valid is None: |
|
|
200 |
ids_to_read_valid = os.listdir(folder_preprocessed_valid) |
|
|
201 |
self.ids_to_read_valid = ids_to_read_valid |
|
|
202 |
|
|
|
203 |
# save predicted directory: |
|
|
204 |
self.save_all = save_predicted |
|
|
205 |
self.predicted_directory = predicted_directory |
|
|
206 |
# load the lfb_network architecture |
|
|
207 |
self.model = lfbnet.LfbNet() |
|
|
208 |
self.task = task |
|
|
209 |
|
|
|
210 |
# forward network decoder |
|
|
211 |
|
|
|
212 |
# latent feedback at zero time: means no feedback from feedback network |
|
|
213 |
self.latent_dim = self.model.latent_dim |
|
|
214 |
self.h_at_zero_time = np.zeros( |
|
|
215 |
(int(self.config_trainer['batch_size']), int(self.latent_dim[0]), int(self.latent_dim[1]), |
|
|
216 |
int(self.latent_dim[2])), np.float32 |
|
|
217 |
) |
|
|
218 |
|
|
|
219 |
@staticmethod |
|
|
220 |
def load_dataset(directory_: str = None, ids_to_read: List[str] = None): |
|
|
221 |
""" |
|
|
222 |
|
|
|
223 |
:param ids_to_read: |
|
|
224 |
:param directory_: |
|
|
225 |
""" |
|
|
226 |
# load batch of data |
|
|
227 |
data_loader = DataLoader(data_dir=directory_, ids_to_read=ids_to_read) |
|
|
228 |
image_batch_ground_truth_batch = data_loader.get_batch_of_data() |
|
|
229 |
|
|
|
230 |
batch_input_data, batch_output_data = image_batch_ground_truth_batch[0], image_batch_ground_truth_batch[1] |
|
|
231 |
# expand dimension for the channel |
|
|
232 |
batch_output_data = np.expand_dims(batch_output_data, axis=-1) |
|
|
233 |
batch_input_data = np.expand_dims(batch_input_data, axis=-1) |
|
|
234 |
|
|
|
235 |
return batch_input_data, batch_output_data |
|
|
236 |
|
|
|
237 |
def load_latest_weight(self): |
|
|
238 |
""" loads the weights of the model with the latest saved weight in the folder ./weight |
|
|
239 |
""" |
|
|
240 |
# load the last trained weight in the folder weight |
|
|
241 |
folder_path = r'./weight/' |
|
|
242 |
file_type = r'\*.h5' |
|
|
243 |
files = glob.glob(folder_path + file_type) |
|
|
244 |
try: |
|
|
245 |
max_file = max(files, key=os.path.getctime) |
|
|
246 |
except: |
|
|
247 |
raise Exception("weight could not found !") |
|
|
248 |
|
|
|
249 |
base_name = str(os.path.basename(max_file)) |
|
|
250 |
print(base_name) |
|
|
251 |
self.model.combine_and_train.load_weights('./weight/forward_system' + str(base_name.split('system')[1])) |
|
|
252 |
# f |
|
|
253 |
self.model.fcn_feedback.load_weights('./weight/feedback_system' + str(base_name.split('system')[1])) |
|
|
254 |
|
|
|
255 |
def train(self): |
|
|
256 |
"""Train the model |
|
|
257 |
""" |
|
|
258 |
|
|
|
259 |
batch_size = self.config_trainer['batch_size'] |
|
|
260 |
# self.load_latest_weight() |
|
|
261 |
# training |
|
|
262 |
if self.task == 'train': |
|
|
263 |
# training |
|
|
264 |
for current_epoch in range(self.config_trainer['num_epochs']): |
|
|
265 |
feedback_loss_dice = [] |
|
|
266 |
forward_loss_dice = [] |
|
|
267 |
forward_decoder_loss_dice = [] |
|
|
268 |
|
|
|
269 |
# shuffle the index of the training data |
|
|
270 |
index_read = np.random.permutation(int(len(self.ids_to_read_train))) |
|
|
271 |
# read data |
|
|
272 |
for selected_patient in range(len(index_read)): |
|
|
273 |
# get index of batch of data |
|
|
274 |
start = selected_patient * batch_size |
|
|
275 |
idx_list_batch = index_read[start:start + batch_size] |
|
|
276 |
# if there are still elements in the given batch |
|
|
277 |
if idx_list_batch.size > 0: |
|
|
278 |
# get index of Why not ? kk = indx_list_batch |
|
|
279 |
kk = [str(k) for i, k in enumerate(self.ids_to_read_train) if i in idx_list_batch] |
|
|
280 |
|
|
|
281 |
batch_input_data, batch_output_data = self.load_dataset( |
|
|
282 |
directory_=self.folder_preprocessed_train, ids_to_read=kk |
|
|
283 |
) |
|
|
284 |
|
|
|
285 |
assert len(batch_input_data) > 0, "batch of data not loaded correctly" |
|
|
286 |
|
|
|
287 |
# shuffle within the batch |
|
|
288 |
index_batch = np.random.permutation(int(batch_input_data.shape[0])) |
|
|
289 |
batch_input_data = batch_input_data[index_batch] |
|
|
290 |
batch_output_data = batch_output_data[index_batch] |
|
|
291 |
|
|
|
292 |
# batches per epoch: Selected batch might as in id could have more images than the batch size |
|
|
293 |
batch_per_epoch = int(batch_input_data.shape[0] / batch_size) |
|
|
294 |
for batch_per_epoch_ in range(batch_per_epoch): |
|
|
295 |
batch_input = batch_input_data[ |
|
|
296 |
batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size] |
|
|
297 |
batch_output = batch_output_data[ |
|
|
298 |
batch_per_epoch_ * batch_size:(batch_per_epoch_ + 1) * batch_size] |
|
|
299 |
|
|
|
300 |
# Train forward models |
|
|
301 |
if current_epoch % 2 == 0: |
|
|
302 |
# step 1: train the forward network encoder and decoder |
|
|
303 |
loss, dice = self.model.combine_and_train.train_on_batch( |
|
|
304 |
[batch_input, self.h_at_zero_time], [batch_output] |
|
|
305 |
) # self.h_at_zero_time |
|
|
306 |
forward_loss_dice.append([loss, dice]) |
|
|
307 |
|
|
|
308 |
else: |
|
|
309 |
predicted_decoder = self.model.combine_and_train.predict( |
|
|
310 |
[batch_input, self.h_at_zero_time] |
|
|
311 |
) # , self.h_at_zero_time |
|
|
312 |
|
|
|
313 |
# step 2: train the feedback network, considering the output of the forward network |
|
|
314 |
loss, dice = self.model.fcn_feedback.train_on_batch(predicted_decoder, batch_output) |
|
|
315 |
feedback_loss_dice.append([loss, dice]) |
|
|
316 |
|
|
|
317 |
# Step 3: train the forward decoder, considering the trained |
|
|
318 |
feedback_latent_result = self.model.feedback_latent.predict([predicted_decoder]) |
|
|
319 |
forward_encoder_output = self.model.forward_encoder.predict([batch_input]) |
|
|
320 |
|
|
|
321 |
# forward_encoder_output.insert(1, feedback_latent_result) |
|
|
322 |
forward_encoder_output = forward_encoder_output[::-1] # bottleneck should be first |
|
|
323 |
forward_encoder_output.insert(1, feedback_latent_result) |
|
|
324 |
loss, dice = self.model.forward_decoder.train_on_batch( |
|
|
325 |
[output for output in forward_encoder_output], [batch_output] |
|
|
326 |
) |
|
|
327 |
forward_decoder_loss_dice.append([loss, dice]) |
|
|
328 |
|
|
|
329 |
forward_loss_dice = np.array(forward_loss_dice) |
|
|
330 |
feedback_loss_dice = np.array(feedback_loss_dice) |
|
|
331 |
forward_decoder_loss_dice = np.array(forward_decoder_loss_dice) |
|
|
332 |
|
|
|
333 |
if current_epoch % 2 == 0: |
|
|
334 |
loss, dice = np.mean(forward_loss_dice, axis=0) |
|
|
335 |
print( |
|
|
336 |
'Training_forward_system: >%d, ' |
|
|
337 |
' fwd_loss = %.3f, fwd_dice=%0.3f, ' % (current_epoch, loss, dice) |
|
|
338 |
) |
|
|
339 |
|
|
|
340 |
else: |
|
|
341 |
loss_forward, dice_forward = np.mean(forward_decoder_loss_dice, axis=0) |
|
|
342 |
loss_feedback, dice_feedback = np.mean(feedback_loss_dice, axis=0) |
|
|
343 |
|
|
|
344 |
print( |
|
|
345 |
'Training_forward_decoder_and_feedback_system: >%d, ' |
|
|
346 |
'fwd_decoder_loss=%03f, ' |
|
|
347 |
'fwd_decoder_dice=%0.3f ' |
|
|
348 |
|
|
|
349 |
'fdb_loss=%03f, ' |
|
|
350 |
'fdb_dice=%.3f ' % (current_epoch, loss_forward, dice_forward, loss_feedback, dice_feedback) |
|
|
351 |
) |
|
|
352 |
# validation test: |
|
|
353 |
self.validation(current_epoch=current_epoch) |
|
|
354 |
|
|
|
355 |
# CHECK TRAINING STOPPING CRITERIA: maximum number of epochs (epoch - 1), meet early stop |
|
|
356 |
if NetworkTrainer.EARLY_STOP_COUNT == self.config_trainer['num_early_stop']: |
|
|
357 |
# save model with early stop identification |
|
|
358 |
self.model.combine_and_train.save( |
|
|
359 |
'weight/forward_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' |
|
|
360 |
) |
|
|
361 |
self.model.fcn_feedback.save( |
|
|
362 |
'weight/feedback_system_early_stopped_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' |
|
|
363 |
) |
|
|
364 |
break # STOP TRAINING WITH BREAK, OR EXIT TRAINING |
|
|
365 |
|
|
|
366 |
# if not training load the last saved weights, and check validation |
|
|
367 |
elif self.task == 'valid': |
|
|
368 |
# load the last trained weight in the folder weight |
|
|
369 |
self.load_latest_weight() |
|
|
370 |
self.validation(current_epoch=self.config_trainer['num_epochs']) |
|
|
371 |
|
|
|
372 |
def validation(self, verbose: int = 0, current_epoch: int = None): |
|
|
373 |
""" |
|
|
374 |
Compute the validation dice, loss of the training from the validation data |
|
|
375 |
""" |
|
|
376 |
# path to the validation data, if not specified, the default path ../data/valid/ would be considered |
|
|
377 |
|
|
|
378 |
folder_preprocessed = self.folder_preprocessed_valid |
|
|
379 |
|
|
|
380 |
# image folder names, or identifier: if not specified the default values would be the name of the folder inside |
|
|
381 |
# the directory "folder processed" or the self.folder_processed_valid : |
|
|
382 |
|
|
|
383 |
valid_identifier = self.ids_to_read_valid |
|
|
384 |
|
|
|
385 |
''' |
|
|
386 |
WE CAN IMPLEMENT THE EVALUATION METHOD AS BATCH BASED, PATIENT BASED, OR THE WHOLE-VALIDATION DATA BASED. FOR |
|
|
387 |
THE LAST OPTION WE NEED TO IMPLEMENT THE EVALUATION() FUNCTION HERE. |
|
|
388 |
''' |
|
|
389 |
|
|
|
390 |
'''' |
|
|
391 |
declare variables to return: |
|
|
392 |
forward loss and dice with h0 (no feedback), |
|
|
393 |
feedback network loss and dice |
|
|
394 |
forward decoder loss and dice, |
|
|
395 |
forward loss and dice with ht (with feedback latent space) |
|
|
396 |
''' |
|
|
397 |
loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], |
|
|
398 |
'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} |
|
|
399 |
|
|
|
400 |
all_dice_sen_sep = {'dice': [], 'specificity': [], 'sensitivity': []} |
|
|
401 |
|
|
|
402 |
# load the dataset, |
|
|
403 |
# get the validation ids |
|
|
404 |
for id_to_validate in valid_identifier: |
|
|
405 |
try: |
|
|
406 |
id_to_validate = str(id_to_validate).split('.')[0] |
|
|
407 |
except: |
|
|
408 |
pass |
|
|
409 |
|
|
|
410 |
valid_input, valid_output = self.load_dataset(directory_=folder_preprocessed, ids_to_read=[id_to_validate]) |
|
|
411 |
|
|
|
412 |
if len(valid_input) == 0: |
|
|
413 |
print("data %s not read" % id_to_validate) |
|
|
414 |
continue |
|
|
415 |
|
|
|
416 |
results, dice_sen_sep = self.evaluation( |
|
|
417 |
input_image=valid_input.copy(), ground_truth=valid_output.copy(), case_name=str(id_to_validate) |
|
|
418 |
) |
|
|
419 |
|
|
|
420 |
# append all loss to loss and dice to dice from all cases in valid identifiers |
|
|
421 |
for keys in results.keys(): |
|
|
422 |
loss_dice[str(keys)].append(results[str(keys)][0]) |
|
|
423 |
|
|
|
424 |
for keys in dice_sen_sep.keys(): |
|
|
425 |
all_dice_sen_sep[str(keys)].append(dice_sen_sep[str(keys)][0]) |
|
|
426 |
|
|
|
427 |
print("\n Dice, sensitivity, specificity \t") |
|
|
428 |
for k, v in all_dice_sen_sep.items(): |
|
|
429 |
print('%s : %0.3f ' % (k, np.mean(list(v), axis=0)), end=" ") |
|
|
430 |
print("\n") |
|
|
431 |
|
|
|
432 |
""" |
|
|
433 |
print the mean of the validation loss and validation dice |
|
|
434 |
""" |
|
|
435 |
|
|
|
436 |
# FOR STOPPING CRITERIA WE ARE USING THE MODEL AT THE 3RD STEP |
|
|
437 |
dice_mean = np.mean(loss_dice['dice_fwd_ht']) |
|
|
438 |
loss_mean = np.mean(loss_dice['loss_fwd_ht']) |
|
|
439 |
|
|
|
440 |
# at the first epoch |
|
|
441 |
if current_epoch == 0: |
|
|
442 |
NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean |
|
|
443 |
NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean |
|
|
444 |
|
|
|
445 |
# compare the current dice and loss with the previous epoch's loss and dice: |
|
|
446 |
# NOW CONSIDER DICE AS OPTIMIZATION METRIC |
|
|
447 |
print("Current validation loss and metrics at epoch %d: >> " % current_epoch, end=" ") |
|
|
448 |
for k, v in loss_dice.items(): |
|
|
449 |
print('%s : %0.3f ' % (k, np.mean(v)), end=" ") |
|
|
450 |
print("\n") |
|
|
451 |
|
|
|
452 |
if NetworkTrainer.BEST_METRIC_VALIDATION <= dice_mean: |
|
|
453 |
# reset early stop count, best dice, and best loss values |
|
|
454 |
NetworkTrainer.BEST_LOSS_VALIDATION = loss_mean |
|
|
455 |
NetworkTrainer.BEST_METRIC_VALIDATION = dice_mean |
|
|
456 |
NetworkTrainer.EARLY_STOP_COUNT = 0 |
|
|
457 |
|
|
|
458 |
# save the best model weights |
|
|
459 |
if not os.path.exists('./weight'): |
|
|
460 |
os.mkdir('./weight') |
|
|
461 |
|
|
|
462 |
self.model.combine_and_train.save( |
|
|
463 |
'weight/forward_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' |
|
|
464 |
) |
|
|
465 |
self.model.fcn_feedback.save( |
|
|
466 |
'weight/feedback_system_' + str(NetworkTrainer.TRAINED_MODEL_IDENTIFIER) + '_.h5' |
|
|
467 |
) |
|
|
468 |
else: # just print the current validation metric (dice) and loss, and count early stop |
|
|
469 |
# Increase the early stop count per epoch |
|
|
470 |
NetworkTrainer.EARLY_STOP_COUNT += 1 |
|
|
471 |
|
|
|
472 |
print( |
|
|
473 |
'\n Best model on validation data : %0.3f : Dice: %0.3f \n' % ( |
|
|
474 |
NetworkTrainer.BEST_LOSS_VALIDATION, NetworkTrainer.BEST_METRIC_VALIDATION) |
|
|
475 |
) |
|
|
476 |
|
|
|
477 |
def evaluation( |
|
|
478 |
self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None, |
|
|
479 |
validation_or_test: str = 'test', case_name: str = None |
|
|
480 |
): |
|
|
481 |
""" |
|
|
482 |
|
|
|
483 |
:param case_name: |
|
|
484 |
:param validation_or_test: |
|
|
485 |
:param verbose: |
|
|
486 |
:param input_image: |
|
|
487 |
:param ground_truth: |
|
|
488 |
|
|
|
489 |
Parameters |
|
|
490 |
---------- |
|
|
491 |
save_all |
|
|
492 |
""" |
|
|
493 |
'''' |
|
|
494 |
declare variables to return: |
|
|
495 |
forward loss and dice with h0 (no feedback), |
|
|
496 |
feedback network loss and dice |
|
|
497 |
forward decoder loss and dice, |
|
|
498 |
forward loss and dice with ht (with feedback latent space) |
|
|
499 |
''' |
|
|
500 |
all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], |
|
|
501 |
'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} |
|
|
502 |
|
|
|
503 |
dice_sen_sp = {'dice': [], 'specificity': [], 'sensitivity': []} |
|
|
504 |
|
|
|
505 |
# latent feedback variable h0 |
|
|
506 |
# replace the first number of batches with the number of input images from the first channel |
|
|
507 |
h0_input = np.zeros( |
|
|
508 |
(len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32 |
|
|
509 |
) |
|
|
510 |
|
|
|
511 |
# step 0: |
|
|
512 |
# Loss and dice on the validation of the forward system |
|
|
513 |
loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose) |
|
|
514 |
all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice) |
|
|
515 |
|
|
|
516 |
# predict from the forward system |
|
|
517 |
predicted = self.model.combine_and_train.predict([input_image, h0_input]) |
|
|
518 |
|
|
|
519 |
# step 2: |
|
|
520 |
# Loss and dice on the validation of the feedback system |
|
|
521 |
loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose) |
|
|
522 |
all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice) |
|
|
523 |
|
|
|
524 |
# step 3: |
|
|
525 |
feedback_latent = self.model.feedback_latent.predict(predicted) # feedback: hf |
|
|
526 |
forward_encoder_output = self.model.forward_encoder.predict([input_image]) # forward system's encoder output |
|
|
527 |
|
|
|
528 |
forward_encoder_output = forward_encoder_output[::-1] # bottleneck should be first |
|
|
529 |
forward_encoder_output.insert(1, feedback_latent) |
|
|
530 |
loss, dice = self.model.forward_decoder.evaluate( |
|
|
531 |
[output for output in forward_encoder_output], [ground_truth], verbose=verbose |
|
|
532 |
) |
|
|
533 |
all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice) |
|
|
534 |
|
|
|
535 |
# loss and dice from the combined and feed back latent space : input [input_image, fdb_latent_space] |
|
|
536 |
loss, dice = self.model.combine_and_train.evaluate( |
|
|
537 |
[input_image, feedback_latent], [ground_truth], verbose=verbose |
|
|
538 |
) |
|
|
539 |
all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice) |
|
|
540 |
""" |
|
|
541 |
For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during |
|
|
542 |
the validation cases |
|
|
543 |
""" |
|
|
544 |
predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) |
|
|
545 |
|
|
|
546 |
# binary.dc, sen, and specificty works only on binary images |
|
|
547 |
dice_sen_sp['dice'].append( |
|
|
548 |
binary.dc(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth)) |
|
|
549 |
) |
|
|
550 |
dice_sen_sp['sensitivity'].append( |
|
|
551 |
binary.sensitivity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth)) |
|
|
552 |
) |
|
|
553 |
dice_sen_sp['specificity'].append( |
|
|
554 |
binary.specificity(NetworkTrainer.threshold_image(predicted), NetworkTrainer.threshold_image(ground_truth)) |
|
|
555 |
) |
|
|
556 |
# all = np.concatenate((ground_truth, predicted, input_image), axis=0) |
|
|
557 |
# display_image(all) |
|
|
558 |
|
|
|
559 |
# Sometimes save predictions |
|
|
560 |
if self.save_all: |
|
|
561 |
predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) |
|
|
562 |
save_nii_images( |
|
|
563 |
[predicted, ground_truth, input_image], identifier=str(case_name), |
|
|
564 |
name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory |
|
|
565 |
) |
|
|
566 |
else: |
|
|
567 |
|
|
|
568 |
n = randint(0, 10) |
|
|
569 |
if n % 3 == 0: |
|
|
570 |
predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) |
|
|
571 |
save_nii_images( |
|
|
572 |
[predicted, ground_truth, input_image], identifier=str(case_name), |
|
|
573 |
name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_image"], path_save=self.predicted_directory |
|
|
574 |
) |
|
|
575 |
|
|
|
576 |
return all_loss_dice, dice_sen_sp |
|
|
577 |
|
|
|
578 |
@staticmethod |
|
|
579 |
def display_image(im_display: ndarray): |
|
|
580 |
""" display given images |
|
|
581 |
|
|
|
582 |
:param all: 2D image arrays to display |
|
|
583 |
:returns: display images |
|
|
584 |
""" |
|
|
585 |
plt.figure(figsize=(10, 8)) |
|
|
586 |
plt.subplots_adjust(hspace=0.5) |
|
|
587 |
plt.suptitle("Daily closing prices", fontsize=18, y=0.95) |
|
|
588 |
# loop through the length of tickers and keep track of index |
|
|
589 |
for n, im in enumerate(im_display): |
|
|
590 |
# add a new subplot iteratively |
|
|
591 |
plt.subplot(3, 2, n + 1) |
|
|
592 |
plt.imshow(im) # chart formatting |
|
|
593 |
plt.show() |
|
|
594 |
|
|
|
595 |
@staticmethod |
|
|
596 |
# binary.dc, sen, and specificty works only on binary images |
|
|
597 |
def threshold_image(im_: ndarray, thr_value: float = 0.5) -> ndarray: |
|
|
598 |
""" threshold given input array with the given thresholding value |
|
|
599 |
|
|
|
600 |
:param im_: ndarray of images |
|
|
601 |
:param thr_value: thresholding value |
|
|
602 |
:return: threshold array image |
|
|
603 |
""" |
|
|
604 |
im_[im_ > thr_value] = 1 |
|
|
605 |
im_[im_ < thr_value] = 0 |
|
|
606 |
return im_ |
|
|
607 |
|
|
|
608 |
|
|
|
609 |
class ModelTesting: |
|
|
610 |
""" performs prediction on a given data set. It predicts the segmentation results, and save the results, calculate |
|
|
611 |
the clinical metrics such as TMTV, Dmax, sTMTV, sDmax. |
|
|
612 |
|
|
|
613 |
""" |
|
|
614 |
now = datetime.now() # current time, date, month, |
|
|
615 |
TRAINED_MODEL_IDENTIFIER = re.sub('[ :]', "_", now.ctime()) |
|
|
616 |
print("current directory", os.getcwd()) |
|
|
617 |
|
|
|
618 |
def __init__( |
|
|
619 |
self, config_test: dict = None, preprocessed_dir: str = '../data/test/', data_list: List[str] = None, |
|
|
620 |
predicted_dir: str = "../data/predicted" |
|
|
621 |
): |
|
|
622 |
""" |
|
|
623 |
|
|
|
624 |
:param config_trainer: |
|
|
625 |
:param folder_preprocessed_train: |
|
|
626 |
:param folder_preprocessed_valid: |
|
|
627 |
:param ids_to_read_train: |
|
|
628 |
:param ids_to_read_valid: |
|
|
629 |
:param task: |
|
|
630 |
:param predicted_dir: |
|
|
631 |
""" |
|
|
632 |
|
|
|
633 |
if config_test is None: |
|
|
634 |
self.config_test = deepcopy(default_training_parameters()) |
|
|
635 |
|
|
|
636 |
# training data |
|
|
637 |
self.preprocessed_dir = preprocessed_dir |
|
|
638 |
self.predicted_dir = predicted_dir |
|
|
639 |
|
|
|
640 |
# if the list of testing cases are not given, get from the directory |
|
|
641 |
if data_list is None: |
|
|
642 |
data_list = os.listdir(preprocessed_dir) |
|
|
643 |
|
|
|
644 |
self.data_list = data_list |
|
|
645 |
|
|
|
646 |
# load the lfb_network architecture |
|
|
647 |
self.model = lfbnet.LfbNet() |
|
|
648 |
|
|
|
649 |
# latent feedback at zero time: means no feedback from feedback network |
|
|
650 |
self.latent_dim = self.model.latent_dim |
|
|
651 |
|
|
|
652 |
# load the last trained weight in the folder weight |
|
|
653 |
print(os.getcwd()) |
|
|
654 |
folder_path = os.path.join(os.getcwd(), 'src/weight') |
|
|
655 |
print(folder_path) |
|
|
656 |
|
|
|
657 |
full_path = [path_i for path_i in glob.glob(str(folder_path) + '/*.h5')] |
|
|
658 |
|
|
|
659 |
print("files \n", full_path) |
|
|
660 |
try: |
|
|
661 |
max_file = max(full_path, key=os.path.getctime) |
|
|
662 |
except: |
|
|
663 |
raise Exception("weight could not found !") |
|
|
664 |
|
|
|
665 |
base_name = str(os.path.basename(max_file)) |
|
|
666 |
print(base_name) |
|
|
667 |
self.model.combine_and_train.load_weights( |
|
|
668 |
str(folder_path) + '/forward_system' + str(base_name.split('system')[1]) |
|
|
669 |
) |
|
|
670 |
# f |
|
|
671 |
self.model.fcn_feedback.load_weights(str(folder_path) + '/feedback_system' + str(base_name.split('system')[1])) |
|
|
672 |
|
|
|
673 |
self.test() |
|
|
674 |
|
|
|
675 |
def test(self): |
|
|
676 |
""" |
|
|
677 |
Compute the validation dice, loss of the training from the validation data |
|
|
678 |
""" |
|
|
679 |
# path to the validation data, if not specified, the default path ../data/valid/ would be considered |
|
|
680 |
# |
|
|
681 |
folder_preprocessed = self.preprocessed_dir |
|
|
682 |
# image folder names, or identifier: if not specified the default values would be the name of the folder inside |
|
|
683 |
# the directory "folder processed" or the self.folder_processed_valid : |
|
|
684 |
test_identifier = self.data_list |
|
|
685 |
|
|
|
686 |
'''' |
|
|
687 |
declare variables to return if there is a reference segmentation or ground truth : |
|
|
688 |
forward loss and dice with h0 (no feedback), |
|
|
689 |
feedback network loss and dice |
|
|
690 |
forward decoder loss and dice, |
|
|
691 |
forward loss and dice with ht (with feedback latent space) |
|
|
692 |
''' |
|
|
693 |
loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], |
|
|
694 |
'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} |
|
|
695 |
|
|
|
696 |
# get the validation ids |
|
|
697 |
test_output = [] |
|
|
698 |
for id_to_test in tqdm(list(test_identifier)): |
|
|
699 |
test_input, test_output = NetworkTrainer.load_dataset( |
|
|
700 |
directory_=folder_preprocessed, ids_to_read=[id_to_test] |
|
|
701 |
) |
|
|
702 |
|
|
|
703 |
if len(test_input) == 0: |
|
|
704 |
print("data %s not read" % id_to_test) |
|
|
705 |
continue |
|
|
706 |
|
|
|
707 |
''' |
|
|
708 |
if there is a ground truth segmentation (gt), and you would like to compare with the predicted segmentation |
|
|
709 |
by the deep learning model |
|
|
710 |
''' |
|
|
711 |
|
|
|
712 |
if len(test_output): |
|
|
713 |
results = self.evaluation_test( |
|
|
714 |
input_image=test_input.copy(), ground_truth=test_output.copy(), case_name=str(id_to_test) |
|
|
715 |
) |
|
|
716 |
|
|
|
717 |
# append all loss to loss and dice to dice from all cases in valid identifiers |
|
|
718 |
for keys in results.keys(): |
|
|
719 |
loss_dice[str(keys)].append(results[str(keys)][0]) |
|
|
720 |
|
|
|
721 |
print("Results (sagittal and coronal) for case id: %s : >> " % id_to_test, end=" ") |
|
|
722 |
for k, v in loss_dice.items(): |
|
|
723 |
print('%s : %0.3f ' % (k, np.mean(v)), end=" ") |
|
|
724 |
print("\n") |
|
|
725 |
|
|
|
726 |
# Predict the segmentation and save in the folder predicted, dataset identifier |
|
|
727 |
else: |
|
|
728 |
self.prediction(input_image=test_input.copy(), case_name=str(id_to_test)) |
|
|
729 |
|
|
|
730 |
""" |
|
|
731 |
print the mean of the testing loss and dice if there is a ground truth, for all cases |
|
|
732 |
""" |
|
|
733 |
if len(test_output): |
|
|
734 |
print("Total dataset metrics: : >> ", end=" ") |
|
|
735 |
for k, v in loss_dice.items(): |
|
|
736 |
print('%s : %0.3f ' % (k, np.mean(v)), end=" ") |
|
|
737 |
print("\n") |
|
|
738 |
|
|
|
739 |
def evaluation_test( |
|
|
740 |
self, verbose: int = 0, input_image: ndarray = None, ground_truth: ndarray = None, |
|
|
741 |
validation_or_test: str = 'validate', case_name: str = None |
|
|
742 |
): |
|
|
743 |
""" |
|
|
744 |
|
|
|
745 |
:param case_name: |
|
|
746 |
:param validation_or_test: |
|
|
747 |
:param verbose: |
|
|
748 |
:param input_image: |
|
|
749 |
:param ground_truth: |
|
|
750 |
""" |
|
|
751 |
'''' |
|
|
752 |
declare variables to return: |
|
|
753 |
forward loss and dice with h0 (no feedback), |
|
|
754 |
feedback network loss and dice |
|
|
755 |
forward decoder loss and dice, |
|
|
756 |
forward loss and dice with ht (with feedback latent space) |
|
|
757 |
''' |
|
|
758 |
all_loss_dice = {'loss_fwd_h0': [], 'dice__fwd_h0': [], 'loss_fdb_h0': [], 'dice_fdb_h0': [], |
|
|
759 |
'loss_fwd_decoder': [], 'dice_fwd_decoder': [], 'loss_fwd_ht': [], 'dice_fwd_ht': []} |
|
|
760 |
# latent feedback variable h0 |
|
|
761 |
# replace the first number of batches with the number of input images from the first channel |
|
|
762 |
h0_input = np.zeros( |
|
|
763 |
(len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32 |
|
|
764 |
) |
|
|
765 |
|
|
|
766 |
# step 0: |
|
|
767 |
# Loss and dice on the validation of the forward system |
|
|
768 |
loss, dice = self.model.combine_and_train.evaluate([input_image, h0_input], [ground_truth], verbose=verbose) |
|
|
769 |
all_loss_dice['loss_fwd_h0'].append(loss), all_loss_dice['dice__fwd_h0'].append(dice) |
|
|
770 |
|
|
|
771 |
# predict from the forward system |
|
|
772 |
predicted = self.model.combine_and_train.predict([input_image, h0_input]) |
|
|
773 |
|
|
|
774 |
# step 2: |
|
|
775 |
# Loss and dice on the validation of the feedback system |
|
|
776 |
loss, dice = self.model.fcn_feedback.evaluate([predicted], [ground_truth], verbose=verbose) |
|
|
777 |
all_loss_dice['loss_fdb_h0'].append(loss), all_loss_dice['dice_fdb_h0'].append(dice) |
|
|
778 |
|
|
|
779 |
# step 3: |
|
|
780 |
feedback_latent = self.model.feedback_latent.predict(predicted) # feedback: hf |
|
|
781 |
forward_encoder_output = self.model.forward_encoder.predict([input_image]) # forward system's encoder output |
|
|
782 |
|
|
|
783 |
forward_encoder_output = forward_encoder_output[::-1] # bottleneck should be first |
|
|
784 |
forward_encoder_output.insert(1, feedback_latent) |
|
|
785 |
loss, dice = self.model.forward_decoder.evaluate( |
|
|
786 |
[output for output in forward_encoder_output], [ground_truth], verbose=verbose |
|
|
787 |
) |
|
|
788 |
all_loss_dice['loss_fwd_decoder'].append(loss), all_loss_dice['dice_fwd_decoder'].append(dice) |
|
|
789 |
|
|
|
790 |
# loss and dice from the combined and feed back latent space : input [input_image, fdb_latent_space] |
|
|
791 |
loss, dice = self.model.combine_and_train.evaluate( |
|
|
792 |
[input_image, feedback_latent], [ground_truth], verbose=verbose |
|
|
793 |
) |
|
|
794 |
all_loss_dice['loss_fwd_ht'].append(loss), all_loss_dice['dice_fwd_ht'].append(dice) |
|
|
795 |
|
|
|
796 |
""" |
|
|
797 |
For the testing time, we use defined metrics on the predicted images instead of using model.evaluate during |
|
|
798 |
the validation cases |
|
|
799 |
""" |
|
|
800 |
if validation_or_test == "test": |
|
|
801 |
# return [dice, specificity, and sensitivity |
|
|
802 |
return {'dice': binary.dc(predicted, ground_truth), |
|
|
803 |
'specificity': binary.specificity(predicted, ground_truth), |
|
|
804 |
'sensitivity': binary.sensitivity(predicted, ground_truth)} |
|
|
805 |
|
|
|
806 |
predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) |
|
|
807 |
predicted = remove_outliers_in_sagittal(predicted) |
|
|
808 |
save_nii_images( |
|
|
809 |
[predicted, ground_truth, input_image], identifier=str(case_name), |
|
|
810 |
name=[case_name + "_predicted", case_name + "_ground_truth", case_name + "_pet"], |
|
|
811 |
path_save= os.path.join(str(self.predicted_dir), 'predicted_data') |
|
|
812 |
) |
|
|
813 |
|
|
|
814 |
return all_loss_dice |
|
|
815 |
|
|
|
816 |
def prediction(self, input_image: ndarray = None, case_name: str = None): |
|
|
817 |
""" |
|
|
818 |
:param case_name: |
|
|
819 |
:param input_image: |
|
|
820 |
""" |
|
|
821 |
# latent feedback variable h0 |
|
|
822 |
# replace the first number of batches with the number of input images from the first channel |
|
|
823 |
h0_input = np.zeros( |
|
|
824 |
(len(input_image), int(self.latent_dim[0]), int(self.latent_dim[1]), int(self.latent_dim[2])), np.float32 |
|
|
825 |
) |
|
|
826 |
|
|
|
827 |
# STEP 1: forward system prediction |
|
|
828 |
# predict from the forward system |
|
|
829 |
predicted = self.model.combine_and_train.predict([input_image, h0_input]) |
|
|
830 |
|
|
|
831 |
# step 2: Feedback system prediction |
|
|
832 |
feedback_latent = self.model.feedback_latent.predict(predicted) # feedback: hf |
|
|
833 |
|
|
|
834 |
predicted = self.model.combine_and_train.predict([input_image, feedback_latent]) |
|
|
835 |
predicted = remove_outliers_in_sagittal(predicted) |
|
|
836 |
save_nii_images( |
|
|
837 |
image=[predicted, input_image], identifier=str(case_name), name=[case_name + "_predicted", |
|
|
838 |
case_name + "_pet"], |
|
|
839 |
path_save= os.path.join(str(self.predicted_dir), 'predicted_data') |
|
|
840 |
) |
|
|
841 |
|
|
|
842 |
|
|
|
843 |
if __name__ == '__main__': |
|
|
844 |
train_valid_data_dir = r"E:\LFBNet\data\remarc_default_MIP_dir/" |
|
|
845 |
train_valid_ids_path_csv = r'E:\LFBNet\data\csv\training_validation_indexs\remarc/' |
|
|
846 |
train_ids, valid_ids = get_training_and_validation_ids_from_csv(train_valid_ids_path_csv) |
|
|
847 |
|
|
|
848 |
trainer = NetworkTrainer( |
|
|
849 |
folder_preprocessed_train=train_valid_data_dir, folder_preprocessed_valid=train_valid_data_dir, |
|
|
850 |
ids_to_read_train=train_ids, ids_to_read_valid=valid_ids |
|
|
851 |
) |
|
|
852 |
trainer.train() |