Switch to unified view

a b/neusomatic/python/train.py
1
#-------------------------------------------------------------------------
2
# train.py
3
# Train NeuSomatic network
4
#-------------------------------------------------------------------------
5
6
import os
7
import traceback
8
import argparse
9
import datetime
10
import logging
11
12
import numpy as np
13
import torch
14
from torch.autograd import Variable
15
import torch.nn as nn
16
import torch.nn.functional as F
17
import torch.optim as optim
18
from torchvision import transforms
19
import torchvision
20
from random import shuffle
21
import pickle
22
23
from network import NeuSomaticNet
24
from dataloader import NeuSomaticDataset, matrix_transform
25
from merge_tsvs import merge_tsvs
26
27
type_class_dict = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3}
28
vartype_classes = ['DEL', 'INS', 'NONE', 'SNP']
29
30
import torch._utils
31
try:
32
    torch._utils._rebuild_tensor_v2
33
except AttributeError:
34
    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
35
        tensor = torch._utils._rebuild_tensor(
36
            storage, storage_offset, size, stride)
37
        tensor.requires_grad = requires_grad
38
        tensor._backward_hooks = backward_hooks
39
        return tensor
40
    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
41
42
43
def make_weights_for_balanced_classes(count_class_t, count_class_l, nclasses_t, nclasses_l,
44
                                      none_count=None):
45
    logger = logging.getLogger(make_weights_for_balanced_classes.__name__)
46
47
    w_t = [0] * nclasses_t
48
    w_l = [0] * nclasses_l
49
50
    count_class_t = list(count_class_t)
51
    count_class_l = list(count_class_l)
52
    if none_count:
53
        count_class_t[type_class_dict["NONE"]] = none_count
54
        count_class_l[0] = none_count
55
56
    logger.info("count type classes: {}".format(
57
        list(zip(vartype_classes, count_class_t))))
58
    N = float(sum(count_class_t))
59
    for i in range(nclasses_t):
60
        w_t[i] = (1 - (float(count_class_t[i]) / float(N))) / float(nclasses_t)
61
    w_t = np.array(w_t)
62
    logger.info("weight type classes: {}".format(
63
        list(zip(vartype_classes, w_t))))
64
65
    logger.info("count length classes: {}".format(list(
66
        zip(range(nclasses_l), count_class_l))))
67
    N = float(sum(count_class_l))
68
    for i in range(nclasses_l):
69
        w_l[i] = (1 - (float(count_class_l[i]) / float(N))) / float(nclasses_l)
70
    w_l = np.array(w_l)
71
    logger.info("weight length classes: {}".format(list(
72
        zip(range(nclasses_l), w_l))))
73
    return w_t, w_l
74
75
76
def test(net, epoch, validation_loader, use_cuda):
77
    logger = logging.getLogger(test.__name__)
78
    net.eval()
79
    nclasses = len(vartype_classes)
80
    class_correct = list(0. for i in range(nclasses))
81
    class_total = list(0. for i in range(nclasses))
82
    class_p_total = list(0. for i in range(nclasses))
83
84
    len_class_correct = list(0. for i in range(4))
85
    len_class_total = list(0. for i in range(4))
86
    len_class_p_total = list(0. for i in range(4))
87
88
    falses = []
89
    for data in validation_loader:
90
        (matrices, labels, _, var_len_s, _), (paths) = data
91
92
        matrices = Variable(matrices)
93
        if use_cuda:
94
            matrices = matrices.cuda()
95
96
        outputs, _ = net(matrices)
97
        [outputs1, outputs2, outputs3] = outputs
98
99
        _, predicted = torch.max(outputs1.data.cpu(), 1)
100
        pos_pred = outputs2.data.cpu().numpy()
101
        _, len_pred = torch.max(outputs3.data.cpu(), 1)
102
        preds = {}
103
        for i, _ in enumerate(paths[0]):
104
            preds[i] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i]]
105
106
        if labels.size()[0] > 1:
107
            compare_labels = (predicted == labels).squeeze()
108
        else:
109
            compare_labels = (predicted == labels)
110
        false_preds = np.where(compare_labels.numpy() == 0)[0]
111
        if len(false_preds) > 0:
112
            for i in false_preds:
113
                falses.append([paths[0][i], vartype_classes[predicted[i]], pos_pred[i], len_pred[i],
114
                               list(
115
                                   np.round(F.softmax(outputs1[i, :], 0).data.cpu().numpy(), 4)),
116
                               list(
117
                                   np.round(F.softmax(outputs3[i, :], 0).data.cpu().numpy(), 4))])
118
119
        for i in range(len(labels)):
120
            label = labels[i]
121
            class_correct[label] += compare_labels[i].data.cpu().numpy()
122
            class_total[label] += 1
123
        for i in range(len(predicted)):
124
            label = predicted[i]
125
            class_p_total[label] += 1
126
127
        if var_len_s.size()[0] > 1:
128
            compare_len = (len_pred == var_len_s).squeeze()
129
        else:
130
            compare_len = (len_pred == var_len_s)
131
132
        for i in range(len(var_len_s)):
133
            len_ = var_len_s[i]
134
            len_class_correct[len_] += compare_len[i].data.cpu().numpy()
135
            len_class_total[len_] += 1
136
        for i in range(len(len_pred)):
137
            len_ = len_pred[i]
138
            len_class_p_total[len_] += 1
139
140
    for i in range(nclasses):
141
        SN = 100 * class_correct[i] / (class_total[i] + 0.0001)
142
        PR = 100 * class_correct[i] / (class_p_total[i] + 0.0001)
143
        F1 = 2 * PR * SN / (PR + SN + 0.0001)
144
        logger.info('Epoch {}: Type Accuracy of {:>5} ({}) : {:.2f}  {:.2f} {:.2f}'.format(
145
            epoch,
146
            vartype_classes[i], class_total[i],
147
            SN, PR, F1))
148
    logger.info('Epoch {}: Type Accuracy of the network on the {} test candidates: {:.4f} %'.format(
149
        epoch, sum(class_total), (
150
            100 * sum(class_correct) / float(sum(class_total)))))
151
152
    for i in range(4):
153
        SN = 100 * len_class_correct[i] / (len_class_total[i] + 0.0001)
154
        PR = 100 * len_class_correct[i] / (len_class_p_total[i] + 0.0001)
155
        F1 = 2 * PR * SN / (PR + SN + 0.0001)
156
        logger.info('Epoch {}: Length Accuracy of {:>5} ({}) : {:.2f}  {:.2f} {:.2f}'.format(
157
            epoch, i, len_class_total[i],
158
            SN, PR, F1))
159
    logger.info('Epoch {}: Length Accuracy of the network on the {} test candidates: {:.4f} %'.format(
160
        epoch, sum(len_class_total), (
161
            100 * sum(len_class_correct) / float(sum(len_class_total)))))
162
163
    net.train()
164
165
166
class SubsetNoneSampler(torch.utils.data.sampler.Sampler):
167
168
    def __init__(self, none_indices, var_indices, none_count):
169
        self.none_indices = none_indices
170
        self.var_indices = var_indices
171
        self.none_count = none_count
172
        self.current_none_id = 0
173
174
    def __iter__(self):
175
        logger = logging.getLogger(SubsetNoneSampler.__iter__.__name__)
176
        if self.current_none_id > (len(self.none_indices) - self.none_count):
177
            this_round_nones = self.none_indices[self.current_none_id:]
178
            self.none_indices = list(map(lambda i: self.none_indices[i],
179
                                         torch.randperm(len(self.none_indices)).tolist()))
180
            self.current_none_id = self.none_count - len(this_round_nones)
181
            this_round_nones += self.none_indices[0:self.current_none_id]
182
        else:
183
            this_round_nones = self.none_indices[
184
                self.current_none_id:self.current_none_id + self.none_count]
185
            self.current_none_id += self.none_count
186
187
        current_indices = this_round_nones + self.var_indices
188
        ret = iter(map(lambda i: current_indices[i],
189
                       torch.randperm(len(current_indices))))
190
        return ret
191
192
    def __len__(self):
193
        return len(self.var_indices) + self.none_count
194
195
196
def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint,
197
                     num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs,
198
                     lr_drop_ratio, momentum, boost_none, none_count_scale,
199
                     max_load_candidates, coverage_thr, save_freq, ensemble,
200
                     merged_candidates_per_tsv, merged_max_num_tsvs, overwrite_merged_tsvs,
201
                     train_split_len,
202
                     normalize_channels,
203
                     use_cuda):
204
    logger = logging.getLogger(train_neusomatic.__name__)
205
206
    logger.info("----------------Train NeuSomatic Network-------------------")
207
    logger.info("PyTorch Version: {}".format(torch.__version__))
208
    logger.info("Torchvision Version: {}".format(torchvision.__version__))
209
210
    if not os.path.exists(out_dir):
211
        os.mkdir(out_dir)
212
213
    if not use_cuda:
214
        torch.set_num_threads(num_threads)
215
216
    data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
217
    num_channels = 119 if ensemble else 26
218
    net = NeuSomaticNet(num_channels)
219
    if use_cuda:
220
        logger.info("GPU training!")
221
        net.cuda()
222
    else:
223
        logger.info("CPU training!")
224
225
    if torch.cuda.device_count() > 1:
226
        logger.info("We use {} GPUs!".format(torch.cuda.device_count()))
227
        net = nn.DataParallel(net)
228
229
    if not os.path.exists("{}/models/".format(out_dir)):
230
        os.mkdir("{}/models/".format(out_dir))
231
232
    if checkpoint:
233
        logger.info(
234
            "Load pretrained model from checkpoint {}".format(checkpoint))
235
        pretrained_dict = torch.load(
236
            checkpoint, map_location=lambda storage, loc: storage)
237
        pretrained_state_dict = pretrained_dict["state_dict"]
238
        tag = pretrained_dict["tag"]
239
        sofar_epochs = pretrained_dict["epoch"]
240
        logger.info(
241
            "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs))
242
        coverage_thr = pretrained_dict["coverage_thr"]
243
        logger.info(
244
            "Override coverage_thr from pretrained checkpoint: {}".format(coverage_thr))
245
        if "normalize_channels" in pretrained_dict:
246
            normalize_channels = pretrained_dict["normalize_channels"]
247
        else:
248
            normalize_channels = False
249
        logger.info(
250
            "Override normalize_channels from pretrained checkpoint: {}".format(normalize_channels))
251
        prev_epochs = sofar_epochs + 1
252
        model_dict = net.state_dict()
253
        # 1. filter out unnecessary keys
254
        # pretrained_state_dict = {
255
        # k: v for k, v in pretrained_state_dict.items() if k in model_dict}
256
        if "module." in list(pretrained_state_dict.keys())[0] and "module." not in list(model_dict.keys())[0]:
257
            pretrained_state_dict = {k.split("module.")[1]: v for k, v in pretrained_state_dict.items(
258
            ) if k.split("module.")[1] in model_dict}
259
        elif "module." not in list(pretrained_state_dict.keys())[0] and "module." in list(model_dict.keys())[0]:
260
            pretrained_state_dict = {
261
                ("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict}
262
        else:
263
            pretrained_state_dict = {k: v for k,
264
                                     v in pretrained_state_dict.items() if k in model_dict}
265
        # 2. overwrite entries in the existing state dict
266
        model_dict.update(pretrained_state_dict)
267
        # 3. load the new state dict
268
        net.load_state_dict(pretrained_state_dict)
269
    else:
270
        prev_epochs = 0
271
        time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
272
        tag = "neusomatic_{}".format(time_now)
273
    logger.info("tag: {}".format(tag))
274
275
    shuffle(candidates_tsv)
276
277
    if len(candidates_tsv) > merged_max_num_tsvs:
278
        candidates_tsv = merge_tsvs(input_tsvs=candidates_tsv, out=out_dir,
279
                                    candidates_per_tsv=merged_candidates_per_tsv,
280
                                    max_num_tsvs=merged_max_num_tsvs,
281
                                    overwrite_merged_tsvs=overwrite_merged_tsvs,
282
                                    keep_none_types=True)
283
284
    Ls = []
285
    for tsv in candidates_tsv:
286
        idx = pickle.load(open(tsv + ".idx", "rb"))
287
        Ls.append(len(idx) - 1)
288
289
    Ls, candidates_tsv = list(zip(
290
        *sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True)))
291
292
    train_split_tsvs = []
293
    current_L = 0
294
    current_split_tsvs = []
295
    for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)):
296
        current_L += L
297
        current_split_tsvs.append(tsv)
298
        if current_L >= train_split_len or (i == len(candidates_tsv) - 1 and current_L > 0):
299
            logger.info("tsvs in split {}: {}".format(
300
                len(train_split_tsvs), current_split_tsvs))
301
            train_split_tsvs.append(current_split_tsvs)
302
            current_L = 0
303
            current_split_tsvs = []
304
305
    assert sum(map(lambda x: len(x), train_split_tsvs)) == len(candidates_tsv)
306
    train_sets = []
307
    none_counts = []
308
    var_counts = []
309
    none_indices_ = []
310
    var_indices_ = []
311
    samplers = []
312
    for split_i, tsvs in enumerate(train_split_tsvs):
313
        train_set = NeuSomaticDataset(roots=tsvs,
314
                                      max_load_candidates=int(
315
                                          max_load_candidates * len(tsvs) / float(len(candidates_tsv))),
316
                                      transform=data_transform, is_test=False,
317
                                      num_threads=num_threads, coverage_thr=coverage_thr,
318
                                      normalize_channels=normalize_channels)
319
        train_sets.append(train_set)
320
        none_indices = train_set.get_none_indices()
321
        var_indices = train_set.get_var_indices()
322
        if none_indices:
323
            none_indices = list(map(lambda i: none_indices[i],
324
                                    torch.randperm(len(none_indices)).tolist()))
325
        logger.info(
326
            "Non-somatic candidates in split {}: {}".format(split_i, len(none_indices)))
327
        if var_indices:
328
            var_indices = list(map(lambda i: var_indices[i],
329
                                   torch.randperm(len(var_indices)).tolist()))
330
        logger.info("Somatic candidates in split {}: {}".format(
331
            split_i, len(var_indices)))
332
        none_count = max(min(len(none_indices), len(
333
            var_indices) * none_count_scale), 1)
334
        logger.info(
335
            "Non-somatic considered in each epoch of split {}: {}".format(split_i, none_count))
336
337
        sampler = SubsetNoneSampler(none_indices, var_indices, none_count)
338
        samplers.append(sampler)
339
        none_counts.append(none_count)
340
        var_counts.append(len(var_indices))
341
        var_indices_.append(var_indices)
342
        none_indices_.append(none_indices)
343
    logger.info("# Total Train cadidates: {}".format(
344
        sum(map(lambda x: len(x), train_sets))))
345
346
    if validation_candidates_tsv:
347
        validation_set = NeuSomaticDataset(roots=validation_candidates_tsv,
348
                                           max_load_candidates=max_load_candidates,
349
                                           transform=data_transform, is_test=True,
350
                                           num_threads=num_threads, coverage_thr=coverage_thr,
351
                                           normalize_channels=normalize_channels)
352
        validation_loader = torch.utils.data.DataLoader(validation_set,
353
                                                        batch_size=batch_size, shuffle=True,
354
                                                        num_workers=num_threads, pin_memory=True)
355
        logger.info("#Validation candidates: {}".format(len(validation_set)))
356
357
    count_class_t = [0] * 4
358
    count_class_l = [0] * 4
359
    for train_set in train_sets:
360
        for i in range(4):
361
            count_class_t[i] += train_set.count_class_t[i]
362
            count_class_l[i] += train_set.count_class_l[i]
363
364
    weights_type, weights_length = make_weights_for_balanced_classes(
365
        count_class_t, count_class_l, 4, 4, sum(none_counts))
366
367
    weights_type[2] *= boost_none
368
    weights_length[0] *= boost_none
369
370
    logger.info("weights_type:{}, weights_length:{}".format(
371
        weights_type, weights_length))
372
373
    loss_s = []
374
    gradients = torch.FloatTensor(weights_type)
375
    gradients2 = torch.FloatTensor(weights_length)
376
    if use_cuda:
377
        gradients = gradients.cuda()
378
        gradients2 = gradients2.cuda()
379
    criterion_crossentropy = nn.CrossEntropyLoss(gradients)
380
    criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2)
381
    criterion_smoothl1 = nn.SmoothL1Loss()
382
    optimizer = optim.SGD(
383
        net.parameters(), lr=learning_rate, momentum=momentum)
384
385
    net.train()
386
    len_train_set = sum(none_counts) + sum(var_counts)
387
    logger.info("Number of candidater per epoch: {}".format(len_train_set))
388
    print_freq = max(1, int(len_train_set / float(batch_size) / 4.0))
389
    curr_epoch = prev_epochs
390
    torch.save({"state_dict": net.state_dict(),
391
                "tag": tag,
392
                "epoch": curr_epoch,
393
                "coverage_thr": coverage_thr,
394
                "normalize_channels": normalize_channels},
395
               '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
396
397
    if len(train_sets) == 1:
398
        train_sets[0].open_candidate_tsvs()
399
        train_loader = torch.utils.data.DataLoader(train_sets[0],
400
                                                   batch_size=batch_size,
401
                                                   num_workers=num_threads, pin_memory=True,
402
                                                   sampler=samplers[0])
403
    # loop over the dataset multiple times
404
    n_epoch = 0
405
    for epoch in range(max_epochs - prev_epochs):
406
        n_epoch += 1
407
        running_loss = 0.0
408
        i_ = 0
409
        for split_i, train_set in enumerate(train_sets):
410
            if len(train_sets) > 1:
411
                train_set.open_candidate_tsvs()
412
                train_loader = torch.utils.data.DataLoader(train_set,
413
                                                           batch_size=batch_size,
414
                                                           num_workers=num_threads, pin_memory=True,
415
                                                           sampler=samplers[split_i])
416
            for i, data in enumerate(train_loader, 0):
417
                i_ += 1
418
                # get the inputs
419
                (inputs, labels, var_pos_s, var_len_s, _), _ = data
420
                # wrap them in Variable
421
                inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable(
422
                    labels), Variable(var_pos_s), Variable(var_len_s)
423
                if use_cuda:
424
                    inputs, labels, var_pos_s, var_len_s = inputs.cuda(
425
                    ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda()
426
427
                # zero the parameter gradients
428
                optimizer.zero_grad()
429
430
                outputs, _ = net(inputs)
431
                [outputs_classification, outputs_pos, outputs_len] = outputs
432
                var_len_labels = Variable(
433
                    torch.LongTensor(var_len_s.cpu().data.numpy()))
434
                if use_cuda:
435
                    var_len_labels = var_len_labels.cuda()
436
                loss = criterion_crossentropy(outputs_classification, labels) + 1 * criterion_smoothl1(
437
                    outputs_pos.squeeze(1), var_pos_s[:, 1]
438
                ) + 1 * criterion_crossentropy2(outputs_len, var_len_labels)
439
440
                loss.backward()
441
                optimizer.step()
442
                loss_s.append(loss.data)
443
444
                running_loss += loss.data
445
                if i_ % print_freq == print_freq - 1:
446
                    logger.info('epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format(
447
                                n_epoch + prev_epochs, len(loss_s),
448
                                learning_rate, running_loss / print_freq))
449
                    running_loss = 0.0
450
            if len(train_sets) > 1:
451
                train_set.close_candidate_tsvs()
452
453
        curr_epoch = n_epoch + prev_epochs
454
        if curr_epoch % save_freq == 0:
455
            torch.save({"state_dict": net.state_dict(),
456
                        "tag": tag,
457
                        "epoch": curr_epoch,
458
                        "coverage_thr": coverage_thr,
459
                        "normalize_channels": normalize_channels,
460
                        }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
461
            if validation_candidates_tsv:
462
                test(net, curr_epoch, validation_loader, use_cuda)
463
        if curr_epoch % lr_drop_epochs == 0:
464
            learning_rate *= lr_drop_ratio
465
            optimizer = optim.SGD(
466
                net.parameters(), lr=learning_rate, momentum=momentum)
467
    logger.info('Finished Training')
468
469
    if len(train_sets) == 1:
470
        train_sets[0].close_candidate_tsvs()
471
472
    curr_epoch = n_epoch + prev_epochs
473
    torch.save({"state_dict": net.state_dict(),
474
                "tag": tag,
475
                "epoch": curr_epoch,
476
                "coverage_thr": coverage_thr,
477
                "normalize_channels": normalize_channels,
478
                }, '{}/models/checkpoint_{}_epoch{}.pth'.format(
479
        out_dir, tag, curr_epoch))
480
    if validation_candidates_tsv:
481
        test(net, curr_epoch, validation_loader, use_cuda)
482
    logger.info("Total Epochs: {}".format(curr_epoch))
483
    logger.info("Total Epochs: {}".format(curr_epoch))
484
485
    logger.info("Training is Done.")
486
487
    return '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)
488
489
if __name__ == '__main__':
490
491
    FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s'
492
    logging.basicConfig(level=logging.INFO, format=FORMAT)
493
    logger = logging.getLogger(__name__)
494
495
    parser = argparse.ArgumentParser(
496
        description='simple call variants from bam')
497
    parser.add_argument('--candidates_tsv', nargs="*",
498
                        help=' train candidate tsv files', required=True)
499
    parser.add_argument('--out', type=str,
500
                        help='output directory', required=True)
501
    parser.add_argument('--checkpoint', type=str,
502
                        help='pretrained network model checkpoint path', default=None)
503
    parser.add_argument('--validation_candidates_tsv', nargs="*",
504
                        help=' validation candidate tsv files', default=[])
505
    parser.add_argument('--ensemble',
506
                        help='Enable training for ensemble mode',
507
                        action="store_true")
508
    parser.add_argument('--num_threads', type=int,
509
                        help='number of threads', default=1)
510
    parser.add_argument('--batch_size', type=int,
511
                        help='batch size', default=1000)
512
    parser.add_argument('--max_epochs', type=int,
513
                        help='maximum number of training epochs', default=1000)
514
    parser.add_argument('--lr', type=float, help='learning_rate', default=0.01)
515
    parser.add_argument('--lr_drop_epochs', type=int,
516
                        help='number of epochs to drop learning rate', default=400)
517
    parser.add_argument('--lr_drop_ratio', type=float,
518
                        help='learning rate drop scale', default=0.1)
519
    parser.add_argument('--momentum', type=float,
520
                        help='SGD momentum', default=0.9)
521
    parser.add_argument('--boost_none', type=float,
522
                        help='the amount to boost none-somatic classification weight', default=100)
523
    parser.add_argument('--none_count_scale', type=float,
524
                        help='ratio of the none/somatic canidates to use in each training epoch \
525
                        (the none candidate will be subsampled in each epoch based on this ratio',
526
                        default=2)
527
    parser.add_argument('--max_load_candidates', type=int,
528
                        help='maximum candidates to load in memory', default=1000000)
529
    parser.add_argument('--save_freq', type=int,
530
                        help='the frequency of saving checkpoints in terms of # epochs', default=50)
531
    parser.add_argument('--merged_candidates_per_tsv', type=int,
532
                        help='Maximum number of candidates in each merged tsv file ', default=10000000)
533
    parser.add_argument('--merged_max_num_tsvs', type=int,
534
                        help='Maximum number of merged tsv files \
535
                        (higher priority than merged_candidates_per_tsv)', default=10)
536
    parser.add_argument('--overwrite_merged_tsvs',
537
                        help='if OUT/merged_tsvs/ folder exists overwrite the merged tsvs',
538
                        action="store_true")
539
    parser.add_argument('--train_split_len', type=int,
540
                        help='Maximum number of candidates used in each split of training (>=merged_candidates_per_tsv)',
541
                        default=10000000)
542
    parser.add_argument('--coverage_thr', type=int,
543
                        help='maximum coverage threshold to be used for network input \
544
                              normalization. \
545
                              Will be overridden if pretrained model is provided\
546
                              For ~50x WGS, coverage_thr=100 should work. \
547
                              For higher coverage WES, coverage_thr=300 should work.', default=100)
548
    parser.add_argument('--normalize_channels',
549
                        help='normalize BQ, MQ, and other bam-info channels by frequency of observed alleles. \
550
                              Will be overridden if pretrained model is provided',
551
                        action="store_true")
552
    args = parser.parse_args()
553
554
    logger.info(args)
555
556
    use_cuda = torch.cuda.is_available()
557
    logger.info("use_cuda: {}".format(use_cuda))
558
559
    try:
560
        checkpoint = train_neusomatic(args.candidates_tsv, args.validation_candidates_tsv,
561
                                      args.out, args.checkpoint, args.num_threads, args.batch_size,
562
                                      args.max_epochs,
563
                                      args.lr, args.lr_drop_epochs, args.lr_drop_ratio, args.momentum,
564
                                      args.boost_none, args.none_count_scale,
565
                                      args.max_load_candidates, args.coverage_thr, args.save_freq,
566
                                      args.ensemble,
567
                                      args.merged_candidates_per_tsv, args.merged_max_num_tsvs,
568
                                      args.overwrite_merged_tsvs, args.train_split_len,
569
                                      args.normalize_channels,
570
                                      use_cuda)
571
    except Exception as e:
572
        logger.error(traceback.format_exc())
573
        logger.error("Aborting!")
574
        logger.error(
575
            "train.py failure on arguments: {}".format(args))
576
        raise e