Diff of /aggmap/map.py [000000] .. [9e8054]

Switch to unified view

a b/aggmap/map.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Sun Aug 25 20:29:36 2019
5
6
@author: wanxiang.shen@u.nus.edu
7
8
main aggmap code
9
10
"""
11
from aggmap.utils.logtools import print_info, print_warn, print_error
12
from aggmap.utils.matrixopt import Scatter2Grid, Scatter2Array, smartpadding 
13
from aggmap.utils import vismap, summary, calculator
14
from aggmap.utils.gen_nwk import mp2newick
15
16
17
from sklearn.manifold import TSNE, MDS, Isomap, LocallyLinearEmbedding, SpectralEmbedding
18
        
19
from joblib import Parallel, delayed, load, dump
20
from scipy.spatial.distance import squareform, cdist, pdist
21
from scipy.cluster.hierarchy import fcluster, linkage, dendrogram
22
import matplotlib.pylab as plt
23
import seaborn as sns
24
from umap import UMAP
25
from tqdm import tqdm
26
from copy import copy, deepcopy
27
import pandas as pd
28
import numpy as np
29
30
31
class Base:
32
    
33
    def __init__(self):
34
        pass
35
        
36
    def _save(self, filename):
37
        return dump(self, filename)
38
        
39
    def _load(self, filename):
40
        return load(filename)
41
42
    
43
    def MinMaxScaleClip(self, x, xmin, xmax):
44
        scaled = (x - xmin) / ((xmax - xmin) + 1e-8)
45
        return scaled
46
47
    def StandardScaler(self, x, xmean, xstd):
48
        return (x-xmean) / (xstd + 1e-8) 
49
    
50
    
51
class Random_2DEmbedding:
52
    
53
    def __init__(self, random_state=123, n_components=2):
54
        self.random_state=random_state
55
        self.n_components = n_components
56
57
    def fit(self, X):
58
        M, N = X.shape
59
        np.random.seed(self.random_state)
60
        rho = np.sqrt(np.random.uniform(0, 1, N))
61
        phi = np.random.uniform(0, 4*np.pi, N)
62
        x = rho * np.cos(phi)
63
        y = rho * np.sin(phi)
64
        rd = pd.DataFrame([x,y]).T.sample(frac=1, random_state=123).reset_index(drop=True)
65
        self.embedding_ = rd.values
66
        return self
67
        
68
        
69
def _get_df_scatter(mp):
70
    xy = mp.embedded.embedding_
71
    colormaps = mp.colormaps
72
    df = pd.DataFrame(xy, columns = ['x', 'y'])
73
    bitsinfo = mp.bitsinfo.set_index('IDs')
74
    df = df.join(bitsinfo.loc[mp.flist].reset_index())
75
    df['colors'] = df['Subtypes'].map(colormaps)
76
    return df
77
78
79
def _get_df_grid(mp):
80
81
    p,q = mp._S.fmap_shape
82
    position = np.zeros(mp._S.fmap_shape, dtype='O').reshape(p*q,)
83
    position[mp._S.col_asses] = mp.flist
84
    position = position.reshape(p, q)
85
    if mp.fmap_shape != None:  
86
        m, n = mp.fmap_shape
87
        if (m > p) | (n > q):
88
            position = smartpadding(position, (m,n), constant_values=0)        
89
    M, N = position.shape
90
    
91
    x = []
92
    y = []
93
    for i in range(M):
94
        for j in range(N):
95
            x.append(j) #, position[j,i]
96
            y.append(i)
97
    v = position.reshape(M*N,)
98
99
    df = pd.DataFrame(list(zip(x,y, v)), columns = ['x', 'y', 'v'])
100
101
    bitsinfo = mp.bitsinfo
102
    subtypedict = bitsinfo.set_index('IDs')['Subtypes'].to_dict()
103
    subtypedict.update({0:'NaN'})
104
    df['Subtypes'] = df.v.map(subtypedict)
105
    df['colors'] = df['Subtypes'].map(mp.colormaps) 
106
    
107
    feature_list = df.v
108
    feature_names = []
109
    for i, j in feature_list.items():
110
        if j == 0:
111
            j = 'NaN-' + str(i)
112
        feature_names.append(j)
113
    df.v = feature_names
114
    return df
115
116
117
class AggMap(Base):
118
    
119
    """ The feature restructuring class AggMap
120
    
121
    
122
    Parameters
123
    ----------
124
    dfx: pandas DataFrame
125
        Input data frame. 
126
        
127
    metric: string,  default: 'correlation'
128
        measurement of feature distance, support {'cosine', 'correlation', 'euclidean', 'jaccard', 'hamming', 'dice'}
129
    
130
    info_distance: numpy array, defalt: None
131
        a vector-form distance vector of the feature points, shape should be: (n*(n-1)/2), where n is the number of the features. It can be useful when you have you own vector-form distance to pass
132
        
133
    by_scipy: bool, defalt: False.
134
        calculate the distance by using the scipy pdist fuction.
135
        It can bu useful when dfx.shape[1] > 20000, i.e., the number of features is very large
136
        Using pdist will increase the speed to calculate the distance.
137
    
138
    n_cpus: int, default: 16
139
        number of cpu cores to use to calculate the distance.        
140
    """
141
    
142
    def __init__(self, 
143
                 dfx,
144
                 metric = 'correlation',
145
                 by_scipy = False,
146
                 n_cpus = 16,
147
                 info_distance = None,
148
                ):
149
        
150
        assert type(dfx) == pd.core.frame.DataFrame, 'input dfx must be pandas DataFrame!'
151
        super().__init__()
152
        self.metric = metric
153
        self.by_scipy = by_scipy
154
        self.isfit = False
155
        self.alist = dfx.columns.tolist()
156
        self.ftype = 'feature points'
157
        self.cluster_flag = False
158
        m,n = dfx.shape
159
        info_distance_length = int(n*(n-1)/2)
160
        assert len(dfx.columns.unique()) == n, 'the column names of dataframe must be unique !'        
161
        
162
        ## calculating distance
163
        if np.array(info_distance).any():
164
            assert len(info_distance) == info_distance_length, 'shape of info_distance must be (%s,)' % info_distance_length
165
            print_info('Skipping the distance calculation, using the customized vector-form distance...')
166
            self.info_distance = np.array(info_distance)
167
            self.metric = 'precomputed'
168
        else:
169
            print_info('Calculating distance ...')
170
            
171
            if self.by_scipy:
172
                D = pdist(dfx.T, metric=metric)
173
                D = np.nan_to_num(D,copy=False)
174
                self.info_distance = D.clip(0, np.inf)
175
            else:
176
                D = calculator.pairwise_distance(dfx.values, n_cpus=n_cpus, method=metric)
177
                D = np.nan_to_num(D,copy=False)
178
                D_ = squareform(D)
179
                self.info_distance = D_.clip(0, np.inf)
180
            
181
        ## statistic info
182
        S = summary.Summary(n_jobs = 10)
183
        res= []
184
        for i in tqdm(range(dfx.shape[1]), ascii=True):
185
            r = S._statistics_one(dfx.values, i)
186
            res.append(r)
187
        dfs = pd.DataFrame(res, index = self.alist)
188
        self.info_scale = dfs
189
        
190
        
191
        
192
    def _fit_embedding(self, 
193
                       dist_matrix,
194
                       emb_method = 'umap', 
195
                       n_components = 2,
196
                       random_state = 32,  
197
                       verbose = 2,
198
                       n_neighbors = 15,
199
                       min_dist = 0.1,
200
                       a = 1.576943460405378,
201
                       b = 0.8950608781227859,
202
                       **kwargs):
203
        
204
        """
205
        parameters
206
        -----------------
207
        dist_matrix: distance matrix to fit
208
        emb_method: {'tsne', 'umap', 'mds'}, algorithm to embedd high-D to 2D
209
        kwargs: the extra parameters for the conresponding algorithm
210
        """
211
        
212
213
        if 'metric' in kwargs.keys():
214
            metric = kwargs.get('metric')
215
            kwargs.pop('metric')
216
            
217
        else:
218
            metric = 'precomputed'
219
220
        if emb_method == 'tsne':
221
            embedded = TSNE(n_components=n_components, 
222
                            random_state=random_state,
223
                            metric = metric,
224
                            verbose = verbose,
225
                            **kwargs)
226
            embedded = embedded.fit(dist_matrix)   
227
            
228
        elif emb_method == 'umap':
229
            embedded = UMAP(n_components = n_components, 
230
                            n_neighbors = n_neighbors,
231
                            min_dist = min_dist,
232
                            a = a,
233
                            b = b,
234
                            verbose = verbose,
235
                            random_state=random_state, 
236
                            metric = metric, **kwargs)
237
            embedded = embedded.fit(dist_matrix)   
238
            
239
        elif emb_method =='mds':
240
            if 'metric' in kwargs.keys():
241
                kwargs.pop('metric')
242
            if 'dissimilarity' in kwargs.keys():
243
                dissimilarity = kwargs.get('dissimilarity')
244
                kwargs.pop('dissimilarity')
245
            else:
246
                dissimilarity = 'precomputed'
247
                
248
            embedded = MDS(metric = True, 
249
                           n_components= n_components,
250
                           verbose = verbose,
251
                           dissimilarity = dissimilarity, 
252
                           random_state = random_state, **kwargs)
253
            embedded = embedded.fit(dist_matrix)          
254
        
255
        elif emb_method == 'random':
256
            embedded = Random_2DEmbedding(random_state=random_state, 
257
                                          n_components=n_components)
258
            embedded = embedded.fit(dist_matrix)   
259
            
260
        elif emb_method == 'isomap':
261
            embedded = Isomap(n_neighbors = n_neighbors,
262
                              n_components=n_components, 
263
                              metric = metric,
264
                              **kwargs)
265
            embedded = embedded.fit(dist_matrix)   
266
            
267
        elif emb_method == 'lle':
268
            embedded = LocallyLinearEmbedding(random_state=random_state, 
269
                                              n_neighbors = n_neighbors,
270
                                              n_components=n_components, 
271
                                              **kwargs)
272
            embedded = embedded.fit(dist_matrix)   
273
            
274
        elif emb_method == 'se':
275
            embedded = SpectralEmbedding(random_state=random_state, 
276
                                          n_neighbors = n_neighbors,
277
                                          n_components=n_components, 
278
                                          affinity = metric,
279
                                          **kwargs)
280
            affinity_matrix = np.exp(-dist_matrix**2)  #make more uniformly embedding  
281
            
282
            embedded = embedded.fit(affinity_matrix)    
283
    
284
        return embedded
285
286
287
    
288
    def fit(self, 
289
            feature_group_list = [],
290
            cluster_channels = 5,
291
            var_thr = -1, 
292
            split_channels = True, 
293
            fmap_shape = None, 
294
            emb_method = 'umap', 
295
            min_dist = 0.1, 
296
            n_neighbors = 15,
297
            a = 1.576943460405378,
298
            b = 0.8950608781227859,
299
            verbose = 2, 
300
            random_state = 32,
301
            group_color_dict  = {},
302
            lnk_method = 'complete',
303
            **kwargs): 
304
        """
305
        parameters
306
        -----------------
307
        feature_group_list: list of the group name for each feature point
308
        cluster_channels: int, number of the channels(clusters) if feature_group_list is empty
309
        var_thr: float, defalt is -1, meaning that feature will be included only if the conresponding variance larger than this value. Since some of the feature has pretty low variances, we can remove them by increasing this threshold
310
        split_channels: bool, if True, outputs will split into various channels using the types of feature
311
        fmap_shape: None or tuple, size of molmap, if None, the size of feature map will be calculated automatically
312
        emb_method: {'umap', 'tsne', 'mds', 'isomap', 'random', 'lle', 'se'}, algorithm to embedd high-D to 2D
313
        min_dist: float, UMAP parameters for the effective minimum distance between embedded points.
314
        n_neighbors: init, UMAP parameters of controlling the embedding. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved.
315
        a: float, UMAP parameters of controlling the embedding. If None, it will automatically be determined by ``min_dist`` and ``spread``.
316
        b: float, UMAP parameters of controlling the embedding. If None, it will automatically be determined by ``min_dist`` and ``spread``.
317
        group_color_dict: dict of the group colors, keys are the group names, values are the colors
318
        lnk_method: {'complete', 'average', 'single', 'weighted', 'centroid'}, linkage method
319
        kwargs: the extra parameters for the conresponding embedding method
320
        """
321
            
322
        if 'n_components' in kwargs.keys():
323
            kwargs.pop('n_components')
324
            
325
            
326
        ## embedding  into a 2d 
327
        assert emb_method in ['tsne', 'umap', 'mds', 'isomap', 'random', 'lle', 'se'], 'No Such Method Supported: %s' % emb_method
328
        
329
        self.feature_group_list = feature_group_list
330
        self.var_thr = var_thr
331
        self.split_channels = split_channels
332
        self.fmap_shape = fmap_shape
333
        self.emb_method = emb_method
334
        self.lnk_method = lnk_method
335
        self.random_state = random_state
336
        
337
        if fmap_shape != None:
338
            assert len(fmap_shape) == 2, "fmap_shape must be a tuple with two elements!"
339
        
340
        # flist and distance
341
        flist = self.info_scale[self.info_scale['var'] > self.var_thr].index.tolist()
342
        
343
        dfd = pd.DataFrame(squareform(self.info_distance),
344
                           index=self.alist,
345
                           columns=self.alist)
346
        dist_matrix = dfd.loc[flist][flist]
347
        self.flist = flist
348
        
349
        self.x_mean = self.info_scale['mean'].values
350
        self.x_std =  self.info_scale['std'].values
351
        
352
        self.x_min = self.info_scale['min'].values
353
        self.x_max = self.info_scale['max'].values
354
     
355
        #bitsinfo
356
        dfb = pd.DataFrame(self.flist, columns = ['IDs'])
357
        if feature_group_list != []:
358
            
359
            self.cluster_flag = False
360
            
361
            assert len(feature_group_list) == len(self.alist), "the length of the input group list is not equal to length of the feature list"
362
            self.cluster_channels = len(set(feature_group_list))
363
            self.feature_group_list_ = feature_group_list
364
            dfb['Subtypes'] = dfb['IDs'].map(pd.Series(feature_group_list, index = self.alist))
365
            
366
            if set(feature_group_list).issubset(set(group_color_dict.keys())):
367
                self.group_color_dict = group_color_dict
368
                dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
369
            else:
370
                unique_types = dfb['Subtypes'].unique()
371
                color_list = sns.color_palette("hsv", len(unique_types)).as_hex()
372
                group_color_dict = dict(zip(unique_types, color_list))
373
                dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
374
                self.group_color_dict = group_color_dict
375
        else:
376
            self.cluster_channels = cluster_channels
377
            print_info('applying hierarchical clustering to obtain group information ...')
378
            self.cluster_flag = True
379
            
380
            Z = linkage(squareform(dist_matrix.values),  lnk_method)
381
            labels = fcluster(Z, cluster_channels, criterion='maxclust')
382
            
383
            feature_group_list_ = ['cluster_%s' % str(i).zfill(2) for i in labels]
384
            dfb['Subtypes'] = feature_group_list_
385
            dfb = dfb.sort_values('Subtypes')
386
            unique_types = dfb['Subtypes'].unique()
387
            
388
            if not set(unique_types).issubset(set(group_color_dict.keys())):
389
                color_list = sns.color_palette("hsv", len(unique_types)).as_hex()
390
                group_color_dict = dict(zip(unique_types, color_list))
391
            
392
            dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
393
            self.group_color_dict = group_color_dict           
394
            self.Z = Z
395
            self.feature_group_list_ = feature_group_list_
396
            
397
398
        self.bitsinfo = dfb
399
        colormaps = dfb.set_index('Subtypes')['colors'].to_dict()
400
        colormaps.update({'NaN': '#000000'})
401
        self.colormaps = colormaps
402
        self._S = Scatter2Grid()
403
404
        ## 2d embedding first
405
        embedded = self._fit_embedding(dist_matrix,
406
                                       emb_method = emb_method,
407
                                       n_neighbors = n_neighbors,
408
                                       random_state = random_state,
409
                                       min_dist = min_dist, 
410
                                       a = a,
411
                                       b = b,
412
                                       verbose = verbose,
413
                                       n_components = 2, **kwargs)
414
        
415
        self.embedded = embedded 
416
        
417
        df = pd.DataFrame(embedded.embedding_, index = self.flist,columns=['x', 'y'])
418
        typemap = self.bitsinfo.set_index('IDs')
419
        df = df.join(typemap)
420
        df['Channels'] = df['Subtypes']
421
        self.df_embedding = df
422
      
423
        ## linear assignment algorithm 
424
        print_info('Applying grid assignment of feature points, this may take several minutes(1~30 min)')
425
        self._S.fit(self.df_embedding, self.split_channels, channel_col = 'Channels')
426
        print_info('Finished')
427
        
428
        ## fit flag
429
        self.isfit = True
430
        if self.fmap_shape == None:
431
            self.fmap_shape = self._S.fmap_shape        
432
        else:
433
            m, n = self.fmap_shape
434
            p, q = self._S.fmap_shape
435
            assert (m >= p) & (n >=q), "fmap_shape's width must >= %s, height >= %s " % (p, q)
436
437
    
438
        self.df_scatter = _get_df_scatter(self)
439
        self.df_grid = _get_df_grid(self)
440
        self.df_grid_reshape = _get_df_grid(self)
441
        self.feature_names_reshape = self.df_grid.v.tolist()
442
        return self
443
        
444
    
445
    def refit_c(self, cluster_channels = 10, lnk_method = 'complete', group_color_dict = {}):
446
        """
447
        re-fit the aggmap object to update the number of channels
448
        
449
        parameters
450
        --------------------
451
        cluster_channels: int, number of the channels(clusters)
452
        group_color_dict: dict of the group colors, keys are the group names, values are the colors
453
        lnk_method: {'complete', 'average', 'single', 'weighted', 'centroid'}, linkage method
454
        """
455
        
456
        if not self.isfit:
457
            print_error('please fit first!')
458
            return
459
            
460
461
        self.split_channels = True
462
        self.lnk_method = lnk_method
463
        self.cluster_channels = cluster_channels
464
        
465
        dfd = pd.DataFrame(squareform(self.info_distance),
466
                           index=self.alist,
467
                           columns=self.alist)
468
        dist_matrix = dfd.loc[self.flist][self.flist]
469
470
        dfb = pd.DataFrame(self.flist, columns = ['IDs'])
471
        print_info('applying hierarchical clustering to obtain group information ...')
472
        self.cluster_flag = True
473
474
        Z = linkage(squareform(dist_matrix.values),  self.lnk_method)
475
        labels = fcluster(Z, self.cluster_channels, criterion='maxclust')
476
477
        feature_group_list_ = ['cluster_%s' % str(i).zfill(2) for i in labels]
478
        dfb['Subtypes'] = feature_group_list_
479
        dfb = dfb.sort_values('Subtypes')
480
        unique_types = dfb['Subtypes'].unique()
481
482
        if not set(unique_types).issubset(set(group_color_dict.keys())):
483
            color_list = sns.color_palette("hsv", len(unique_types)).as_hex()
484
            group_color_dict = dict(zip(unique_types, color_list))
485
486
        dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
487
        self.group_color_dict = group_color_dict           
488
        self.Z = Z
489
        self.feature_group_list_ = feature_group_list_
490
491
        # update self.bitsinfo
492
        self.bitsinfo = dfb
493
        colormaps = dfb.set_index('Subtypes')['colors'].to_dict()
494
        colormaps.update({'NaN': '#000000'})
495
        self.colormaps = colormaps
496
497
        # update self.df_embedding
498
        df = pd.DataFrame(self.embedded.embedding_, index = self.flist,columns=['x', 'y'])
499
        typemap = self.bitsinfo.set_index('IDs')
500
        df = df.join(typemap)
501
        df['Channels'] = df['Subtypes']
502
        self.df_embedding = df
503
504
        ## linear assignment not performed, only refit the channel number 
505
        print_info('skipping grid assignment of feature points, fitting to target channel number')
506
        self._S.refit_c(self.df_embedding)
507
        print_info('Finished')
508
509
        if self.fmap_shape == None:
510
            self.fmap_shape = self._S.fmap_shape        
511
        else:
512
            m, n = self.fmap_shape
513
            p, q = self._S.fmap_shape
514
            assert (m >= p) & (n >=q), "fmap_shape's width must >= %s, height >= %s " % (p, q)
515
516
517
        self.df_scatter = _get_df_scatter(self)
518
        self.df_grid = _get_df_grid(self)
519
        self.df_grid_reshape = _get_df_grid(self)
520
        self.feature_names_reshape = self.df_grid.v.tolist()
521
        return self
522
    
523
    def transform_mpX_to_df(self, X):
524
        '''
525
        input 4D X, output 2D dataframe
526
        '''
527
        n, w,h, c = X.shape
528
        X_2D = X.sum(axis=-1).reshape(n, w*h)
529
        return pd.DataFrame(X_2D, columns = self.feature_names_reshape)    
530
531
    
532
    def transform(self, 
533
                  arr_1d, 
534
                  scale = True, 
535
                  scale_method = 'minmax',
536
                  fillnan = 0):
537
    
538
    
539
        """
540
        parameters
541
        --------------------
542
        arr_1d: 1d numpy array feature points
543
        scale: bool, if True, we will apply MinMax scaling by the precomputed values
544
        scale_method: {'minmax', 'standard'}
545
        fillnan: fill nan value, default: 0
546
        """
547
        
548
        if not self.isfit:
549
            print_error('please fit first!')
550
            return
551
552
        if scale:
553
            if scale_method == 'standard':
554
                arr_1d = self.StandardScaler(arr_1d, self.x_mean, self.x_std)
555
            else:
556
                arr_1d = self.MinMaxScaleClip(arr_1d, self.x_min, self.x_max)
557
        
558
        df = pd.DataFrame(arr_1d).T
559
        df.columns = self.alist
560
561
        df = df[self.flist]
562
        vector_1d = df.values[0] #shape = (N, )
563
        fmap = self._S.transform(vector_1d)  
564
        p, q, c = fmap.shape
565
        
566
        if self.fmap_shape != None:        
567
            m, n = self.fmap_shape
568
            if (m > p) | (n > q):
569
                fps = []
570
                for i in range(c):
571
                    fp = smartpadding(fmap[:,:,i], self.fmap_shape)
572
                    fps.append(fp)
573
                fmap = np.stack(fps, axis=-1)        
574
575
        return np.nan_to_num(fmap, nan = fillnan)   
576
    
577
    
578
579
    
580
    def batch_transform(self, 
581
                        array_2d, 
582
                        scale = True, 
583
                        scale_method = 'minmax',
584
                        n_jobs=4,
585
                        fillnan = 0):
586
    
587
        """
588
        parameters
589
        --------------------
590
        array_2d: 2D numpy array feature points, M(samples) x N(feature ponits)
591
        scale: bool, if True, we will apply MinMax scaling by the precomputed values
592
        scale_method: {'minmax', 'standard'}
593
        n_jobs: number of parallel
594
        fillnan: fill nan value, default: 0
595
        """
596
        
597
        if not self.isfit:
598
            print_error('please fit first!')
599
            return
600
        
601
        assert type(array_2d) == np.ndarray, 'input must be numpy ndarray!' 
602
        assert array_2d.ndim == 2, 'input must be 2-D  numpy array!' 
603
        
604
        P = Parallel(n_jobs=n_jobs)
605
        res = P(delayed(self.transform)(arr_1d, 
606
                                        scale,
607
                                        scale_method,
608
                                        fillnan) for arr_1d in tqdm(array_2d, ascii=True)) 
609
        X = np.stack(res) 
610
        
611
        return X
612
    
613
    
614
    def plot_scatter(self, htmlpath='./', htmlname=None, radius = 2, enabled_data_labels = False):
615
        """Scatter plot, radius: the size of the scatter, must be int"""
616
        H_scatter = vismap.plot_scatter(self,  
617
                                        htmlpath=htmlpath, 
618
                                        htmlname=htmlname,
619
                                        radius = radius,
620
                                        enabled_data_labels = enabled_data_labels)
621
        return H_scatter   
622
        
623
        
624
    def plot_grid(self, htmlpath='./', htmlname=None, enabled_data_labels = False):
625
        """Grid plot"""
626
        
627
        H_grid = vismap.plot_grid(self,
628
                                  htmlpath=htmlpath, 
629
                                  htmlname=htmlname,
630
                                  enabled_data_labels = enabled_data_labels)
631
        return H_grid       
632
        
633
        
634
        
635
    def plot_tree(self, figsize=(16,8), add_leaf_labels = True, leaf_font_size = 18, leaf_rotation = 90):
636
        """Diagram tree plot"""
637
            
638
        fig = plt.figure(figsize=figsize)
639
        
640
        if self.cluster_flag:
641
            
642
            Z = self.Z
643
            
644
645
            D_leaf_colors = self.bitsinfo['colors'].to_dict() 
646
            link_cols = {}
647
            for i, i12 in enumerate(Z[:,:2].astype(int)):
648
                c1, c2 = (link_cols[x] if x > len(Z) else D_leaf_colors[x] for x in i12)
649
                link_cols[i+1+len(Z)] = c1
650
            
651
            if add_leaf_labels:
652
                labels = self.alist
653
            else:
654
                labels = None
655
            P =dendrogram(Z, labels = labels, 
656
                          leaf_rotation = leaf_rotation, 
657
                          leaf_font_size = leaf_font_size, 
658
                          link_color_func=lambda x: link_cols[x])
659
        
660
        return fig
661
        
662
        
663
    def to_nwk_tree(self, treefile = 'mytree', leaf_names = None):
664
        '''
665
        convert mp object to newick tree and the data file to submit to itol sever
666
        '''
667
        return mp2newick(self, treefile = treefile, leaf_names=leaf_names)
668
        
669
        
670
    def copy(self):
671
        """copy self"""
672
        return deepcopy(self)
673
        
674
        
675
    def load(self, filename):
676
        """load self"""
677
        return self._load(filename)
678
    
679
    
680
    def save(self, filename):
681
        """save self"""
682
        return self._save(filename)