[03464c]: / datasets / __init__.py

Download this file

225 lines (184 with data), 7.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
This package about data loading and data preprocessing
"""
import os
import torch
import importlib
import numpy as np
import pandas as pd
from util import util
from datasets.basic_dataset import BasicDataset
from datasets.dataloader_prefetch import DataLoaderPrefetch
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
def find_dataset_using_name(dataset_mode):
"""
Get the dataset of certain mode
"""
dataset_filename = "datasets." + dataset_mode + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
# Instantiate the dataset class
dataset = None
# Change the name format to corresponding class name
target_dataset_name = dataset_mode.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BasicDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BasicDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def create_dataset(param):
"""
Create a dataset given the parameters.
"""
dataset_class = find_dataset_using_name(param.omics_mode)
# Get an instance of this dataset class
dataset = dataset_class(param)
print("Dataset [%s] was created" % type(dataset).__name__)
return dataset
class CustomDataLoader:
"""
Create a dataloader for certain dataset.
"""
def __init__(self, dataset, param, shuffle=True, enable_drop_last=False):
self.dataset = dataset
self.param = param
drop_last = False
if enable_drop_last:
if len(dataset) % param.batch_size < 3*len(param.gpu_ids):
drop_last = True
# Create dataloader for this dataset
self.dataloader = DataLoaderPrefetch(
dataset,
batch_size=param.batch_size,
shuffle=shuffle,
num_workers=int(param.num_threads),
drop_last=drop_last,
pin_memory=param.set_pin_memory
)
def __len__(self):
"""Return the number of data in the dataset"""
return len(self.dataset)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
yield data
def get_A_dim(self):
"""Return the dimension of first input omics data type"""
return self.dataset.A_dim
def get_B_dim(self):
"""Return the dimension of second input omics data type"""
return self.dataset.B_dim
def get_omics_dims(self):
"""Return a list of omics dimensions"""
return self.dataset.omics_dims
def get_class_num(self):
"""Return the number of classes for the downstream classification task"""
return self.dataset.class_num
def get_values_max(self):
"""Return the maximum target value of the dataset"""
return self.dataset.values_max
def get_values_min(self):
"""Return the minimum target value of the dataset"""
return self.dataset.values_min
def get_survival_T_max(self):
"""Return the maximum T of the dataset"""
return self.dataset.survival_T_max
def get_survival_T_min(self):
"""Return the minimum T of the dataset"""
return self.dataset.survival_T_min
def get_sample_list(self):
"""Return the sample list of the dataset"""
return self.dataset.sample_list
def create_single_dataloader(param, shuffle=True, enable_drop_last=False):
"""
Create a single dataloader
"""
dataset = create_dataset(param)
dataloader = CustomDataLoader(dataset, param, shuffle=shuffle, enable_drop_last=enable_drop_last)
sample_list = dataset.sample_list
return dataloader, sample_list
def create_separate_dataloader(param):
"""
Create set of dataloader (train, val, test).
"""
full_dataset = create_dataset(param)
full_size = len(full_dataset)
full_idx = np.arange(full_size)
if param.not_stratified:
train_idx, test_idx = train_test_split(full_idx,
test_size=param.test_ratio,
train_size=param.train_ratio,
shuffle=True)
else:
if param.downstream_task == 'classification':
targets = full_dataset.labels_array
elif param.downstream_task == 'survival':
targets = full_dataset.survival_E_array
if param.stratify_label:
targets = full_dataset.labels_array
elif param.downstream_task == 'multitask':
targets = full_dataset.labels_array
elif param.downstream_task == 'alltask':
targets = full_dataset.labels_array[0]
train_idx, test_idx = train_test_split(full_idx,
test_size=param.test_ratio,
train_size=param.train_ratio,
shuffle=True,
stratify=targets)
val_idx = list(set(full_idx) - set(train_idx) - set(test_idx))
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)
full_dataloader = CustomDataLoader(full_dataset, param)
train_dataloader = CustomDataLoader(train_dataset, param, enable_drop_last=True)
val_dataloader = CustomDataLoader(val_dataset, param, shuffle=False)
test_dataloader = CustomDataLoader(test_dataset, param, shuffle=False)
return full_dataloader, train_dataloader, val_dataloader, test_dataloader
def load_file(param, file_name):
"""
Load data according to the format.
"""
if param.file_format == 'tsv':
file_path = os.path.join(param.data_root, file_name + '.tsv')
print('Loading data from ' + file_path)
df = pd.read_csv(file_path, sep='\t', header=0, index_col=0, na_filter=param.detect_na)
elif param.file_format == 'csv':
file_path = os.path.join(param.data_root, file_name + '.csv')
print('Loading data from ' + file_path)
df = pd.read_csv(file_path, header=0, index_col=0, na_filter=param.detect_na)
elif param.file_format == 'hdf':
file_path = os.path.join(param.data_root, file_name + '.h5')
print('Loading data from ' + file_path)
df = pd.read_hdf(file_path, header=0, index_col=0)
else:
raise NotImplementedError('File format %s is supported' % param.file_format)
return df
def get_survival_y_true(param, T, E):
"""
Get y_true for survival prediction based on T and E
"""
# Get T_max
if param.survival_T_max == -1:
T_max = T.max()
else:
T_max = param.survival_T_max
# Get time points
time_points = util.get_time_points(T_max, param.time_num)
# Get the y_true
y_true = []
for i, (t, e) in enumerate(zip(T, E)):
y_true_i = np.zeros(param.time_num + 1)
dist_to_time_points = [abs(t - point) for point in time_points[:-1]]
time_index = np.argmin(dist_to_time_points)
# if this is a uncensored data point
if e == 1:
y_true_i[time_index] = 1
y_true.append(y_true_i)
# if this is a censored data point
else:
y_true_i[time_index:] = 1
y_true.append(y_true_i)
y_true = torch.Tensor(y_true)
return y_true