Diff of /fetal/train_fetal.py [000000] .. [ccb1dd]

Switch to unified view

a b/fetal/train_fetal.py
1
import json
2
import os
3
import glob
4
5
import fetal_net
6
import fetal_net.preprocess
7
import fetal_net.metrics
8
from fetal.config_utils import get_config
9
from fetal.utils import get_last_model_path, fetch_training_data_files, create_data_file
10
from fetal_net.data import write_data_to_file, open_data_file
11
from fetal_net.generator import get_training_and_validation_generators
12
from fetal_net.model.fetal_net import fetal_envelope_model
13
from fetal_net.training import load_old_model, train_model
14
15
config = get_config()
16
17
18
def main(overwrite=False):
19
    # convert input images into an hdf5 file
20
    if overwrite or not os.path.exists(config["data_file"]):
21
        create_data_file(config)
22
23
    data_file_opened = open_data_file(config["data_file"])
24
25
    if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
26
        model_path = get_last_model_path(config["model_file"])
27
        print('Loading model from: {}'.format(model_path))
28
        model = load_old_model(model_path)
29
    else:
30
        # instantiate new model
31
        loss_func = getattr(fetal_net.metrics, config['loss'])
32
        model_func = getattr(fetal_net.model, config['model_name'])
33
        model = model_func(input_shape=config["input_shape"],
34
                           initial_learning_rate=config["initial_learning_rate"],
35
                           **{'dropout_rate': config['dropout_rate'],
36
                              'loss_function': loss_func,
37
                              'mask_shape': None if config["weight_mask"] is None else config["input_shape"],
38
                              # TODO: change to output shape
39
                              'old_model_path': config['old_model']})
40
        if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
41
            model_path = get_last_model_path(config["model_file"])
42
            print('Loading model from: {}'.format(model_path))
43
            model.load_weights(model_path)
44
    model.summary()
45
46
    # get training and testing generators
47
    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
48
        data_file_opened,
49
        batch_size=config["batch_size"],
50
        data_split=config["validation_split"],
51
        overwrite=overwrite,
52
        validation_keys_file=config["validation_file"],
53
        training_keys_file=config["training_file"],
54
        test_keys_file=config["test_file"],
55
        n_labels=config["n_labels"],
56
        labels=config["labels"],
57
        patch_shape=(*config["patch_shape"], config["patch_depth"]),
58
        validation_batch_size=config["validation_batch_size"],
59
        augment=config["augment"],
60
        skip_blank_train=config["skip_blank_train"],
61
        skip_blank_val=config["skip_blank_val"],
62
        truth_index=config["truth_index"],
63
        truth_size=config["truth_size"],
64
        prev_truth_index=config["prev_truth_index"],
65
        prev_truth_size=config["prev_truth_size"],
66
        truth_downsample=config["truth_downsample"],
67
        truth_crop=config["truth_crop"],
68
        patches_per_epoch=config["patches_per_epoch"],
69
        categorical=config["categorical"], is3d=config["3D"],
70
        drop_easy_patches_train=config["drop_easy_patches_train"],
71
        drop_easy_patches_val=config["drop_easy_patches_val"])
72
73
    # run training
74
    train_model(model=model,
75
                model_file=config["model_file"],
76
                training_generator=train_generator,
77
                validation_generator=validation_generator,
78
                steps_per_epoch=n_train_steps,
79
                validation_steps=n_validation_steps,
80
                initial_learning_rate=config["initial_learning_rate"],
81
                learning_rate_drop=config["learning_rate_drop"],
82
                learning_rate_patience=config["patience"],
83
                early_stopping_patience=config["early_stop"],
84
                n_epochs=config["n_epochs"],
85
                output_folder=config["base_dir"])
86
    data_file_opened.close()
87
88
89
if __name__ == "__main__":
90
    main(overwrite=config["overwrite"])