[71ad2f]: / src / bert / bert_utils.py

Download this file

21 lines (18 with data), 811 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
import torch
from src.utils import calculate_metrics
#################### Printing Test Metrics ##############################
def bert_test_results(model, test_loader, icdtype, device):
model.eval()
with torch.no_grad():
model_result = []
targets = []
for resps, batch_targets in test_loader:
model_batch_result = model(resps["ids"].to(device), resps["mask"].to(device), resps["token_type_ids"].to(device))
model_result.extend(model_batch_result.cpu().numpy())
targets.extend(batch_targets[icdtype].cpu().numpy())
result = calculate_metrics(np.array(model_result), np.array(targets))
print('-'*20 + icdtype + '-'*20)
print(result)
print('-'*len('-'*20 + icdtype + '-'*20))
########################################################################