--- a
+++ b/adpkd_segmentation/train.py
@@ -0,0 +1,359 @@
+"""
+Model train script
+
+python -m adpkd_segmentation.train --config --config path_to_config_yaml --makelinks # noqa
+
+If using a specific GPU (e.g. device 2):
+CUDA_VISIBLE_DEVICES=2 python -m train --config path_to_config_yaml --makelinks
+
+The makelinks flag is needed only once to create symbolic links to the data.
+
+To create and activate the conda environment for this repository:
+conda env create --file adpkd-segmentation.yml
+conda activate adpkd-segmentation
+"""
+
+# %%
+import argparse
+import json
+import numpy as np
+import os
+import yaml
+import random
+from collections import OrderedDict
+
+from matplotlib import pyplot as plt
+import torch
+import torch.nn as nn
+from torch.utils.tensorboard import SummaryWriter
+from catalyst.contrib.nn import Lookahead
+
+from adpkd_segmentation.create_eval_configs import create_config
+from adpkd_segmentation.config.config_utils import get_object_instance
+from adpkd_segmentation.data.link_data import makelinks
+from adpkd_segmentation.evaluate import validate
+from adpkd_segmentation.utils.train_utils import (
+    load_model_data,
+    save_model_data,
+)
+from adpkd_segmentation.data.data_utils import tensor_dict_to_device
+
+CHECKPOINTS = "checkpoints"
+RESULTS = "results"
+TB_LOGS = "tb_logs"
+
+
+# %%
+def get_losses_str(losses_and_metrics, tensors=True):
+    out = []
+    for k, v in losses_and_metrics.items():
+        if tensors:
+            v = v.item()
+        out.append("{}:{:.5f}".format(k, v))
+    return ", ".join(out)
+
+
+# %%
+def is_better(current, previous, metric_type):
+    if metric_type == "low":
+        return current < previous
+    elif metric_type == "high":
+        return current > previous
+    else:
+        raise Exception("unknown metric_type")
+
+
+# %%
+def save_val_metrics(metrics, results_dir, epoch, global_step):
+    with open(
+        "{}/val_results_ep_{}_gs_{}.json".format(
+            results_dir, epoch, global_step
+        ),
+        "w",
+    ) as fp:
+        json.dump(metrics, fp, indent=4)
+
+
+# %%
+def tb_log_metrics(writer, metrics, global_step):
+    for k, v in metrics.items():
+        writer.add_scalar(k, v, global_step)
+
+
+# %%
+def plot_image_from_batch(
+    writer, batch, prediction, target, global_step, idx=0
+):
+    image = batch[idx]
+    pred_mask = prediction[idx]
+    mask = target[idx]
+    # plot each of the image channels as a separate grayscale image
+    # (C, 1, H, W) shape to treat each of the channels as a new image
+    image = image.unsqueeze(1)
+    # same approach for mask and prediction
+    mask = mask.unsqueeze(1)
+    pred_mask = pred_mask.unsqueeze(1)
+    writer.add_images("input_images", image, global_step, dataformats="NCHW")
+    writer.add_images("mask_channels", mask, global_step, dataformats="NCHW")
+    writer.add_images(
+        "prediction_channels", pred_mask, global_step, dataformats="NCHW"
+    )
+
+
+# %%
+def plot_fig_from_batch(
+    writer,
+    batch,
+    prediction,
+    target,
+    global_step,
+    idx=0,
+    title="fig: img_target_pred",
+):
+    image = batch[idx][1]  # middle channel
+    # single channel by default
+    mask = target[idx][0]
+    pred_mask = prediction[idx][0]
+    # TODO add standardizaton to convert all setups into single channel format
+    # currently supporting 1 and 2 channels, sigmoid based
+    num_channels = target.shape[1]
+    if num_channels == 2:
+        pred_mask = prediction[idx][0] + prediction[idx][1]  # combine channels
+        mask = target[idx][0] + target[idx][1]  # combine channels
+
+    image = image.detach().cpu()
+    mask = mask.detach().cpu()
+    pred_mask = pred_mask.detach().cpu()
+
+    f, axarr = plt.subplots(1, 3)
+    axarr[0].imshow(image, cmap="gray")
+    axarr[1].imshow(image, cmap="gray")  # background for mask
+    axarr[1].imshow(mask, alpha=0.5)
+    axarr[2].imshow(image, cmap="gray")  # background for mask
+    axarr[2].imshow(pred_mask, alpha=0.5)
+
+    writer.add_figure(title, f, global_step)
+
+
+# %%
+def train(config, config_save_name):
+    # reproducibility
+    seed = config.get("_SEED", 42)
+    random.seed(seed)
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False
+
+    model_config = config["_MODEL_CONFIG"]
+    train_dataloader_config = config["_TRAIN_DATALOADER_CONFIG"]
+    val_dataloader_config = config["_VAL_DATALOADER_CONFIG"]
+    loss_metric_config = config["_LOSSES_METRICS_CONFIG"]
+    experiment_dir = config["_EXPERIMENT_DIR"]
+    checkpoints_dir = os.path.join(experiment_dir, CHECKPOINTS)
+    results_dir = os.path.join(experiment_dir, RESULTS)
+    tb_logs_dir_train = os.path.join(experiment_dir, TB_LOGS, "train")
+    tb_logs_dir_val = os.path.join(experiment_dir, TB_LOGS, "val")
+    config_out = os.path.join(experiment_dir, config_save_name)
+
+    saved_checkpoint = config["_MODEL_CHECKPOINT"]
+    checkpoint_format = config["_NEW_CKP_FORMAT"]
+    loss_key = config["_OPTIMIZATION_LOSS"]
+    optim_config = config["_OPTIMIZER"]
+    lookahead_config = config["_LOOKAHEAD_OPTIM"]
+    lr_scheduler_config = config["_LR_SCHEDULER"]
+    experiment_data = config["_EXPERIMENT_DATA"]
+    val_plotting_dict = config.get("_VAL_PLOTTING")
+
+    model = get_object_instance(model_config)()
+    global_step = 0
+    if saved_checkpoint is not None:
+        global_step = load_model_data(
+            saved_checkpoint, model, new_format=checkpoint_format
+        )
+    train_loader = get_object_instance(train_dataloader_config)()
+    val_loader = get_object_instance(val_dataloader_config)()
+
+    print("Train dataset length: {}".format(len(train_loader.dataset)))
+    print("Validation dataset length: {}".format(len(val_loader.dataset)))
+    print(
+        "Valiation dataset patients:\n{}".format(val_loader.dataset.patients)
+    )
+
+    loss_metric = get_object_instance(loss_metric_config)()
+    optimizer_getter = get_object_instance(optim_config)
+    lr_scheduler_getter = get_object_instance(lr_scheduler_config)
+
+    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+    os.makedirs(checkpoints_dir)
+    os.makedirs(results_dir)
+    os.makedirs(tb_logs_dir_train)
+    os.makedirs(tb_logs_dir_val)
+    with open(config_out, "w") as f:
+        yaml.dump(config, f, default_flow_style=False)
+
+    # create configs for val and test
+    val_config, val_out_dir = create_config(config, "val")
+    test_config, test_out_dir = create_config(config, "test")
+    os.makedirs(val_out_dir)
+    os.makedirs(test_out_dir)
+
+    val_path = os.path.join(val_out_dir, "val.yaml")
+    print("Creating evaluation config for val: {}".format(val_path))
+    with open(val_path, "w") as f:
+        yaml.dump(val_config, f, default_flow_style=False)
+
+    test_path = os.path.join(test_out_dir, "test.yaml")
+    print("Creating evaluation config for test: {}".format(test_path))
+    with open(test_path, "w") as f:
+        yaml.dump(test_config, f, default_flow_style=False)
+
+    train_writer = SummaryWriter(tb_logs_dir_train)
+    val_writer = SummaryWriter(tb_logs_dir_val)
+
+    model_params = model.parameters()
+    if config.get("_MODEL_PARAM_PREP") is not None:
+        model_prep = get_object_instance(config.get("_MODEL_PARAM_PREP"))
+        model_params = model_prep(model)
+
+    optimizer = optimizer_getter(model_params)
+    if lookahead_config["use_lookahead"]:
+        optimizer = Lookahead(optimizer, **lookahead_config["params"])
+    lr_scheduler = lr_scheduler_getter(optimizer)
+
+    if torch.cuda.device_count() > 1:
+        model = nn.DataParallel(model)
+
+    model = model.to(device)
+    model.train()
+    num_epochs = experiment_data["num_epochs"]
+    batch_log_interval = experiment_data["batch_log_interval"]
+    # "low" or "high"
+    best_metric_type = experiment_data["best_metric_type"]
+    saving_metric = experiment_data["saving_metric"]
+    previous = float("inf") if best_metric_type == "low" else float("-inf")
+
+    output_example_idx = (
+        hasattr(train_loader.dataset, "output_idx")
+        and train_loader.dataset.output_idx
+    )
+
+    for epoch in range(num_epochs):
+        for output in train_loader:
+            if output_example_idx:
+                x_batch, y_batch, index = output
+                extra_dict = train_loader.dataset.get_extra_dict(index)
+                extra_dict = tensor_dict_to_device(extra_dict, device)
+            else:
+                x_batch, y_batch = output
+                extra_dict = None
+
+            optimizer.zero_grad()
+            x_batch = x_batch.to(device)
+            y_batch = y_batch.to(device)
+            y_batch_hat = model(x_batch)
+            losses_and_metrics = loss_metric(y_batch_hat, y_batch, extra_dict)
+            loss = losses_and_metrics[loss_key]
+            loss.backward()
+            optimizer.step()
+            global_step += 1
+            if global_step % batch_log_interval == 0:
+                print("TRAIN:", get_losses_str(losses_and_metrics))
+                tb_log_metrics(train_writer, losses_and_metrics, global_step)
+                # TODO: add support for softmax processing
+                prediction = torch.sigmoid(y_batch_hat)
+                plot_fig_from_batch(
+                    train_writer, x_batch, prediction, y_batch, global_step,
+                )
+            # lr change after each batch
+            if lr_scheduler_getter.step_type == "after_batch":
+                lr_scheduler.step()
+
+        # done with one epoch
+        # let's validate (use code from the validation script)
+        model.eval()
+        all_losses_and_metrics = validate(
+            val_loader,
+            model,
+            loss_metric,
+            device,
+            plotting_func=plot_fig_from_batch,
+            plotting_dict=val_plotting_dict,
+            writer=val_writer,
+            global_step=global_step,
+            val_metric_to_check=saving_metric,
+            output_losses_list=False,
+        )
+
+        print("Validation results for epoch {}".format(epoch))
+        print("VAL:", get_losses_str(all_losses_and_metrics, tensors=False))
+        model.train()
+
+        current = all_losses_and_metrics[saving_metric]
+        if is_better(current, previous, best_metric_type):
+            print(
+                "Validation metric improved "
+                "at the end of epoch {}".format(epoch)
+            )
+            previous = current
+            save_val_metrics(
+                all_losses_and_metrics, results_dir, epoch, global_step
+            )
+            out_path = os.path.join(checkpoints_dir, "best_val_checkpoint.pth")
+            save_model_data(out_path, model, global_step)
+
+        tb_log_metrics(val_writer, all_losses_and_metrics, global_step)
+
+        # learning rate schedule step at the end of epoch
+        if lr_scheduler_getter.step_type != "after_batch":
+            if lr_scheduler_getter.step_type == "use_val":
+                lr_scheduler.step(all_losses_and_metrics[loss_key])
+            elif lr_scheduler_getter.step_type == "use_epoch":
+                lr_scheduler.step(epoch)
+            else:
+                lr_scheduler.step()
+
+        # plot distinct learning rates in order they appear in the optimizer
+        lr_dict = OrderedDict()
+        for param_group in optimizer.param_groups:
+            lr = param_group.get("lr")
+            lr_dict[lr] = None
+        for idx, lr in enumerate(lr_dict):
+            tb_log_metrics(val_writer, {"lr_{}".format(idx): lr}, global_step)
+            tb_log_metrics(
+                train_writer, {"lr_{}".format(idx): lr}, global_step
+            )
+
+    train_writer.close()
+    val_writer.close()
+
+
+# %%
+def quick_check(run_makelinks=False):
+    if run_makelinks:
+        makelinks()
+    path = "./config/examples/train_example.yaml"
+    with open(path, "r") as f:
+        config = yaml.load(f, Loader=yaml.FullLoader)
+    train(config)
+
+
+# %%
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--config", help="YAML config path", type=str, required=True
+    )
+    parser.add_argument(
+        "--makelinks", help="Make data links", action="store_true"
+    )
+
+    args = parser.parse_args()
+    with open(args.config, "r") as f:
+        config = yaml.load(f, Loader=yaml.FullLoader)
+    if args.makelinks:
+        makelinks()
+    config_save_name = os.path.basename(args.config)
+    train(config, config_save_name)