Data: Tabular Time Series Specialty: Endocrinology Laboratory: Blood Tests EHR: Demographics Diagnoses Medications Omics: Genomics Multi-omics Transcriptomics Wearable: Activity Clinical Purpose: Treatment Response Assessment Task: Biomarker Discovery

Switch to unified view

a b/src/move/training/training_loop.py
1
from typing import Optional
2
3
from torch.utils.data import DataLoader
4
5
from move.models.vae import VAE
6
7
TrainingLoopOutput = tuple[list[float], list[float], list[float], list[float], float]
8
9
10
def dilate_batch(dataloader: DataLoader) -> DataLoader:
11
    """
12
    Increase the batch size of a dataloader.
13
14
    Args:
15
        dataloader (DataLoader): An object feeding data to the VAE
16
17
    Returns:
18
        DataLoader: An object feeding data to the VAE
19
    """
20
    assert dataloader.batch_size is not None
21
    dataset = dataloader.dataset
22
    batch_size = int(dataloader.batch_size * 1.5)
23
    return DataLoader(dataset, batch_size, shuffle=True, drop_last=True)
24
25
26
BATCH_DILATION_STEPS = []
27
KLD_WARMUP_STEPS = []
28
29
30
def training_loop(
31
    model: VAE,
32
    train_dataloader: DataLoader,
33
    valid_dataloader: Optional[DataLoader] = None,
34
    lr: float = 1e-4,
35
    num_epochs: int = 100,
36
    batch_dilation_steps: list[int] = BATCH_DILATION_STEPS,
37
    kld_warmup_steps: list[int] = KLD_WARMUP_STEPS,
38
    early_stopping: bool = False,
39
    patience: int = 0,
40
) -> TrainingLoopOutput:
41
    """
42
    Trains a VAE model with batch dilation and KLD warm-up. Optionally,
43
    enforce early stopping.
44
45
    Args:
46
        model (VAE): trained VAE model object
47
        train_dataloader (DataLoader):  An object feeding data to the VAE
48
                                        with training data
49
        valid_dataloader (Optional[DataLoader], optional): An object feeding data to the
50
                                            VAE with validation data. Defaults to None.
51
        lr (float, optional): learning rate. Defaults to 1e-4.
52
        num_epochs (int, optional): number of epochs. Defaults to 100.
53
        batch_dilation_steps (list[int], optional): a list with integers corresponding
54
                                to epochs when batch size is increased. Defaults to [].
55
        kld_warmup_steps (list[int], optional):  a list with integers corresponding to
56
                    epochs when kld is decreased by the selected rate. Defaults to [].
57
        early_stopping (bool, optional): boolean if use early stopping.
58
                                         Defaults to False.
59
60
        patience (int, optional): number of epochs to wait before early stop
61
                                  if no progress on the validation set. Defaults to 0.
62
63
    Returns:
64
        (tuple): a tuple containing:
65
            *outputs (*list): lists containing information of epoch loss, BCE loss,
66
                              SSE loss, KLD loss
67
            kld_weight (float): final KLD after dilations during the training
68
    """
69
70
    outputs = [[] for _ in range(4)]
71
    min_likelihood = float("inf")
72
    counter = 0
73
74
    kld_weight = 0.0
75
76
    for epoch in range(1, num_epochs + 1):
77
        if epoch in kld_warmup_steps:
78
            kld_weight += 1 / len(kld_warmup_steps)
79
80
        if epoch in batch_dilation_steps:
81
            train_dataloader = dilate_batch(train_dataloader)
82
83
        for i, output in enumerate(
84
            model.encoding(train_dataloader, epoch, lr, kld_weight)
85
        ):
86
            outputs[i].append(output)
87
88
        if early_stopping and valid_dataloader is not None:
89
            output = model.latent(valid_dataloader, kld_weight)
90
            valid_likelihood = output[-1]
91
            if valid_likelihood > min_likelihood and counter < patience:
92
                counter += 1
93
                if counter % 5 == 0:
94
                    lr *= 0.9
95
            elif counter == patience:
96
                break
97
            else:
98
                min_likelihood = valid_likelihood
99
                counter = 0
100
101
    return *outputs, kld_weight