Switch to unified view

a b/singlecellmultiomics/utils/bdbplot.py
1
from lxml import etree
2
import math
3
import collections
4
from collections import Counter
5
from collections import OrderedDict
6
import numpy as np
7
from singlecellmultiomics.utils import bdbbio
8
import os
9
import matplotlib.cm
10
from Bio import SeqIO
11
from Bio.Seq import Seq
12
import matplotlib.pyplot as plt
13
from colorama import Fore #,Back, Style
14
from colorama import Back
15
from colorama import Style
16
from colorama import init
17
import scipy
18
import scipy.cluster
19
import time
20
import itertools
21
22
init(autoreset=True)
23
24
25
26
#Convert a nested dictionary to a matrix
27
# ({'A':{'1':2}, 'B':{'1':3, '2':4}}) will become
28
#(array([[  2.,  nan],
29
#        [  3.,   4.]]), ['A', 'B'], ['1', '2'])
30
31
32
def interpolateBezier( points, steps=10, t=None):
33
    if len(points)==3:
34
        mapper = lambda t,p: (1-t)**2 * p[0] + 2*(1-t)*t*p[1] + t**2*p[2]
35
    elif len(points)==4:
36
        mapper = lambda t,p: (np.power( (1-t),3)*p[0] +\
37
         3* np.power((1-t),2) *t *p[1] +\
38
         3*(1-t)*np.power(t,2)*p[2] +\
39
         np.power(t,3)*p[3])
40
41
    if t is not None:
42
        return   mapper(t, [q[0] for q in points]), mapper(t, [q[1] for q in points])
43
    xGen = ( mapper(t, [q[0] for q in points]) for t in np.linspace(0, 1, steps) )
44
    yGen = ( mapper(t, [q[1] for q in points]) for t in np.linspace(0, 1, steps) )
45
46
    return zip(xGen, yGen)
47
48
def interpolateBezierAngle(points, t, ds=0.001):
49
    x0, y0 = interpolateBezier(points, t=t-ds)
50
    x1, y1 = interpolateBezier(points, t=t+ds)
51
    return np.arctan2( y1-y0, x0-x1)
52
53
54
def initMatrix(rowNames,columnNames, mtype="obj"):
55
    if mtype=="obj":
56
        matrix = np.empty( (len(rowNames), len(columnNames)), dtype=object)
57
    elif mtype=="npzeros":
58
        matrix = np.zeros( (len(rowNames), len(columnNames)))
59
    return(matrix)
60
61
def nestedDictionaryToNumpyMatrix( nestedDictionary, setNan=True, mtype="obj", transpose=False, indicateProgress=False):
62
63
    rowNames = sorted( list(nestedDictionary.keys()), key=int )
64
    columnNames = set()
65
    for key in nestedDictionary:
66
        columnNames.update( set(nestedDictionary[key].keys() ))
67
    columnNames = sorted( list(columnNames) )
68
69
    keys = list(columnNames)
70
71
    if ':' in keys[0]:
72
        sargs = np.argsort( [ int(k.split(':')[1]) for k in keys] )
73
        print(sargs)
74
75
        columnNames = [ keys[index] for index in sargs ]
76
        print(columnNames)
77
78
    matrix = initMatrix(rowNames,columnNames, mtype)
79
80
    if setNan:
81
        matrix[:] = np.nan
82
83
    prevTime = time.time()
84
    for rowIndex,rowName in enumerate(rowNames):
85
        if indicateProgress and (time.time()-prevTime)>1:
86
            prevTime = time.time()
87
            print("Matrix creation progress: %.2f%%" % (100.0*rowIndex/len(rowNames)))
88
89
        for colIndex,colName in enumerate(columnNames):
90
            try:
91
                matrix[rowIndex,colIndex] = nestedDictionary[rowName][colName]
92
            except:
93
                pass
94
95
    if transpose:
96
        matrix = matrix.transpose()
97
        columnNames, rowNames =  rowNames, columnNames
98
99
    if indicateProgress:
100
        print("Matrix finished")
101
    return( (matrix, rowNames, columnNames) )
102
103
def pruneNonUniqueColumnsFromMatrix(matrix,rows,columns, minInstances=1,minOccurence=1):
104
    colsToKeep = []
105
    for columnIndex in range(matrix.shape[1]):
106
        if len( set(np.unique( matrix[:,columnIndex].astype(str) ))-set( ["nan"]) )>minInstances:
107
108
            cnts = Counter( list(matrix[:,columnIndex].astype(str)) )
109
110
            counts = Counter({k: cnts[k] for k in  cnts if cnts[k] >= minOccurence})
111
            del counts['nan']
112
113
            if len(counts.values())>minInstances:
114
                colsToKeep.append(columnIndex)
115
116
    matrix = matrix[:,colsToKeep]
117
    columns = np.array(columns)[colsToKeep]
118
    return(matrix, rows, columns)
119
120
121
122
# Convert dictionary of tuples to a numpy matrix
123
def tupleAnnotationsToNumpyMatrix( originRowNames, originColNames, tuples, setNan=True, mtype="obj" ):
124
125
    m = initMatrix(originRowNames,originColNames, mtype) #np.zeros( (len(originRowNames), len(originColNames)))
126
    if setNan:
127
        m[:] = np.nan
128
    for tup in tuples:
129
        value = tuples[tup]
130
        if tup[0] in originRowNames and tup[1] in originColNames:
131
         m[originRowNames.index(tup[0]), originColNames.index(tup[1])  ] = value
132
    return(m)
133
134
def dictAnnotationsToNumpyMatrix( originRowNames, originColNames, dictionary, mtype="obj" ):
135
    tuples = {}
136
    for rowKey in dictionary:
137
        for columnKey in dictionary[rowKey]:
138
            tuples[ (rowKey, columnKey) ] = dictionary[rowKey][columnKey]
139
    return(tupleAnnotationsToNumpyMatrix(originRowNames, originColNames, tuples, mtype=mtype))
140
141
142
def getSomeColors(n):
143
    return( plt.cm.Set1(np.linspace(0, 1, n)) )
144
145
146
147
def _ipol(a, b, first, last,  interpolateValue):
148
    #Due to floating point rounding errors the interpolate value can be very close to last,
149
    # it is ok to return last in those cases
150
    if last>first and interpolateValue>=last:
151
        return(b)
152
    if last<first and interpolateValue>=first:
153
        return(a)
154
155
    y_interp = scipy.interpolate.interp1d([first, last], [a,b])
156
    return( y_interp(interpolateValue) )
157
158
def interpolate(interpolateValue,  colorScaleKeys, nodeColorMapping):
159
        #Seek positions around value to interpolate
160
        first = colorScaleKeys[0]
161
        index = 0
162
        last = first
163
        for value in colorScaleKeys:
164
            if value>=interpolateValue:
165
                last = value
166
                break
167
            else:
168
                first = value
169
            index+=1
170
        if value==interpolateValue:
171
            return(nodeColorMapping[value])
172
173
        #Do interpolation
174
        colorA = nodeColorMapping[first]
175
        colorB = nodeColorMapping[last]
176
        dx = last-first
177
178
        # Check out of bounds condition
179
        if interpolateValue< first:
180
            return(colorA)
181
        if interpolateValue>last:
182
            return(colorB)
183
184
185
186
        return( _ipol(colorA[0], colorB[0], first, last, interpolateValue), _ipol(colorA[1], colorB[1], first, last, interpolateValue), _ipol(colorA[2], colorB[2], first, last, interpolateValue))
187
188
189
190
def plotFeatureSpace(features, classLabels, featureNames, path=None, bins = 50, title=None):
191
192
    print(Fore.GREEN + "Feature space plotter:")
193
    #1d mode:
194
    print(features.shape)
195
    classAbundance = Counter(classLabels)
196
    print(classAbundance)
197
198
    classes = set(classLabels)
199
    classColors = ['#FF6A00','#0066FF','#FF33FF','#666666']
200
    classColors = [ tuple(float(int(hexColor.replace('#','')[i:i+2], 16))/255.0 for i in (0, 2 ,4)) for hexColor in classColors]
201
202
    if features.shape[1]==1:
203
        print("Performing 1-D density plot of %s samples" % ( features.shape[0]))
204
205
        plt.close('all')
206
207
        histStart = features.min()
208
        histEnd = features.max()
209
        if histStart == histEnd:
210
            histEnd += 1
211
            histStart -= 1
212
        precision = (histEnd-histStart)/bins
213
214
215
        print("Histogram will be plotted from %s to %s " % (histStart, histEnd))
216
        fig, ax = plt.subplots() #figsize=(120, 10))
217
218
        print(classColors)
219
        for classIndex,className in enumerate(list(classes)):
220
            boolList = np.array(classLabels)==np.array(className)
221
            classSize = classAbundance[className]
222
            ax.hist(
223
                features[boolList,0],
224
                np.arange(histStart,histEnd+precision,precision),
225
                normed=False, fc= (classColors[classIndex]+ (0.5,)),
226
                ec=classColors[classIndex],
227
                lw=1.5, histtype='stepfilled',
228
                label='%s[%s]' % (className,classSize)
229
                )
230
231
        plt.ylabel("Density")
232
        plt.xlabel(featureNames[0])
233
        if title is not None:
234
            plt.title(title)
235
        ax.legend(loc='upper right')
236
        #plt.yscale('log', nonposy='clip')
237
        if path is None:
238
            plt.show()
239
        else:
240
            plt.savefig(path)
241
        return(True)
242
243
244
    if features.shape[1]==2:
245
        #2d
246
        print("Performing 2-D density plot of %s samples" % ( features.shape[0]))
247
        fig, ax = plt.subplots()
248
249
        for classIndex,className in enumerate(list(classes)):
250
            #print(np.where( classLabels==className, features ))
251
            boolList = np.array(classLabels)==np.array(className)
252
            plt.plot( features[boolList,0], features[boolList,1], ".",label='%s[%s]' % (className,sum(boolList)), c= (classColors[classIndex] + (0.5,)))
253
254
        plt.xlabel(featureNames[0])
255
        plt.ylabel(featureNames[1])
256
        plt.legend(loc="lower right")
257
        plt.tight_layout()
258
        try:
259
            plt.savefig(path)
260
261
        except Exception as e:
262
            print(e)
263
        return(True)
264
    print(Fore.RED + "Invalid amount of dimensions for feature space plotting (%s)" % features.shape[1])
265
266
def matplotHeatmap( D, YC, figsize=(10,10), clust=True, xLab=None, yLab=None, show=True, colormap=plt.cm.YlGnBu_r, colorbarLabel=None ):
267
    plt.rcParams["axes.grid"] = False
268
    import scipy
269
    import scipy.cluster.hierarchy as sch
270
    # Compute and plot first dendrogram.
271
    fig = plt.figure(figsize=figsize)
272
    if not clust:
273
        idx1 = range(0, len(YC))
274
        idx1 = range(0, len(YC))
275
276
    if clust:
277
        ax1 = fig.add_axes([0.09,0.1,0.2,0.6])
278
279
        L = sch.linkage(D, method='centroid')
280
        Z1 = sch.dendrogram(L, orientation='right')
281
        ax1.set_xticks([])
282
        ax1.set_yticks([])
283
284
        # Compute and plot second dendrogram.
285
        ax2 = fig.add_axes([0.3,0.71,0.6,0.2])
286
        Z2 = sch.dendrogram(L)
287
        ax2.set_xticks([])
288
        ax2.set_yticks([])
289
        idx1 = Z1['leaves']
290
        idx2 = Z2['leaves']
291
        D = D[idx1,:]
292
        D = D[:,idx1]
293
294
    # Plot distance matrix.
295
    axmatrix = fig.add_axes([0.3,0.1,0.6,0.6])
296
297
    im = axmatrix.matshow(D, aspect='auto', origin='lower', cmap=colormap)
298
    axmatrix.set_xticks([])
299
    axmatrix.set_yticks([])
300
    ###
301
302
    if xLab is None:
303
        axmatrix.set_xticks(range(len(YC)))
304
        axmatrix.set_xticklabels(YC[idx1])
305
    else:
306
        axmatrix.set_xticks(range(len(xLab)))
307
        axmatrix.set_xticklabels(xLab)
308
    axmatrix.xaxis.set_label_position('bottom')
309
    axmatrix.xaxis.tick_bottom()
310
311
    plt.xticks(rotation=-90)
312
313
    if yLab is None:
314
        axmatrix.set_yticks(range(len(YC)))
315
        axmatrix.set_yticklabels(YC[idx1], minor=False)
316
    else:
317
        axmatrix.set_yticks(range(len(yLab)))
318
        axmatrix.set_yticklabels(yLab, minor=False)
319
    axmatrix.yaxis.set_label_position('right')
320
    axmatrix.yaxis.tick_right()
321
322
323
    # Plot colorbar.
324
    axcolor = fig.add_axes([1.05,0.1,0.02,0.6])
325
    cbar = fig.colorbar(im, cax=axcolor)
326
    if colorbarLabel is not None:
327
        cbar.set_label(colorbarLabel, rotation=270,  labelpad=15)
328
    if show:
329
        fig.savefig('dendrogram.png')
330
331
def tsnePlot(data, labels=None, components=2, perplexity=30.0, iterations=1000):
332
    from sklearn.manifold import TSNE
333
    #from MulticoreTSNE import MulticoreTSNE as TSNE
334
    model = TSNE(n_components=components, perplexity=perplexity, n_iter=iterations ) #random_state=0, n_jobs=8,
335
    transformedPoints = model.fit_transform(data.astype(np.float64))
336
337
338
    classes = list(set(labels))
339
    classColors = getSomeColors(len(classes))
340
    color = np.array([ classColors[classes.index(label)] for label in labels ])
341
    nplabels = np.array(labels)
342
    print("TSNE input:")
343
    print(data.shape)
344
    print("TSNE plotting for %s classes " %  len(classes))
345
346
    #print("Color mapping is %s" % ",".join(color))
347
    #Plot the data:
348
    if components==2:
349
        fig = plt.figure()
350
        #plt.style.use('ggplot')
351
        ax = fig.add_subplot(111)
352
353
        print(color)
354
        #plt.scatter(transformedPoints[:, 0], transformedPoints[:, 1], c=color, cmap=plt.cm.Spectral, s=1, alpha=0.5) #labels=labels,
355
        for classIndex, className in enumerate(classes):
356
            classColor = classColors[classIndex]
357
            print(className)
358
            print(classColor)
359
            plt.scatter(transformedPoints[className==nplabels, 0], transformedPoints[className==nplabels, 1],   s=3, alpha=0.9, label=className)  #c=classColor,, cmap=plt.cm.Spectral,
360
361
        #plt.axis('tight')
362
        box = ax.get_position()
363
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
364
365
        # Put a legend to the right of the current axis
366
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
367
368
        plt.show()
369
370
371
    elif components==3:
372
        from mpl_toolkits.mplot3d import Axes3D
373
        fig = plt.figure()
374
        ax = fig.add_subplot(111, projection='3d')
375
376
        for classIndex,className in enumerate(classes):
377
            samplesForClass = (className==nplabels)
378
            ax.scatter(transformedPoints[samplesForClass,0], transformedPoints[samplesForClass,1], transformedPoints[samplesForClass,2], c=classColors[classIndex], label=classIndex)
379
        ax.legend()
380
        plt.show()
381
382
383
class BDBcolor():
384
385
    def __init__(self, r=0, g=0, b=0, a=1.0 ):
386
387
        if str(r)[0]=='#':
388
            #parse hex colour:
389
            #parts = r.replace('#','').replace('(','').replace(')','').split(',')
390
            cleaned = r.replace('#','')
391
392
            #parts = [int(i) for i in parts]
393
            r =  int(cleaned[0:2], 16)
394
            g =  int(cleaned[2:4], 16)
395
            b =  int(cleaned[4:6], 16)
396
397
        self.r = max(0,min(255,r))
398
        self.g = max(0,min(255,g))
399
        self.b = max(0,min(255,b))
400
        self.a = max(0,min(1.0,a))
401
402
    def getRGBStr(self):
403
        return('rgb(%s,%s,%s)' % (int(self.r), int(self.g), int(self.b)))
404
405
    def getRGBAStr(self):
406
        return('rgba(%s,%s,%s,%s)' % (self.r, self.g, self.b, self.a))
407
408
409
    def getReadableInverted(self):
410
411
        hsv = self.getHSV()
412
        hsv['v'] = (  255.0-( hsv['v'] ) )
413
        rgb = self.HSVtoRGB(0,0,hsv['v'])
414
        return( BDBcolor( rgb['r'],rgb['g'],rgb['b'] ))
415
416
    def getHSV(self):
417
        h=0
418
        s=0
419
        v=0
420
        minV = min( self.r, self.g, self.b )
421
        maxV = max( self.r, self.g, self.b )
422
        v = maxV
423
        delta = maxV - minV
424
        if maxV != 0:
425
            s = delta / float(maxV)
426
        else:
427
            s = 0
428
            h = -1
429
            return({'h':h, 's':s, 'v':v})
430
431
        if delta==0:
432
                h = 255
433
        else:
434
                if self.r == maxV:
435
                    h = ( self.g - self.b ) / float(delta)
436
437
                else:
438
                        if self.g == maxV:
439
440
                                h = 2.0 + float( self.b - self.r ) / float(delta)
441
                        else:
442
                                h = 4.0 + float( self.r - self.g ) / float(delta)
443
444
445
        h *= 60
446
        if h < 0:
447
            h += 360
448
449
        return({'h':h, 's':s, 'v':v})
450
451
452
    def HSVtoRGB( self, h,s,v ):
453
454
455
        i=0
456
        f=0
457
        p=0
458
        q=0
459
        t = 0
460
461
        if s == 0:
462
            #Grey
463
            r = g = b = v
464
            return({'r':round(r), 'g':round(g), 'b':round(b)})
465
466
        h /= 60         # sector 0 to 5
467
        i = math.floor( h )
468
        f = h - i           # factorial part of h
469
        p = v * ( 1 - s )
470
        q = v * ( 1 - s * f )
471
        t = v * ( 1 - s * ( 1 - f ) )
472
473
        if i==0:
474
            r = v
475
            g = t
476
            b = p
477
        elif i==1:
478
            r = q
479
            g = v
480
            b = p
481
        elif i==2:
482
            r = p
483
            g = v
484
            b = t
485
        elif i==3:
486
            r = p
487
            g = q
488
            b = v
489
        elif i==4:
490
            r = t
491
            g = p
492
            b = v
493
        else:
494
            r = v
495
            g = p
496
            b = q
497
498
        return({'r':round(r), 'g':round(g), 'b':round(b)})
499
500
501
class BDBPlot():
502
503
504
    def __init__(self):
505
506
        # We need to declare the xlink namespace, to create references to things in our own file
507
        self.xlink =  'http://www.w3.org/1999/xlink'
508
509
        NSMAP = {'xlink':self.xlink }
510
511
        self.svgTree = etree.Element('svg',nsmap = NSMAP)
512
        self.svgTree.set('xmlns','http://www.w3.org/2000/svg')
513
        self.svgTree.set('version','1.2')
514
515
516
517
        self.root = self.svgTree.getroottree()
518
519
        self.nextFilterId = 0
520
        self.nextObjId = 0
521
        self.nextTspanId = 0
522
        #Create definition element
523
        self.svgTree.append( self.getDefinitionBlock() )
524
        self.debug = 2 # 2 all
525
526
        self.xMin = 0
527
        self.xMax = 10
528
        self.yMin = 0
529
        self.yMax = 10
530
        self.plotStartX = 30
531
        self.plotStartY = 30
532
        self.plotHeight = 400
533
        self.plotWidth = 600
534
535
        self.setWidth(800)
536
        self.setHeight(1000)
537
        self.script = ""
538
539
    def clear(self):
540
541
        toRm = []
542
        for child in self.svgTree:
543
            toRm.append(child)
544
        for child in toRm:
545
            self.svgTree.remove(child)
546
547
        self.svgTree.append( self.getDefinitionBlock() )
548
        self.nextFilterId = 0
549
        self.nextObjId = 0
550
551
    def getGroup(self, identifier, zIndex=0 ):
552
        g = etree.Element('g')
553
        g.set('id', str(identifier))
554
        self.svgTree.append(g)
555
        return(g)
556
557
558
    def getTspan(self):
559
        tspan = etree.Element('tspan')
560
        tspan.set('id', str(self.nextTspanId))
561
        self.nextTspanId+=1
562
        return(tspan)
563
564
    def addLegend(self,colorMapping):
565
566
        y = 0
567
        c = self.getYLabelCoord(y)
568
        yp = c[1]+10 + 80
569
        for color in colorMapping:
570
571
572
            text = self.getText(str(colorMapping[color]), c[0]+self.plotStartX, yp,BDBcolor(80,80,80,1))
573
            text.set('text-anchor','begin')
574
            text.set('dominant-baseline','middle')
575
            text.set('font-family','Cambria Math')
576
            text.set('fill','%s' % color)
577
            self.svgTree.append( text )
578
            yp+=20
579
580
581
582
    def getGroupColors(self, n):
583
584
        if n==1:
585
            return(['#3770C4'])
586
        if n==2:
587
            return(['#3770C4','#66A43E'])
588
        if n==3:
589
            return(['#3770C4','#66A43E','#F6853A'])
590
        if n==4:
591
            return(['#3770C4','#A43E3E','#66A43E','#F6853A'])
592
        if n==5:
593
            return(['#3770C4','#A43E3E','#66A43E','#F6853A','#A33DA2'])
594
        if n==6:
595
            return(['#3770C4','#A43E3E','#66A43E','#F6853A','#A33DA2','#AAD400'])
596
        if n==7:
597
            return(['#3770C4','#A43E3E','#66A43E','#F6853A','#A33DA2','#AAD400','#9DAC93'])
598
        if n==8:
599
            return(['#3770C4','#A43E3E','#66A43E','#F6853A','#A33DA2','#AAD400','#9DAC93','#7FCADF'])
600
        if n==9:
601
            return(['#3770C4','#A43E3E','#66A43E','#F6853A','#A33DA2','#AAD400','#9DAC93','#7FCADF','#D1AC17'])
602
603
604
605
        return(['#3770C4','#A43E3E','#66A43E','#F6853A','#A33DA2','#AAD400','#9DAC93','#7FCADF','#D1AC17','#000080','#FF0066','#6C5D53'] + ['#333333']*n)
606
607
608
    #Macro to set a title quickly
609
    def setTitle(self, string, x=None, y=10, size=25, fill='#333333' ):
610
611
        centerX = x is None
612
        if x is None:
613
            x,_= self.getPlottingCoord(self.xMin + 0.5*(self.xMax - self.xMin), 0)
614
615
616
        text = self.getText(str(string), x,y, fill=fill)
617
        if centerX:
618
            text.set('text-anchor','middle')
619
        else:
620
            text.set('text-anchor','begin')
621
        text.set('dominant-baseline','central')
622
        text.set('font-family','Gill Sans MT')
623
        text.set('font-size', str(size))
624
        self.svgTree.append(text)
625
        return(self)
626
627
    def setSubtitle(self, string, x=None, y=40, size=15, fill='#222222'):
628
        self.setTitle(string,x,y,size,fill)
629
630
    def setWidth(self, width):
631
        self.width = width
632
        self.svgTree.set('width','%s' % width)
633
634
    def setHeight(self, height):
635
        self.height = height
636
        self.svgTree.set('height','%s' % height)
637
638
639
    def getDx(self):
640
        return( float(self.plotWidth )/float((self.xMax - self.xMin)))
641
642
643
    def getDy(self):
644
        return( float(self.plotHeight )/float((self.yMax - self.yMin)))
645
646
647
    def getPlottingCoord(self, x,y,z=0):
648
649
        return( (self.plotStartX + (float(x)/(self.xMax - self.xMin))*self.plotWidth, self.plotHeight+self.plotStartY -(( float(y)/(self.yMax - self.yMin)))*self.plotHeight))
650
651
652
    def getXLabelCoord(self, x):
653
        return( (self.plotStartX + (float(x)/(self.xMax - self.xMin))*self.plotWidth, self.plotHeight+self.plotStartY+2 ))
654
655
656
    def getYLabelCoord(self, y):
657
        return( (self.plotStartX, self.plotHeight+self.plotStartY -(( float(y)/(self.yMax - self.yMin)))*self.plotHeight ))
658
659
    def getNextObjId(self):
660
        self.nextObjId+=1
661
        return(str(self.nextObjId))
662
663
664
    def getDefinitionBlock(self):
665
666
        self.defs = etree.Element('defs')
667
        self.defs.set('id','defs0')
668
        return(self.defs)
669
670
671
    def filter(self):
672
        filterDef = etree.Element('filter')
673
        filterDef.set('id','filter_%s' % (self.nextFilterId))
674
        self.nextFilterId+=1
675
        return(filterDef)
676
677
    def getAxis(self,hv=0):
678
679
        if hv==1:
680
            p = self.getPath(self.getPathDefinition([self.getPlottingCoord(self.xMin, self.yMin), self.getPlottingCoord(self.xMax, self.yMin)]))
681
        elif hv==2:
682
            p = self.getPath(self.getPathDefinition([self.getPlottingCoord(self.xMin, self.yMax),self.getPlottingCoord(self.xMin, self.yMin)]))
683
        else:
684
            p = self.getPath(self.getPathDefinition([self.getPlottingCoord(self.xMin, self.yMax),self.getPlottingCoord(self.xMin, self.yMin), self.getPlottingCoord(self.xMax, self.yMin)]))
685
        return(p)
686
687
    def getPathDefinition(self, coordinates, preventAliasing=False ):
688
689
        definition = []
690
        for idx,coordinateTuple in enumerate(coordinates):
691
692
            if preventAliasing:
693
                coordinateTuple = ( round(coordinateTuple[0])+0.5, round(coordinateTuple[1])+0.5 )
694
695
            if idx==0:
696
                definition.append('M%s,%s' % (coordinateTuple[0],coordinateTuple[1]))
697
            else:
698
                definition.append('L%s,%s' % (coordinateTuple[0],coordinateTuple[1]))
699
        return(' '.join(definition))
700
701
702
    def getLinearGradientDefinition(self, tuplesWithStops): #Format: %x, color
703
704
        definitionElement =  etree.Element('linearGradient')
705
        definitionElement.set('id', self.getNextObjId()) #not really needed
706
        definitionElement.set('x1', tuplesWithStops[0][0])
707
        definitionElement.set('y1', tuplesWithStops[0][0])
708
        definitionElement.set('x2', tuplesWithStops[-1][0])
709
        definitionElement.set('y2', tuplesWithStops[-1][0])
710
711
        for i, tup in enumerate(tuplesWithStops):
712
713
            stop = etree.SubElement(definitionElement, 'stop')
714
            stop.set('offset',tup[0])
715
            stop.set('stop-color',tup[1])
716
            if i==0:
717
                stop.set('class','start')
718
        stop.set('class','stop')
719
        return(definitionElement)
720
721
    def shadow(self, dy=2, dx=2, gaussStd=2,color='rgb(0,0,0)', floodOpacity=0.9,
722
     width=None, # Width: set a pixel region around the filter to prevent clipping (https://stackoverflow.com/questions/17883655/svg-shadow-cut-off)
723
     height=None):
724
        f = self.filter()
725
726
        if width is not None:
727
            f.set('width', str(width))
728
            f.set('x', '-%s' % (width*0.5))
729
        if height is not None:
730
            f.set('height', str(height))
731
            f.set('y', '-%s' % (height*0.5))
732
733
        f.set('color-interpolation-filters','sRGB')
734
        self.nextFilterId+=1
735
        #Flood
736
        flood = etree.SubElement(f, 'feFlood')
737
        flood.set('result','flood')
738
        flood.set('flood-color',color)
739
        flood.set('flood-opacity','%s' % floodOpacity)
740
741
        #Composite filter
742
        composite1 = etree.SubElement(f, 'feComposite')
743
        composite1.set('in2','SourceGraphic')
744
        composite1.set('operator','in')
745
        composite1.set('in','flood')
746
        composite1.set('result','composite1')
747
748
        #Gaussian blur
749
        gauss = etree.SubElement(f, 'feGaussianBlur')
750
        gauss.set('result','blur')
751
        gauss.set('stdDeviation','%s' % gaussStd)
752
753
        #Shadow offset
754
        offset = etree.SubElement(f, 'feOffset')
755
        offset.set('result','offset')
756
        offset.set('dy','%s' % dy)
757
        offset.set('dx','%s' % dx)
758
759
        #Final composite filter
760
        composite2 = etree.SubElement(f, 'feComposite')
761
        composite2.set('in2','offset')
762
        composite2.set('operator','over')
763
        composite2.set('in','SourceGraphic')
764
        composite2.set('result','composite2')
765
        return(f)
766
767
768
    def makeInnerShadow(self, shadow ):
769
770
        i = etree.SubElement(shadow, 'feComposite')
771
        i.set('operator','in')
772
        i.set('in2','SourceGraphic')
773
774
775
776
    def addDef(self, filterDef, defId=None):
777
        if defId is not None:
778
            filterDef.set('id',defId)
779
        self.defs.append(filterDef)
780
        return( filterDef.get('id') )
781
782
    def hasDef(self, defId):
783
        return( len(self.defs.findall(".//*[@id='%s']" % defId))>0 )
784
785
    def getDef(self,defId):
786
        if not self.hasDef(defId):
787
            print(('Definition %s was not found' % defId))
788
            exit()
789
        return( self.defs.findall(".//*[@id='%s']" % defId)[0] )
790
791
    def warn(self, msg):
792
        print(('[WARN] %s' % msg))
793
794
    def getRectangle(self, x,y, width, height):
795
        rectangle = etree.Element('rect')
796
        rectangle.set('id', self.getNextObjId())
797
        rectangle.set('x',str(x))
798
        rectangle.set('y',str(y))
799
        rectangle.set('width',str(width))
800
        rectangle.set('height',str(height))
801
        rectangle.set('style',"fill:rgba(100,100,100,1);stroke:#1b1b1b")
802
        return(rectangle)
803
804
    def getImage( self, path, x=None, y=None, width=None, height=None, preserveAspectRatio=None):
805
        image = etree.Element('image')
806
        image.set('id', self.getNextObjId())
807
        if x is not None:
808
            image.set('x',str(x))
809
        if y is not None:
810
            image.set('y',str(y))
811
        if width is not None:
812
            image.set('width',str(width))
813
814
        if preserveAspectRatio is not None:
815
            image.set('preserveAspectRatio',str(preserveAspectRatio))
816
817
        if height is not None:
818
            image.set('height',str(height))
819
820
        image.set('{%s}href'% self.xlink ,str(path))
821
822
        return(image)
823
824
825
826
    #Modify attribute in style
827
    def modifyStyleString(self, style, setAttr={},remove=[]):
828
829
        attributes = {}
830
831
        if style is not None and style.strip()!='':
832
            parts = style.split(';')
833
834
            for part in parts:
835
                kvPair = part.split(':')
836
                if len(kvPair)==2:
837
                    key = kvPair[0]
838
                    value = kvPair[1]
839
840
                    if key not in remove:
841
                        attributes[kvPair[0]] = kvPair[1]
842
                else:
843
                    self.warn('Style parsing %s failed (ignoring)' % part)
844
845
846
847
            if self.debug>=3:
848
                print('Style decomposition')
849
                for key in attributes:
850
                    print(('%s\t:\t%s' % (key, attributes[key]) ))
851
852
        #Roll changes:
853
        for attribute in setAttr:
854
            attributes[attribute] = setAttr[attribute]
855
856
        #Create new style string
857
        newStyle = []
858
        for attr in attributes:
859
            newStyle.append('%s:%s' % (attr, attributes[attr]))
860
861
        return(';'.join(newStyle))
862
863
    def modifyStyle(self, element,setAttr={},remove=[]):
864
865
        if not 'style' in element.attrib:
866
            element.set('style','')
867
868
        element.set('style', self.modifyStyleString(element.get('style'), setAttr, remove ))
869
870
871
    def setTextRotation(self, element, angle ):
872
        element.set('transform','rotate(%s, %s, %s)'%(angle,element.get('x'), element.get('y')))
873
874
875
876
    def humanReadable(self, value, targetDigits=2,fp=0):
877
878
        #Float:
879
        if value<1 and value>0:
880
            return('%.2f' % value )
881
882
        if value == 0.0:
883
            return('0')
884
885
        baseId = int(math.floor( math.log10(float(value))/3.0 ))
886
        suffix = ""
887
        if baseId==0:
888
            sVal =  str(round(value,targetDigits))
889
            if len(sVal)>targetDigits and sVal.find('.'):
890
                sVal = sVal.split('.')[0]
891
892
        elif baseId>0:
893
894
            sStrD = max(0,targetDigits-len(str( '{:.0f}'.format((value/(math.pow(10,baseId*3)))) )))
895
896
897
            sVal = ('{:.%sf}' % min(fp, sStrD)).format((value/(math.pow(10,baseId*3))))
898
            suffix = 'kMGTYZ'[baseId-1]
899
        else:
900
901
            sStrD = max(0,targetDigits-len(str( '{:.0f}'.format((value*(math.pow(10,-baseId*3)))) )))
902
            sVal = ('{:.%sf}' %  min(fp, sStrD)).format((value*(math.pow(10,-baseId*3))))
903
            suffix = 'mnpf'[-baseId-1]
904
905
            if len(sVal)+1>targetDigits:
906
                # :(
907
                sVal = str(round(value,fp))[1:]
908
                suffix = ''
909
910
911
        return('%s%s' % (sVal,suffix))
912
913
914
915
    def getText(self, text, x=0, y=0, fill=BDBcolor(), pathId=None):
916
917
918
        textElement =  etree.Element('text')
919
        textElement.set('id', self.getNextObjId())
920
921
        if pathId != None:
922
            tp =  etree.Element('textPath')
923
            tp.text = text
924
            tp.set('{%s}href'%self.xlink,  '#%s' % (pathId))
925
            tp.set("startOffset", "50%")
926
927
            textElement.append(tp)
928
        else:
929
            textElement.text = str(text)
930
931
        textElement.set('x',str(x))
932
        textElement.set('y',str(y))
933
934
        if type(fill) is str:
935
            textElement.set('fill',fill)
936
        else:
937
            textElement.set('fill',fill.getRGBStr())
938
        textElement.set('shape-rendering','crispEdges')
939
940
941
        return(textElement)
942
943
    def getCenteredText(self,text, x,y, bold=False, fontSize=14, fill='rgba(50,50,50,1)', **kwargs):
944
        text = self.getText(text,x,y,**kwargs)
945
        text.set('text-anchor','middle')
946
        text.set('dominant-baseline','middle')
947
        text.set('font-family','Helvetica')
948
        #text.set('font-family','Cambria Math')
949
        if bold:
950
            text.set('font-weight', 'bold')
951
        text.set('font-size', str(fontSize))
952
        text.set('fill', fill)
953
        return text
954
955
    def addTspan(self, textObject, text=None):
956
        tspan = self.getTspan()
957
        textObject.append(tspan)
958
959
        #Fill with text if supplied:
960
        if text is not None:
961
            tspan.text = text
962
963
        return(tspan)
964
    #superscript
965
    def addSuper(self, text, superText, offset=-10):
966
967
        superElement =  etree.Element('tspan')
968
        superElement.text = superText
969
        superElement.set('dy', '%s' % offset)
970
        text.append(superElement)
971
972
973
974
975
    def polarToCartesian(self, centerX, centerY, radius, angleInDegrees = None, angleInRadians = None):
976
977
        angleInRadians = (angleInDegrees-90) * math.pi / 180.0 if angleInDegrees is not None else angleInRadians
978
        return({
979
          'x': centerX + (radius * math.cos(angleInRadians)),
980
          'y': centerY + (radius * math.sin(angleInRadians))
981
        })
982
983
    def describeArc(self, x, y, radius, startAngle, endAngle, sweep=0, largeArcFlag=None):
984
985
        start = self.polarToCartesian(x, y, radius, endAngle)
986
        end = self.polarToCartesian(x, y, radius, startAngle)
987
988
        if largeArcFlag==None:
989
            if endAngle - startAngle <= 180:
990
                largeArcFlag  = "0"
991
            else:
992
                largeArcFlag  = "1"
993
994
        d = " ".join([str(x) for x in [
995
            "M", start['x'], start['y'],
996
            "A", radius, radius, 0, largeArcFlag, sweep, end['x'], end['y']
997
        ]])
998
999
        return(d)
1000
1001
    def describeArcRad(self, x, y, radius, startAngle, endAngle, sweep=0, largeArcFlag=None):
1002
1003
        start = self.polarToCartesian(x, y, radius, angleInRadians=startAngle )
1004
        end = self.polarToCartesian(x, y, radius, angleInRadians=endAngle  )
1005
1006
        if largeArcFlag==None:
1007
            if endAngle - startAngle <= math.pi:
1008
                largeArcFlag  = "0"
1009
            else:
1010
                largeArcFlag  = "1"
1011
1012
        d = " ".join([str(x) for x in [
1013
            "M", start['x'], start['y'],
1014
            "A", radius, radius, 0, largeArcFlag, sweep, end['x'], end['y']
1015
        ]])
1016
1017
        return(d)
1018
1019
1020
    def getCircle(self, centerX, centerY, radius):
1021
        circle = etree.Element('circle')
1022
        circle.set('id', self.getNextObjId())
1023
        circle.set('cx',str(centerX))
1024
        circle.set('cy',str(centerY))
1025
        circle.set('r',str(radius))
1026
        circle.set('style',"fill:none;stroke:#1b1b1b;stroke-width:1.29999995;stroke-linecap:round;stroke-miterlimit:4;stroke-opacity:1;stroke-dasharray:5.2, 5.2;stroke-dashoffset:0")
1027
        return(circle)
1028
1029
1030
    def getPath(self, pathDef):
1031
        path = etree.Element('path')
1032
        path.set('id', self.getNextObjId())
1033
        path.set('d', pathDef)
1034
        path.set('style',"fill:none;stroke:#1b1b1b;stroke-width:1;stroke-linecap:round")
1035
        return(path)
1036
1037
1038
    def dump(self):
1039
        print(( etree.toString(self.root, pretty_print=True)))
1040
1041
1042
1043
1044
    def write(self, path, pretty=False, htmlCallback=None, bodyCallback=None):
1045
1046
        try:
1047
            os.makedirs(os.path.dirname(path),exist_ok=True)
1048
        except:
1049
            pass
1050
1051
        if len(self.script)>0:
1052
            html = etree.Element('html')
1053
1054
1055
            head = etree.Element('head')
1056
            body = etree.Element('body')
1057
1058
            s = etree.Element('script')
1059
            s.set('type','text/javascript')
1060
            s.text = self.script
1061
1062
            jquery = etree.Element('script')
1063
            jquery.set('src','https://ajax.googleapis.com/ajax/libs/jquery/2.2.3/jquery.min.js')
1064
            jquery.text = ' '
1065
            head.append(jquery)
1066
1067
1068
            jqueryUi = etree.Element('link')
1069
            jqueryUi.set('rel', 'stylesheet')
1070
            jqueryUi.set('href','https://ajax.googleapis.com/ajax/libs/jqueryui/1.12.1/themes/smoothness/jquery-ui.css')
1071
            jqueryUi.text = ' '
1072
            head.append(jqueryUi)
1073
1074
            jqueryUi = etree.Element('script')
1075
            jqueryUi.set('src','https://ajax.googleapis.com/ajax/libs/jqueryui/1.12.1/jquery-ui.min.js')
1076
            jqueryUi.text = ' '
1077
            head.append(jqueryUi)
1078
1079
            jqueryColor = etree.Element('script')
1080
            jqueryColor.set('src','http://code.jquery.com/color/jquery.color-2.1.2.js')
1081
            jqueryColor.text = ' '
1082
            head.append(jqueryColor)
1083
1084
            body.append(self.svgTree)
1085
            html.append(head)
1086
1087
            body.append(s)
1088
1089
            if htmlCallback is not None:
1090
                htmlCallback(html)
1091
1092
            if bodyCallback is not None:
1093
                bodyCallback(body)
1094
            html.append(body)
1095
1096
            import html as pyhtml
1097
            if pretty:
1098
                import xml.dom.minidom as minidom
1099
                with open(path, 'w') as f:
1100
                    f.write( pyhtml.unescape(minidom.parseString(etree.tostring(html.getroottree())).toprettyxml(indent=" ").decode('utf-8') ))
1101
            else:
1102
1103
                #html.getroottree().write(path)
1104
                with open(path, 'w') as f:
1105
                    f.write( pyhtml.unescape( etree.tostring(html.getroottree()).decode('utf-8') ) )
1106
        else:
1107
            try:
1108
                self.svgTree.getroottree().write(path)
1109
            except:
1110
                print("failed saving %s" % path)
1111
        return(path)
1112
1113
    def SVGtoPNG(self,svgPath, pngPath, width=None, inkscapePath="C:\Program Files (x86)\Inkscape\inkscape.exe"):
1114
1115
        if width is not None:
1116
            pass
1117
        else:
1118
            width = self.width
1119
1120
1121
        #cmd = '"%(INKSCAPE_PATH)s" -z --verb=org.ekips.filter.embedimage --verb=FileSave --verb=FileClose -f %(source_svg)s -w %(width)s -j -e %(dest_png)s' %  {'INKSCAPE_PATH':inkscapePath, 'source_svg':svgPath, 'dest_png':pngPath, 'width':width}
1122
        cmd = '"%(INKSCAPE_PATH)s" -z  --verb=FileSave --verb=FileClose -f %(source_svg)s -w %(width)s -e %(dest_png)s' %  {'INKSCAPE_PATH':inkscapePath, 'source_svg':svgPath, 'dest_png':pngPath, 'width':width}
1123
        os.system('%s' % cmd)
1124
        os.system('%s' % cmd)
1125
1126
1127
1128
1129
1130
#circle = bdbplot.getCircle(200,200,50)
1131
#bdbplot.modifyStyle(circle, {'filter':'url(#%s)'%shadow.get('id')})
1132
#bdbplot.svgTree.append( circle )
1133
1134
#circle = bdbplot.getCircle(250,250,50)
1135
#bdbplot.modifyStyle(circle, {'filter':'url(#%s)'%shadow.get('id')})
1136
#bdbplot.svgTree.append( circle )
1137
1138
#path = bdbplot.getPath('M100 100 L300 100 L300 300 L300 100')
1139
#bdbplot.svgTree.append( path )
1140
1141
1142
##
1143
# Spaghettogram
1144
##
1145
1146
1147
1148
##
1149
# Histogram
1150
##
1151
1152
#
1153
# dictionary of read abundace->freq
1154
def readCountHistogram(abundanceFreqDict, logAbundance=True):
1155
1156
    #We expect a distribution which is very steep.
1157
    lookAhead = 3
1158
1159
1160
    if logAbundance:
1161
        f = abundanceFreqDict
1162
        abundanceFreqDict = Counter({})
1163
        for a in f:
1164
            #print(("%s %s" % (a, f[a])))
1165
            try:
1166
                logA = int(round( math.log10(int(a)*100), 0))
1167
            except Exception as e:
1168
                logA = 0
1169
                print(e)
1170
            abundanceFreqDict[logA] += f[a]
1171
           # print(("%s %s -> %s %s" % (a, f[a], logA, abundanceFreqDict[logA] )))
1172
1173
    #Find the highest abundant read:
1174
    hfreq = max([n for n in abundanceFreqDict])
1175
1176
    #Find closed distribution:
1177
    #Mapping from abundance value to plotting X coordinate
1178
    xxMapping = {}
1179
1180
    closedEnd = 1
1181
    perBin = 100
1182
    for c in range(1,perBin):
1183
1184
        if abundanceFreqDict[c]==0 and 0==sum(abundanceFreqDict[q] for q in range(c,c+lookAhead+1)):
1185
            closedEnd = c-1
1186
            break
1187
1188
        else:
1189
            xxMapping[c] = c-0.5
1190
1191
    #check how many extra blocks we need
1192
1193
    extraBlocks = 0
1194
    prevX = closedEnd-0.5
1195
    maxX = 1
1196
    prevAbundance = closedEnd
1197
    extraBlockCoords = []
1198
    extraBlockContinuous = {}
1199
    for abundance in sorted(abundanceFreqDict.keys()):
1200
        if abundance > closedEnd:
1201
1202
1203
            if (abundance-prevAbundance) > 1:
1204
                xxMapping[abundance] = prevX+2
1205
                extraBlockContinuous[extraBlocks] = True
1206
            else:
1207
                xxMapping[abundance] = prevX+1
1208
                extraBlockContinuous[extraBlocks] = False
1209
1210
            prevAbundance=abundance
1211
            extraBlockCoords.append(abundance)
1212
            prevX = xxMapping[abundance]
1213
            maxX= xxMapping[abundance]+1
1214
            extraBlocks+=1
1215
    #print(xxMapping)
1216
    bdbplot = BDBPlot()
1217
    bdbplot.plotStartX = 100
1218
    bdbplot.plotStartY = 150
1219
1220
    bdbplot.plotHeight =400
1221
    bdbplot.plotWidth = max(600, (maxX) * 25)
1222
1223
    bdbplot.setWidth(bdbplot.plotWidth+bdbplot.plotStartX+10)
1224
    bdbplot.setHeight(800)
1225
1226
1227
    bdbplot.xMax = max(1,maxX) # prevent 0 (breaks everything, 0 divisions and such)
1228
    bdbplot.yMax = max(1,int(math.log10(hfreq)))
1229
1230
    axis = bdbplot.getAxis(2)
1231
    bdbplot.svgTree.append( axis )
1232
1233
    #Draw specialised x-axis
1234
1235
    p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, bdbplot.yMin),bdbplot.getPlottingCoord(xxMapping[closedEnd]+1, bdbplot.yMin)]))
1236
    bdbplot.svgTree.append( p )
1237
1238
    for extraBlock in range(0,extraBlocks):
1239
        x = xxMapping[extraBlockCoords[extraBlock]]
1240
1241
        #p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(x, bdbplot.yMin),bdbplot.getPlottingCoord(x+1, bdbplot.yMin)]))
1242
1243
        if  extraBlockContinuous[extraBlock]:
1244
            d = 0.15
1245
            p = bdbplot.getPath(bdbplot.getPathDefinition([
1246
                bdbplot.getPlottingCoord(x-1, bdbplot.yMin),
1247
                bdbplot.getPlottingCoord(x-0.75, bdbplot.yMin+d),
1248
                bdbplot.getPlottingCoord(x-0.25, bdbplot.yMin-d),
1249
                bdbplot.getPlottingCoord(x, bdbplot.yMin)
1250
                ]))
1251
1252
            bdbplot.modifyStyle(p, {'stroke-width':'1', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1253
            bdbplot.svgTree.append( p )
1254
1255
1256
        p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(x, bdbplot.yMin),bdbplot.getPlottingCoord(x+1, bdbplot.yMin)]))
1257
        bdbplot.svgTree.append( p )
1258
1259
1260
1261
        #p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(x+2, bdbplot.yMin),bdbplot.getPlottingCoord(x+3, bdbplot.yMin)]))
1262
        #bdbplot.modifyStyle(p, {'stroke-width':'1', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1263
        #bdbplot.svgTree.append( p )
1264
1265
1266
    #Draw fine grid
1267
    for y in range(1,bdbplot.yMax+1):
1268
        p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)]))
1269
1270
        bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1271
        bdbplot.svgTree.append( p )
1272
1273
1274
    ### Block plotting ###
1275
    rectangles = []
1276
    barShadow = bdbplot.shadow(1,1)
1277
    bdbplot.addDef(barShadow)
1278
1279
    for abundance in range(1,int(hfreq)+1):
1280
1281
        if abundanceFreqDict[abundance]>0:
1282
            plotX = xxMapping[abundance]
1283
            frequency = abundanceFreqDict[abundance]
1284
1285
            if frequency==1:
1286
                value=0.20
1287
            else:
1288
                value = math.log10(frequency)
1289
1290
            c = bdbplot.getPlottingCoord(plotX,value)
1291
            origin = bdbplot.getPlottingCoord(plotX,0)
1292
1293
            barWidth = float(bdbplot.plotWidth)/(bdbplot.xMax+1)
1294
1295
            rectangleParams = (c[0], c[1], barWidth,  (float(value)/bdbplot.yMax) * bdbplot.plotHeight-3)
1296
            rectangles.append(rectangleParams)
1297
            bar = bdbplot.getRectangle( *rectangleParams )
1298
            bdbplot.modifyStyle(bar, {'filter':'url(#%s)'%barShadow.get('id'),'fill':'rgba(255,255,255,1)'})
1299
            bdbplot.svgTree.append( bar )
1300
1301
            text = bdbplot.getText(str( bdbplot.humanReadable(frequency,1 ) ), c[0]+0.5*barWidth, c[1]-10,BDBcolor(0,0,0,1))
1302
            text.set('text-anchor','middle')
1303
            text.set('dominant-baseline','middle')
1304
            text.set('font-family','Gill Sans MT')
1305
            #text.set('font-family','Cambria Math')
1306
            text.set('font-size', '14')
1307
            text.set('font-weight', 'bold')
1308
            text.set('fill', 'rgba(50,50,50,1)')
1309
            bdbplot.svgTree.append( text )
1310
1311
            #AXIS LABEL
1312
            c = bdbplot.getXLabelCoord(plotX+0.5)
1313
            text = bdbplot.getText(str(bdbplot.humanReadable(abundance)), c[0], c[1]+15,BDBcolor(80,80,80,1))
1314
            text.set('text-anchor','middle')
1315
            text.set('dominant-baseline','middle')
1316
            text.set('font-family','Cambria Math')
1317
            bdbplot.svgTree.append( text )
1318
1319
1320
1321
1322
    for rect in rectangles:
1323
        bar = bdbplot.getRectangle( *rect )
1324
        bdbplot.modifyStyle(bar, {'fill':bdbplot.getGroupColors(1)[0], 'stroke':'#FFFFFF','stroke-width':'1.5'})
1325
        bdbplot.svgTree.append( bar )
1326
1327
1328
1329
    #Y axis label
1330
    for y in range(1,bdbplot.yMax+1):
1331
        c = bdbplot.getYLabelCoord(y)
1332
1333
        value = math.pow(10,y)
1334
1335
        text = bdbplot.getText(str(10), c[0]-10, c[1],BDBcolor(80,80,80,1))
1336
        text.set('text-anchor','end')
1337
        text.set('dominant-baseline','middle')
1338
        text.set('font-family','Cambria Math')
1339
        bdbplot.addSuper(text,str(y))
1340
        bdbplot.svgTree.append( text )
1341
1342
1343
1344
    c = bdbplot.getYLabelCoord( (bdbplot.yMax/2))
1345
    text = bdbplot.getText('Frequency', c[0]-60, c[1]-30,BDBcolor(0,0,0,1))
1346
    text.set('text-anchor','middle')
1347
    text.set('dominant-baseline','middle')
1348
    text.set('font-family','Gill Sans MT')
1349
    text.set('font-size', '25')
1350
    #bdbplot.modifyStyle(text, {'font-size': '20'})
1351
    bdbplot.setTextRotation(text,270)
1352
    bdbplot.svgTree.append( text )
1353
1354
    c = bdbplot.getXLabelCoord(bdbplot.xMax/2)
1355
    text = bdbplot.getText('Read abundance', c[0], c[1]+50,BDBcolor(0,0,0,1))
1356
    text.set('text-anchor','middle')
1357
    text.set('dominant-baseline','middle')
1358
    text.set('font-family','Gill Sans MT')
1359
    text.set('font-size', '25')
1360
    bdbplot.svgTree.append( text )
1361
1362
1363
    return(bdbplot)
1364
1365
1366
1367
1368
class subdividedHistClass():
1369
1370
    def __init__(self, name, dataPoints, logTransform=False, offset=0):
1371
        self.logTransform = logTransform
1372
        self.totalValue = 0
1373
        self.barSpacerWidth = 5
1374
        self.barWidth = 40
1375
        self.maxValue = 0
1376
        self.bars = []
1377
        self.name = name
1378
        self.startX = offset
1379
        x = offset
1380
        self.width = 0
1381
        for dName,count in dataPoints.most_common():
1382
            self.totalValue+=count
1383
            x+=self.barSpacerWidth
1384
            self.bars.append({'x':x, 'y':count,'name':dName})
1385
            x+=self.barWidth
1386
1387
            self.maxValue = max(self.maxValue, count)
1388
1389
        self.width = x + self.barSpacerWidth - offset
1390
1391
    def plot(self, bdbplot, scarAliases,subClassColors ):
1392
        #Draw full class rectangle:
1393
        value = 1
1394
        if self.logTransform and self.totalValue>0:
1395
            value = math.log(self.totalValue)
1396
1397
        if not self.logTransform:
1398
            value = self.totalValue
1399
1400
        c = bdbplot.getPlottingCoord(self.startX,value)
1401
        origin = bdbplot.getPlottingCoord(self.startX,0)
1402
1403
        rectangleParams = (c[0], c[1], self.width,  (float(value)/bdbplot.yMax) * bdbplot.plotHeight)
1404
        bar = bdbplot.getRectangle( *rectangleParams )
1405
        bdbplot.modifyStyle(bar, {'fill':'rgba(150,150,150,0.8)','stroke-width':'0'})
1406
        bdbplot.svgTree.append( bar )
1407
1408
1409
        text = bdbplot.getText(self.name, c[0]+0.5*self.width, c[1]- 10,BDBcolor(0,0,0,1))
1410
        text.set('text-anchor','middle')
1411
        text.set('dominant-baseline','middle')
1412
        text.set('font-family','Gill Sans MT')
1413
        #text.set('font-family','Cambria Math')
1414
        text.set('font-size', '14')
1415
        text.set('font-weight', 'bold')
1416
        text.set('fill', 'rgba(50,50,50,1)')
1417
        bdbplot.svgTree.append( text )
1418
1419
        barShadow = bdbplot.shadow(1,1)
1420
        bdbplot.addDef(barShadow)
1421
        rectangles= []
1422
        for bar in self.bars:
1423
            #Add bar:
1424
            plotX = bar['x']
1425
            frequency= bar['y']
1426
            className = scarAliases[ bar['name'] ]
1427
            barColor = subClassColors[bar['name']]
1428
1429
            #Add class label to X axis:
1430
            c = bdbplot.getXLabelCoord(plotX+self.barWidth*0.5)
1431
            text = bdbplot.getText(className, c[0], c[1]+15,BDBcolor(80,80,80,1))
1432
            text.set('text-anchor','middle')
1433
            text.set('dominant-baseline','middle')
1434
            text.set('font-family','Cambria Math')
1435
            bdbplot.svgTree.append( text )
1436
1437
            if self.logTransform:
1438
                if frequency==1:
1439
                    value=0.20
1440
                else:
1441
                    value = math.log10(frequency)
1442
            else:
1443
                value = frequency
1444
1445
            c = bdbplot.getPlottingCoord(plotX,value)
1446
            origin = bdbplot.getPlottingCoord(plotX,0)
1447
1448
            #barWidth = float(bdbplot.plotWidth)/(bdbplot.xMax+1)
1449
            barWidth = self.barWidth
1450
            rectangleParams = (c[0], c[1], barWidth,  (float(value)/bdbplot.yMax) * bdbplot.plotHeight)
1451
            rectangles.append(rectangleParams)
1452
            bar = bdbplot.getRectangle( *rectangleParams )
1453
            bdbplot.modifyStyle(bar, {'stroke-width':'0','filter':'url(#%s)'%barShadow.get('id'),'fill':barColor})
1454
            bdbplot.svgTree.append( bar )
1455
1456
            text = bdbplot.getText(str( bdbplot.humanReadable(frequency,3,2 ) ), c[0]+0.5*barWidth, c[1]-10,BDBcolor(0,0,0,1))
1457
            text.set('text-anchor','middle')
1458
            text.set('dominant-baseline','middle')
1459
            text.set('font-family','Gill Sans MT')
1460
            #text.set('font-family','Cambria Math')
1461
            text.set('font-size', '14')
1462
            text.set('font-weight', 'bold')
1463
            text.set('fill', 'rgba(50,50,50,1)')
1464
1465
            bdbplot.svgTree.append( text )
1466
1467
            #Percentile:
1468
            text = bdbplot.getText(str( bdbplot.humanReadable(100*(float(frequency)/self.totalValue),3,2 )+'%' ), c[0]+0.5*barWidth, c[1]+10,BDBcolor(255,255,255,1))
1469
            text.set('text-anchor','middle')
1470
            text.set('dominant-baseline','middle')
1471
            text.set('font-family','Gill Sans MT')
1472
            #text.set('font-family','Cambria Math')
1473
            text.set('font-size', '14')
1474
            text.set('font-weight', 'bold')
1475
            text.set('fill', 'rgba(255,255,255,1)')
1476
1477
            bdbplot.svgTree.append( text )
1478
1479
1480
1481
def subdividedClassHistogram( classes, logTransform= False, scarAliases={}):
1482
1483
    classSpacerWidth = 20
1484
1485
    currentX = classSpacerWidth
1486
    classIndex = 0
1487
    maxValue = 0
1488
1489
    subHistClassList = []
1490
1491
    for className in classes:
1492
        shc = subdividedHistClass(className,classes[className], logTransform, currentX)
1493
        currentX += shc.width + classSpacerWidth
1494
1495
        if logTransform:
1496
            if shc.totalValue>0:
1497
                maxValue = max(maxValue,math.log(shc.totalValue))
1498
        else:
1499
            maxValue = max(maxValue,shc.totalValue)
1500
1501
        subHistClassList.append( shc )
1502
1503
1504
    bdbplot = BDBPlot()
1505
1506
    ## color list:
1507
    subClassColors = {}
1508
    idx = 0
1509
    #print(len(scarAliases))
1510
    gc = bdbplot.getGroupColors( len(scarAliases) )
1511
    for scar in scarAliases:
1512
        subClassColors[scar] = gc[idx]
1513
        print(('%s -> %s' %(scar, gc[idx])))
1514
        idx+=1
1515
1516
    ## Plot area preparation
1517
    bdbplot.plotStartX = 100
1518
    bdbplot.plotStartY = 100
1519
1520
    bdbplot.plotHeight =400
1521
    bdbplot.plotWidth = currentX
1522
1523
    bdbplot.setWidth(bdbplot.plotWidth+bdbplot.plotStartX+10)
1524
    bdbplot.setHeight(800)
1525
    bdbplot.xMax = max(1,currentX) # prevent 0 (breaks everything, 0 divisions and such)
1526
    if logTransform:
1527
        bdbplot.yMax = max(1,int(maxValue)+1)
1528
    else:
1529
        bdbplot.yMax = max(1,int(maxValue+1))
1530
1531
    axis = bdbplot.getAxis()
1532
    bdbplot.svgTree.append( axis )
1533
    if logTransform:
1534
        for y in range(1,bdbplot.yMax+1):
1535
            p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)], True))
1536
1537
            bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1538
            bdbplot.svgTree.append( p )
1539
1540
    #Draw fine grid
1541
    if logTransform:
1542
        for y in range(1,bdbplot.yMax+1):
1543
            p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)], True))
1544
1545
            bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1546
            bdbplot.svgTree.append( p )
1547
1548
    else:
1549
        stepSize = 50000
1550
        if bdbplot.yMax<101:
1551
            stepSize=10
1552
        for y in range(0,bdbplot.yMax+1,stepSize):
1553
            p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)], True))
1554
1555
            bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1556
            bdbplot.svgTree.append( p )
1557
1558
            c = bdbplot.getYLabelCoord(y)
1559
1560
            text = bdbplot.getText(bdbplot.humanReadable(y,2,3 ), c[0]-10, c[1],BDBcolor(80,80,80,1))
1561
            text.set('text-anchor','end')
1562
            text.set('dominant-baseline','middle')
1563
            text.set('font-family','Cambria Math')
1564
            bdbplot.svgTree.append( text )
1565
1566
1567
    if logTransform:
1568
        for y in range(1,bdbplot.yMax+1):
1569
            c = bdbplot.getYLabelCoord(y)
1570
1571
1572
            value = math.pow(10,y)
1573
1574
            text = bdbplot.getText(str(10), c[0]-10, c[1],BDBcolor(80,80,80,1))
1575
            text.set('text-anchor','end')
1576
            text.set('dominant-baseline','middle')
1577
            text.set('font-family','Cambria Math')
1578
            bdbplot.addSuper(text,str(y))
1579
            bdbplot.svgTree.append( text )
1580
1581
1582
    for shc in subHistClassList:
1583
        shc.plot(bdbplot,scarAliases,subClassColors)
1584
1585
1586
    c = bdbplot.getYLabelCoord( (bdbplot.yMax/2))
1587
    text = bdbplot.getText('Reads', c[0]-60, c[1]-30,BDBcolor(0,0,0,1))
1588
    text.set('text-anchor','middle')
1589
    text.set('dominant-baseline','middle')
1590
    text.set('font-family','Gill Sans MT')
1591
    text.set('font-size', '25')
1592
    #bdbplot.modifyStyle(text, {'font-size': '20'})
1593
    bdbplot.setTextRotation(text,270)
1594
    bdbplot.svgTree.append( text )
1595
1596
    c = bdbplot.getXLabelCoord(bdbplot.xMax/2)
1597
    text = bdbplot.getText('Sample', c[0], c[1]+50,BDBcolor(0,0,0,1))
1598
    text.set('text-anchor','middle')
1599
    text.set('dominant-baseline','middle')
1600
    text.set('font-family','Gill Sans MT')
1601
    text.set('font-size', '25')
1602
    bdbplot.svgTree.append( text )
1603
1604
1605
    return({'plot':bdbplot,'colorMapping':subClassColors})
1606
1607
1608
1609
def classHistogram( classCountMapping, logTransform = False, classColors=None,  placeLeft=None,placeRight=None, reverseOrder=True, height=400, xLabel='Sample', yLabel='Reads', yStepper = 50000, classWidth=50, freqFontSize=14, xLabelFontSize=10,  rotateClassLabels=0,zebraFillMode=False, defaultFillColor='#404040', barSpacing=5, xLabelOffset=50, showZeros=False, axisLabelFontSize=25, drawFreqLabels=True, title=None, freqMethod='humanReadable' ): # freqmethod 'humanReadable' ,'float'
1610
1611
    amountOfClasses = len(classCountMapping)
1612
    #classWidth = 50
1613
    maxValue = 0
1614
    for className in classCountMapping:
1615
        maxValue = max(classCountMapping[className],maxValue)
1616
1617
1618
    bdbplot = BDBPlot()
1619
    bdbplot.plotStartX = 100
1620
    bdbplot.plotStartY = 100
1621
1622
    bdbplot.plotHeight = height
1623
    bdbplot.plotWidth = max(600, (amountOfClasses) * classWidth)
1624
1625
    bdbplot.setWidth(bdbplot.plotWidth+bdbplot.plotStartX+10)
1626
    bdbplot.setHeight(700)
1627
1628
    if title is not None:
1629
        text = bdbplot.getText(title, 10,20, fill='#666666')
1630
        text.set('text-anchor','begin')
1631
        text.set('dominant-baseline','central')
1632
        text.set('font-family','Gill Sans MT')
1633
        text.set('font-size', '25')
1634
        bdbplot.svgTree.append(text)
1635
1636
1637
    bdbplot.xMax = max(1,amountOfClasses) # prevent 0 (breaks everything, 0 divisions and such)
1638
1639
    if logTransform:
1640
        bdbplot.yMax = max(1,int(math.log10(maxValue)+1))
1641
    else:
1642
        bdbplot.yMax = max(1,int(maxValue+1))
1643
1644
    axis = bdbplot.getAxis(2)
1645
    bdbplot.svgTree.append( axis )
1646
1647
    classIndex = 0
1648
1649
    rectangles = []
1650
    barShadow = bdbplot.shadow(1,1)
1651
    bdbplot.addDef(barShadow)
1652
1653
    whiteShadow =  bdbplot.shadow(1,1,3,'rgb(255,255,255)')
1654
    bdbplot.addDef(whiteShadow)
1655
1656
    #Draw fine grid
1657
    if logTransform:
1658
        for y in range(1,bdbplot.yMax+1):
1659
            p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)], True))
1660
1661
            bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1662
            bdbplot.svgTree.append( p )
1663
1664
    else:
1665
        for y in np.arange(0,bdbplot.yMax+yStepper,yStepper):
1666
            p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)], True))
1667
1668
            bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
1669
            bdbplot.svgTree.append( p )
1670
1671
            c = bdbplot.getYLabelCoord(y)
1672
            if(freqMethod=='humanReadable'):
1673
                text = bdbplot.getText(bdbplot.humanReadable(y,2,3 ), c[0]-10, c[1],BDBcolor(80,80,80,1))
1674
            else:
1675
                text = bdbplot.getText(float(y), c[0]-10, c[1],BDBcolor(80,80,80,1))
1676
            text.set('text-anchor','end')
1677
            text.set('dominant-baseline','middle')
1678
            text.set('font-family','Cambria Math')
1679
1680
            bdbplot.svgTree.append( text )
1681
1682
1683
1684
    if isinstance(classCountMapping, collections.OrderedDict):
1685
1686
        if reverseOrder:
1687
            classOrderKeys = list(reversed(list(classCountMapping.keys())))
1688
        else:
1689
            classOrderKeys = list(classCountMapping.keys())
1690
1691
        classOrder = [
1692
            (key, classCountMapping[key]) for key in classOrderKeys
1693
        ]
1694
1695
    else:
1696
        if reverseOrder:
1697
            classOrder = list(reversed(classCountMapping.most_common()))
1698
        else:
1699
            classOrder = list(classCountMapping.most_common())
1700
    classOrderKeys = {}
1701
    for className,freq in classOrder:
1702
        classOrderKeys[className] = (className, freq)
1703
    #Prepend left desired classes to the left
1704
    if placeLeft is not None:
1705
        for className in placeLeft:
1706
            #print(className)
1707
            tup = classOrderKeys[className]
1708
            classOrder.remove(tup)
1709
            classOrder.insert(0, tup)
1710
    if placeRight is not None:
1711
        for className in placeRight:
1712
            #print(className)
1713
            tup = classOrderKeys[className]
1714
            classOrder.remove(tup)
1715
            classOrder.append( tup)
1716
1717
    barWidth = (float(bdbplot.plotWidth)/(bdbplot.xMax+1))
1718
    for className, frequency in classOrder:
1719
1720
        #Add class label to X axis:
1721
        c = bdbplot.getXLabelCoord(classIndex)
1722
        text = bdbplot.getText(str(className), c[0] + ( 0.5*(barWidth)), c[1]+15,BDBcolor(80,80,80,1))
1723
        text.set('text-anchor','middle')
1724
        text.set('dominant-baseline','middle')
1725
        text.set('font-family','Cambria Math')
1726
        text.set('font-size',str(xLabelFontSize))
1727
1728
        if rotateClassLabels!=0:
1729
            bdbplot.setTextRotation(text, rotateClassLabels)
1730
            text.set('text-anchor','start')
1731
        if rotateClassLabels>90:
1732
            bdbplot.setTextRotation(text, rotateClassLabels)
1733
            text.set('text-anchor','end')
1734
1735
1736
        bdbplot.svgTree.append( text )
1737
1738
        #Add bar:
1739
        plotX = classIndex
1740
1741
        if logTransform:
1742
1743
1744
            if frequency==1:
1745
                value=0.20
1746
            elif frequency==0:
1747
                value= 0
1748
            else:
1749
                value = math.log10(frequency)
1750
        else:
1751
            value = frequency
1752
1753
        c = bdbplot.getPlottingCoord(plotX,value)
1754
        origin = bdbplot.getPlottingCoord(plotX,0)
1755
1756
1757
1758
1759
        #barWidth = float(bdbplot.plotWidth)/(bdbplot.xMax+1)
1760
1761
1762
        rectangleParams = (c[0]+0.5*barSpacing, c[1], barWidth-barSpacing,  (float(value)/bdbplot.yMax) * bdbplot.plotHeight-3)
1763
        rectangles.append(rectangleParams)
1764
        bar = bdbplot.getRectangle( *rectangleParams )
1765
        bdbplot.modifyStyle(bar, {'filter':'url(#%s)'%barShadow.get('id'),'fill':'rgba(255,255,255,1)'})
1766
        bdbplot.svgTree.append( bar )
1767
1768
        if (showZeros or frequency>0) and drawFreqLabels:
1769
            if(freqMethod=='humanReadable'):
1770
                text = bdbplot.getText(str( bdbplot.humanReadable(frequency,2,1 ) ), c[0]+0.5*barWidth, c[1]-10,BDBcolor(0,0,0,1))
1771
            else:
1772
                text = bdbplot.getText('%.2f' % frequency  , c[0]+0.5*barWidth, c[1]-10,BDBcolor(0,0,0,1))
1773
            text.set('text-anchor','middle')
1774
            text.set('dominant-baseline','middle')
1775
            text.set('font-family','Gill Sans MT')
1776
            #text.set('font-family','Cambria Math')
1777
            text.set('font-size', str(freqFontSize))
1778
            text.set('font-weight', 'bold')
1779
            text.set('fill', 'rgba(50,50,50,1)')
1780
            bdbplot.modifyStyle(text, {'filter':'url(#%s)'%whiteShadow.get('id')})
1781
1782
            bdbplot.svgTree.append( text )
1783
1784
        classIndex += 1
1785
1786
1787
    if logTransform:
1788
        for y in range(1,bdbplot.yMax+1):
1789
            c = bdbplot.getYLabelCoord(y)
1790
1791
1792
            value = math.pow(10,y)
1793
1794
            text = bdbplot.getText(str(10), c[0]-10, c[1],BDBcolor(80,80,80,1))
1795
            text.set('text-anchor','end')
1796
            text.set('dominant-baseline','middle')
1797
            text.set('font-family','Cambria Math')
1798
            bdbplot.addSuper(text,str(y))
1799
            bdbplot.svgTree.append( text )
1800
1801
1802
    for idx,rect in enumerate(rectangles):
1803
        bar = bdbplot.getRectangle( *rect )
1804
1805
        if zebraFillMode:
1806
            fillColor = bdbplot.getGroupColors(3)[idx%2]
1807
        else:
1808
            fillColor = defaultFillColor
1809
        if classColors is not None:
1810
            if classOrder[idx][0] in classColors:
1811
                fillColor = classColors[classOrder[idx][0]]
1812
            else:
1813
                #print(('Setting %s to default color' % classOrder[idx][0]))
1814
                pass
1815
1816
        bdbplot.modifyStyle(bar, {'fill':fillColor, 'stroke':'#FFFFFF','stroke-width':'1.75'})
1817
        bdbplot.svgTree.append( bar )
1818
1819
1820
1821
1822
    c = bdbplot.getYLabelCoord( (bdbplot.yMax/2))
1823
    text = bdbplot.getText(yLabel, c[0]-60, c[1]-30,BDBcolor(0,0,0,1))
1824
    text.set('text-anchor','middle')
1825
    text.set('dominant-baseline','middle')
1826
    text.set('font-family','Gill Sans MT')
1827
    text.set('font-size', str(axisLabelFontSize))
1828
    #bdbplot.modifyStyle(text, {'font-size': '20'})
1829
    bdbplot.setTextRotation(text,270)
1830
    bdbplot.svgTree.append( text )
1831
1832
    c = bdbplot.getXLabelCoord(bdbplot.xMax/2)
1833
    text = bdbplot.getText(xLabel, c[0], c[1]+xLabelOffset,BDBcolor(0,0,0,1))
1834
    text.set('text-anchor','middle')
1835
    text.set('dominant-baseline','middle')
1836
    text.set('font-family','Gill Sans MT')
1837
    text.set('font-size', str(axisLabelFontSize))
1838
    bdbplot.svgTree.append( text )
1839
1840
1841
    return(bdbplot)
1842
1843
1844
1845
class Heatmap(object):
1846
1847
    def __init__(self, npMatrix, colorMatrix=None, rowNames=None,  rowColors=None, columnColors=None, cellFormat=None, cellIdentifiers=None, cellRotations=None, cellStrings=None, columnNames=None, cellSize=25, title=None, subtitle=None, nominalColoring=False, rotateColumnLabels=90, cellAnnot=None, cellAnnotFormat=None, metaDataMatrix=None, cluster=False, groupSize=10 ):
1848
1849
        self.nominalColoring = nominalColoring
1850
        self.nominalColoringMapping = None
1851
        self.zeroColor = None
1852
        self.colormap = matplotlib.cm.get_cmap('plasma')
1853
        self.NanColor = (0.9,0.9,0.9,1)
1854
        self.rotateColumnLabels = rotateColumnLabels
1855
        self.footerHeight = 400
1856
        self.groupSpacerSize = 10
1857
        self.cellFont = 'Cambria'
1858
        self.labelFont = 'Cambria'
1859
1860
        print("Plotting %s by %s matrix" % npMatrix.shape)
1861
        if rowNames is not None:
1862
            print("Supplied %s rownames" % len(rowNames))
1863
            rowNames = [ "%s: %s"%t for t in enumerate(rowNames) ]
1864
1865
        if columnNames is not None:
1866
            print("Supplied %s column names" % len(columnNames))
1867
1868
        if cluster:
1869
            print("Clustering")
1870
            self.matrix = npMatrix
1871
            clusterMatrix = np.zeros( npMatrix.shape )
1872
            for (column,row), value in np.ndenumerate(npMatrix):
1873
                l = self.nominalIndex(value)
1874
                clusterMatrix[column,row] = l
1875
1876
            distances = scipy.spatial.distance.pdist( np.nan_to_num(clusterMatrix.transpose()), 'cityblock' )
1877
            mdistMatrix = scipy.spatial.distance.squareform(distances)
1878
            clustering = scipy.cluster.hierarchy.linkage( mdistMatrix, 'ward' )
1879
            leavesList = list( scipy.cluster.hierarchy.leaves_list(clustering) )
1880
            #npMatrix  = clusterMatrix
1881
            print(leavesList)
1882
        else:
1883
            leavesList = list(range(npMatrix.shape[1]))
1884
        print("Ranging %s" % len(leavesList))
1885
1886
        self.matrix = npMatrix[:,leavesList]
1887
        if colorMatrix is not None:
1888
            self.colorMatrix = colorMatrix[:,leavesList] #Values between zero and one
1889
        else:
1890
            self.colorMatrix = None
1891
        self.rowColors = np.array(rowColors)[leavesList] if rowColors is not None else []
1892
        self.columnColors = columnColors if columnColors is not None else []
1893
        self.cellIdentifiers = cellIdentifiers if cellIdentifiers is not None else None
1894
        self.cellAnnot = cellAnnot
1895
        self.rowNames =np.array(rowNames)[leavesList] if rowNames is not None else []
1896
        self.columnNames = columnNames if columnNames is not None else []
1897
        self.cellSize = cellSize
1898
        self.cellFormat = cellFormat if cellFormat is not None else lambda x: x
1899
        self.cellAnnotFormat  = cellAnnotFormat if cellAnnotFormat is not None else lambda x: x
1900
        self.cellStrings = cellStrings[:,leavesList] if cellStrings is not None else []
1901
        self.cellRotations = cellRotations[:,leavesList] if cellRotations is not None else None
1902
1903
        self.metaDataMatrix = metaDataMatrix[:,leavesList] if metaDataMatrix is not None else None
1904
        self.title = title
1905
        self.subtitle = subtitle
1906
        self.leftMargin = 80
1907
        self.labelWidth = 20
1908
        self.topMargin = 150
1909
        self.cellSpacing = 1
1910
        self.cellFontSize = 10
1911
        self.labelFontSize = 15
1912
        self.groupSize = groupSize
1913
        #self.colormap = matplotlib.cm.get_cmap('inferno')
1914
        print(self.rowNames)
1915
1916
1917
    def getRowName(self, index):
1918
        try:
1919
            return(str(self.rowNames[index]))
1920
        except:
1921
            return('')
1922
    def getColName(self, index):
1923
        try:
1924
            return(str(self.columnNames[index]))
1925
        except:
1926
            return('')
1927
1928
    def getRowColor(self,index):
1929
        try:
1930
            c = self.rowColors[index]
1931
1932
            return( c )
1933
        except:
1934
            return( BDBcolor( 50,50,50, 1) )
1935
1936
    def getColumnColor(self,index):
1937
        try:
1938
            c = self.columnColors[index]
1939
            return( c )
1940
        except:
1941
            return( BDBcolor( 50,50,50, 1) )
1942
1943
    def getCellString(self, row,column):
1944
        try:
1945
            return( self.cellFormat(self.cellStrings[column,row]))
1946
        except:
1947
            return('')
1948
1949
    def getCellAnnotString(self, row,column):
1950
        try:
1951
            return( str(self.cellAnnotFormat(self.cellAnnot[column, row])))
1952
        except:
1953
            return('')
1954
1955
    def getCellId(self, row,column):
1956
        try:
1957
            return( str(self.cellIdentifiers[column, row]) )
1958
        except:
1959
            return(None)
1960
    def getColor(self, value):
1961
1962
        theValueIsNan = self.isnan(value)
1963
1964
        if theValueIsNan:
1965
            r,g,b,a  = self.NanColor
1966
        elif self.zeroColor is not None and value==0:
1967
            r,g,b,a  = self.zeroColor
1968
        else:
1969
1970
            if self.nominalColoring:
1971
                r,g,b,a = self.nominalColor(value)
1972
            else:
1973
                try:
1974
                    r,g,b,a = self.colormap(value)
1975
                except:
1976
1977
                    print("Reverted to nominal coloring mode")
1978
                    self.nominalColoring = True
1979
                    r,g,b,a = self.nominalColor(value)
1980
        return(r,g,b,a)
1981
1982
1983
    def nominalIndex(self,value):
1984
        #Force build of nominal matrix
1985
        r,g,b,a = self.nominalColor(value)
1986
        #get index:
1987
        try:
1988
            idx = list(self.nominalColoringMapping.keys()).index(value)
1989
        except:
1990
            idx=0.01
1991
        return(idx / len(self.nominalColoringMapping.keys()))
1992
1993
1994
    def getCellMetaData(self, row,column):
1995
        try:
1996
            return( str(self.metaDataMatrix[column, row]) )
1997
        except:
1998
            return(None)
1999
2000
    def addTitle(self, plot):
2001
        if self.title is not None:
2002
            text = plot.getText(self.title, 10,20, fill='#666666')
2003
            text.set('text-anchor','begin')
2004
            text.set('dominant-baseline','central')
2005
            text.set('font-family','Gill Sans MT')
2006
            text.set('font-size', '25')
2007
            plot.svgTree.append(text)
2008
2009
    def getCellRotation(self, row, column):
2010
        try:
2011
            if self.cellRotations[column, row] == np.nan:
2012
                return(None)
2013
2014
            return( str(self.cellRotations[column, row]) )
2015
        except:
2016
            return(None)
2017
2018
2019
    def addSubtitle(self, plot):
2020
        if self.subtitle is not None:
2021
            text = plot.getText(self.subtitle, 10,40, fill='#222222')
2022
            text.set('text-anchor','begin')
2023
            text.set('dominant-baseline','central')
2024
            text.set('font-family','Cambria')
2025
            text.set('font-size', '15')
2026
            plot.svgTree.append(text)
2027
2028
    def addGroupTitle(self, plot, title, indexStart, indexEnd):
2029
2030
        matrixGroup = plot.svgTree.findall(".//g[@id='matrix']")[0]
2031
        #Calculate x-starting coordinate
2032
        xStart = self.columnIndexToXCoordinate(indexStart)
2033
        #Calculate x-ending coordinate
2034
        xEnd = self.columnIndexToXCoordinate(indexEnd)+self.cellSize
2035
        yStart = self.rowIndexToYCoordinate(0)-self.cellSize
2036
        yEnd = yStart - self.cellSize*0.5
2037
2038
        p = plot.getPath(plot.getPathDefinition([(xStart, yStart),(xStart,yEnd),(xEnd, yEnd), (xEnd, yStart)]))
2039
        plot.modifyStyle(p, {'stroke-width':'1', 'stroke-linecap':'round','stroke-dashoffset':'0', 'stroke':'#333333'} )
2040
        matrixGroup.append( p )
2041
2042
        text = plot.getText(title, xStart+0.5*( (xEnd-xStart)), yEnd - 8, fill='#000000')
2043
        text.set('text-anchor','middle')
2044
        text.set('dominant-baseline','central')
2045
        text.set('font-family','Gill Sans MT')
2046
        #stext.set('font-family','Cambria')
2047
        matrixGroup.append(text)
2048
2049
2050
    def columnIndexToXCoordinate(self, columnIndex, objSize=None):
2051
        objSize = self.cellSize if objSize is None else objSize
2052
        return(  (self.cellSize + (self.cellSize-objSize)*0.5)  + (columnIndex-1) * self.cellSize+ self.cellSpacing*columnIndex + int(columnIndex/self.groupSize)*self.groupSpacerSize)
2053
        #return(columnIndex * (self.cellSize+self.cellSpacing) + int(columnIndex/4)*self.groupSpacerSize)
2054
2055
2056
    def rowIndexToYCoordinate(self, rowIndex, objSize=None):
2057
        objSize = self.cellSize if objSize is None else objSize
2058
        return( (self.cellSize + (self.cellSize-objSize)*0.5) + (rowIndex-1) * (self.cellSize) + self.cellSpacing*rowIndex + int(rowIndex/self.groupSize)*self.groupSpacerSize)
2059
2060
2061
    def nominalColor(self, value):
2062
        if self.nominalColoringMapping == None:
2063
            #Find all unique values in the matrix:
2064
2065
            uniqueValues = list(set( v for _,v in np.ndenumerate( self.matrix.astype(str ) ) ))
2066
            try:
2067
                uniqueValues = sorted(uniqueValues)
2068
            except:
2069
                pass
2070
2071
            if len(uniqueValues)!=0:
2072
                self.nominalColoringMapping = {val:self.colormap(float(index)/float(len(uniqueValues))) for index,val in enumerate(uniqueValues)}
2073
            else:
2074
                return(self.NanColor)
2075
            print(self.nominalColoringMapping)
2076
2077
2078
        return( self.nominalColoringMapping.get(value, self.NanColor) )
2079
2080
    def isnan(self,value):
2081
2082
        theValueIsNan = False
2083
        if value is None:
2084
            theValueIsNan = True
2085
        else:
2086
            try:
2087
                theValueIsNan = np.isnan(value)
2088
            except:
2089
                theValueIsNan = False
2090
        return(theValueIsNan)
2091
2092
    def getPlot(self):
2093
2094
        plot = BDBPlot()
2095
2096
        matrixGroup = plot.getGroup('matrix')
2097
        matrixGroup.set('transform', 'translate(%s, %s)' % (self.leftMargin,self.topMargin ))
2098
        columnCount, rowCount = self.matrix.shape
2099
2100
        plotWidth = self.leftMargin + (columnCount*(self.cellSize+self.cellSpacing )) + self.groupSpacerSize*(columnCount/4)
2101
        plotHeight = self.topMargin + self.rowIndexToYCoordinate(rowCount) +self.cellSize+self.cellSpacing  + self.footerHeight
2102
        plot.setWidth(plotWidth)
2103
        plot.setHeight(plotHeight)
2104
        ySlack = self.cellSize/2
2105
2106
        cellShadow = plot.shadow(0.5,0.5,1, 'rgb(0,0,0)', 0.98)
2107
2108
        plot.addDef(cellShadow)
2109
2110
2111
        foreGroundTilesGroup = plot.getGroup('foreGroundTiles')
2112
        foreGroundTilesGroup.set('style', 'filter:url(#%s);'%cellShadow.get('id') )
2113
        matrixGroup.append(foreGroundTilesGroup)
2114
        for (column,row), value in np.ndenumerate(self.matrix):
2115
2116
            zIndex = 0
2117
            cellSize = self.cellSize*0.75 if self.getCellRotation(row, column)!=None and self.getCellRotation(row, column)!=0 and self.getCellRotation(row, column)!='nan' else self.cellSize
2118
            x = self.columnIndexToXCoordinate(column,cellSize)
2119
            y = self.rowIndexToYCoordinate(row,cellSize)
2120
            rect = plot.getRectangle(x,y,cellSize, cellSize)
2121
2122
            try:
2123
                cVal = self.colorMatrix[column,row]
2124
            except:
2125
                cVal = value
2126
2127
2128
            theValueIsNan = self.isnan(cVal)
2129
            r,g,b,a = self.getColor(value)
2130
2131
            r = int(r*255.0)
2132
            g = int(g*255.0)
2133
            b = int(b*255.0)
2134
            #print('rgb(%s,%s,%s)' %  (r,g,b))
2135
2136
2137
            rect.set( 'fill','rgb(%s,%s,%s)' %  (r,g,b))
2138
            plot.modifyStyle(rect, {'fill': 'rgba(%s,%s,%s,1)' % (r,g,b), 'stroke-width':'0', 'stroke':'rgba(%s,%s,%s,1)' % (0,0,0)})
2139
2140
            if self.getCellId(row,column) is not None:
2141
                rect.set('cell_id',self.getCellId(row,column))
2142
            if self.getCellMetaData(row,column) is not None:
2143
                rect.set('meta',self.getCellMetaData(row,column))
2144
2145
            if self.getCellRotation(row, column)!=None and self.getCellRotation(row, column)!=0 and self.getCellRotation(row, column)!='nan':
2146
                rect.set('transform','rotate(%s, %s, %s)'%(self.getCellRotation(row, column),x+cellSize*0.5, y+cellSize*0.5))
2147
                zIndex = len(matrixGroup)-1
2148
2149
            if theValueIsNan:
2150
                matrixGroup.insert( 0, rect)
2151
            else:
2152
                foreGroundTilesGroup.insert( zIndex, rect)
2153
2154
            brightness = 255 - ((r+g+b)/3.0)
2155
            if ((r+g+b)/3.0)<100:
2156
                c = BDBcolor( brightness, brightness, brightness, 1)
2157
            else:
2158
                c = BDBcolor( 0, 0, 0, 1)
2159
2160
            if theValueIsNan==False:
2161
                if self.cellAnnot is None:
2162
                    text = plot.getText(str(self.getCellString(row,column)), x+0.5*cellSize, y+0.5*cellSize, fill=c.getRGBStr())
2163
                    text.set('text-anchor','middle')
2164
                    text.set('dominant-baseline','central')
2165
                    #text.set('font-family','Gill Sans MT')
2166
                    text.set('font-family',self.cellFont)
2167
                    text.set('font-size', str(self.cellFontSize))
2168
                    matrixGroup.append(text)
2169
                else:
2170
                    text = plot.getText(str(self.getCellString(row,column)), x+0.5*cellSize, y+0.3*cellSize, fill=c.getRGBStr())
2171
                    text.set('text-anchor','middle')
2172
                    text.set('dominant-baseline','central')
2173
                    #text.set('font-family','Gill Sans MT')
2174
                    text.set('font-family',self.cellFont)
2175
                    text.set('font-size', str(self.cellFontSize))
2176
                    #text.set('font-weight','bold')
2177
                    matrixGroup.append(text)
2178
2179
                    text = plot.getText(str(self.getCellAnnotString(row,column)), x+0.5*cellSize, y+0.75*cellSize, fill=c.getRGBStr())
2180
                    text.set('text-anchor','middle')
2181
                    text.set('dominant-baseline','central')
2182
                    #text.set('font-family','Gill Sans MT')
2183
                    text.set('font-family',self.cellFont)
2184
                    text.set('font-size', str(self.cellFontSize*0.90))
2185
                    matrixGroup.append(text)
2186
2187
2188
            if column==0:
2189
                #x -= self.labelWidth
2190
2191
                self.leftMargin = max( len( self.getRowName(row) )*8, self.leftMargin )
2192
2193
                text = plot.getText(self.getRowName(row), self.columnIndexToXCoordinate(column)-self.labelWidth,  self.rowIndexToYCoordinate(row)+ySlack, fill=self.getRowColor(row))
2194
                text.set('text-anchor','end')
2195
                text.set('dominant-baseline','central')
2196
                #text.set('font-family','Gill Sans MT')
2197
                text.set('font-family',self.labelFont)
2198
                text.set('font-size', str(self.labelFontSize))
2199
2200
                matrixGroup.append(text)
2201
            if row==(rowCount-1) or row==0:
2202
                offset = self.labelWidth+self.cellSize if row==(rowCount-1) else -self.cellSize*0.5
2203
                text = plot.getText(self.getColName(column), self.columnIndexToXCoordinate(column)+self.cellSize*0.5,  self.rowIndexToYCoordinate(row)+offset, fill=self.getColumnColor(column))
2204
                text.set('text-anchor','middle')
2205
                text.set('dominant-baseline','central')
2206
                #text.set('font-family','Gill Sans MT')
2207
                text.set('font-family',self.labelFont)
2208
                text.set('font-size', str(self.labelFontSize))
2209
                if self.rotateColumnLabels!=0:
2210
                    plot.setTextRotation(text, self.rotateColumnLabels)
2211
                    if self.rotateColumnLabels>90:
2212
                        text.set('text-anchor','end' if row==(rowCount-1) else 'start')
2213
                    else:
2214
                        text.set('text-anchor','start' if row==(rowCount-1) else 'end')
2215
                matrixGroup.append(text)
2216
2217
        plot.svgTree.append(matrixGroup)
2218
        self.addTitle(plot)
2219
        self.addSubtitle(plot)
2220
2221
        matrixGroup.set('transform', 'translate(%s, %s)' % (self.leftMargin,self.topMargin ))
2222
        plotWidth = self.leftMargin + (columnCount*(self.cellSize+self.cellSpacing )) + self.groupSpacerSize*(columnCount/4)
2223
        plotHeight = self.topMargin + self.rowIndexToYCoordinate(rowCount) +self.cellSize+self.cellSpacing  + self.footerHeight
2224
        plot.setWidth(plotWidth)
2225
        plot.setHeight(plotHeight)
2226
        return(plot)
2227
2228
2229
2230
def testHeatmap():
2231
    xmin = 0.0
2232
    xmax = 10.0
2233
    dx = 1.0
2234
    ymin=0.0
2235
    ymax=5.0
2236
    dy = 1.0
2237
    x,y = np.meshgrid(np.arange(xmin,xmax,dx),np.arange(ymin,ymax,dy))
2238
    npMat = (x*y)
2239
    m = npMat.max()
2240
    print(m)
2241
    npMat /= m
2242
    print(npMat)
2243
    xlabels = ["X"+str(x) for x in range(npMat.shape[1])]
2244
    ylabels = ["Y"+str(y) for y in range(npMat.shape[0])]
2245
    h = Heatmap(npMat, npMat, rowNames=xlabels, columnNames=ylabels)
2246
    p = h.getPlot()
2247
    p.write('test.svg')
2248
2249
2250
def histogram(values = [1,7,3,2,1,0,0,0,1], rebin=False, binCount=9, reScale=False, logScale=False, logScaleData=False):
2251
2252
    if rebin:
2253
        newBars = {}
2254
        bars = [0]*(binCount+1)
2255
        frequencies = dict()
2256
        for v in values:
2257
            if not v in frequencies:
2258
                frequencies[v]=1
2259
            else:
2260
                frequencies[v]+=1
2261
2262
        #Take log of frequencies
2263
        if logScale:
2264
            for q in frequencies:
2265
2266
                if frequencies[q]>0:
2267
                    frequencies[q] = math.log10(frequencies[q])
2268
                else:
2269
                    frequencies[q] = -1
2270
2271
2272
        minValue = min([float(x) for x in list(frequencies.keys())])
2273
        maxValue = max([float(x) for x in list(frequencies.keys())])
2274
        binSize = (maxValue - minValue)/binCount
2275
2276
        sampleTotal = sum(frequencies.values())
2277
2278
        for binIndex in range(0,binCount+1):
2279
            binStart = minValue+ binIndex*binSize
2280
            binEnd = binStart+binSize
2281
            binTotal = 0
2282
            for d in frequencies.irange(binStart, binEnd, (True,False)):
2283
                binTotal+=frequencies[d]
2284
                if reScale:
2285
                    newBars[binStart+0.5*binSize] =  float(binTotal)/float(sampleTotal)
2286
                    bars[binIndex] = float(binTotal)/float(sampleTotal)
2287
                else:
2288
                    newBars[binStart+0.5*binSize] =  binTotal
2289
                    bars[binIndex] = binTotal
2290
2291
2292
    else:
2293
        bars=values
2294
2295
2296
    bdbplot = BDBPlot()
2297
2298
2299
    bdbplot.plotStartX = 100
2300
    bdbplot.plotStartY = 100
2301
    bdbplot.xMax = max(1,int(math.ceil(len(bars)))) # prevent 0 (breaks everything, 0 divisions and such)
2302
    bdbplot.yMax = max(1,int(math.ceil(max(bars))))
2303
2304
2305
    #plotWall = bdbplot.getRectangle(bdbplot.plotStartX,bdbplot.plotStartY,bdbplot.plotWidth,bdbplot.plotHeight)
2306
    #bdbplot.svgTree.append( plotWall )
2307
    #bdbplot.modifyStyle(plotWall, {'filter':'url(#%s)'%shadow.get('id'),'fill':'rgba(255,255,255,1)'})
2308
    #bdbplot.modifyStyle(plotWall, {'fill':'rgba(255,255,255,1)'})
2309
2310
    axis = bdbplot.getAxis()
2311
    bdbplot.svgTree.append( axis )
2312
2313
    #Draw fine grid
2314
    for y in range(1,bdbplot.yMax+1):
2315
        p = bdbplot.getPath(bdbplot.getPathDefinition([bdbplot.getPlottingCoord(bdbplot.xMin, y),bdbplot.getPlottingCoord(bdbplot.xMax, y)]))
2316
2317
        bdbplot.modifyStyle(p, {'stroke-width':'0.5', 'stroke-linecap':'round', 'stroke-dasharray':'2 2','stroke-dashoffset':'0'} )
2318
        bdbplot.svgTree.append( p )
2319
2320
2321
    # Add axis labels:
2322
    if logScaleData:
2323
        for x in range(0,bdbplot.xMax+1):
2324
            c = bdbplot.getXLabelCoord(x)
2325
2326
            text = bdbplot.getText(str(10), c[0], c[1]+15,BDBcolor(80,80,80,1))
2327
            text.set('text-anchor','middle')
2328
            text.set('dominant-baseline','middle')
2329
            text.set('font-family','Cambria Math')
2330
            bdbplot.addSuper(text,str(x))
2331
            bdbplot.svgTree.append( text )
2332
2333
    else:
2334
        for x in range(0,bdbplot.xMax+1):
2335
            c = bdbplot.getXLabelCoord(x)
2336
            text = bdbplot.getText(str(x), c[0], c[1]+15,BDBcolor(80,80,80,1))
2337
            text.set('text-anchor','middle')
2338
            text.set('dominant-baseline','middle')
2339
            text.set('font-family','Cambria Math')
2340
            bdbplot.svgTree.append( text )
2341
    if logScale:
2342
        for y in range(1,bdbplot.yMax+1):
2343
            c = bdbplot.getYLabelCoord(y)
2344
2345
            value = math.pow(10,y)
2346
2347
            text = bdbplot.getText(str(10), c[0]-10, c[1],BDBcolor(80,80,80,1))
2348
            text.set('text-anchor','end')
2349
            text.set('dominant-baseline','middle')
2350
            text.set('font-family','Cambria Math')
2351
            bdbplot.addSuper(text,str(y))
2352
            bdbplot.svgTree.append( text )
2353
2354
    else:
2355
        for y in range(0,bdbplot.yMax+1):
2356
            c = bdbplot.getYLabelCoord(y)
2357
            text = bdbplot.getText(str(y), c[0]-10, c[1],BDBcolor(80,80,80,1))
2358
            text.set('text-anchor','end')
2359
            text.set('dominant-baseline','middle')
2360
            text.set('font-family','Cambria Math')
2361
            bdbplot.svgTree.append( text )
2362
2363
2364
2365
    c = bdbplot.getYLabelCoord( (bdbplot.yMax/2))
2366
    text = bdbplot.getText('Frequency', c[0]-60, c[1]-30,BDBcolor(0,0,0,1))
2367
    text.set('text-anchor','middle')
2368
    text.set('dominant-baseline','middle')
2369
    text.set('font-family','Gill Sans MT')
2370
    text.set('font-size', '25')
2371
    #bdbplot.modifyStyle(text, {'font-size': '20'})
2372
    bdbplot.setTextRotation(text,270)
2373
    bdbplot.svgTree.append( text )
2374
2375
    c = bdbplot.getXLabelCoord(bdbplot.xMax/2)
2376
    text = bdbplot.getText('Read abundance', c[0], c[1]+50,BDBcolor(0,0,0,1))
2377
    text.set('text-anchor','middle')
2378
    text.set('dominant-baseline','middle')
2379
    text.set('font-family','Gill Sans MT')
2380
    text.set('font-size', '25')
2381
    bdbplot.svgTree.append( text )
2382
2383
2384
    barShadow = bdbplot.shadow(1,1)
2385
    bdbplot.addDef(barShadow)
2386
2387
    for barIndex,barValue in enumerate(bars):
2388
        c = bdbplot.getPlottingCoord(barIndex,barValue)
2389
        origin = bdbplot.getPlottingCoord(barIndex,0)
2390
        bar = bdbplot.getRectangle(  c[0], c[1], float(bdbplot.plotWidth)/(len(bars)+1),  (float(barValue)/bdbplot.yMax) * bdbplot.plotHeight )
2391
2392
        bdbplot.modifyStyle(bar, {'filter':'url(#%s)'%barShadow.get('id'),'fill':'#FFFFFF'})
2393
        bdbplot.svgTree.append( bar )
2394
2395
    for barIndex,barValue in enumerate(bars):
2396
        c = bdbplot.getPlottingCoord(barIndex,barValue)
2397
        origin = bdbplot.getPlottingCoord(barIndex,0)
2398
2399
        barWidth = float(bdbplot.plotWidth)/(len(bars)+1)
2400
2401
        bar = bdbplot.getRectangle(  c[0], c[1], barWidth,  (float(barValue)/bdbplot.yMax) * bdbplot.plotHeight )
2402
        bdbplot.modifyStyle(bar, {'fill':bdbplot.getGroupColors(1)[0], 'stroke':'#FFFFFF','stroke-width':'1.5'})
2403
        bdbplot.svgTree.append( bar )
2404
2405
        if barValue!=-1:
2406
            text = bdbplot.getText(str( bdbplot.humanReadable( int(math.pow(10,barValue)) ) ), c[0]+0.5*barWidth, c[1]-10,BDBcolor(0,0,0,1))
2407
            text.set('text-anchor','middle')
2408
            text.set('dominant-baseline','middle')
2409
            text.set('font-family','Gill Sans MT')
2410
            #text.set('font-family','Cambria Math')
2411
            text.set('font-size', '14')
2412
            text.set('font-weight', 'bold')
2413
            text.set('fill', 'rgba(50,50,50,1)')
2414
            bdbplot.svgTree.append( text )
2415
2416
    return(bdbplot)
2417
2418
def densityXY(scatterData, plotPath, xlabel='x', ylabel='y', logX=False, forceShow=False, logY=False):
2419
2420
2421
    from scipy.stats import gaussian_kde
2422
2423
    if len(scatterData['x'])==0:
2424
        #self.warn('No datapoints left for comparison')
2425
        return(False)
2426
2427
    #@todo: cast this earlier.
2428
    scatterData['x'] = np.array(scatterData['x'])
2429
    scatterData['y'] = np.array(scatterData['y'])
2430
2431
    xy = np.vstack([scatterData['x'],scatterData['y']])
2432
    z = gaussian_kde(xy)(xy)
2433
2434
    # Sort the points by density, so that the densest points are plotted last
2435
    idx = z.argsort()
2436
    scatterData['x'] = scatterData['x'][idx]
2437
    scatterData['y'] = scatterData['y'][idx]
2438
    z = z[idx]
2439
2440
    plt.close('all')
2441
    fig, ax = plt.subplots()
2442
    ax.scatter(scatterData['x'], scatterData['y'], c=z, s=50, edgecolor='')
2443
    if logX:
2444
        ax.set_xscale('log')
2445
    if logY:
2446
        ax.set_yscale('log')
2447
2448
    plt.ylabel(ylabel)
2449
    plt.xlabel(xlabel)
2450
    if plotPath and forceShow:
2451
        plt.show()
2452
    if plotPath is None:
2453
        plt.show()
2454
    else:
2455
        plt.savefig(plotPath, bbox_inches='tight')
2456
        plt.close('all')
2457
2458
2459
##
2460
# SIMPLE X Y PLOT
2461
##
2462
def simpleXY():
2463
2464
    bdbplot = BDBPlot()
2465
2466
    shadow = bdbplot.shadow()
2467
    bdbplot.addDef(shadow)
2468
    bdbplot.plotStartX = 100
2469
    bdbplot.plotStartY = 100
2470
    bdbplot.xMax = 10
2471
    bdbplot.yMax = 10
2472
2473
    #plotWall = bdbplot.getRectangle(bdbplot.plotStartX,bdbplot.plotStartY,bdbplot.plotWidth,bdbplot.plotHeight)
2474
    #bdbplot.svgTree.append( plotWall )
2475
    #bdbplot.modifyStyle(plotWall, {'filter':'url(#%s)'%shadow.get('id'),'fill':'rgba(255,255,255,1)'})
2476
    #bdbplot.modifyStyle(plotWall, {'fill':'rgba(255,255,255,1)'})
2477
2478
    axis = bdbplot.getAxis()
2479
    bdbplot.svgTree.append( axis )
2480
2481
    # Add axis labels:
2482
    for x in range(0,11):
2483
        c = bdbplot.getXLabelCoord(x)
2484
        text = bdbplot.getText(str(x), c[0], c[1]+15,BDBcolor(80,80,80,1))
2485
        text.set('text-anchor','middle')
2486
        text.set('dominant-baseline','middle')
2487
        text.set('font-family','Cambria Math')
2488
        bdbplot.addSuper(text,'y')
2489
        bdbplot.svgTree.append( text )
2490
2491
    for y in range(1,11):
2492
        c = bdbplot.getYLabelCoord(y)
2493
        text = bdbplot.getText(str(y), c[0]-15, c[1],BDBcolor(80,80,80,1))
2494
        text.set('text-anchor','middle')
2495
        text.set('dominant-baseline','middle')
2496
        text.set('font-family','Cambria Math')
2497
        bdbplot.svgTree.append( text )
2498
2499
2500
2501
    c = bdbplot.getYLabelCoord(5)
2502
    text = bdbplot.getText('Y Axis', c[0]-40, c[1],BDBcolor(0,0,0,1))
2503
    text.set('text-anchor','middle')
2504
    text.set('dominant-baseline','middle')
2505
    text.set('font-family','Gill Sans MT')
2506
    bdbplot.setTextRotation(text,270)
2507
    bdbplot.svgTree.append( text )
2508
2509
    c = bdbplot.getXLabelCoord(5)
2510
    text = bdbplot.getText('X Axis', c[0], c[1]+40,BDBcolor(0,0,0,1))
2511
    text.set('text-anchor','middle')
2512
    text.set('dominant-baseline','middle')
2513
    text.set('font-family','Gill Sans MT')
2514
    bdbplot.svgTree.append( text )
2515
2516
2517
2518
2519
    for x in range(0,10):
2520
2521
        c = bdbplot.getPlottingCoord(x,x)
2522
        circle = bdbplot.getCircle(c[0],c[1],2)
2523
        #bdbplot.modifyStyle(circle, {'filter':'url(#%s)'%shadow.get('id')})
2524
        bdbplot.svgTree.append( circle )
2525
2526
    #
2527
    #for i in range(0,10):
2528
    #
2529
    #   a = ((math.pi*2)/10.0) * i
2530
    #   text = bdbplot.getText('%s' % i,250 + 50*math.cos(a),250 + 50*math.sin(a),BDBcolor(80,80,80,1))
2531
    #   text.set('text-anchor','middle')
2532
    #   text.set('dominant-baseline','middle')
2533
    #   text.set('font-family','Cambria Math')
2534
    #   bdbplot.svgTree.append( text )
2535
2536
2537
    bdbplot.dump()
2538
    bdbplot.write('test.svg')
2539
2540
2541
#vals = [0,11]
2542
#
2543
#import random
2544
#for i in range(0,16):
2545
#   vals += [i]* int(math.ceil(math.exp(i/2+1)))
2546
#
2547
#
2548
##print('c(%s)' % ','.join(str(i) for i in vals))
2549
#
2550
#plot = histogram(vals, True,15, False, True)
2551
#
2552
#
2553
#plot.write('test.svg')
2554
2555
#
2556
#d = Counter({1:100,2:50,3:10,5:5,6:10,  1000:1, 500:2, 10000:3})
2557
#plot = readCountHistogram(d)
2558
#text = plot.getText('Embryo 1, component 1',10,40)
2559
#text.set('font-family','Gill Sans MT')
2560
#text.set('font-size', '42')
2561
#plot.svgTree.append( text )
2562
#
2563
#totalCount = sum( [v*d[v]for v in d] )
2564
#text = plot.getText( '%s reads total' % plot.humanReadable(totalCount),10,75, BDBcolor(77,77,77,1))
2565
#text.set('font-family','Gill Sans MT')
2566
#text.set('font-size', '23')
2567
#plot.svgTree.append( text )
2568
#
2569
#
2570
#plot.write('test.svg')
2571
2572
import networkx as nx
2573
import scipy.interpolate
2574
from Bio import pairwise2
2575
2576
2577
class GraphRenderer():
2578
2579
    def interpolate(self, interpolateValue,  colorScaleKeys, nodeColorMapping):
2580
2581
        #Seek positions around value to interpolate
2582
        first = colorScaleKeys[0]
2583
        index = 0
2584
        last = first
2585
        for value in colorScaleKeys:
2586
2587
            if value>=interpolateValue:
2588
                last = value
2589
                break
2590
            else:
2591
                first = value
2592
            index+=1
2593
        if value==interpolateValue:
2594
            return(nodeColorMapping[value])
2595
2596
        #Do interpolation
2597
        colorA = nodeColorMapping[first]
2598
        colorB = nodeColorMapping[last]
2599
        dx = last-first
2600
2601
        return( self._ipol(colorA[0], colorB[0], first, last, interpolateValue), self._ipol(colorA[1], colorB[1], first, last, interpolateValue), self._ipol(colorA[2], colorB[2], first, last, interpolateValue))
2602
2603
2604
    def _ipol(self,a, b, first, last,  interpolateValue):
2605
        #Due to floating point rounding errors the interpolate value can be very close to last,
2606
        # it is ok to return last in those cases
2607
        if last>first and interpolateValue>=last:
2608
            return(b)
2609
        if last<first and interpolateValue>=first:
2610
            return(a)
2611
2612
        y_interp = scipy.interpolate.interp1d([first, last], [a,b])
2613
        return( y_interp(interpolateValue) )
2614
2615
2616
    def sortByIndexAndBase(self, value):
2617
2618
        parts = value.split('_')
2619
        pos = ['A','T','C','G','N'].index(parts[1])
2620
        if pos==None:
2621
            pos=0
2622
        else:
2623
            pos+=1
2624
        return( int(parts[0]) + pos*0.1 )
2625
2626
2627
    def __init__(self, nxGraph, coloringMode = 'nodeRGB', coloringAttribute='confidence', performDistanceMeasure=True, performFrequencyMeasure=True, alias='none'):
2628
        self.g = nxGraph
2629
        self.undirectedG = self.g.to_undirected()
2630
        self.plot = BDBPlot()
2631
2632
        self.nodeShadow = self.plot.shadow(0.5,0.5,1, 'rgb(0,0,0)', 0.98)
2633
        self.plot.addDef(self.nodeShadow)
2634
2635
        minX = 0
2636
        maxX = 0
2637
        minY = 0
2638
        maxY = 0
2639
        for nodeName in self.g:
2640
            node = self.g.node[nodeName]
2641
            if 'x' in node and 'y' in node and 'size' in node :
2642
                if (node['x']-node['size'])<minX:
2643
                    minX = node['x']-node['size']
2644
                if (node['x']+node['size'])>maxX:
2645
                    maxX = node['x']+node['size']
2646
2647
                if (node['y']-node['size'])<minY:
2648
                    minY = node['y']-node['size']
2649
                if (node['y']+node['size'])>maxY:
2650
                    maxY = node['y']+node['size']
2651
2652
2653
2654
        #Estimate color scale: #############
2655
        colorScaleY = 0
2656
        colorScaleHeight = 0
2657
        colorScaleGraphSpacing = 0
2658
        createdColorScale = False
2659
        if coloringAttribute is not None and coloringMode == 'nodeRGB':
2660
            createdColorScale= True
2661
            colorScaleParts = 10
2662
            colorScaleWidth = 400
2663
            colorScaleHeight = 35
2664
            colorScaleSpacing = 5
2665
            colorScaleShadow = self.plot.shadow(1,1,1)
2666
            colorScaleDeltaX = float(colorScaleWidth-colorScaleSpacing)/colorScaleParts
2667
            self.plot.addDef(colorScaleShadow)
2668
            labelHeight = 15
2669
            colorScaleX = 5
2670
            colorScaleGraphSpacing = 10
2671
            colorScaleY = (maxY-minY)+colorScaleGraphSpacing
2672
2673
            nodeColorMapping = {}
2674
            abundanceMapping = Counter({})
2675
            lowestValue = 100000
2676
            highestValue = -lowestValue
2677
            for nodeName in self.g:
2678
                node = self.g.node[nodeName]
2679
                if coloringAttribute in node and 'r' in node:
2680
                    if node[coloringAttribute]==1:
2681
                        value = 0
2682
                    else:
2683
                        value = -math.log( 1.0 - node[coloringAttribute],10 )
2684
2685
                    abundanceMapping[value]+= node['abundance']
2686
                    lowestValue = min(lowestValue, value)
2687
                    highestValue = max(highestValue, value)
2688
                    nodeColorMapping[value] = (node['r'], node['g'], node['b'])
2689
2690
            #Create color scale; ##########
2691
            lowestValue = 2.5
2692
            highestValue = 3.4
2693
2694
            #lowestValue = 3.0
2695
            #highestValue = 3.6
2696
            print((lowestValue, highestValue))
2697
            colorScale = self.plot.getGroup('colorScale')
2698
2699
            self.plot.svgTree.append(colorScale)
2700
            r = self.plot.getRectangle(colorScaleX,colorScaleY,colorScaleWidth,colorScaleHeight+labelHeight+colorScaleSpacing )
2701
            colorScale.append(r)
2702
2703
            r.set('style', 'fill:#FFFFFF' )
2704
            #r.set('style', 'filter:url(#%s);fill:#FFFFFF'%colorScaleShadow.get('id') )
2705
            colorScaleKeys = sorted(OrderedDict(sorted(nodeColorMapping.items())) )
2706
            #print(nodeColorMapping)
2707
            deltaValue = (highestValue-lowestValue) / (colorScaleParts-1)
2708
            currentValue = lowestValue
2709
            x = colorScaleSpacing+colorScaleX
2710
            y = colorScaleY + colorScaleSpacing
2711
2712
            idx = 0
2713
            while idx<colorScaleParts:
2714
                print(('Interpolating for %s' % currentValue ))
2715
                c = self.interpolate(currentValue, colorScaleKeys, nodeColorMapping)
2716
                r = self.plot.getRectangle(x,y,colorScaleDeltaX-colorScaleSpacing,colorScaleHeight-2*colorScaleSpacing )
2717
                r.set('fill',  'rgb(%s, %s, %s)' % c)
2718
                r.set('style', 'filter:url(#%s);'%  self.nodeShadow.get('id') )
2719
                r.set('stroke',  'None')
2720
                colorScale.append( r )
2721
2722
                text = self.plot.getText('%.1f' %  (10.0*currentValue) ,x+0.5*(colorScaleDeltaX-colorScaleSpacing),y+colorScaleHeight+colorScaleSpacing, BDBcolor(0,0,0,1))
2723
                #self.plot.setTextRotation(text, 90)
2724
                text.set('text-anchor','middle')
2725
                text.set('dominant-baseline','central')
2726
                text.set('font-family','Gill Sans MT')
2727
                #text.set('font-family','Cambria Math')
2728
                text.set('font-size', '10')
2729
                #text.set('font-weight', 'bold')
2730
                text.set('fill', 'rgba(0,0,0,1)')
2731
                self.plot.svgTree.append( text )
2732
2733
2734
                currentValue+=deltaValue
2735
                x += colorScaleDeltaX
2736
                idx+=1
2737
2738
2739
        #######################
2740
2741
2742
        #Colorize nodes by distance to center
2743
2744
2745
        distancesFound = Counter({})
2746
        distancesReads = Counter({})
2747
        for node in self.g:
2748
            if 'idStr' in self.g.node[node] and  'Wt' == self.g.node[node]['idStr']:
2749
                centerNode = node
2750
                break
2751
2752
        ldistThreshold = 4
2753
        maxHamming = 8
2754
        classColors = {'H0':'#0000DD','H1': '#66A43E', 'H2': '#0D40DB', 'H3': '#3970DD', 'H4': '#769AE0', 'H5': '#A6B9DD', 'H6': '#D7DAE0', 'H7': '#FFFFFF', 'N1':'#FFCC00','N2':'#FF6600', 'N3':'#C83737','N4':'#800000'}
2755
        if performDistanceMeasure:
2756
2757
            print('Estimating all distances...')
2758
            weirdSequences = []
2759
            for nodeIndex,node in enumerate(self.undirectedG):
2760
2761
                if nodeIndex%100==0:
2762
                    completion = 100.0*(float(nodeIndex)/len(self.undirectedG))
2763
                print('\rcompletion %s    ' % completion, end=' ')
2764
2765
                abundance = self.g.node[node]['abundance']
2766
                hammingDistance = bdbbio.getHammingDistance(node, centerNode)
2767
                unformattedAlignments = pairwise2.align.localxx(node,centerNode) #bdbbio.getLevenshteinDistance(node,centerNode)
2768
                ldist = len(node)
2769
2770
                self.g.node[node]['exactHammingDistance'] =  hammingDistance
2771
                self.g.node[node]['exactNWDistance'] =  ldist
2772
                if len(unformattedAlignments)>0:
2773
                    ldist = int(round(len(node)-float(unformattedAlignments[0][2])))
2774
                    #pairwise2.format_alignment(*unformattedAlignments[0])
2775
2776
                if hammingDistance<maxHamming and ldist<=hammingDistance:
2777
                    distancesReads['H%s'%hammingDistance]+=abundance
2778
2779
                    if hammingDistance==0:
2780
                        self.g.node[node]['color'] = '#66A43E'
2781
                        distancesFound['H0']+=1
2782
                    elif hammingDistance==1:
2783
                        self.g.node[node]['color'] = '#0D40DB'
2784
                        distancesFound['H1']+=1
2785
                    elif hammingDistance==2:
2786
                        self.g.node[node]['color'] = '#2362E0'
2787
                        distancesFound['H2']+=1
2788
                    elif hammingDistance==3:
2789
                        self.g.node[node]['color'] = '#769AE0'
2790
                        distancesFound['H3']+=1
2791
                    elif hammingDistance==4:
2792
                        self.g.node[node]['color'] = '#A6B9DD'
2793
                        distancesFound['H4']+=1
2794
                    elif hammingDistance==5:
2795
                        self.g.node[node]['color'] = '#D7DAE0'
2796
                        distancesFound['H5']+=1
2797
                    elif hammingDistance==6:
2798
                        self.g.node[node]['color'] = '#FFFFFF'
2799
                        distancesFound['H6']+=1
2800
                else:
2801
2802
2803
                    if ldist<=ldistThreshold:
2804
2805
                        distancesFound['N%s'%ldist]+=1
2806
                        distancesReads['N%s'%ldist]+=abundance
2807
                        if 'N%s'%ldist not in classColors:
2808
2809
                            brightness = 255-int(round((float(ldist)/ldistThreshold)*100))
2810
                            self.g.node[node]['color'] = 'rgb(%s,%s,%s)' % (brightness,0,0)
2811
                            classColors['N%s'%ldist] = self.g.node[node]['color']
2812
                        self.g.node[node]['color'] = classColors['N%s'%ldist]
2813
                    else:
2814
                        self.g.node[node]['color'] = '#404040'
2815
                        #distancesFound['l%s'%ldist]+=1
2816
                        #distancesReads['l%s'%ldist]+=abundance
2817
                        distancesFound['N>%s'%ldistThreshold]+=1
2818
                        distancesReads['N>%s'%ldistThreshold]+=abundance
2819
                        weirdSequences.append(SeqIO.SeqRecord(Seq(node), 'NW%s-a%s-%s' % (str(ldist),abundance,str(nodeIndex))))
2820
2821
            #classHistogram( classCountMapping, logTransform = False, classColors=None, placeLeft=None ):
2822
2823
            classHistogram(distancesFound, True, classColors, None,['N>%s'%ldistThreshold], False).write('%s_distancesByNodes.svg' % alias)
2824
            classHistogram(distancesReads, True, classColors, None,['N>%s'%ldistThreshold], False).write('%s_distancesByCountB.svg' % alias)
2825
            nx.write_graphml( self.g, './%s-distanceAnnotated.graphml' % alias)
2826
            fastaPath = './%s-weirdSequences.fa' % alias
2827
            SeqIO.write(weirdSequences, fastaPath, "fasta")
2828
2829
2830
2831
2832
        if performFrequencyMeasure:
2833
2834
            sequenceColors = {'rest':'#404040'}
2835
            sequenceFrequencies = Counter({})
2836
            frequencyTable = []
2837
            frequencyCounter = Counter({})
2838
            for nodeIndex,node in enumerate(self.undirectedG):
2839
                abundance = self.g.node[node]['abundance']
2840
                nodeName = nodeIndex
2841
                if 'idStr' in self.g.node[node]:
2842
                    nodeName = self.g.node[node]['idStr']
2843
                #if 'r' in self.g.node[node]:
2844
                #   sequenceColors[str(nodeName)] = 'rgb(%s, %s, %s)' % (self.g.node[node]['r'],self.g.node[node]['g'],self.g.node[node]['b'])
2845
                if 'color' in self.g.node[node]:
2846
                    sequenceColors[str(nodeName)] = self.g.node[node]['color']
2847
2848
                if abundance>3:
2849
                    sequenceFrequencies[str(nodeName)] += abundance
2850
2851
                else:
2852
                    #sequenceFrequencies['rest'] += abundance
2853
                    pass
2854
                for index,base in enumerate(str(node)):
2855
2856
                    if index>=len(frequencyTable):
2857
                        frequencyTable.append(Counter({}))
2858
                        for b in ['A','T','C','G']:
2859
                            frequencyCounter['%s_%s' % (index, b)] = 0
2860
                    frequencyCounter['%s_%s' % (index, base)]+=abundance
2861
                    frequencyTable[index][base]+= abundance
2862
2863
2864
            freqKeys = sorted(list(frequencyCounter.keys()), key=lambda x: self.sortByIndexAndBase(x))
2865
            colors = {}
2866
            for key in freqKeys:
2867
                parts = key.split('_')
2868
                i = ['N','A','T','C','G'].index(parts[1])
2869
                if i==None:
2870
                    i = 0
2871
                colors[key] = ['#404040', '#336bbd','#ff6600','#aa0000','#5aa02c'][i]
2872
2873
2874
            classHistogram(sequenceFrequencies, True, sequenceColors, None,[], False, classWidth=40, rotateClassLabels=90).write('%s_sequenceFrequenciesLog.svg' % alias)
2875
            classHistogram(sequenceFrequencies, False, sequenceColors, None,[], False, classWidth=40, rotateClassLabels=90).write('%s_sequenceFrequenciesLin.svg' % alias)
2876
            classHistogram(frequencyCounter, True, colors, None,freqKeys, False, classWidth=30, rotateClassLabels=90).write('%s_baseFrequencies.svg' % alias)
2877
2878
2879
2880
        print(('Offsetting %s %s' % (minX, minY)))
2881
        self.plot.setWidth( maxX-minX )
2882
        if createdColorScale:
2883
            self.plot.setHeight( ( (colorScaleY+colorScaleHeight+colorScaleGraphSpacing)-minY ))
2884
        else:
2885
            self.plot.setHeight( maxY - minY + 10)
2886
        h = maxY-minY
2887
        ## plotting
2888
        edgeGroup = self.plot.getGroup('edges')
2889
        nodeGroup = self.plot.getGroup('nodes')
2890
        labelGroup = self.plot.getGroup('labels')
2891
2892
        print('Adding edges')
2893
2894
        for fromNode,toNode,data in self.g.edges(data=True):
2895
2896
            if 'hdist' in data and data['hdist']==1:
2897
                fn = self.g.node[fromNode]
2898
                tn = self.g.node[toNode]
2899
                p = self.plot.getPath( self.plot.getPathDefinition([ (int(round(fn['x']-minX)), h-int(round(fn['y']-minY))), (int(round(tn['x']-minX)), h-int(round(tn['y']-minY))) ]) )
2900
2901
                etree.strip_attributes(p,'style')
2902
                if coloringMode == 'nodeRGB':
2903
                    if 'r' in tn:
2904
                        p.set('stroke', 'rgb(%s, %s, %s)' % (tn['r'],tn['g'],tn['b']))
2905
                else:
2906
                    if 'color' in tn:
2907
                        p.set('stroke', '%s' % (tn['color']))
2908
                p.set('stroke-width', '0.75')
2909
                p.set('stroke-opacity', '0.8')
2910
                #self.plot.modifyStyle(p, { 'stroke-width':'0.5', 'stroke-opacity':"0.5"  } ) #,'stroke-width':'0.5' 'stroke':'rgba(80,80,80,0.8)'})
2911
                #self.plot.modifyStyle(p, { 'stroke':'rgba(80,80,80,0.8)' } ) #,'stroke-width':'0.5' 'stroke':'rgba(80,80,80,0.8)'})
2912
2913
                edgeGroup.append(p)
2914
2915
        #self.plot.modifyStyle(edgeGroup, {'stroke-width':'0.5', 'stroke':'rgba(80,80,80,0.8)'})
2916
        self.plot.svgTree.append(edgeGroup)
2917
        self.plot.svgTree.append(nodeGroup)
2918
        self.plot.svgTree.append(labelGroup)
2919
        print('Adding nodes')
2920
2921
2922
2923
2924
2925
        for componentIndex,connectedComponent in enumerate(nx.connected_component_subgraphs(self.undirectedG)):
2926
            componentGroup = self.plot.getGroup('component_%s' % componentIndex)
2927
            smallNodes = self.plot.getGroup('component_%s_smallNodes' % componentIndex)
2928
            bigNodes = self.plot.getGroup('component_%s_bigNodes' % componentIndex)
2929
            componentGroup.append(smallNodes)
2930
            componentGroup.append(bigNodes)
2931
            nodeGroup.append(componentGroup)
2932
            for nodeName in connectedComponent:
2933
                node = self.g.node[nodeName]
2934
                if 'x' in node and 'y' in node and 'size' in node :
2935
2936
2937
                    if node['size']>0:
2938
                        circle = self.plot.getCircle(int(round(node['x']-minX)), h-int(round(node['y']-minY)), int(round(node['size'])))
2939
                        circle.set('style','fill:none')
2940
                        if coloringMode == 'nodeRGB' and 'r' in node:
2941
                            self.plot.modifyStyle(circle, {'fill':'rgb(%s,%s,%s)' % (node['r'], node['g'], node['b']), 'stroke':'rgba(0,0,247,0.8)'})
2942
                        else:
2943
                            if 'color' in node:
2944
                                self.plot.modifyStyle(circle, {'fill':'%s' % (node['color']), 'stroke':'rgba(0,0,247,0.8)'})
2945
2946
                        if node['size']>0.0:
2947
                            bigNodes.append(circle)
2948
                            bigNodes.append(circle)
2949
                            if node['abundance']>500:
2950
                                if 'idStr' in node:
2951
                                    label = node['idStr']
2952
                                else:
2953
                                    label= ''
2954
2955
                                text = self.plot.getText('%s %s' % (label,node['abundance']) ,int(round(node['x']-minX)), h-int(round(node['y']-minY))+4,BDBcolor(80,80,80,1))
2956
                                text.set('text-anchor','middle')
2957
                                text.set('dominant-baseline','central')
2958
                                text.set('font-family','Gill Sans MT')
2959
                                #text.set('font-family','Cambria Math')
2960
                                text.set('font-size', '14')
2961
                                #text.set('font-weight', 'bold')
2962
                                text.set('fill', 'rgba(50,50,50,1)')
2963
                                labelGroup.append( text )
2964
2965
2966
                        else:
2967
                            smallNodes.append(circle)
2968
                            smallNodes.append(circle)
2969
2970
                else:
2971
                    print('Skipped a node; could not find coordinates')
2972
2973
            bigNodes.set('style', 'filter:url(#%s);'%self.nodeShadow.get('id') )
2974
            #self.plot.modifyStyle(bigNodes, {'filter':'url(#%s)'%self.nodeShadow.get('id')})
2975
2976
2977
2978
2979
2980
2981
2982
2983
def testGraphRenderer():
2984
    p = GraphRenderer( nx.read_graphml("C:\\Users\BuysDB\Desktop\Control1-g3.graphml"),'distances',alias='controlDist')
2985
    p.plot.write('control_spaghettogramRenderDistancesB.svg')
2986
    p.plot.SVGtoPNG('control_spaghettogramRenderDistancesB.svg', 'control_spaghettogramRenderDistancesB.png',2048)
2987
2988
    p = GraphRenderer( nx.read_graphml("C:\\Users\BuysDB\Desktop\Control1-g3.graphml"),'nodeRGB',alias='controlConf')
2989
    p.plot.write('control_spaghettogramRenderConfidenceB.svg')
2990
    p.plot.SVGtoPNG('control_spaghettogramRenderConfidenceB.svg', 'control_spaghettogramRenderConfidenceB.png',2048)
2991
2992
2993
def artGraphRenderer():
2994
    p = GraphRenderer( nx.read_graphml("C:\\Users\BuysDB\Desktop\ArtSimulatedGraphHC26k.graphml"),'nodeRGB','confidence',False, True, 'simulated')
2995
    p.plot.write('ArtSimulatedGraphHC26k.svg')
2996
    p.plot.SVGtoPNG('ArtSimulatedGraphHC26k.svg', 'ArtSimulatedGraphHC26k.png',2048)
2997
2998
def embryoGraphRenderer():
2999
    p = GraphRenderer( nx.read_graphml("C:\\Users\BuysDB\Desktop\embryo7remapped2.graphml"),'nodeRGB','confidence',False, True, 'embryo7')
3000
    p.plot.write('embryo7remapped3.svg')
3001
    p.plot.SVGtoPNG('embryo7remapped3.svg', 'embryo7remapped3.png',2048)
3002
3003
def midbrainGraphRenderer():
3004
    p = GraphRenderer( nx.read_graphml("C:\\Users\BuysDB\Desktop\midbrain.graphml"),'nodeRGB','confidence',False, True, 'brain')
3005
    p.plot.write('midbrain.svg')
3006
    p.plot.SVGtoPNG('midbrain.svg', 'midbrain.png',2048)
3007
3008
class SequenceBin():
3009
3010
    def __init__(self, sequence, abundance):
3011
        self.sequence = sequence
3012
        self.abundance = abundance
3013
        self.confidences = []
3014
        self.diffIndices = []
3015
        self.x = 0
3016
        self.y = 0
3017
3018
3019
3020
3021
class HammingBin():
3022
3023
    def __init__(self, index, hammingDistance=0):
3024
        self.hammingDistance = hammingDistance
3025
        self.index = index
3026
        self.sequences = []
3027
3028
    def addSequence(self, sequence, abundance):
3029
        self.sequences.append(sequence)
3030
3031
3032
3033
#SVG table renderer
3034
class SVGTable(BDBPlot):
3035
    #@param datamatrix list of lists containing values
3036
    #@param header list containing column names
3037
    def __init__(self, dataMatrix, header):
3038
        BDBPlot.__init__(self)
3039
        self.data = dataMatrix
3040
        self.header = header
3041
        self.cellPointer = 0
3042
3043
    #Render a cell in the matrix
3044
    def cellRenderFunction(self,x,y):
3045
        self.svgTree.getGroup()
3046
3047
3048
3049
class SpaghettoPlot():
3050
3051
    def __init__(self, networkxGraph):
3052
        self.g = networkxGraph
3053
        self.minRadius = 1
3054
        self.nucleotideColours = {'A':'#FF2222','T':'#22FF22', 'G':'#2222FF','C':'#FFFF22'}
3055
3056
3057
    def getSurfaceBasedNodeRadius(self, abundance, maxRadius):
3058
3059
        #return(math.log(abundance+1)+1)
3060
        maxO = math.pi*math.pow(float(maxRadius),2)
3061
        return( max(self.minRadius,math.sqrt( float(abundance*maxO)/math.pi) ) )
3062
3063
3064
3065
    def layout(self):
3066
3067
3068
        maxRadius = 100
3069
        minDistance = 15 #Distance between the nodes
3070
3071
        components = []
3072
        self.readLen = 0
3073
3074
        toRemove = []
3075
        for nodeA,nodeB,d in self.g.edges_iter(data='hdist'):
3076
            if d!=1:
3077
                toRemove.append( (nodeA, nodeB) )
3078
        self.g.remove_edges_from(toRemove)
3079
3080
        for componentIndex,connectedComponent in enumerate(nx.connected_component_subgraphs(self.g)):
3081
            if len(connectedComponent)>3:
3082
                #Find center(s)
3083
                plot = BDBPlot()
3084
3085
                centerNode = None
3086
                centerAbundance = 0
3087
                for node in connectedComponent:
3088
                    a = connectedComponent.node[node]['abundance']
3089
                    if a > centerAbundance:
3090
                        centerNode = node
3091
                        centerAbundance = a
3092
                    #Todo: compat for more centers
3093
                    self.readLen = max(self.readLen, len(node))
3094
                #Estimate radial size of connected component:
3095
                #longest path from center to member
3096
                radialSize = 1
3097
3098
                distanceMap = {0:[centerNode]}
3099
                for targetNode in connectedComponent:
3100
                    if targetNode!=centerNode:
3101
3102
                        pathLen = nx.shortest_path_length(self.g,source=centerNode,target=targetNode)
3103
                        radialSize = max( radialSize, pathLen)
3104
                        if not pathLen in distanceMap:
3105
                            distanceMap[pathLen] = []
3106
                        distanceMap[pathLen].append(targetNode)
3107
3108
3109
                #print(distanceMap)
3110
3111
                #Read index to angle mapping
3112
                phis = []
3113
                for index in range(0,self.readLen):
3114
                    phis.append( -math.pi*0.5 + float(math.pi*2) * (float(index)/self.readLen) )
3115
3116
3117
3118
                currentRadius = self.getSurfaceBasedNodeRadius(float(self.g.node[centerNode]['abundance'])/centerAbundance,maxRadius)
3119
                centerRadius = currentRadius
3120
                # Construct coordinates
3121
                coordinates = {}
3122
                coordinates[centerNode] = {'x':0,'y':0,'r':currentRadius}
3123
3124
                hammingRadials = []
3125
                for distance in range(1,radialSize+1):
3126
3127
                    print(('Distance %s radius: %s' % (distance, currentRadius)))
3128
                    if distance in distanceMap:
3129
                        #Find the radius of this circle
3130
                        maxNodeRadius = 0
3131
                        for node in distanceMap[distance]:
3132
                            maxNodeRadius = max(maxNodeRadius, self.getSurfaceBasedNodeRadius(float(self.g.node[node]['abundance'])/centerAbundance,maxRadius))
3133
3134
                        currentRadius += (maxNodeRadius+ minDistance) #added *.5!
3135
3136
                        #Calculate x and y coordinates:
3137
3138
                        for node in distanceMap[distance]:
3139
                            r = self.getSurfaceBasedNodeRadius(float(self.g.node[node]['abundance'])/centerAbundance,maxRadius)
3140
3141
                            #Retrieve the hamming index of the node
3142
                            indices = bdbbio.getHammingIndices(node,centerNode) #[ int(x) for x in self.g.node[node]['indices'].split(',') ]
3143
3144
                            #Take the first index as base
3145
                            for index in indices:
3146
                                angle = phis[ index ]
3147
3148
                                coordinates[node] = {'x':math.cos(angle)*currentRadius, 'y':math.sin(angle)*currentRadius, 'r':r, 'a':angle }
3149
                            #print(coordinates[node])
3150
                    hammingRadials.append(currentRadius)
3151
3152
                translateX = -currentRadius
3153
                translateY = -currentRadius
3154
                #for nodeName in coordinates:
3155
                #   translateX = min(translateX, coordinates[nodeName]['x']-coordinates[nodeName]['r'] )
3156
                #   translateY = min(translateY, coordinates[nodeName]['y']-coordinates[nodeName]['r'] )
3157
                #Draw component
3158
3159
                for i,r in enumerate(hammingRadials):
3160
3161
                    circle = plot.getCircle( -translateX,  -translateY, r)
3162
                    plot.modifyStyle(circle, {'stroke-width':'0.3', 'stroke-miterlimit':'4', 'stroke-dasharray':'2.08,2.08'})
3163
                    plot.svgTree.append( circle )
3164
3165
                    text = plot.getText('h'+str(i+1),-translateX, -r-translateY+5, fill=BDBcolor(30,30,30,0))
3166
                    text.set('text-anchor','middle')
3167
                    text.set('dominant-baseline','central')
3168
                    text.set('font-family','Cambria')
3169
                    text.set('font-size','6')
3170
                    plot.svgTree.append(text)
3171
3172
3173
                    if i==0:
3174
                        r=centerRadius*2 + 10
3175
                        for index, angle in enumerate(phis):
3176
                            text = plot.getText(centerNode[index], math.cos(angle)*r*0.5 - translateX, math.sin(angle)*r*0.5-translateY, fill=BDBcolor(30,30,30,0))
3177
                            text.set('text-anchor','middle')
3178
                            text.set('dominant-baseline','central')
3179
                            text.set('font-family','Gill Sans MT')
3180
                            text.set('font-size','10')
3181
                            plot.svgTree.append(text)
3182
3183
                            p = r - 35
3184
                            text = plot.getText(str(index), math.cos(angle)*p*0.5 - translateX, math.sin(angle)*p*0.5-translateY, fill=BDBcolor(30,30,30,0))
3185
                            text.set('text-anchor','middle')
3186
                            text.set('dominant-baseline','central')
3187
                            text.set('font-family','Cambria')
3188
                            text.set('font-size','8')
3189
                            plot.svgTree.append(text)
3190
3191
3192
3193
                for nodeName in coordinates:
3194
3195
                    circle = plot.getCircle( coordinates[nodeName]['x']-translateX,  coordinates[nodeName]['y']-translateY, coordinates[nodeName]['r'])
3196
                    plot.modifyStyle(circle, {'stroke-width':'0', 'fill':'#FF6655', 'fill-opacity':'0.30'})
3197
                    plot.svgTree.append( circle )
3198
3199
                plot.write('./components/component%s.svg' % componentIndex )