--- a +++ b/singlecellmultiomics/utils/base_call_covariates.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from multiprocessing import Pool +from pysam import AlignmentFile, FastaFile +import pysam +from singlecellmultiomics.bamProcessing.bamBinCounts import blacklisted_binning_contigs +from singlecellmultiomics.utils.sequtils import reverse_complement, get_context +from singlecellmultiomics.utils import prob_to_phred +from pysamiterators import CachedFasta +from array import array +from uuid import uuid4 +from singlecellmultiomics.bamProcessing import merge_bams, get_contigs_with_reads, has_variant_reads +import argparse +import pickle +import gzip +import pandas as pd +import numpy as np +import os +from dataclasses import dataclass, field +from collections import defaultdict,Counter + +def get_covariate_key(read, qpos, refpos, reference, refbase, cycle_bin_size=3, k_rad=1): + + qual = read.query_qualities[qpos] + qbase = read.query_sequence[qpos] + + if qbase == 'N': + return None + + context = get_context(read.reference_name, refpos, reference, qbase, k_rad) + if 'N' in context or len(context)!=(2*k_rad+1): + context=get_context(read.reference_name, refpos, reference, qbase, 0) + assert len(context)==1 + + if 'N' in context: + return None + + + if read.is_reverse: + cycle = len(read.query_sequence) - qpos # -1 + context = reverse_complement(context) + else: + cycle = qpos + 1 + + return qual, read.is_read2, int(round(cycle / cycle_bin_size)) * cycle_bin_size, context + + +def add_keys_excluding_context_covar(covariates, covar_phreds, k_rad=1): + k_rad = 1 + for base in 'ACTG': + pd.DataFrame({ (key[0],key[1],key[2],key[3][k_rad]):value for key,value in covariates.items() if len(key[3])==k_rad*2+1 and key[3][k_rad]==base }).T + + index = [] + values = [] + for key,(t,f) in [((key[0],key[1],key[2],key[3][k_rad]), value) for key,value in covariates.items() if len(key[3])==k_rad*2+1 and key[3][k_rad]==base ]: + index.append(key) + values.append([t,f]) + + for indices, base_data in pd.DataFrame(values,index=pd.MultiIndex.from_tuples(index)).groupby(level=(0,1,2,3)): + t,f = base_data.sum() + covar_phreds[indices] = prob_to_phred( f/(t+f) ) + +def covariate_obs_to_phreds(covariates, k_rad): + covar_phreds = {} + + for k,(p_base_true,p_base_false) in covariates.items(): + covar_phreds[k] = prob_to_phred( p_base_false/(p_base_true+p_base_false) ) + + # Add values for context not available: + add_keys_excluding_context_covar(covariates, covar_phreds,k_rad=k_rad) + + # Add value for when everything fails (just mean): + t,f = pd.DataFrame(covariates).mean(1) + covar_phreds[None] = prob_to_phred( f/(t+f) ) + + return covar_phreds + + +# Molecule base information container: +# N:reads, +# N: IVT reactions +# +@dataclass +class BaseInformation: + ref_base: str = None + majority_base: str = None + reads: set = field(default_factory=set) + qual: set = field(default_factory=list) + cycle_r1: set = field(default_factory=list) + cycle_r2: set = field(default_factory=list) + ivt_reactions: set = field(default_factory=set) + + +# Nested counter: +def nested_counter(): + return defaultdict(Counter) + + +def get_molecule_covariate_dict(): + return { + + } + + +def get_covariate_dict(): + return { + ('r1', 'qual'): defaultdict(Counter), + ('r2', 'qual'): defaultdict(Counter), + ('r1', 'cycle'): defaultdict(Counter), + ('r2', 'cycle'): defaultdict(Counter), + ('m', 'n_ivt'): defaultdict(Counter), + ('m', 'n_reads'): defaultdict(Counter), + ('m', 'mean_qual'): defaultdict(Counter), + ('m', 'mean_cycle_r1'): defaultdict(Counter), + ('m', 'mean_cycle_r2'): defaultdict(Counter), + + # Within Ivt duplicate comparison: + ('i1', 'qual'): defaultdict(Counter), + ('i2', 'qual'): defaultdict(Counter), + ('i1', 'cycle'): defaultdict(Counter), + ('i2', 'cycle'): defaultdict(Counter), + ('i', 'd_start'): defaultdict(Counter), + + # Outer IVT duplicate comparison (comparing the consensus sequences of IVT duplicates) + ('O1', 'n_reads'): defaultdict(Counter), + ('O', 'conv'): defaultdict(Counter), + + # Inner IVT duplicate comparison (IVT duplicates agree) + ('v', 'conv'): defaultdict(Counter), + + } + + +def pool_wrapper(args): + func, kwargs = args + return func(**kwargs) + + +def get_jobs(alignments, job_size=10_000_000, select_contig=None, **kwargs): + for contig, size in zip(alignments.references, alignments.lengths): + if select_contig is not None and contig != select_contig: + continue + for start in range(0, size, job_size): + end = min(size, start + job_size) + args = {'contig': contig, + 'start': start, + 'stop': end, + } + args.update(kwargs) + yield args + +""" +def extract_covariates(bam_path, contig, start, stop): + covariates = get_covariate_dict() + variant_discoveries = set() + no_variants = set() + max_show = 1 + shown = 0 + with pysam.AlignmentFile(bam_path) as alignments, pysam.AlignmentFile(bulk_bam_path) as bulk, pysam.FastaFile( + reference_path) as reference: + # reference = CachedFasta(ref) + + read_covariates_enabled = False + + reference_sequence = reference.fetch(contig, start, stop + 1).upper() + for i, molecule in enumerate( + MoleculeIterator(alignments, + molecule_class=NlaIIIMolecule, + fragment_class=NlaIIIFragment, + fragment_class_args=fragment_class_args, + contig=contig, + start=start, + stop=stop + + ) + ): + + if len(molecule) < 4: + continue + + found_allele = False + for read in molecule.iter_reads(): + if read.has_tag('DA'): + found_allele = True + break + if not found_allele: + continue + + information_container = defaultdict(BaseInformation) + + majority = molecule.get_consensus() + for key, base in majority.items(): + information_container[key].majority_base = base + + ivt_base_obs = [] + for ivt_id, fragments in molecule.get_rt_reactions().items(): + + # Look for inconsistencies between the fragments of the same IVT reaction + # Calculate IVT consensus: (Only when more fragments are available) + if len(fragments) > 1 and ivt_id[0] is not None: + # IVT consensus: + ivt_molecule = NlaIIIMolecule(fragments) + ivt_base_obs.append(ivt_molecule.get_base_observation_dict()) + # ivt_consensus = ivt_molecule.get_consensus() + + # else: + # ivt_consensus = None + + if read_covariates_enabled: + for fragment in fragments: + + for read in fragment: + if read is None: + continue + + for qpos, refpos in read.get_aligned_pairs(with_seq=False): + if refpos is None or refpos < start or refpos >= stop: + continue + + ref_base = reference_sequence[refpos - start] + # if ref_base is None: + # continue + # ref_base = ref_base.upper() + if qpos is None or refpos is None or ref_base not in 'ACGT': + continue + qbase = read.query_sequence[qpos] + if qbase == 'N': + continue + + key = (read.reference_name, refpos) + if key in known: + continue + + if ref_base != qbase and not key in no_variants: + if (read.reference_name, refpos, qbase) in variant_discoveries: + continue + + if has_variant_reads(bulk, *key, qbase): + + # print(f'Discovered variant at {read.reference_name}:{refpos+1} {ref_base}>{qbase}') + # known.add((read.reference_name, refpos)) + variant_discoveries.add((read.reference_name, refpos, qbase)) + continue + else: + no_variants.add(key) + + if ref_base is not None: + information_container[key].ref_base = ref_base + + information_container[key].reads.add(read) + information_container[key].qual.append(read.query_qualities[qpos]) + + cycle = read.query_length - qpos if read.is_reverse else qpos + if read.is_read1: + information_container[key].cycle_r1.append(cycle) + else: + information_container[key].cycle_r2.append(cycle) + + information_container[key].ivt_reactions.add(ivt_id) + + # Set read covariates: + call_is_correct = ref_base == qbase + covariates[f'r{"2" if read.is_read2 else "1"}', 'qual'][read.query_qualities[qpos]][ + call_is_correct] += 1 + covariates[f'r{"2" if read.is_read2 else "1"}', 'cycle'][cycle][call_is_correct] += 1 + + # Set IVT covariates + if ivt_consensus is None or not key in ivt_consensus: + continue + + ivt_majority_base = ivt_consensus[key] + if ivt_majority_base != ref_base or ivt_majority_base == 'N': + continue + call_matches_ivt_mayority = ivt_majority_base == qbase + covariates[f'i{"2" if read.is_read2 else "1"}', 'qual'][read.query_qualities[qpos]][ + call_matches_ivt_mayority] += 1 + covariates[f'i{"2" if read.is_read2 else "1"}', 'cycle'][cycle][ + call_matches_ivt_mayority] += 1 + + # Calculate distance to start of molecule (in bins of 5 bp, to a maximum of 400) + dstart = np.clip(int(abs(refpos - molecule.get_cut_site()[1]) / 5), 0, 400) + covariates[f'i', 'd_start'][dstart][call_matches_ivt_mayority] += 1 + # if not call_matches_ivt_mayority and shown<max_show: + # #print(key, molecule.sample, f'{ivt_majority_base}>{qbase}' ) + # shown+=1 + + if len(ivt_base_obs) >= 2: + for ivt_matches, gen_pos, base_A, base_B in get_ivt_mismatches(ivt_base_obs): + # Ignore known variation: + if gen_pos in variant_discoveries or gen_pos in known or has_variant_reads(bulk, *gen_pos, + base_A) or has_variant_reads( + bulk, *gen_pos, base_B): + continue + + # Obtain reference base: + + refpos = gen_pos[1] + if refpos is None or refpos < start or refpos >= stop: + continue + + # ref_base = reference_sequence[refpos-start] + try: + context = reference_sequence[refpos - start - 1: refpos - start + 2] + if 'N' in context: + continue + + if molecule.strand: # is reverse.. + context = reverse_complement(context) + base_A = complement(base_A) + base_B = complement(base_B) + + refbase = context[1] + except IndexError: + continue + if ivt_matches: + if refbase != base_A: + covariates['v', 'conv'][f'{context}>{context[0]}{base_A}{context[2]}'][False] += 1 + # else: + # covariates['O','conv'][f'{context}>{context[0]}{base_A}{context[2]}'][True]+=1 + + if refbase != base_B: + covariates['v', 'conv'][f'{context}>{context[0]}{base_B}{context[2]}'][False] += 1 + + + else: + # IVT not matching: + print('Not matching IVT at', gen_pos) + + if refbase != base_A: + covariates['O', 'conv'][f'{context}>{context[0]}{base_A}{context[2]}'][False] += 1 + # else: + # covariates['O','conv'][f'{context}>{context[0]}{base_A}{context[2]}'][True]+=1 + + if refbase != base_B: + covariates['O', 'conv'][f'{context}>{context[0]}{base_B}{context[2]}'][False] += 1 + # else: + # covariates['O','conv'][f'{context}>{context[0]}{base_B}{context[2]}'][True]+=1 + + if False: + # We have accumulated our molecule wisdom. + for location, bi in information_container.items(): + if location in known: + continue + + if bi.ref_base is None or bi.majority_base is None or bi.ref_base not in 'ACGT': + continue + call_is_correct = bi.ref_base == bi.majority_base + + covariates['m', 'n_ivt'][len(bi.ivt_reactions)][call_is_correct] += 1 + covariates['m', 'n_reads'][len(bi.reads)][call_is_correct] += 1 + covariates['m', 'mean_qual'][int(np.mean(bi.qual))][call_is_correct] += 1 + if len(bi.cycle_r2): + covariates['m', 'mean_cycle_r2'][int(np.mean(bi.cycle_r2))][call_is_correct] += 1 + if len(bi.cycle_r1): + covariates['m', 'mean_cycle_r1'][int(np.mean(bi.cycle_r1))][call_is_correct] += 1 + + return covariates, variant_discoveries + +""" + +def extract_covariates(bam_path: str, + reference_path: str, + contig: str, + start: int, + end: int, + start_fetch: int, + end_fetch: int, + filter_kwargs: dict, + covariate_kwargs: dict): + """ + Count mismatches and matches for similar base-calls + + Returns: + match_mismatch(dict) : dictionary ( covariate_key: [mismatches, matches], .. ) + """ + # known is a set() containing locations of known variation (snps) + # @todo: extend to indels + global known # <- Locations, set of (contig, position) tuples to ignore + + joined = dict() + + # Filters which select which reads are used to estimate covariates: + min_mapping_quality = filter_kwargs.get('min_mapping_quality', 0) + deduplicate = filter_kwargs.get('deduplicate', False) + filter_qcfailed = filter_kwargs.get('filter_qcfailed', False) + variant_blacklist_vcf_files = filter_kwargs.get('variant_blacklist_vcf_files', None) + + # Obtain all variants in the selected range: + blacklist = set() + if variant_blacklist_vcf_files is not None: + + for path in variant_blacklist_vcf_files: + try: + with pysam.VariantFile(path) as bf: + for record in bf.fetch(contig, start_fetch, end_fetch): + blacklist.add(record.pos) + except ValueError: + print(f'Contig {contig} is missing in the vcf file') + + + with AlignmentFile(bam_path) as alignments, FastaFile(reference_path) as fa: + reference = CachedFasta(fa) # @todo: prefetch selected region + for read in alignments.fetch(contig, start_fetch, end_fetch): + if (deduplicate and read.is_duplicate) or \ + (read.is_qcfail and filter_qcfailed) or \ + (read.mapping_quality < min_mapping_quality): + continue + + for qpos, refpos, refbase in read.get_aligned_pairs(matches_only=True, with_seq=True): + + if refpos > end or refpos < start: # Prevent the same location to be counted multiple times + continue + + if refpos in blacklist: + continue + + refbase = refbase.upper() + if refbase == 'N' or (read.reference_name, refpos) in known: + continue + + key = get_covariate_key(read, qpos, refpos, reference, refbase, **covariate_kwargs) + if key is None: + continue + + matched = (refbase == read.query_sequence[qpos]) + try: + joined[key][matched] += 1 + except KeyError: + if matched: + joined[key] = array('l', [0, 1]) + else: + joined[key] = array('l', [1, 0]) + return joined + + +def extract_covariates_wrapper(kwargs): + return extract_covariates(**kwargs) + + +def extract_covariates_from_bam(bam_path, reference_path, known_variants, n_processes=None, bin_size=10_000_000, + min_mapping_quality = 40, + deduplicate = True, + filter_qcfailed = True, + variant_blacklist_vcf_files = None + ): + + global known + known = known_variants + + joined = dict() + + job_generation_args = { + 'contig_length_resource': bam_path, + 'bin_size': bin_size, + 'fragment_size': 0} + + filter_kwargs = { + 'min_mapping_quality': 40, + 'deduplicate': True, + 'filter_qcfailed': True, + 'variant_blacklist_vcf_files':variant_blacklist_vcf_files + } + + covariate_kwargs = { + 'cycle_bin_size': 3, + 'k_rad' : 1 + } + jobs_total = sum(1 for _ in (blacklisted_binning_contigs(**job_generation_args))) + + with Pool(n_processes) as workers: + + for i, r in enumerate( + workers.imap_unordered(extract_covariates_wrapper, ( + { + 'bam_path': bam_path, + 'reference_path': reference_path, + 'contig': contig, + 'start': start, + 'end': end, + 'start_fetch': start_fetch, + 'end_fetch': end_fetch, + 'filter_kwargs': filter_kwargs, + 'covariate_kwargs': covariate_kwargs + } + for contig, start, end, start_fetch, end_fetch in + blacklisted_binning_contigs(**job_generation_args)))): + print(round(100 * (i / jobs_total), 1), end='\r') + + for key, tf in r.items(): + try: + joined[key][0] += tf[0] + joined[key][1] += tf[1] + except KeyError: + joined[key] = array('l', [0, 0]) + joined[key][0] += tf[0] + joined[key][1] += tf[1] + return joined + + +def recalibrate_base_calls(read, reference, joined_prob, covariate_kwargs): + # @todo: make copy to save phred scores of soft-clipped bases + # This array will contain all recalibrated phred scores: + new_qualities = array('B', [0] * len(read.query_qualities)) + + # Iterate all aligned pairs and replace phred score: + + for qpos, refpos in read.get_aligned_pairs(matches_only=True, with_seq=False): + + key = get_covariate_key(read, qpos, refpos, reference, None, **covariate_kwargs) + try: + phred = joined_prob[key] + except KeyError: + phred = joined_prob[None] + new_qualities[qpos] = phred + + read.query_qualities = new_qualities + + +def _recalibrate_reads(bam_path, reference_path, contig, start, end, covariate_kwargs, **kwargs): + # Recalibrate the reads in bam_path + + global joined_prob # Global to share over multiprocessing + # joined_prob contains P(error| d), where d is a descriptor generated by get_covariate_key + + o_path = f'out_{uuid4()}.bam' + + # Open source bam file: + with AlignmentFile(bam_path) as alignments, FastaFile(reference_path) as fa: + # @todo: extract only selected region from fasta file: + reference = CachedFasta(fa) + # Open target bam file: + with AlignmentFile(o_path, header=alignments.header, mode='wb') as out: + # Iterate all reads in the source bam file: + for read in alignments.fetch(contig, start, end): + recalibrate_base_calls(read, reference, joined_prob, covariate_kwargs) + out.write(read) + + pysam.index(o_path) + return o_path + + +def __recalibrate_reads(kwargs): + return _recalibrate_reads(**kwargs) + + +def recalibrate_reads(bam_path, target_bam_path, reference_path, n_processes, covariates, covariate_kwargs, intermediate_bam_size=20_000_000): + job_generation_args = { + 'contig_length_resource': bam_path, + 'bin_size': intermediate_bam_size, + 'fragment_size': 0 + } + global joined_prob + joined_prob = covariates + + print(len(covariates), 'discrete elements') + with Pool(n_processes) as workers: + intermediate_bams = list( workers.imap_unordered(__recalibrate_reads, ( + { + 'bam_path': bam_path, + 'reference_path': reference_path, + 'contig': contig, + 'start': None, + 'end': None, + # 'start_fetch': start_fetch, + #'end_fetch': end_fetch, + 'covariate_kwargs': covariate_kwargs + } + for contig in list(get_contigs_with_reads(bam_path))))) + #for contig, start, end, start_fetch, end_fetch in + #blacklisted_binning_contigs(**job_generation_args)))) + + merge_bams(intermediate_bams, target_bam_path) + + + + +if __name__=='__main__': + + argparser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="""Obtain base calling biases using complete posterior distribution and perform corrections. Both identifying the covariates and the recallibration are multithreaded""") + argparser.add_argument('bamfile', metavar='bamfile', type=str) + argparser.add_argument('-reference', help="Path to reference fasta file used to generate the bamfile", required=True) + argparser.add_argument('-known', help="vcf file with known variation", required=False) + + + argparser.add_argument('-threads', help="Amount of threads to use. Uses all when not set") + + argparser.add_argument( + '-covariates_out', + type=str, + help='Write covariates to this file, ends in .pickle.gz. ') + + argparser.add_argument( + '-covariates_in', + type=str, + help='Read in existing covariates, ends in .pickle.gz') + + argparser.add_argument( + '-bam_out', + type=str, + required=False, + help='Write corrected bam file here') + + argparser.add_argument( + '--f', + action= 'store_true') + + args = argparser.parse_args() + + assert args.bam_out!=args.bamfile, 'The input bam file name cannot match the output bam file' + + if args.covariates_in is None and args.covariates_out is not None and args.bam_out is None and os.path.exists(args.covariates_out) and not args.f: + print('Output covariate file already exists. Use -f to overwrite.') + exit() + + # Set defaults when nothing is supplied + if args.covariates_in is None and args.bam_out is None and args.covariates_out is None: + args.bam_out = args.bamfile.replace('.bam','.recall.bam') + if args.covariates_out is None: + args.covariates_out = args.bamfile.replace('.bam','.covariates.pickle.gz') + + + if args.covariates_in is not None: + print(f'Loading covariates from {args.covariates_in} ') + with gzip.open(args.covariates_in,'rb') as i: + covariates = pickle.load(i) + else: + covariates = extract_covariates_from_bam(args.bamfile, + reference_path=args.reference, + variant_blacklist_vcf_files= [] if args.known is None else [args.known], + known_variants=set() + ) + if args.covariates_out is not None: + + print(f'Writing covariates to {args.covariates_out}') + with gzip.open(args.covariates_out,'wb') as o: + pickle.dump(covariates, o) + + if args.bam_out is not None: + # Create cov probs + + covar_phreds = covariate_obs_to_phreds(covariates ,k_rad=1) + + print(f'Writing corrected bam file to {args.bam_out}') + recalibrate_reads(bam_path=args.bamfile, + target_bam_path=args.bam_out, + reference_path=args.reference, + n_processes=args.threads, + covariates=covar_phreds, + covariate_kwargs = { + 'cycle_bin_size': 3, + 'k_rad' : 1 + })