a b/deepdta-toy/emetrics.py
1
import numpy as np
2
3
4
def get_aupr(Y, P):
5
    if hasattr(Y, 'A'): Y = Y.A
6
    if hasattr(P, 'A'): P = P.A
7
    Y = np.where(Y>0, 1, 0)
8
    Y = Y.ravel()
9
    P = P.ravel()
10
    f = open("temp.txt", 'w')
11
    for i in range(Y.shape[0]):
12
        f.write("%f %d\n" %(P[i], Y[i]))
13
    f.close()
14
    f = open("foo.txt", 'w')
15
    subprocess.call(["java", "-jar", "auc.jar", "temp.txt", "list"], stdout=f)
16
    f.close()
17
    f = open("foo.txt")
18
    lines = f.readlines()
19
    aucpr = float(lines[-2].split()[-1])
20
    f.close()
21
    return aucpr
22
23
24
25
def get_cindex(Y, P):
26
    summ = 0
27
    pair = 0
28
    
29
    for i in range(1, len(Y)):
30
        for j in range(0, i):
31
            if i is not j:
32
                if(Y[i] > Y[j]):
33
                    pair +=1
34
                    summ +=  1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])
35
        
36
            
37
    if pair is not 0:
38
        return summ/pair
39
    else:
40
        return 0
41
42
43
def r_squared_error(y_obs,y_pred):
44
    y_obs = np.array(y_obs)
45
    y_pred = np.array(y_pred)
46
    y_obs_mean = [np.mean(y_obs) for y in y_obs]
47
    y_pred_mean = [np.mean(y_pred) for y in y_pred]
48
49
    mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
50
    mult = mult * mult
51
52
    y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
53
    y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )
54
55
    return mult / float(y_obs_sq * y_pred_sq)
56
57
58
def get_k(y_obs,y_pred):
59
    y_obs = np.array(y_obs)
60
    y_pred = np.array(y_pred)
61
62
    return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))
63
64
65
def squared_error_zero(y_obs,y_pred):
66
    k = get_k(y_obs,y_pred)
67
68
    y_obs = np.array(y_obs)
69
    y_pred = np.array(y_pred)
70
    y_obs_mean = [np.mean(y_obs) for y in y_obs]
71
    upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
72
    down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
73
74
    return 1 - (upp / float(down))
75
76
77
def get_rm2(ys_orig,ys_line):
78
    r2 = r_squared_error(ys_orig, ys_line)
79
    r02 = squared_error_zero(ys_orig, ys_line)
80
81
    return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))