--- a
+++ b/fetal/config_utils.py
@@ -0,0 +1,193 @@
+import json
+import argparse
+import os
+from pathlib import Path
+
+
+def get_config():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--overwrite_config", help="overwrite saved config",
+                        action="store_true")
+    parser.add_argument("--config_dir", help="specifies config dir path",
+                        type=str, required=True)
+    parser.add_argument("--split_dir", help="specifies config dir path",
+                        type=str, required=False)
+    opts = parser.parse_args()
+    # Load previous config if exists
+    if Path(os.path.join(opts.config_dir, 'config.json')).exists() and not opts.overwrite_config:
+        print('Loading previous config.json from {}'.format(opts.config_dir))
+        with open(os.path.join(opts.config_dir, 'config.json')) as f:
+            config = json.load(f)
+    else:
+        config = dict()
+        config["base_dir"] = opts.config_dir
+        config["split_dir"] = './debug_split'
+        config['scans_dir'] = '../../Datasets/brain_new_cutted_window_1_99'
+        config['fake_scans_dir'] = '../../Datasets/brain_new_cutted_window_1_99'
+
+        Path(config["base_dir"]).mkdir(parents=True, exist_ok=True)
+        Path(config["split_dir"]).mkdir(parents=True, exist_ok=True)
+
+        # Training params
+        config["batch_size"] = 1
+        config["validation_batch_size"] = 1  # most of times should be equal to "batch_size"
+        config["patches_per_epoch"] = 800  # patches_per_epoch / batch_size = steps per epoch
+
+        config["n_epochs"] = 50  # cutoff the training after this many epochs
+        config["patience"] = 3  # learning rate will be reduced after this many epochs if the validation loss is not improving
+        config["early_stop"] = 7  # training will be stopped after this many epochs without the validation loss improving
+        config["initial_learning_rate"] = 1e-4
+        config["learning_rate_drop"] = 0.5  # factor by which the learning rate will be reduced
+        config["validation_split"] = 0.90  # portion of the data that will be used for training %
+
+        config["3D"] = False  # Enable for 3D Models
+        if config["3D"]:
+            # Model params (3D)
+            config["patch_shape"] = (16, 16)  # switch to None to train on the whole image
+            config["patch_depth"] = 16
+            config["truth_index"] = 0
+            config["truth_size"] = 16
+            model_name = 'isensee'  # or 'unet'
+        else:
+            # Model params (2D) - should increase "batch_size" and "patches_per_epoch"
+            config["patch_shape"] = (64, 64)  # switch to None to train on the whole image
+            config["patch_depth"] = 5
+            config["truth_index"] = 2
+            config["truth_size"] = 1
+            model_name = 'unet'  # or 'isensee'
+
+        # choose model
+        config["model_name"] = {
+            '3D': {
+                'unet': 'unet_model_3d',
+                'isensee': 'isensee2017_model_3d'
+            },
+            '2D': {
+                'unet': 'unet_model_2d',
+                'isensee': 'isensee2017_model'
+            }
+        }['3D' if config["3D"] else '2D'][model_name]
+        config["model_name"] = 'dis_net'
+
+        # choose loss
+        config["loss"] = {
+            0: 'binary_crossentropy_loss',
+            1: 'dice_coefficient_loss',
+            2: 'focal_loss',
+            3: 'dice_and_xent',
+            4: 'dice_and_xent_mask'
+        }[1]
+
+        config["augment"] = {
+            "flip": [0.5, 0.5, 0.5],  # augments the data by randomly flipping an axis during
+            "permute": False,
+            # NOT SUPPORTED (data shape must be a cube. Augments the data by permuting in various directions)
+            "translate": (15, 15, 7),  #
+            "scale": (0.1, 0.1, 0),  # i.e 0.20 for 20%, std of scaling factor, switch to None if you want no distortion
+            # "iso_scale": {
+            #     "max": 1
+            # },
+            "rotate": (0, 0, 90),  # std of angle rotation, switch to None if you want no rotation
+            "poisson_noise": 1,
+            "gaussian_filter": {
+                "prob": 0.0,
+                "max_sigma": 1
+            },
+            "contrast": {
+                'prob': 0,
+                'min_factor': 0.2,
+                'max_factor': 0.1
+            },
+            # "piecewise_affine": {
+            #     'scale': 2
+            # },
+            "elastic_transform": {
+                'alpha': 5,
+                'sigma': 10
+            },
+            # "intensity_multiplication": 0.2,
+            "coarse_dropout": {
+                "rate": 0.2,
+                "size_percent": [0.10, 0.30],
+                "per_channel": True
+            },
+            "gaussian_noise": {
+                "prob": 0.5,
+                "sigma": 0.05
+            },
+            "speckle_noise": {
+                "prob": 0.5,
+                "sigma": 0.05
+            }
+        }
+
+        # If the model outputs smaller result (x,y)-wise than the input
+        config["truth_downsample"] = None  # factor to downsample the ground-truth
+        config["truth_crop"] = False  # if true will crop sample else resize
+        config["categorical"] = False  # will make the target one_hot
+
+        # Relevant only for previous slice truth training
+        config["prev_truth_index"] = None  # None for regular training
+        config["prev_truth_size"] = None  # None for regular training
+
+        config["labels"] = (1,)  # the label numbers on the input image - currently only 1 label supported
+
+        config["skip_blank_train"] = False  # if True, then patches without any target will be skipped
+        config["skip_blank_val"] = False  # if True, then patches without any target will be skipped
+        config["drop_easy_patches_train"] = False  # will randomly prefer balanced patches (50% 1, 50% 0)
+        config["drop_easy_patches_val"] = False  # will randomly prefer balanced patches (50% 1, 50% 0)
+
+        # Data normalization
+        config['normalization'] = {
+            0: False,
+            1: 'all',
+            2: 'each'
+        }[1]  # Normalize by all or each data mean and std
+
+        # add ".gz" extension if needed
+        config["ext"] = ".gz"
+
+        # Not relevant at the moment...
+        config["dropout_rate"] = 0
+
+        # Weight masks (currently supported only with isensee3d model and dice_and_xent_weigthed loss)
+        config["weight_mask"] = None  # ["dists"] # or []
+
+        # Auto set - do not touch
+        config["augment"] = config["augment"] if any(config["augment"].values()) else None
+        config["n_labels"] = len(config["labels"])
+        config["all_modalities"] = ["volume"]
+        config["training_modalities"] = config[
+            "all_modalities"]  # change this if you want to only use some of the modalities
+        config["nb_channels"] = len(config["training_modalities"])
+        config["input_shape"] = tuple(list(config["patch_shape"]) +
+                                      [config["patch_depth"] + (
+                                          config["prev_truth_size"] if config["prev_truth_index"] is not None else 0)])
+        config["truth_channel"] = config["nb_channels"]
+        # Auto set - do not touch
+        config["data_file"] = os.path.join(config["base_dir"], "fetal_data.h5")
+        config["model_file"] = os.path.join(config["base_dir"], "fetal_net_model")
+        config["training_file"] = os.path.join(config["split_dir"], "training_ids.pkl")
+        config["validation_file"] = os.path.join(config["split_dir"], "validation_ids.pkl")
+        config["test_file"] = os.path.join(config["split_dir"], "test_ids.pkl")
+        config["overwrite"] = False  # If True, will previous files. If False, will use previously written files.
+        config["scale_data"] = (0.33, 0.33, 1)
+
+        config["preproc"] = {
+            0: "laplace",
+            1: "laplace_norm",
+            2: "grad",
+            3: "grad_norm"
+        }[1]
+        config["preproc"] = None
+
+        if config['3D']:
+            config["input_shape"] = [1] + list(config["input_shape"])
+
+        # relevant only to NormNet
+        config["old_model"] = '/home/galdude33/Lab/workspace/fetal_envelope2/brats/debug_normnet/old_model.h5'
+
+        with open(os.path.join(config["base_dir"], 'config.json'), mode='w') as f:
+            json.dump(config, f, indent=2)
+
+    return config