Switch to unified view

a b/singlecellmultiomics/bamProcessing/bamMutProfiler.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import matplotlib
5
matplotlib.rcParams['figure.dpi'] = 160
6
matplotlib.use('Agg')
7
import matplotlib.pyplot as plt
8
import pandas as pd
9
from glob import glob
10
import seaborn as sns
11
import pysam
12
import numpy as np
13
import multiprocessing
14
from datetime import datetime
15
from singlecellmultiomics.utils.plotting import GenomicPlot
16
from singlecellmultiomics.bamProcessing.bamBinCounts import count_fragments_binned, generate_commands, gc_correct_cn_frame, obtain_counts
17
import os
18
import argparse
19
from colorama import Fore, Style
20
21
22
if __name__ == '__main__':
23
    argparser = argparse.ArgumentParser(
24
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
25
        description="""Export and plot copy number profiles
26
    """)
27
    argparser.add_argument('bamfile', metavar='bamfile', type=str)
28
    argparser.add_argument('-ref', help='path to reference fasta', type=str, required=True)
29
    argparser.add_argument('-bin_size', default=500_000, type=int)
30
    argparser.add_argument('-max_cp', default=5, type=int)
31
    argparser.add_argument('-threads', default=16, type=int)
32
    argparser.add_argument('-bins_per_job', default=5, type=int)
33
    argparser.add_argument('-pct_clip', default=99.999, type=float)
34
    argparser.add_argument('-min_mapping_qual', default=40, type=int)
35
    argparser.add_argument('-molecule_threshold', default=5_000, type=int)
36
37
    argparser.add_argument('-rawmatplot', type=str, help='Path to raw matrix, plot is not made when this path is not supplied ')
38
    argparser.add_argument('-gcmatplot', type=str, help='Path to gc corrected matrix, plot is not made when this path is not supplied ')
39
    argparser.add_argument('-histplot', type=str, help='Path to histogram ')
40
41
    argparser.add_argument('-rawmat', type=str)
42
    argparser.add_argument('-gcmat', type=str)
43
44
    argparser.add_argument('-norm_method', default='median', type=str)
45
46
    args = argparser.parse_args()
47
48
    alignments_path = args.bamfile
49
    bin_size = args.bin_size
50
    MAXCP = args.max_cp
51
    pct_clip = args.pct_clip
52
    bins_per_job = args.bins_per_job
53
    min_mapping_qual = args.min_mapping_qual
54
    threads = args.threads
55
    molecule_threshold = args.molecule_threshold
56
57
    histplot=args.histplot
58
    rawmatplot=args.rawmatplot
59
    gcmatplot=args.gcmatplot
60
    rawmat=args.rawmat
61
    gcmat=args.gcmat
62
63
    reference = pysam.FastaFile(args.ref)
64
    h=GenomicPlot(reference)
65
    contigs = GenomicPlot(reference).contigs
66
67
    print("Creating count matrix ... ", end="")
68
    commands = generate_commands(
69
                alignments_path=alignments_path,
70
                bin_size=bin_size,key_tags=None,
71
                bins_per_job=5,head=None,min_mq=min_mapping_qual)
72
73
    counts = obtain_counts(commands,
74
                            reference=reference,
75
                            threads=threads,
76
                            live_update=False,
77
                            show_n_cells=None,
78
                            update_interval=None )
79
    print(f"\rCreating count matrix [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
80
81
    if histplot is not None:
82
        print("Creating molecule histogram ... ",end="")
83
        df = pd.DataFrame(counts).T.fillna(0)
84
        fig, ax = plt.subplots()
85
        cell_sums = df.sum()
86
        cell_sums.name = 'Frequency'
87
        cell_sums.plot.hist(bins=50)
88
        ax.set_xlabel('# molecules')
89
        ax.set_xscale('log')
90
        ax.axvline(molecule_threshold, c='r', label='Threshold')
91
        plt.legend()
92
        plt.savefig(histplot)
93
        print(f"\rCreating molecule histogram [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
94
95
96
    # Convert the count dictionary to a dataframe
97
    print("Filtering count matrix ... ", end="")
98
    df = pd.DataFrame(counts).T.fillna(0)
99
    # remove cells were the median is zero
100
    if args.norm_method=='median':
101
        try:
102
            shape_before_median_filter = df.shape
103
            df = df.T[df.median()>0].T
104
            shape_after_median_filter = df.shape
105
            print(shape_before_median_filter,shape_after_median_filter )
106
            # Remove rows with little counts
107
            df = df.T[df.sum()>molecule_threshold].T
108
            df = df / np.percentile(df,pct_clip,axis=0)
109
            df = np.clip(0,MAXCP,(df / df.median())*2)
110
            df = df.T
111
        except Exception as e:
112
            print(f"\rMedian normalisation [ {Fore.RED}FAIL{Style.RESET_ALL} ] ")
113
            args.norm_method = 'mean'
114
115
    if args.norm_method == 'mean':
116
        shape_before_median_filter = df.shape
117
        df = df.T[df.mean()>0].T
118
        shape_after_median_filter = df.shape
119
        # Remove rows with little counts
120
        df = df.T[df.sum()>molecule_threshold].T
121
        df = df / np.percentile(df,pct_clip,axis=0)
122
        df = np.clip(0,MAXCP,(df / df.mean())*2)
123
        df = df.T
124
125
126
127
    if df.shape[0]==0:
128
        print(f"\rRaw count matrix [ {Fore.RED}FAIL{Style.RESET_ALL} ] ")
129
        raise ValueError('Resulting count matrix is empty, review the filter settings')
130
    else:
131
        print(f"\rFiltering count matrix [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
132
        print( f'{df.shape[0]} cells, and {df.shape[1]} bins remaining' )
133
    del counts
134
135
    if rawmat is not None:
136
        print("Exporting raw count matrix ... ", end="")
137
        if rawmat.endswith('.pickle.gz'):
138
            df.to_pickle(rawmat)
139
        else:
140
            df.to_csv(rawmat)
141
142
        print(f"\rExporting raw count matrix [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
143
144
    if rawmatplot is not None:
145
        print("Creating raw heatmap ...", end="")
146
        h.cn_heatmap(df, figsize=(15,15))
147
        plt.savefig(rawmatplot)
148
        print(f"\rCreating raw heatmap [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
149
        plt.close('all')
150
151
    if gcmatplot is not None or gcmat is not None:
152
        print("Performing GC correction ...", end="")
153
        corrected_cells = gc_correct_cn_frame(df, reference, MAXCP, threads, norm_method=args.norm_method)
154
        print(f"\rPerforming GC correction [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
155
156
    if gcmatplot is not None:
157
        print("Creating heatmap ...", end="")
158
        h.cn_heatmap(corrected_cells,figsize=(15,15))
159
        plt.savefig(gcmatplot)
160
        plt.close('all')
161
        print(f"\rCreating heatmap [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
162
163
    if gcmat is not None:
164
        print("Exporting corrected count matrix ... ")
165
        if gcmat.endswith('.pickle.gz'):
166
            corrected_cells.to_pickle(gcmat)
167
        else:
168
            corrected_cells.to_csv(gcmat)
169
        print(f"\rExporting corrected count matrix [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")