|
a |
|
b/custom_byol_bolts.py |
|
|
1 |
import math |
|
|
2 |
from argparse import ArgumentParser |
|
|
3 |
from copy import deepcopy |
|
|
4 |
from typing import Any |
|
|
5 |
|
|
|
6 |
import pytorch_lightning as pl |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
import torch.nn.functional as F |
|
|
10 |
from pytorch_lightning import seed_everything |
|
|
11 |
from torch.optim import Adam |
|
|
12 |
|
|
|
13 |
from pl_bolts.models.self_supervised import BYOL |
|
|
14 |
# from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate |
|
|
15 |
from pl_bolts.optimizers.lars_scheduling import LARSWrapper |
|
|
16 |
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR |
|
|
17 |
|
|
|
18 |
from models.resnet_simclr import ResNetSimCLR |
|
|
19 |
import re |
|
|
20 |
|
|
|
21 |
import time |
|
|
22 |
|
|
|
23 |
import yaml |
|
|
24 |
import logging |
|
|
25 |
import os |
|
|
26 |
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper |
|
|
27 |
from clinical_ts.create_logger import create_logger |
|
|
28 |
import pickle |
|
|
29 |
from pytorch_lightning import Trainer, seed_everything |
|
|
30 |
|
|
|
31 |
from torch import nn |
|
|
32 |
from torch.nn import functional as F |
|
|
33 |
from online_evaluator import SSLOnlineEvaluator |
|
|
34 |
from ecg_datamodule import ECGDataModule |
|
|
35 |
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
36 |
import pdb |
|
|
37 |
|
|
|
38 |
logger = create_logger(__name__) |
|
|
39 |
method="byol" |
|
|
40 |
def mean(res, key1, key2=None): |
|
|
41 |
if key2 is not None: |
|
|
42 |
return torch.stack([x[key1][key2] for x in res]).mean() |
|
|
43 |
return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean() |
|
|
44 |
|
|
|
45 |
class MLP(nn.Module): |
|
|
46 |
def __init__(self, input_dim=512, hidden_size=4096, output_dim=256): |
|
|
47 |
super().__init__() |
|
|
48 |
self.output_dim = output_dim |
|
|
49 |
self.input_dim = input_dim |
|
|
50 |
self.model = nn.Sequential( |
|
|
51 |
nn.Linear(input_dim, hidden_size, bias=False), |
|
|
52 |
nn.BatchNorm1d(hidden_size), |
|
|
53 |
nn.ReLU(inplace=True), |
|
|
54 |
nn.Linear(hidden_size, output_dim, bias=True)) |
|
|
55 |
|
|
|
56 |
def forward(self, x): |
|
|
57 |
x = self.model(x) |
|
|
58 |
return x |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
class SiameseArm(nn.Module): |
|
|
62 |
def __init__(self, encoder=None, out_dim=128, hidden_size=512, projector_dim=512): |
|
|
63 |
super().__init__() |
|
|
64 |
|
|
|
65 |
if encoder is None: |
|
|
66 |
encoder = torchvision_ssl_encoder('resnet50') |
|
|
67 |
# Encoder |
|
|
68 |
self.encoder = encoder |
|
|
69 |
# Pooler |
|
|
70 |
self.pooler = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
71 |
# Projector |
|
|
72 |
projector_dim = encoder.l1.in_features |
|
|
73 |
self.projector = MLP( |
|
|
74 |
input_dim=projector_dim, hidden_size=hidden_size, output_dim=out_dim) |
|
|
75 |
# Predictor |
|
|
76 |
self.predictor = MLP( |
|
|
77 |
input_dim=out_dim, hidden_size=hidden_size, output_dim=out_dim) |
|
|
78 |
|
|
|
79 |
def forward(self, x): |
|
|
80 |
y = self.encoder(x)[0] |
|
|
81 |
y = y.view(y.size(0), -1) |
|
|
82 |
z = self.projector(y) |
|
|
83 |
h = self.predictor(z) |
|
|
84 |
return y, z, h |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
class BYOLMAWeightUpdate(pl.Callback): |
|
|
88 |
def __init__(self, initial_tau=0.996): |
|
|
89 |
""" |
|
|
90 |
Args: |
|
|
91 |
initial_tau: starting tau. Auto-updates with every training step |
|
|
92 |
""" |
|
|
93 |
super().__init__() |
|
|
94 |
self.initial_tau = initial_tau |
|
|
95 |
self.current_tau = initial_tau |
|
|
96 |
|
|
|
97 |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): |
|
|
98 |
# get networks |
|
|
99 |
online_net = pl_module.online_network |
|
|
100 |
target_net = pl_module.target_network |
|
|
101 |
|
|
|
102 |
# update weights |
|
|
103 |
self.update_weights(online_net, target_net) |
|
|
104 |
|
|
|
105 |
# update tau after |
|
|
106 |
self.current_tau = self.update_tau(pl_module, trainer) |
|
|
107 |
|
|
|
108 |
def update_tau(self, pl_module, trainer): |
|
|
109 |
max_steps = len(trainer.train_dataloader) * trainer.max_epochs |
|
|
110 |
tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * |
|
|
111 |
pl_module.global_step / max_steps) + 1) / 2 |
|
|
112 |
return tau |
|
|
113 |
|
|
|
114 |
def update_weights(self, online_net, target_net): |
|
|
115 |
# apply MA weight update |
|
|
116 |
for (name, online_p), (_, target_p) in zip(online_net.named_parameters(), target_net.named_parameters()): |
|
|
117 |
if 'weight' in name: |
|
|
118 |
target_p.data = self.current_tau * target_p.data + \ |
|
|
119 |
(1 - self.current_tau) * online_p.data |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
class CustomBYOL(pl.LightningModule): |
|
|
123 |
def __init__(self, |
|
|
124 |
num_classes=5, |
|
|
125 |
learning_rate: float = 0.2, |
|
|
126 |
weight_decay: float = 1.5e-6, |
|
|
127 |
input_height: int = 32, |
|
|
128 |
batch_size: int = 32, |
|
|
129 |
num_workers: int = 0, |
|
|
130 |
warmup_epochs: int = 10, |
|
|
131 |
max_epochs: int = 1000, |
|
|
132 |
config=None, |
|
|
133 |
transformations=None, |
|
|
134 |
**kwargs): |
|
|
135 |
""" |
|
|
136 |
Args: |
|
|
137 |
datamodule: The datamodule |
|
|
138 |
learning_rate: the learning rate |
|
|
139 |
weight_decay: optimizer weight decay |
|
|
140 |
input_height: image input height |
|
|
141 |
batch_size: the batch size |
|
|
142 |
num_workers: number of workers |
|
|
143 |
warmup_epochs: num of epochs for scheduler warm up |
|
|
144 |
max_epochs: max epochs for scheduler |
|
|
145 |
""" |
|
|
146 |
super().__init__() |
|
|
147 |
self.save_hyperparameters() |
|
|
148 |
|
|
|
149 |
self.config = config |
|
|
150 |
self.transformations = transformations |
|
|
151 |
self.online_network = SiameseArm( |
|
|
152 |
encoder=self.init_model(), out_dim=config["model"]["out_dim"]) |
|
|
153 |
self.target_network = deepcopy(self.online_network) |
|
|
154 |
self.weight_callback = BYOLMAWeightUpdate() |
|
|
155 |
self.log_dict = {} |
|
|
156 |
self.epoch = 0 |
|
|
157 |
# self.model_device = self.online_network.encoder.features[0][0].weight.device |
|
|
158 |
|
|
|
159 |
def init_model(self): |
|
|
160 |
model = ResNetSimCLR(**self.config["model"]) |
|
|
161 |
# return model.features |
|
|
162 |
return model |
|
|
163 |
|
|
|
164 |
# def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: |
|
|
165 |
def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: |
|
|
166 |
# Add callback for user automatically since it's key to BYOL weight update |
|
|
167 |
self.weight_callback.on_train_batch_end( |
|
|
168 |
self.trainer, self, outputs, batch, batch_idx, 0) |
|
|
169 |
|
|
|
170 |
def forward(self, x): |
|
|
171 |
y, _, _ = self.online_network(x) |
|
|
172 |
return y |
|
|
173 |
|
|
|
174 |
def cosine_similarity(self, a, b): |
|
|
175 |
a = F.normalize(a, dim=-1) |
|
|
176 |
b = F.normalize(b, dim=-1) |
|
|
177 |
sim = (a * b).sum(-1).mean() |
|
|
178 |
return sim |
|
|
179 |
|
|
|
180 |
def shared_step(self, batch, batch_idx): |
|
|
181 |
# (img_1, img_2), y = batch |
|
|
182 |
(img_1, y1), (img_2, y2) = batch |
|
|
183 |
|
|
|
184 |
img_1 = self.to_device(img_1) |
|
|
185 |
img_2 = self.to_device(img_2) |
|
|
186 |
|
|
|
187 |
# Image 1 to image 2 loss |
|
|
188 |
y1, z1, h1 = self.online_network(img_1) |
|
|
189 |
with torch.no_grad(): |
|
|
190 |
y2, z2, h2 = self.target_network(img_2) |
|
|
191 |
loss_a = - 2 * self.cosine_similarity(h1, z2) |
|
|
192 |
|
|
|
193 |
# Image 2 to image 1 loss |
|
|
194 |
y1, z1, h1 = self.online_network(img_2) |
|
|
195 |
with torch.no_grad(): |
|
|
196 |
y2, z2, h2 = self.target_network(img_1) |
|
|
197 |
# L2 normalize |
|
|
198 |
loss_b = - 2 * self.cosine_similarity(h1, z2) |
|
|
199 |
|
|
|
200 |
# Final loss |
|
|
201 |
total_loss = loss_a + loss_b |
|
|
202 |
|
|
|
203 |
return loss_a, loss_b, total_loss |
|
|
204 |
|
|
|
205 |
def training_step(self, batch, batch_idx): |
|
|
206 |
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) |
|
|
207 |
|
|
|
208 |
# log results |
|
|
209 |
# result = pl.TrainResult(minimize=total_loss) |
|
|
210 |
# result.log('train_loss/1_2_loss', loss_a, on_epoch=True) |
|
|
211 |
# result.log('train_loss/2_1_loss', loss_b, on_epoch=True) |
|
|
212 |
# result.log('train_loss/total_loss', total_loss, on_epoch=True) |
|
|
213 |
|
|
|
214 |
# # log results |
|
|
215 |
# self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, |
|
|
216 |
# 'train_loss': total_loss}) |
|
|
217 |
|
|
|
218 |
return total_loss |
|
|
219 |
|
|
|
220 |
def validation_step(self, batch, batch_idx, dataloader_idx): |
|
|
221 |
if dataloader_idx != 0: |
|
|
222 |
return {} |
|
|
223 |
|
|
|
224 |
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) |
|
|
225 |
|
|
|
226 |
# # log results |
|
|
227 |
# result = pl.EvalResult() |
|
|
228 |
# result.log('val_loss/1_2_loss', loss_a, on_epoch=True) |
|
|
229 |
# result.log('val_loss/2_1_loss', loss_b, on_epoch=True) |
|
|
230 |
# result.log('val_loss/total_loss', total_loss, on_epoch=True) |
|
|
231 |
|
|
|
232 |
# self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, |
|
|
233 |
# 'train_loss': total_loss}) |
|
|
234 |
results = { |
|
|
235 |
'val_loss': total_loss, |
|
|
236 |
'val_1_2_loss' : loss_a, |
|
|
237 |
'val_2_1_loss': loss_b |
|
|
238 |
} |
|
|
239 |
return results |
|
|
240 |
|
|
|
241 |
def validation_epoch_end(self, outputs): |
|
|
242 |
# outputs[0] because we are using multiple datasets! |
|
|
243 |
val_loss = mean(outputs[0], 'val_loss') |
|
|
244 |
loss_a = mean(outputs[0], 'val_1_2_loss') |
|
|
245 |
loss_b = mean(outputs[0], 'val_2_1_loss') |
|
|
246 |
|
|
|
247 |
log = { |
|
|
248 |
'val_loss': val_loss, |
|
|
249 |
'val_1_2_loss' : loss_a, |
|
|
250 |
'val_2_1_loss': loss_b |
|
|
251 |
} |
|
|
252 |
|
|
|
253 |
return {'val_loss': val_loss, 'log': log, 'progress_bar': log} |
|
|
254 |
|
|
|
255 |
def configure_optimizers(self): |
|
|
256 |
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, |
|
|
257 |
weight_decay=self.hparams.weight_decay) |
|
|
258 |
# optimizer = LARSWrapper(optimizer) |
|
|
259 |
optimizer = optimizer |
|
|
260 |
scheduler = LinearWarmupCosineAnnealingLR( |
|
|
261 |
optimizer, |
|
|
262 |
warmup_epochs=self.hparams.warmup_epochs, |
|
|
263 |
max_epochs=self.hparams.max_epochs |
|
|
264 |
) |
|
|
265 |
return [optimizer], [scheduler] |
|
|
266 |
|
|
|
267 |
def on_train_start(self): |
|
|
268 |
# log configuration |
|
|
269 |
config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config)) |
|
|
270 |
config_str = re.sub(r"[\[\]\']", "", config_str) |
|
|
271 |
transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str( |
|
|
272 |
t) + ":<br/>" + str(t.get_params()) for t in self.transformations])) |
|
|
273 |
transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str) |
|
|
274 |
self.logger.experiment.add_text( |
|
|
275 |
"configuration", str(config_str), global_step=0) |
|
|
276 |
self.logger.experiment.add_text("transformations", str( |
|
|
277 |
transformation_str), global_step=0) |
|
|
278 |
self.epoch = 0 |
|
|
279 |
|
|
|
280 |
def on_epoch_end(self): |
|
|
281 |
self.epoch += 1 |
|
|
282 |
|
|
|
283 |
def get_representations(self, x): |
|
|
284 |
return self.online_network(x)[0] |
|
|
285 |
|
|
|
286 |
def get_model(self): |
|
|
287 |
return self.online_network.encoder |
|
|
288 |
|
|
|
289 |
def get_device(self): |
|
|
290 |
return self.online_network.encoder.features[0][0].weight.device |
|
|
291 |
|
|
|
292 |
def to_device(self, x): |
|
|
293 |
return x.type(self.type()).to(self.get_device()) |
|
|
294 |
|
|
|
295 |
def type(self): |
|
|
296 |
return self.online_network.encoder.features[0][0].weight.type() |
|
|
297 |
|
|
|
298 |
def parse_args(parent_parser): |
|
|
299 |
parser = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
300 |
parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline', |
|
|
301 |
default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"]) |
|
|
302 |
# GaussianNoise |
|
|
303 |
parser.add_argument( |
|
|
304 |
'--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float) |
|
|
305 |
# RandomResizedCrop |
|
|
306 |
parser.add_argument('--rr_crop_ratio_range', |
|
|
307 |
help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float) |
|
|
308 |
parser.add_argument( |
|
|
309 |
'--output_size', help='output size for random resized crop transformation', default=250, type=int) |
|
|
310 |
# DynamicTimeWarp |
|
|
311 |
parser.add_argument( |
|
|
312 |
'--warps', help='number of warps for dynamic time warp transformation', default=3, type=int) |
|
|
313 |
parser.add_argument( |
|
|
314 |
'--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int) |
|
|
315 |
# TimeWarp |
|
|
316 |
parser.add_argument( |
|
|
317 |
'--epsilon', help='epsilon param for time warp', default=10, type=float) |
|
|
318 |
# ChannelResize |
|
|
319 |
parser.add_argument('--magnitude_range', nargs='+', |
|
|
320 |
help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float) |
|
|
321 |
# Downsample |
|
|
322 |
parser.add_argument( |
|
|
323 |
'--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float) |
|
|
324 |
# TimeOut |
|
|
325 |
parser.add_argument('--to_crop_ratio_range', nargs='+', |
|
|
326 |
help='ratio range for timeout transformation', default=[0.2, 0.4], type=float) |
|
|
327 |
# resume training |
|
|
328 |
parser.add_argument('--resume', action='store_true') |
|
|
329 |
parser.add_argument( |
|
|
330 |
'--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1) |
|
|
331 |
parser.add_argument( |
|
|
332 |
'--num_nodes', default=1, help='number of cluster nodes', type=int) |
|
|
333 |
parser.add_argument( |
|
|
334 |
'--distributed_backend', help='sets backend type') |
|
|
335 |
parser.add_argument('--batch_size', type=int) |
|
|
336 |
parser.add_argument('--epochs', type=int) |
|
|
337 |
parser.add_argument('--debug', action='store_true') |
|
|
338 |
parser.add_argument('--warm_up', default=1, type=int) |
|
|
339 |
parser.add_argument('--precision', type=int) |
|
|
340 |
parser.add_argument('--datasets', dest="target_folders", |
|
|
341 |
nargs='+', help='used datasets for pretraining') |
|
|
342 |
parser.add_argument('--log_dir', default="./experiment_logs") |
|
|
343 |
parser.add_argument( |
|
|
344 |
'--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0) |
|
|
345 |
parser.add_argument('--lr', type=float, help="learning rate") |
|
|
346 |
parser.add_argument('--out_dim', type=int, help="output dimension of model") |
|
|
347 |
parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data") |
|
|
348 |
parser.add_argument('--base_model') |
|
|
349 |
parser.add_argument('--widen',type=int, help="use wide xresnet1d50") |
|
|
350 |
parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining") |
|
|
351 |
|
|
|
352 |
parser.add_argument('--checkpoint_path', default="") |
|
|
353 |
return parser |
|
|
354 |
|
|
|
355 |
def init_logger(config): |
|
|
356 |
level = logging.INFO |
|
|
357 |
|
|
|
358 |
if config['debug']: |
|
|
359 |
level = logging.DEBUG |
|
|
360 |
|
|
|
361 |
# remove all handlers to change basic configuration |
|
|
362 |
for handler in logging.root.handlers[:]: |
|
|
363 |
logging.root.removeHandler(handler) |
|
|
364 |
if not os.path.isdir(config['log_dir']): |
|
|
365 |
os.mkdir(config['log_dir']) |
|
|
366 |
logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level, |
|
|
367 |
format='%(asctime)s %(name)s:%(lineno)s %(levelname)s: %(message)s ') |
|
|
368 |
return logging.getLogger(__name__) |
|
|
369 |
|
|
|
370 |
def pretrain_routine(args): |
|
|
371 |
t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius, |
|
|
372 |
"epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range, |
|
|
373 |
"bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1} |
|
|
374 |
transformations = args.trafos |
|
|
375 |
checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml") |
|
|
376 |
config_file = checkpoint_config if args.resume and os.path.isfile( |
|
|
377 |
checkpoint_config) else "bolts_config.yaml" |
|
|
378 |
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader) |
|
|
379 |
args_dict = vars(args) |
|
|
380 |
for key in set(config.keys()).union(set(args_dict.keys())): |
|
|
381 |
config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys( |
|
|
382 |
) and key in config.keys() and args_dict[key] is None) else args_dict[key] |
|
|
383 |
if args.target_folders is not None: |
|
|
384 |
config["dataset"]["target_folders"] = args.target_folders |
|
|
385 |
config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"] |
|
|
386 |
config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"] |
|
|
387 |
config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"] |
|
|
388 |
config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"] |
|
|
389 |
if args.out_dim is not None: |
|
|
390 |
config["model"]["out_dim"] = args.out_dim |
|
|
391 |
init_logger(config) |
|
|
392 |
dataset = SimCLRDataSetWrapper( |
|
|
393 |
config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params) |
|
|
394 |
for i, t in enumerate(dataset.transformations): |
|
|
395 |
logger.info(str(i) + ". Transformation: " + |
|
|
396 |
str(t) + ": " + str(t.get_params())) |
|
|
397 |
date = time.asctime() |
|
|
398 |
label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19, |
|
|
399 |
"label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5} |
|
|
400 |
ptb_num_classes = label_to_num_classes[config["eval_dataset"] |
|
|
401 |
["ptb_xl_label"]] |
|
|
402 |
abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN", |
|
|
403 |
"TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"} |
|
|
404 |
trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [ |
|
|
405 |
"TT", "Tr"] else '' for tr in dataset.transformations])) |
|
|
406 |
name = str(date) + "_" + method + "_" + str( |
|
|
407 |
time.time_ns())[-3:] + "_" + trs[1:] |
|
|
408 |
tb_logger = TensorBoardLogger(args.log_dir, name=name, version='') |
|
|
409 |
config["log_dir"] = os.path.join(args.log_dir, name) |
|
|
410 |
print(config) |
|
|
411 |
return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger |
|
|
412 |
|
|
|
413 |
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks): |
|
|
414 |
scores = {} |
|
|
415 |
for ca in callbacks: |
|
|
416 |
if isinstance(ca, SSLOnlineEvaluator): |
|
|
417 |
scores[str(ca)] = {"macro": ca.best_macro} |
|
|
418 |
|
|
|
419 |
results = {"config": config, "trafos": args.trafos, "scores": scores} |
|
|
420 |
|
|
|
421 |
with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle: |
|
|
422 |
pickle.dump(results, handle) |
|
|
423 |
|
|
|
424 |
trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt")) |
|
|
425 |
with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file: |
|
|
426 |
print(config, file=text_file) |
|
|
427 |
|
|
|
428 |
def cli_main(): |
|
|
429 |
from pytorch_lightning import Trainer |
|
|
430 |
from online_evaluator import SSLOnlineEvaluator |
|
|
431 |
from ecg_datamodule import ECGDataModule |
|
|
432 |
from clinical_ts.create_logger import create_logger |
|
|
433 |
from os.path import exists |
|
|
434 |
|
|
|
435 |
parser = ArgumentParser() |
|
|
436 |
parser = parse_args(parser) |
|
|
437 |
logger.info("parse arguments") |
|
|
438 |
args = parser.parse_args() |
|
|
439 |
config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args) |
|
|
440 |
|
|
|
441 |
# data |
|
|
442 |
ecg_datamodule = ECGDataModule(config, transformations, t_params) |
|
|
443 |
|
|
|
444 |
callbacks = [] |
|
|
445 |
if args.run_callbacks: |
|
|
446 |
# callback for online linear evaluation/fine-tuning |
|
|
447 |
linear_evaluator = SSLOnlineEvaluator(drop_p=0, |
|
|
448 |
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False) |
|
|
449 |
|
|
|
450 |
fine_tuner = SSLOnlineEvaluator(drop_p=0, |
|
|
451 |
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False) |
|
|
452 |
|
|
|
453 |
callbacks.append(linear_evaluator) |
|
|
454 |
callbacks.append(fine_tuner) |
|
|
455 |
|
|
|
456 |
# configure trainer |
|
|
457 |
trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus, |
|
|
458 |
distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks) |
|
|
459 |
|
|
|
460 |
# pytorch lightning module |
|
|
461 |
pl_model = CustomBYOL(5, learning_rate=config["lr"], weight_decay=eval(config["weight_decay"]), |
|
|
462 |
warm_up_epochs=config["warm_up"], max_epochs=config[ |
|
|
463 |
"epochs"], num_workers=config["dataset"]["num_workers"], |
|
|
464 |
batch_size=config["batch_size"], config=config, transformations=ecg_datamodule.transformations) |
|
|
465 |
|
|
|
466 |
|
|
|
467 |
# load checkpoint |
|
|
468 |
if args.checkpoint_path != "": |
|
|
469 |
if exists(args.checkpoint_path): |
|
|
470 |
logger.info("Retrieve checkpoint from " + args.checkpoint_path) |
|
|
471 |
pl_model.load_from_checkpoint(args.checkpoint_path) |
|
|
472 |
else: |
|
|
473 |
raise("checkpoint does not exist") |
|
|
474 |
|
|
|
475 |
# start training |
|
|
476 |
trainer.fit(pl_model, ecg_datamodule) |
|
|
477 |
|
|
|
478 |
aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks) |
|
|
479 |
|
|
|
480 |
if __name__ == "__main__": |
|
|
481 |
cli_main() |