Diff of /DL_CV/validation.py [000000] .. [cd187b]

Switch to unified view

a b/DL_CV/validation.py
1
# -*- coding: utf-8 -*-
2
# @Author  : chq_N
3
# @Time    : 2020/8/26
4
5
6
from datetime import datetime
7
8
import numpy as np
9
from scipy.special import softmax
10
from sklearn.metrics import roc_auc_score, roc_curve
11
from sklearn.model_selection import StratifiedKFold
12
13
from model import init_model
14
15
16
def cross_val(X, y, seed):
17
    np.random.seed(seed)
18
    kf = StratifiedKFold(n_splits=5, shuffle=True)
19
    prob_list = list()
20
    auc_list = list()
21
    y_list = list()
22
    index_group = []
23
    for train_index, test_index in kf.split(X, y):
24
        index_group.append(test_index.tolist())
25
26
    for i in range(len(index_group)):
27
        fold = str(i) + '-5'
28
        test_idx = index_group[i]
29
        _val_i = i + 1
30
        if _val_i >= len(index_group): _val_i = 0
31
        val_idx = index_group[_val_i]
32
        _train_i = [_ for _ in range(len(index_group)) if _ != i and _ != _val_i]
33
        train_idx = []
34
        for _ in _train_i:
35
            train_idx += index_group[_]
36
37
        m = init_model(
38
            fold, X[train_idx], y[train_idx],
39
            X[val_idx], y[val_idx],
40
            X[test_idx], y[test_idx], )
41
        print('Training model:', fold)
42
        _time = datetime.now()
43
        m.fit()
44
        print('Training Time:', datetime.now() - _time)
45
        print('Testing model:', fold)
46
47
        test_corr = 0
48
        val_max = 0
49
        val_i = -1
50
        for j in range(500, 5001, 100):
51
            m.load_model(j)
52
            pred, label = m.transform('test', 24)
53
            pred = softmax(pred, axis=1)
54
            test_auc = roc_auc_score(label, pred[:, 1])
55
            pred, label = m.transform('val', 24)
56
            pred = softmax(pred, axis=1)
57
            val_auc = roc_auc_score(label, pred[:, 1])
58
            if val_auc > val_max:
59
                val_i = j
60
                val_max = val_auc
61
                test_corr = test_auc
62
63
        print('Best iter:', val_i, 'Best V auc:', val_max, 'Corr T auc:', test_corr)
64
        m.load_model(val_i)
65
        pred, label = m.transform('test', 24)
66
        pred = softmax(pred, axis=1)[:, 1]
67
        test_auc = roc_auc_score(label, pred)
68
        prob_list.append(pred)
69
        auc_list.append(test_auc)
70
        y_list.append(label)
71
    return np.concatenate(prob_list), np.concatenate(y_list), np.mean(auc_list), np.std(auc_list, ddof=1)
72
73
74
def detail_test(features, label, ppv_th=0.7):
75
    def get_sen_spe(pred, label):
76
77
        def criteria(x, th):
78
            return (x > th).astype('int')
79
80
        for j in range(0, 1000, 1):
81
            j = j / 1000
82
            TP = ((label == 1) * (criteria(pred, j) == 1))
83
            TN = ((label == 0) * (criteria(pred, j) == 0))
84
            FP = ((label == 0) * (criteria(pred, j) == 1))
85
            FN = ((label == 1) * (criteria(pred, j) == 0))
86
            sensitivity = TP.sum() / (TP.sum() + FN.sum() + 1e-9)
87
            specifity = TN.sum() / (TN.sum() + FP.sum() + 1e-9)
88
            ppv = TP.sum() / (TP.sum() + FP.sum() + 1e-9)
89
            npv = TN.sum() / (TN.sum() + FN.sum() + 1e-9)
90
            if ppv >= ppv_th:
91
                break
92
        return sensitivity, specifity, ppv, npv
93
94
    X_selected = features
95
    auc_all = list()
96
    auc_mean = list()
97
    auc_std = list()
98
    sensitivity = list()
99
    specificity = list()
100
    ppv = list()
101
    npv = list()
102
    tpr = list()
103
    for i in range(5):
104
        y_pred, _y, _auc_mean, _auc_std = cross_val(X_selected, label, i * 10)
105
        test_auc = roc_auc_score(_y, y_pred)
106
        _tpr = inter_auc(y_pred, _y)
107
        tpr.append(_tpr)
108
        auc_all.append(test_auc)
109
        auc_mean.append(_auc_mean)
110
        auc_std.append(_auc_std)
111
        _sen, _spe, _ppv, _npv = get_sen_spe(y_pred, _y)
112
        sensitivity.append(_sen)
113
        specificity.append(_spe)
114
        ppv.append(_ppv)
115
        npv.append(_npv)
116
    return tpr, auc_all, auc_mean, auc_std, sensitivity, specificity, ppv, npv
117
118
119
def inter_auc(y_pred, y):
120
    inter_fpr = np.linspace(0, 1, 1000)
121
    fpr, tpr, thresholds = roc_curve(y, y_pred)
122
    inter_tpr = np.interp(inter_fpr, fpr, tpr)
123
    inter_tpr[0] = 0.0
124
    inter_tpr[-1] = 1.0
125
    return inter_tpr