a b/singlecellmultiomics/statistic/plate.py
1
import matplotlib.patheffects as path_effects
2
import math
3
import string
4
import numpy as np
5
import pandas as pd
6
import matplotlib.pyplot as plt
7
from .statistic import StatisticHistogram
8
import singlecellmultiomics.pyutils as pyutils
9
import collections
10
11
import matplotlib
12
matplotlib.rcParams['figure.dpi'] = 160
13
matplotlib.use('Agg')
14
15
16
def human_readable(value, targetDigits=2, fp=0):
17
18
    # Float:
19
    if value < 1 and value > 0:
20
        return('%.2f' % value)
21
22
    if value == 0.0:
23
        return('0')
24
25
    baseId = int(math.floor(math.log10(float(value)) / 3.0))
26
    suffix = ""
27
    if baseId == 0:
28
        sVal = str(round(value, targetDigits))
29
        if len(sVal) > targetDigits and sVal.find('.'):
30
            sVal = sVal.split('.')[0]
31
32
    elif baseId > 0:
33
34
        sStrD = max(0, targetDigits -
35
                    len(str('{:.0f}'.format((value / (math.pow(10, baseId * 3)))))))
36
37
        sVal = ('{:.%sf}' % min(fp, sStrD)).format(
38
            (value / (math.pow(10, baseId * 3))))
39
        suffix = 'kMGTYZ'[baseId - 1]
40
    else:
41
42
        sStrD = max(0, targetDigits -
43
                    len(str('{:.0f}'.format((value * (math.pow(10, -baseId * 3)))))))
44
        sVal = ('{:.%sf}' % min(fp, sStrD)).format(
45
            (value * (math.pow(10, -baseId * 3))))
46
        suffix = 'mnpf'[-baseId - 1]
47
48
        if len(sVal) + 1 > targetDigits:
49
            # :(
50
            sVal = str(round(value, fp))[1:]
51
            suffix = ''
52
53
    return('%s%s' % (sVal, suffix))
54
55
56
# Visualize the following:
57
# PER LIBRARY / DEMUX method
58
# total fragments
59
# total fragments with correct site
60
# unique molecules
61
62
# 384 well format:
63
64
well2index = collections.defaultdict(dict)
65
index2well = collections.defaultdict(dict)
66
rows = string.ascii_uppercase[:16]
67
columns = list(range(1, 25))
68
69
for ci in range(1, 385):
70
    i = ci - 1
71
    rowIndex = math.floor(i / len(columns))
72
    row = rows[rowIndex]
73
    column = columns[i % len(columns)]
74
    well2index[384][(row, column)] = ci
75
    index2well[384][ci] = (row, column)
76
77
78
rows96 = string.ascii_uppercase[:8]
79
80
columns96 = list(range(1, 13))
81
82
for ci in range(1, 97):
83
    i = ci - 1
84
    rowIndex = math.floor(i / len(columns96))
85
    row = rows96[rowIndex]
86
    column = columns96[i % len(columns96)]
87
    well2index[96][(row, column)] = ci
88
    index2well[96][ci] = (row, column)
89
90
91
class PlateStatistic(object):
92
93
    def __init__(self, args):
94
        self.args = args
95
96
        self.rawFragmentCount = collections.defaultdict(
97
            collections.Counter)  # (library, mux) -> cell -> counts
98
        self.usableCount = collections.defaultdict(
99
            collections.Counter)  # (library, mux) -> cell -> counts
100
        self.moleculeCount = collections.defaultdict(
101
            collections.Counter)  # (library, mux) -> cell -> counts
102
        self.skipReasons = collections.Counter()
103
104
    def to_csv(self, path):
105
        pd.DataFrame(
106
            self.moleculeCount).to_csv(
107
            path.replace(
108
                '.csv',
109
                'molecules.csv'))
110
        pd.DataFrame(
111
            self.usableCount).to_csv(
112
            path.replace(
113
                '.csv',
114
                'usable_reads.csv'))
115
        pd.DataFrame(
116
            self.rawFragmentCount).to_csv(
117
            path.replace(
118
                '.csv',
119
                'raw_fragments.csv'))
120
121
    def processRead(self, R1,R2):
122
123
        for read in [R1,R2]:
124
125
            if read is None:
126
                continue
127
128
            if not read.has_tag('MX'):
129
                return
130
131
            self.rawFragmentCount[(read.get_tag('LY'),
132
                                   read.get_tag('MX'))][read.get_tag('SM')] += 1
133
134
            if read.get_tag('MX').startswith('CS2'):
135
                if read.has_tag('XT') or read.has_tag('EX'):
136
                    if read.is_read1:  # We only count reads2
137
                        return
138
                    self.usableCount[(read.get_tag('LY'),
139
                                      read.get_tag('MX'))][read.get_tag('SM')] += 1
140
141
                    if read.has_tag('RC') and read.get_tag('RC') == 1:
142
                        self.moleculeCount[(read.get_tag('LY'), read.get_tag(
143
                            'MX'))][read.get_tag('SM')] += 1
144
            else:
145
146
                if read.has_tag('DS'):
147
                    if not read.is_read1:
148
                        self.skipReasons['Not R1'] += 1
149
                        return
150
151
                    self.usableCount[(read.get_tag('LY'),
152
                                      read.get_tag('MX'))][read.get_tag('SM')] += 1
153
                    if not read.is_duplicate:
154
                        self.moleculeCount[(read.get_tag('LY'), read.get_tag(
155
                            'MX'))][read.get_tag('SM')] += 1
156
                else:
157
                    self.skipReasons['No DS'] += 1
158
            break
159
160
161
    def __repr__(self):
162
        return 'Plate statistic'
163
164
    def cell_counts_to_dataframe(self, cell_counts, mux, name='raw_reads'):
165
        df = pd.DataFrame({name: cell_counts})
166
167
        offset = 0 # Offset is zero for all protocols since 0.1.12
168
169
        format = 384 if ('384' in mux or mux.startswith('CS2')) else 96
170
171
        df['col'] = [index2well[format]
172
                     [(offset + int(x.rsplit('_')[-1]))][1] for x in df.index]
173
        df['row'] = [-rows.index(index2well[format]
174
                                 [(offset + int(x.rsplit('_')[-1]))][0]) for x in df.index]
175
        df['size'] = (df[name] / np.percentile(df[name], 99) * 200)
176
177
        return df
178
179
    def __iter__(self):
180
        for data, name in [
181
            (self.rawFragmentCount, 'raw_reads'),
182
            (self.usableCount, 'usable_reads'),
183
                (self.moleculeCount, 'unique_molecules')]:
184
            for (library, mux), cellCounts in data.items():
185
                df = self.cell_counts_to_dataframe(cellCounts, mux, name=name)
186
                for i, row in df.iterrows():
187
                    yield i, row
188
189
    def plot(self, target_path, title=None):
190
        for data, name in [
191
            (self.rawFragmentCount, 'raw_reads'),
192
            (self.usableCount, 'usable_reads'),
193
                (self.moleculeCount, 'unique_molecules')]:
194
195
            for (library, mux), cellCounts in data.items():
196
197
                df = self.cell_counts_to_dataframe(cellCounts, mux, name=name)
198
                df.plot.scatter(x='col', y='row', s=df['size'],
199
                                c=[(0.2, 0.2, 0.5, 0.9)]
200
                                )
201
202
                # Annotate the outliers with values:
203
                ax = plt.gca()
204
                for ii, row in df.iterrows():
205
                    if row[name] > 0 and (
206
                        row[name] < np.percentile(
207
                            df[name],
208
                            5) or row[name] > np.percentile(
209
                            df[name],
210
                            95)):
211
                        text = ax.annotate(human_readable(int(row[name])), (row['col'], row['row']),
212
                                           ha='center', va='baseline', color='w', size=7)
213
                        text.set_path_effects([path_effects.Stroke(
214
                            linewidth=3, foreground='black'), path_effects.Normal()])
215
216
                plt.yticks(
217
                    sorted(
218
                        df['row'].unique())[
219
                        ::-1],
220
                    sorted(rows),
221
                    rotation=0)
222
                plt.xticks(
223
                    sorted(
224
                        df['col'].unique()),
225
                    sorted(columns),
226
                    rotation=0)
227
                plt.title(fr'{name} with ${mux}$ adapter' + f'\n{library}')
228
229
                # Create legend:
230
                #ld = []
231
                # for x in np.linspace(1, max(df[name]), 4):
232
            #        size = (x/np.percentile(df[name],99))*200
233
                #    ld.append( mlines.Line2D([], [], color='blue', marker='.', linestyle='None',
234
                # markersize=np.sqrt(size), label=f'{int(x)}:{size}'))
235
                # plt.legend(handles=ld)
236
237
                plt.savefig(
238
                    target_path.replace(
239
                        '.png',
240
                        '') +
241
                    f'{name}_{mux}_{library}.png')
242
                plt.close()