--- a
+++ b/neusomatic/python/train.py
@@ -0,0 +1,576 @@
+#-------------------------------------------------------------------------
+# train.py
+# Train NeuSomatic network
+#-------------------------------------------------------------------------
+
+import os
+import traceback
+import argparse
+import datetime
+import logging
+
+import numpy as np
+import torch
+from torch.autograd import Variable
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import transforms
+import torchvision
+from random import shuffle
+import pickle
+
+from network import NeuSomaticNet
+from dataloader import NeuSomaticDataset, matrix_transform
+from merge_tsvs import merge_tsvs
+
+type_class_dict = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3}
+vartype_classes = ['DEL', 'INS', 'NONE', 'SNP']
+
+import torch._utils
+try:
+    torch._utils._rebuild_tensor_v2
+except AttributeError:
+    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
+        tensor = torch._utils._rebuild_tensor(
+            storage, storage_offset, size, stride)
+        tensor.requires_grad = requires_grad
+        tensor._backward_hooks = backward_hooks
+        return tensor
+    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
+
+
+def make_weights_for_balanced_classes(count_class_t, count_class_l, nclasses_t, nclasses_l,
+                                      none_count=None):
+    logger = logging.getLogger(make_weights_for_balanced_classes.__name__)
+
+    w_t = [0] * nclasses_t
+    w_l = [0] * nclasses_l
+
+    count_class_t = list(count_class_t)
+    count_class_l = list(count_class_l)
+    if none_count:
+        count_class_t[type_class_dict["NONE"]] = none_count
+        count_class_l[0] = none_count
+
+    logger.info("count type classes: {}".format(
+        list(zip(vartype_classes, count_class_t))))
+    N = float(sum(count_class_t))
+    for i in range(nclasses_t):
+        w_t[i] = (1 - (float(count_class_t[i]) / float(N))) / float(nclasses_t)
+    w_t = np.array(w_t)
+    logger.info("weight type classes: {}".format(
+        list(zip(vartype_classes, w_t))))
+
+    logger.info("count length classes: {}".format(list(
+        zip(range(nclasses_l), count_class_l))))
+    N = float(sum(count_class_l))
+    for i in range(nclasses_l):
+        w_l[i] = (1 - (float(count_class_l[i]) / float(N))) / float(nclasses_l)
+    w_l = np.array(w_l)
+    logger.info("weight length classes: {}".format(list(
+        zip(range(nclasses_l), w_l))))
+    return w_t, w_l
+
+
+def test(net, epoch, validation_loader, use_cuda):
+    logger = logging.getLogger(test.__name__)
+    net.eval()
+    nclasses = len(vartype_classes)
+    class_correct = list(0. for i in range(nclasses))
+    class_total = list(0. for i in range(nclasses))
+    class_p_total = list(0. for i in range(nclasses))
+
+    len_class_correct = list(0. for i in range(4))
+    len_class_total = list(0. for i in range(4))
+    len_class_p_total = list(0. for i in range(4))
+
+    falses = []
+    for data in validation_loader:
+        (matrices, labels, _, var_len_s, _), (paths) = data
+
+        matrices = Variable(matrices)
+        if use_cuda:
+            matrices = matrices.cuda()
+
+        outputs, _ = net(matrices)
+        [outputs1, outputs2, outputs3] = outputs
+
+        _, predicted = torch.max(outputs1.data.cpu(), 1)
+        pos_pred = outputs2.data.cpu().numpy()
+        _, len_pred = torch.max(outputs3.data.cpu(), 1)
+        preds = {}
+        for i, _ in enumerate(paths[0]):
+            preds[i] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i]]
+
+        if labels.size()[0] > 1:
+            compare_labels = (predicted == labels).squeeze()
+        else:
+            compare_labels = (predicted == labels)
+        false_preds = np.where(compare_labels.numpy() == 0)[0]
+        if len(false_preds) > 0:
+            for i in false_preds:
+                falses.append([paths[0][i], vartype_classes[predicted[i]], pos_pred[i], len_pred[i],
+                               list(
+                                   np.round(F.softmax(outputs1[i, :], 0).data.cpu().numpy(), 4)),
+                               list(
+                                   np.round(F.softmax(outputs3[i, :], 0).data.cpu().numpy(), 4))])
+
+        for i in range(len(labels)):
+            label = labels[i]
+            class_correct[label] += compare_labels[i].data.cpu().numpy()
+            class_total[label] += 1
+        for i in range(len(predicted)):
+            label = predicted[i]
+            class_p_total[label] += 1
+
+        if var_len_s.size()[0] > 1:
+            compare_len = (len_pred == var_len_s).squeeze()
+        else:
+            compare_len = (len_pred == var_len_s)
+
+        for i in range(len(var_len_s)):
+            len_ = var_len_s[i]
+            len_class_correct[len_] += compare_len[i].data.cpu().numpy()
+            len_class_total[len_] += 1
+        for i in range(len(len_pred)):
+            len_ = len_pred[i]
+            len_class_p_total[len_] += 1
+
+    for i in range(nclasses):
+        SN = 100 * class_correct[i] / (class_total[i] + 0.0001)
+        PR = 100 * class_correct[i] / (class_p_total[i] + 0.0001)
+        F1 = 2 * PR * SN / (PR + SN + 0.0001)
+        logger.info('Epoch {}: Type Accuracy of {:>5} ({}) : {:.2f}  {:.2f} {:.2f}'.format(
+            epoch,
+            vartype_classes[i], class_total[i],
+            SN, PR, F1))
+    logger.info('Epoch {}: Type Accuracy of the network on the {} test candidates: {:.4f} %'.format(
+        epoch, sum(class_total), (
+            100 * sum(class_correct) / float(sum(class_total)))))
+
+    for i in range(4):
+        SN = 100 * len_class_correct[i] / (len_class_total[i] + 0.0001)
+        PR = 100 * len_class_correct[i] / (len_class_p_total[i] + 0.0001)
+        F1 = 2 * PR * SN / (PR + SN + 0.0001)
+        logger.info('Epoch {}: Length Accuracy of {:>5} ({}) : {:.2f}  {:.2f} {:.2f}'.format(
+            epoch, i, len_class_total[i],
+            SN, PR, F1))
+    logger.info('Epoch {}: Length Accuracy of the network on the {} test candidates: {:.4f} %'.format(
+        epoch, sum(len_class_total), (
+            100 * sum(len_class_correct) / float(sum(len_class_total)))))
+
+    net.train()
+
+
+class SubsetNoneSampler(torch.utils.data.sampler.Sampler):
+
+    def __init__(self, none_indices, var_indices, none_count):
+        self.none_indices = none_indices
+        self.var_indices = var_indices
+        self.none_count = none_count
+        self.current_none_id = 0
+
+    def __iter__(self):
+        logger = logging.getLogger(SubsetNoneSampler.__iter__.__name__)
+        if self.current_none_id > (len(self.none_indices) - self.none_count):
+            this_round_nones = self.none_indices[self.current_none_id:]
+            self.none_indices = list(map(lambda i: self.none_indices[i],
+                                         torch.randperm(len(self.none_indices)).tolist()))
+            self.current_none_id = self.none_count - len(this_round_nones)
+            this_round_nones += self.none_indices[0:self.current_none_id]
+        else:
+            this_round_nones = self.none_indices[
+                self.current_none_id:self.current_none_id + self.none_count]
+            self.current_none_id += self.none_count
+
+        current_indices = this_round_nones + self.var_indices
+        ret = iter(map(lambda i: current_indices[i],
+                       torch.randperm(len(current_indices))))
+        return ret
+
+    def __len__(self):
+        return len(self.var_indices) + self.none_count
+
+
+def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint,
+                     num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs,
+                     lr_drop_ratio, momentum, boost_none, none_count_scale,
+                     max_load_candidates, coverage_thr, save_freq, ensemble,
+                     merged_candidates_per_tsv, merged_max_num_tsvs, overwrite_merged_tsvs,
+                     train_split_len,
+                     normalize_channels,
+                     use_cuda):
+    logger = logging.getLogger(train_neusomatic.__name__)
+
+    logger.info("----------------Train NeuSomatic Network-------------------")
+    logger.info("PyTorch Version: {}".format(torch.__version__))
+    logger.info("Torchvision Version: {}".format(torchvision.__version__))
+
+    if not os.path.exists(out_dir):
+        os.mkdir(out_dir)
+
+    if not use_cuda:
+        torch.set_num_threads(num_threads)
+
+    data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+    num_channels = 119 if ensemble else 26
+    net = NeuSomaticNet(num_channels)
+    if use_cuda:
+        logger.info("GPU training!")
+        net.cuda()
+    else:
+        logger.info("CPU training!")
+
+    if torch.cuda.device_count() > 1:
+        logger.info("We use {} GPUs!".format(torch.cuda.device_count()))
+        net = nn.DataParallel(net)
+
+    if not os.path.exists("{}/models/".format(out_dir)):
+        os.mkdir("{}/models/".format(out_dir))
+
+    if checkpoint:
+        logger.info(
+            "Load pretrained model from checkpoint {}".format(checkpoint))
+        pretrained_dict = torch.load(
+            checkpoint, map_location=lambda storage, loc: storage)
+        pretrained_state_dict = pretrained_dict["state_dict"]
+        tag = pretrained_dict["tag"]
+        sofar_epochs = pretrained_dict["epoch"]
+        logger.info(
+            "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs))
+        coverage_thr = pretrained_dict["coverage_thr"]
+        logger.info(
+            "Override coverage_thr from pretrained checkpoint: {}".format(coverage_thr))
+        if "normalize_channels" in pretrained_dict:
+            normalize_channels = pretrained_dict["normalize_channels"]
+        else:
+            normalize_channels = False
+        logger.info(
+            "Override normalize_channels from pretrained checkpoint: {}".format(normalize_channels))
+        prev_epochs = sofar_epochs + 1
+        model_dict = net.state_dict()
+        # 1. filter out unnecessary keys
+        # pretrained_state_dict = {
+        # k: v for k, v in pretrained_state_dict.items() if k in model_dict}
+        if "module." in list(pretrained_state_dict.keys())[0] and "module." not in list(model_dict.keys())[0]:
+            pretrained_state_dict = {k.split("module.")[1]: v for k, v in pretrained_state_dict.items(
+            ) if k.split("module.")[1] in model_dict}
+        elif "module." not in list(pretrained_state_dict.keys())[0] and "module." in list(model_dict.keys())[0]:
+            pretrained_state_dict = {
+                ("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict}
+        else:
+            pretrained_state_dict = {k: v for k,
+                                     v in pretrained_state_dict.items() if k in model_dict}
+        # 2. overwrite entries in the existing state dict
+        model_dict.update(pretrained_state_dict)
+        # 3. load the new state dict
+        net.load_state_dict(pretrained_state_dict)
+    else:
+        prev_epochs = 0
+        time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
+        tag = "neusomatic_{}".format(time_now)
+    logger.info("tag: {}".format(tag))
+
+    shuffle(candidates_tsv)
+
+    if len(candidates_tsv) > merged_max_num_tsvs:
+        candidates_tsv = merge_tsvs(input_tsvs=candidates_tsv, out=out_dir,
+                                    candidates_per_tsv=merged_candidates_per_tsv,
+                                    max_num_tsvs=merged_max_num_tsvs,
+                                    overwrite_merged_tsvs=overwrite_merged_tsvs,
+                                    keep_none_types=True)
+
+    Ls = []
+    for tsv in candidates_tsv:
+        idx = pickle.load(open(tsv + ".idx", "rb"))
+        Ls.append(len(idx) - 1)
+
+    Ls, candidates_tsv = list(zip(
+        *sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True)))
+
+    train_split_tsvs = []
+    current_L = 0
+    current_split_tsvs = []
+    for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)):
+        current_L += L
+        current_split_tsvs.append(tsv)
+        if current_L >= train_split_len or (i == len(candidates_tsv) - 1 and current_L > 0):
+            logger.info("tsvs in split {}: {}".format(
+                len(train_split_tsvs), current_split_tsvs))
+            train_split_tsvs.append(current_split_tsvs)
+            current_L = 0
+            current_split_tsvs = []
+
+    assert sum(map(lambda x: len(x), train_split_tsvs)) == len(candidates_tsv)
+    train_sets = []
+    none_counts = []
+    var_counts = []
+    none_indices_ = []
+    var_indices_ = []
+    samplers = []
+    for split_i, tsvs in enumerate(train_split_tsvs):
+        train_set = NeuSomaticDataset(roots=tsvs,
+                                      max_load_candidates=int(
+                                          max_load_candidates * len(tsvs) / float(len(candidates_tsv))),
+                                      transform=data_transform, is_test=False,
+                                      num_threads=num_threads, coverage_thr=coverage_thr,
+                                      normalize_channels=normalize_channels)
+        train_sets.append(train_set)
+        none_indices = train_set.get_none_indices()
+        var_indices = train_set.get_var_indices()
+        if none_indices:
+            none_indices = list(map(lambda i: none_indices[i],
+                                    torch.randperm(len(none_indices)).tolist()))
+        logger.info(
+            "Non-somatic candidates in split {}: {}".format(split_i, len(none_indices)))
+        if var_indices:
+            var_indices = list(map(lambda i: var_indices[i],
+                                   torch.randperm(len(var_indices)).tolist()))
+        logger.info("Somatic candidates in split {}: {}".format(
+            split_i, len(var_indices)))
+        none_count = max(min(len(none_indices), len(
+            var_indices) * none_count_scale), 1)
+        logger.info(
+            "Non-somatic considered in each epoch of split {}: {}".format(split_i, none_count))
+
+        sampler = SubsetNoneSampler(none_indices, var_indices, none_count)
+        samplers.append(sampler)
+        none_counts.append(none_count)
+        var_counts.append(len(var_indices))
+        var_indices_.append(var_indices)
+        none_indices_.append(none_indices)
+    logger.info("# Total Train cadidates: {}".format(
+        sum(map(lambda x: len(x), train_sets))))
+
+    if validation_candidates_tsv:
+        validation_set = NeuSomaticDataset(roots=validation_candidates_tsv,
+                                           max_load_candidates=max_load_candidates,
+                                           transform=data_transform, is_test=True,
+                                           num_threads=num_threads, coverage_thr=coverage_thr,
+                                           normalize_channels=normalize_channels)
+        validation_loader = torch.utils.data.DataLoader(validation_set,
+                                                        batch_size=batch_size, shuffle=True,
+                                                        num_workers=num_threads, pin_memory=True)
+        logger.info("#Validation candidates: {}".format(len(validation_set)))
+
+    count_class_t = [0] * 4
+    count_class_l = [0] * 4
+    for train_set in train_sets:
+        for i in range(4):
+            count_class_t[i] += train_set.count_class_t[i]
+            count_class_l[i] += train_set.count_class_l[i]
+
+    weights_type, weights_length = make_weights_for_balanced_classes(
+        count_class_t, count_class_l, 4, 4, sum(none_counts))
+
+    weights_type[2] *= boost_none
+    weights_length[0] *= boost_none
+
+    logger.info("weights_type:{}, weights_length:{}".format(
+        weights_type, weights_length))
+
+    loss_s = []
+    gradients = torch.FloatTensor(weights_type)
+    gradients2 = torch.FloatTensor(weights_length)
+    if use_cuda:
+        gradients = gradients.cuda()
+        gradients2 = gradients2.cuda()
+    criterion_crossentropy = nn.CrossEntropyLoss(gradients)
+    criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2)
+    criterion_smoothl1 = nn.SmoothL1Loss()
+    optimizer = optim.SGD(
+        net.parameters(), lr=learning_rate, momentum=momentum)
+
+    net.train()
+    len_train_set = sum(none_counts) + sum(var_counts)
+    logger.info("Number of candidater per epoch: {}".format(len_train_set))
+    print_freq = max(1, int(len_train_set / float(batch_size) / 4.0))
+    curr_epoch = prev_epochs
+    torch.save({"state_dict": net.state_dict(),
+                "tag": tag,
+                "epoch": curr_epoch,
+                "coverage_thr": coverage_thr,
+                "normalize_channels": normalize_channels},
+               '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
+
+    if len(train_sets) == 1:
+        train_sets[0].open_candidate_tsvs()
+        train_loader = torch.utils.data.DataLoader(train_sets[0],
+                                                   batch_size=batch_size,
+                                                   num_workers=num_threads, pin_memory=True,
+                                                   sampler=samplers[0])
+    # loop over the dataset multiple times
+    n_epoch = 0
+    for epoch in range(max_epochs - prev_epochs):
+        n_epoch += 1
+        running_loss = 0.0
+        i_ = 0
+        for split_i, train_set in enumerate(train_sets):
+            if len(train_sets) > 1:
+                train_set.open_candidate_tsvs()
+                train_loader = torch.utils.data.DataLoader(train_set,
+                                                           batch_size=batch_size,
+                                                           num_workers=num_threads, pin_memory=True,
+                                                           sampler=samplers[split_i])
+            for i, data in enumerate(train_loader, 0):
+                i_ += 1
+                # get the inputs
+                (inputs, labels, var_pos_s, var_len_s, _), _ = data
+                # wrap them in Variable
+                inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable(
+                    labels), Variable(var_pos_s), Variable(var_len_s)
+                if use_cuda:
+                    inputs, labels, var_pos_s, var_len_s = inputs.cuda(
+                    ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda()
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                outputs, _ = net(inputs)
+                [outputs_classification, outputs_pos, outputs_len] = outputs
+                var_len_labels = Variable(
+                    torch.LongTensor(var_len_s.cpu().data.numpy()))
+                if use_cuda:
+                    var_len_labels = var_len_labels.cuda()
+                loss = criterion_crossentropy(outputs_classification, labels) + 1 * criterion_smoothl1(
+                    outputs_pos.squeeze(1), var_pos_s[:, 1]
+                ) + 1 * criterion_crossentropy2(outputs_len, var_len_labels)
+
+                loss.backward()
+                optimizer.step()
+                loss_s.append(loss.data)
+
+                running_loss += loss.data
+                if i_ % print_freq == print_freq - 1:
+                    logger.info('epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format(
+                                n_epoch + prev_epochs, len(loss_s),
+                                learning_rate, running_loss / print_freq))
+                    running_loss = 0.0
+            if len(train_sets) > 1:
+                train_set.close_candidate_tsvs()
+
+        curr_epoch = n_epoch + prev_epochs
+        if curr_epoch % save_freq == 0:
+            torch.save({"state_dict": net.state_dict(),
+                        "tag": tag,
+                        "epoch": curr_epoch,
+                        "coverage_thr": coverage_thr,
+                        "normalize_channels": normalize_channels,
+                        }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
+            if validation_candidates_tsv:
+                test(net, curr_epoch, validation_loader, use_cuda)
+        if curr_epoch % lr_drop_epochs == 0:
+            learning_rate *= lr_drop_ratio
+            optimizer = optim.SGD(
+                net.parameters(), lr=learning_rate, momentum=momentum)
+    logger.info('Finished Training')
+
+    if len(train_sets) == 1:
+        train_sets[0].close_candidate_tsvs()
+
+    curr_epoch = n_epoch + prev_epochs
+    torch.save({"state_dict": net.state_dict(),
+                "tag": tag,
+                "epoch": curr_epoch,
+                "coverage_thr": coverage_thr,
+                "normalize_channels": normalize_channels,
+                }, '{}/models/checkpoint_{}_epoch{}.pth'.format(
+        out_dir, tag, curr_epoch))
+    if validation_candidates_tsv:
+        test(net, curr_epoch, validation_loader, use_cuda)
+    logger.info("Total Epochs: {}".format(curr_epoch))
+    logger.info("Total Epochs: {}".format(curr_epoch))
+
+    logger.info("Training is Done.")
+
+    return '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)
+
+if __name__ == '__main__':
+
+    FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=FORMAT)
+    logger = logging.getLogger(__name__)
+
+    parser = argparse.ArgumentParser(
+        description='simple call variants from bam')
+    parser.add_argument('--candidates_tsv', nargs="*",
+                        help=' train candidate tsv files', required=True)
+    parser.add_argument('--out', type=str,
+                        help='output directory', required=True)
+    parser.add_argument('--checkpoint', type=str,
+                        help='pretrained network model checkpoint path', default=None)
+    parser.add_argument('--validation_candidates_tsv', nargs="*",
+                        help=' validation candidate tsv files', default=[])
+    parser.add_argument('--ensemble',
+                        help='Enable training for ensemble mode',
+                        action="store_true")
+    parser.add_argument('--num_threads', type=int,
+                        help='number of threads', default=1)
+    parser.add_argument('--batch_size', type=int,
+                        help='batch size', default=1000)
+    parser.add_argument('--max_epochs', type=int,
+                        help='maximum number of training epochs', default=1000)
+    parser.add_argument('--lr', type=float, help='learning_rate', default=0.01)
+    parser.add_argument('--lr_drop_epochs', type=int,
+                        help='number of epochs to drop learning rate', default=400)
+    parser.add_argument('--lr_drop_ratio', type=float,
+                        help='learning rate drop scale', default=0.1)
+    parser.add_argument('--momentum', type=float,
+                        help='SGD momentum', default=0.9)
+    parser.add_argument('--boost_none', type=float,
+                        help='the amount to boost none-somatic classification weight', default=100)
+    parser.add_argument('--none_count_scale', type=float,
+                        help='ratio of the none/somatic canidates to use in each training epoch \
+                        (the none candidate will be subsampled in each epoch based on this ratio',
+                        default=2)
+    parser.add_argument('--max_load_candidates', type=int,
+                        help='maximum candidates to load in memory', default=1000000)
+    parser.add_argument('--save_freq', type=int,
+                        help='the frequency of saving checkpoints in terms of # epochs', default=50)
+    parser.add_argument('--merged_candidates_per_tsv', type=int,
+                        help='Maximum number of candidates in each merged tsv file ', default=10000000)
+    parser.add_argument('--merged_max_num_tsvs', type=int,
+                        help='Maximum number of merged tsv files \
+                        (higher priority than merged_candidates_per_tsv)', default=10)
+    parser.add_argument('--overwrite_merged_tsvs',
+                        help='if OUT/merged_tsvs/ folder exists overwrite the merged tsvs',
+                        action="store_true")
+    parser.add_argument('--train_split_len', type=int,
+                        help='Maximum number of candidates used in each split of training (>=merged_candidates_per_tsv)',
+                        default=10000000)
+    parser.add_argument('--coverage_thr', type=int,
+                        help='maximum coverage threshold to be used for network input \
+                              normalization. \
+                              Will be overridden if pretrained model is provided\
+                              For ~50x WGS, coverage_thr=100 should work. \
+                              For higher coverage WES, coverage_thr=300 should work.', default=100)
+    parser.add_argument('--normalize_channels',
+                        help='normalize BQ, MQ, and other bam-info channels by frequency of observed alleles. \
+                              Will be overridden if pretrained model is provided',
+                        action="store_true")
+    args = parser.parse_args()
+
+    logger.info(args)
+
+    use_cuda = torch.cuda.is_available()
+    logger.info("use_cuda: {}".format(use_cuda))
+
+    try:
+        checkpoint = train_neusomatic(args.candidates_tsv, args.validation_candidates_tsv,
+                                      args.out, args.checkpoint, args.num_threads, args.batch_size,
+                                      args.max_epochs,
+                                      args.lr, args.lr_drop_epochs, args.lr_drop_ratio, args.momentum,
+                                      args.boost_none, args.none_count_scale,
+                                      args.max_load_candidates, args.coverage_thr, args.save_freq,
+                                      args.ensemble,
+                                      args.merged_candidates_per_tsv, args.merged_max_num_tsvs,
+                                      args.overwrite_merged_tsvs, args.train_split_len,
+                                      args.normalize_channels,
+                                      use_cuda)
+    except Exception as e:
+        logger.error(traceback.format_exc())
+        logger.error("Aborting!")
+        logger.error(
+            "train.py failure on arguments: {}".format(args))
+        raise e