--- a +++ b/clinical_ts/timeseries_utils.py @@ -0,0 +1,828 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/A_timeseries_utils.ipynb (unless otherwise specified). + +__all__ = ['butter_filter', 'butter_filter_frequency_response', 'apply_butter_filter', 'save_dataset', 'load_dataset', + 'dataset_add_chunk_col', 'dataset_add_length_col', 'dataset_add_labels_col', 'dataset_add_mean_col', + 'dataset_add_median_col', 'dataset_add_std_col', 'dataset_add_iqr_col', 'dataset_get_stats', + 'npys_to_memmap_batched', 'npys_to_memmap', 'reformat_as_memmap', 'TimeseriesDatasetCrops', 'RandomCrop', + 'CenterCrop', 'GaussianNoise', 'Rescale', 'ToTensor', 'Normalize', 'NormalizeBatch', 'ButterFilter', + 'ChannelFilter', 'Transform', 'TupleTransform', 'aggregate_predictions'] + +# Cell +import numpy as np +import pandas as pd +import torch +import torch.utils.data +from torch import nn +from pathlib import Path +from scipy.stats import iqr + +try: + import pickle5 as pickle +except ImportError as e: + import pickle + +#Note: due to issues with the numpy rng for multiprocessing (https://github.com/pytorch/pytorch/issues/5059) that could be fixed by a custom worker_init_fn we use random throught for convenience +import random + +#Note: multiprocessing issues with python lists and dicts (https://github.com/pytorch/pytorch/issues/13246) and pandas dfs (https://github.com/pytorch/pytorch/issues/5902) +import multiprocessing as mp + +from skimage import transform + +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + +from scipy.signal import butter, sosfilt, sosfiltfilt, sosfreqz + +from tqdm.auto import tqdm + +# Cell +#https://stackoverflow.com/questions/12093594/how-to-implement-band-pass-butterworth-filter-with-scipy-signal-butter +def butter_filter(lowcut=10, highcut=20, fs=50, order=5, btype='band'): + '''returns butterworth filter with given specifications''' + nyq = 0.5 * fs + low = lowcut / nyq + high = highcut / nyq + + sos = butter(order, [low, high] if btype=="band" else (low if btype=="low" else high), analog=False, btype=btype, output='sos') + return sos + +def butter_filter_frequency_response(filter): + '''returns frequency response of a given filter (result of call of butter_filter)''' + w, h = sosfreqz(filter) + #gain vs. freq(Hz) + #plt.plot((fs * 0.5 / np.pi) * w, abs(h)) + return w,h + +def apply_butter_filter(data, filter, forwardbackward=True): + '''pass filter from call of butter_filter to data (assuming time axis at dimension 0)''' + if(forwardbackward): + return sosfiltfilt(filter, data, axis=0) + else: + data = sosfilt(filter, data, axis=0) + +# Cell +def save_dataset(df,lbl_itos,mean,std,target_root,filename_postfix="",protocol=4): + target_root = Path(target_root) + df.to_pickle(target_root/("df"+filename_postfix+".pkl"), protocol=protocol) + + if(isinstance(lbl_itos,dict)):#dict as pickle + outfile = open(target_root/("lbl_itos"+filename_postfix+".pkl"), "wb") + pickle.dump(lbl_itos, outfile, protocol=protocol) + outfile.close() + else:#array + np.save(target_root/("lbl_itos"+filename_postfix+".npy"),lbl_itos) + + np.save(target_root/("mean"+filename_postfix+".npy"),mean) + np.save(target_root/("std"+filename_postfix+".npy"),std) + +def load_dataset(target_root,filename_postfix="",df_mapped=True): + target_root = Path(target_root) + # if(df_mapped): + # df = pd.read_pickle(target_root/("df_memmap"+filename_postfix+".pkl")) + # else: + # df = pd.read_pickle(target_root/("df"+filename_postfix+".pkl") + + ### due to pickle 5 protocol error + + if(df_mapped): + df = pickle.load(open(target_root/("df_memmap"+filename_postfix+".pkl"), "rb")) + else: + df = pickle.load(open(target_root/("df"+filename_postfix+".pkl"), "rb")) + + + if((target_root/("lbl_itos"+filename_postfix+".pkl")).exists()):#dict as pickle + infile = open(target_root/("lbl_itos"+filename_postfix+".pkl"), "rb") + lbl_itos=pickle.load(infile) + infile.close() + else:#array + lbl_itos = np.load(target_root/("lbl_itos"+filename_postfix+".npy")) + + + mean = np.load(target_root/("mean"+filename_postfix+".npy")) + std = np.load(target_root/("std"+filename_postfix+".npy")) + return df, lbl_itos, mean, std + +# Cell +def dataset_add_chunk_col(df, col="data"): + '''add a chunk column to the dataset df''' + df["chunk"]=df.groupby(col).cumcount() + +def dataset_add_length_col(df, col="data", data_folder=None): + '''add a length column to the dataset df''' + df[col+"_length"]=df[col].apply(lambda x: len(np.load(x if data_folder is None else data_folder/x, allow_pickle=True))) + +def dataset_add_labels_col(df, col="label", data_folder=None): + '''add a column with unique labels in column col''' + df[col+"_labels"]=df[col].apply(lambda x: list(np.unique(np.load(x if data_folder is None else data_folder/x, allow_pickle=True)))) + +def dataset_add_mean_col(df, col="data", axis=(0), data_folder=None): + '''adds a column with mean''' + df[col+"_mean"]=df[col].apply(lambda x: np.mean(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) + +def dataset_add_median_col(df, col="data", axis=(0), data_folder=None): + '''adds a column with median''' + df[col+"_median"]=df[col].apply(lambda x: np.median(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) + +def dataset_add_std_col(df, col="data", axis=(0), data_folder=None): + '''adds a column with mean''' + df[col+"_std"]=df[col].apply(lambda x: np.std(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) + +def dataset_add_iqr_col(df, col="data", axis=(0), data_folder=None): + '''adds a column with mean''' + df[col+"_iqr"]=df[col].apply(lambda x: iqr(np.load(x if data_folder is None else data_folder/x, allow_pickle=True),axis=axis)) + +def dataset_get_stats(df, col="data", simple=True): + '''creates (weighted) means and stds from mean, std and length cols of the df''' + if(simple): + return df[col+"_mean"].mean(), df[col+"_std"].mean() + else: + #https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html + #or https://gist.github.com/thomasbrandon/ad5b1218fc573c10ea4e1f0c63658469 + def combine_two_means_vars(x1,x2): + (mean1,var1,n1) = x1 + (mean2,var2,n2) = x2 + mean = mean1*n1/(n1+n2)+ mean2*n2/(n1+n2) + var = var1*n1/(n1+n2)+ var2*n2/(n1+n2)+n1*n2/(n1+n2)/(n1+n2)*np.power(mean1-mean2,2) + return (mean, var, (n1+n2)) + + def combine_all_means_vars(means,vars,lengths): + inputs = list(zip(means,vars,lengths)) + result = inputs[0] + + for inputs2 in inputs[1:]: + result= combine_two_means_vars(result,inputs2) + return result + + means = list(df[col+"_mean"]) + vars = np.power(list(df[col+"_std"]),2) + lengths = list(df[col+"_length"]) + mean,var,length = combine_all_means_vars(means,vars,lengths) + return mean, np.sqrt(var) + +# Cell +def npys_to_memmap_batched(npys, target_filename, max_len=0, delete_npys=True, batch_length=900000): + memmap = None + start = np.array([0])#start_idx in current memmap file (always already the next start- delete last token in the end) + length = []#length of segment + filenames= []#memmap files + file_idx=[]#corresponding memmap file for sample + shape=[]#shapes of all memmap files + + data = [] + data_lengths=[] + dtype = None + + for idx,npy in tqdm(list(enumerate(npys))): + + data.append(np.load(npy, allow_pickle=True)) + data_lengths.append(len(data[-1])) + + if(idx==len(npys)-1 or np.sum(data_lengths)>batch_length):#flush + data = np.concatenate(data) + if(memmap is None or (max_len>0 and start[-1]>max_len)):#new memmap file has to be created + if(max_len>0): + filenames.append(target_filename.parent/(target_filename.stem+"_"+str(len(filenames))+".npy")) + else: + filenames.append(target_filename) + + shape.append([np.sum(data_lengths)]+[l for l in data.shape[1:]])#insert present shape + + if(memmap is not None):#an existing memmap exceeded max_len + del memmap + #create new memmap + start[-1] = 0 + start = np.concatenate([start,np.cumsum(data_lengths)]) + length = np.concatenate([length,data_lengths]) + + memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape) + else: + #append to existing memmap + start = np.concatenate([start,start[-1]+np.cumsum(data_lengths)]) + length = np.concatenate([length,data_lengths]) + shape[-1] = [start[-1]]+[l for l in data.shape[1:]] + memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple(shape[-1])) + + #store mapping memmap_id to memmap_file_id + file_idx=np.concatenate([file_idx,[(len(filenames)-1)]*len(data_lengths)]) + #insert the actual data + memmap[start[-len(data_lengths)-1]:start[-len(data_lengths)-1]+len(data)]=data[:] + memmap.flush() + dtype = data.dtype + data = []#reset data storage + data_lengths = [] + + start= start[:-1]#remove the last element + #cleanup + for npy in npys: + if(delete_npys is True): + npy.unlink() + del memmap + + #convert everything to relative paths + filenames= [f.name for f in filenames] + #save metadata + np.savez(target_filename.parent/(target_filename.stem+"_meta.npz"),start=start,length=length,shape=shape,file_idx=file_idx,dtype=dtype,filenames=filenames) + + +def npys_to_memmap(npys, target_filename, max_len=0, delete_npys=True): + memmap = None + start = []#start_idx in current memmap file + length = []#length of segment + filenames= []#memmap files + file_idx=[]#corresponding memmap file for sample + shape=[] + + for idx,npy in tqdm(list(enumerate(npys))): + data = np.load(npy, allow_pickle=True) + if(memmap is None or (max_len>0 and start[-1]+length[-1]>max_len)): + if(max_len>0): + filenames.append(target_filename.parent/(target_filename.stem+"_"+str(len(filenames)+".npy"))) + else: + filenames.append(target_filename) + + if(memmap is not None):#an existing memmap exceeded max_len + shape.append([start[-1]+length[-1]]+[l for l in data.shape[1:]]) + del memmap + #create new memmap + start.append(0) + length.append(data.shape[0]) + memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='w+', shape=data.shape) + else: + #append to existing memmap + start.append(start[-1]+length[-1]) + length.append(data.shape[0]) + memmap = np.memmap(filenames[-1], dtype=data.dtype, mode='r+', shape=tuple([start[-1]+length[-1]]+[l for l in data.shape[1:]])) + + #store mapping memmap_id to memmap_file_id + file_idx.append(len(filenames)-1) + #insert the actual data + memmap[start[-1]:start[-1]+length[-1]]=data[:] + memmap.flush() + if(delete_npys is True): + npy.unlink() + del memmap + + #append final shape if necessary + if(len(shape)<len(filenames)): + shape.append([start[-1]+length[-1]]+[l for l in data.shape[1:]]) + #convert everything to relative paths + filenames= [f.name for f in filenames] + #save metadata + np.savez(target_filename.parent/(target_filename.stem+"_meta.npz"),start=start,length=length,shape=shape,file_idx=file_idx,dtype=data.dtype,filenames=filenames) + +def reformat_as_memmap(df, target_filename, data_folder=None, annotation=False, max_len=0, delete_npys=True,col_data="data",col_label="label", batch_length=0): + npys_data = [] + npys_label = [] + + for id,row in df.iterrows(): + npys_data.append(data_folder/row[col_data] if data_folder is not None else row[col_data]) + if(annotation): + npys_label.append(data_folder/row[col_label] if data_folder is not None else row[col_label]) + if(batch_length==0): + npys_to_memmap(npys_data, target_filename, max_len=max_len, delete_npys=delete_npys) + else: + npys_to_memmap_batched(npys_data, target_filename, max_len=max_len, delete_npys=delete_npys,batch_length=batch_length) + if(annotation): + if(batch_length==0): + npys_to_memmap(npys_label, target_filename.parent/(target_filename.stem+"_label.npy"), max_len=max_len, delete_npys=delete_npys) + else: + npys_to_memmap_batched(npys_label, target_filename.parent/(target_filename.stem+"_label.npy"), max_len=max_len, delete_npys=delete_npys, batch_length=batch_length) + + #replace data(filename) by integer + df_mapped = df.copy() + df_mapped["data_original"]=df_mapped.data + df_mapped["data"]=np.arange(len(df_mapped)) + + df_mapped.to_pickle(target_filename.parent/("df_"+target_filename.stem+".pkl")) + return df_mapped + +# Cell +class TimeseriesDatasetCrops(torch.utils.data.Dataset): + """timeseries dataset with partial crops.""" + + def __init__(self, df, output_size, chunk_length, min_chunk_length, memmap_filename=None, npy_data=None, random_crop=True, data_folder=None, num_classes=2, copies=0, col_lbl="label", stride=None, start_idx=0, annotation=False, transforms=None, sample_items_per_record=1): + """ + accepts three kinds of input: + 1) filenames pointing to aligned numpy arrays [timesteps,channels,...] for data and either integer labels or filename pointing to numpy arrays[timesteps,...] e.g. for annotations + 2) memmap_filename to memmap file (same argument that was passed to reformat_as_memmap) for data [concatenated,...] and labels- label column in df corresponds to index in this memmap + 3) npy_data [samples,ts,...] (either path or np.array directly- also supporting variable length input) - label column in df corresponds to sampleid + + transforms: list of callables (transformations) or (preferred) single instance e.g. from torchvision.transforms.Compose (applied in the specified order i.e. leftmost element first) + + col_lbl = None: return dummy label 0 (e.g. for unsupervised pretraining) + """ + assert not((memmap_filename is not None) and (npy_data is not None)) + # require integer entries if using memmap or npy + assert (memmap_filename is None and npy_data is None) or df.data.dtype==np.int64 + + self.timeseries_df_data = np.array(df["data"]) + if(self.timeseries_df_data.dtype not in [np.int16, np.int32, np.int64]): + assert(memmap_filename is None and npy_data is None) #only for filenames in mode files + self.timeseries_df_data = np.array(df["data"].astype(str)).astype(np.string_) + + if(col_lbl is None):# use dummy labels + self.timeseries_df_label = np.zeros(len(df)) + else: # use actual labels + if(isinstance(df[col_lbl].iloc[0],list) or isinstance(df[col_lbl].iloc[0],np.ndarray)):#stack arrays/lists for proper batching + self.timeseries_df_label = np.stack(df[col_lbl]) + else: # single integers/floats + self.timeseries_df_label = np.array(df[col_lbl]) + + if(self.timeseries_df_label.dtype not in [np.int16, np.int32, np.int64, np.float32, np.float64]): #everything else cannot be batched anyway mp.Manager().list(self.timeseries_df_label) + assert(annotation and memmap_filename is None and npy_data is None)#only for filenames in mode files + self.timeseries_df_label = np.array(df[col_lbl].apply(lambda x:str(x))).astype(np.string_) + + self.output_size = output_size + self.data_folder = data_folder + self.transforms = transforms + if(isinstance(self.transforms,list) or isinstance(self.transforms,np.ndarray)): + print("Warning: the use of list as arguments for transforms is dicouraged") + self.annotation = annotation + self.col_lbl = col_lbl + + self.c = num_classes + + self.mode="files" + + if(memmap_filename is not None): + self.memmap_meta_filename = memmap_filename.parent/(memmap_filename.stem+"_meta.npz") + self.mode="memmap" + memmap_meta = np.load(self.memmap_meta_filename, allow_pickle=True) + self.memmap_start = memmap_meta["start"] + self.memmap_shape = memmap_meta["shape"] + self.memmap_length = memmap_meta["length"] + self.memmap_file_idx = memmap_meta["file_idx"] + self.memmap_dtype = np.dtype(str(memmap_meta["dtype"])) + self.memmap_filenames = np.array(memmap_meta["filenames"]).astype(np.string_)#save as byte to avoid issue with mp + if(annotation): + memmap_meta_label = np.load(self.memmap_meta_filename.parent/("_".join(self.memmap_meta_filename.stem.split("_")[:-1])+"_label_meta.npz"), allow_pickle=True) + self.memmap_shape_label = memmap_meta_label["shape"] + self.memmap_filenames_label = np.array(memmap_meta_label["filenames"]).astype(np.string_) + self.memmap_dtype_label = np.dtype(str(memmap_meta_label["dtype"])) + elif(npy_data is not None): + self.mode="npy" + if(isinstance(npy_data,np.ndarray) or isinstance(npy_data,list)): + self.npy_data = np.array(npy_data) + assert(annotation is False) + else: + self.npy_data = np.load(npy_data, allow_pickle=True) + if(annotation): + self.npy_data_label = np.load(npy_data.parent/(npy_data.stem+"_label.npy"), allow_pickle=True) + + self.random_crop = random_crop + self.sample_items_per_record = sample_items_per_record + + self.df_idx_mapping=[] + self.start_idx_mapping=[] + self.end_idx_mapping=[] + + for df_idx,(id,row) in enumerate(df.iterrows()): + if(self.mode=="files"): + data_length = row["data_length"] + elif(self.mode=="memmap"): + data_length= self.memmap_length[row["data"]] + else: #npy + data_length = len(self.npy_data[row["data"]]) + + if(chunk_length == 0):#do not split + idx_start = [start_idx] + idx_end = [data_length] + else: + idx_start = list(range(start_idx,data_length,chunk_length if stride is None else stride)) + idx_end = [min(l+chunk_length, data_length) for l in idx_start] + + #remove final chunk(s) if too short + for i in range(len(idx_start)): + if(idx_end[i]-idx_start[i]< min_chunk_length): + del idx_start[i:] + del idx_end[i:] + break + #append to lists + for _ in range(copies+1): + for i_s,i_e in zip(idx_start,idx_end): + self.df_idx_mapping.append(df_idx) + self.start_idx_mapping.append(i_s) + self.end_idx_mapping.append(i_e) + #convert to np.array to avoid mp issues with python lists + self.df_idx_mapping = np.array(self.df_idx_mapping) + self.start_idx_mapping = np.array(self.start_idx_mapping) + self.end_idx_mapping = np.array(self.end_idx_mapping) + + def __len__(self): + return len(self.df_idx_mapping) + + @property + def is_empty(self): + return len(self.df_idx_mapping)==0 + + def __getitem__(self, idx): + lst=[] + for _ in range(self.sample_items_per_record): + #determine crop idxs + timesteps= self.get_sample_length(idx) + + if(self.random_crop):#random crop + if(timesteps==self.output_size): + start_idx_rel = 0 + else: + start_idx_rel = random.randint(0, timesteps - self.output_size -1)#np.random.randint(0, timesteps - self.output_size) + else: + start_idx_rel = (timesteps - self.output_size)//2 + if(self.sample_items_per_record==1): + return self._getitem(idx,start_idx_rel) + else: + lst.append(self._getitem(idx,start_idx_rel)) + return tuple(lst) + + def _getitem(self, idx,start_idx_rel): + #low-level function that actually fetches the data + df_idx = self.df_idx_mapping[idx] + start_idx = self.start_idx_mapping[idx] + end_idx = self.end_idx_mapping[idx] + #determine crop idxs + timesteps= end_idx - start_idx + assert(timesteps>=self.output_size) + start_idx_crop = start_idx + start_idx_rel + end_idx_crop = start_idx_crop+self.output_size + + #print(idx,start_idx,end_idx,start_idx_crop,end_idx_crop) + #load the actual data + if(self.mode=="files"):#from separate files + data_filename = str(self.timeseries_df_data[df_idx],encoding='utf-8') #todo: fix potential issues here + if self.data_folder is not None: + data_filename = self.data_folder/data_filename + data = np.load(data_filename, allow_pickle=True)[start_idx_crop:end_idx_crop] #data type has to be adjusted when saving to npy + + ID = data_filename.stem + + if(self.annotation is True): + label_filename = str(self.timeseries_df_label[df_idx],encoding='utf-8') + if self.data_folder is not None: + label_filename = self.data_folder/label_filename + label = np.load(label_filename, allow_pickle=True)[start_idx_crop:end_idx_crop] #data type has to be adjusted when saving to npy + else: + label = self.timeseries_df_label[df_idx] #input type has to be adjusted in the dataframe + elif(self.mode=="memmap"): #from one memmap file + memmap_idx = self.timeseries_df_data[df_idx] #grab the actual index (Note the df to create the ds might be a subset of the original df used to create the memmap) + memmap_file_idx = self.memmap_file_idx[memmap_idx] + idx_offset = self.memmap_start[memmap_idx] + + #wi = torch.utils.data.get_worker_info() + #pid = 0 if wi is None else wi.id#os.getpid() + #print("idx",idx,"ID",ID,"idx_offset",idx_offset,"start_idx_crop",start_idx_crop,"df_idx", self.df_idx_mapping[idx],"pid",pid) + mem_filename = str(self.memmap_filenames[memmap_file_idx],encoding='utf-8') + mem_file = np.memmap(self.memmap_meta_filename.parent/mem_filename, self.memmap_dtype, mode='r', shape=tuple(self.memmap_shape[memmap_file_idx])) + data = np.copy(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) + del mem_file + #print(mem_file[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) + if(self.annotation): + mem_filename_label = str(self.memmap_filenames_label[memmap_file_idx],encoding='utf-8') + mem_file_label = np.memmap(self.memmap_meta_filename.parent/mem_filename_label, self.memmap_dtype_label, mode='r', shape=tuple(self.memmap_shape_label[memmap_file_idx])) + + label = np.copy(mem_file_label[idx_offset + start_idx_crop: idx_offset + end_idx_crop]) + del mem_file_label + else: + label = self.timeseries_df_label[df_idx] + else:#single npy array + ID = self.timeseries_df_data[df_idx] + + data = self.npy_data[ID][start_idx_crop:end_idx_crop] + + if(self.annotation): + label = self.npy_data_label[ID][start_idx_crop:end_idx_crop] + else: + label = self.timeseries_df_label[df_idx] + + sample = (data,label) + + if(isinstance(self.transforms,list)):#transforms passed as list + for t in self.transforms: + sample = t(sample) + elif(self.transforms is not None):#single transform e.g. from torchvision.transforms.Compose + sample = self.transforms(sample) + + return sample + + def get_sampling_weights(self, class_weight_dict,length_weighting=False, timeseries_df_group_by_col=None): + ''' + class_weight_dict: dictionary of class weights + length_weighting: weigh samples by length + timeseries_df_group_by_col: column of the pandas df used to create the object''' + assert(self.annotation is False) + assert(length_weighting is False or timeseries_df_group_by_col is None) + weights = np.zeros(len(self.df_idx_mapping),dtype=np.float32) + length_per_class = {} + length_per_group = {} + for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)): + label = self.timeseries_df_label[i] + weight = class_weight_dict[label] + if(length_weighting): + if label in length_per_class.keys(): + length_per_class[label] += e-s + else: + length_per_class[label] = e-s + if(timeseries_df_group_by_col is not None): + group = timeseries_df_group_by_col[i] + if group in length_per_group.keys(): + length_per_group[group] += e-s + else: + length_per_group[group] = e-s + weights[iw] = weight + + if(length_weighting):#need second pass to properly take into account the total length per class + for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)): + label = self.timeseries_df_label[i] + weights[iw]= (e-s)/length_per_class[label]*weights[iw] + if(timeseries_df_group_by_col is not None): + for iw,(i,s,e) in enumerate(zip(self.df_idx_mapping,self.start_idx_mapping,self.end_idx_mapping)): + group = timeseries_df_group_by_col[i] + weights[iw]= (e-s)/length_per_group[group]*weights[iw] + + weights = weights/np.min(weights)#normalize smallest weight to 1 + return weights + + def get_id_mapping(self): + return self.df_idx_mapping + + def get_sample_id(self,idx): + return self.df_idx_mapping[idx] + + def get_sample_length(self,idx): + return self.end_idx_mapping[idx]-self.start_idx_mapping[idx] + + def get_sample_start(self,idx): + return self.start_idx_mapping[idx] + +# Cell +class RandomCrop(object): + """Crop randomly the image in a sample. + """ + + def __init__(self, output_size,annotation=False): + self.output_size = output_size + self.annotation = annotation + + def __call__(self, sample): + data, label = sample + + timesteps= len(data) + assert(timesteps>=self.output_size) + if(timesteps==self.output_size): + start=0 + else: + start = random.randint(0, timesteps - self.output_size-1) #np.random.randint(0, timesteps - self.output_size) + + data = data[start: start + self.output_size] + if(self.annotation): + label = label[start: start + self.output_size] + + return data, label + +# Cell +class CenterCrop(object): + """Center crop the image in a sample. + """ + + def __init__(self, output_size, annotation=False): + self.output_size = output_size + self.annotation = annotation + + def __call__(self, sample): + data, label = sample + + timesteps= len(data) + start = (timesteps - self.output_size)//2 + + data = data[start: start + self.output_size] + if(self.annotation): + label = label[start: start + self.output_size] + + return data, label + +# Cell +class GaussianNoise(object): + """Add gaussian noise to sample. + """ + + def __init__(self, scale=0.1): + self.scale = scale + + def __call__(self, sample): + if self.scale ==0: + return sample + else: + data, label = sample + data = data + np.reshape(np.array([random.gauss(0,self.scale) for _ in range(np.prod(data.shape))]),data.shape)#np.random.normal(scale=self.scale,size=data.shape).astype(np.float32) + return data, label + +# Cell +class Rescale(object): + """Rescale by factor. + """ + + def __init__(self, scale=0.5,interpolation_order=3): + self.scale = scale + self.interpolation_order = interpolation_order + + def __call__(self, sample): + if self.scale ==1: + return sample + else: + data, label = sample + timesteps_new = int(self.scale * len(data)) + data = transform.resize(data,(timesteps_new,data.shape[1]),order=interpolation_order).astype(np.float32) + return data,label + +# Cell +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + def __init__(self, transpose_data=True, transpose_label=False): + #swap channel and time axis for direct application of pytorch's convs + self.transpose_data=transpose_data + self.transpose_label=transpose_label + + def __call__(self, sample): + + def _to_tensor(data,transpose=False): + if(isinstance(data,np.ndarray)): + if(transpose):#seq,[x,y,]ch + return torch.from_numpy(np.moveaxis(data,-1,0)) + else: + return torch.from_numpy(data) + else:#default_collate will take care of it + return data + + data, label = sample + + if not isinstance(data,tuple): + data = _to_tensor(data,self.transpose_data) + else: + data = tuple(_to_tensor(x,self.transpose_data) for x in data) + + if not isinstance(label,tuple): + label = _to_tensor(label,self.transpose_label) + else: + label = tuple(_to_tensor(x,self.transpose_label) for x in label) + + return data,label #returning as a tuple (potentially of lists) + +# Cell +class Normalize(object): + """Normalize using given stats. + """ + def __init__(self, stats_mean, stats_std, input=True, channels=[]): + self.stats_mean=stats_mean.astype(np.float32) if stats_mean is not None else None + self.stats_std=stats_std.astype(np.float32)+1e-8 if stats_std is not None else None + self.input = input + if(len(channels)>0): + for i in range(len(stats_mean)): + if(not(i in channels)): + self.stats_mean[:,i]=0 + self.stats_std[:,i]=1 + + def __call__(self, sample): + datax, labelx = sample + data = datax if self.input else labelx + #assuming channel last + if(self.stats_mean is not None): + data = data - self.stats_mean + if(self.stats_std is not None): + data = data/self.stats_std + + if(self.input): + return (data, labelx) + else: + return (datax, data) + +# Cell +class NormalizeBatch(object): + """Normalize using batch statistics. + axis: tuple of integers of axis numbers to be normalized over (by default everything but the last) + """ + def __init__(self, input=True, channels=[],axis=None): + self.channels = channels + self.channels_keep = None + self.input = input + self.axis = axis + + def __call__(self, sample): + datax, labelx = sample + data = datax if self.input else labelx + #assuming channel last + #batch_mean = np.mean(data,axis=tuple(range(0,len(data)-1))) + #batch_std = np.std(data,axis=tuple(range(0,len(data)-1)))+1e-8 + batch_mean = np.mean(data,axis=self.axis if self.axis is not None else tuple(range(0,len(data.shape)-1))) + batch_std = np.std(data,axis=self.axis if self.axis is not None else tuple(range(0,len(data.shape)-1)))+1e-8 + + if(len(self.channels)>0): + if(self.channels_keep is None): + self.channels_keep = np.setdiff(range(data.shape[-1]),self.channels) + + batch_mean[self.channels_keep]=0 + batch_std[self.channels_keep]=1 + + data = (data - batch_mean)/batch_std + + if(self.input): + return (data, labelx) + else: + return (datax, data) + +# Cell +class ButterFilter(object): + """Apply filter + """ + + def __init__(self, lowcut=50, highcut=50, fs=100, order=5, btype='band', forwardbackward=True, input=True): + self.filter = butter_filter(lowcut,highcut,fs,order,btype) + self.input = input + self.forwardbackward = forwardbackward + + def __call__(self, sample): + datax, labelx = sample + data = datax if self.input else labelx + + if(self.forwardbackward): + data = sosfiltfilt(self.filter, data, axis=0) + else: + data = sosfilt(self.filter, data, axis=0) + + if(self.input): + return (data, labelx) + else: + return (datax, data) + +# Cell +class ChannelFilter(object): + """Select certain channels. + """ + + def __init__(self, channels=[0], input=True): + self.channels = channels + self.input = input + + def __call__(self, sample): + data,label = sample + if(self.input): + return (data[...,self.channels], label) + else: + return (data, label[...,self.channels]) + +# Cell +class Transform(object): + """Transforms data using a given function i.e. data_new = func(data) for input is True else label_new = func(label) + """ + + def __init__(self, func, input=False): + self.func = func + self.input = input + + def __call__(self, sample): + data,label = sample + if(self.input): + return (self.func(data), label) + else: + return (data, self.func(label)) + +# Cell +class TupleTransform(object): + """Transforms data using a given function (operating on both data and label and return a tuple) i.e. data_new, label_new = func(data_old, label_old) + """ + + def __init__(self, func, input=False): + self.func = func + + def __call__(self, sample): + data,label = sample + return self.func(data,label) + +# Cell +def aggregate_predictions(preds,targs=None,idmap=None,aggregate_fn = np.mean,verbose=False): + ''' + aggregates potentially multiple predictions per sample (can also pass targs for convenience) + idmap: idmap as returned by TimeSeriesCropsDataset's get_id_mapping + preds: ordered predictions as returned by learn.get_preds() + aggregate_fn: function that is used to aggregate multiple predictions per sample (most commonly np.amax or np.mean) + ''' + if(idmap is not None and len(idmap)!=len(np.unique(idmap))): + if(verbose): + print("aggregating predictions...") + preds_aggregated = [] + targs_aggregated = [] + for i in np.unique(idmap): + preds_local = preds[np.where(idmap==i)[0]] + preds_aggregated.append(aggregate_fn(preds_local,axis=0)) + if targs is not None: + targs_local = targs[np.where(idmap==i)[0]] + assert(np.all(targs_local==targs_local[0])) #all labels have to agree + targs_aggregated.append(targs_local[0]) + if(targs is None): + return np.array(preds_aggregated) + else: + return np.array(preds_aggregated),np.array(targs_aggregated) + else: + if(targs is None): + return preds + else: + return preds,targs