--- a +++ b/evaluate.py @@ -0,0 +1,77 @@ +# Imports +from resnet import ResNet1d +import tqdm +import h5py +import torch +import os +import json +import numpy as np +import argparse +from warnings import warn +import pandas as pd + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument('mdl', type=str, + help='folder containing model.') + parser.add_argument('path_to_traces', type=str, default='../data/ecg_tracings.hdf5', + help='path to hdf5 containing ECG traces') + parser.add_argument('--batch_size', type=int, default=8, + help='number of exams per batch.') + parser.add_argument('--output', type=str, default='predicted_age.csv', + help='output file.') + parser.add_argument('--traces_dset', default='tracings', + help='traces dataset in the hdf5 file.') + parser.add_argument('--ids_dset', + help='ids dataset in the hdf5 file.') + args, unk = parser.parse_known_args() + # Check for unknown options + if unk: + warn("Unknown arguments:" + str(unk) + ".") + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + # Get checkpoint + ckpt = torch.load(os.path.join(args.mdl, 'model.pth'), map_location=lambda storage, loc: storage) + # Get config + config = os.path.join(args.mdl, 'config.json') + with open(config, 'r') as f: + config_dict = json.load(f) + # Get model + N_LEADS = 12 + model = ResNet1d(input_dim=(N_LEADS, config_dict['seq_length']), + blocks_dim=list(zip(config_dict['net_filter_size'], config_dict['net_seq_lengh'])), + n_classes=1, + kernel_size=config_dict['kernel_size'], + dropout_rate=config_dict['dropout_rate']) + # load model checkpoint + model.load_state_dict(ckpt["model"]) + model = model.to(device) + # Get traces + ff = h5py.File(args.path_to_traces, 'r') + traces = ff[args.traces_dset] + n_total = len(traces) + if args.ids_dset: + ids = ff[args.ids_dset] + else: + ids = range(n_total) + # Get dimension + predicted_age = np.zeros((n_total,)) + # Evaluate on test data + model.eval() + n_total, n_samples, n_leads = traces.shape + n_batches = int(np.ceil(n_total/args.batch_size)) + # Compute gradients + predicted_age = np.zeros((n_total,)) + end = 0 + for i in tqdm.tqdm(range(n_batches)): + start = end + end = min((i + 1) * args.batch_size, n_total) + with torch.no_grad(): + x = torch.tensor(traces[start:end, :, :]).transpose(-1, -2) + x = x.to(device, dtype=torch.float32) + y_pred = model(x) + predicted_age[start:end] = y_pred.detach().cpu().numpy().flatten() + # Save predictions + df = pd.DataFrame({'ids': ids, 'predicted_age': predicted_age}) + df = df.set_index('ids') + df.to_csv(args.output) \ No newline at end of file