|
a |
|
b/src/move/data/dataloaders.py |
|
|
1 |
__all__ = ["MOVEDataset", "make_dataset", "make_dataloader", "split_samples"] |
|
|
2 |
|
|
|
3 |
from typing import Optional |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import torch |
|
|
7 |
from torch.utils.data import DataLoader, TensorDataset |
|
|
8 |
|
|
|
9 |
from move.core.typing import BoolArray, FloatArray |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class MOVEDataset(TensorDataset): |
|
|
13 |
""" |
|
|
14 |
Characterizes a dataset for PyTorch |
|
|
15 |
|
|
|
16 |
Args: |
|
|
17 |
cat_all: |
|
|
18 |
categorical input matrix (N_samples, N_variables x N_max-classes. |
|
|
19 |
con_all: |
|
|
20 |
normalized continuous input matrix (N_samples, N_variables). |
|
|
21 |
cat_shapes: |
|
|
22 |
list of tuples corresponding to number of features (N_variables, |
|
|
23 |
N_max-classes) of each categorical class. |
|
|
24 |
con_shapes: |
|
|
25 |
list of tuples corresponding to number of features |
|
|
26 |
(N_variables) of each continuous class. |
|
|
27 |
|
|
|
28 |
Raises: |
|
|
29 |
ValueError: |
|
|
30 |
Number of samples between categorical and continuous datasets must |
|
|
31 |
match. |
|
|
32 |
ValueError: |
|
|
33 |
Categorical and continuous data cannot be both empty. |
|
|
34 |
""" |
|
|
35 |
|
|
|
36 |
def __init__( |
|
|
37 |
self, |
|
|
38 |
cat_all: Optional[torch.Tensor] = None, |
|
|
39 |
con_all: Optional[torch.Tensor] = None, |
|
|
40 |
cat_shapes: Optional[list[tuple[int, ...]]] = None, |
|
|
41 |
con_shapes: Optional[list[int]] = None, |
|
|
42 |
) -> None: |
|
|
43 |
# Check |
|
|
44 |
num_samples = None if cat_all is None else cat_all.shape[0] |
|
|
45 |
if con_all is not None: |
|
|
46 |
if num_samples is None: |
|
|
47 |
num_samples = con_all.shape[0] |
|
|
48 |
elif num_samples != con_all.shape[0]: |
|
|
49 |
raise ValueError( |
|
|
50 |
"Number of samples between categorical and continuous " |
|
|
51 |
"datasets must match." |
|
|
52 |
) |
|
|
53 |
elif num_samples is None: |
|
|
54 |
raise ValueError("Categorical and continuous data cannot be both empty.") |
|
|
55 |
self.num_samples = num_samples |
|
|
56 |
self.cat_all = cat_all |
|
|
57 |
self.cat_shapes = cat_shapes |
|
|
58 |
self.con_all = con_all |
|
|
59 |
self.con_shapes = con_shapes |
|
|
60 |
|
|
|
61 |
def __len__(self) -> int: |
|
|
62 |
return self.num_samples |
|
|
63 |
|
|
|
64 |
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
65 |
cat_slice = torch.empty(0) if self.cat_all is None else self.cat_all[idx] |
|
|
66 |
con_slice = torch.empty(0) if self.con_all is None else self.con_all[idx] |
|
|
67 |
return cat_slice, con_slice |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
def concat_cat_list( |
|
|
71 |
cat_list: list[FloatArray], |
|
|
72 |
) -> tuple[list[tuple[int, ...]], FloatArray]: |
|
|
73 |
""" |
|
|
74 |
Concatenate a list of categorical data |
|
|
75 |
Args: |
|
|
76 |
cat_list: list with each categorical class data |
|
|
77 |
Returns: |
|
|
78 |
(tuple): a tuple containing: |
|
|
79 |
cat_shapes: |
|
|
80 |
list of categorical data classes shapes (N_variables, |
|
|
81 |
N_max-classes) |
|
|
82 |
cat_all (FloatArray): |
|
|
83 |
2D array of concatenated patients categorical data |
|
|
84 |
""" |
|
|
85 |
|
|
|
86 |
cat_shapes = [] |
|
|
87 |
cat_flat = [] |
|
|
88 |
for cat in cat_list: |
|
|
89 |
cat_shape = (cat.shape[1], cat.shape[2]) |
|
|
90 |
cat_shapes.append(cat_shape) |
|
|
91 |
cat_flat.append(cat.reshape(cat.shape[0], -1)) |
|
|
92 |
cat_all = np.concatenate(cat_flat, axis=1) |
|
|
93 |
return cat_shapes, cat_all |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
def concat_con_list( |
|
|
97 |
con_list: list[FloatArray], |
|
|
98 |
) -> tuple[list[int], FloatArray]: |
|
|
99 |
""" |
|
|
100 |
Concatenate a list of continuous data |
|
|
101 |
Args: |
|
|
102 |
con_list: list with each continuous class data |
|
|
103 |
Returns: |
|
|
104 |
(tuple): a tuple containing: |
|
|
105 |
n_con_shapes: |
|
|
106 |
list of continuous data classes shapes (in 1D) (N_variables) |
|
|
107 |
con_all: |
|
|
108 |
2D array of concatenated patients continuous data |
|
|
109 |
""" |
|
|
110 |
con_shapes = [con.shape[1] for con in con_list] |
|
|
111 |
con_all: FloatArray = np.concatenate(con_list, axis=1) |
|
|
112 |
return con_shapes, con_all |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
def make_dataset( |
|
|
116 |
cat_list: Optional[list[FloatArray]] = None, |
|
|
117 |
con_list: Optional[list[FloatArray]] = None, |
|
|
118 |
mask: Optional[BoolArray] = None, |
|
|
119 |
) -> MOVEDataset: |
|
|
120 |
"""Creates a dataset that combines categorical and continuous datasets. |
|
|
121 |
|
|
|
122 |
Args: |
|
|
123 |
cat_list: |
|
|
124 |
List of categorical datasets (`num_samples` x `num_features` |
|
|
125 |
x `num_categories`). Defaults to None. |
|
|
126 |
con_list: |
|
|
127 |
List of continuous datasets (`num_samples` x `num_features`). |
|
|
128 |
Defaults to None. |
|
|
129 |
mask: |
|
|
130 |
Boolean array to mask samples. Defaults to None. |
|
|
131 |
|
|
|
132 |
Raises: |
|
|
133 |
ValueError: If both inputs are None |
|
|
134 |
|
|
|
135 |
Returns: |
|
|
136 |
MOVEDataset |
|
|
137 |
""" |
|
|
138 |
if not cat_list and not con_list: |
|
|
139 |
raise ValueError("At least one type of data must be in the input") |
|
|
140 |
|
|
|
141 |
cat_shapes, cat_all = [], None |
|
|
142 |
if cat_list: |
|
|
143 |
cat_shapes, cat_all = concat_cat_list(cat_list) |
|
|
144 |
|
|
|
145 |
con_shapes, con_all = [], None |
|
|
146 |
if con_list: |
|
|
147 |
con_shapes, con_all = concat_con_list(con_list) |
|
|
148 |
|
|
|
149 |
if cat_all is not None: |
|
|
150 |
cat_all = torch.from_numpy(cat_all) |
|
|
151 |
if mask is not None: |
|
|
152 |
cat_all = cat_all[mask] |
|
|
153 |
|
|
|
154 |
if con_all is not None: |
|
|
155 |
con_all = torch.from_numpy(con_all) |
|
|
156 |
if mask is not None: |
|
|
157 |
con_all = con_all[mask] |
|
|
158 |
|
|
|
159 |
return MOVEDataset(cat_all, con_all, cat_shapes, con_shapes) |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def make_dataloader( |
|
|
163 |
cat_list: Optional[list[FloatArray]] = None, |
|
|
164 |
con_list: Optional[list[FloatArray]] = None, |
|
|
165 |
mask: Optional[BoolArray] = None, |
|
|
166 |
**kwargs |
|
|
167 |
) -> DataLoader: |
|
|
168 |
"""Creates a DataLoader that combines categorical and continuous datasets. |
|
|
169 |
|
|
|
170 |
Args: |
|
|
171 |
cat_list: |
|
|
172 |
List of categorical datasets (`num_samples` x `num_features` |
|
|
173 |
x `num_categories`). Defaults to None. |
|
|
174 |
con_list: |
|
|
175 |
List of continuous datasets (`num_samples` x `num_features`). |
|
|
176 |
Defaults to None. |
|
|
177 |
mask: |
|
|
178 |
Boolean array to mask samples. Defaults to None. |
|
|
179 |
**kwargs: |
|
|
180 |
Arguments to pass to the DataLoader (e.g., batch size) |
|
|
181 |
|
|
|
182 |
Raises: |
|
|
183 |
ValueError: If both inputs are None |
|
|
184 |
|
|
|
185 |
Returns: |
|
|
186 |
DataLoader |
|
|
187 |
""" |
|
|
188 |
dataset = make_dataset(cat_list, con_list, mask) |
|
|
189 |
return DataLoader(dataset, **kwargs) |
|
|
190 |
|
|
|
191 |
|
|
|
192 |
def split_samples( |
|
|
193 |
num_samples: int, |
|
|
194 |
train_frac: float, |
|
|
195 |
) -> BoolArray: |
|
|
196 |
"""Generate mask to randomly split samples into training and test sets. |
|
|
197 |
|
|
|
198 |
Args: |
|
|
199 |
num_samples: Number of samples to split. |
|
|
200 |
train_frac: Fraction of samples corresponding to training set. |
|
|
201 |
|
|
|
202 |
Returns: |
|
|
203 |
Boolean array to mask test samples. |
|
|
204 |
""" |
|
|
205 |
sample_ids = np.arange(num_samples) |
|
|
206 |
train_size = int(train_frac * num_samples) |
|
|
207 |
|
|
|
208 |
rng = np.random.default_rng() |
|
|
209 |
train_ids = rng.permutation(sample_ids)[:train_size] |
|
|
210 |
|
|
|
211 |
return np.isin(sample_ids, train_ids) |