Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to unified view

a b/src/move/data/perturbations.py
1
__all__ = ["perturb_categorical_data", "perturb_continuous_data"]
2
3
from pathlib import Path
4
from typing import Literal, Optional, cast
5
6
import numpy as np
7
import torch
8
from torch.utils.data import DataLoader
9
10
from move.data.dataloaders import MOVEDataset
11
from move.data.preprocessing import feature_stats
12
from move.visualization.dataset_distributions import plot_value_distributions
13
14
ContinuousPerturbationType = Literal["minimum", "maximum", "plus_std", "minus_std"]
15
16
17
def perturb_categorical_data(
18
    baseline_dataloader: DataLoader,
19
    cat_dataset_names: list[str],
20
    target_dataset_name: str,
21
    target_value: np.ndarray,
22
) -> list[DataLoader]:
23
    """Add perturbations to categorical data. For each feature in the target
24
    dataset, change its value to target.
25
26
    Args:
27
        baseline_dataloader: Baseline dataloader
28
        cat_dataset_names: List of categorical dataset names
29
        target_dataset_name: Target categorical dataset to perturb
30
        target_value: Target value
31
32
    Returns:
33
        List of dataloaders containing all perturbed datasets
34
    """
35
36
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
37
    assert baseline_dataset.cat_shapes is not None
38
    assert baseline_dataset.cat_all is not None
39
40
    target_idx = cat_dataset_names.index(target_dataset_name)
41
    splits = np.cumsum(
42
        [0] + [int.__mul__(*shape) for shape in baseline_dataset.cat_shapes]
43
    )
44
    slice_ = slice(*splits[target_idx : target_idx + 2])
45
46
    target_shape = baseline_dataset.cat_shapes[target_idx]
47
    num_features = target_shape[0]  # CHANGE
48
49
    dataloaders = []
50
    for i in range(num_features):
51
        perturbed_cat = baseline_dataset.cat_all.clone()
52
        target_dataset = perturbed_cat[:, slice_].view(
53
            baseline_dataset.num_samples, *target_shape
54
        )
55
        target_dataset[:, i, :] = torch.FloatTensor(target_value)
56
        perturbed_dataset = MOVEDataset(
57
            perturbed_cat,
58
            baseline_dataset.con_all,
59
            baseline_dataset.cat_shapes,
60
            baseline_dataset.con_shapes,
61
        )
62
        perturbed_dataloader = DataLoader(
63
            perturbed_dataset,
64
            shuffle=False,
65
            batch_size=baseline_dataloader.batch_size,
66
        )
67
        dataloaders.append(perturbed_dataloader)
68
    return dataloaders
69
70
71
def perturb_continuous_data(
72
    baseline_dataloader: DataLoader,
73
    con_dataset_names: list[str],
74
    target_dataset_name: str,
75
    target_value: float,
76
) -> list[DataLoader]:
77
    """Add perturbations to continuous data. For each feature in the target
78
    dataset, change its value to target.
79
80
    Args:
81
        baseline_dataloader: Baseline dataloader
82
        con_dataset_names: List of continuous dataset names
83
        target_dataset_name: Target continuous dataset to perturb
84
        target_value: Target value
85
86
    Returns:
87
        List of dataloaders containing all perturbed datasets
88
    """
89
90
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
91
    assert baseline_dataset.con_shapes is not None
92
    assert baseline_dataset.con_all is not None
93
94
    target_idx = con_dataset_names.index(target_dataset_name)
95
    splits = np.cumsum([0] + baseline_dataset.con_shapes)
96
    slice_ = slice(*splits[target_idx : target_idx + 2])
97
98
    num_features = baseline_dataset.con_shapes[target_idx]
99
100
    dataloaders = []
101
    for i in range(num_features):
102
        perturbed_con = baseline_dataset.con_all.clone()
103
        target_dataset = perturbed_con[:, slice_]
104
        target_dataset[:, i] = torch.FloatTensor([target_value])
105
        perturbed_dataset = MOVEDataset(
106
            baseline_dataset.cat_all,
107
            perturbed_con,
108
            baseline_dataset.cat_shapes,
109
            baseline_dataset.con_shapes,
110
        )
111
        perturbed_dataloader = DataLoader(
112
            perturbed_dataset,
113
            shuffle=False,
114
            batch_size=baseline_dataloader.batch_size,
115
        )
116
        dataloaders.append(perturbed_dataloader)
117
118
    return dataloaders
119
120
121
def perturb_continuous_data_extended(
122
    baseline_dataloader: DataLoader,
123
    con_dataset_names: list[str],
124
    target_dataset_name: str,
125
    perturbation_type: ContinuousPerturbationType,
126
    output_subpath: Optional[Path] = None,
127
) -> list[DataLoader]:
128
    """Add perturbations to continuous data. For each feature in the target
129
    dataset, change the feature's value in all samples (in rows):
130
    1,2) substituting this feature in all samples by the feature's minimum/maximum value
131
    3,4) Adding/Substracting one standard deviation to the sample's feature value
132
133
    Args:
134
        baseline_dataloader: Baseline dataloader
135
        con_dataset_names: List of continuous dataset names
136
        target_dataset_name: Target continuous dataset to perturb
137
        perturbation_type: 'minimum', 'maximum', 'plus_std' or 'minus_std'.
138
        output_subpath: path where the figure showing the perturbation will be saved
139
140
    Returns:
141
        - List of dataloaders containing all perturbed datasets
142
        - Plot of the feature value distribution after the perturbation. Note that
143
          all perturbations are collapsed into one single plot.
144
145
    Note:
146
        This function was created so that it could generalize to non-normalized
147
        datasets. Scaling is done per dataset, not per feature -> slightly different
148
        stds feature to feature.
149
    """
150
151
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
152
    assert baseline_dataset.con_shapes is not None
153
    assert baseline_dataset.con_all is not None
154
155
    target_idx = con_dataset_names.index(target_dataset_name)  # dataset index
156
    splits = np.cumsum([0] + baseline_dataset.con_shapes)
157
    slice_ = slice(*splits[target_idx : target_idx + 2])
158
159
    num_features = baseline_dataset.con_shapes[target_idx]
160
    dataloaders = []
161
    perturbations_list = []
162
163
    for i in range(num_features):
164
        perturbed_con = baseline_dataset.con_all.clone()
165
        target_dataset = perturbed_con[:, slice_]
166
        # Change the desired feature value by:
167
        min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats(
168
            target_dataset
169
        )
170
        if perturbation_type == "minimum":
171
            target_dataset[:, i] = torch.FloatTensor([min_feat_val_list[i]])
172
        elif perturbation_type == "maximum":
173
            target_dataset[:, i] = torch.FloatTensor([max_feat_val_list[i]])
174
        elif perturbation_type == "plus_std":
175
            target_dataset[:, i] += torch.FloatTensor([std_feat_val_list[i]])
176
        elif perturbation_type == "minus_std":
177
            target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]])
178
179
        perturbations_list.append(target_dataset[:, i].numpy())
180
181
        perturbed_dataset = MOVEDataset(
182
            baseline_dataset.cat_all,
183
            perturbed_con,
184
            baseline_dataset.cat_shapes,
185
            baseline_dataset.con_shapes,
186
        )
187
188
        perturbed_dataloader = DataLoader(
189
            perturbed_dataset,
190
            shuffle=False,
191
            batch_size=baseline_dataloader.batch_size,
192
        )
193
        dataloaders.append(perturbed_dataloader)
194
195
    # Plot the perturbations for all features, collapsed in one plot:
196
    if output_subpath is not None:
197
        fig = plot_value_distributions(np.array(perturbations_list).transpose())
198
        fig_path = str(
199
            output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
200
        )
201
        fig.savefig(fig_path)
202
203
    return dataloaders