Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery
[c23b31]: / src / move / data / perturbations.py

Download this file

204 lines (168 with data), 7.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
__all__ = ["perturb_categorical_data", "perturb_continuous_data"]
from pathlib import Path
from typing import Literal, Optional, cast
import numpy as np
import torch
from torch.utils.data import DataLoader
from move.data.dataloaders import MOVEDataset
from move.data.preprocessing import feature_stats
from move.visualization.dataset_distributions import plot_value_distributions
ContinuousPerturbationType = Literal["minimum", "maximum", "plus_std", "minus_std"]
def perturb_categorical_data(
baseline_dataloader: DataLoader,
cat_dataset_names: list[str],
target_dataset_name: str,
target_value: np.ndarray,
) -> list[DataLoader]:
"""Add perturbations to categorical data. For each feature in the target
dataset, change its value to target.
Args:
baseline_dataloader: Baseline dataloader
cat_dataset_names: List of categorical dataset names
target_dataset_name: Target categorical dataset to perturb
target_value: Target value
Returns:
List of dataloaders containing all perturbed datasets
"""
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.cat_shapes is not None
assert baseline_dataset.cat_all is not None
target_idx = cat_dataset_names.index(target_dataset_name)
splits = np.cumsum(
[0] + [int.__mul__(*shape) for shape in baseline_dataset.cat_shapes]
)
slice_ = slice(*splits[target_idx : target_idx + 2])
target_shape = baseline_dataset.cat_shapes[target_idx]
num_features = target_shape[0] # CHANGE
dataloaders = []
for i in range(num_features):
perturbed_cat = baseline_dataset.cat_all.clone()
target_dataset = perturbed_cat[:, slice_].view(
baseline_dataset.num_samples, *target_shape
)
target_dataset[:, i, :] = torch.FloatTensor(target_value)
perturbed_dataset = MOVEDataset(
perturbed_cat,
baseline_dataset.con_all,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
return dataloaders
def perturb_continuous_data(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
target_value: float,
) -> list[DataLoader]:
"""Add perturbations to continuous data. For each feature in the target
dataset, change its value to target.
Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
target_value: Target value
Returns:
List of dataloaders containing all perturbed datasets
"""
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None
target_idx = con_dataset_names.index(target_dataset_name)
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])
num_features = baseline_dataset.con_shapes[target_idx]
dataloaders = []
for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
target_dataset[:, i] = torch.FloatTensor([target_value])
perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
return dataloaders
def perturb_continuous_data_extended(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
perturbation_type: ContinuousPerturbationType,
output_subpath: Optional[Path] = None,
) -> list[DataLoader]:
"""Add perturbations to continuous data. For each feature in the target
dataset, change the feature's value in all samples (in rows):
1,2) substituting this feature in all samples by the feature's minimum/maximum value
3,4) Adding/Substracting one standard deviation to the sample's feature value
Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
perturbation_type: 'minimum', 'maximum', 'plus_std' or 'minus_std'.
output_subpath: path where the figure showing the perturbation will be saved
Returns:
- List of dataloaders containing all perturbed datasets
- Plot of the feature value distribution after the perturbation. Note that
all perturbations are collapsed into one single plot.
Note:
This function was created so that it could generalize to non-normalized
datasets. Scaling is done per dataset, not per feature -> slightly different
stds feature to feature.
"""
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None
target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])
num_features = baseline_dataset.con_shapes[target_idx]
dataloaders = []
perturbations_list = []
for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
# Change the desired feature value by:
min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats(
target_dataset
)
if perturbation_type == "minimum":
target_dataset[:, i] = torch.FloatTensor([min_feat_val_list[i]])
elif perturbation_type == "maximum":
target_dataset[:, i] = torch.FloatTensor([max_feat_val_list[i]])
elif perturbation_type == "plus_std":
target_dataset[:, i] += torch.FloatTensor([std_feat_val_list[i]])
elif perturbation_type == "minus_std":
target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]])
perturbations_list.append(target_dataset[:, i].numpy())
perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
# Plot the perturbations for all features, collapsed in one plot:
if output_subpath is not None:
fig = plot_value_distributions(np.array(perturbations_list).transpose())
fig_path = str(
output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
)
fig.savefig(fig_path)
return dataloaders