--- a +++ b/evaluation.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + +from tqdm import tqdm +from data import GIImageDataLoader + +# TODO: other metrics +def evaluate(model: nn.Module, data_loader: GIImageDataLoader, criterion: nn.Module, device: str, use_fp16: bool) -> float: + model.eval() + model = model.to(device) + criterion = criterion.to(device) + total_loss = 0.0 + with torch.no_grad(): + for inputs, labels in tqdm(data_loader): + inputs = inputs.to(device) + labels = labels.to(device) + with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_fp16): + preds = model(inputs) + loss = criterion(preds, labels) + total_loss += loss.item() * inputs.size(0) + return total_loss