#!/usr/bin/env python
# -*- coding: utf-8 -*-
from multiprocessing import Pool
from pysam import AlignmentFile, FastaFile
import pysam
from singlecellmultiomics.bamProcessing.bamBinCounts import blacklisted_binning_contigs
from singlecellmultiomics.utils.sequtils import reverse_complement, get_context
from singlecellmultiomics.utils import prob_to_phred
from pysamiterators import CachedFasta
from array import array
from uuid import uuid4
from singlecellmultiomics.bamProcessing import merge_bams, get_contigs_with_reads, has_variant_reads
import argparse
import pickle
import gzip
import pandas as pd
import numpy as np
import os
from dataclasses import dataclass, field
from collections import defaultdict,Counter
def get_covariate_key(read, qpos, refpos, reference, refbase, cycle_bin_size=3, k_rad=1):
qual = read.query_qualities[qpos]
qbase = read.query_sequence[qpos]
if qbase == 'N':
return None
context = get_context(read.reference_name, refpos, reference, qbase, k_rad)
if 'N' in context or len(context)!=(2*k_rad+1):
context=get_context(read.reference_name, refpos, reference, qbase, 0)
assert len(context)==1
if 'N' in context:
return None
if read.is_reverse:
cycle = len(read.query_sequence) - qpos # -1
context = reverse_complement(context)
else:
cycle = qpos + 1
return qual, read.is_read2, int(round(cycle / cycle_bin_size)) * cycle_bin_size, context
def add_keys_excluding_context_covar(covariates, covar_phreds, k_rad=1):
k_rad = 1
for base in 'ACTG':
pd.DataFrame({ (key[0],key[1],key[2],key[3][k_rad]):value for key,value in covariates.items() if len(key[3])==k_rad*2+1 and key[3][k_rad]==base }).T
index = []
values = []
for key,(t,f) in [((key[0],key[1],key[2],key[3][k_rad]), value) for key,value in covariates.items() if len(key[3])==k_rad*2+1 and key[3][k_rad]==base ]:
index.append(key)
values.append([t,f])
for indices, base_data in pd.DataFrame(values,index=pd.MultiIndex.from_tuples(index)).groupby(level=(0,1,2,3)):
t,f = base_data.sum()
covar_phreds[indices] = prob_to_phred( f/(t+f) )
def covariate_obs_to_phreds(covariates, k_rad):
covar_phreds = {}
for k,(p_base_true,p_base_false) in covariates.items():
covar_phreds[k] = prob_to_phred( p_base_false/(p_base_true+p_base_false) )
# Add values for context not available:
add_keys_excluding_context_covar(covariates, covar_phreds,k_rad=k_rad)
# Add value for when everything fails (just mean):
t,f = pd.DataFrame(covariates).mean(1)
covar_phreds[None] = prob_to_phred( f/(t+f) )
return covar_phreds
# Molecule base information container:
# N:reads,
# N: IVT reactions
#
@dataclass
class BaseInformation:
ref_base: str = None
majority_base: str = None
reads: set = field(default_factory=set)
qual: set = field(default_factory=list)
cycle_r1: set = field(default_factory=list)
cycle_r2: set = field(default_factory=list)
ivt_reactions: set = field(default_factory=set)
# Nested counter:
def nested_counter():
return defaultdict(Counter)
def get_molecule_covariate_dict():
return {
}
def get_covariate_dict():
return {
('r1', 'qual'): defaultdict(Counter),
('r2', 'qual'): defaultdict(Counter),
('r1', 'cycle'): defaultdict(Counter),
('r2', 'cycle'): defaultdict(Counter),
('m', 'n_ivt'): defaultdict(Counter),
('m', 'n_reads'): defaultdict(Counter),
('m', 'mean_qual'): defaultdict(Counter),
('m', 'mean_cycle_r1'): defaultdict(Counter),
('m', 'mean_cycle_r2'): defaultdict(Counter),
# Within Ivt duplicate comparison:
('i1', 'qual'): defaultdict(Counter),
('i2', 'qual'): defaultdict(Counter),
('i1', 'cycle'): defaultdict(Counter),
('i2', 'cycle'): defaultdict(Counter),
('i', 'd_start'): defaultdict(Counter),
# Outer IVT duplicate comparison (comparing the consensus sequences of IVT duplicates)
('O1', 'n_reads'): defaultdict(Counter),
('O', 'conv'): defaultdict(Counter),
# Inner IVT duplicate comparison (IVT duplicates agree)
('v', 'conv'): defaultdict(Counter),
}
def pool_wrapper(args):
func, kwargs = args
return func(**kwargs)
def get_jobs(alignments, job_size=10_000_000, select_contig=None, **kwargs):
for contig, size in zip(alignments.references, alignments.lengths):
if select_contig is not None and contig != select_contig:
continue
for start in range(0, size, job_size):
end = min(size, start + job_size)
args = {'contig': contig,
'start': start,
'stop': end,
}
args.update(kwargs)
yield args
"""
def extract_covariates(bam_path, contig, start, stop):
covariates = get_covariate_dict()
variant_discoveries = set()
no_variants = set()
max_show = 1
shown = 0
with pysam.AlignmentFile(bam_path) as alignments, pysam.AlignmentFile(bulk_bam_path) as bulk, pysam.FastaFile(
reference_path) as reference:
# reference = CachedFasta(ref)
read_covariates_enabled = False
reference_sequence = reference.fetch(contig, start, stop + 1).upper()
for i, molecule in enumerate(
MoleculeIterator(alignments,
molecule_class=NlaIIIMolecule,
fragment_class=NlaIIIFragment,
fragment_class_args=fragment_class_args,
contig=contig,
start=start,
stop=stop
)
):
if len(molecule) < 4:
continue
found_allele = False
for read in molecule.iter_reads():
if read.has_tag('DA'):
found_allele = True
break
if not found_allele:
continue
information_container = defaultdict(BaseInformation)
majority = molecule.get_consensus()
for key, base in majority.items():
information_container[key].majority_base = base
ivt_base_obs = []
for ivt_id, fragments in molecule.get_rt_reactions().items():
# Look for inconsistencies between the fragments of the same IVT reaction
# Calculate IVT consensus: (Only when more fragments are available)
if len(fragments) > 1 and ivt_id[0] is not None:
# IVT consensus:
ivt_molecule = NlaIIIMolecule(fragments)
ivt_base_obs.append(ivt_molecule.get_base_observation_dict())
# ivt_consensus = ivt_molecule.get_consensus()
# else:
# ivt_consensus = None
if read_covariates_enabled:
for fragment in fragments:
for read in fragment:
if read is None:
continue
for qpos, refpos in read.get_aligned_pairs(with_seq=False):
if refpos is None or refpos < start or refpos >= stop:
continue
ref_base = reference_sequence[refpos - start]
# if ref_base is None:
# continue
# ref_base = ref_base.upper()
if qpos is None or refpos is None or ref_base not in 'ACGT':
continue
qbase = read.query_sequence[qpos]
if qbase == 'N':
continue
key = (read.reference_name, refpos)
if key in known:
continue
if ref_base != qbase and not key in no_variants:
if (read.reference_name, refpos, qbase) in variant_discoveries:
continue
if has_variant_reads(bulk, *key, qbase):
# print(f'Discovered variant at {read.reference_name}:{refpos+1} {ref_base}>{qbase}')
# known.add((read.reference_name, refpos))
variant_discoveries.add((read.reference_name, refpos, qbase))
continue
else:
no_variants.add(key)
if ref_base is not None:
information_container[key].ref_base = ref_base
information_container[key].reads.add(read)
information_container[key].qual.append(read.query_qualities[qpos])
cycle = read.query_length - qpos if read.is_reverse else qpos
if read.is_read1:
information_container[key].cycle_r1.append(cycle)
else:
information_container[key].cycle_r2.append(cycle)
information_container[key].ivt_reactions.add(ivt_id)
# Set read covariates:
call_is_correct = ref_base == qbase
covariates[f'r{"2" if read.is_read2 else "1"}', 'qual'][read.query_qualities[qpos]][
call_is_correct] += 1
covariates[f'r{"2" if read.is_read2 else "1"}', 'cycle'][cycle][call_is_correct] += 1
# Set IVT covariates
if ivt_consensus is None or not key in ivt_consensus:
continue
ivt_majority_base = ivt_consensus[key]
if ivt_majority_base != ref_base or ivt_majority_base == 'N':
continue
call_matches_ivt_mayority = ivt_majority_base == qbase
covariates[f'i{"2" if read.is_read2 else "1"}', 'qual'][read.query_qualities[qpos]][
call_matches_ivt_mayority] += 1
covariates[f'i{"2" if read.is_read2 else "1"}', 'cycle'][cycle][
call_matches_ivt_mayority] += 1
# Calculate distance to start of molecule (in bins of 5 bp, to a maximum of 400)
dstart = np.clip(int(abs(refpos - molecule.get_cut_site()[1]) / 5), 0, 400)
covariates[f'i', 'd_start'][dstart][call_matches_ivt_mayority] += 1
# if not call_matches_ivt_mayority and shown<max_show:
# #print(key, molecule.sample, f'{ivt_majority_base}>{qbase}' )
# shown+=1
if len(ivt_base_obs) >= 2:
for ivt_matches, gen_pos, base_A, base_B in get_ivt_mismatches(ivt_base_obs):
# Ignore known variation:
if gen_pos in variant_discoveries or gen_pos in known or has_variant_reads(bulk, *gen_pos,
base_A) or has_variant_reads(
bulk, *gen_pos, base_B):
continue
# Obtain reference base:
refpos = gen_pos[1]
if refpos is None or refpos < start or refpos >= stop:
continue
# ref_base = reference_sequence[refpos-start]
try:
context = reference_sequence[refpos - start - 1: refpos - start + 2]
if 'N' in context:
continue
if molecule.strand: # is reverse..
context = reverse_complement(context)
base_A = complement(base_A)
base_B = complement(base_B)
refbase = context[1]
except IndexError:
continue
if ivt_matches:
if refbase != base_A:
covariates['v', 'conv'][f'{context}>{context[0]}{base_A}{context[2]}'][False] += 1
# else:
# covariates['O','conv'][f'{context}>{context[0]}{base_A}{context[2]}'][True]+=1
if refbase != base_B:
covariates['v', 'conv'][f'{context}>{context[0]}{base_B}{context[2]}'][False] += 1
else:
# IVT not matching:
print('Not matching IVT at', gen_pos)
if refbase != base_A:
covariates['O', 'conv'][f'{context}>{context[0]}{base_A}{context[2]}'][False] += 1
# else:
# covariates['O','conv'][f'{context}>{context[0]}{base_A}{context[2]}'][True]+=1
if refbase != base_B:
covariates['O', 'conv'][f'{context}>{context[0]}{base_B}{context[2]}'][False] += 1
# else:
# covariates['O','conv'][f'{context}>{context[0]}{base_B}{context[2]}'][True]+=1
if False:
# We have accumulated our molecule wisdom.
for location, bi in information_container.items():
if location in known:
continue
if bi.ref_base is None or bi.majority_base is None or bi.ref_base not in 'ACGT':
continue
call_is_correct = bi.ref_base == bi.majority_base
covariates['m', 'n_ivt'][len(bi.ivt_reactions)][call_is_correct] += 1
covariates['m', 'n_reads'][len(bi.reads)][call_is_correct] += 1
covariates['m', 'mean_qual'][int(np.mean(bi.qual))][call_is_correct] += 1
if len(bi.cycle_r2):
covariates['m', 'mean_cycle_r2'][int(np.mean(bi.cycle_r2))][call_is_correct] += 1
if len(bi.cycle_r1):
covariates['m', 'mean_cycle_r1'][int(np.mean(bi.cycle_r1))][call_is_correct] += 1
return covariates, variant_discoveries
"""
def extract_covariates(bam_path: str,
reference_path: str,
contig: str,
start: int,
end: int,
start_fetch: int,
end_fetch: int,
filter_kwargs: dict,
covariate_kwargs: dict):
"""
Count mismatches and matches for similar base-calls
Returns:
match_mismatch(dict) : dictionary ( covariate_key: [mismatches, matches], .. )
"""
# known is a set() containing locations of known variation (snps)
# @todo: extend to indels
global known # <- Locations, set of (contig, position) tuples to ignore
joined = dict()
# Filters which select which reads are used to estimate covariates:
min_mapping_quality = filter_kwargs.get('min_mapping_quality', 0)
deduplicate = filter_kwargs.get('deduplicate', False)
filter_qcfailed = filter_kwargs.get('filter_qcfailed', False)
variant_blacklist_vcf_files = filter_kwargs.get('variant_blacklist_vcf_files', None)
# Obtain all variants in the selected range:
blacklist = set()
if variant_blacklist_vcf_files is not None:
for path in variant_blacklist_vcf_files:
try:
with pysam.VariantFile(path) as bf:
for record in bf.fetch(contig, start_fetch, end_fetch):
blacklist.add(record.pos)
except ValueError:
print(f'Contig {contig} is missing in the vcf file')
with AlignmentFile(bam_path) as alignments, FastaFile(reference_path) as fa:
reference = CachedFasta(fa) # @todo: prefetch selected region
for read in alignments.fetch(contig, start_fetch, end_fetch):
if (deduplicate and read.is_duplicate) or \
(read.is_qcfail and filter_qcfailed) or \
(read.mapping_quality < min_mapping_quality):
continue
for qpos, refpos, refbase in read.get_aligned_pairs(matches_only=True, with_seq=True):
if refpos > end or refpos < start: # Prevent the same location to be counted multiple times
continue
if refpos in blacklist:
continue
refbase = refbase.upper()
if refbase == 'N' or (read.reference_name, refpos) in known:
continue
key = get_covariate_key(read, qpos, refpos, reference, refbase, **covariate_kwargs)
if key is None:
continue
matched = (refbase == read.query_sequence[qpos])
try:
joined[key][matched] += 1
except KeyError:
if matched:
joined[key] = array('l', [0, 1])
else:
joined[key] = array('l', [1, 0])
return joined
def extract_covariates_wrapper(kwargs):
return extract_covariates(**kwargs)
def extract_covariates_from_bam(bam_path, reference_path, known_variants, n_processes=None, bin_size=10_000_000,
min_mapping_quality = 40,
deduplicate = True,
filter_qcfailed = True,
variant_blacklist_vcf_files = None
):
global known
known = known_variants
joined = dict()
job_generation_args = {
'contig_length_resource': bam_path,
'bin_size': bin_size,
'fragment_size': 0}
filter_kwargs = {
'min_mapping_quality': 40,
'deduplicate': True,
'filter_qcfailed': True,
'variant_blacklist_vcf_files':variant_blacklist_vcf_files
}
covariate_kwargs = {
'cycle_bin_size': 3,
'k_rad' : 1
}
jobs_total = sum(1 for _ in (blacklisted_binning_contigs(**job_generation_args)))
with Pool(n_processes) as workers:
for i, r in enumerate(
workers.imap_unordered(extract_covariates_wrapper, (
{
'bam_path': bam_path,
'reference_path': reference_path,
'contig': contig,
'start': start,
'end': end,
'start_fetch': start_fetch,
'end_fetch': end_fetch,
'filter_kwargs': filter_kwargs,
'covariate_kwargs': covariate_kwargs
}
for contig, start, end, start_fetch, end_fetch in
blacklisted_binning_contigs(**job_generation_args)))):
print(round(100 * (i / jobs_total), 1), end='\r')
for key, tf in r.items():
try:
joined[key][0] += tf[0]
joined[key][1] += tf[1]
except KeyError:
joined[key] = array('l', [0, 0])
joined[key][0] += tf[0]
joined[key][1] += tf[1]
return joined
def recalibrate_base_calls(read, reference, joined_prob, covariate_kwargs):
# @todo: make copy to save phred scores of soft-clipped bases
# This array will contain all recalibrated phred scores:
new_qualities = array('B', [0] * len(read.query_qualities))
# Iterate all aligned pairs and replace phred score:
for qpos, refpos in read.get_aligned_pairs(matches_only=True, with_seq=False):
key = get_covariate_key(read, qpos, refpos, reference, None, **covariate_kwargs)
try:
phred = joined_prob[key]
except KeyError:
phred = joined_prob[None]
new_qualities[qpos] = phred
read.query_qualities = new_qualities
def _recalibrate_reads(bam_path, reference_path, contig, start, end, covariate_kwargs, **kwargs):
# Recalibrate the reads in bam_path
global joined_prob # Global to share over multiprocessing
# joined_prob contains P(error| d), where d is a descriptor generated by get_covariate_key
o_path = f'out_{uuid4()}.bam'
# Open source bam file:
with AlignmentFile(bam_path) as alignments, FastaFile(reference_path) as fa:
# @todo: extract only selected region from fasta file:
reference = CachedFasta(fa)
# Open target bam file:
with AlignmentFile(o_path, header=alignments.header, mode='wb') as out:
# Iterate all reads in the source bam file:
for read in alignments.fetch(contig, start, end):
recalibrate_base_calls(read, reference, joined_prob, covariate_kwargs)
out.write(read)
pysam.index(o_path)
return o_path
def __recalibrate_reads(kwargs):
return _recalibrate_reads(**kwargs)
def recalibrate_reads(bam_path, target_bam_path, reference_path, n_processes, covariates, covariate_kwargs, intermediate_bam_size=20_000_000):
job_generation_args = {
'contig_length_resource': bam_path,
'bin_size': intermediate_bam_size,
'fragment_size': 0
}
global joined_prob
joined_prob = covariates
print(len(covariates), 'discrete elements')
with Pool(n_processes) as workers:
intermediate_bams = list( workers.imap_unordered(__recalibrate_reads, (
{
'bam_path': bam_path,
'reference_path': reference_path,
'contig': contig,
'start': None,
'end': None,
# 'start_fetch': start_fetch,
#'end_fetch': end_fetch,
'covariate_kwargs': covariate_kwargs
}
for contig in list(get_contigs_with_reads(bam_path)))))
#for contig, start, end, start_fetch, end_fetch in
#blacklisted_binning_contigs(**job_generation_args))))
merge_bams(intermediate_bams, target_bam_path)
if __name__=='__main__':
argparser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Obtain base calling biases using complete posterior distribution and perform corrections. Both identifying the covariates and the recallibration are multithreaded""")
argparser.add_argument('bamfile', metavar='bamfile', type=str)
argparser.add_argument('-reference', help="Path to reference fasta file used to generate the bamfile", required=True)
argparser.add_argument('-known', help="vcf file with known variation", required=False)
argparser.add_argument('-threads', help="Amount of threads to use. Uses all when not set")
argparser.add_argument(
'-covariates_out',
type=str,
help='Write covariates to this file, ends in .pickle.gz. ')
argparser.add_argument(
'-covariates_in',
type=str,
help='Read in existing covariates, ends in .pickle.gz')
argparser.add_argument(
'-bam_out',
type=str,
required=False,
help='Write corrected bam file here')
argparser.add_argument(
'--f',
action= 'store_true')
args = argparser.parse_args()
assert args.bam_out!=args.bamfile, 'The input bam file name cannot match the output bam file'
if args.covariates_in is None and args.covariates_out is not None and args.bam_out is None and os.path.exists(args.covariates_out) and not args.f:
print('Output covariate file already exists. Use -f to overwrite.')
exit()
# Set defaults when nothing is supplied
if args.covariates_in is None and args.bam_out is None and args.covariates_out is None:
args.bam_out = args.bamfile.replace('.bam','.recall.bam')
if args.covariates_out is None:
args.covariates_out = args.bamfile.replace('.bam','.covariates.pickle.gz')
if args.covariates_in is not None:
print(f'Loading covariates from {args.covariates_in} ')
with gzip.open(args.covariates_in,'rb') as i:
covariates = pickle.load(i)
else:
covariates = extract_covariates_from_bam(args.bamfile,
reference_path=args.reference,
variant_blacklist_vcf_files= [] if args.known is None else [args.known],
known_variants=set()
)
if args.covariates_out is not None:
print(f'Writing covariates to {args.covariates_out}')
with gzip.open(args.covariates_out,'wb') as o:
pickle.dump(covariates, o)
if args.bam_out is not None:
# Create cov probs
covar_phreds = covariate_obs_to_phreds(covariates ,k_rad=1)
print(f'Writing corrected bam file to {args.bam_out}')
recalibrate_reads(bam_path=args.bamfile,
target_bam_path=args.bam_out,
reference_path=args.reference,
n_processes=args.threads,
covariates=covar_phreds,
covariate_kwargs = {
'cycle_bin_size': 3,
'k_rad' : 1
})