--- a +++ b/src/hybrid/hybrid_test_results.py @@ -0,0 +1,22 @@ +import numpy as np +import torch +from src.utils import calculate_metrics + + +def hybrid_test_results(model, hybrid_test_loader, icdtype, device): + + model.eval() + with torch.no_grad(): + model_result = [] + targets = [] + for rnn_x, cnn_x, batch_targets in hybrid_test_loader: + rnn_x = rnn_x.to(device) + cnn_x = cnn_x.to(device) + + model_batch_result = model(rnn_x, cnn_x) + model_result.extend(model_batch_result.cpu().numpy()) + targets.extend(batch_targets[icdtype].cpu().numpy()) + + result = calculate_metrics(np.array(model_result), np.array(targets)) + print('-'*10 + icdtype + '-'*10) + print(result) \ No newline at end of file