--- a +++ b/adpkd_segmentation/train.py @@ -0,0 +1,359 @@ +""" +Model train script + +python -m adpkd_segmentation.train --config --config path_to_config_yaml --makelinks # noqa + +If using a specific GPU (e.g. device 2): +CUDA_VISIBLE_DEVICES=2 python -m train --config path_to_config_yaml --makelinks + +The makelinks flag is needed only once to create symbolic links to the data. + +To create and activate the conda environment for this repository: +conda env create --file adpkd-segmentation.yml +conda activate adpkd-segmentation +""" + +# %% +import argparse +import json +import numpy as np +import os +import yaml +import random +from collections import OrderedDict + +from matplotlib import pyplot as plt +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from catalyst.contrib.nn import Lookahead + +from adpkd_segmentation.create_eval_configs import create_config +from adpkd_segmentation.config.config_utils import get_object_instance +from adpkd_segmentation.data.link_data import makelinks +from adpkd_segmentation.evaluate import validate +from adpkd_segmentation.utils.train_utils import ( + load_model_data, + save_model_data, +) +from adpkd_segmentation.data.data_utils import tensor_dict_to_device + +CHECKPOINTS = "checkpoints" +RESULTS = "results" +TB_LOGS = "tb_logs" + + +# %% +def get_losses_str(losses_and_metrics, tensors=True): + out = [] + for k, v in losses_and_metrics.items(): + if tensors: + v = v.item() + out.append("{}:{:.5f}".format(k, v)) + return ", ".join(out) + + +# %% +def is_better(current, previous, metric_type): + if metric_type == "low": + return current < previous + elif metric_type == "high": + return current > previous + else: + raise Exception("unknown metric_type") + + +# %% +def save_val_metrics(metrics, results_dir, epoch, global_step): + with open( + "{}/val_results_ep_{}_gs_{}.json".format( + results_dir, epoch, global_step + ), + "w", + ) as fp: + json.dump(metrics, fp, indent=4) + + +# %% +def tb_log_metrics(writer, metrics, global_step): + for k, v in metrics.items(): + writer.add_scalar(k, v, global_step) + + +# %% +def plot_image_from_batch( + writer, batch, prediction, target, global_step, idx=0 +): + image = batch[idx] + pred_mask = prediction[idx] + mask = target[idx] + # plot each of the image channels as a separate grayscale image + # (C, 1, H, W) shape to treat each of the channels as a new image + image = image.unsqueeze(1) + # same approach for mask and prediction + mask = mask.unsqueeze(1) + pred_mask = pred_mask.unsqueeze(1) + writer.add_images("input_images", image, global_step, dataformats="NCHW") + writer.add_images("mask_channels", mask, global_step, dataformats="NCHW") + writer.add_images( + "prediction_channels", pred_mask, global_step, dataformats="NCHW" + ) + + +# %% +def plot_fig_from_batch( + writer, + batch, + prediction, + target, + global_step, + idx=0, + title="fig: img_target_pred", +): + image = batch[idx][1] # middle channel + # single channel by default + mask = target[idx][0] + pred_mask = prediction[idx][0] + # TODO add standardizaton to convert all setups into single channel format + # currently supporting 1 and 2 channels, sigmoid based + num_channels = target.shape[1] + if num_channels == 2: + pred_mask = prediction[idx][0] + prediction[idx][1] # combine channels + mask = target[idx][0] + target[idx][1] # combine channels + + image = image.detach().cpu() + mask = mask.detach().cpu() + pred_mask = pred_mask.detach().cpu() + + f, axarr = plt.subplots(1, 3) + axarr[0].imshow(image, cmap="gray") + axarr[1].imshow(image, cmap="gray") # background for mask + axarr[1].imshow(mask, alpha=0.5) + axarr[2].imshow(image, cmap="gray") # background for mask + axarr[2].imshow(pred_mask, alpha=0.5) + + writer.add_figure(title, f, global_step) + + +# %% +def train(config, config_save_name): + # reproducibility + seed = config.get("_SEED", 42) + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + model_config = config["_MODEL_CONFIG"] + train_dataloader_config = config["_TRAIN_DATALOADER_CONFIG"] + val_dataloader_config = config["_VAL_DATALOADER_CONFIG"] + loss_metric_config = config["_LOSSES_METRICS_CONFIG"] + experiment_dir = config["_EXPERIMENT_DIR"] + checkpoints_dir = os.path.join(experiment_dir, CHECKPOINTS) + results_dir = os.path.join(experiment_dir, RESULTS) + tb_logs_dir_train = os.path.join(experiment_dir, TB_LOGS, "train") + tb_logs_dir_val = os.path.join(experiment_dir, TB_LOGS, "val") + config_out = os.path.join(experiment_dir, config_save_name) + + saved_checkpoint = config["_MODEL_CHECKPOINT"] + checkpoint_format = config["_NEW_CKP_FORMAT"] + loss_key = config["_OPTIMIZATION_LOSS"] + optim_config = config["_OPTIMIZER"] + lookahead_config = config["_LOOKAHEAD_OPTIM"] + lr_scheduler_config = config["_LR_SCHEDULER"] + experiment_data = config["_EXPERIMENT_DATA"] + val_plotting_dict = config.get("_VAL_PLOTTING") + + model = get_object_instance(model_config)() + global_step = 0 + if saved_checkpoint is not None: + global_step = load_model_data( + saved_checkpoint, model, new_format=checkpoint_format + ) + train_loader = get_object_instance(train_dataloader_config)() + val_loader = get_object_instance(val_dataloader_config)() + + print("Train dataset length: {}".format(len(train_loader.dataset))) + print("Validation dataset length: {}".format(len(val_loader.dataset))) + print( + "Valiation dataset patients:\n{}".format(val_loader.dataset.patients) + ) + + loss_metric = get_object_instance(loss_metric_config)() + optimizer_getter = get_object_instance(optim_config) + lr_scheduler_getter = get_object_instance(lr_scheduler_config) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + os.makedirs(checkpoints_dir) + os.makedirs(results_dir) + os.makedirs(tb_logs_dir_train) + os.makedirs(tb_logs_dir_val) + with open(config_out, "w") as f: + yaml.dump(config, f, default_flow_style=False) + + # create configs for val and test + val_config, val_out_dir = create_config(config, "val") + test_config, test_out_dir = create_config(config, "test") + os.makedirs(val_out_dir) + os.makedirs(test_out_dir) + + val_path = os.path.join(val_out_dir, "val.yaml") + print("Creating evaluation config for val: {}".format(val_path)) + with open(val_path, "w") as f: + yaml.dump(val_config, f, default_flow_style=False) + + test_path = os.path.join(test_out_dir, "test.yaml") + print("Creating evaluation config for test: {}".format(test_path)) + with open(test_path, "w") as f: + yaml.dump(test_config, f, default_flow_style=False) + + train_writer = SummaryWriter(tb_logs_dir_train) + val_writer = SummaryWriter(tb_logs_dir_val) + + model_params = model.parameters() + if config.get("_MODEL_PARAM_PREP") is not None: + model_prep = get_object_instance(config.get("_MODEL_PARAM_PREP")) + model_params = model_prep(model) + + optimizer = optimizer_getter(model_params) + if lookahead_config["use_lookahead"]: + optimizer = Lookahead(optimizer, **lookahead_config["params"]) + lr_scheduler = lr_scheduler_getter(optimizer) + + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + + model = model.to(device) + model.train() + num_epochs = experiment_data["num_epochs"] + batch_log_interval = experiment_data["batch_log_interval"] + # "low" or "high" + best_metric_type = experiment_data["best_metric_type"] + saving_metric = experiment_data["saving_metric"] + previous = float("inf") if best_metric_type == "low" else float("-inf") + + output_example_idx = ( + hasattr(train_loader.dataset, "output_idx") + and train_loader.dataset.output_idx + ) + + for epoch in range(num_epochs): + for output in train_loader: + if output_example_idx: + x_batch, y_batch, index = output + extra_dict = train_loader.dataset.get_extra_dict(index) + extra_dict = tensor_dict_to_device(extra_dict, device) + else: + x_batch, y_batch = output + extra_dict = None + + optimizer.zero_grad() + x_batch = x_batch.to(device) + y_batch = y_batch.to(device) + y_batch_hat = model(x_batch) + losses_and_metrics = loss_metric(y_batch_hat, y_batch, extra_dict) + loss = losses_and_metrics[loss_key] + loss.backward() + optimizer.step() + global_step += 1 + if global_step % batch_log_interval == 0: + print("TRAIN:", get_losses_str(losses_and_metrics)) + tb_log_metrics(train_writer, losses_and_metrics, global_step) + # TODO: add support for softmax processing + prediction = torch.sigmoid(y_batch_hat) + plot_fig_from_batch( + train_writer, x_batch, prediction, y_batch, global_step, + ) + # lr change after each batch + if lr_scheduler_getter.step_type == "after_batch": + lr_scheduler.step() + + # done with one epoch + # let's validate (use code from the validation script) + model.eval() + all_losses_and_metrics = validate( + val_loader, + model, + loss_metric, + device, + plotting_func=plot_fig_from_batch, + plotting_dict=val_plotting_dict, + writer=val_writer, + global_step=global_step, + val_metric_to_check=saving_metric, + output_losses_list=False, + ) + + print("Validation results for epoch {}".format(epoch)) + print("VAL:", get_losses_str(all_losses_and_metrics, tensors=False)) + model.train() + + current = all_losses_and_metrics[saving_metric] + if is_better(current, previous, best_metric_type): + print( + "Validation metric improved " + "at the end of epoch {}".format(epoch) + ) + previous = current + save_val_metrics( + all_losses_and_metrics, results_dir, epoch, global_step + ) + out_path = os.path.join(checkpoints_dir, "best_val_checkpoint.pth") + save_model_data(out_path, model, global_step) + + tb_log_metrics(val_writer, all_losses_and_metrics, global_step) + + # learning rate schedule step at the end of epoch + if lr_scheduler_getter.step_type != "after_batch": + if lr_scheduler_getter.step_type == "use_val": + lr_scheduler.step(all_losses_and_metrics[loss_key]) + elif lr_scheduler_getter.step_type == "use_epoch": + lr_scheduler.step(epoch) + else: + lr_scheduler.step() + + # plot distinct learning rates in order they appear in the optimizer + lr_dict = OrderedDict() + for param_group in optimizer.param_groups: + lr = param_group.get("lr") + lr_dict[lr] = None + for idx, lr in enumerate(lr_dict): + tb_log_metrics(val_writer, {"lr_{}".format(idx): lr}, global_step) + tb_log_metrics( + train_writer, {"lr_{}".format(idx): lr}, global_step + ) + + train_writer.close() + val_writer.close() + + +# %% +def quick_check(run_makelinks=False): + if run_makelinks: + makelinks() + path = "./config/examples/train_example.yaml" + with open(path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + train(config) + + +# %% +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", help="YAML config path", type=str, required=True + ) + parser.add_argument( + "--makelinks", help="Make data links", action="store_true" + ) + + args = parser.parse_args() + with open(args.config, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + if args.makelinks: + makelinks() + config_save_name = os.path.basename(args.config) + train(config, config_save_name)