|
a |
|
b/src/move/training/training_loop.py |
|
|
1 |
from typing import Optional |
|
|
2 |
|
|
|
3 |
from torch.utils.data import DataLoader |
|
|
4 |
|
|
|
5 |
from move.models.vae import VAE |
|
|
6 |
|
|
|
7 |
TrainingLoopOutput = tuple[list[float], list[float], list[float], list[float], float] |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def dilate_batch(dataloader: DataLoader) -> DataLoader: |
|
|
11 |
""" |
|
|
12 |
Increase the batch size of a dataloader. |
|
|
13 |
|
|
|
14 |
Args: |
|
|
15 |
dataloader (DataLoader): An object feeding data to the VAE |
|
|
16 |
|
|
|
17 |
Returns: |
|
|
18 |
DataLoader: An object feeding data to the VAE |
|
|
19 |
""" |
|
|
20 |
assert dataloader.batch_size is not None |
|
|
21 |
dataset = dataloader.dataset |
|
|
22 |
batch_size = int(dataloader.batch_size * 1.5) |
|
|
23 |
return DataLoader(dataset, batch_size, shuffle=True, drop_last=True) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
BATCH_DILATION_STEPS = [] |
|
|
27 |
KLD_WARMUP_STEPS = [] |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def training_loop( |
|
|
31 |
model: VAE, |
|
|
32 |
train_dataloader: DataLoader, |
|
|
33 |
valid_dataloader: Optional[DataLoader] = None, |
|
|
34 |
lr: float = 1e-4, |
|
|
35 |
num_epochs: int = 100, |
|
|
36 |
batch_dilation_steps: list[int] = BATCH_DILATION_STEPS, |
|
|
37 |
kld_warmup_steps: list[int] = KLD_WARMUP_STEPS, |
|
|
38 |
early_stopping: bool = False, |
|
|
39 |
patience: int = 0, |
|
|
40 |
) -> TrainingLoopOutput: |
|
|
41 |
""" |
|
|
42 |
Trains a VAE model with batch dilation and KLD warm-up. Optionally, |
|
|
43 |
enforce early stopping. |
|
|
44 |
|
|
|
45 |
Args: |
|
|
46 |
model (VAE): trained VAE model object |
|
|
47 |
train_dataloader (DataLoader): An object feeding data to the VAE |
|
|
48 |
with training data |
|
|
49 |
valid_dataloader (Optional[DataLoader], optional): An object feeding data to the |
|
|
50 |
VAE with validation data. Defaults to None. |
|
|
51 |
lr (float, optional): learning rate. Defaults to 1e-4. |
|
|
52 |
num_epochs (int, optional): number of epochs. Defaults to 100. |
|
|
53 |
batch_dilation_steps (list[int], optional): a list with integers corresponding |
|
|
54 |
to epochs when batch size is increased. Defaults to []. |
|
|
55 |
kld_warmup_steps (list[int], optional): a list with integers corresponding to |
|
|
56 |
epochs when kld is decreased by the selected rate. Defaults to []. |
|
|
57 |
early_stopping (bool, optional): boolean if use early stopping. |
|
|
58 |
Defaults to False. |
|
|
59 |
|
|
|
60 |
patience (int, optional): number of epochs to wait before early stop |
|
|
61 |
if no progress on the validation set. Defaults to 0. |
|
|
62 |
|
|
|
63 |
Returns: |
|
|
64 |
(tuple): a tuple containing: |
|
|
65 |
*outputs (*list): lists containing information of epoch loss, BCE loss, |
|
|
66 |
SSE loss, KLD loss |
|
|
67 |
kld_weight (float): final KLD after dilations during the training |
|
|
68 |
""" |
|
|
69 |
|
|
|
70 |
outputs = [[] for _ in range(4)] |
|
|
71 |
min_likelihood = float("inf") |
|
|
72 |
counter = 0 |
|
|
73 |
|
|
|
74 |
kld_weight = 0.0 |
|
|
75 |
|
|
|
76 |
for epoch in range(1, num_epochs + 1): |
|
|
77 |
if epoch in kld_warmup_steps: |
|
|
78 |
kld_weight += 1 / len(kld_warmup_steps) |
|
|
79 |
|
|
|
80 |
if epoch in batch_dilation_steps: |
|
|
81 |
train_dataloader = dilate_batch(train_dataloader) |
|
|
82 |
|
|
|
83 |
for i, output in enumerate( |
|
|
84 |
model.encoding(train_dataloader, epoch, lr, kld_weight) |
|
|
85 |
): |
|
|
86 |
outputs[i].append(output) |
|
|
87 |
|
|
|
88 |
if early_stopping and valid_dataloader is not None: |
|
|
89 |
output = model.latent(valid_dataloader, kld_weight) |
|
|
90 |
valid_likelihood = output[-1] |
|
|
91 |
if valid_likelihood > min_likelihood and counter < patience: |
|
|
92 |
counter += 1 |
|
|
93 |
if counter % 5 == 0: |
|
|
94 |
lr *= 0.9 |
|
|
95 |
elif counter == patience: |
|
|
96 |
break |
|
|
97 |
else: |
|
|
98 |
min_likelihood = valid_likelihood |
|
|
99 |
counter = 0 |
|
|
100 |
|
|
|
101 |
return *outputs, kld_weight |