|
a |
|
b/fetal/train_fetal.py |
|
|
1 |
import json |
|
|
2 |
import os |
|
|
3 |
import glob |
|
|
4 |
|
|
|
5 |
import fetal_net |
|
|
6 |
import fetal_net.preprocess |
|
|
7 |
import fetal_net.metrics |
|
|
8 |
from fetal.config_utils import get_config |
|
|
9 |
from fetal.utils import get_last_model_path, fetch_training_data_files, create_data_file |
|
|
10 |
from fetal_net.data import write_data_to_file, open_data_file |
|
|
11 |
from fetal_net.generator import get_training_and_validation_generators |
|
|
12 |
from fetal_net.model.fetal_net import fetal_envelope_model |
|
|
13 |
from fetal_net.training import load_old_model, train_model |
|
|
14 |
|
|
|
15 |
config = get_config() |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def main(overwrite=False): |
|
|
19 |
# convert input images into an hdf5 file |
|
|
20 |
if overwrite or not os.path.exists(config["data_file"]): |
|
|
21 |
create_data_file(config) |
|
|
22 |
|
|
|
23 |
data_file_opened = open_data_file(config["data_file"]) |
|
|
24 |
|
|
|
25 |
if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0: |
|
|
26 |
model_path = get_last_model_path(config["model_file"]) |
|
|
27 |
print('Loading model from: {}'.format(model_path)) |
|
|
28 |
model = load_old_model(model_path) |
|
|
29 |
else: |
|
|
30 |
# instantiate new model |
|
|
31 |
loss_func = getattr(fetal_net.metrics, config['loss']) |
|
|
32 |
model_func = getattr(fetal_net.model, config['model_name']) |
|
|
33 |
model = model_func(input_shape=config["input_shape"], |
|
|
34 |
initial_learning_rate=config["initial_learning_rate"], |
|
|
35 |
**{'dropout_rate': config['dropout_rate'], |
|
|
36 |
'loss_function': loss_func, |
|
|
37 |
'mask_shape': None if config["weight_mask"] is None else config["input_shape"], |
|
|
38 |
# TODO: change to output shape |
|
|
39 |
'old_model_path': config['old_model']}) |
|
|
40 |
if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0: |
|
|
41 |
model_path = get_last_model_path(config["model_file"]) |
|
|
42 |
print('Loading model from: {}'.format(model_path)) |
|
|
43 |
model.load_weights(model_path) |
|
|
44 |
model.summary() |
|
|
45 |
|
|
|
46 |
# get training and testing generators |
|
|
47 |
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators( |
|
|
48 |
data_file_opened, |
|
|
49 |
batch_size=config["batch_size"], |
|
|
50 |
data_split=config["validation_split"], |
|
|
51 |
overwrite=overwrite, |
|
|
52 |
validation_keys_file=config["validation_file"], |
|
|
53 |
training_keys_file=config["training_file"], |
|
|
54 |
test_keys_file=config["test_file"], |
|
|
55 |
n_labels=config["n_labels"], |
|
|
56 |
labels=config["labels"], |
|
|
57 |
patch_shape=(*config["patch_shape"], config["patch_depth"]), |
|
|
58 |
validation_batch_size=config["validation_batch_size"], |
|
|
59 |
augment=config["augment"], |
|
|
60 |
skip_blank_train=config["skip_blank_train"], |
|
|
61 |
skip_blank_val=config["skip_blank_val"], |
|
|
62 |
truth_index=config["truth_index"], |
|
|
63 |
truth_size=config["truth_size"], |
|
|
64 |
prev_truth_index=config["prev_truth_index"], |
|
|
65 |
prev_truth_size=config["prev_truth_size"], |
|
|
66 |
truth_downsample=config["truth_downsample"], |
|
|
67 |
truth_crop=config["truth_crop"], |
|
|
68 |
patches_per_epoch=config["patches_per_epoch"], |
|
|
69 |
categorical=config["categorical"], is3d=config["3D"], |
|
|
70 |
drop_easy_patches_train=config["drop_easy_patches_train"], |
|
|
71 |
drop_easy_patches_val=config["drop_easy_patches_val"]) |
|
|
72 |
|
|
|
73 |
# run training |
|
|
74 |
train_model(model=model, |
|
|
75 |
model_file=config["model_file"], |
|
|
76 |
training_generator=train_generator, |
|
|
77 |
validation_generator=validation_generator, |
|
|
78 |
steps_per_epoch=n_train_steps, |
|
|
79 |
validation_steps=n_validation_steps, |
|
|
80 |
initial_learning_rate=config["initial_learning_rate"], |
|
|
81 |
learning_rate_drop=config["learning_rate_drop"], |
|
|
82 |
learning_rate_patience=config["patience"], |
|
|
83 |
early_stopping_patience=config["early_stop"], |
|
|
84 |
n_epochs=config["n_epochs"], |
|
|
85 |
output_folder=config["base_dir"]) |
|
|
86 |
data_file_opened.close() |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
if __name__ == "__main__": |
|
|
90 |
main(overwrite=config["overwrite"]) |