--- a +++ b/datasets/__init__.py @@ -0,0 +1,224 @@ +""" +This package about data loading and data preprocessing +""" +import os +import torch +import importlib +import numpy as np +import pandas as pd +from util import util +from datasets.basic_dataset import BasicDataset +from datasets.dataloader_prefetch import DataLoaderPrefetch +from torch.utils.data import Subset +from sklearn.model_selection import train_test_split + + +def find_dataset_using_name(dataset_mode): + """ + Get the dataset of certain mode + """ + dataset_filename = "datasets." + dataset_mode + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + # Instantiate the dataset class + dataset = None + # Change the name format to corresponding class name + target_dataset_name = dataset_mode.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BasicDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BasicDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def create_dataset(param): + """ + Create a dataset given the parameters. + """ + dataset_class = find_dataset_using_name(param.omics_mode) + # Get an instance of this dataset class + dataset = dataset_class(param) + print("Dataset [%s] was created" % type(dataset).__name__) + + return dataset + + +class CustomDataLoader: + """ + Create a dataloader for certain dataset. + """ + def __init__(self, dataset, param, shuffle=True, enable_drop_last=False): + self.dataset = dataset + self.param = param + + drop_last = False + if enable_drop_last: + if len(dataset) % param.batch_size < 3*len(param.gpu_ids): + drop_last = True + + # Create dataloader for this dataset + self.dataloader = DataLoaderPrefetch( + dataset, + batch_size=param.batch_size, + shuffle=shuffle, + num_workers=int(param.num_threads), + drop_last=drop_last, + pin_memory=param.set_pin_memory + ) + + def __len__(self): + """Return the number of data in the dataset""" + return len(self.dataset) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + yield data + + def get_A_dim(self): + """Return the dimension of first input omics data type""" + return self.dataset.A_dim + + def get_B_dim(self): + """Return the dimension of second input omics data type""" + return self.dataset.B_dim + + def get_omics_dims(self): + """Return a list of omics dimensions""" + return self.dataset.omics_dims + + def get_class_num(self): + """Return the number of classes for the downstream classification task""" + return self.dataset.class_num + + def get_values_max(self): + """Return the maximum target value of the dataset""" + return self.dataset.values_max + + def get_values_min(self): + """Return the minimum target value of the dataset""" + return self.dataset.values_min + + def get_survival_T_max(self): + """Return the maximum T of the dataset""" + return self.dataset.survival_T_max + + def get_survival_T_min(self): + """Return the minimum T of the dataset""" + return self.dataset.survival_T_min + + def get_sample_list(self): + """Return the sample list of the dataset""" + return self.dataset.sample_list + + +def create_single_dataloader(param, shuffle=True, enable_drop_last=False): + """ + Create a single dataloader + """ + dataset = create_dataset(param) + dataloader = CustomDataLoader(dataset, param, shuffle=shuffle, enable_drop_last=enable_drop_last) + sample_list = dataset.sample_list + + return dataloader, sample_list + + +def create_separate_dataloader(param): + """ + Create set of dataloader (train, val, test). + """ + full_dataset = create_dataset(param) + full_size = len(full_dataset) + full_idx = np.arange(full_size) + + if param.not_stratified: + train_idx, test_idx = train_test_split(full_idx, + test_size=param.test_ratio, + train_size=param.train_ratio, + shuffle=True) + else: + if param.downstream_task == 'classification': + targets = full_dataset.labels_array + elif param.downstream_task == 'survival': + targets = full_dataset.survival_E_array + if param.stratify_label: + targets = full_dataset.labels_array + elif param.downstream_task == 'multitask': + targets = full_dataset.labels_array + elif param.downstream_task == 'alltask': + targets = full_dataset.labels_array[0] + train_idx, test_idx = train_test_split(full_idx, + test_size=param.test_ratio, + train_size=param.train_ratio, + shuffle=True, + stratify=targets) + + val_idx = list(set(full_idx) - set(train_idx) - set(test_idx)) + + train_dataset = Subset(full_dataset, train_idx) + val_dataset = Subset(full_dataset, val_idx) + test_dataset = Subset(full_dataset, test_idx) + + full_dataloader = CustomDataLoader(full_dataset, param) + train_dataloader = CustomDataLoader(train_dataset, param, enable_drop_last=True) + val_dataloader = CustomDataLoader(val_dataset, param, shuffle=False) + test_dataloader = CustomDataLoader(test_dataset, param, shuffle=False) + + return full_dataloader, train_dataloader, val_dataloader, test_dataloader + + +def load_file(param, file_name): + """ + Load data according to the format. + """ + if param.file_format == 'tsv': + file_path = os.path.join(param.data_root, file_name + '.tsv') + print('Loading data from ' + file_path) + df = pd.read_csv(file_path, sep='\t', header=0, index_col=0, na_filter=param.detect_na) + elif param.file_format == 'csv': + file_path = os.path.join(param.data_root, file_name + '.csv') + print('Loading data from ' + file_path) + df = pd.read_csv(file_path, header=0, index_col=0, na_filter=param.detect_na) + elif param.file_format == 'hdf': + file_path = os.path.join(param.data_root, file_name + '.h5') + print('Loading data from ' + file_path) + df = pd.read_hdf(file_path, header=0, index_col=0) + else: + raise NotImplementedError('File format %s is supported' % param.file_format) + return df + + +def get_survival_y_true(param, T, E): + """ + Get y_true for survival prediction based on T and E + """ + # Get T_max + if param.survival_T_max == -1: + T_max = T.max() + else: + T_max = param.survival_T_max + + # Get time points + time_points = util.get_time_points(T_max, param.time_num) + + # Get the y_true + y_true = [] + for i, (t, e) in enumerate(zip(T, E)): + y_true_i = np.zeros(param.time_num + 1) + dist_to_time_points = [abs(t - point) for point in time_points[:-1]] + time_index = np.argmin(dist_to_time_points) + # if this is a uncensored data point + if e == 1: + y_true_i[time_index] = 1 + y_true.append(y_true_i) + # if this is a censored data point + else: + y_true_i[time_index:] = 1 + y_true.append(y_true_i) + y_true = torch.Tensor(y_true) + + return y_true