Switch to unified view

a b/metrics/regression_metrics.py
1
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, R2Score
2
3
def get_regression_metrics(preds, labels):
4
    # get regression metrics: mse, mae, rmse, r2
5
    mse = MeanSquaredError(squared=True)
6
    rmse = MeanSquaredError(squared=False)
7
    mae = MeanAbsoluteError()
8
    r2 = R2Score()
9
10
    mse(preds, labels)
11
    rmse(preds, labels)
12
    mae(preds, labels)
13
    r2(preds, labels)
14
15
    # return a dictionary
16
    return {
17
        "mse": mse.compute().item(),
18
        "rmse": rmse.compute().item(),
19
        "mae": mae.compute().item(),
20
        "r2": r2.compute().item(),
21
    }