|
a |
|
b/custom_simclr_bolts.py |
|
|
1 |
import pytorch_lightning as pl |
|
|
2 |
# from pl_bolts.models.self_supervised import SimCLR |
|
|
3 |
from pl_bolts.optimizers.lars_scheduling import LARSWrapper |
|
|
4 |
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR |
|
|
5 |
from torch.optim import Adam |
|
|
6 |
import torch |
|
|
7 |
import re |
|
|
8 |
import pdb |
|
|
9 |
|
|
|
10 |
import math |
|
|
11 |
from argparse import ArgumentParser |
|
|
12 |
from typing import Callable, Optional |
|
|
13 |
|
|
|
14 |
import numpy as np |
|
|
15 |
import torch |
|
|
16 |
import torch.distributed as dist |
|
|
17 |
import torch.nn.functional as F |
|
|
18 |
from pytorch_lightning.utilities import AMPType |
|
|
19 |
from torch import nn |
|
|
20 |
from torch.optim.optimizer import Optimizer |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
from models.resnet_simclr import ResNetSimCLR |
|
|
24 |
import re |
|
|
25 |
|
|
|
26 |
import time |
|
|
27 |
import pickle |
|
|
28 |
import yaml |
|
|
29 |
import logging |
|
|
30 |
import os |
|
|
31 |
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper |
|
|
32 |
from clinical_ts.create_logger import create_logger |
|
|
33 |
import pickle |
|
|
34 |
from pytorch_lightning import Trainer, seed_everything |
|
|
35 |
|
|
|
36 |
from torch import nn |
|
|
37 |
from torch.nn import functional as F |
|
|
38 |
from online_evaluator import SSLOnlineEvaluator |
|
|
39 |
from ecg_datamodule import ECGDataModule |
|
|
40 |
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
41 |
from pl_bolts.models.self_supervised.evaluator import Flatten |
|
|
42 |
import pdb |
|
|
43 |
method="simclr" |
|
|
44 |
logger = create_logger(__name__) |
|
|
45 |
def _accuracy(zis, zjs, batch_size): |
|
|
46 |
with torch.no_grad(): |
|
|
47 |
representations = torch.cat([zjs, zis], dim=0) |
|
|
48 |
similarity_matrix = torch.mm( |
|
|
49 |
representations, representations.t().contiguous()) |
|
|
50 |
corrected_similarity_matrix = similarity_matrix - \ |
|
|
51 |
torch.eye(2*batch_size).type_as(similarity_matrix) |
|
|
52 |
pred_similarities, pred_indices = torch.max( |
|
|
53 |
corrected_similarity_matrix[:batch_size], dim=1) |
|
|
54 |
correct_indices = torch.arange(batch_size)+batch_size |
|
|
55 |
correct_preds = ( |
|
|
56 |
pred_indices == correct_indices.type_as(pred_indices)).sum() |
|
|
57 |
return correct_preds.float()/batch_size |
|
|
58 |
|
|
|
59 |
def mean(res, key1, key2=None): |
|
|
60 |
if key2 is not None: |
|
|
61 |
return torch.stack([x[key1][key2] for x in res]).mean() |
|
|
62 |
return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean() |
|
|
63 |
|
|
|
64 |
class Projection(nn.Module): |
|
|
65 |
def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): |
|
|
66 |
super().__init__() |
|
|
67 |
self.output_dim = output_dim |
|
|
68 |
self.input_dim = input_dim |
|
|
69 |
self.hidden_dim = hidden_dim |
|
|
70 |
self.model = nn.Sequential( |
|
|
71 |
# nn.AdaptiveAvgPool2d((1, 1)), |
|
|
72 |
Flatten(), |
|
|
73 |
nn.Linear(self.input_dim, self.hidden_dim, bias=True), |
|
|
74 |
# nn.BatchNorm1d(self.hidden_dim), |
|
|
75 |
nn.ReLU(), |
|
|
76 |
nn.Linear(self.hidden_dim, self.output_dim, bias=True)) |
|
|
77 |
|
|
|
78 |
def forward(self, x): |
|
|
79 |
x = self.model(x) |
|
|
80 |
return F.normalize(x, dim=1) |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
class SyncFunction(torch.autograd.Function): |
|
|
84 |
|
|
|
85 |
@staticmethod |
|
|
86 |
def forward(ctx, tensor): |
|
|
87 |
ctx.batch_size = tensor.shape[0] |
|
|
88 |
|
|
|
89 |
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] |
|
|
90 |
|
|
|
91 |
torch.distributed.all_gather(gathered_tensor, tensor) |
|
|
92 |
gathered_tensor = torch.cat(gathered_tensor, 0) |
|
|
93 |
|
|
|
94 |
return gathered_tensor |
|
|
95 |
|
|
|
96 |
@staticmethod |
|
|
97 |
def backward(ctx, grad_output): |
|
|
98 |
grad_input = grad_output.clone() |
|
|
99 |
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) |
|
|
100 |
|
|
|
101 |
return grad_input[torch.distributed.get_rank() * ctx.batch_size:(torch.distributed.get_rank() + 1) * |
|
|
102 |
ctx.batch_size] |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
class CustomSimCLR(pl.LightningModule): |
|
|
106 |
|
|
|
107 |
def __init__(self, |
|
|
108 |
batch_size, |
|
|
109 |
num_samples, |
|
|
110 |
warmup_epochs=10, |
|
|
111 |
lr=1e-4, |
|
|
112 |
opt_weight_decay=1e-6, |
|
|
113 |
loss_temperature=0.5, |
|
|
114 |
config=None, |
|
|
115 |
transformations=None, |
|
|
116 |
**kwargs): |
|
|
117 |
""" |
|
|
118 |
Args: |
|
|
119 |
batch_size: the batch size |
|
|
120 |
num_samples: num samples in the dataset |
|
|
121 |
warmup_epochs: epochs to warmup the lr for |
|
|
122 |
lr: the optimizer learning rate |
|
|
123 |
opt_weight_decay: the optimizer weight decay |
|
|
124 |
loss_temperature: the loss temperature |
|
|
125 |
""" |
|
|
126 |
|
|
|
127 |
super(CustomSimCLR, self).__init__() |
|
|
128 |
self.config = config |
|
|
129 |
self.transformations = transformations |
|
|
130 |
self.epoch = 0 |
|
|
131 |
self.batch_size = batch_size |
|
|
132 |
self.num_samples = num_samples |
|
|
133 |
self.save_hyperparameters() |
|
|
134 |
# pdb.set_trace() |
|
|
135 |
|
|
|
136 |
def configure_optimizers(self): |
|
|
137 |
global_batch_size = self.trainer.world_size * self.hparams.batch_size |
|
|
138 |
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size |
|
|
139 |
# TRICK 1 (Use lars + filter weights) |
|
|
140 |
# exclude certain parameters |
|
|
141 |
parameters = self.exclude_from_wt_decay( |
|
|
142 |
self.named_parameters(), |
|
|
143 |
weight_decay=self.hparams.opt_weight_decay |
|
|
144 |
) |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
# optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.lr)) |
|
|
148 |
optimizer = Adam(parameters, lr=self.hparams.lr) |
|
|
149 |
|
|
|
150 |
# Trick 2 (after each step) |
|
|
151 |
self.hparams.warmup_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch |
|
|
152 |
max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch |
|
|
153 |
|
|
|
154 |
linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR( |
|
|
155 |
optimizer, |
|
|
156 |
warmup_epochs=self.hparams.warmup_epochs, |
|
|
157 |
max_epochs=max_epochs, |
|
|
158 |
warmup_start_lr=0, |
|
|
159 |
eta_min=0 |
|
|
160 |
) |
|
|
161 |
|
|
|
162 |
scheduler = { |
|
|
163 |
'scheduler': linear_warmup_cosine_decay, |
|
|
164 |
'interval': 'step', |
|
|
165 |
'frequency': 1 |
|
|
166 |
} |
|
|
167 |
|
|
|
168 |
return [optimizer], [scheduler] |
|
|
169 |
|
|
|
170 |
def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): |
|
|
171 |
params = [] |
|
|
172 |
excluded_params = [] |
|
|
173 |
|
|
|
174 |
for name, param in named_params: |
|
|
175 |
if not param.requires_grad: |
|
|
176 |
continue |
|
|
177 |
elif any(layer_name in name for layer_name in skip_list): |
|
|
178 |
excluded_params.append(param) |
|
|
179 |
else: |
|
|
180 |
params.append(param) |
|
|
181 |
|
|
|
182 |
return [ |
|
|
183 |
{'params': params, 'weight_decay': weight_decay}, |
|
|
184 |
{'params': excluded_params, 'weight_decay': 0.} |
|
|
185 |
] |
|
|
186 |
|
|
|
187 |
def shared_forward(self, batch, batch_idx): |
|
|
188 |
(x1, y1), (x2, y2) = batch |
|
|
189 |
# ENCODE |
|
|
190 |
# encode -> representations |
|
|
191 |
# (b, 3, 32, 32) -> (b, 2048, 2, 2) |
|
|
192 |
x1 = self.to_device(x1) |
|
|
193 |
x2 = self.to_device(x2) |
|
|
194 |
|
|
|
195 |
h1 = self.encoder(x1)[0] |
|
|
196 |
h2 = self.encoder(x2)[0] |
|
|
197 |
|
|
|
198 |
# the bolts resnets return a list of feature maps |
|
|
199 |
if isinstance(h1, list): |
|
|
200 |
h1 = h1[-1] |
|
|
201 |
h2 = h2[-1] |
|
|
202 |
|
|
|
203 |
# PROJECT |
|
|
204 |
# img -> E -> h -> || -> z |
|
|
205 |
# (b, 2048, 2, 2) -> (b, 128) |
|
|
206 |
z1 = self.projection(h1.squeeze()) |
|
|
207 |
z2 = self.projection(h2.squeeze()) |
|
|
208 |
|
|
|
209 |
return z1, z2 |
|
|
210 |
|
|
|
211 |
def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): |
|
|
212 |
""" |
|
|
213 |
assume out_1 and out_2 are normalized |
|
|
214 |
out_1: [batch_size, dim] |
|
|
215 |
out_2: [batch_size, dim] |
|
|
216 |
""" |
|
|
217 |
# gather representations in case of distributed training |
|
|
218 |
# out_1_dist: [batch_size * world_size, dim] |
|
|
219 |
# out_2_dist: [batch_size * world_size, dim] |
|
|
220 |
if torch.distributed.is_available() and torch.distributed.is_initialized(): |
|
|
221 |
out_1_dist = SyncFunction.apply(out_1) |
|
|
222 |
out_2_dist = SyncFunction.apply(out_2) |
|
|
223 |
print("out dist shape: ", out_1_dist.shape) |
|
|
224 |
else: |
|
|
225 |
out_1_dist = out_1 |
|
|
226 |
out_2_dist = out_2 |
|
|
227 |
|
|
|
228 |
# out: [2 * batch_size, dim] |
|
|
229 |
# out_dist: [2 * batch_size * world_size, dim] |
|
|
230 |
out = torch.cat([out_1, out_2], dim=0) |
|
|
231 |
out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) |
|
|
232 |
|
|
|
233 |
# cov and sim: [2 * batch_size, 2 * batch_size * world_size] |
|
|
234 |
# neg: [2 * batch_size] |
|
|
235 |
cov = torch.mm(out, out_dist.t().contiguous()) |
|
|
236 |
sim = torch.exp(cov / temperature) |
|
|
237 |
neg = sim.sum(dim=-1) |
|
|
238 |
|
|
|
239 |
# from each row, subtract e^1 to remove similarity measure for x1.x1 |
|
|
240 |
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) |
|
|
241 |
neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability |
|
|
242 |
|
|
|
243 |
# Positive similarity, pos becomes [2 * batch_size] |
|
|
244 |
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) |
|
|
245 |
pos = torch.cat([pos, pos], dim=0) |
|
|
246 |
|
|
|
247 |
loss = -torch.log(pos / (neg + eps)).mean() |
|
|
248 |
|
|
|
249 |
return loss |
|
|
250 |
|
|
|
251 |
def training_step(self, batch, batch_idx): |
|
|
252 |
z1, z2 = self.shared_forward(batch, batch_idx) |
|
|
253 |
loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature) |
|
|
254 |
# result = pl.TrainResult(minimize=loss) |
|
|
255 |
# result.log('train/train_loss', loss, on_epoch=True) |
|
|
256 |
|
|
|
257 |
acc = _accuracy(z1, z2, z1.shape[0]) |
|
|
258 |
# result.log('train/train_acc', acc, on_epoch=True) |
|
|
259 |
result = { |
|
|
260 |
"train/train_loss": loss, |
|
|
261 |
"minimize":loss, |
|
|
262 |
"train/train_acc" : acc, |
|
|
263 |
} |
|
|
264 |
return loss |
|
|
265 |
|
|
|
266 |
def validation_step(self, batch, batch_idx, dataloader_idx): |
|
|
267 |
if dataloader_idx != 0: |
|
|
268 |
return {} |
|
|
269 |
z1, z2 = self.shared_forward(batch, batch_idx) |
|
|
270 |
loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature) |
|
|
271 |
|
|
|
272 |
acc = _accuracy(z1, z2, z1.shape[0]) |
|
|
273 |
results = { |
|
|
274 |
'val_loss': loss, |
|
|
275 |
'val_acc': torch.tensor(acc) |
|
|
276 |
} |
|
|
277 |
return results |
|
|
278 |
|
|
|
279 |
def validation_epoch_end(self, outputs): |
|
|
280 |
# outputs[0] because we are using multiple datasets! |
|
|
281 |
val_loss = mean(outputs[0], 'val_loss') |
|
|
282 |
val_acc = mean(outputs[0], 'val_acc') |
|
|
283 |
|
|
|
284 |
log = { |
|
|
285 |
'val/val_loss': val_loss, |
|
|
286 |
'val/val_acc': val_acc |
|
|
287 |
} |
|
|
288 |
return {'val_loss': val_loss, 'log': log, 'progress_bar': log} |
|
|
289 |
|
|
|
290 |
def on_train_start(self): |
|
|
291 |
# log configuration |
|
|
292 |
config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config)) |
|
|
293 |
config_str = re.sub(r"[\[\]\']", "", config_str) |
|
|
294 |
transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str( |
|
|
295 |
t) + ":<br/>" + str(t.get_params()) for t in self.transformations])) |
|
|
296 |
transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str) |
|
|
297 |
self.logger.experiment.add_text( |
|
|
298 |
"configuration", str(config_str), global_step=0) |
|
|
299 |
self.logger.experiment.add_text("transformations", str( |
|
|
300 |
transformation_str), global_step=0) |
|
|
301 |
self.epoch = 0 |
|
|
302 |
|
|
|
303 |
def on_epoch_end(self): |
|
|
304 |
self.epoch += 1 |
|
|
305 |
|
|
|
306 |
def type(self): |
|
|
307 |
return self.encoder.features[0][0].weight.type() |
|
|
308 |
|
|
|
309 |
def get_representations(self, x): |
|
|
310 |
return self.encoder(x)[0] |
|
|
311 |
|
|
|
312 |
def get_model(self): |
|
|
313 |
return self.encoder |
|
|
314 |
|
|
|
315 |
def get_device(self): |
|
|
316 |
return self.encoder.features[0][0].weight.device |
|
|
317 |
|
|
|
318 |
def to_device(self, x): |
|
|
319 |
return x.type(self.type()).to(self.get_device()) |
|
|
320 |
|
|
|
321 |
|
|
|
322 |
def parse_args(parent_parser): |
|
|
323 |
parser = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
324 |
parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline', |
|
|
325 |
default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"]) |
|
|
326 |
# GaussianNoise |
|
|
327 |
parser.add_argument( |
|
|
328 |
'--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float) |
|
|
329 |
# RandomResizedCrop |
|
|
330 |
parser.add_argument('--rr_crop_ratio_range', |
|
|
331 |
help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float) |
|
|
332 |
parser.add_argument( |
|
|
333 |
'--output_size', help='output size for random resized crop transformation', default=250, type=int) |
|
|
334 |
# DynamicTimeWarp |
|
|
335 |
parser.add_argument( |
|
|
336 |
'--warps', help='number of warps for dynamic time warp transformation', default=3, type=int) |
|
|
337 |
parser.add_argument( |
|
|
338 |
'--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int) |
|
|
339 |
# TimeWarp |
|
|
340 |
parser.add_argument( |
|
|
341 |
'--epsilon', help='epsilon param for time warp', default=10, type=float) |
|
|
342 |
# ChannelResize |
|
|
343 |
parser.add_argument('--magnitude_range', nargs='+', |
|
|
344 |
help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float) |
|
|
345 |
# Downsample |
|
|
346 |
parser.add_argument( |
|
|
347 |
'--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float) |
|
|
348 |
# TimeOut |
|
|
349 |
parser.add_argument('--to_crop_ratio_range', nargs='+', |
|
|
350 |
help='ratio range for timeout transformation', default=[0.2, 0.4], type=float) |
|
|
351 |
# resume training |
|
|
352 |
parser.add_argument('--resume', action='store_true') |
|
|
353 |
parser.add_argument( |
|
|
354 |
'--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1) |
|
|
355 |
parser.add_argument( |
|
|
356 |
'--num_nodes', default=1, help='number of cluster nodes', type=int) |
|
|
357 |
parser.add_argument( |
|
|
358 |
'--distributed_backend', help='sets backend type') |
|
|
359 |
parser.add_argument('--batch_size', type=int) |
|
|
360 |
parser.add_argument('--epochs', type=int) |
|
|
361 |
parser.add_argument('--debug', action='store_true') |
|
|
362 |
parser.add_argument('--warm_up', default=1, type=int, help="number of warm up epochs") |
|
|
363 |
parser.add_argument('--precision', type=int) |
|
|
364 |
parser.add_argument('--datasets', dest="target_folders", |
|
|
365 |
nargs='+', help='used datasets for pretraining') |
|
|
366 |
parser.add_argument('--log_dir', default="./experiment_logs") |
|
|
367 |
parser.add_argument( |
|
|
368 |
'--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0) |
|
|
369 |
parser.add_argument('--lr', type=float, help="learning rate") |
|
|
370 |
parser.add_argument('--out_dim', type=int, help="output dimension of model") |
|
|
371 |
parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data") |
|
|
372 |
parser.add_argument('--base_model') |
|
|
373 |
parser.add_argument('--widen',type=int, help="use wide xresnet1d50") |
|
|
374 |
parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining") |
|
|
375 |
parser.add_argument('--checkpoint_path', default="") |
|
|
376 |
return parser |
|
|
377 |
|
|
|
378 |
def init_logger(config): |
|
|
379 |
level = logging.INFO |
|
|
380 |
|
|
|
381 |
if config['debug']: |
|
|
382 |
level = logging.DEBUG |
|
|
383 |
|
|
|
384 |
# remove all handlers to change basic configuration |
|
|
385 |
for handler in logging.root.handlers[:]: |
|
|
386 |
logging.root.removeHandler(handler) |
|
|
387 |
if not os.path.isdir(config['log_dir']): |
|
|
388 |
os.mkdir(config['log_dir']) |
|
|
389 |
logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level, |
|
|
390 |
format='%(asctime)s %(name)s:%(lineno)s %(levelname)s: %(message)s ') |
|
|
391 |
return logging.getLogger(__name__) |
|
|
392 |
|
|
|
393 |
def pretrain_routine(args): |
|
|
394 |
t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius, |
|
|
395 |
"epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range, |
|
|
396 |
"bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1} |
|
|
397 |
transformations = args.trafos |
|
|
398 |
checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml") |
|
|
399 |
config_file = checkpoint_config if args.resume and os.path.isfile( |
|
|
400 |
checkpoint_config) else "bolts_config.yaml" |
|
|
401 |
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader) |
|
|
402 |
args_dict = vars(args) |
|
|
403 |
for key in set(config.keys()).union(set(args_dict.keys())): |
|
|
404 |
config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys( |
|
|
405 |
) and key in config.keys() and args_dict[key] is None) else args_dict[key] |
|
|
406 |
if args.target_folders is not None: |
|
|
407 |
config["dataset"]["target_folders"] = args.target_folders |
|
|
408 |
config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"] |
|
|
409 |
config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"] |
|
|
410 |
config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"] |
|
|
411 |
config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"] |
|
|
412 |
if args.out_dim is not None: |
|
|
413 |
config["model"]["out_dim"] = args.out_dim |
|
|
414 |
init_logger(config) |
|
|
415 |
dataset = SimCLRDataSetWrapper( |
|
|
416 |
config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params) |
|
|
417 |
for i, t in enumerate(dataset.transformations): |
|
|
418 |
logger.info(str(i) + ". Transformation: " + |
|
|
419 |
str(t) + ": " + str(t.get_params())) |
|
|
420 |
date = time.asctime() |
|
|
421 |
label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19, |
|
|
422 |
"label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5} |
|
|
423 |
ptb_num_classes = label_to_num_classes[config["eval_dataset"] |
|
|
424 |
["ptb_xl_label"]] |
|
|
425 |
abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN", |
|
|
426 |
"TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"} |
|
|
427 |
trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [ |
|
|
428 |
"TT", "Tr"] else '' for tr in dataset.transformations])) |
|
|
429 |
name = str(date) + "_" + method + "_" + str( |
|
|
430 |
time.time_ns())[-3:] + "_" + trs[1:] |
|
|
431 |
tb_logger = TensorBoardLogger(args.log_dir, name=name, version='') |
|
|
432 |
config["log_dir"] = os.path.join(args.log_dir, name) |
|
|
433 |
print(config) |
|
|
434 |
return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger |
|
|
435 |
|
|
|
436 |
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks): |
|
|
437 |
scores = {} |
|
|
438 |
for ca in callbacks: |
|
|
439 |
if isinstance(ca, SSLOnlineEvaluator): |
|
|
440 |
scores[str(ca)] = {"macro": ca.best_macro} |
|
|
441 |
|
|
|
442 |
results = {"config": config, "trafos": args.trafos, "scores": scores} |
|
|
443 |
|
|
|
444 |
with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle: |
|
|
445 |
pickle.dump(results, handle) |
|
|
446 |
|
|
|
447 |
trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt")) |
|
|
448 |
with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file: |
|
|
449 |
print(config, file=text_file) |
|
|
450 |
|
|
|
451 |
def cli_main(): |
|
|
452 |
from pytorch_lightning import Trainer |
|
|
453 |
from online_evaluator import SSLOnlineEvaluator |
|
|
454 |
from ecg_datamodule import ECGDataModule |
|
|
455 |
from clinical_ts.create_logger import create_logger |
|
|
456 |
from os.path import exists |
|
|
457 |
|
|
|
458 |
parser = ArgumentParser() |
|
|
459 |
parser = parse_args(parser) |
|
|
460 |
logger.info("parse arguments") |
|
|
461 |
args = parser.parse_args() |
|
|
462 |
config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args) |
|
|
463 |
|
|
|
464 |
# data |
|
|
465 |
ecg_datamodule = ECGDataModule(config, transformations, t_params) |
|
|
466 |
|
|
|
467 |
callbacks = [] |
|
|
468 |
if args.run_callbacks: |
|
|
469 |
# callback for online linear evaluation/fine-tuning |
|
|
470 |
linear_evaluator = SSLOnlineEvaluator(drop_p=0, |
|
|
471 |
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False) |
|
|
472 |
|
|
|
473 |
fine_tuner = SSLOnlineEvaluator(drop_p=0, |
|
|
474 |
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False) |
|
|
475 |
|
|
|
476 |
callbacks.append(linear_evaluator) |
|
|
477 |
callbacks.append(fine_tuner) |
|
|
478 |
|
|
|
479 |
# configure trainer |
|
|
480 |
trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus, |
|
|
481 |
distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks) |
|
|
482 |
|
|
|
483 |
# pytorch lightning module |
|
|
484 |
model = ResNetSimCLR(**config["model"]) |
|
|
485 |
pl_model = CustomSimCLR( |
|
|
486 |
config["batch_size"], ecg_datamodule.num_samples, warmup_epochs=config["warm_up"], lr=config["lr"], |
|
|
487 |
out_dim=config["model"]["out_dim"], config=config, |
|
|
488 |
transformations=ecg_datamodule.transformations, loss_temperature=config["loss"]["temperature"], weight_decay=eval(config["weight_decay"])) |
|
|
489 |
pl_model.encoder = model |
|
|
490 |
pl_model.projection = Projection( |
|
|
491 |
input_dim=model.l1.in_features, hidden_dim=512, output_dim=config["model"]["out_dim"]) |
|
|
492 |
|
|
|
493 |
# load checkpoint |
|
|
494 |
if args.checkpoint_path != "": |
|
|
495 |
if exists(args.checkpoint_path): |
|
|
496 |
logger.info("Retrieve checkpoint from " + args.checkpoint_path) |
|
|
497 |
pl_model.load_from_checkpoint(args.checkpoint_path) |
|
|
498 |
else: |
|
|
499 |
raise("checkpoint does not exist") |
|
|
500 |
|
|
|
501 |
# start training |
|
|
502 |
trainer.fit(pl_model, ecg_datamodule) |
|
|
503 |
|
|
|
504 |
aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks) |
|
|
505 |
|
|
|
506 |
if __name__ == "__main__": |
|
|
507 |
cli_main() |