--- a
+++ b/train.py
@@ -0,0 +1,238 @@
+import json
+import torch
+import os
+from tqdm import tqdm
+from resnet import ResNet1d
+from dataloader import BatchDataloader
+import torch.optim as optim
+import numpy as np
+
+
+def compute_loss(ages, pred_ages, weights):
+    diff = ages.flatten() - pred_ages.flatten()
+    loss = torch.sum(weights.flatten() * diff * diff)
+    return loss
+
+
+def compute_weights(ages, max_weight=np.inf):
+    _, inverse, counts = np.unique(ages, return_inverse=True, return_counts=True)
+    weights = 1 / counts[inverse]
+    normalized_weights = weights / sum(weights)
+    w = len(ages) * normalized_weights
+    # Truncate weights to a maximum
+    if max_weight < np.inf:
+        w = np.minimum(w, max_weight)
+        w = len(ages) * w / sum(w)
+    return w
+
+
+def train(ep, dataload):
+    model.train()
+    total_loss = 0
+    n_entries = 0
+    train_desc = "Epoch {:2d}: train - Loss: {:.6f}"
+    train_bar = tqdm(initial=0, leave=True, total=len(dataload),
+                     desc=train_desc.format(ep, 0, 0), position=0)
+    for traces, ages, weights in dataload:
+        traces = traces.transpose(1, 2)
+        traces, ages, weights = traces.to(device), ages.to(device), weights.to(device)
+        # Reinitialize grad
+        model.zero_grad()
+        # Send to device
+        # Forward pass
+        pred_ages = model(traces)
+        loss = compute_loss(ages, pred_ages, weights)
+        # Backward pass
+        loss.backward()
+        # Optimize
+        optimizer.step()
+        # Update
+        bs = len(traces)
+        total_loss += loss.detach().cpu().numpy()
+        n_entries += bs
+        # Update train bar
+        train_bar.desc = train_desc.format(ep, total_loss / n_entries)
+        train_bar.update(1)
+    train_bar.close()
+    return total_loss / n_entries
+
+
+def eval(ep, dataload):
+    model.eval()
+    total_loss = 0
+    n_entries = 0
+    eval_desc = "Epoch {:2d}: valid - Loss: {:.6f}"
+    eval_bar = tqdm(initial=0, leave=True, total=len(dataload),
+                    desc=eval_desc.format(ep, 0, 0), position=0)
+    for traces, ages, weights in dataload:
+        traces = traces.transpose(1, 2)
+        traces, ages, weights = traces.to(device), ages.to(device), weights.to(device)
+        with torch.no_grad():
+            # Forward pass
+            pred_ages = model(traces)
+            loss = compute_loss(ages, pred_ages, weights)
+            # Update outputs
+            bs = len(traces)
+            # Update ids
+            total_loss += loss.detach().cpu().numpy()
+            n_entries += bs
+            # Print result
+            eval_bar.desc = eval_desc.format(ep, total_loss / n_entries)
+            eval_bar.update(1)
+    eval_bar.close()
+    return total_loss / n_entries
+
+
+if __name__ == "__main__":
+    import h5py
+    import pandas as pd
+    import argparse
+    from warnings import warn
+
+    # Arguments that will be saved in config file
+    parser = argparse.ArgumentParser(add_help=True,
+                                     description='Train model to predict rage from the raw ecg tracing.')
+    parser.add_argument('--epochs', type=int, default=70,
+                        help='maximum number of epochs (default: 70)')
+    parser.add_argument('--seed', type=int, default=2,
+                        help='random seed for number generator (default: 2)')
+    parser.add_argument('--sample_freq', type=int, default=400,
+                        help='sample frequency (in Hz) in which all traces will be resampled at (default: 400)')
+    parser.add_argument('--seq_length', type=int, default=4096,
+                        help='size (in # of samples) for all traces. If needed traces will be zeropadded'
+                                    'to fit into the given size. (default: 4096)')
+    parser.add_argument('--scale_multiplier', type=int, default=10,
+                        help='multiplicative factor used to rescale inputs.')
+    parser.add_argument('--batch_size', type=int, default=32,
+                        help='batch size (default: 32).')
+    parser.add_argument('--lr', type=float, default=0.001,
+                        help='learning rate (default: 0.001)')
+    parser.add_argument("--patience", type=int, default=7,
+                        help='maximum number of epochs without reducing the learning rate (default: 7)')
+    parser.add_argument("--min_lr", type=float, default=1e-7,
+                        help='minimum learning rate (default: 1e-7)')
+    parser.add_argument("--lr_factor", type=float, default=0.1,
+                        help='reducing factor for the lr in a plateu (default: 0.1)')
+    parser.add_argument('--net_filter_size', type=int, nargs='+', default=[64, 128, 196, 256, 320],
+                        help='filter size in resnet layers (default: [64, 128, 196, 256, 320]).')
+    parser.add_argument('--net_seq_lengh', type=int, nargs='+', default=[4096, 1024, 256, 64, 16],
+                        help='number of samples per resnet layer (default: [4096, 1024, 256, 64, 16]).')
+    parser.add_argument('--dropout_rate', type=float, default=0.8,
+                        help='dropout rate (default: 0.8).')
+    parser.add_argument('--kernel_size', type=int, default=17,
+                        help='kernel size in convolutional layers (default: 17).')
+    parser.add_argument('--folder', default='model/',
+                        help='output folder (default: ./out)')
+    parser.add_argument('--traces_dset', default='tracings',
+                        help='traces dataset in the hdf5 file.')
+    parser.add_argument('--ids_dset', default='',
+                        help='by default consider the ids are just the order')
+    parser.add_argument('--age_col', default='age',
+                        help='column with the age in csv file.')
+    parser.add_argument('--ids_col', default=None,
+                        help='column with the ids in csv file.')
+    parser.add_argument('--cuda', action='store_true',
+                        help='use cuda for computations. (default: False)')
+    parser.add_argument('--n_valid', type=int, default=100,
+                        help='the first `n_valid` exams in the hdf will be for validation.'
+                             'The rest is for training')
+    parser.add_argument('path_to_traces',
+                        help='path to file containing ECG traces')
+    parser.add_argument('path_to_csv',
+                        help='path to csv file containing attributes.')
+    args, unk = parser.parse_known_args()
+    # Check for unknown options
+    if unk:
+        warn("Unknown arguments:" + str(unk) + ".")
+
+    torch.manual_seed(args.seed)
+    print(args)
+    # Set device
+    device = torch.device('cuda:0' if args.cuda else 'cpu')
+    folder = args.folder
+
+    # Generate output folder if needed
+    if not os.path.exists(args.folder):
+        os.makedirs(args.folder)
+    # Save config file
+    with open(os.path.join(args.folder, 'args.json'), 'w') as f:
+        json.dump(vars(args), f, indent='\t')
+
+    tqdm.write("Building data loaders...")
+    # Get csv data
+    df = pd.read_csv(args.path_to_csv, index_col=args.ids_col)
+    ages = df[args.age_col]
+    # Get h5 data
+    f = h5py.File(args.path_to_traces, 'r')
+    traces = f[args.traces_dset]
+    if args.ids_dset:
+        h5ids = f[args.ids_dset]
+        df = df.reindex(h5ids, fill_value=False, copy=True)
+    # Train/ val split
+    valid_mask = np.arange(len(df)) <= args.n_valid
+    train_mask = ~valid_mask
+    # weights
+    weights = compute_weights(ages)
+    # Dataloader
+    train_loader = BatchDataloader(traces, ages, weights, bs=args.batch_size, mask=train_mask)
+    valid_loader = BatchDataloader(traces, ages, weights, bs=args.batch_size, mask=valid_mask)
+    tqdm.write("Done!")
+
+    tqdm.write("Define model...")
+    N_LEADS = 12  # the 12 leads
+    N_CLASSES = 1  # just the age
+    model = ResNet1d(input_dim=(N_LEADS, args.seq_length),
+                     blocks_dim=list(zip(args.net_filter_size, args.net_seq_lengh)),
+                     n_classes=N_CLASSES,
+                     kernel_size=args.kernel_size,
+                     dropout_rate=args.dropout_rate)
+    model.to(device=device)
+    tqdm.write("Done!")
+
+    tqdm.write("Define optimizer...")
+    optimizer = optim.Adam(model.parameters(), args.lr)
+    tqdm.write("Done!")
+
+    tqdm.write("Define scheduler...")
+    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=args.patience,
+                                                     min_lr=args.lr_factor * args.min_lr,
+                                                     factor=args.lr_factor)
+    tqdm.write("Done!")
+
+    tqdm.write("Training...")
+    start_epoch = 0
+    best_loss = np.Inf
+    history = pd.DataFrame(columns=['epoch', 'train_loss', 'valid_loss', 'lr',
+                                    'weighted_rmse', 'weighted_mae', 'rmse', 'mse'])
+    for ep in range(start_epoch, args.epochs):
+        train_loss = train(ep, train_loader)
+        valid_loss = eval(ep, valid_loader)
+        # Save best model
+        if valid_loss < best_loss:
+            # Save model
+            torch.save({'epoch': ep,
+                        'model': model.state_dict(),
+                        'valid_loss': valid_loss,
+                        'optimizer': optimizer.state_dict()},
+                       os.path.join(folder, 'model.pth'))
+            # Update best validation loss
+            best_loss = valid_loss
+        # Get learning rate
+        for param_group in optimizer.param_groups:
+            learning_rate = param_group["lr"]
+        # Interrupt for minimum learning rate
+        if learning_rate < args.min_lr:
+            break
+        # Print message
+        tqdm.write('Epoch {:2d}: \tTrain Loss {:.6f} ' \
+                  '\tValid Loss {:.6f} \tLearning Rate {:.7f}\t'
+                 .format(ep, train_loss, valid_loss, learning_rate))
+        # Save history
+        history = history.append({"epoch": ep, "train_loss": train_loss,
+                                  "valid_loss": valid_loss, "lr": learning_rate}, ignore_index=True)
+        history.to_csv(os.path.join(folder, 'history.csv'), index=False)
+        # Update learning rate
+        scheduler.step(valid_loss)
+    tqdm.write("Done!")
+
+