a b/aggmap/_devmap.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Sun Aug 25 20:29:36 2019
5
6
@author: wanxiang.shen@u.nus.edu
7
8
main aggmap code
9
10
11
"""
12
from aggmap.utils.logtools import print_info, print_warn, print_error
13
from aggmap.utils.matrixopt import Scatter2Grid, Scatter2Array, smartpadding 
14
from aggmap.utils import vismap, summary, calculator
15
16
from sklearn.cluster import AgglomerativeClustering
17
from sklearn.manifold import TSNE, MDS
18
from sklearn.utils import shuffle
19
from joblib import Parallel, delayed, load, dump
20
from scipy.spatial.distance import squareform
21
from scipy.cluster.hierarchy import fcluster, linkage, dendrogram
22
import matplotlib.pylab as plt
23
import seaborn as sns
24
from umap import UMAP
25
from tqdm import tqdm
26
from copy import copy
27
import pandas as pd
28
import numpy as np
29
import os
30
31
32
class Base:
33
    
34
    def __init__(self):
35
        pass
36
        
37
    def _save(self, filename):
38
        return dump(self, filename)
39
        
40
    def _load(self, filename):
41
        return load(filename)
42
43
    
44
    def MinMaxScaleClip(self, x, xmin, xmax):
45
        scaled = (x - xmin) / ((xmax - xmin) + 1e-8)
46
        return scaled
47
48
    def StandardScaler(self, x, xmean, xstd):
49
        return (x-xmean) / (xstd + 1e-8) 
50
    
51
52
    
53
def _cluster_model2linkage_matrix(model):    
54
    counts = np.zeros(model.children_.shape[0])
55
    n_samples = len(model.labels_)
56
    for i, merge in enumerate(model.children_):
57
        current_count = 0
58
        for child_idx in merge:
59
            if child_idx < n_samples:
60
                current_count += 1  # leaf node
61
            else:
62
                current_count += counts[child_idx - n_samples]
63
        counts[i] = current_count
64
65
    linkage_matrix = np.column_stack([model.children_, model.distances_,
66
                                      counts]).astype(float)
67
    return linkage_matrix
68
69
70
class AggMap(Base):
71
    
72
    
73
    '''
74
    Note: t-SNE initialize method should be changed into 'pca': https://www.nature.com/articles/s41587-020-00809-z
75
    
76
    >>> mp = AggMap(dfx)
77
    >>> mp.fit(emb_method = 'tsne', init = 'pca')
78
    
79
    '''
80
    
81
    def __init__(self, 
82
                 dfx,
83
                 metric = 'correlation',
84
                ):
85
        
86
        """
87
        paramters
88
        -----------------
89
        dfx: pandas DataFrame
90
        metric: {'cosine', 'correlation', 'euclidean', 'jaccard', 'hamming', 'dice'}, default: 'correlation', measurement of feature distance
91
92
        """
93
        
94
        assert type(dfx) == pd.core.frame.DataFrame, 'input dfx mush be pandas DataFrame!'
95
        super().__init__()
96
97
        self.metric = metric
98
        self.isfit = False
99
        self.alist = dfx.columns.tolist()
100
        self.ftype = 'feature points'
101
102
        
103
        ## calculating distance
104
        print_info('Calculating distance ...')
105
        D = calculator.pairwise_distance(dfx.values, n_cpus=16, method=metric)
106
        D = np.nan_to_num(D,copy=False)
107
        D_ = squareform(D)
108
        self.info_distance = D_.clip(0, np.inf)
109
110
        ## statistic info
111
        S = summary.Summary(n_jobs = 10)
112
        res= []
113
        for i in tqdm(range(dfx.shape[1]), ascii=True):
114
            r = S._statistics_one(dfx.values, i)
115
            res.append(r)
116
        dfs = pd.DataFrame(res, index = self.alist)
117
        self.info_scale = dfs
118
        
119
        print_info('Applying the Agglomerative Clustering ...')
120
        cluster_model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
121
        cluster_model.fit(dfx.values.T)
122
        self._intrinsic_Z = _cluster_model2linkage_matrix(cluster_model)        
123
        
124
        
125
        
126
    def _fit_embedding(self, 
127
                        dist_matrix,
128
                        method = 'umap',  
129
                        n_components = 2,
130
                        random_state = 32,  
131
                        verbose = 2,
132
                        n_neighbors = 15,
133
                        min_dist = 0.1,
134
                        **kwargs):
135
        
136
        """
137
        parameters
138
        -----------------
139
        dist_matrix: distance matrix to fit
140
        method: {'tsne', 'umap', 'mds'}, algorithm to embedd high-D to 2D
141
        kwargs: the extra parameters for the conresponding algorithm
142
        """
143
144
        if 'metric' in kwargs.keys():
145
            metric = kwargs.get('metric')
146
            kwargs.pop('metric')
147
            
148
        else:
149
            metric = 'precomputed'
150
151
        if method == 'tsne':
152
            embedded = TSNE(n_components=n_components, 
153
                            random_state=random_state,
154
                            metric = metric,
155
                            verbose = verbose,
156
                            **kwargs)
157
        elif method == 'umap':
158
            embedded = UMAP(n_components = n_components, 
159
                            n_neighbors = n_neighbors,
160
                            min_dist = min_dist,
161
                            verbose = verbose,
162
                            random_state=random_state, 
163
                            metric = metric, **kwargs)
164
            
165
        elif method =='mds':
166
            if 'metric' in kwargs.keys():
167
                kwargs.pop('metric')
168
            if 'dissimilarity' in kwargs.keys():
169
                dissimilarity = kwargs.get('dissimilarity')
170
                kwargs.pop('dissimilarity')
171
            else:
172
                dissimilarity = 'precomputed'
173
                
174
            embedded = MDS(metric = True, 
175
                           n_components= n_components,
176
                           verbose = verbose,
177
                           dissimilarity = dissimilarity, 
178
                           random_state = random_state, **kwargs)
179
        
180
        embedded = embedded.fit(dist_matrix)    
181
        
182
        return embedded
183
    
184
    
185
   
186
            
187
188
    def fit(self, 
189
            feature_group_list = [],
190
            cluster_channels = 5,
191
            var_thr = -1, 
192
            split_channels = True, 
193
            fmap_type = 'grid',  
194
            fmap_shape = None, 
195
            emb_method = 'umap', 
196
            min_dist = 0.1, 
197
            n_neighbors = 15,
198
            verbose = 2, 
199
            random_state = 32,
200
            group_color_dict  = {},
201
            lnk_method = 'ward',
202
            **kwargs): 
203
        """
204
        parameters
205
        -----------------
206
        feature_group_list: list of the group name for each feature point
207
        cluster_channels: int, number of the channels(clusters) if feature_group_list is empty
208
        var_thr: float, defalt is -1, meaning that feature will be included only if the conresponding variance larger than this value. Since some of the feature has pretty low variances, we can remove them by increasing this threshold
209
        split_channels: bool, if True, outputs will split into various channels using the types of feature
210
        fmap_type:{'scatter', 'grid'}, default: 'gird', if 'scatter', will return a scatter mol map without an assignment to a grid
211
        fmap_shape: None or tuple, size of molmap, only works when fmap_type is 'scatter', if None, the size of feature map will be calculated automatically
212
        emb_method: {'tsne', 'umap', 'mds'}, algorithm to embedd high-D to 2D
213
        group_color_dict: dict of the group colors, keys are the group names, values are the colors
214
        lnk_method: {'ward','complete', 'average', 'single'}, linkage method
215
        kwargs: the extra parameters for the conresponding embedding method
216
        """
217
            
218
        if 'n_components' in kwargs.keys():
219
            kwargs.pop('n_components')
220
            
221
            
222
        ## embedding  into a 2d 
223
        assert emb_method in ['tsne', 'umap', 'mds'], 'No Such Method Supported: %s' % emb_method
224
        assert fmap_type in ['scatter', 'grid'], 'No Such Feature Map Type Supported: %s'   % fmap_type     
225
        self.var_thr = var_thr
226
        self.split_channels = split_channels
227
        self.fmap_type = fmap_type
228
        self.fmap_shape = fmap_shape
229
        self.emb_method = emb_method
230
        self.lnk_method = lnk_method
231
        
232
        if fmap_shape != None:
233
            assert len(fmap_shape) == 2, "fmap_shape must be a tuple with two elements!"
234
        
235
        # flist and distance
236
        flist = self.info_scale[self.info_scale['var'] > self.var_thr].index.tolist()
237
        
238
        dfd = pd.DataFrame(squareform(self.info_distance),
239
                           index=self.alist,
240
                           columns=self.alist)
241
        dist_matrix = dfd.loc[flist][flist]
242
        self.flist = flist
243
        
244
        self.x_mean = self.info_scale['mean'].values
245
        self.x_std =  self.info_scale['std'].values
246
        
247
        self.x_min = self.info_scale['min'].values
248
        self.x_max = self.info_scale['max'].values
249
        
250
   
251
        #bitsinfo
252
        dfb = pd.DataFrame(self.alist, columns = ['IDs'])
253
        if feature_group_list != []:
254
255
            self.Z = self._intrinsic_Z
256
            
257
            assert len(feature_group_list) == len(self.alist), "the length of the input group list is not equal to length of the feature list"
258
            self.cluster_channels = len(set(feature_group_list))
259
            self.feature_group_list = feature_group_list
260
            
261
            dfb['Subtypes'] = feature_group_list
262
            
263
            if set(feature_group_list).issubset(set(group_color_dict.keys())):
264
                self.group_color_dict = group_color_dict
265
                dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
266
            else:
267
                unique_types = dfb['Subtypes'].unique()
268
                color_list = sns.color_palette("hsv", len(unique_types)).as_hex()
269
                group_color_dict = dict(zip(unique_types, color_list))
270
                dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
271
                self.group_color_dict = group_color_dict
272
        else:
273
            
274
            self.cluster_channels = cluster_channels
275
            print_info('applying hierarchical clustering to obtain group information ...')
276
277
            if self.lnk_method != 'ward':
278
                Z = linkage(squareform(dfd.values),  lnk_method)
279
            else:
280
                Z = self._intrinsic_Z
281
                
282
            labels = fcluster(Z, cluster_channels, criterion='maxclust')
283
            
284
            feature_group_list = ['cluster_%s' % str(i).zfill(2) for i in labels]
285
            dfb['Subtypes'] = feature_group_list
286
            dfb = dfb.sort_values('Subtypes')
287
            unique_types = dfb['Subtypes'].unique()
288
            
289
            if not set(unique_types).issubset(set(group_color_dict.keys())):
290
                color_list = sns.color_palette("hsv", len(unique_types)).as_hex()
291
                group_color_dict = dict(zip(unique_types, color_list))
292
            
293
            dfb['colors'] = dfb['Subtypes'].map(group_color_dict)
294
            self.group_color_dict = group_color_dict           
295
            self.Z = Z
296
            self.feature_group_list = feature_group_list
297
            
298
299
        self.bitsinfo = dfb
300
        colormaps = dfb.set_index('Subtypes')['colors'].to_dict()
301
        colormaps.update({'NaN': '#000000'})
302
        self.colormaps = colormaps
303
  
304
        if fmap_type == 'grid':
305
            S = Scatter2Grid()
306
        else:
307
            if fmap_shape == None:
308
                N = len(self.flist)
309
                l = np._int(np.sqrt(N))*2
310
                fmap_shape = (l, l)                
311
            S = Scatter2Array(fmap_shape)
312
        
313
        self._S = S
314
315
        ## 2d embedding first
316
        embedded = self._fit_embedding(dist_matrix,
317
                                       method = emb_method,
318
                                       n_neighbors = n_neighbors,
319
                                       random_state = random_state,
320
                                       min_dist = min_dist, 
321
                                       verbose = verbose,
322
                                       n_components = 2, **kwargs)
323
        
324
        self.embedded = embedded 
325
        
326
        df = pd.DataFrame(embedded.embedding_, index = self.flist,columns=['x', 'y'])
327
        typemap = self.bitsinfo.set_index('IDs')
328
        df = df.join(typemap)
329
        df['Channels'] = df['Subtypes']
330
        self.df_embedding = df
331
      
332
        if self.fmap_type == 'scatter':
333
            ## naive scatter algorithm
334
            print_info('Applying naive scatter feature map...')
335
            self._S.fit(self.df_embedding, self.split_channels, channel_col = 'Channels')
336
            print_info('Finished')
337
            
338
        else:
339
            ## linear assignment algorithm 
340
            print_info('Applying grid feature map(assignment), this may take several minutes(1~30 min)')
341
            self._S.fit(self.df_embedding, self.split_channels, channel_col = 'Channels')
342
            print_info('Finished')
343
        
344
        ## fit flag
345
        self.isfit = True
346
        if self.fmap_shape == None:
347
            self.fmap_shape = self._S.fmap_shape        
348
        else:
349
            m, n = self.fmap_shape
350
            p, q = self._S.fmap_shape
351
            assert (m >= p) & (n >=q), "fmap_shape's width must >= %s, height >= %s " % (p, q)
352
        return self
353
        
354
355
    def transform(self, 
356
                  arr_1d, 
357
                  scale = True, 
358
                  scale_method = 'minmax',
359
                  fillnan = 0):
360
    
361
    
362
        """
363
        parameters
364
        --------------------
365
        arr_1d: 1d numpy array feature points
366
        scale: bool, if True, we will apply MinMax scaling by the precomputed values
367
        scale_method: {'minmax', 'standard'}
368
        fillnan: fill nan value, default: 0
369
        """
370
        
371
        if not self.isfit:
372
            print_error('please fit first!')
373
            return
374
375
        if scale:
376
            if scale_method == 'standard':
377
                arr_1d = self.StandardScaler(arr_1d, self.x_mean, self.x_std)
378
            else:
379
                arr_1d = self.MinMaxScaleClip(arr_1d, self.x_min, self.x_max)
380
        
381
        df = pd.DataFrame(arr_1d).T
382
        df.columns = self.alist
383
384
        df = df[self.flist]
385
        vector_1d = df.values[0] #shape = (N, )
386
        fmap = self._S.transform(vector_1d)  
387
        p, q, c = fmap.shape
388
        
389
        if self.fmap_shape != None:        
390
            m, n = self.fmap_shape
391
            if (m > p) | (n > q):
392
                fps = []
393
                for i in range(c):
394
                    fp = smartpadding(fmap[:,:,i], self.fmap_shape)
395
                    fps.append(fp)
396
                fmap = np.stack(fps, axis=-1)        
397
398
        return np.nan_to_num(fmap, nan = fillnan)   
399
    
400
    
401
402
    
403
    def batch_transform(self, 
404
                        array_2d, 
405
                        scale = True, 
406
                        scale_method = 'minmax',
407
                        n_jobs=4,
408
                        fillnan = 0):
409
    
410
        """
411
        parameters
412
        --------------------
413
        array_2d: 2D numpy array feature points, M(samples) x N(feature ponits)
414
        scale: bool, if True, we will apply MinMax scaling by the precomputed values
415
        scale_method: {'minmax', 'standard'}
416
        n_jobs: number of parallel
417
        fillnan: fill nan value, default: 0
418
        """
419
        
420
        if not self.isfit:
421
            print_error('please fit first!')
422
            return
423
        
424
        assert type(array_2d) == np.ndarray, 'input must be numpy ndarray!' 
425
        assert array_2d.ndim == 2, 'input must be 2-D  numpy array!' 
426
        
427
        P = Parallel(n_jobs=n_jobs)
428
        res = P(delayed(self.transform)(arr_1d, 
429
                                        scale,
430
                                        scale_method,
431
                                        fillnan) for arr_1d in tqdm(array_2d, ascii=True)) 
432
        X = np.stack(res) 
433
        
434
        return X
435
    
436
    
437
    def plot_scatter(self, htmlpath='./', htmlname=None, radius = 2, enabled_data_labels = False):
438
        """radius: the size of the scatter, must be int"""
439
        df_scatter, H_scatter = vismap.plot_scatter(self,  
440
                                                    htmlpath=htmlpath, 
441
                                                    htmlname=htmlname,
442
                                                    radius = radius,
443
                                                    enabled_data_labels = enabled_data_labels)
444
        
445
        self.df_scatter = df_scatter
446
        return H_scatter   
447
        
448
        
449
    def plot_grid(self, htmlpath='./', htmlname=None, enabled_data_labels = False):
450
        
451
        if self.fmap_type != 'grid':
452
            return
453
        
454
        df_grid, H_grid = vismap.plot_grid(self,
455
                                           htmlpath=htmlpath, 
456
                                           htmlname=htmlname,
457
                                           enabled_data_labels = enabled_data_labels)
458
        
459
        self.df_grid = df_grid
460
        return H_grid       
461
        
462
        
463
        
464
    def plot_tree(self, figsize=(16,8), add_leaf_labels = True, leaf_font_size = 18, leaf_rotation = 90):
465
466
        fig = plt.figure(figsize=figsize)
467
     
468
        Z = self.Z
469
470
        D_leaf_colors = self.bitsinfo['colors'].to_dict() 
471
        link_cols = {}
472
        for i, i12 in enumerate(Z[:,:2].astype(int)):
473
            c1, c2 = (link_cols[x] if x > len(Z) else D_leaf_colors[x] for x in i12)
474
            link_cols[i+1+len(Z)] = c1
475
476
        if add_leaf_labels:
477
            labels = self.alist
478
        else:
479
            labels = None
480
481
        P =dendrogram(Z, labels = labels, 
482
                      leaf_rotation = leaf_rotation, 
483
                      leaf_font_size = leaf_font_size, 
484
                      link_color_func=lambda x: link_cols[x])
485
        
486
        return fig
487
        
488
        
489
    def copy(self):
490
        return copy(self)
491
        
492
        
493
    def load(self, filename):
494
        return self._load(filename)
495
    
496
    
497
    def save(self, filename):
498
        return self._save(filename)