--- a
+++ b/fetal/train_fetal.py
@@ -0,0 +1,90 @@
+import json
+import os
+import glob
+
+import fetal_net
+import fetal_net.preprocess
+import fetal_net.metrics
+from fetal.config_utils import get_config
+from fetal.utils import get_last_model_path, fetch_training_data_files, create_data_file
+from fetal_net.data import write_data_to_file, open_data_file
+from fetal_net.generator import get_training_and_validation_generators
+from fetal_net.model.fetal_net import fetal_envelope_model
+from fetal_net.training import load_old_model, train_model
+
+config = get_config()
+
+
+def main(overwrite=False):
+    # convert input images into an hdf5 file
+    if overwrite or not os.path.exists(config["data_file"]):
+        create_data_file(config)
+
+    data_file_opened = open_data_file(config["data_file"])
+
+    if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
+        model_path = get_last_model_path(config["model_file"])
+        print('Loading model from: {}'.format(model_path))
+        model = load_old_model(model_path)
+    else:
+        # instantiate new model
+        loss_func = getattr(fetal_net.metrics, config['loss'])
+        model_func = getattr(fetal_net.model, config['model_name'])
+        model = model_func(input_shape=config["input_shape"],
+                           initial_learning_rate=config["initial_learning_rate"],
+                           **{'dropout_rate': config['dropout_rate'],
+                              'loss_function': loss_func,
+                              'mask_shape': None if config["weight_mask"] is None else config["input_shape"],
+                              # TODO: change to output shape
+                              'old_model_path': config['old_model']})
+        if not overwrite and len(glob.glob(config["model_file"] + '*.h5')) > 0:
+            model_path = get_last_model_path(config["model_file"])
+            print('Loading model from: {}'.format(model_path))
+            model.load_weights(model_path)
+    model.summary()
+
+    # get training and testing generators
+    train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
+        data_file_opened,
+        batch_size=config["batch_size"],
+        data_split=config["validation_split"],
+        overwrite=overwrite,
+        validation_keys_file=config["validation_file"],
+        training_keys_file=config["training_file"],
+        test_keys_file=config["test_file"],
+        n_labels=config["n_labels"],
+        labels=config["labels"],
+        patch_shape=(*config["patch_shape"], config["patch_depth"]),
+        validation_batch_size=config["validation_batch_size"],
+        augment=config["augment"],
+        skip_blank_train=config["skip_blank_train"],
+        skip_blank_val=config["skip_blank_val"],
+        truth_index=config["truth_index"],
+        truth_size=config["truth_size"],
+        prev_truth_index=config["prev_truth_index"],
+        prev_truth_size=config["prev_truth_size"],
+        truth_downsample=config["truth_downsample"],
+        truth_crop=config["truth_crop"],
+        patches_per_epoch=config["patches_per_epoch"],
+        categorical=config["categorical"], is3d=config["3D"],
+        drop_easy_patches_train=config["drop_easy_patches_train"],
+        drop_easy_patches_val=config["drop_easy_patches_val"])
+
+    # run training
+    train_model(model=model,
+                model_file=config["model_file"],
+                training_generator=train_generator,
+                validation_generator=validation_generator,
+                steps_per_epoch=n_train_steps,
+                validation_steps=n_validation_steps,
+                initial_learning_rate=config["initial_learning_rate"],
+                learning_rate_drop=config["learning_rate_drop"],
+                learning_rate_patience=config["patience"],
+                early_stopping_patience=config["early_stop"],
+                n_epochs=config["n_epochs"],
+                output_folder=config["base_dir"])
+    data_file_opened.close()
+
+
+if __name__ == "__main__":
+    main(overwrite=config["overwrite"])