a b/aggmap/utils/matrixopt.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
"""
4
Created on Sun Aug 25 20:29:36 2019
5
6
@author: wanxiang.shen@u.nus.edu
7
8
matrix operation
9
10
"""
11
12
import numpy as np
13
from lapjv import lapjv
14
from scipy.signal import convolve2d
15
from scipy.spatial.distance import cdist
16
17
18
class Scatter2Grid:
19
    
20
    def __init__(self):  
21
        """assign x,y coords to gird numpy array"""
22
        self.fmap_shape = None
23
        self.indices = None
24
        self.indices_list = None
25
        
26
    def fit(self, df, split_channels = True, channel_col = 'Channels'):
27
        """
28
        parameters
29
        ------------------
30
        df: dataframe with x, y columns
31
        split_channels: bool, if True, will apply split by group
32
        channel_col: column in df.columns, split to groups by this col        
33
        
34
        """
35
        df['idx'] = range(len(df))
36
        
37
        embedding_2d = df[['x','y']].values
38
        N = len(df)
39
40
        size1 = int(np.ceil(np.sqrt(N)))
41
        size2 = int(np.ceil(N/size1))
42
        grid_size = (size1, size2)
43
        
44
        grid = np.dstack(np.meshgrid(np.linspace(0, 1, size2), 
45
                                     np.linspace(0, 1, size1))).reshape(-1, 2)
46
        grid_map = grid[:N]
47
        cost_matrix = cdist(grid_map, embedding_2d, "sqeuclidean").astype(float)
48
        cost_matrix = cost_matrix * (100000 / cost_matrix.max())
49
        row_asses, col_asses, _ = lapjv(cost_matrix)
50
51
        self.row_asses = row_asses
52
        self.col_asses = col_asses
53
        self.fmap_shape = grid_size
54
        self.indices = col_asses
55
56
        self.channel_col = channel_col
57
        self.split_channels = split_channels
58
        df['indices'] = self.indices
59
        self.df = df
60
        
61
        if self.split_channels:
62
            def _apply_split(x):
63
                return x[['idx', 'indices']].to_dict('list')
64
            sidx = df.groupby(channel_col).apply(_apply_split)      
65
            channels = sidx.index.tolist()
66
            indices_list = sidx.tolist()            
67
            self.channels = channels
68
            self.indices_list = indices_list
69
70
            
71
    def refit_c(self, df):
72
        """
73
        parameters
74
        ------------------
75
        df: dataframe with x, y columns
76
      
77
        
78
        """
79
        df['idx'] = range(len(df))
80
        df['indices'] = self.indices
81
        self.df = df
82
        
83
        if self.split_channels:
84
            def _apply_split(x):
85
                return x[['idx', 'indices']].to_dict('list')
86
            sidx = df.groupby(self.channel_col).apply(_apply_split)      
87
            channels = sidx.index.tolist()
88
            indices_list = sidx.tolist()            
89
            self.channels = channels
90
            self.indices_list = indices_list
91
            
92
            
93
    def transform(self, vector_1d):
94
        """vector_1d: extracted features
95
        """             
96
        ### linear assignment map ###
97
        M, N = self.fmap_shape
98
99
        if self.split_channels:
100
            arr_res = []
101
            for idict in self.indices_list:
102
103
                indices = idict['indices']
104
                idx = idict['idx']
105
106
                arr = np.zeros(self.fmap_shape)
107
                arr_1d = arr.reshape(M*N, )
108
                arr_1d[indices] = vector_1d[idx]
109
                arr = arr_1d.reshape(M, N)  
110
                arr_res.append(arr) 
111
            arr_res = np.stack(arr_res, axis=-1)
112
        else:
113
            arr_res = np.zeros(self.fmap_shape)
114
            arr_1d = arr_res.reshape(M*N, )
115
            arr_1d[self.indices] = vector_1d
116
            arr_res = arr_1d.reshape(M, N, 1)          
117
        return arr_res
118
    
119
120
    
121
class Scatter2Array:
122
    
123
    def __init__(self, fmap_shape = (128,128)):  
124
        """convert x,y coords to numpy array"""
125
        self.fmap_shape = fmap_shape
126
        self.indices = None
127
        self.indices_list = None
128
        
129
    def _fit(self, df):
130
        """df: dataframe with x, y columns"""
131
        M, N = self.fmap_shape
132
        self.X = np.linspace(df.x.min(), df.x.max(), M)
133
        self.Y = np.linspace(df.y.min(), df.y.max(), N)
134
135
    
136
    def _transform(self, dfnew):
137
        """dfnew: dataframe with x, y columns
138
           in case we need to split channels
139
        """             
140
        x = dfnew.x.values
141
        y = dfnew.y.values
142
        M, N = self.fmap_shape
143
        indices = []
144
        for i in range(len(dfnew)):
145
            #perform a l1 distance
146
            idx = np.argmin(abs(self.X-x[i]))
147
            idy = np.argmin(abs(self.Y-y[i]))     
148
            indice = N*idy + idx
149
            indices.append(indice)
150
        return indices
151
    
152
    
153
    def fit(self, df, split_channels = True, channel_col = 'Channels'):
154
        """
155
        parameters
156
        ---------------
157
        df: embedding_df, dataframe
158
        split_channels: bool, if True, will apply split by group
159
        channel_col: column in df.columns, split to groups by this col
160
        """
161
        df['idx'] = range(len(df))
162
        self.df = df
163
        self.channel_col = channel_col
164
        self.split_channels = split_channels
165
        _ = self._fit(df)
166
        
167
        if self.split_channels:
168
            g = df.groupby(channel_col)
169
            sidx = g.apply(self._transform)            
170
            self.channels = sidx.index.tolist()
171
            self.indices_list = sidx.tolist()
172
        else:    
173
            self.indices = self._transform(df)
174
            
175
            
176
    def transform(self, vector_1d):
177
        """vector_1d: feature values 1d array"""
178
        
179
        M, N = self.fmap_shape
180
        arr = np.zeros(self.fmap_shape)
181
        arr_1d = arr.reshape(M*N, )
182
            
183
        if self.split_channels:
184
            df = self.df
185
            arr_res = []
186
            for indices, channel in zip(self.indices_list, self.channels):
187
                arr = np.zeros(self.fmap_shape)
188
                df1 = df[df[self.channel_col] == channel]
189
                idx = df1.idx.tolist()
190
                arr_1d_copy = arr_1d.copy()
191
                arr_1d_copy[indices] = vector_1d[idx]
192
                arr_1d_copy = arr_1d_copy.reshape(M, N) 
193
                arr_res.append(arr_1d_copy)
194
            arr_res = np.stack(arr_res, axis=-1)
195
        else:
196
            arr_1d_copy = arr_1d.copy()
197
            arr_1d_copy[self.indices] = vector_1d
198
            arr_res = arr_1d_copy.reshape(M, N, 1) 
199
        return arr_res
200
201
202
def smartpadding(array, target_size, mode='constant', constant_values=0):
203
    """
204
    array: 2d array to be padded
205
    target_size: tuple of target array's shape
206
    """
207
    X, Y = array.shape
208
    M, N = target_size
209
    top = int(np.ceil((M-X)/2))
210
    bottom = int(M - X - top)
211
    right = int(np.ceil((N-Y)/2))
212
    left = int(N - Y - right)
213
    array_pad = np.pad(array, pad_width=[(top, bottom),
214
                                         (left, right)], 
215
                       mode=mode, 
216
                       constant_values=constant_values)
217
    
218
    return array_pad
219
220
221
def fspecial_gauss(size = 31, sigma = 2):
222
223
    """Function to mimic the 'fspecial' gaussian MATLAB function
224
      size should be odd value
225
    """
226
    x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1]
227
    g = np.exp(-((x**2 + y**2)/(2.0*sigma**2)))
228
    return g/g.sum()
229
230
231
def conv2(array, kernel_size = 31, sigma = 2,  mode='same', fillvalue = 0):
232
    kernel = fspecial_gauss(kernel_size, sigma)
233
    return np.rot90(convolve2d(np.rot90(array, 2), np.rot90(kernel, 2), 
234
                               mode=mode, 
235
                               fillvalue = fillvalue), 2)