Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery
[c23b31]: / src / move / training / training_loop.py

Download this file

102 lines (80 with data), 3.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
from typing import Optional
from torch.utils.data import DataLoader
from move.models.vae import VAE
TrainingLoopOutput = tuple[list[float], list[float], list[float], list[float], float]
def dilate_batch(dataloader: DataLoader) -> DataLoader:
"""
Increase the batch size of a dataloader.
Args:
dataloader (DataLoader): An object feeding data to the VAE
Returns:
DataLoader: An object feeding data to the VAE
"""
assert dataloader.batch_size is not None
dataset = dataloader.dataset
batch_size = int(dataloader.batch_size * 1.5)
return DataLoader(dataset, batch_size, shuffle=True, drop_last=True)
BATCH_DILATION_STEPS = []
KLD_WARMUP_STEPS = []
def training_loop(
model: VAE,
train_dataloader: DataLoader,
valid_dataloader: Optional[DataLoader] = None,
lr: float = 1e-4,
num_epochs: int = 100,
batch_dilation_steps: list[int] = BATCH_DILATION_STEPS,
kld_warmup_steps: list[int] = KLD_WARMUP_STEPS,
early_stopping: bool = False,
patience: int = 0,
) -> TrainingLoopOutput:
"""
Trains a VAE model with batch dilation and KLD warm-up. Optionally,
enforce early stopping.
Args:
model (VAE): trained VAE model object
train_dataloader (DataLoader): An object feeding data to the VAE
with training data
valid_dataloader (Optional[DataLoader], optional): An object feeding data to the
VAE with validation data. Defaults to None.
lr (float, optional): learning rate. Defaults to 1e-4.
num_epochs (int, optional): number of epochs. Defaults to 100.
batch_dilation_steps (list[int], optional): a list with integers corresponding
to epochs when batch size is increased. Defaults to [].
kld_warmup_steps (list[int], optional): a list with integers corresponding to
epochs when kld is decreased by the selected rate. Defaults to [].
early_stopping (bool, optional): boolean if use early stopping.
Defaults to False.
patience (int, optional): number of epochs to wait before early stop
if no progress on the validation set. Defaults to 0.
Returns:
(tuple): a tuple containing:
*outputs (*list): lists containing information of epoch loss, BCE loss,
SSE loss, KLD loss
kld_weight (float): final KLD after dilations during the training
"""
outputs = [[] for _ in range(4)]
min_likelihood = float("inf")
counter = 0
kld_weight = 0.0
for epoch in range(1, num_epochs + 1):
if epoch in kld_warmup_steps:
kld_weight += 1 / len(kld_warmup_steps)
if epoch in batch_dilation_steps:
train_dataloader = dilate_batch(train_dataloader)
for i, output in enumerate(
model.encoding(train_dataloader, epoch, lr, kld_weight)
):
outputs[i].append(output)
if early_stopping and valid_dataloader is not None:
output = model.latent(valid_dataloader, kld_weight)
valid_likelihood = output[-1]
if valid_likelihood > min_likelihood and counter < patience:
counter += 1
if counter % 5 == 0:
lr *= 0.9
elif counter == patience:
break
else:
min_likelihood = valid_likelihood
counter = 0
return *outputs, kld_weight