[71ad2f]: / src / hybrid / hybrid_test_results.py

Download this file

22 lines (17 with data), 634 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import torch
from src.utils import calculate_metrics
def hybrid_test_results(model, hybrid_test_loader, icdtype, device):
model.eval()
with torch.no_grad():
model_result = []
targets = []
for rnn_x, cnn_x, batch_targets in hybrid_test_loader:
rnn_x = rnn_x.to(device)
cnn_x = cnn_x.to(device)
model_batch_result = model(rnn_x, cnn_x)
model_result.extend(model_batch_result.cpu().numpy())
targets.extend(batch_targets[icdtype].cpu().numpy())
result = calculate_metrics(np.array(model_result), np.array(targets))
print('-'*10 + icdtype + '-'*10)
print(result)