|
a |
|
b/metrics/regression_metrics.py |
|
|
1 |
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, R2Score |
|
|
2 |
|
|
|
3 |
def get_regression_metrics(preds, labels): |
|
|
4 |
# get regression metrics: mse, mae, rmse, r2 |
|
|
5 |
mse = MeanSquaredError(squared=True) |
|
|
6 |
rmse = MeanSquaredError(squared=False) |
|
|
7 |
mae = MeanAbsoluteError() |
|
|
8 |
r2 = R2Score() |
|
|
9 |
|
|
|
10 |
mse(preds, labels) |
|
|
11 |
rmse(preds, labels) |
|
|
12 |
mae(preds, labels) |
|
|
13 |
r2(preds, labels) |
|
|
14 |
|
|
|
15 |
# return a dictionary |
|
|
16 |
return { |
|
|
17 |
"mse": mse.compute().item(), |
|
|
18 |
"rmse": rmse.compute().item(), |
|
|
19 |
"mae": mae.compute().item(), |
|
|
20 |
"r2": r2.compute().item(), |
|
|
21 |
} |