--- a
+++ b/src/train.py
@@ -0,0 +1,172 @@
+import importlib
+import sys
+import torch
+from src.models.unet3d import unet3d
+from torchvision import transforms
+
+
+from src.dataset.train_val_split import train_val_split
+from src.losses.ce_dice_loss import CrossEntropyDiceLoss3D
+
+from src.losses import dice_loss, region_based_loss, new_losses
+
+from src.models.io_model import load_model
+from src.train.trainer import Trainer, TrainerArgs
+from torch.optim import lr_scheduler
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from src.config import BratsConfiguration
+from src.dataset.augmentations import color_augmentations, spatial_augmentations
+
+from src.dataset.utils import dataset, visualization as visualization
+from src.models.vnet import vnet, asymm_vnet
+from src.logging_conf import logger
+from src.dataset.loaders.brats_dataset import BratsDataset
+
+
+def num_params(net_params):
+    n_params = sum([p.data.nelement() for p in net_params])
+    logger.info(f"Number of params: {n_params}")
+
+
+######## PARAMS
+logger.info("Processing Parameters...")
+
+config = BratsConfiguration(sys.argv[1])
+model_config = config.get_model_config()
+dataset_config = config.get_dataset_config()
+basic_config = config.get_basic_config()
+
+patch_size = config.patch_size
+tensorboard_logdir = basic_config.get("tensorboard_logs")
+checkpoint_path = model_config.get("checkpoint")
+batch_size = dataset_config.getint("batch_size")
+n_patches = dataset_config.getint("n_patches")
+n_classes = dataset_config.getint("classes")
+loss = model_config.get("loss")
+
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+logger.info(f"Device: {device}")
+
+
+######## DATASET
+logger.info("Creating Dataset...")
+
+data, _ = dataset.read_brats(dataset_config.get("train_csv"), lgg_only=dataset_config.getboolean("lgg_only"))
+data_train, data_val = train_val_split(data, val_size=0.2)
+data_train = data_train * n_patches
+data_val = data_val * n_patches
+
+n_modalities = dataset_config.getint("n_modalities")  # like color channels
+sampling_method = importlib.import_module(dataset_config.get("sampling_method"))
+
+
+transform = transforms.Compose([color_augmentations.RandomIntensityShift(),
+                                color_augmentations.RandomIntensityScale(),
+                                spatial_augmentations.RandomMirrorFlip(p=0.5),
+                                spatial_augmentations.RandomRotation90(p=0.5)])
+
+
+compute_patch = basic_config.getboolean("compute_patches")
+train_dataset = BratsDataset(data_train, sampling_method, patch_size, compute_patch=compute_patch, transform=transform)
+train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
+
+val_dataset = BratsDataset(data_val, sampling_method, patch_size, compute_patch=compute_patch, transform=transform)
+val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
+
+if basic_config.getboolean("plot"):
+    data_batch, labels_batch = next(iter(train_loader))
+    data_batch.reshape(data_batch.shape[0] * data_batch.shape[1], data_batch.shape[2], data_batch.shape[3],
+                       data_batch.shape[4], data_batch.shape[5])
+    labels_batch.reshape(labels_batch.shape[0] * labels_batch.shape[1], labels_batch.shape[2], labels_batch.shape[3],
+                         labels_batch.shape[4])
+
+    print(data_batch.shape)
+    logger.info('Plotting images')
+    visualization.plot_batch_slice(data_batch, labels_batch, slice=30, save=True)
+
+
+######## MODEL
+logger.info("Initiating Model...")
+
+config_network = model_config["network"]
+if config_network== "vnet":
+
+    network = vnet.VNet(elu=model_config.getboolean("use_elu"),
+                        in_channels=n_modalities,
+                        classes=n_classes,
+                        init_features_maps=model_config.getint("init_features_maps"))
+
+elif config_network == "vnet_asymm":
+    network = asymm_vnet.VNet(non_linearity=model_config.get("non_linearity"), in_channels=n_modalities, classes=n_classes,
+                              init_features_maps=model_config.getint("init_features_maps"), kernel_size=model_config.getint("kernel_size"),
+                              padding=model_config.getint("padding"))
+
+elif config_network == "3dunet_residual":
+
+    network = unet3d.ResidualUNet3D(in_channels=n_modalities, out_channels=n_classes, final_sigmoid=False,
+                                    f_maps=model_config.getint("init_features_maps"), layer_order="crg",
+                                    num_levels=4, num_groups=4,conv_padding=1)
+
+elif config_network == "3dunet":
+
+    network = unet3d.UNet3D(in_channels=n_modalities, out_channels=n_classes, final_sigmoid=False,
+                                    f_maps=model_config.getint("init_features_maps"), layer_order="crg",
+                                    num_levels=4, num_groups=4,conv_padding=1)
+else:
+    raise ValueError("Bad parameter for network {}".format(model_config.get("network")))
+
+num_params(network.parameters())
+
+
+##### TRAIN
+logger.info("Start Training")
+network.to(device)
+
+optim = model_config.get("optimizer")
+
+if optim == "SGD":
+    optimizer = torch.optim.SGD(network.parameters(), lr=model_config.getfloat("learning_rate"),
+                                momentum=model_config.getfloat("momentum"), weight_decay=model_config.getfloat("weight_decay"))
+elif optim == "ADAM":
+    optimizer = torch.optim.Adam(network.parameters(), lr=model_config.getfloat("learning_rate"), weight_decay=model_config.getfloat("weight_decay"), amsgrad=False)
+
+else:
+    raise ValueError("Bad optimizer. Current options: [SGD, ADAM]")
+
+best_loss = 1000
+if basic_config.getboolean("resume"):
+    logger.info("Loading model from checkpoint..")
+    model, optimizer, start_epoch, best_loss = load_model(network, checkpoint_path, device, optimizer, True)
+    logger.info(f"Loaded model with starting epoch {start_epoch}")
+else:
+    start_epoch = 0
+
+writer = SummaryWriter(tensorboard_logdir)
+scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=model_config.getfloat("lr_decay"),
+                                           patience=model_config.getint("patience"))
+
+if loss == "dice":
+    criterion = dice_loss.DiceLoss(classes=n_classes, eval_regions=model_config.getboolean("eval_regions"),
+                                   sigmoid_normalization=True)
+
+elif loss == "combined":
+    # 0. back, 1: ncr, 2: ed, 3: et
+    ce_weigh = torch.tensor([0.1, 0.35, 0.2 , 0.35])
+    criterion = CrossEntropyDiceLoss3D(weight=ce_weigh, classes=n_classes,
+                                       eval_regions=model_config.getboolean("eval_regions"), sigmoid_normalization=True)
+elif loss == "both_dice":
+    criterion = region_based_loss.RegionBasedDiceLoss3D(classes=n_classes, sigmoid_normalization=True)
+
+elif loss == "gdl":
+    criterion = new_losses.GeneralizedDiceLoss()
+
+else:
+    raise ValueError(f"Bad loss value {loss}. Expected ['dice', combined]")
+
+args = TrainerArgs(model_config.getint("n_epochs"), device, model_config.get("model_path"), loss)
+trainer = Trainer(args, network, optimizer, criterion, start_epoch, train_loader, val_loader, scheduler, writer)
+trainer.start(best_loss=best_loss)
+
+
+print("Finished!")