a b/src/hybrid/hybrid_test_results.py
1
import numpy as np 
2
import torch
3
from src.utils import calculate_metrics
4
5
6
def hybrid_test_results(model, hybrid_test_loader, icdtype, device):
7
8
  model.eval()
9
  with torch.no_grad():
10
    model_result = []
11
    targets = []
12
    for rnn_x, cnn_x, batch_targets in hybrid_test_loader:
13
      rnn_x = rnn_x.to(device)
14
      cnn_x = cnn_x.to(device)
15
16
      model_batch_result = model(rnn_x, cnn_x)
17
      model_result.extend(model_batch_result.cpu().numpy())
18
      targets.extend(batch_targets[icdtype].cpu().numpy())
19
20
  result = calculate_metrics(np.array(model_result), np.array(targets))
21
  print('-'*10 + icdtype + '-'*10)
22
  print(result)