a b/Retrieval/utils/metrics.py
1
import numpy as np
2
3
4
def RSE(pred, true):
5
    return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6
7
8
def CORR(pred, true):
9
    u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10
    d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11
    return (u / d).mean(-1)
12
13
14
def MAE(pred, true):
15
    return np.mean(np.abs(pred - true))
16
17
18
def MSE(pred, true):
19
    return np.mean((pred - true) ** 2)
20
21
22
def RMSE(pred, true):
23
    return np.sqrt(MSE(pred, true))
24
25
26
def MAPE(pred, true):
27
    return np.mean(np.abs((pred - true) / true))
28
29
30
def MSPE(pred, true):
31
    return np.mean(np.square((pred - true) / true))
32
33
34
def metric(pred, true):
35
    mae = MAE(pred, true)
36
    mse = MSE(pred, true)
37
    rmse = RMSE(pred, true)
38
    mape = MAPE(pred, true)
39
    mspe = MSPE(pred, true)
40
41
    return mae, mse, rmse, mape, mspe