|
a |
|
b/evaluation.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
|
|
|
4 |
from tqdm import tqdm |
|
|
5 |
from data import GIImageDataLoader |
|
|
6 |
|
|
|
7 |
# TODO: other metrics |
|
|
8 |
def evaluate(model: nn.Module, data_loader: GIImageDataLoader, criterion: nn.Module, device: str, use_fp16: bool) -> float: |
|
|
9 |
model.eval() |
|
|
10 |
model = model.to(device) |
|
|
11 |
criterion = criterion.to(device) |
|
|
12 |
total_loss = 0.0 |
|
|
13 |
with torch.no_grad(): |
|
|
14 |
for inputs, labels in tqdm(data_loader): |
|
|
15 |
inputs = inputs.to(device) |
|
|
16 |
labels = labels.to(device) |
|
|
17 |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_fp16): |
|
|
18 |
preds = model(inputs) |
|
|
19 |
loss = criterion(preds, labels) |
|
|
20 |
total_loss += loss.item() * inputs.size(0) |
|
|
21 |
return total_loss |