Switch to unified view

a b/aggmap/aggmodel/xAI/perturb.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue Feb  2 14:54:38 2021
4
5
@author: wanxiang.shen@u.nus.edu
6
"""
7
8
9
import numpy as np
10
import pandas as pd
11
12
from tqdm import tqdm
13
from copy import copy
14
15
from aggmap.utils.matrixopt import conv2
16
from sklearn.metrics import mean_squared_error, log_loss
17
from sklearn.preprocessing import StandardScaler
18
19
20
21
22
def islice(lst, n):
23
    return [lst[i:i + n] for i in range(0, len(lst), n)]
24
25
26
def GetGlobalIMP(model, mp, arrX, dfY, task_type = 'classification', 
27
                sigmoidy = False, 
28
                apply_logrithm = False,
29
                apply_smoothing = False, 
30
                kernel_size = 5, 
31
                sigma = 1.6):
32
    '''
33
    Forward prop. Feature importance
34
    
35
    apply_scale_smothing: alpplying a smothing on the map
36
    
37
    '''
38
    
39
    if task_type == 'classification':
40
        f = log_loss
41
    else:
42
        f = mean_squared_error
43
        
44
    def sigmoid(x):
45
        return 1 / (1 + np.exp(-x))
46
47
    scaler = StandardScaler()
48
    grid = mp.plot_grid()
49
    Y_true = dfY.values
50
    df_grid = mp.df_grid.sort_values(['y', 'x']).reset_index(drop=True)
51
    Y_prob = model.predict(arrX)
52
    N, W, H, C = arrX.shape
53
    T = len(df_grid)
54
    nX = 20 # 10 arrX to predict
55
    vmin = arrX.min()
56
    
57
    if (sigmoidy) & (task_type == 'classification'):
58
        Y_prob = sigmoid(Y_prob)
59
    
60
    final_res = {}
61
    for k, col in enumerate(dfY.columns):
62
        if col == 'Healthy':
63
            continue # omit this feature imp
64
65
        print('calculating feature importance for %s ...' % col)
66
        
67
        results = []
68
        loss = f(Y_true[:, k].tolist(), Y_prob[:, k].tolist())
69
        
70
        tmp_X = []
71
        flag = 0
72
        for i in tqdm(range(T), ascii= True):
73
            
74
            ts = df_grid.iloc[i]
75
            y = ts.y
76
            x = ts.x
77
            
78
            ## step 1: make permutaions
79
            X1 = np.array(arrX)
80
            X1[:, y, x,:] = np.full(X1[:, y, x,:].shape, fill_value = arrX.min())
81
            tmp_X.append(X1)
82
83
            if (flag == nX) | (i == T-1):
84
                X2p = np.concatenate(tmp_X)
85
                ## step 2: make predictions
86
                Y_pred_prob = model.predict(X2p) #predict ont by one is not efficiency
87
                if (sigmoidy) & (task_type == 'classification'):
88
                    Y_pred_prob = sigmoid(Y_pred_prob)    
89
90
                ## step 3: calculate changes
91
                for Y_pred in islice(Y_pred_prob, N):
92
                    mut_loss = f(Y_true[:, k].tolist(), Y_pred[:, k].tolist()) 
93
                    res =  mut_loss - loss # if res > 0, important, othervise, not important
94
                    results.append(res)
95
96
                flag = 0
97
                tmp_X = []
98
            flag += 1
99
100
        ## step 4:apply scaling or smothing 
101
        s = pd.DataFrame(results).values
102
        if apply_logrithm:
103
            s = np.log(s)
104
        smin = np.nanmin(s[s != -np.inf])
105
        smax = np.nanmax(s[s != np.inf])
106
        s = np.nan_to_num(s, nan=smin, posinf=smax, neginf=smin) #fillna with smin
107
        a = scaler.fit_transform(s)
108
        a = a.reshape(*mp._S.fmap_shape)
109
        if apply_smoothing:
110
            covda = conv2(a, kernel_size=kernel_size, sigma=sigma)
111
            results = covda.reshape(-1,).tolist()
112
        else:
113
            results = a.reshape(-1,).tolist()
114
        
115
        
116
        final_res.update({col:results})
117
        
118
    df = pd.DataFrame(final_res)
119
    df.columns = df.columns + '_importance'
120
    df = df_grid.join(df)
121
    return df
122
123
124
125
def GetLocalIMP(model, mp, arrX, dfY, 
126
                    task_type = 'classification', 
127
                    sigmoidy = False,  
128
                    apply_logrithm = False, 
129
                    apply_smoothing = False,
130
                    kernel_size = 3, sigma = 1.2):
131
    '''
132
    Forward prop. Feature importance
133
    '''
134
    
135
    assert len(arrX) == 1, 'each for only one image!'
136
    
137
    if task_type == 'classification':
138
        f = log_loss
139
    else:
140
        f = mean_squared_error
141
        
142
    def sigmoid(x):
143
        return 1 / (1 + np.exp(-x))
144
    
145
    scaler = StandardScaler()
146
    
147
    grid = mp.plot_grid()
148
    Y_true = dfY.values
149
    df_grid = mp.df_grid.sort_values(['y', 'x']).reset_index(drop=True)
150
    Y_prob = model.predict(arrX)
151
    N, W, H, C = arrX.shape
152
153
    if (sigmoidy) & (task_type == 'classification'):
154
        Y_prob = sigmoid(Y_prob)
155
156
    results = []
157
    loss = f(Y_true.ravel().tolist(),  Y_prob.ravel().tolist())
158
    
159
    all_X1 = []
160
    for i in tqdm(range(len(df_grid)), ascii= True):
161
        ts = df_grid.iloc[i]
162
        y = ts.y
163
        x = ts.x
164
        X1 = np.array(arrX)
165
        X1[:, y, x,:] = np.full(X1[:, y, x,:].shape, fill_value = arrX.min())
166
        #Y1 = model.predict(X1)
167
        #Y_pred_prob = model.predict(X1)
168
        all_X1.append(X1)
169
        
170
    all_X = np.concatenate(all_X1)
171
    all_Y_pred_prob = model.predict(all_X)
172
173
    for Y_pred_prob in all_Y_pred_prob:
174
        if (sigmoidy) & (task_type == 'classification'):
175
            Y_pred_prob = sigmoid(Y_pred_prob)
176
        mut_loss = f(Y_true.ravel().tolist(), Y_pred_prob.ravel().tolist()) 
177
        res =  mut_loss - loss # if res > 0, important, othervise, not important
178
        results.append(res)
179
180
    ## apply smothing and scalings
181
    s = pd.DataFrame(results).values
182
    if apply_logrithm:
183
        s = np.log(s)
184
    smin = np.nanmin(s[s != -np.inf])
185
    smax = np.nanmax(s[s != np.inf])
186
    s = np.nan_to_num(s, nan=smin, posinf=smax, neginf=smin) #fillna with smin
187
    a = scaler.fit_transform(s)
188
    a = a.reshape(*mp._S.fmap_shape)
189
    if apply_smoothing:
190
        covda = conv2(a, kernel_size=kernel_size, sigma=sigma)
191
        results = covda.reshape(-1,).tolist()
192
    else:
193
        results = a.reshape(-1,).tolist()
194
        
195
196
    df = pd.DataFrame(results, columns = ['imp'])
197
    #df.columns = df.columns + '_importance'
198
    df = df_grid.join(df)
199
    return df