Switch to side-by-side view

--- a
+++ b/singlecellmultiomics/statistic/plate.py
@@ -0,0 +1,242 @@
+import matplotlib.patheffects as path_effects
+import math
+import string
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+from .statistic import StatisticHistogram
+import singlecellmultiomics.pyutils as pyutils
+import collections
+
+import matplotlib
+matplotlib.rcParams['figure.dpi'] = 160
+matplotlib.use('Agg')
+
+
+def human_readable(value, targetDigits=2, fp=0):
+
+    # Float:
+    if value < 1 and value > 0:
+        return('%.2f' % value)
+
+    if value == 0.0:
+        return('0')
+
+    baseId = int(math.floor(math.log10(float(value)) / 3.0))
+    suffix = ""
+    if baseId == 0:
+        sVal = str(round(value, targetDigits))
+        if len(sVal) > targetDigits and sVal.find('.'):
+            sVal = sVal.split('.')[0]
+
+    elif baseId > 0:
+
+        sStrD = max(0, targetDigits -
+                    len(str('{:.0f}'.format((value / (math.pow(10, baseId * 3)))))))
+
+        sVal = ('{:.%sf}' % min(fp, sStrD)).format(
+            (value / (math.pow(10, baseId * 3))))
+        suffix = 'kMGTYZ'[baseId - 1]
+    else:
+
+        sStrD = max(0, targetDigits -
+                    len(str('{:.0f}'.format((value * (math.pow(10, -baseId * 3)))))))
+        sVal = ('{:.%sf}' % min(fp, sStrD)).format(
+            (value * (math.pow(10, -baseId * 3))))
+        suffix = 'mnpf'[-baseId - 1]
+
+        if len(sVal) + 1 > targetDigits:
+            # :(
+            sVal = str(round(value, fp))[1:]
+            suffix = ''
+
+    return('%s%s' % (sVal, suffix))
+
+
+# Visualize the following:
+# PER LIBRARY / DEMUX method
+# total fragments
+# total fragments with correct site
+# unique molecules
+
+# 384 well format:
+
+well2index = collections.defaultdict(dict)
+index2well = collections.defaultdict(dict)
+rows = string.ascii_uppercase[:16]
+columns = list(range(1, 25))
+
+for ci in range(1, 385):
+    i = ci - 1
+    rowIndex = math.floor(i / len(columns))
+    row = rows[rowIndex]
+    column = columns[i % len(columns)]
+    well2index[384][(row, column)] = ci
+    index2well[384][ci] = (row, column)
+
+
+rows96 = string.ascii_uppercase[:8]
+
+columns96 = list(range(1, 13))
+
+for ci in range(1, 97):
+    i = ci - 1
+    rowIndex = math.floor(i / len(columns96))
+    row = rows96[rowIndex]
+    column = columns96[i % len(columns96)]
+    well2index[96][(row, column)] = ci
+    index2well[96][ci] = (row, column)
+
+
+class PlateStatistic(object):
+
+    def __init__(self, args):
+        self.args = args
+
+        self.rawFragmentCount = collections.defaultdict(
+            collections.Counter)  # (library, mux) -> cell -> counts
+        self.usableCount = collections.defaultdict(
+            collections.Counter)  # (library, mux) -> cell -> counts
+        self.moleculeCount = collections.defaultdict(
+            collections.Counter)  # (library, mux) -> cell -> counts
+        self.skipReasons = collections.Counter()
+
+    def to_csv(self, path):
+        pd.DataFrame(
+            self.moleculeCount).to_csv(
+            path.replace(
+                '.csv',
+                'molecules.csv'))
+        pd.DataFrame(
+            self.usableCount).to_csv(
+            path.replace(
+                '.csv',
+                'usable_reads.csv'))
+        pd.DataFrame(
+            self.rawFragmentCount).to_csv(
+            path.replace(
+                '.csv',
+                'raw_fragments.csv'))
+
+    def processRead(self, R1,R2):
+
+        for read in [R1,R2]:
+
+            if read is None:
+                continue
+
+            if not read.has_tag('MX'):
+                return
+
+            self.rawFragmentCount[(read.get_tag('LY'),
+                                   read.get_tag('MX'))][read.get_tag('SM')] += 1
+
+            if read.get_tag('MX').startswith('CS2'):
+                if read.has_tag('XT') or read.has_tag('EX'):
+                    if read.is_read1:  # We only count reads2
+                        return
+                    self.usableCount[(read.get_tag('LY'),
+                                      read.get_tag('MX'))][read.get_tag('SM')] += 1
+
+                    if read.has_tag('RC') and read.get_tag('RC') == 1:
+                        self.moleculeCount[(read.get_tag('LY'), read.get_tag(
+                            'MX'))][read.get_tag('SM')] += 1
+            else:
+
+                if read.has_tag('DS'):
+                    if not read.is_read1:
+                        self.skipReasons['Not R1'] += 1
+                        return
+
+                    self.usableCount[(read.get_tag('LY'),
+                                      read.get_tag('MX'))][read.get_tag('SM')] += 1
+                    if not read.is_duplicate:
+                        self.moleculeCount[(read.get_tag('LY'), read.get_tag(
+                            'MX'))][read.get_tag('SM')] += 1
+                else:
+                    self.skipReasons['No DS'] += 1
+            break
+
+
+    def __repr__(self):
+        return 'Plate statistic'
+
+    def cell_counts_to_dataframe(self, cell_counts, mux, name='raw_reads'):
+        df = pd.DataFrame({name: cell_counts})
+
+        offset = 0 # Offset is zero for all protocols since 0.1.12
+
+        format = 384 if ('384' in mux or mux.startswith('CS2')) else 96
+
+        df['col'] = [index2well[format]
+                     [(offset + int(x.rsplit('_')[-1]))][1] for x in df.index]
+        df['row'] = [-rows.index(index2well[format]
+                                 [(offset + int(x.rsplit('_')[-1]))][0]) for x in df.index]
+        df['size'] = (df[name] / np.percentile(df[name], 99) * 200)
+
+        return df
+
+    def __iter__(self):
+        for data, name in [
+            (self.rawFragmentCount, 'raw_reads'),
+            (self.usableCount, 'usable_reads'),
+                (self.moleculeCount, 'unique_molecules')]:
+            for (library, mux), cellCounts in data.items():
+                df = self.cell_counts_to_dataframe(cellCounts, mux, name=name)
+                for i, row in df.iterrows():
+                    yield i, row
+
+    def plot(self, target_path, title=None):
+        for data, name in [
+            (self.rawFragmentCount, 'raw_reads'),
+            (self.usableCount, 'usable_reads'),
+                (self.moleculeCount, 'unique_molecules')]:
+
+            for (library, mux), cellCounts in data.items():
+
+                df = self.cell_counts_to_dataframe(cellCounts, mux, name=name)
+                df.plot.scatter(x='col', y='row', s=df['size'],
+                                c=[(0.2, 0.2, 0.5, 0.9)]
+                                )
+
+                # Annotate the outliers with values:
+                ax = plt.gca()
+                for ii, row in df.iterrows():
+                    if row[name] > 0 and (
+                        row[name] < np.percentile(
+                            df[name],
+                            5) or row[name] > np.percentile(
+                            df[name],
+                            95)):
+                        text = ax.annotate(human_readable(int(row[name])), (row['col'], row['row']),
+                                           ha='center', va='baseline', color='w', size=7)
+                        text.set_path_effects([path_effects.Stroke(
+                            linewidth=3, foreground='black'), path_effects.Normal()])
+
+                plt.yticks(
+                    sorted(
+                        df['row'].unique())[
+                        ::-1],
+                    sorted(rows),
+                    rotation=0)
+                plt.xticks(
+                    sorted(
+                        df['col'].unique()),
+                    sorted(columns),
+                    rotation=0)
+                plt.title(fr'{name} with ${mux}$ adapter' + f'\n{library}')
+
+                # Create legend:
+                #ld = []
+                # for x in np.linspace(1, max(df[name]), 4):
+            #        size = (x/np.percentile(df[name],99))*200
+                #    ld.append( mlines.Line2D([], [], color='blue', marker='.', linestyle='None',
+                # markersize=np.sqrt(size), label=f'{int(x)}:{size}'))
+                # plt.legend(handles=ld)
+
+                plt.savefig(
+                    target_path.replace(
+                        '.png',
+                        '') +
+                    f'{name}_{mux}_{library}.png')
+                plt.close()