--- a +++ b/examples/ae/train_ae.py @@ -0,0 +1,219 @@ +import pandas as pd +import numpy as np + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from torchvision.transforms import Compose + +import pytorch_lightning as pl +from pytorch_lightning.loggers.neptune import NeptuneLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from sklearn.model_selection import train_test_split + +from ecgxai.utils.dataset import UniversalECGDataset +from ecgxai.network.AE_encoder_decoder import AEDoubleResidualEncoder, DoubleResidualDecoder + +from ecgxai.utils.loss import TW +from ecgxai.utils.transforms import ApplyGain, ToTensor, Resample +from ecgxai.systems.AE_system import AE + + +# Please note that this configuration requires median beat data which is not currently publicly available +params = { + "median_data_dir": "/median", + "one_mili_csv": "header_info.csv", +} + +pl.seed_everything(0) + +# Set Data transforms +train_transform = Compose([ToTensor(), ApplyGain(), Resample(500)]) +test_transform = Compose([ToTensor(), ApplyGain(), Resample(500)]) + +# Get data and split in train and test +full_set_df = pd.read_csv(params['one_mili_csv']) + + +trainset_df, testset_df = train_test_split(full_set_df, test_size=0.1) + +trainset = UniversalECGDataset( + 'umcu', + params['median_data_dir'], + trainset_df, + transform=train_transform, +) + +testset = UniversalECGDataset( + 'umcu', + params['median_data_dir'], + testset_df, + transform=test_transform, +) + +batchsize = 64 +trainLoader = DataLoader(trainset, batch_size=batchsize, shuffle=True, num_workers=12) +testLoader = DataLoader(testset, batch_size=batchsize, shuffle=True, num_workers=8) + +# Remember to properly configure your logger here +# You can change the neptune logger to any logger supported by pytorch lighting +neptune_logger = NeptuneLogger( + api_key=open("../neptune_token.txt", "r").read(), + project="%project%" +) + +lr = 0.001 + +latent_dim = 32 +in_sample_dim = 600 +out_sample_dim = 600 +sample_channels = 8 +out_sample_channels = 8 + +enc_pre_block_1_out_channels = 16 +enc_pre_block_1_kernel_size = 3 +enc_pre_block_1_bn = True +enc_pre_block_1_dropout_rate = 0.0 +enc_pre_block_1_act_func = None + +enc_pre_block_2_out_channels = 32 +enc_pre_block_2_kernel_size = 3 +enc_pre_block2_act_func = None +enc_pre_block_2_dropout_rate = 0.1 +enc_pre_block2_bn = True + +enc_cnn_num_layers = 12 +enc_cnn_kernel_size = 16 +enc_cnn_dropout_rate = 0.1 +enc_cnn_sub_sample_every = 4 +enc_cnn_double_channel_every = 4 +enc_cnn_act_func = nn.ReLU() +enc_cnn_bn = True + +dec_post_block_1_out_channels = 8 +dec_post_block_1_kernel_size = 3 +dec_post_block_1_bn = True +dec_post_block_1_dropout_rate = 0.1 +dec_post_block_1_act_func = None + +dec_post_block_2_out_channels = 8 +dec_post_block_2_kernel_size = 17 +dec_post_block2_act_func = None +dec_post_block_2_dropout_rate = 0.1 +dec_post_block2_bn = True + +dec_cnn_num_layers = 3 +dec_cnn_kernel_size = 3 +dec_cnn_dropout_rate = 0.1 +dec_cnn_sub_sample_every = 4 +dec_cnn_double_channel_every = 4 +dec_cnn_act_func = nn.ReLU() +dec_cnn_bn = True + +hyperparameters = dict( + lr=lr, + latent_dim=latent_dim, + in_sample_dim=in_sample_dim, + out_sample_dim=out_sample_dim, + sample_channels=sample_channels, + out_sample_channels=out_sample_channels, + + pre_block_1_out_channels=enc_pre_block_1_out_channels, + pre_block_1_kernel_size=enc_pre_block_1_kernel_size, + pre_block_1_bn=enc_pre_block_1_bn, + pre_block_1_dropout_rate=enc_pre_block_1_dropout_rate, + pre_block_1_act_func=enc_pre_block_1_act_func, + + pre_block_2_out_channels=enc_pre_block_2_out_channels, + pre_block_2_kernel_size=enc_pre_block_2_kernel_size, + pre_block2_act_func=enc_pre_block2_act_func, + pre_block_2_dropout_rate=enc_pre_block_2_dropout_rate, + pre_block2_bn=enc_pre_block2_bn, + + enc_cnn_num_layers=enc_cnn_num_layers, + enc_cnn_kernel_size=enc_cnn_kernel_size, + enc_cnn_dropout_rate=enc_cnn_dropout_rate, + enc_cnn_sub_sample_every=enc_cnn_sub_sample_every, + enc_cnn_double_channel_every=enc_cnn_double_channel_every, + enc_cnn_act_func=enc_cnn_act_func, + enc_cnn_bn=enc_cnn_bn +) + +print(hyperparameters) + +encoder = AEDoubleResidualEncoder( + latent_dim=latent_dim, + in_sample_dim=in_sample_dim, + out_sample_dim=out_sample_dim, + sample_channels=sample_channels, + out_sample_channels=out_sample_channels, + + pre_block_1_out_channels=enc_pre_block_1_out_channels, + pre_block_1_kernel_size=enc_pre_block_1_kernel_size, + pre_block_1_bn=enc_pre_block_1_bn, + pre_block_1_dropout_rate=enc_pre_block_1_dropout_rate, + pre_block_1_act_funct=enc_pre_block_1_act_func, + + pre_block_2_out_channels=enc_pre_block_2_out_channels, + pre_block_2_kernel_size=enc_pre_block_2_kernel_size, + pre_block2_act_funct=enc_pre_block2_act_func, + pre_block_2_dropout_rate=enc_pre_block_2_dropout_rate, + pre_block2_bn=enc_pre_block2_bn, + + cnn_num_layers=enc_cnn_num_layers, + cnn_kernel_size=enc_cnn_kernel_size, + cnn_dropout_rate=enc_cnn_dropout_rate, + cnn_sub_sample_every=enc_cnn_sub_sample_every, + cnn_double_channel_every=enc_cnn_double_channel_every, + cnn_act_func=enc_cnn_act_func, + cnn_bn=enc_cnn_bn +) + +decoder = DoubleResidualDecoder( + latent_dim=latent_dim, + in_sample_dim=in_sample_dim, + out_sample_dim=out_sample_dim, + sample_channels=sample_channels, + out_sample_channels=out_sample_channels, + + post_block_1_in_channels=dec_post_block_1_out_channels, + post_block_1_kernel_size=dec_post_block_1_kernel_size, + post_block_1_bn=dec_post_block_1_bn, + post_block_1_dropout_rate=dec_post_block_1_dropout_rate, + post_block_1_act_func=dec_post_block_1_act_func, + + post_block_2_in_channels=dec_post_block_2_out_channels, + post_block_2_kernel_size=dec_post_block_2_kernel_size, + post_block_2_act_func=dec_post_block2_act_func, + post_block_2_dropout_rate=dec_post_block_2_dropout_rate, + post_block_2_bn=dec_post_block2_bn, + + cnn_num_layers=dec_cnn_num_layers, + cnn_kernel_size=dec_cnn_kernel_size, + cnn_dropout_rate=dec_cnn_dropout_rate, + cnn_sub_sample_every=dec_cnn_sub_sample_every, + cnn_double_channel_every=dec_cnn_double_channel_every, + cnn_act_func=dec_cnn_act_func, + cnn_bn=dec_cnn_bn +) + + +model = AE(encoder, decoder, lr=lr, loss=TW(torch.nn.MSELoss(reduction='mean'), input_args=['x', 'reconstruction'])) + +trainer = pl.Trainer( + logger=neptune_logger, + checkpoint_callback=False, + gradient_clip_val=10, + max_epochs=50, + gpus=1 if torch.cuda.is_available() else None, + callbacks=[ + ModelCheckpoint( + save_last=True + ), + ], +) + +trainer.logger.log_hyperparams(hyperparameters) +trainer.fit(model, trainLoader, testLoader)