a b/src/move/data/dataloaders.py
1
__all__ = ["MOVEDataset", "make_dataset", "make_dataloader", "split_samples"]
2
3
from typing import Optional
4
5
import numpy as np
6
import torch
7
from torch.utils.data import DataLoader, TensorDataset
8
9
from move.core.typing import BoolArray, FloatArray
10
11
12
class MOVEDataset(TensorDataset):
13
    """
14
    Characterizes a dataset for PyTorch
15
16
    Args:
17
        cat_all:
18
            categorical input matrix (N_samples, N_variables x N_max-classes.
19
        con_all:
20
            normalized continuous input matrix (N_samples, N_variables).
21
        cat_shapes:
22
            list of tuples corresponding to number of features (N_variables,
23
            N_max-classes) of each categorical class.
24
        con_shapes:
25
            list of tuples corresponding to number of features
26
            (N_variables) of each continuous class.
27
28
    Raises:
29
        ValueError:
30
            Number of samples between categorical and continuous datasets must
31
            match.
32
        ValueError:
33
            Categorical and continuous data cannot be both empty.
34
    """
35
36
    def __init__(
37
        self,
38
        cat_all: Optional[torch.Tensor] = None,
39
        con_all: Optional[torch.Tensor] = None,
40
        cat_shapes: Optional[list[tuple[int, ...]]] = None,
41
        con_shapes: Optional[list[int]] = None,
42
    ) -> None:
43
        # Check
44
        num_samples = None if cat_all is None else cat_all.shape[0]
45
        if con_all is not None:
46
            if num_samples is None:
47
                num_samples = con_all.shape[0]
48
            elif num_samples != con_all.shape[0]:
49
                raise ValueError(
50
                    "Number of samples between categorical and continuous "
51
                    "datasets must match."
52
                )
53
        elif num_samples is None:
54
            raise ValueError("Categorical and continuous data cannot be both empty.")
55
        self.num_samples = num_samples
56
        self.cat_all = cat_all
57
        self.cat_shapes = cat_shapes
58
        self.con_all = con_all
59
        self.con_shapes = con_shapes
60
61
    def __len__(self) -> int:
62
        return self.num_samples
63
64
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
65
        cat_slice = torch.empty(0) if self.cat_all is None else self.cat_all[idx]
66
        con_slice = torch.empty(0) if self.con_all is None else self.con_all[idx]
67
        return cat_slice, con_slice
68
69
70
def concat_cat_list(
71
    cat_list: list[FloatArray],
72
) -> tuple[list[tuple[int, ...]], FloatArray]:
73
    """
74
    Concatenate a list of categorical data
75
    Args:
76
        cat_list: list with each categorical class data
77
    Returns:
78
        (tuple): a tuple containing:
79
            cat_shapes:
80
                list of categorical data classes shapes (N_variables,
81
                 N_max-classes)
82
            cat_all (FloatArray):
83
                2D array of concatenated patients categorical data
84
    """
85
86
    cat_shapes = []
87
    cat_flat = []
88
    for cat in cat_list:
89
        cat_shape = (cat.shape[1], cat.shape[2])
90
        cat_shapes.append(cat_shape)
91
        cat_flat.append(cat.reshape(cat.shape[0], -1))
92
    cat_all = np.concatenate(cat_flat, axis=1)
93
    return cat_shapes, cat_all
94
95
96
def concat_con_list(
97
    con_list: list[FloatArray],
98
) -> tuple[list[int], FloatArray]:
99
    """
100
    Concatenate a list of continuous data
101
    Args:
102
        con_list: list with each continuous class data
103
    Returns:
104
        (tuple): a tuple containing:
105
            n_con_shapes:
106
                list of continuous data classes shapes (in 1D) (N_variables)
107
            con_all:
108
                2D array of concatenated patients continuous data
109
    """
110
    con_shapes = [con.shape[1] for con in con_list]
111
    con_all: FloatArray = np.concatenate(con_list, axis=1)
112
    return con_shapes, con_all
113
114
115
def make_dataset(
116
    cat_list: Optional[list[FloatArray]] = None,
117
    con_list: Optional[list[FloatArray]] = None,
118
    mask: Optional[BoolArray] = None,
119
) -> MOVEDataset:
120
    """Creates a dataset that combines categorical and continuous datasets.
121
122
    Args:
123
        cat_list:
124
            List of categorical datasets (`num_samples` x `num_features`
125
            x `num_categories`). Defaults to None.
126
        con_list:
127
            List of continuous datasets (`num_samples` x `num_features`).
128
            Defaults to None.
129
        mask:
130
            Boolean array to mask samples. Defaults to None.
131
132
    Raises:
133
        ValueError: If both inputs are None
134
135
    Returns:
136
        MOVEDataset
137
    """
138
    if not cat_list and not con_list:
139
        raise ValueError("At least one type of data must be in the input")
140
141
    cat_shapes, cat_all = [], None
142
    if cat_list:
143
        cat_shapes, cat_all = concat_cat_list(cat_list)
144
145
    con_shapes, con_all = [], None
146
    if con_list:
147
        con_shapes, con_all = concat_con_list(con_list)
148
149
    if cat_all is not None:
150
        cat_all = torch.from_numpy(cat_all)
151
        if mask is not None:
152
            cat_all = cat_all[mask]
153
154
    if con_all is not None:
155
        con_all = torch.from_numpy(con_all)
156
        if mask is not None:
157
            con_all = con_all[mask]
158
159
    return MOVEDataset(cat_all, con_all, cat_shapes, con_shapes)
160
161
162
def make_dataloader(
163
    cat_list: Optional[list[FloatArray]] = None,
164
    con_list: Optional[list[FloatArray]] = None,
165
    mask: Optional[BoolArray] = None,
166
    **kwargs
167
) -> DataLoader:
168
    """Creates a DataLoader that combines categorical and continuous datasets.
169
170
    Args:
171
        cat_list:
172
            List of categorical datasets (`num_samples` x `num_features`
173
            x `num_categories`). Defaults to None.
174
        con_list:
175
            List of continuous datasets (`num_samples` x `num_features`).
176
            Defaults to None.
177
        mask:
178
            Boolean array to mask samples. Defaults to None.
179
        **kwargs:
180
            Arguments to pass to the DataLoader (e.g., batch size)
181
182
    Raises:
183
        ValueError: If both inputs are None
184
185
    Returns:
186
        DataLoader
187
    """
188
    dataset = make_dataset(cat_list, con_list, mask)
189
    return DataLoader(dataset, **kwargs)
190
191
192
def split_samples(
193
    num_samples: int,
194
    train_frac: float,
195
) -> BoolArray:
196
    """Generate mask to randomly split samples into training and test sets.
197
198
    Args:
199
        num_samples: Number of samples to split.
200
        train_frac: Fraction of samples corresponding to training set.
201
202
    Returns:
203
        Boolean array to mask test samples.
204
    """
205
    sample_ids = np.arange(num_samples)
206
    train_size = int(train_frac * num_samples)
207
208
    rng = np.random.default_rng()
209
    train_ids = rng.permutation(sample_ids)[:train_size]
210
211
    return np.isin(sample_ids, train_ids)