--- a +++ b/neusomatic/python/train.py @@ -0,0 +1,576 @@ +#------------------------------------------------------------------------- +# train.py +# Train NeuSomatic network +#------------------------------------------------------------------------- + +import os +import traceback +import argparse +import datetime +import logging + +import numpy as np +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import transforms +import torchvision +from random import shuffle +import pickle + +from network import NeuSomaticNet +from dataloader import NeuSomaticDataset, matrix_transform +from merge_tsvs import merge_tsvs + +type_class_dict = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3} +vartype_classes = ['DEL', 'INS', 'NONE', 'SNP'] + +import torch._utils +try: + torch._utils._rebuild_tensor_v2 +except AttributeError: + def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): + tensor = torch._utils._rebuild_tensor( + storage, storage_offset, size, stride) + tensor.requires_grad = requires_grad + tensor._backward_hooks = backward_hooks + return tensor + torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 + + +def make_weights_for_balanced_classes(count_class_t, count_class_l, nclasses_t, nclasses_l, + none_count=None): + logger = logging.getLogger(make_weights_for_balanced_classes.__name__) + + w_t = [0] * nclasses_t + w_l = [0] * nclasses_l + + count_class_t = list(count_class_t) + count_class_l = list(count_class_l) + if none_count: + count_class_t[type_class_dict["NONE"]] = none_count + count_class_l[0] = none_count + + logger.info("count type classes: {}".format( + list(zip(vartype_classes, count_class_t)))) + N = float(sum(count_class_t)) + for i in range(nclasses_t): + w_t[i] = (1 - (float(count_class_t[i]) / float(N))) / float(nclasses_t) + w_t = np.array(w_t) + logger.info("weight type classes: {}".format( + list(zip(vartype_classes, w_t)))) + + logger.info("count length classes: {}".format(list( + zip(range(nclasses_l), count_class_l)))) + N = float(sum(count_class_l)) + for i in range(nclasses_l): + w_l[i] = (1 - (float(count_class_l[i]) / float(N))) / float(nclasses_l) + w_l = np.array(w_l) + logger.info("weight length classes: {}".format(list( + zip(range(nclasses_l), w_l)))) + return w_t, w_l + + +def test(net, epoch, validation_loader, use_cuda): + logger = logging.getLogger(test.__name__) + net.eval() + nclasses = len(vartype_classes) + class_correct = list(0. for i in range(nclasses)) + class_total = list(0. for i in range(nclasses)) + class_p_total = list(0. for i in range(nclasses)) + + len_class_correct = list(0. for i in range(4)) + len_class_total = list(0. for i in range(4)) + len_class_p_total = list(0. for i in range(4)) + + falses = [] + for data in validation_loader: + (matrices, labels, _, var_len_s, _), (paths) = data + + matrices = Variable(matrices) + if use_cuda: + matrices = matrices.cuda() + + outputs, _ = net(matrices) + [outputs1, outputs2, outputs3] = outputs + + _, predicted = torch.max(outputs1.data.cpu(), 1) + pos_pred = outputs2.data.cpu().numpy() + _, len_pred = torch.max(outputs3.data.cpu(), 1) + preds = {} + for i, _ in enumerate(paths[0]): + preds[i] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i]] + + if labels.size()[0] > 1: + compare_labels = (predicted == labels).squeeze() + else: + compare_labels = (predicted == labels) + false_preds = np.where(compare_labels.numpy() == 0)[0] + if len(false_preds) > 0: + for i in false_preds: + falses.append([paths[0][i], vartype_classes[predicted[i]], pos_pred[i], len_pred[i], + list( + np.round(F.softmax(outputs1[i, :], 0).data.cpu().numpy(), 4)), + list( + np.round(F.softmax(outputs3[i, :], 0).data.cpu().numpy(), 4))]) + + for i in range(len(labels)): + label = labels[i] + class_correct[label] += compare_labels[i].data.cpu().numpy() + class_total[label] += 1 + for i in range(len(predicted)): + label = predicted[i] + class_p_total[label] += 1 + + if var_len_s.size()[0] > 1: + compare_len = (len_pred == var_len_s).squeeze() + else: + compare_len = (len_pred == var_len_s) + + for i in range(len(var_len_s)): + len_ = var_len_s[i] + len_class_correct[len_] += compare_len[i].data.cpu().numpy() + len_class_total[len_] += 1 + for i in range(len(len_pred)): + len_ = len_pred[i] + len_class_p_total[len_] += 1 + + for i in range(nclasses): + SN = 100 * class_correct[i] / (class_total[i] + 0.0001) + PR = 100 * class_correct[i] / (class_p_total[i] + 0.0001) + F1 = 2 * PR * SN / (PR + SN + 0.0001) + logger.info('Epoch {}: Type Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}'.format( + epoch, + vartype_classes[i], class_total[i], + SN, PR, F1)) + logger.info('Epoch {}: Type Accuracy of the network on the {} test candidates: {:.4f} %'.format( + epoch, sum(class_total), ( + 100 * sum(class_correct) / float(sum(class_total))))) + + for i in range(4): + SN = 100 * len_class_correct[i] / (len_class_total[i] + 0.0001) + PR = 100 * len_class_correct[i] / (len_class_p_total[i] + 0.0001) + F1 = 2 * PR * SN / (PR + SN + 0.0001) + logger.info('Epoch {}: Length Accuracy of {:>5} ({}) : {:.2f} {:.2f} {:.2f}'.format( + epoch, i, len_class_total[i], + SN, PR, F1)) + logger.info('Epoch {}: Length Accuracy of the network on the {} test candidates: {:.4f} %'.format( + epoch, sum(len_class_total), ( + 100 * sum(len_class_correct) / float(sum(len_class_total))))) + + net.train() + + +class SubsetNoneSampler(torch.utils.data.sampler.Sampler): + + def __init__(self, none_indices, var_indices, none_count): + self.none_indices = none_indices + self.var_indices = var_indices + self.none_count = none_count + self.current_none_id = 0 + + def __iter__(self): + logger = logging.getLogger(SubsetNoneSampler.__iter__.__name__) + if self.current_none_id > (len(self.none_indices) - self.none_count): + this_round_nones = self.none_indices[self.current_none_id:] + self.none_indices = list(map(lambda i: self.none_indices[i], + torch.randperm(len(self.none_indices)).tolist())) + self.current_none_id = self.none_count - len(this_round_nones) + this_round_nones += self.none_indices[0:self.current_none_id] + else: + this_round_nones = self.none_indices[ + self.current_none_id:self.current_none_id + self.none_count] + self.current_none_id += self.none_count + + current_indices = this_round_nones + self.var_indices + ret = iter(map(lambda i: current_indices[i], + torch.randperm(len(current_indices)))) + return ret + + def __len__(self): + return len(self.var_indices) + self.none_count + + +def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir, checkpoint, + num_threads, batch_size, max_epochs, learning_rate, lr_drop_epochs, + lr_drop_ratio, momentum, boost_none, none_count_scale, + max_load_candidates, coverage_thr, save_freq, ensemble, + merged_candidates_per_tsv, merged_max_num_tsvs, overwrite_merged_tsvs, + train_split_len, + normalize_channels, + use_cuda): + logger = logging.getLogger(train_neusomatic.__name__) + + logger.info("----------------Train NeuSomatic Network-------------------") + logger.info("PyTorch Version: {}".format(torch.__version__)) + logger.info("Torchvision Version: {}".format(torchvision.__version__)) + + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + if not use_cuda: + torch.set_num_threads(num_threads) + + data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + num_channels = 119 if ensemble else 26 + net = NeuSomaticNet(num_channels) + if use_cuda: + logger.info("GPU training!") + net.cuda() + else: + logger.info("CPU training!") + + if torch.cuda.device_count() > 1: + logger.info("We use {} GPUs!".format(torch.cuda.device_count())) + net = nn.DataParallel(net) + + if not os.path.exists("{}/models/".format(out_dir)): + os.mkdir("{}/models/".format(out_dir)) + + if checkpoint: + logger.info( + "Load pretrained model from checkpoint {}".format(checkpoint)) + pretrained_dict = torch.load( + checkpoint, map_location=lambda storage, loc: storage) + pretrained_state_dict = pretrained_dict["state_dict"] + tag = pretrained_dict["tag"] + sofar_epochs = pretrained_dict["epoch"] + logger.info( + "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs)) + coverage_thr = pretrained_dict["coverage_thr"] + logger.info( + "Override coverage_thr from pretrained checkpoint: {}".format(coverage_thr)) + if "normalize_channels" in pretrained_dict: + normalize_channels = pretrained_dict["normalize_channels"] + else: + normalize_channels = False + logger.info( + "Override normalize_channels from pretrained checkpoint: {}".format(normalize_channels)) + prev_epochs = sofar_epochs + 1 + model_dict = net.state_dict() + # 1. filter out unnecessary keys + # pretrained_state_dict = { + # k: v for k, v in pretrained_state_dict.items() if k in model_dict} + if "module." in list(pretrained_state_dict.keys())[0] and "module." not in list(model_dict.keys())[0]: + pretrained_state_dict = {k.split("module.")[1]: v for k, v in pretrained_state_dict.items( + ) if k.split("module.")[1] in model_dict} + elif "module." not in list(pretrained_state_dict.keys())[0] and "module." in list(model_dict.keys())[0]: + pretrained_state_dict = { + ("module." + k): v for k, v in pretrained_state_dict.items() if ("module." + k) in model_dict} + else: + pretrained_state_dict = {k: v for k, + v in pretrained_state_dict.items() if k in model_dict} + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_state_dict) + # 3. load the new state dict + net.load_state_dict(pretrained_state_dict) + else: + prev_epochs = 0 + time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") + tag = "neusomatic_{}".format(time_now) + logger.info("tag: {}".format(tag)) + + shuffle(candidates_tsv) + + if len(candidates_tsv) > merged_max_num_tsvs: + candidates_tsv = merge_tsvs(input_tsvs=candidates_tsv, out=out_dir, + candidates_per_tsv=merged_candidates_per_tsv, + max_num_tsvs=merged_max_num_tsvs, + overwrite_merged_tsvs=overwrite_merged_tsvs, + keep_none_types=True) + + Ls = [] + for tsv in candidates_tsv: + idx = pickle.load(open(tsv + ".idx", "rb")) + Ls.append(len(idx) - 1) + + Ls, candidates_tsv = list(zip( + *sorted(zip(Ls, candidates_tsv), key=lambda x: x[0], reverse=True))) + + train_split_tsvs = [] + current_L = 0 + current_split_tsvs = [] + for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)): + current_L += L + current_split_tsvs.append(tsv) + if current_L >= train_split_len or (i == len(candidates_tsv) - 1 and current_L > 0): + logger.info("tsvs in split {}: {}".format( + len(train_split_tsvs), current_split_tsvs)) + train_split_tsvs.append(current_split_tsvs) + current_L = 0 + current_split_tsvs = [] + + assert sum(map(lambda x: len(x), train_split_tsvs)) == len(candidates_tsv) + train_sets = [] + none_counts = [] + var_counts = [] + none_indices_ = [] + var_indices_ = [] + samplers = [] + for split_i, tsvs in enumerate(train_split_tsvs): + train_set = NeuSomaticDataset(roots=tsvs, + max_load_candidates=int( + max_load_candidates * len(tsvs) / float(len(candidates_tsv))), + transform=data_transform, is_test=False, + num_threads=num_threads, coverage_thr=coverage_thr, + normalize_channels=normalize_channels) + train_sets.append(train_set) + none_indices = train_set.get_none_indices() + var_indices = train_set.get_var_indices() + if none_indices: + none_indices = list(map(lambda i: none_indices[i], + torch.randperm(len(none_indices)).tolist())) + logger.info( + "Non-somatic candidates in split {}: {}".format(split_i, len(none_indices))) + if var_indices: + var_indices = list(map(lambda i: var_indices[i], + torch.randperm(len(var_indices)).tolist())) + logger.info("Somatic candidates in split {}: {}".format( + split_i, len(var_indices))) + none_count = max(min(len(none_indices), len( + var_indices) * none_count_scale), 1) + logger.info( + "Non-somatic considered in each epoch of split {}: {}".format(split_i, none_count)) + + sampler = SubsetNoneSampler(none_indices, var_indices, none_count) + samplers.append(sampler) + none_counts.append(none_count) + var_counts.append(len(var_indices)) + var_indices_.append(var_indices) + none_indices_.append(none_indices) + logger.info("# Total Train cadidates: {}".format( + sum(map(lambda x: len(x), train_sets)))) + + if validation_candidates_tsv: + validation_set = NeuSomaticDataset(roots=validation_candidates_tsv, + max_load_candidates=max_load_candidates, + transform=data_transform, is_test=True, + num_threads=num_threads, coverage_thr=coverage_thr, + normalize_channels=normalize_channels) + validation_loader = torch.utils.data.DataLoader(validation_set, + batch_size=batch_size, shuffle=True, + num_workers=num_threads, pin_memory=True) + logger.info("#Validation candidates: {}".format(len(validation_set))) + + count_class_t = [0] * 4 + count_class_l = [0] * 4 + for train_set in train_sets: + for i in range(4): + count_class_t[i] += train_set.count_class_t[i] + count_class_l[i] += train_set.count_class_l[i] + + weights_type, weights_length = make_weights_for_balanced_classes( + count_class_t, count_class_l, 4, 4, sum(none_counts)) + + weights_type[2] *= boost_none + weights_length[0] *= boost_none + + logger.info("weights_type:{}, weights_length:{}".format( + weights_type, weights_length)) + + loss_s = [] + gradients = torch.FloatTensor(weights_type) + gradients2 = torch.FloatTensor(weights_length) + if use_cuda: + gradients = gradients.cuda() + gradients2 = gradients2.cuda() + criterion_crossentropy = nn.CrossEntropyLoss(gradients) + criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2) + criterion_smoothl1 = nn.SmoothL1Loss() + optimizer = optim.SGD( + net.parameters(), lr=learning_rate, momentum=momentum) + + net.train() + len_train_set = sum(none_counts) + sum(var_counts) + logger.info("Number of candidater per epoch: {}".format(len_train_set)) + print_freq = max(1, int(len_train_set / float(batch_size) / 4.0)) + curr_epoch = prev_epochs + torch.save({"state_dict": net.state_dict(), + "tag": tag, + "epoch": curr_epoch, + "coverage_thr": coverage_thr, + "normalize_channels": normalize_channels}, + '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) + + if len(train_sets) == 1: + train_sets[0].open_candidate_tsvs() + train_loader = torch.utils.data.DataLoader(train_sets[0], + batch_size=batch_size, + num_workers=num_threads, pin_memory=True, + sampler=samplers[0]) + # loop over the dataset multiple times + n_epoch = 0 + for epoch in range(max_epochs - prev_epochs): + n_epoch += 1 + running_loss = 0.0 + i_ = 0 + for split_i, train_set in enumerate(train_sets): + if len(train_sets) > 1: + train_set.open_candidate_tsvs() + train_loader = torch.utils.data.DataLoader(train_set, + batch_size=batch_size, + num_workers=num_threads, pin_memory=True, + sampler=samplers[split_i]) + for i, data in enumerate(train_loader, 0): + i_ += 1 + # get the inputs + (inputs, labels, var_pos_s, var_len_s, _), _ = data + # wrap them in Variable + inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable( + labels), Variable(var_pos_s), Variable(var_len_s) + if use_cuda: + inputs, labels, var_pos_s, var_len_s = inputs.cuda( + ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda() + + # zero the parameter gradients + optimizer.zero_grad() + + outputs, _ = net(inputs) + [outputs_classification, outputs_pos, outputs_len] = outputs + var_len_labels = Variable( + torch.LongTensor(var_len_s.cpu().data.numpy())) + if use_cuda: + var_len_labels = var_len_labels.cuda() + loss = criterion_crossentropy(outputs_classification, labels) + 1 * criterion_smoothl1( + outputs_pos.squeeze(1), var_pos_s[:, 1] + ) + 1 * criterion_crossentropy2(outputs_len, var_len_labels) + + loss.backward() + optimizer.step() + loss_s.append(loss.data) + + running_loss += loss.data + if i_ % print_freq == print_freq - 1: + logger.info('epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format( + n_epoch + prev_epochs, len(loss_s), + learning_rate, running_loss / print_freq)) + running_loss = 0.0 + if len(train_sets) > 1: + train_set.close_candidate_tsvs() + + curr_epoch = n_epoch + prev_epochs + if curr_epoch % save_freq == 0: + torch.save({"state_dict": net.state_dict(), + "tag": tag, + "epoch": curr_epoch, + "coverage_thr": coverage_thr, + "normalize_channels": normalize_channels, + }, '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch)) + if validation_candidates_tsv: + test(net, curr_epoch, validation_loader, use_cuda) + if curr_epoch % lr_drop_epochs == 0: + learning_rate *= lr_drop_ratio + optimizer = optim.SGD( + net.parameters(), lr=learning_rate, momentum=momentum) + logger.info('Finished Training') + + if len(train_sets) == 1: + train_sets[0].close_candidate_tsvs() + + curr_epoch = n_epoch + prev_epochs + torch.save({"state_dict": net.state_dict(), + "tag": tag, + "epoch": curr_epoch, + "coverage_thr": coverage_thr, + "normalize_channels": normalize_channels, + }, '{}/models/checkpoint_{}_epoch{}.pth'.format( + out_dir, tag, curr_epoch)) + if validation_candidates_tsv: + test(net, curr_epoch, validation_loader, use_cuda) + logger.info("Total Epochs: {}".format(curr_epoch)) + logger.info("Total Epochs: {}".format(curr_epoch)) + + logger.info("Training is Done.") + + return '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch) + +if __name__ == '__main__': + + FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s' + logging.basicConfig(level=logging.INFO, format=FORMAT) + logger = logging.getLogger(__name__) + + parser = argparse.ArgumentParser( + description='simple call variants from bam') + parser.add_argument('--candidates_tsv', nargs="*", + help=' train candidate tsv files', required=True) + parser.add_argument('--out', type=str, + help='output directory', required=True) + parser.add_argument('--checkpoint', type=str, + help='pretrained network model checkpoint path', default=None) + parser.add_argument('--validation_candidates_tsv', nargs="*", + help=' validation candidate tsv files', default=[]) + parser.add_argument('--ensemble', + help='Enable training for ensemble mode', + action="store_true") + parser.add_argument('--num_threads', type=int, + help='number of threads', default=1) + parser.add_argument('--batch_size', type=int, + help='batch size', default=1000) + parser.add_argument('--max_epochs', type=int, + help='maximum number of training epochs', default=1000) + parser.add_argument('--lr', type=float, help='learning_rate', default=0.01) + parser.add_argument('--lr_drop_epochs', type=int, + help='number of epochs to drop learning rate', default=400) + parser.add_argument('--lr_drop_ratio', type=float, + help='learning rate drop scale', default=0.1) + parser.add_argument('--momentum', type=float, + help='SGD momentum', default=0.9) + parser.add_argument('--boost_none', type=float, + help='the amount to boost none-somatic classification weight', default=100) + parser.add_argument('--none_count_scale', type=float, + help='ratio of the none/somatic canidates to use in each training epoch \ + (the none candidate will be subsampled in each epoch based on this ratio', + default=2) + parser.add_argument('--max_load_candidates', type=int, + help='maximum candidates to load in memory', default=1000000) + parser.add_argument('--save_freq', type=int, + help='the frequency of saving checkpoints in terms of # epochs', default=50) + parser.add_argument('--merged_candidates_per_tsv', type=int, + help='Maximum number of candidates in each merged tsv file ', default=10000000) + parser.add_argument('--merged_max_num_tsvs', type=int, + help='Maximum number of merged tsv files \ + (higher priority than merged_candidates_per_tsv)', default=10) + parser.add_argument('--overwrite_merged_tsvs', + help='if OUT/merged_tsvs/ folder exists overwrite the merged tsvs', + action="store_true") + parser.add_argument('--train_split_len', type=int, + help='Maximum number of candidates used in each split of training (>=merged_candidates_per_tsv)', + default=10000000) + parser.add_argument('--coverage_thr', type=int, + help='maximum coverage threshold to be used for network input \ + normalization. \ + Will be overridden if pretrained model is provided\ + For ~50x WGS, coverage_thr=100 should work. \ + For higher coverage WES, coverage_thr=300 should work.', default=100) + parser.add_argument('--normalize_channels', + help='normalize BQ, MQ, and other bam-info channels by frequency of observed alleles. \ + Will be overridden if pretrained model is provided', + action="store_true") + args = parser.parse_args() + + logger.info(args) + + use_cuda = torch.cuda.is_available() + logger.info("use_cuda: {}".format(use_cuda)) + + try: + checkpoint = train_neusomatic(args.candidates_tsv, args.validation_candidates_tsv, + args.out, args.checkpoint, args.num_threads, args.batch_size, + args.max_epochs, + args.lr, args.lr_drop_epochs, args.lr_drop_ratio, args.momentum, + args.boost_none, args.none_count_scale, + args.max_load_candidates, args.coverage_thr, args.save_freq, + args.ensemble, + args.merged_candidates_per_tsv, args.merged_max_num_tsvs, + args.overwrite_merged_tsvs, args.train_split_len, + args.normalize_channels, + use_cuda) + except Exception as e: + logger.error(traceback.format_exc()) + logger.error("Aborting!") + logger.error( + "train.py failure on arguments: {}".format(args)) + raise e