import json
import torch
import os
from tqdm import tqdm
from resnet import ResNet1d
from dataloader import BatchDataloader
import torch.optim as optim
import numpy as np
def compute_loss(ages, pred_ages, weights):
diff = ages.flatten() - pred_ages.flatten()
loss = torch.sum(weights.flatten() * diff * diff)
return loss
def compute_weights(ages, max_weight=np.inf):
_, inverse, counts = np.unique(ages, return_inverse=True, return_counts=True)
weights = 1 / counts[inverse]
normalized_weights = weights / sum(weights)
w = len(ages) * normalized_weights
# Truncate weights to a maximum
if max_weight < np.inf:
w = np.minimum(w, max_weight)
w = len(ages) * w / sum(w)
return w
def train(ep, dataload):
model.train()
total_loss = 0
n_entries = 0
train_desc = "Epoch {:2d}: train - Loss: {:.6f}"
train_bar = tqdm(initial=0, leave=True, total=len(dataload),
desc=train_desc.format(ep, 0, 0), position=0)
for traces, ages, weights in dataload:
traces = traces.transpose(1, 2)
traces, ages, weights = traces.to(device), ages.to(device), weights.to(device)
# Reinitialize grad
model.zero_grad()
# Send to device
# Forward pass
pred_ages = model(traces)
loss = compute_loss(ages, pred_ages, weights)
# Backward pass
loss.backward()
# Optimize
optimizer.step()
# Update
bs = len(traces)
total_loss += loss.detach().cpu().numpy()
n_entries += bs
# Update train bar
train_bar.desc = train_desc.format(ep, total_loss / n_entries)
train_bar.update(1)
train_bar.close()
return total_loss / n_entries
def eval(ep, dataload):
model.eval()
total_loss = 0
n_entries = 0
eval_desc = "Epoch {:2d}: valid - Loss: {:.6f}"
eval_bar = tqdm(initial=0, leave=True, total=len(dataload),
desc=eval_desc.format(ep, 0, 0), position=0)
for traces, ages, weights in dataload:
traces = traces.transpose(1, 2)
traces, ages, weights = traces.to(device), ages.to(device), weights.to(device)
with torch.no_grad():
# Forward pass
pred_ages = model(traces)
loss = compute_loss(ages, pred_ages, weights)
# Update outputs
bs = len(traces)
# Update ids
total_loss += loss.detach().cpu().numpy()
n_entries += bs
# Print result
eval_bar.desc = eval_desc.format(ep, total_loss / n_entries)
eval_bar.update(1)
eval_bar.close()
return total_loss / n_entries
if __name__ == "__main__":
import h5py
import pandas as pd
import argparse
from warnings import warn
# Arguments that will be saved in config file
parser = argparse.ArgumentParser(add_help=True,
description='Train model to predict rage from the raw ecg tracing.')
parser.add_argument('--epochs', type=int, default=70,
help='maximum number of epochs (default: 70)')
parser.add_argument('--seed', type=int, default=2,
help='random seed for number generator (default: 2)')
parser.add_argument('--sample_freq', type=int, default=400,
help='sample frequency (in Hz) in which all traces will be resampled at (default: 400)')
parser.add_argument('--seq_length', type=int, default=4096,
help='size (in # of samples) for all traces. If needed traces will be zeropadded'
'to fit into the given size. (default: 4096)')
parser.add_argument('--scale_multiplier', type=int, default=10,
help='multiplicative factor used to rescale inputs.')
parser.add_argument('--batch_size', type=int, default=32,
help='batch size (default: 32).')
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate (default: 0.001)')
parser.add_argument("--patience", type=int, default=7,
help='maximum number of epochs without reducing the learning rate (default: 7)')
parser.add_argument("--min_lr", type=float, default=1e-7,
help='minimum learning rate (default: 1e-7)')
parser.add_argument("--lr_factor", type=float, default=0.1,
help='reducing factor for the lr in a plateu (default: 0.1)')
parser.add_argument('--net_filter_size', type=int, nargs='+', default=[64, 128, 196, 256, 320],
help='filter size in resnet layers (default: [64, 128, 196, 256, 320]).')
parser.add_argument('--net_seq_lengh', type=int, nargs='+', default=[4096, 1024, 256, 64, 16],
help='number of samples per resnet layer (default: [4096, 1024, 256, 64, 16]).')
parser.add_argument('--dropout_rate', type=float, default=0.8,
help='dropout rate (default: 0.8).')
parser.add_argument('--kernel_size', type=int, default=17,
help='kernel size in convolutional layers (default: 17).')
parser.add_argument('--folder', default='model/',
help='output folder (default: ./out)')
parser.add_argument('--traces_dset', default='tracings',
help='traces dataset in the hdf5 file.')
parser.add_argument('--ids_dset', default='',
help='by default consider the ids are just the order')
parser.add_argument('--age_col', default='age',
help='column with the age in csv file.')
parser.add_argument('--ids_col', default=None,
help='column with the ids in csv file.')
parser.add_argument('--cuda', action='store_true',
help='use cuda for computations. (default: False)')
parser.add_argument('--n_valid', type=int, default=100,
help='the first `n_valid` exams in the hdf will be for validation.'
'The rest is for training')
parser.add_argument('path_to_traces',
help='path to file containing ECG traces')
parser.add_argument('path_to_csv',
help='path to csv file containing attributes.')
args, unk = parser.parse_known_args()
# Check for unknown options
if unk:
warn("Unknown arguments:" + str(unk) + ".")
torch.manual_seed(args.seed)
print(args)
# Set device
device = torch.device('cuda:0' if args.cuda else 'cpu')
folder = args.folder
# Generate output folder if needed
if not os.path.exists(args.folder):
os.makedirs(args.folder)
# Save config file
with open(os.path.join(args.folder, 'args.json'), 'w') as f:
json.dump(vars(args), f, indent='\t')
tqdm.write("Building data loaders...")
# Get csv data
df = pd.read_csv(args.path_to_csv, index_col=args.ids_col)
ages = df[args.age_col]
# Get h5 data
f = h5py.File(args.path_to_traces, 'r')
traces = f[args.traces_dset]
if args.ids_dset:
h5ids = f[args.ids_dset]
df = df.reindex(h5ids, fill_value=False, copy=True)
# Train/ val split
valid_mask = np.arange(len(df)) <= args.n_valid
train_mask = ~valid_mask
# weights
weights = compute_weights(ages)
# Dataloader
train_loader = BatchDataloader(traces, ages, weights, bs=args.batch_size, mask=train_mask)
valid_loader = BatchDataloader(traces, ages, weights, bs=args.batch_size, mask=valid_mask)
tqdm.write("Done!")
tqdm.write("Define model...")
N_LEADS = 12 # the 12 leads
N_CLASSES = 1 # just the age
model = ResNet1d(input_dim=(N_LEADS, args.seq_length),
blocks_dim=list(zip(args.net_filter_size, args.net_seq_lengh)),
n_classes=N_CLASSES,
kernel_size=args.kernel_size,
dropout_rate=args.dropout_rate)
model.to(device=device)
tqdm.write("Done!")
tqdm.write("Define optimizer...")
optimizer = optim.Adam(model.parameters(), args.lr)
tqdm.write("Done!")
tqdm.write("Define scheduler...")
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=args.patience,
min_lr=args.lr_factor * args.min_lr,
factor=args.lr_factor)
tqdm.write("Done!")
tqdm.write("Training...")
start_epoch = 0
best_loss = np.Inf
history = pd.DataFrame(columns=['epoch', 'train_loss', 'valid_loss', 'lr',
'weighted_rmse', 'weighted_mae', 'rmse', 'mse'])
for ep in range(start_epoch, args.epochs):
train_loss = train(ep, train_loader)
valid_loss = eval(ep, valid_loader)
# Save best model
if valid_loss < best_loss:
# Save model
torch.save({'epoch': ep,
'model': model.state_dict(),
'valid_loss': valid_loss,
'optimizer': optimizer.state_dict()},
os.path.join(folder, 'model.pth'))
# Update best validation loss
best_loss = valid_loss
# Get learning rate
for param_group in optimizer.param_groups:
learning_rate = param_group["lr"]
# Interrupt for minimum learning rate
if learning_rate < args.min_lr:
break
# Print message
tqdm.write('Epoch {:2d}: \tTrain Loss {:.6f} ' \
'\tValid Loss {:.6f} \tLearning Rate {:.7f}\t'
.format(ep, train_loss, valid_loss, learning_rate))
# Save history
history = history.append({"epoch": ep, "train_loss": train_loss,
"valid_loss": valid_loss, "lr": learning_rate}, ignore_index=True)
history.to_csv(os.path.join(folder, 'history.csv'), index=False)
# Update learning rate
scheduler.step(valid_loss)
tqdm.write("Done!")