Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery
[c23b31]: / src / move / conf / schema.py

Download this file

272 lines (205 with data), 7.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
__all__ = [
"MOVEConfig",
"EncodeDataConfig",
"AnalyzeLatentConfig",
"TuneModelReconstructionConfig",
"TuneModelStabilityConfig",
"IdentifyAssociationsConfig",
"IdentifyAssociationsBayesConfig",
"IdentifyAssociationsTTestConfig",
]
from dataclasses import dataclass, field
from typing import Any, Optional
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
from move.models.vae import VAE
from move.training.training_loop import training_loop
def get_fully_qualname(sth: Any) -> str:
return ".".join((sth.__module__, sth.__qualname__))
@dataclass
class InputConfig:
name: str
weight: int = 1
@dataclass
class ContinuousInputConfig(InputConfig):
scale: bool = True
log2: bool = False
@dataclass
class DataConfig:
raw_data_path: str = MISSING
interim_data_path: str = MISSING
results_path: str = MISSING
sample_names: str = MISSING
categorical_inputs: list[InputConfig] = MISSING
continuous_inputs: list[ContinuousInputConfig] = MISSING
categorical_names: list[str] = MISSING
continuous_names: list[str] = MISSING
categorical_weights: list[int] = MISSING
continuous_weights: list[int] = MISSING
@dataclass
class ModelConfig:
_target_: str = MISSING
cuda: bool = MISSING
@dataclass
class VAEConfig(ModelConfig):
"""Configuration for the VAE module."""
_target_: str = get_fully_qualname(VAE)
categorical_weights: list[int] = MISSING
continuous_weights: list[int] = MISSING
num_hidden: list[int] = MISSING
num_latent: int = MISSING
beta: float = MISSING
dropout: float = MISSING
cuda: bool = False
@dataclass
class TrainingLoopConfig:
_target_: str = get_fully_qualname(training_loop)
num_epochs: int = MISSING
lr: float = MISSING
kld_warmup_steps: list[int] = MISSING
batch_dilation_steps: list[int] = MISSING
early_stopping: bool = MISSING
patience: int = MISSING
@dataclass
class TaskConfig:
"""Configuration for a MOVE task.
Attributes:
batch_size: Number of samples in a training batch.
model: Configuration for a model.
training_loop: Configuration for a training loop.
"""
batch_size: Optional[int]
model: Optional[VAEConfig]
training_loop: Optional[TrainingLoopConfig]
@dataclass
class EncodeDataConfig(TaskConfig):
"""Configuration for a data-encoding task."""
batch_size = None
model = None
training_loop = None
@dataclass
class TuneModelConfig(TaskConfig):
"""Configure the "tune model" task."""
...
@dataclass
class TuneModelStabilityConfig(TuneModelConfig):
"""Configure the "tune model" task."""
num_refits: int = MISSING
@dataclass
class TuneModelReconstructionConfig(TuneModelConfig):
"""Configure the "tune model" task."""
...
@dataclass
class AnalyzeLatentConfig(TaskConfig):
"""Configure the "analyze latents" task.
Attributes:
feature_names:
Names of features to visualize."""
feature_names: list[str] = field(default_factory=list)
reducer: dict[str, Any] = MISSING
@dataclass
class IdentifyAssociationsConfig(TaskConfig):
"""Configure the "identify associations" task.
Attributes:
target_dataset:
Name of categorical dataset to perturb.
target_value:
The value to change to. It should be a category name.
num_refits:
Number of times to refit the model.
sig_threshold:
Threshold used to determine whether an association is significant.
In the t-test approach, this is called significance level (alpha).
In the probabilistc approach, significant associations are selected
if their FDR is below this threshold.
This value should be within the range [0, 1].
save_models:
Whether to save the weights of each refit. If weights are saved,
rerunning the task will load them instead of training.
"""
target_dataset: str = MISSING
target_value: str = MISSING
num_refits: int = MISSING
sig_threshold: float = 0.05
save_refits: bool = False
@dataclass
class IdentifyAssociationsBayesConfig(IdentifyAssociationsConfig):
"""Configure the probabilistic approach to identify associations."""
...
@dataclass
class IdentifyAssociationsTTestConfig(IdentifyAssociationsConfig):
"""Configure the t-test approach to identify associations.
Args:
num_latent:
List of latent space dimensions to train. It should contain four
elements.
"""
num_latent: list[int] = MISSING
@dataclass
class IdentifyAssociationsKSConfig(IdentifyAssociationsConfig):
"""Configure the Kolmogorov-Smirnov approach to identify associations.
Args:
perturbed_feature_names: names of the perturbed features of interest.
target_feature_names: names of the target features of interest.
Description:
For each perturbed feature - target feature pair, we will plot:
- Input vs. reconstruction correlation plot: to assess reconstruction
quality of both target and perturbed features.
- Distribution of reconstruction values for the target feature before
and after the perturbation of the perturbed feature.
"""
perturbed_feature_names: list[str] = field(default_factory=list)
target_feature_names: list[str] = field(default_factory=list)
@dataclass
class MOVEConfig:
defaults: list[Any] = field(default_factory=lambda: [dict(data="base_data")])
data: DataConfig = MISSING
task: TaskConfig = MISSING
seed: Optional[int] = None
def extract_weights(configs: list[InputConfig]) -> list[int]:
"""Extracts the weights from a list of input configs."""
return [1 if not hasattr(item, "weight") else item.weight for item in configs]
def extract_names(configs: list[InputConfig]) -> list[str]:
"""Extracts the weights from a list of input configs."""
return [item.name for item in configs]
# Store config schema
cs = ConfigStore.instance()
cs.store(name="config_schema", node=MOVEConfig)
cs.store(
group="task",
name="encode_data",
node=EncodeDataConfig,
)
cs.store(
group="task",
name="tune_model_reconstruction_schema",
node=TuneModelReconstructionConfig,
)
cs.store(
group="task",
name="tune_model_stability_schema",
node=TuneModelStabilityConfig,
)
cs.store(
group="task",
name="analyze_latent_schema",
node=AnalyzeLatentConfig,
)
cs.store(
group="task",
name="identify_associations_bayes_schema",
node=IdentifyAssociationsBayesConfig,
)
cs.store(
group="task",
name="identify_associations_ttest_schema",
node=IdentifyAssociationsTTestConfig,
)
cs.store(
group="task",
name="identify_associations_ks_schema",
node=IdentifyAssociationsKSConfig,
)
# Register custom resolvers
OmegaConf.register_new_resolver("weights", extract_weights)
OmegaConf.register_new_resolver("names", extract_names)