Diff of /evaluation.py [000000] .. [748954]

Switch to unified view

a b/evaluation.py
1
import torch
2
import torch.nn as nn
3
4
from tqdm import tqdm
5
from data import GIImageDataLoader
6
7
# TODO: other metrics
8
def evaluate(model: nn.Module, data_loader: GIImageDataLoader, criterion: nn.Module, device: str, use_fp16: bool) -> float:
9
    model.eval()
10
    model = model.to(device)
11
    criterion = criterion.to(device)
12
    total_loss = 0.0
13
    with torch.no_grad():
14
        for inputs, labels in tqdm(data_loader):
15
            inputs = inputs.to(device)
16
            labels = labels.to(device)
17
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_fp16):
18
                preds = model(inputs)
19
                loss = criterion(preds, labels)
20
            total_loss += loss.item() * inputs.size(0)
21
    return total_loss