a b/src/test_results.py
1
import numpy as np 
2
import torch
3
from src.utils import calculate_metrics
4
5
def test_results(model, test_loader, icdtype, device):
6
7
  model.eval()
8
  with torch.no_grad():
9
    model_result = []
10
    targets = []
11
    for x_test, batch_targets in test_loader:
12
      x_test = x_test.to(device)
13
      model_batch_result = model(x_test)
14
      model_result.extend(model_batch_result.cpu().numpy())
15
      targets.extend(batch_targets[icdtype].cpu().numpy())
16
  result = calculate_metrics(np.array(model_result), np.array(targets))
17
  print('-'*10 + icdtype + '-'*10)
18
  print(result)