--- a +++ b/clinical_ts/simclr_dataset_wrapper.py @@ -0,0 +1,298 @@ +from .create_logger import create_logger +import numpy as np +from torch.utils.data import DataLoader +# from .customDataLoader import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +import torchvision.transforms as transforms +from torch.utils.data import ConcatDataset +from torchvision import datasets +from functools import partial +from pathlib import Path +import pandas as pd +import pdb +try: + import pickle5 as pickle +except ImportError as e: + import pickle +from .timeseries_utils import TimeseriesDatasetCrops, reformat_as_memmap, load_dataset +from .ecg_utils import * +from .timeseries_transformations import GaussianNoise, RandomResizedCrop, ChannelResize, Negation, DynamicTimeWarp, DownSample, TimeWarp, TimeOut, ToTensor, BaselineWander, PowerlineNoise, EMNoise, BaselineShift, TGaussianNoise, TRandomResizedCrop, TChannelResize, TNegation, TDynamicTimeWarp, TDownSample, TTimeOut, TBaselineWander, TPowerlineNoise, TEMNoise, TBaselineShift, TGaussianBlur1d, TNormalize, Transpose + + +logger = create_logger(__name__) + + +def transformations_from_strings(transformations, t_params): + if transformations is None: + return [ToTensor()] + + def str_to_trafo(trafo): + if trafo == "RandomResizedCrop": + return TRandomResizedCrop(crop_ratio_range=t_params["rr_crop_ratio_range"], output_size=t_params["output_size"]) + elif trafo == "ChannelResize": + return TChannelResize(magnitude_range=t_params["magnitude_range"]) + elif trafo == "Negation": + return TNegation() + elif trafo == "DynamicTimeWarp": + return TDynamicTimeWarp(warps=t_params["warps"], radius=t_params["radius"]) + elif trafo == "DownSample": + return TDownSample(downsample_ratio=t_params["downsample_ratio"]) + elif trafo == "TimeWarp": + return TimeWarp(epsilon=t_params["epsilon"]) + elif trafo == "TimeOut": + return TTimeOut(crop_ratio_range=t_params["to_crop_ratio_range"]) + elif trafo == "GaussianNoise": + return TGaussianNoise(scale=t_params["gaussian_scale"]) + elif trafo == "BaselineWander": + return TBaselineWander(Cmax=t_params["bw_cmax"]) + elif trafo == "PowerlineNoise": + return TPowerlineNoise(Cmax=t_params["pl_cmax"]) + elif trafo == "EMNoise": + return TEMNoise(Cmax=t_params["em_cmax"]) + elif trafo == "BaselineShift": + return TBaselineShift(Cmax=t_params["bs_cmax"]) + elif trafo == "GaussianBlur": + return TGaussianBlur1d() + elif trafo == "Normalize": + return TNormalize() + else: + raise Exception(str(trafo) + " is not a valid transformation") + + # for numpy transformations + # trafo_list = [str_to_trafo(trafo) + # for trafo in transformations] + [ToTensor()] + + # for torch transformations + trafo_list = [ToTensor(transpose_data=False)] + [str_to_trafo(trafo) + for trafo in transformations] + [Transpose()] + return trafo_list + + +class SimCLRDataSetWrapper(object): + + def __init__(self, batch_size, num_workers, valid_size, input_shape, s, data_folder, target_folders, target_fs, recreate_data_ptb_xl, + mode="pretraining", transformations=None, t_params=None, ptb_xl_label="label_diag_superclass", filter_cinc=False, + percentage=1.0, swav=False, nmb_crops=7, folds=8, test=False): + self.batch_size = batch_size + self.num_workers = num_workers + self.valid_size = valid_size + self.s = s + self.input_shape = eval(input_shape) + self.data_folder = Path(data_folder) + # Path(target_folder+str(target_fs)) + self.target_folders = [Path(target_folder) + for target_folder in target_folders] + self.target_fs = target_fs + self.recreate_data_ptb_xl = recreate_data_ptb_xl + self.val_ds_idmap = None + self.lbl_itos = None + self.transformations = transformations_from_strings( + transformations, t_params) + self.train_ds_size = 0 + self.val_ds_size = 0 + self.ptb_xl_label = ptb_xl_label + self.filter_cinc = filter_cinc + self.percentage = percentage + self.swav = swav + self.nmb_crops = nmb_crops + self.folds = folds + self.test = test + if mode in ["linear_evaluation", "pretraining"]: + self.mode = mode + else: + raise("mode unkown") + + def get_data_loaders(self): + data_augment = self._get_simclr_pipeline_transform() + + # train_dataset = datasets.STL10('./data', split='train+unlabeled', download=True, + # transform=SimCLRDataTransform(data_augment)) + + if self.mode == "linear_evaluation": + # transformations = transforms.Compose([RandomResizedCrop(crop_ratio_range=[0.5, 1.0]), + # ToTensor()]) + # transformations = data_augment + # transformations = ToTensor() + train_ds, val_ds = self._get_datasets( + self.target_folders[0], transforms=data_augment) + self.val_ds_idmap = val_ds.get_id_mapping() + else: + + wrapper_transform = SwAVDataTransform(data_augment, num_crops=self.nmb_crops) if self.swav else SimCLRDataTransform(data_augment) + datasets = [self._get_datasets(target_folder, transforms=wrapper_transform) for target_folder in self.target_folders] + train_datasets, valid_datasets = list(zip(*datasets)) + + train_ds = ConcatDataset(list(train_datasets)) + val_ds = ConcatDataset(list(valid_datasets)) + + train_loader, valid_loader = self.get_train_validation_data_loaders( + train_ds, val_ds) + + self.train_ds_size = len(train_ds) + self.val_ds_size = len(val_ds) + return train_loader, valid_loader + + def _get_datasets(self, target_folder, transforms=None): + logger.info("get dataset from " + str(target_folder)) + # Dataset parameters + input_channels = 12 + target_fs = 100 + # Training setting + input_size = 250 # originally 600 + chunkify_train = False + chunkify_valid = self.mode != "pretraining" + chunk_length_train = input_size # target_fs*6 + chunk_length_valid = input_size + min_chunk_length = input_size # chunk_length + stride_length_train = chunk_length_train//4 # chunk_length_train//8 + stride_length_valid = input_size//2 # chunk_length_valid + + copies_valid = 0 # >0 should only be used with chunkify_valid=False + if self.test: + valid_fold=10 + test_fold=9 + else: + valid_fold=9 + test_fold=10 + + train_folds = [] + train_folds = list(range(1, 11)) + train_folds.remove(test_fold) + train_folds.remove(valid_fold) + train_folds = np.array(train_folds) + + df_memmap_filename = "df_memmap.pkl" + memmap_filename = "memmap.npy" + + # df, lbl_itos, mean, std = prepare_data_ptb_xl(self.data_folder, min_cnt=50, target_fs=self.target_fs, + # channels=input_channels, channel_stoi=channel_stoi_default, target_folder=self.target_folder, recreate_data=self.recreate_data_ptb_xl) + df_mapped, lbl_itos, mean, std = load_dataset(target_folder) + + + if(self.recreate_data_ptb_xl): + df_mapped = reformat_as_memmap( + df, target_folder/(memmap_filename), data_folder=target_folder) + else: + # df_mapped = pd.read_pickle( + # target_folder/(df_memmap_filename)) + df_mapped = pickle.load(open(target_folder/(df_memmap_filename), "rb")) + #self.lbl_itos = np.array(lbl_itos[label]) + + self.lbl_itos = lbl_itos + self.num_classes = len(lbl_itos) + # print("num classes:", self.num_classes) + + if "ptb" in str(target_folder): + label = self.ptb_xl_label # just possible for ptb xl + self.lbl_itos = np.array(lbl_itos[label]) + label = label + "_filtered_numeric" + else: + label = "label" + self.lbl_itos = lbl_itos + + df_mapped["diag_label"] = df_mapped[label].copy() + if "ptb" in str(target_folder) or self.mode == "linear_evaluation": + logger.debug("get labels for linear evaluation on ptb") + df_mapped["label"] = df_mapped[label].apply( + lambda x: multihot_encode(x, len(self.lbl_itos))) + else: + logger.debug("insert artifical labels to non-ptb dataset") + df_mapped["label"] = df_mapped[label].apply( + lambda x: np.array([1, 0, 0, 0, 0])) + + # logger.info("labels: " + str(self.lbl_itos)) + # df_mapped["label"] = df_mapped["label"].apply(lambda x: onehot_encode(x, len(self.lbl_itos))) + + + if self.mode == "pretraining": + valid_fold = test_fold = 9 + if self.percentage < 1.0: + logger.info("reduce dataset to {}%".format(self.percentage*100)) + total_samples = len(df_mapped) + num_samples = int(self.percentage*total_samples) + sample_indices = np.sort(np.random.choice(np.arange(total_samples), size=num_samples, replace=False)) + df_mapped = df_mapped.loc[sample_indices] + + df_train = df_mapped[(df_mapped.strat_fold != test_fold) & ( + df_mapped.strat_fold != valid_fold) & (df_mapped.label.apply(lambda x: np.sum(x) > 0))] + else: + assert(self.folds < 9) + df_train = df_mapped[(df_mapped.strat_fold.apply(lambda x: x in train_folds[range(self.folds)]) & (df_mapped.label.apply(lambda x: np.sum(x) > 0)))] + + df_valid = df_mapped[(df_mapped.strat_fold == valid_fold) & ( + df_mapped.label.apply(lambda x: np.sum(x) > 0))] + df_test = df_mapped[(df_mapped.strat_fold == test_fold) & ( + df_mapped.label.apply(lambda x: np.sum(x) > 0))] + + if self.filter_cinc and "cinc" in str(target_folder): + df_train = filter_out_datasets(df_train) + df_valid = filter_out_datasets(df_valid) + df_test = filter_out_datasets(df_test) + + train_ds = TimeseriesDatasetCrops(df_train, input_size, num_classes=len(self.lbl_itos), data_folder=target_folder, chunk_length=chunk_length_train if chunkify_train else 0, + min_chunk_length=min_chunk_length, stride=stride_length_train, transforms=transforms, annotation=False, col_lbl="label", memmap_filename=target_folder/(memmap_filename)) + val_ds = TimeseriesDatasetCrops(df_valid, input_size, num_classes=len(self.lbl_itos), data_folder=target_folder, chunk_length=chunk_length_valid if chunkify_valid else 0, + min_chunk_length=min_chunk_length, stride=stride_length_valid, transforms=transforms, annotation=False, col_lbl="label", memmap_filename=target_folder/(memmap_filename)) + self.df_train = df_train + self.df_valid = df_valid + self.df_test = df_test + return train_ds, val_ds + + def _get_simclr_pipeline_transform(self): + # get a set of data augmentation transformations as described in the SimCLR paper. + # find transformations in ecg_transformations.py file + # data_transforms = transforms.Compose([RandomResizedCrop(crop_ratio_range=[0.5, 1.0]), + # ChannelResize(magnitude_range=[0.33, 3]), + # DynamicTimeWarp(), + # ToTensor()]) + # data_transforms = [RandomResizedCrop(), ChannelResize(), ToTensor()] + data_transforms = transforms.Compose(self.transformations) + return data_transforms + + def get_train_validation_data_loaders(self, train_ds, val_ds): + + train_loader = DataLoader(train_ds, batch_size=self.batch_size, + num_workers=self.num_workers, pin_memory=True, shuffle=True, drop_last=True) + val_loader = DataLoader(val_ds, batch_size=self.batch_size, + shuffle=False, num_workers=self.num_workers, pin_memory=True) + + return train_loader, val_loader + +class SimCLRDataTransform(object): + def __init__(self, transform): + if transform is None: + self.transform = lambda x: x + self.transform = transform + + def __call__(self, sample): + xi = self.transform(sample) + xj = self.transform(sample) + return xi, xj + +class SwAVDataTransform(object): + def __init__(self, transform, num_crops=7): + if transform is None: + self.transform = lambda x: x + self.transform = transform + self.num_crops=num_crops + + def __call__(self, sample): + transformed = [] + for _ in range(self.num_crops): + transformed.append(self.transform(sample)[0]) + return transformed, sample[1] + + +def multihot_encode(x, num_classes): + res = np.zeros(num_classes, dtype=np.float32) + res[x] = 1 + return res + + +def filter_out_datasets(df, negative_datasets={"PTB", "PTB-XL"}): + datasets = set(df["dataset"]) + positive_datasets = [ + dataset for dataset in datasets if dataset not in negative_datasets] + positive_df_ids = [row in positive_datasets for row in df["dataset"]] + filtered_df = df.loc[positive_df_ids] + return filtered_df