Switch to unified view

a b/aggmap/aggmodel/explainer.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Fri Sep. 17 17:10:53 2021
4
5
@author: wanxiang.shen@u.nus.edu
6
"""
7
8
9
import numpy as np
10
import pandas as pd
11
12
from tqdm import tqdm
13
from copy import copy
14
import shap
15
16
17
from sklearn.metrics import mean_squared_error, log_loss
18
from sklearn.preprocessing import StandardScaler
19
20
from aggmap.utils.matrixopt import conv2
21
from aggmap.utils.logtools import print_warn, print_info
22
23
24
25
class shapley_explainer:
26
    """Kernel Shap based model explaination, the limiations can be found in this paper:https://christophm.github.io/interpretable-ml-book/shapley.html#disadvantages-16 <Problems with Shapley-value-based explanations as feature importance measures>. The SHAP values do not identify causality Global mean absolute Deep SHAP feature importance is the average impact on model output magnitude.
27
    
28
  
29
    Parameters
30
    ----------
31
    estimator:
32
        model with a predict or predict_probe method
33
    mp:
34
        aggmap object
35
    backgroud: string or int
36
        {'min', 'global_min','all', int}.
37
        if min, then use the min value as the backgroud data (equals to 1 sample)
38
        if global_min, then use the min value of all data as the backgroud data.   
39
        if int, then sample the K samples as the backgroud data
40
        if 'all' use all of the train data as the backgroud data for shap,
41
    k_means_sampling: bool,
42
        whether use the k-mean to sample the backgroud values or not
43
    link : 
44
        {"identity", "logit"}. A generalized linear model link to connect the feature importance values to the model output. 
45
        Since the feature importance values, phi, sum up to the model output, it often makes sense to connect them 
46
        to the output with a link function where link(output) = sum(phi). 
47
        If the model output is a probability then the LogitLink link function makes the feature importance values have log-odds units.
48
    args: 
49
        Other parameters for shap.KernelExplainer.
50
        
51
        
52
    
53
    Examples
54
    --------
55
    >>> import seaborn as sns
56
    >>> from aggmap.aggmodel.explainer import shapley_explainer
57
    >>> ## shapley explainer
58
    >>> shap_explainer = shapley_explainer(estimator, mp)
59
    >>> global_imp_shap = shap_explainer.global_explain(clf.X_)
60
    >>> local_imp_shap = shap_explainer.local_explain(clf.X_[[0]])
61
    >>> ## S-map of shapley explainer
62
    >>> sns.heatmap(local_imp_shap.shapley_importance_class_1.values.reshape(mp.fmap_shape), 
63
    >>> cmap = 'rainbow')
64
    >>> ## shapley plot
65
    >>> shap.summary_plot(shap_explainer.shap_values, 
66
    >>> feature_names = shap_explainer.feature_names) # #global  plot_type='bar
67
    >>> shap.initjs()
68
    >>> shap.force_plot(shap_explainer.explainer.expected_value[1], 
69
    >>> shap_explainer.shap_values[1], feature_names = shap_explainer.feature_names)
70
71
    """
72
73
    def __init__(self, estimator, mp, backgroud = 'min', k_means_sampling = True, link='identity', **args):
74
        '''
75
        
76
        Parameters
77
        ----------
78
        estimator:
79
            model with a predict or predict_probe method
80
        mp:
81
            aggmap object
82
        backgroud: string or int
83
            {'min', 'global_min', 'all', int}.
84
            if min, then use the min value as the backgroud data (equals to 1 sample)
85
            if global_min, then use the min value of all data as the backgroud data.            
86
            if int, then sample the K samples as the backgroud data
87
            if 'all' use all of the train data as the backgroud data for shap,
88
        k_means_sampling: bool,
89
            whether use the k-mean to sample the backgroud values or not
90
        link : 
91
            {"identity", "logit"}. A generalized linear model link to connect the feature importance values to the model output. 
92
            Since the feature importance values, phi, sum up to the model output, it often makes sense to connect them 
93
            to the output with a link function where link(output) = sum(phi). 
94
            If the model output is a probability then the LogitLink link function makes the feature importance values have log-odds units.
95
        args: 
96
            Other parameters for shap.KernelExplainer
97
        '''
98
        self.estimator = estimator
99
        self.mp = mp
100
        self.link = link
101
        self.backgroud = backgroud
102
        self.k_means_sampling = k_means_sampling
103
        
104
        train_features = self.covert_mpX_to_shapely_df(self.estimator.X_)
105
        
106
        if type(backgroud) == int:
107
            if self.k_means_sampling:
108
                self.backgroud_data =  shap.kmeans(train_features, backgroud)
109
            else:
110
                self.backgroud_data =  shap.sample(train_features, backgroud)
111
            
112
        else:
113
            if backgroud == 'min':
114
                self.backgroud_data =  train_features.min().to_frame().T.values
115
                
116
            elif backgroud == 'global_min':
117
                gmin =  train_features.min().min()
118
                self.backgroud_data =  np.full(shape=(1, train_features.shape[1]), 
119
                                               fill_value = gmin)                
120
            else:
121
                self.backgroud_data =  train_features
122
123
                    
124
        self.explainer = shap.KernelExplainer(self._predict_warpper, self.backgroud_data, link=self.link, **args)
125
        self.feature_names = train_features.columns.tolist() # mp.alist
126
127
        
128
    def _predict_warpper(self, X):
129
        X_new = self.mp.batch_transform(X, scale=False)
130
        if self.estimator.name == 'AggMap Regression Estimator': # case regression task
131
            predict_results = self.estimator.predict(X_new)
132
        else:
133
            predict_results = self.estimator.predict_proba(X_new)
134
        return predict_results
135
    
136
    def get_shap_values(self, X, nsamples = 'auto', **args):
137
        df_to_explain = self.covert_mpX_to_shapely_df(X)
138
        shap_values = self.explainer.shap_values(df_to_explain, nsamples=nsamples, **args)
139
        all_imps = []
140
        for i, data in enumerate(shap_values):
141
            name = 'shapley_importance_class_' + str(i) 
142
            imp = abs(pd.DataFrame(data, columns = self.feature_names)).mean().to_frame(name = name)
143
            all_imps.append(imp)
144
145
        df_reshape = self.mp.df_grid_reshape.set_index('v')
146
        df_reshape.index = self.mp.feature_names_reshape
147
        df_imp = df_reshape.join(pd.concat(all_imps, axis=1)).fillna(0)
148
        self.df_imp = df_imp
149
        self.shap_values = shap_values
150
        return shap_values
151
    
152
    
153
    def local_explain(self, X=None, idx=0, nsamples = 'auto', **args):
154
        
155
        '''
156
        Explaination of one sample only:
157
        
158
        Parameters
159
        ----------
160
        X: None or 4D array, where the shape is (n, w, h, c)
161
           the 4D array of AggMap multi-channel fmaps.
162
           Noted if X is None, then use the estimator.X_[[idx]] instead, namely explain the first sample if idx=0
163
        nsamples: {'auto', int}
164
            Number of times to re-evaluate the model when explaining each prediction. 
165
            More samples lead to lower variance estimates of the SHAP values. The “auto” setting uses nsamples = 2 * X.shape[1] + 2048
166
        args: other parameters in the shape_values method of shap.KernelExplainer 
167
        '''
168
        if X is None:
169
            print_info('Explaining the first sample only')
170
            X = self.clf.X_[[idx]]
171
        assert len(X.shape) == 4, 'input X mush a 4D array: (1, w, h, c)'
172
        assert len(X) == 1,  'Input X must has one sample only, but got %s' % len(X)
173
        
174
        shap_values = self.get_shap_values(X, nsamples = nsamples, **args)
175
        self.shap_values = shap_values
176
        return self.df_imp
177
    
178
    
179
    def global_explain(self, X=None, nsamples = 'auto', **args):
180
        '''
181
        Explaination of many samples.
182
        
183
        Parameters
184
        ----------
185
        X: None or 4D array, where the shape is (n, w, h, c)
186
           the 4D array of AggMap multi-channel fmaps.
187
           Noted that if X is None, then use the estimator.X_ instead, namely explain the training set of the estimator
188
        nsamples: {'auto', int}
189
            Number of times to re-evaluate the model when explaining each prediction. 
190
            More samples lead to lower variance estimates of the SHAP values. The “auto” setting uses nsamples = 2 * X.shape[1] + 2048
191
        args: other parameters in the shape_values method of shap.KernelExplainer 
192
        '''
193
        if X is None:
194
            X = self.clf.X_
195
            print_info('Explaining the whole samples of the training Set')
196
        assert len(X.shape) == 4, 'input X mush a 4D array: (n, w, h, c)'
197
        
198
        shap_values = self.get_shap_values(X, nsamples = nsamples, **args)
199
        self.shap_values = shap_values
200
        return self.df_imp
201
202
    
203
    def _covert_x_2D(self, X, feature_names):
204
        n, w,h, c = X.shape
205
        assert len(feature_names) == w*h, 'length of feature_names should be w*h of X.shape (n, w, h,c)'
206
        X_2D = X.sum(axis=-1).reshape(n, w*h)
207
        return pd.DataFrame(X_2D, columns = feature_names)
208
209
210
    def covert_mpX_to_shapely_df(self, X):
211
        dfx_stack_reshape = self._covert_x_2D(X, feature_names = self.mp.feature_names_reshape)
212
        shapely_df = pd.DataFrame(index=self.mp.alist).join(dfx_stack_reshape.T).T
213
        shapely_df = shapely_df.fillna(0)
214
        return shapely_df
215
216
    
217
    
218
219
class simply_explainer:
220
    
221
    """Simply-explainer for model explaination.
222
223
    Parameters
224
    ----------
225
    estimator: object
226
        model with a predict or predict_probe method
227
    mp: object
228
        aggmap object
229
    backgroud: {'min', 'global_min','zeros'}, default: 'min'.
230
        if "min", then use the min value of a vector of the training set,
231
        if 'global_min', then use the min value of all training set.
232
        if 'zero', then use all zeros as the backgroud data.    
233
    apply_logrithm: bool, default: False
234
        apply the logirthm to the feature importance score
235
    apply_smoothing: bool, default: False
236
        apply the gaussian smoothing on the feature importance score (Saliency map)
237
    kernel_size: int, default: 5.
238
        the kernel size for the smoothing
239
    sigma: float, default: 1.0.
240
        the sigma for the smoothing.
241
        
242
    
243
    
244
            
245
    Examples
246
    --------
247
    >>> import seaborn as sns
248
    >>> from aggmap.aggmodel.explainer import simply_explainer
249
    >>> simp_explainer = simply_explainer(estimator, mp)
250
    >>> global_imp_simp = simp_explainer.global_explain(clf.X_, clf.y_)
251
    >>> local_imp_simp = simp_explainer.local_explain(clf.X_[[0]], clf.y_[[0]])    
252
    >>> ## S-map of simply explainer
253
    >>> sns.heatmap(local_imp_simp.simply_importance.values.reshape(mp.fmap_shape), 
254
    >>> cmap = 'rainbow')
255
    
256
    """
257
    
258
    def __init__(self, 
259
                 estimator, 
260
                 mp, 
261
                 backgroud = 'min', 
262
                 apply_logrithm = False,
263
                 apply_smoothing = False, 
264
                 kernel_size = 5, 
265
                 sigma = 1.
266
                ):
267
        '''
268
        Simply-explainer for model explaination.
269
        
270
        Parameters
271
        ----------
272
        estimator:
273
            model with a predict or predict_probe method
274
        mp:
275
            aggmap object
276
        backgroud: 
277
            {'min', 'global_min', 'zeros'}, 
278
            if 'zero' use all zeros as the backgroud data, 
279
            if "min" use the min value of a vector of the training set,
280
            if 'global_min', use the min value of all training set.
281
        apply_logrithm: bool, default: False
282
            apply the logirthm to the feature importance score
283
        apply_smoothing: bool, default: False
284
            apply the gaussian smoothing on the feature importance score (s-map )
285
        kernel_size:
286
            the kernel size for the smoothing
287
        sigma:
288
            the sigma for the smoothing.
289
        '''
290
        self.estimator = estimator
291
        self.mp = mp
292
        self.apply_logrithm = apply_logrithm
293
        self.apply_smoothing = apply_smoothing
294
        self.kernel_size = kernel_size
295
        self.sigma = sigma
296
        self.backgroud = backgroud
297
        if backgroud == 'min':
298
            self.backgroud_data =  mp.transform_mpX_to_df(self.estimator.X_).min().values
299
        elif backgroud == 'zeros':
300
            self.backgroud_data =  np.zeros(shape=(len(mp.df_grid_reshape), ))
301
        else:
302
            gmin = self.estimator.X_.min()
303
            self.backgroud_data =  np.full(shape=(len(mp.df_grid_reshape), ), 
304
                                           fill_value = gmin)
305
306
        self.scaler = StandardScaler()
307
308
        df_grid = mp.df_grid_reshape.set_index('v')
309
        df_grid.index = self.mp.feature_names_reshape
310
        self.df_grid = df_grid
311
        
312
        if self.estimator.name == 'AggMap Regression Estimator':
313
            self._f = mean_squared_error
314
        else:
315
            self._f = log_loss
316
        
317
    def _sigmoid(self, x):
318
        return 1 / (1 + np.exp(-x))
319
320
    def _islice(self, lst, n):
321
        return [lst[i:i + n] for i in range(0, len(lst), n)]    
322
    
323
    
324
    def global_explain(self, X=None, y=None):
325
        '''
326
        Explaination of many samples.
327
        
328
        Parameters
329
        ----------
330
        X: None or 4D array, where the shape is (n, w, h, c)
331
           the 4D array of AggMap multi-channel fmaps
332
        y: None or 4D array, where the shape is (n, class_num)
333
           the True label
334
        Noted that if X and y are None, then use the estimator.X_ and estimator.y_ instead, namely explain the training set of the estimator
335
        '''
336
        
337
        if X is None:
338
            X = self.estimator.X_
339
            y = self.estimator.y_
340
            print_info('Explaining the whole samples of the training Set')
341
        
342
        assert len(X.shape) == 4, 'input X mush a 4D array: (n, w, h, c)'
343
        N, W, H, C = X.shape
344
        
345
        dfY = pd.DataFrame(y)
346
        Y_true = y
347
        Y_prob = self.estimator._model.predict(X, verbose = 0)
348
        
349
        T = len(self.df_grid)
350
        nX = 20 # 10 arrX to predict
351
352
        if self.estimator.name == 'AggMap MultiLabels Estimator':
353
            Y_prob = self._sigmoid(Y_prob)
354
        final_res = {}
355
        for k, col in enumerate(dfY.columns):
356
            print_info('calculating feature importance for class %s ...' % col)
357
            results = []
358
            loss = self._f(Y_true[:, k].tolist(), Y_prob[:, k].tolist())
359
            
360
            tmp_X = []
361
            flag = 0
362
            for i in tqdm(range(T), ascii= True):
363
                ts = self.df_grid.iloc[i]
364
                y = ts.y
365
                x = ts.x
366
                ## step 1: make permutaions
367
                X1 = np.array(X)
368
                #x_min = X[:, y, x,:].min()
369
                vmin = self.backgroud_data[i]
370
                X1[:, y, x,:] = np.full(X1[:, y, x,:].shape, fill_value = vmin)
371
                tmp_X.append(X1)
372
373
                if (flag == nX) | (i == T-1):
374
                    X2p = np.concatenate(tmp_X)
375
                    ## step 2: make predictions
376
                    Y_pred_prob = self.estimator._model.predict(X2p, verbose = 0) #predict ont by one is not efficiency
377
                    if self.estimator.name == 'AggMap MultiLabels Estimator':
378
                        Y_pred_prob = self._sigmoid(Y_pred_prob)    
379
380
                    ## step 3: calculate changes
381
                    for Y_pred in self._islice(Y_pred_prob, N):
382
                        mut_loss = self._f(Y_true[:, k].tolist(), Y_pred[:, k].tolist()) 
383
                        res =  mut_loss - loss # if res > 0, important, othervise, not important
384
                        results.append(res)
385
                    flag = 0
386
                    tmp_X = []
387
                flag += 1
388
389
            ## step 4:apply scaling or smothing 
390
            s = pd.DataFrame(results).values
391
            if self.apply_logrithm:
392
                s = np.log(s)
393
            smin = np.nanmin(s[s != -np.inf])
394
            smax = np.nanmax(s[s != np.inf])
395
            s = np.nan_to_num(s, nan=smin, posinf=smax, neginf=smin) #fillna with smin
396
            a = self.scaler.fit_transform(s)
397
            a = a.reshape(*self.mp.fmap_shape)
398
            if self.apply_smoothing:
399
                covda = conv2(a, kernel_size=self.kernel_size, sigma=self.sigma)
400
                results = covda.reshape(-1,).tolist()
401
            else:
402
                results = a.reshape(-1,).tolist()
403
            final_res.update({col:results})
404
405
        df = pd.DataFrame(final_res, index = self.mp.feature_names_reshape)
406
        df.columns = df.columns.astype(str)
407
        df.columns = 'simply_importance_class_' + df.columns
408
        df = self.df_grid.join(df)
409
        return df
410
411
412
    def local_explain(self, X=None, y=None, idx=0):
413
        '''
414
        Explaination of one sample only.
415
        
416
        Parameters
417
        ----------
418
        X: None or 4D array, where the shape is (1, w, h, c)
419
        y: the True label, None or 4D array, where the shape is (1, class_num).
420
        idx: int, 
421
             index of the sample to interpret
422
             Noted that if X and y are None, then use the estimator.X_[[idx]] and estimator.y_[[idx]] instead, namely explain the first sample if idx=0.
423
             
424
        Return
425
        ----------
426
            Feature importance of the current class
427
            
428
429
        '''
430
        
431
        if X is None:
432
            X = self.estimator.X_[[idx]]
433
            y = self.estimator.y_[[idx]]
434
            print_info('Explaining the one sample of the training Set')
435
        
436
        assert len(X.shape) == 4, 'input X mush a 4D array: (1, w, h, c)'
437
        assert (len(X) == 1) & (len(y) == 1), 'Input X, y must have one sample only, but got %s, %s' % (len(X), len(y))
438
439
        
440
        N, W, H, C = X.shape
441
        
442
        dfY = pd.DataFrame(y)
443
        Y_true = y
444
        Y_prob = self.estimator._model.predict(X, verbose = 0)
445
        
446
        T = len(self.df_grid)
447
        nX = 20 # 10 arrX to predict
448
449
        if self.estimator.name == 'AggMap MultiLabels Estimator':
450
            Y_prob = self._sigmoid(Y_prob)
451
452
        results = []
453
        loss = self._f(Y_true.ravel().tolist(),  Y_prob.ravel().tolist())
454
455
        all_X1 = []
456
        for i in tqdm(range(T), ascii= True):
457
            ts = self.df_grid.iloc[i]
458
            y = ts.y
459
            x = ts.x
460
            X1 = np.array(X)
461
            vmin = self.backgroud_data[i]
462
            X1[:, y, x,:] = np.full(X1[:, y, x,:].shape, fill_value = vmin)
463
            all_X1.append(X1)
464
465
        all_X = np.concatenate(all_X1)
466
        all_Y_pred_prob = self.estimator._model.predict(all_X, verbose = 0)
467
468
        for Y_pred_prob in all_Y_pred_prob:
469
            if self.estimator.name == 'AggMap MultiLabels Estimator':
470
                Y_pred_prob = self._sigmoid(Y_pred_prob)
471
            mut_loss = self._f(Y_true.ravel().tolist(), Y_pred_prob.ravel().tolist()) 
472
            res =  mut_loss - loss # if res > 0, important, othervise, not important
473
            results.append(res)
474
475
        ## apply smothing and scalings
476
        s = pd.DataFrame(results).values
477
        if self.apply_logrithm:
478
            s = np.log(s)
479
        smin = np.nanmin(s[s != -np.inf])
480
        smax = np.nanmax(s[s != np.inf])
481
        s = np.nan_to_num(s, nan=smin, posinf=smax, neginf=smin) #fillna with smin
482
        a = self.scaler.fit_transform(s)
483
        a = a.reshape(*self.mp.fmap_shape)
484
        if self.apply_smoothing:
485
            covda = conv2(a, kernel_size=self.kernel_size, sigma=self.sigma)
486
            results = covda.reshape(-1,).tolist()
487
        else:
488
            results = a.reshape(-1,).tolist()
489
490
        df = pd.DataFrame(results, 
491
                          index = self.mp.feature_names_reshape,
492
                          columns = ['simply_importance'])
493
        df = self.df_grid.join(df)
494
        return df
495
496
    
497
498
if __name__ == '__main__':
499
    '''
500
    Model explaination using two methods: simply explainer and shapley explainer
501
    '''
502
    
503
    import seaborn as sns
504
    
505
    ## simply explainer
506
    simp_explainer = simply_explainer(estimator, mp)
507
    global_imp_simp = simp_explainer.global_explain(clf.X_, clf.y_)
508
    local_imp_simp = simp_explainer.local_explain(clf.X_[[0]], clf.y_[[0]])    
509
    
510
    ## S-map of simply explainer
511
    sns.heatmap(local_imp_simp.simply_importance.values.reshape(mp.fmap_shape), cmap = 'rainbow')
512
    
513
    
514
    
515
    ## shapley explainer
516
    shap_explainer = shapley_explainer(estimator, mp)
517
    global_imp_shap = shap_explainer.global_explain(clf.X_)
518
    local_imp_shap = shap_explainer.local_explain(clf.X_[[0]])
519
    
520
    ## S-map of shapley explainer
521
    sns.heatmap(local_imp_shap.shapley_importance_class_1.values.reshape(mp.fmap_shape), cmap = 'rainbow')
522
523
    ## shapley plot
524
    shap.summary_plot(shap_explainer.shap_values, feature_names = shap_explainer.feature_names) # #global  plot_type='bar
525
    shap.initjs()
526
    shap.force_plot(shap_explainer.explainer.expected_value[1], shap_explainer.shap_values[1], feature_names = shap_explainer.feature_names)