|
a |
|
b/custom_swav_bolts.py |
|
|
1 |
""" |
|
|
2 |
Adapted from official swav implementation: https://github.com/facebookresearch/swav |
|
|
3 |
""" |
|
|
4 |
import math |
|
|
5 |
import os |
|
|
6 |
import re |
|
|
7 |
from argparse import ArgumentParser |
|
|
8 |
from typing import Callable, Optional |
|
|
9 |
import pdb |
|
|
10 |
import numpy as np |
|
|
11 |
import pytorch_lightning as pl |
|
|
12 |
import torch |
|
|
13 |
import torch.distributed as dist |
|
|
14 |
from pytorch_lightning.utilities import AMPType |
|
|
15 |
from torch import nn |
|
|
16 |
from pytorch_lightning.core.optimizer import LightningOptimizer |
|
|
17 |
from torch.optim.optimizer import Optimizer |
|
|
18 |
|
|
|
19 |
import yaml |
|
|
20 |
import time |
|
|
21 |
import logging |
|
|
22 |
import pickle |
|
|
23 |
# from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 |
|
|
24 |
from pl_bolts.optimizers.lars_scheduling import LARSWrapper |
|
|
25 |
from pl_bolts.transforms.dataset_normalizations import ( |
|
|
26 |
cifar10_normalization, |
|
|
27 |
imagenet_normalization, |
|
|
28 |
stl10_normalization, |
|
|
29 |
) |
|
|
30 |
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper |
|
|
31 |
from clinical_ts.create_logger import create_logger |
|
|
32 |
from torchvision.models.resnet import Bottleneck, BasicBlock |
|
|
33 |
from online_evaluator import SSLOnlineEvaluator |
|
|
34 |
from ecg_datamodule import ECGDataModule |
|
|
35 |
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
36 |
from models.resnet_simclr import ResNetSimCLR |
|
|
37 |
import torchvision.transforms as transforms |
|
|
38 |
|
|
|
39 |
_TORCHVISION_AVAILABLE = True |
|
|
40 |
|
|
|
41 |
# import cv2 |
|
|
42 |
from typing import List |
|
|
43 |
logger = create_logger(__name__) |
|
|
44 |
method = "swav" |
|
|
45 |
class SwAVTrainDataTransform(object): |
|
|
46 |
def __init__( |
|
|
47 |
self, |
|
|
48 |
normalize=None, |
|
|
49 |
size_crops: List[int] = [96, 36], |
|
|
50 |
nmb_crops: List[int] = [2, 4], |
|
|
51 |
min_scale_crops: List[float] = [0.33, 0.10], |
|
|
52 |
max_scale_crops: List[float] = [1, 0.33], |
|
|
53 |
gaussian_blur: bool = True, |
|
|
54 |
jitter_strength: float = 1. |
|
|
55 |
): |
|
|
56 |
self.jitter_strength = jitter_strength |
|
|
57 |
self.gaussian_blur = gaussian_blur |
|
|
58 |
|
|
|
59 |
assert len(size_crops) == len(nmb_crops) |
|
|
60 |
assert len(min_scale_crops) == len(nmb_crops) |
|
|
61 |
assert len(max_scale_crops) == len(nmb_crops) |
|
|
62 |
|
|
|
63 |
self.size_crops = size_crops |
|
|
64 |
self.nmb_crops = nmb_crops |
|
|
65 |
self.min_scale_crops = min_scale_crops |
|
|
66 |
self.max_scale_crops = max_scale_crops |
|
|
67 |
|
|
|
68 |
self.color_jitter = transforms.ColorJitter( |
|
|
69 |
0.8 * self.jitter_strength, |
|
|
70 |
0.8 * self.jitter_strength, |
|
|
71 |
0.8 * self.jitter_strength, |
|
|
72 |
0.2 * self.jitter_strength |
|
|
73 |
) |
|
|
74 |
|
|
|
75 |
transform = [] |
|
|
76 |
color_transform = [ |
|
|
77 |
transforms.RandomApply([self.color_jitter], p=0.8), |
|
|
78 |
transforms.RandomGrayscale(p=0.2) |
|
|
79 |
] |
|
|
80 |
|
|
|
81 |
if self.gaussian_blur: |
|
|
82 |
kernel_size = int(0.1 * self.size_crops[0]) |
|
|
83 |
if kernel_size % 2 == 0: |
|
|
84 |
kernel_size += 1 |
|
|
85 |
|
|
|
86 |
color_transform.append( |
|
|
87 |
GaussianBlur(kernel_size=kernel_size, p=0.5) |
|
|
88 |
) |
|
|
89 |
|
|
|
90 |
self.color_transform = transforms.Compose(color_transform) |
|
|
91 |
|
|
|
92 |
if normalize is None: |
|
|
93 |
self.final_transform = transforms.ToTensor() |
|
|
94 |
else: |
|
|
95 |
self.final_transform = transforms.Compose( |
|
|
96 |
[transforms.ToTensor(), normalize]) |
|
|
97 |
|
|
|
98 |
for i in range(len(self.size_crops)): |
|
|
99 |
random_resized_crop = transforms.RandomResizedCrop( |
|
|
100 |
self.size_crops[i], |
|
|
101 |
scale=(self.min_scale_crops[i], self.max_scale_crops[i]), |
|
|
102 |
) |
|
|
103 |
|
|
|
104 |
transform.extend([transforms.Compose([ |
|
|
105 |
random_resized_crop, |
|
|
106 |
transforms.RandomHorizontalFlip(p=0.5), |
|
|
107 |
self.color_transform, |
|
|
108 |
self.final_transform]) |
|
|
109 |
] * self.nmb_crops[i]) |
|
|
110 |
|
|
|
111 |
self.transform = transform |
|
|
112 |
|
|
|
113 |
# add online train transform of the size of global view |
|
|
114 |
online_train_transform = transforms.Compose([ |
|
|
115 |
transforms.RandomResizedCrop(self.size_crops[0]), |
|
|
116 |
transforms.RandomHorizontalFlip(), |
|
|
117 |
self.final_transform |
|
|
118 |
]) |
|
|
119 |
|
|
|
120 |
self.transform.append(online_train_transform) |
|
|
121 |
|
|
|
122 |
def __call__(self, sample): |
|
|
123 |
multi_crops = list( |
|
|
124 |
map(lambda transform: transform(sample), self.transform) |
|
|
125 |
) |
|
|
126 |
return multi_crops |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
class SwAVEvalDataTransform(SwAVTrainDataTransform): |
|
|
130 |
def __init__( |
|
|
131 |
self, |
|
|
132 |
normalize=None, |
|
|
133 |
size_crops: List[int] = [96, 36], |
|
|
134 |
nmb_crops: List[int] = [2, 4], |
|
|
135 |
min_scale_crops: List[float] = [0.33, 0.10], |
|
|
136 |
max_scale_crops: List[float] = [1, 0.33], |
|
|
137 |
gaussian_blur: bool = True, |
|
|
138 |
jitter_strength: float = 1. |
|
|
139 |
): |
|
|
140 |
super().__init__( |
|
|
141 |
normalize=normalize, |
|
|
142 |
size_crops=size_crops, |
|
|
143 |
nmb_crops=nmb_crops, |
|
|
144 |
min_scale_crops=min_scale_crops, |
|
|
145 |
max_scale_crops=max_scale_crops, |
|
|
146 |
gaussian_blur=gaussian_blur, |
|
|
147 |
jitter_strength=jitter_strength |
|
|
148 |
) |
|
|
149 |
|
|
|
150 |
input_height = self.size_crops[0] # get global view crop |
|
|
151 |
test_transform = transforms.Compose([ |
|
|
152 |
transforms.Resize(int(input_height + 0.1 * input_height)), |
|
|
153 |
transforms.CenterCrop(input_height), |
|
|
154 |
self.final_transform, |
|
|
155 |
]) |
|
|
156 |
|
|
|
157 |
# replace last transform to eval transform in self.transform list |
|
|
158 |
self.transform[-1] = test_transform |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
class SwAVFinetuneTransform(object): |
|
|
162 |
def __init__( |
|
|
163 |
self, |
|
|
164 |
input_height: int = 224, |
|
|
165 |
jitter_strength: float = 1., |
|
|
166 |
normalize=None, |
|
|
167 |
eval_transform: bool = False |
|
|
168 |
) -> None: |
|
|
169 |
|
|
|
170 |
self.jitter_strength = jitter_strength |
|
|
171 |
self.input_height = input_height |
|
|
172 |
self.normalize = normalize |
|
|
173 |
|
|
|
174 |
self.color_jitter = transforms.ColorJitter( |
|
|
175 |
0.8 * self.jitter_strength, |
|
|
176 |
0.8 * self.jitter_strength, |
|
|
177 |
0.8 * self.jitter_strength, |
|
|
178 |
0.2 * self.jitter_strength |
|
|
179 |
) |
|
|
180 |
|
|
|
181 |
if not eval_transform: |
|
|
182 |
data_transforms = [ |
|
|
183 |
transforms.RandomResizedCrop(size=self.input_height), |
|
|
184 |
transforms.RandomHorizontalFlip(p=0.5), |
|
|
185 |
transforms.RandomApply([self.color_jitter], p=0.8), |
|
|
186 |
transforms.RandomGrayscale(p=0.2) |
|
|
187 |
] |
|
|
188 |
else: |
|
|
189 |
data_transforms = [ |
|
|
190 |
transforms.Resize( |
|
|
191 |
int(self.input_height + 0.1 * self.input_height)), |
|
|
192 |
transforms.CenterCrop(self.input_height) |
|
|
193 |
] |
|
|
194 |
|
|
|
195 |
if normalize is None: |
|
|
196 |
final_transform = transforms.ToTensor() |
|
|
197 |
else: |
|
|
198 |
final_transform = transforms.Compose( |
|
|
199 |
[transforms.ToTensor(), normalize]) |
|
|
200 |
|
|
|
201 |
data_transforms.append(final_transform) |
|
|
202 |
self.transform = transforms.Compose(data_transforms) |
|
|
203 |
|
|
|
204 |
def __call__(self, sample): |
|
|
205 |
return self.transform(sample) |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
class CustomResNet(nn.Module): |
|
|
209 |
def __init__( |
|
|
210 |
self, |
|
|
211 |
model, |
|
|
212 |
zero_init_residual=False, |
|
|
213 |
output_dim=16, |
|
|
214 |
hidden_mlp=512, |
|
|
215 |
nmb_prototypes=8, |
|
|
216 |
eval_mode=False, |
|
|
217 |
first_conv=True, |
|
|
218 |
maxpool1=True, |
|
|
219 |
l2norm=True |
|
|
220 |
): |
|
|
221 |
super(CustomResNet, self).__init__() |
|
|
222 |
self.l2norm = l2norm |
|
|
223 |
self.model = model |
|
|
224 |
self.features = self.model.features |
|
|
225 |
self.projection_head = nn.Sequential( |
|
|
226 |
nn.Linear(512, hidden_mlp), |
|
|
227 |
nn.BatchNorm1d(hidden_mlp), |
|
|
228 |
nn.ReLU(inplace=True), |
|
|
229 |
nn.Linear(hidden_mlp, output_dim), |
|
|
230 |
) |
|
|
231 |
|
|
|
232 |
# prototype layer |
|
|
233 |
self.prototypes = None |
|
|
234 |
if isinstance(nmb_prototypes, list): |
|
|
235 |
self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) |
|
|
236 |
elif nmb_prototypes > 0: |
|
|
237 |
self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) |
|
|
238 |
|
|
|
239 |
for m in self.modules(): |
|
|
240 |
if isinstance(m, nn.Conv2d): |
|
|
241 |
nn.init.kaiming_normal_( |
|
|
242 |
m.weight, mode="fan_out", nonlinearity="relu") |
|
|
243 |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
|
|
244 |
nn.init.constant_(m.weight, 1) |
|
|
245 |
nn.init.constant_(m.bias, 0) |
|
|
246 |
|
|
|
247 |
# Zero-initialize the last BN in each residual branch, |
|
|
248 |
# so that the residual branch starts with zeros, and each residual block behaves like an identity. |
|
|
249 |
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 |
|
|
250 |
if zero_init_residual: |
|
|
251 |
for m in self.modules(): |
|
|
252 |
if isinstance(m, Bottleneck): |
|
|
253 |
nn.init.constant_(m.bn3.weight, 0) |
|
|
254 |
elif isinstance(m, BasicBlock): |
|
|
255 |
nn.init.constant_(m.bn2.weight, 0) |
|
|
256 |
|
|
|
257 |
def forward_backbone(self, x): |
|
|
258 |
x = x.type(self.features[0][0].weight.type()) |
|
|
259 |
h = self.features(x) |
|
|
260 |
h = h.squeeze() |
|
|
261 |
return h |
|
|
262 |
|
|
|
263 |
def forward_head(self, x): |
|
|
264 |
if self.projection_head is not None: |
|
|
265 |
x = self.projection_head(x) |
|
|
266 |
|
|
|
267 |
if self.l2norm: |
|
|
268 |
x = nn.functional.normalize(x, dim=1, p=2) |
|
|
269 |
|
|
|
270 |
if self.prototypes is not None: |
|
|
271 |
return x, self.prototypes(x) |
|
|
272 |
return x |
|
|
273 |
|
|
|
274 |
def forward(self, inputs): |
|
|
275 |
if not isinstance(inputs, list): |
|
|
276 |
inputs = [inputs] |
|
|
277 |
idx_crops = torch.cumsum(torch.unique_consecutive( |
|
|
278 |
torch.tensor([inp.shape[-1] for inp in inputs]), |
|
|
279 |
return_counts=True, |
|
|
280 |
)[1], 0) |
|
|
281 |
start_idx = 0 |
|
|
282 |
for end_idx in idx_crops: |
|
|
283 |
_out = torch.cat(inputs[start_idx: end_idx]) |
|
|
284 |
|
|
|
285 |
if 'cuda' in str(self.features[0][0].weight.device): |
|
|
286 |
_out = self.forward_backbone(_out.cuda(non_blocking=True)) |
|
|
287 |
else: |
|
|
288 |
_out = self.forward_backbone(_out) |
|
|
289 |
|
|
|
290 |
if start_idx == 0: |
|
|
291 |
output = _out |
|
|
292 |
else: |
|
|
293 |
output = torch.cat((output, _out)) |
|
|
294 |
start_idx = end_idx |
|
|
295 |
return self.forward_head(output) |
|
|
296 |
|
|
|
297 |
|
|
|
298 |
class MultiPrototypes(nn.Module): |
|
|
299 |
def __init__(self, output_dim, nmb_prototypes): |
|
|
300 |
super(MultiPrototypes, self).__init__() |
|
|
301 |
self.nmb_heads = len(nmb_prototypes) |
|
|
302 |
for i, k in enumerate(nmb_prototypes): |
|
|
303 |
self.add_module("prototypes" + str(i), |
|
|
304 |
nn.Linear(output_dim, k, bias=False)) |
|
|
305 |
|
|
|
306 |
def forward(self, x): |
|
|
307 |
out = [] |
|
|
308 |
for i in range(self.nmb_heads): |
|
|
309 |
out.append(getattr(self, "prototypes" + str(i))(x)) |
|
|
310 |
return out |
|
|
311 |
|
|
|
312 |
|
|
|
313 |
class CustomSwAV(pl.LightningModule): |
|
|
314 |
def __init__( |
|
|
315 |
self, |
|
|
316 |
model, |
|
|
317 |
gpus: int, |
|
|
318 |
num_samples: int, |
|
|
319 |
batch_size: int, |
|
|
320 |
config=None, |
|
|
321 |
transformations=None, |
|
|
322 |
nodes: int = 1, |
|
|
323 |
arch: str = 'resnet50', |
|
|
324 |
hidden_mlp: int = 2048, |
|
|
325 |
feat_dim: int = 128, |
|
|
326 |
warmup_epochs: int = 10, |
|
|
327 |
max_epochs: int = 100, |
|
|
328 |
nmb_prototypes: int = 3000, |
|
|
329 |
freeze_prototypes_epochs: int = 1, |
|
|
330 |
temperature: float = 0.1, |
|
|
331 |
sinkhorn_iterations: int = 3, |
|
|
332 |
# queue_length: int = 512, # must be divisible by total batch-size |
|
|
333 |
queue_path: str = "queue", |
|
|
334 |
epoch_queue_starts: int = 15, |
|
|
335 |
crops_for_assign: list = [0, 1], |
|
|
336 |
nmb_crops: list = [2, 6], |
|
|
337 |
first_conv: bool = True, |
|
|
338 |
maxpool1: bool = True, |
|
|
339 |
optimizer: str = 'adam', |
|
|
340 |
lars_wrapper: bool = False, |
|
|
341 |
exclude_bn_bias: bool = False, |
|
|
342 |
start_lr: float = 0., |
|
|
343 |
learning_rate: float = 1e-3, |
|
|
344 |
final_lr: float = 0., |
|
|
345 |
weight_decay: float = 1e-6, |
|
|
346 |
epsilon: float = 0.05, |
|
|
347 |
**kwargs |
|
|
348 |
): |
|
|
349 |
""" |
|
|
350 |
Args: |
|
|
351 |
gpus: number of gpus per node used in training, passed to SwAV module |
|
|
352 |
to manage the queue and select distributed sinkhorn |
|
|
353 |
nodes: number of nodes to train on |
|
|
354 |
num_samples: number of image samples used for training |
|
|
355 |
batch_size: batch size per GPU in ddp |
|
|
356 |
dataset: dataset being used for train/val |
|
|
357 |
arch: encoder architecture used for pre-training |
|
|
358 |
hidden_mlp: hidden layer of non-linear projection head, set to 0 |
|
|
359 |
to use a linear projection head |
|
|
360 |
feat_dim: output dim of the projection head |
|
|
361 |
warmup_epochs: apply linear warmup for this many epochs |
|
|
362 |
max_epochs: epoch count for pre-training |
|
|
363 |
nmb_prototypes: count of prototype vectors |
|
|
364 |
freeze_prototypes_epochs: epoch till which gradients of prototype layer |
|
|
365 |
are frozen |
|
|
366 |
temperature: loss temperature |
|
|
367 |
sinkhorn_iterations: iterations for sinkhorn normalization |
|
|
368 |
queue_length: set queue when batch size is small, |
|
|
369 |
must be divisible by total batch-size (i.e. total_gpus * batch_size), |
|
|
370 |
set to 0 to remove the queue |
|
|
371 |
queue_path: folder within the logs directory |
|
|
372 |
epoch_queue_starts: start uing the queue after this epoch |
|
|
373 |
crops_for_assign: list of crop ids for computing assignment |
|
|
374 |
nmb_crops: number of global and local crops, ex: [2, 6] |
|
|
375 |
first_conv: keep first conv same as the original resnet architecture, |
|
|
376 |
if set to false it is replace by a kernel 3, stride 1 conv (cifar-10) |
|
|
377 |
maxpool1: keep first maxpool layer same as the original resnet architecture, |
|
|
378 |
if set to false, first maxpool is turned off (cifar10, maybe stl10) |
|
|
379 |
optimizer: optimizer to use |
|
|
380 |
lars_wrapper: use LARS wrapper over the optimizer |
|
|
381 |
exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers |
|
|
382 |
start_lr: starting lr for linear warmup |
|
|
383 |
learning_rate: learning rate |
|
|
384 |
final_lr: float = final learning rate for cosine weight decay |
|
|
385 |
weight_decay: weight decay for optimizer |
|
|
386 |
epsilon: epsilon val for swav assignments |
|
|
387 |
""" |
|
|
388 |
super().__init__() |
|
|
389 |
# self.save_hyperparameters() |
|
|
390 |
|
|
|
391 |
self.epoch = 0 |
|
|
392 |
self.config = config |
|
|
393 |
self.transformations = transformations |
|
|
394 |
self.gpus = gpus |
|
|
395 |
self.nodes = nodes |
|
|
396 |
self.arch = arch |
|
|
397 |
self.num_samples = num_samples |
|
|
398 |
self.batch_size = batch_size |
|
|
399 |
self.queue_length = 8*batch_size |
|
|
400 |
|
|
|
401 |
self.hidden_mlp = hidden_mlp |
|
|
402 |
self.feat_dim = feat_dim |
|
|
403 |
self.nmb_prototypes = nmb_prototypes |
|
|
404 |
self.freeze_prototypes_epochs = freeze_prototypes_epochs |
|
|
405 |
self.sinkhorn_iterations = sinkhorn_iterations |
|
|
406 |
|
|
|
407 |
#self.queue_length = queue_length |
|
|
408 |
self.queue_path = queue_path |
|
|
409 |
self.epoch_queue_starts = epoch_queue_starts |
|
|
410 |
self.crops_for_assign = crops_for_assign |
|
|
411 |
self.nmb_crops = nmb_crops |
|
|
412 |
|
|
|
413 |
self.first_conv = first_conv |
|
|
414 |
self.maxpool1 = maxpool1 |
|
|
415 |
|
|
|
416 |
self.optim = optimizer |
|
|
417 |
self.lars_wrapper = lars_wrapper |
|
|
418 |
self.exclude_bn_bias = exclude_bn_bias |
|
|
419 |
self.weight_decay = weight_decay |
|
|
420 |
self.epsilon = epsilon |
|
|
421 |
self.temperature = temperature |
|
|
422 |
|
|
|
423 |
self.start_lr = start_lr |
|
|
424 |
self.final_lr = final_lr |
|
|
425 |
self.learning_rate = learning_rate |
|
|
426 |
self.warmup_epochs = warmup_epochs |
|
|
427 |
self.max_epochs = config["epochs"] |
|
|
428 |
|
|
|
429 |
if self.gpus * self.nodes > 1: |
|
|
430 |
self.get_assignments = self.distributed_sinkhorn |
|
|
431 |
else: |
|
|
432 |
self.get_assignments = self.sinkhorn |
|
|
433 |
|
|
|
434 |
|
|
|
435 |
|
|
|
436 |
# compute iters per epoch |
|
|
437 |
global_batch_size = self.nodes * self.gpus * \ |
|
|
438 |
self.batch_size if self.gpus > 0 else self.batch_size |
|
|
439 |
self.train_iters_per_epoch = (self.num_samples // global_batch_size)+1 |
|
|
440 |
|
|
|
441 |
# define LR schedule |
|
|
442 |
warmup_lr_schedule = np.linspace( |
|
|
443 |
self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs |
|
|
444 |
) |
|
|
445 |
iters = np.arange(self.train_iters_per_epoch * |
|
|
446 |
(self.max_epochs - self.warmup_epochs)) |
|
|
447 |
cosine_lr_schedule = np.array([self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * ( |
|
|
448 |
1 + math.cos(math.pi * t / (self.train_iters_per_epoch * |
|
|
449 |
(self.max_epochs - self.warmup_epochs))) |
|
|
450 |
) for t in iters]) |
|
|
451 |
|
|
|
452 |
self.lr_schedule = np.concatenate( |
|
|
453 |
(warmup_lr_schedule, cosine_lr_schedule)) |
|
|
454 |
self.queue = None |
|
|
455 |
self.model = self.init_model(model) |
|
|
456 |
self.softmax = nn.Softmax(dim=1) |
|
|
457 |
|
|
|
458 |
|
|
|
459 |
def setup(self, stage): |
|
|
460 |
queue_folder = os.path.join(self.config["log_dir"], self.queue_path) |
|
|
461 |
if not os.path.exists(queue_folder): |
|
|
462 |
os.makedirs(queue_folder) |
|
|
463 |
|
|
|
464 |
self.queue_path = os.path.join( |
|
|
465 |
queue_folder, |
|
|
466 |
"queue" + str(self.trainer.global_rank) + ".pth" |
|
|
467 |
) |
|
|
468 |
|
|
|
469 |
if os.path.isfile(self.queue_path): |
|
|
470 |
self.queue = torch.load(self.queue_path)["queue"] |
|
|
471 |
|
|
|
472 |
def init_model(self, model): |
|
|
473 |
return CustomResNet(model, hidden_mlp=self.hidden_mlp, |
|
|
474 |
output_dim=self.feat_dim, |
|
|
475 |
nmb_prototypes=self.nmb_prototypes, |
|
|
476 |
first_conv=self.first_conv, |
|
|
477 |
maxpool1=self.maxpool1) |
|
|
478 |
|
|
|
479 |
def forward(self, x): |
|
|
480 |
# pass single batch from the resnet backbone |
|
|
481 |
return self.model.forward_backbone(x) |
|
|
482 |
|
|
|
483 |
def on_train_start(self): |
|
|
484 |
# # log configuration |
|
|
485 |
# config_str = re.sub(r"[,\}\{]", "<br/>", str(self.config)) |
|
|
486 |
# config_str = re.sub(r"[\[\]\']", "", config_str) |
|
|
487 |
# transformation_str = re.sub(r"[\}]", "<br/>", str(["<br>" + str( |
|
|
488 |
# t) + ":<br/>" + str(t.get_params()) for t in self.transformations])) |
|
|
489 |
# transformation_str = re.sub(r"[,\"\{\'\[\]]", "", transformation_str) |
|
|
490 |
# self.logger.experiment.add_text( |
|
|
491 |
# "configuration", str(config_str), global_step=0) |
|
|
492 |
# self.logger.experiment.add_text("transformations", str( |
|
|
493 |
# transformation_str), global_step=0) |
|
|
494 |
self.epoch = 0 |
|
|
495 |
|
|
|
496 |
def on_train_epoch_start(self): |
|
|
497 |
if self.queue_length > 0: |
|
|
498 |
if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None: |
|
|
499 |
self.queue = torch.zeros( |
|
|
500 |
len(self.crops_for_assign), |
|
|
501 |
self.queue_length // self.gpus, # change to nodes * gpus once multi-node |
|
|
502 |
self.feat_dim, |
|
|
503 |
) |
|
|
504 |
|
|
|
505 |
if self.gpus > 0: |
|
|
506 |
self.queue = self.queue.cuda() |
|
|
507 |
|
|
|
508 |
self.use_the_queue = False |
|
|
509 |
|
|
|
510 |
def on_train_epoch_end(self, outputs) -> None: |
|
|
511 |
if self.queue is not None: |
|
|
512 |
torch.save({"queue": self.queue}, self.queue_path) |
|
|
513 |
|
|
|
514 |
def on_epoch_end(self): |
|
|
515 |
self.epoch += 1 |
|
|
516 |
|
|
|
517 |
def on_after_backward(self): |
|
|
518 |
if self.current_epoch < self.freeze_prototypes_epochs: |
|
|
519 |
for name, p in self.model.named_parameters(): |
|
|
520 |
if "prototypes" in name: |
|
|
521 |
p.grad = None |
|
|
522 |
|
|
|
523 |
def shared_step(self, batch): |
|
|
524 |
# if self.dataset == 'stl10': |
|
|
525 |
# unlabeled_batch = batch[0] |
|
|
526 |
# batch = unlabeled_batch |
|
|
527 |
|
|
|
528 |
|
|
|
529 |
inputs, y = batch |
|
|
530 |
# remove online train/eval transforms at this point |
|
|
531 |
inputs = inputs[:-1] |
|
|
532 |
|
|
|
533 |
# 1. normalize the prototypes |
|
|
534 |
with torch.no_grad(): |
|
|
535 |
w = self.model.prototypes.weight.data.clone() |
|
|
536 |
w = nn.functional.normalize(w, dim=1, p=2) |
|
|
537 |
self.model.prototypes.weight.copy_(w) |
|
|
538 |
|
|
|
539 |
# 2. multi-res forward passes |
|
|
540 |
embedding, output = self.model(inputs) |
|
|
541 |
embedding = embedding.detach() |
|
|
542 |
bs = inputs[0].size(0) |
|
|
543 |
|
|
|
544 |
# 3. swav loss computation |
|
|
545 |
loss = 0 |
|
|
546 |
for i, crop_id in enumerate(self.crops_for_assign): |
|
|
547 |
with torch.no_grad(): |
|
|
548 |
out = output[bs * crop_id: bs * (crop_id + 1)] |
|
|
549 |
|
|
|
550 |
# 4. time to use the queue |
|
|
551 |
if self.queue is not None: |
|
|
552 |
if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0): |
|
|
553 |
self.use_the_queue = True |
|
|
554 |
out = torch.cat((torch.mm( |
|
|
555 |
self.queue[i], |
|
|
556 |
self.model.prototypes.weight.t() |
|
|
557 |
), out)) |
|
|
558 |
# fill the queue |
|
|
559 |
self.queue[i, bs:] = self.queue[i, :-bs].clone() |
|
|
560 |
self.queue[i, :bs] = embedding[crop_id * |
|
|
561 |
bs: (crop_id + 1) * bs] |
|
|
562 |
|
|
|
563 |
# 5. get assignments |
|
|
564 |
q = torch.exp(out / self.epsilon).t() |
|
|
565 |
q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:] |
|
|
566 |
|
|
|
567 |
# cluster assignment prediction |
|
|
568 |
subloss = 0 |
|
|
569 |
for v in np.delete(np.arange(np.sum(self.nmb_crops-1)), crop_id): |
|
|
570 |
p = self.softmax( |
|
|
571 |
output[bs * v: bs * (v + 1)] / self.temperature) |
|
|
572 |
loss_value = q * torch.log(p) |
|
|
573 |
subloss -= torch.mean(torch.sum(loss_value, dim=1)) |
|
|
574 |
loss += subloss / (np.sum(self.nmb_crops) - 1) |
|
|
575 |
loss /= len(self.crops_for_assign) |
|
|
576 |
|
|
|
577 |
return loss |
|
|
578 |
|
|
|
579 |
def training_step(self, batch, batch_idx): |
|
|
580 |
|
|
|
581 |
loss = self.shared_step(batch) |
|
|
582 |
|
|
|
583 |
# self.log('train_loss', loss, on_step=True, on_epoch=False) |
|
|
584 |
return loss |
|
|
585 |
|
|
|
586 |
def validation_step(self, batch, batch_idx, dataloader_idx): |
|
|
587 |
|
|
|
588 |
if dataloader_idx != 0: |
|
|
589 |
return {} |
|
|
590 |
loss = self.shared_step(batch) |
|
|
591 |
|
|
|
592 |
# self.log('val_loss', loss, on_step=False, on_epoch=True) |
|
|
593 |
results = { |
|
|
594 |
'val_loss': loss, |
|
|
595 |
} |
|
|
596 |
return results |
|
|
597 |
|
|
|
598 |
def validation_epoch_end(self, outputs): |
|
|
599 |
# outputs[0] because we are using multiple datasets! |
|
|
600 |
val_loss = mean(outputs[0], 'val_loss') |
|
|
601 |
|
|
|
602 |
log = { |
|
|
603 |
'val/val_loss': val_loss, |
|
|
604 |
} |
|
|
605 |
return {'val_loss': val_loss, 'log': log, 'progress_bar': log} |
|
|
606 |
|
|
|
607 |
def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): |
|
|
608 |
params = [] |
|
|
609 |
excluded_params = [] |
|
|
610 |
|
|
|
611 |
for name, param in named_params: |
|
|
612 |
if not param.requires_grad: |
|
|
613 |
continue |
|
|
614 |
elif any(layer_name in name for layer_name in skip_list): |
|
|
615 |
excluded_params.append(param) |
|
|
616 |
else: |
|
|
617 |
params.append(param) |
|
|
618 |
|
|
|
619 |
return [ |
|
|
620 |
{'params': params, 'weight_decay': weight_decay}, |
|
|
621 |
{'params': excluded_params, 'weight_decay': 0.} |
|
|
622 |
] |
|
|
623 |
|
|
|
624 |
def configure_optimizers(self): |
|
|
625 |
if self.exclude_bn_bias: |
|
|
626 |
params = self.exclude_from_wt_decay( |
|
|
627 |
self.named_parameters(), |
|
|
628 |
weight_decay=self.weight_decay |
|
|
629 |
) |
|
|
630 |
else: |
|
|
631 |
params = self.parameters() |
|
|
632 |
|
|
|
633 |
if self.optim == 'sgd': |
|
|
634 |
optimizer = torch.optim.SGD( |
|
|
635 |
params, |
|
|
636 |
lr=self.learning_rate, |
|
|
637 |
momentum=0.9, |
|
|
638 |
weight_decay=self.weight_decay |
|
|
639 |
) |
|
|
640 |
elif self.optim == 'adam': |
|
|
641 |
optimizer = torch.optim.Adam( |
|
|
642 |
params, |
|
|
643 |
lr=self.learning_rate, |
|
|
644 |
weight_decay=self.weight_decay |
|
|
645 |
) |
|
|
646 |
|
|
|
647 |
if self.lars_wrapper: |
|
|
648 |
optimizer = LARSWrapper( |
|
|
649 |
optimizer, |
|
|
650 |
eta=0.001, # trust coefficient |
|
|
651 |
clip=False |
|
|
652 |
) |
|
|
653 |
|
|
|
654 |
return optimizer |
|
|
655 |
|
|
|
656 |
def optimizer_step( |
|
|
657 |
self, |
|
|
658 |
epoch: int = None, |
|
|
659 |
batch_idx: int = None, |
|
|
660 |
optimizer: Optimizer = None, |
|
|
661 |
optimizer_idx: int = None, |
|
|
662 |
optimizer_closure: Optional[Callable] = None, |
|
|
663 |
on_tpu: bool = None, |
|
|
664 |
using_native_amp: bool = None, |
|
|
665 |
using_lbfgs: bool = None, |
|
|
666 |
) -> None: |
|
|
667 |
# warm-up + decay schedule placed here since LARSWrapper is not optimizer class |
|
|
668 |
# adjust LR of optim contained within LARSWrapper |
|
|
669 |
for param_group in optimizer.param_groups: |
|
|
670 |
param_group["lr"] = self.lr_schedule[self.trainer.global_step] |
|
|
671 |
|
|
|
672 |
# from lightning |
|
|
673 |
if not isinstance(optimizer, LightningOptimizer): |
|
|
674 |
# wraps into LightingOptimizer only for running step |
|
|
675 |
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) |
|
|
676 |
optimizer.step(closure=optimizer_closure) |
|
|
677 |
|
|
|
678 |
def sinkhorn(self, Q, nmb_iters): |
|
|
679 |
with torch.no_grad(): |
|
|
680 |
sum_Q = torch.sum(Q) |
|
|
681 |
Q /= sum_Q |
|
|
682 |
|
|
|
683 |
K, B = Q.shape |
|
|
684 |
|
|
|
685 |
if self.gpus > 0: |
|
|
686 |
u = torch.zeros(K).cuda() |
|
|
687 |
r = torch.ones(K).cuda() / K |
|
|
688 |
c = torch.ones(B).cuda() / B |
|
|
689 |
else: |
|
|
690 |
u = torch.zeros(K) |
|
|
691 |
r = torch.ones(K) / K |
|
|
692 |
c = torch.ones(B) / B |
|
|
693 |
|
|
|
694 |
for _ in range(nmb_iters): |
|
|
695 |
u = torch.sum(Q, dim=1) |
|
|
696 |
|
|
|
697 |
Q *= (r / u).unsqueeze(1) |
|
|
698 |
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) |
|
|
699 |
|
|
|
700 |
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() |
|
|
701 |
|
|
|
702 |
def distributed_sinkhorn(self, Q, nmb_iters): |
|
|
703 |
with torch.no_grad(): |
|
|
704 |
sum_Q = torch.sum(Q) |
|
|
705 |
dist.all_reduce(sum_Q) |
|
|
706 |
Q /= sum_Q |
|
|
707 |
|
|
|
708 |
if self.gpus > 0: |
|
|
709 |
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True) |
|
|
710 |
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0] |
|
|
711 |
c = torch.ones(Q.shape[1]).cuda( |
|
|
712 |
non_blocking=True) / (self.gpus * Q.shape[1]) |
|
|
713 |
else: |
|
|
714 |
u = torch.zeros(Q.shape[0]) |
|
|
715 |
r = torch.ones(Q.shape[0]) / Q.shape[0] |
|
|
716 |
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1]) |
|
|
717 |
|
|
|
718 |
curr_sum = torch.sum(Q, dim=1) |
|
|
719 |
dist.all_reduce(curr_sum) |
|
|
720 |
|
|
|
721 |
for it in range(nmb_iters): |
|
|
722 |
u = curr_sum |
|
|
723 |
Q *= (r / u).unsqueeze(1) |
|
|
724 |
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0) |
|
|
725 |
curr_sum = torch.sum(Q, dim=1) |
|
|
726 |
dist.all_reduce(curr_sum) |
|
|
727 |
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float() |
|
|
728 |
|
|
|
729 |
def type(self): |
|
|
730 |
return self.model.features[0][0].weight.type() |
|
|
731 |
|
|
|
732 |
def get_representations(self, x): |
|
|
733 |
return self.model.features(x) |
|
|
734 |
|
|
|
735 |
def get_model(self): |
|
|
736 |
return self.model.model |
|
|
737 |
|
|
|
738 |
def get_device(self): |
|
|
739 |
return self.model.features[0][0].weight.device |
|
|
740 |
|
|
|
741 |
@staticmethod |
|
|
742 |
def add_model_specific_args(parent_parser): |
|
|
743 |
parser = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
744 |
|
|
|
745 |
# model params |
|
|
746 |
parser.add_argument("--arch", default="resnet50", |
|
|
747 |
type=str, help="convnet architecture") |
|
|
748 |
# specify flags to store false |
|
|
749 |
parser.add_argument("--first_conv", action='store_false') |
|
|
750 |
parser.add_argument("--maxpool1", action='store_false') |
|
|
751 |
parser.add_argument("--hidden_mlp", default=2048, type=int, |
|
|
752 |
help="hidden layer dimension in projection head") |
|
|
753 |
parser.add_argument("--feat_dim", default=128, |
|
|
754 |
type=int, help="feature dimension") |
|
|
755 |
parser.add_argument("--online_ft", action='store_true') |
|
|
756 |
parser.add_argument("--fp32", action='store_true') |
|
|
757 |
|
|
|
758 |
# transform params |
|
|
759 |
parser.add_argument("--gaussian_blur", |
|
|
760 |
action="store_true", help="add gaussian blur") |
|
|
761 |
parser.add_argument("--jitter_strength", type=float, |
|
|
762 |
default=1.0, help="jitter strength") |
|
|
763 |
parser.add_argument("--dataset", type=str, |
|
|
764 |
default="stl10", help="stl10, cifar10") |
|
|
765 |
parser.add_argument("--data_dir", type=str, |
|
|
766 |
default=".", help="path to download data") |
|
|
767 |
parser.add_argument("--queue_path", type=str, |
|
|
768 |
default="queue", help="path for queue") |
|
|
769 |
|
|
|
770 |
parser.add_argument("--nmb_crops", type=int, default=[2, 4], nargs="+", |
|
|
771 |
help="list of number of crops (example: [2, 6])") |
|
|
772 |
parser.add_argument("--size_crops", type=int, default=[96, 36], nargs="+", |
|
|
773 |
help="crops resolutions (example: [224, 96])") |
|
|
774 |
parser.add_argument("--min_scale_crops", type=float, default=[0.33, 0.10], nargs="+", |
|
|
775 |
help="argument in RandomResizedCrop (example: [0.14, 0.05])") |
|
|
776 |
parser.add_argument("--max_scale_crops", type=float, default=[1, 0.33], nargs="+", |
|
|
777 |
help="argument in RandomResizedCrop (example: [1., 0.14])") |
|
|
778 |
|
|
|
779 |
# training params |
|
|
780 |
parser.add_argument("--fast_dev_run", action='store_true') |
|
|
781 |
parser.add_argument("--nodes", default=1, type=int, |
|
|
782 |
help="number of nodes for training") |
|
|
783 |
parser.add_argument("--gpus", default=1, type=int, |
|
|
784 |
help="number of gpus to train on") |
|
|
785 |
parser.add_argument("--num_workers", default=8, |
|
|
786 |
type=int, help="num of workers per GPU") |
|
|
787 |
parser.add_argument("--optimizer", default="adam", |
|
|
788 |
type=str, help="choose between adam/sgd") |
|
|
789 |
parser.add_argument("--lars_wrapper", action='store_true', |
|
|
790 |
help="apple lars wrapper over optimizer used") |
|
|
791 |
parser.add_argument('--exclude_bn_bias', action='store_true', |
|
|
792 |
help="exclude bn/bias from weight decay") |
|
|
793 |
parser.add_argument("--max_epochs", default=100, |
|
|
794 |
type=int, help="number of total epochs to run") |
|
|
795 |
parser.add_argument("--max_steps", default=-1, |
|
|
796 |
type=int, help="max steps") |
|
|
797 |
parser.add_argument("--warmup_epochs", default=10, |
|
|
798 |
type=int, help="number of warmup epochs") |
|
|
799 |
parser.add_argument("--batch_size", default=128, |
|
|
800 |
type=int, help="batch size per gpu") |
|
|
801 |
|
|
|
802 |
parser.add_argument("--weight_decay", default=1e-6, |
|
|
803 |
type=float, help="weight decay") |
|
|
804 |
parser.add_argument("--learning_rate", default=1e-3, |
|
|
805 |
type=float, help="base learning rate") |
|
|
806 |
parser.add_argument("--start_lr", default=0, type=float, |
|
|
807 |
help="initial warmup learning rate") |
|
|
808 |
parser.add_argument("--final_lr", type=float, |
|
|
809 |
default=1e-6, help="final learning rate") |
|
|
810 |
|
|
|
811 |
# swav params |
|
|
812 |
parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1], |
|
|
813 |
help="list of crops id used for computing assignments") |
|
|
814 |
parser.add_argument("--temperature", default=0.1, type=float, |
|
|
815 |
help="temperature parameter in training loss") |
|
|
816 |
parser.add_argument("--epsilon", default=0.05, type=float, |
|
|
817 |
help="regularization parameter for Sinkhorn-Knopp algorithm") |
|
|
818 |
parser.add_argument("--sinkhorn_iterations", default=3, type=int, |
|
|
819 |
help="number of iterations in Sinkhorn-Knopp algorithm") |
|
|
820 |
parser.add_argument("--nmb_prototypes", default=512, |
|
|
821 |
type=int, help="number of prototypes") |
|
|
822 |
parser.add_argument("--queue_length", type=int, default=0, |
|
|
823 |
help="length of the queue (0 for no queue); must be divisible by total batch size") |
|
|
824 |
parser.add_argument("--epoch_queue_starts", type=int, default=15, |
|
|
825 |
help="from this epoch, we start using a queue") |
|
|
826 |
parser.add_argument("--freeze_prototypes_epochs", default=1, type=int, |
|
|
827 |
help="freeze the prototypes during this many epochs from the start") |
|
|
828 |
|
|
|
829 |
return parser |
|
|
830 |
|
|
|
831 |
|
|
|
832 |
def mean(res, key1, key2=None): |
|
|
833 |
if key2 is not None: |
|
|
834 |
return torch.stack([x[key1][key2] for x in res]).mean() |
|
|
835 |
return torch.stack([x[key1] for x in res if type(x) == dict and key1 in x.keys()]).mean() |
|
|
836 |
|
|
|
837 |
def parse_args(parent_parser): |
|
|
838 |
parser = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
839 |
parser.add_argument('-t', '--trafos', nargs='+', help='add transformation to data augmentation pipeline', |
|
|
840 |
default=["GaussianNoise", "ChannelResize", "RandomResizedCrop"]) |
|
|
841 |
# GaussianNoise |
|
|
842 |
parser.add_argument( |
|
|
843 |
'--gaussian_scale', help='std param for gaussian noise transformation', default=0.005, type=float) |
|
|
844 |
# RandomResizedCrop |
|
|
845 |
parser.add_argument('--rr_crop_ratio_range', |
|
|
846 |
help='ratio range for random resized crop transformation', default=[0.5, 1.0], type=float) |
|
|
847 |
parser.add_argument( |
|
|
848 |
'--output_size', help='output size for random resized crop transformation', default=250, type=int) |
|
|
849 |
# DynamicTimeWarp |
|
|
850 |
parser.add_argument( |
|
|
851 |
'--warps', help='number of warps for dynamic time warp transformation', default=3, type=int) |
|
|
852 |
parser.add_argument( |
|
|
853 |
'--radius', help='radius of warps of dynamic time warp transformation', default=10, type=int) |
|
|
854 |
# TimeWarp |
|
|
855 |
parser.add_argument( |
|
|
856 |
'--epsilon', help='epsilon param for time warp', default=10, type=float) |
|
|
857 |
# ChannelResize |
|
|
858 |
parser.add_argument('--magnitude_range', nargs='+', |
|
|
859 |
help='range for scale param for ChannelResize transformation', default=[0.5, 2], type=float) |
|
|
860 |
# Downsample |
|
|
861 |
parser.add_argument( |
|
|
862 |
'--downsample_ratio', help='downsample ratio for Downsample transformation', default=0.2, type=float) |
|
|
863 |
# TimeOut |
|
|
864 |
parser.add_argument('--to_crop_ratio_range', nargs='+', |
|
|
865 |
help='ratio range for timeout transformation', default=[0.2, 0.4], type=float) |
|
|
866 |
# resume training |
|
|
867 |
parser.add_argument('--resume', action='store_true') |
|
|
868 |
parser.add_argument( |
|
|
869 |
'--gpus', help='number of gpus to use; use cpu if gpu=0', type=int, default=1) |
|
|
870 |
parser.add_argument( |
|
|
871 |
'--num_nodes', default=1, help='number of cluster nodes', type=int) |
|
|
872 |
parser.add_argument( |
|
|
873 |
'--distributed_backend', help='sets backend type') |
|
|
874 |
parser.add_argument('--batch_size', type=int) |
|
|
875 |
parser.add_argument('--epochs', type=int) |
|
|
876 |
parser.add_argument('--debug', action='store_true') |
|
|
877 |
parser.add_argument('--warm_up', default=1, type=int) |
|
|
878 |
parser.add_argument('--precision', type=int) |
|
|
879 |
parser.add_argument('--datasets', dest="target_folders", |
|
|
880 |
nargs='+', help='used datasets for pretraining') |
|
|
881 |
parser.add_argument('--log_dir', default="./experiment_logs") |
|
|
882 |
parser.add_argument( |
|
|
883 |
'--percentage', help='determines how much of the dataset shall be used during the pretraining', type=float, default=1.0) |
|
|
884 |
parser.add_argument('--lr', type=float, help="learning rate") |
|
|
885 |
parser.add_argument('--out_dim', type=int, help="output dimension of model") |
|
|
886 |
parser.add_argument('--filter_cinc', default=False, action="store_true", help="only valid if cinc is selected: filter out the ptb data") |
|
|
887 |
parser.add_argument('--base_model') |
|
|
888 |
parser.add_argument('--widen',type=int, help="use wide xresnet1d50") |
|
|
889 |
parser.add_argument('--run_callbacks', default=False, action="store_true", help="run callbacks which asses linear evaluaton and finetuning metrics during pretraining") |
|
|
890 |
|
|
|
891 |
parser.add_argument('--checkpoint_path', default="") |
|
|
892 |
return parser |
|
|
893 |
|
|
|
894 |
def init_logger(config): |
|
|
895 |
level = logging.INFO |
|
|
896 |
|
|
|
897 |
if config['debug']: |
|
|
898 |
level = logging.DEBUG |
|
|
899 |
|
|
|
900 |
# remove all handlers to change basic configuration |
|
|
901 |
for handler in logging.root.handlers[:]: |
|
|
902 |
logging.root.removeHandler(handler) |
|
|
903 |
if not os.path.isdir(config['log_dir']): |
|
|
904 |
os.mkdir(config['log_dir']) |
|
|
905 |
logging.basicConfig(filename=os.path.join(config['log_dir'], 'info.log'), level=level, |
|
|
906 |
format='%(asctime)s %(name)s:%(lineno)s %(levelname)s: %(message)s ') |
|
|
907 |
return logging.getLogger(__name__) |
|
|
908 |
|
|
|
909 |
def pretrain_routine(args): |
|
|
910 |
t_params = {"gaussian_scale": args.gaussian_scale, "rr_crop_ratio_range": args.rr_crop_ratio_range, "output_size": args.output_size, "warps": args.warps, "radius": args.radius, |
|
|
911 |
"epsilon": args.epsilon, "magnitude_range": args.magnitude_range, "downsample_ratio": args.downsample_ratio, "to_crop_ratio_range": args.to_crop_ratio_range, |
|
|
912 |
"bw_cmax":0.1, "em_cmax":0.5, "pl_cmax":0.2, "bs_cmax":1} |
|
|
913 |
transformations = args.trafos |
|
|
914 |
checkpoint_config = os.path.join("checkpoints", "bolts_config.yaml") |
|
|
915 |
config_file = checkpoint_config if args.resume and os.path.isfile( |
|
|
916 |
checkpoint_config) else "bolts_config.yaml" |
|
|
917 |
config = yaml.load(open(config_file, "r"), Loader=yaml.FullLoader) |
|
|
918 |
args_dict = vars(args) |
|
|
919 |
for key in set(config.keys()).union(set(args_dict.keys())): |
|
|
920 |
config[key] = config[key] if (key not in args_dict.keys() or key in args_dict.keys( |
|
|
921 |
) and key in config.keys() and args_dict[key] is None) else args_dict[key] |
|
|
922 |
if args.target_folders is not None: |
|
|
923 |
config["dataset"]["target_folders"] = args.target_folders |
|
|
924 |
config["dataset"]["percentage"] = args.percentage if args.percentage is not None else config["dataset"]["percentage"] |
|
|
925 |
config["dataset"]["filter_cinc"] = args.filter_cinc if args.filter_cinc is not None else config["dataset"]["filter_cinc"] |
|
|
926 |
config["model"]["base_model"] = args.base_model if args.base_model is not None else config["model"]["base_model"] |
|
|
927 |
config["model"]["widen"] = args.widen if args.widen is not None else config["model"]["widen"] |
|
|
928 |
config["dataset"]["swav"] = True |
|
|
929 |
config["dataset"]["nmb_crops"] = 7 |
|
|
930 |
config["eval_dataset"]["swav"] = True |
|
|
931 |
config["eval_dataset"]["nmb_crops"] = 7 |
|
|
932 |
if args.out_dim is not None: |
|
|
933 |
config["model"]["out_dim"] = args.out_dim |
|
|
934 |
init_logger(config) |
|
|
935 |
dataset = SimCLRDataSetWrapper( |
|
|
936 |
config['batch_size'], **config['dataset'], transformations=transformations, t_params=t_params) |
|
|
937 |
for i, t in enumerate(dataset.transformations): |
|
|
938 |
logger.info(str(i) + ". Transformation: " + |
|
|
939 |
str(t) + ": " + str(t.get_params())) |
|
|
940 |
date = time.asctime() |
|
|
941 |
label_to_num_classes = {"label_all": 71, "label_diag": 44, "label_form": 19, |
|
|
942 |
"label_rhythm": 12, "label_diag_subclass": 23, "label_diag_superclass": 5} |
|
|
943 |
ptb_num_classes = label_to_num_classes[config["eval_dataset"] |
|
|
944 |
["ptb_xl_label"]] |
|
|
945 |
abr = {"Transpose": "Tr", "TimeOut": "TO", "DynamicTimeWarp": "DTW", "RandomResizedCrop": "RRC", "ChannelResize": "ChR", "GaussianNoise": "GN", |
|
|
946 |
"TimeWarp": "TW", "ToTensor": "TT", "GaussianBlur": "GB", "BaselineWander": "BlW", "PowerlineNoise": "PlN", "EMNoise": "EM", "BaselineShift": "BlS"} |
|
|
947 |
trs = re.sub(r"[,'\]\[]", "", str([abr[str(tr)] if abr[str(tr)] not in [ |
|
|
948 |
"TT", "Tr"] else '' for tr in dataset.transformations])) |
|
|
949 |
name = str(date) + "_" + method + "_" + str( |
|
|
950 |
time.time_ns())[-3:] + "_" + trs[1:] |
|
|
951 |
tb_logger = TensorBoardLogger(args.log_dir, name=name, version='') |
|
|
952 |
config["log_dir"] = os.path.join(args.log_dir, name) |
|
|
953 |
print(config) |
|
|
954 |
return config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger |
|
|
955 |
|
|
|
956 |
def aftertrain_routine(config, args, trainer, pl_model, datamodule, callbacks): |
|
|
957 |
scores = {} |
|
|
958 |
for ca in callbacks: |
|
|
959 |
if isinstance(ca, SSLOnlineEvaluator): |
|
|
960 |
scores[str(ca)] = {"macro": ca.best_macro} |
|
|
961 |
|
|
|
962 |
results = {"config": config, "trafos": args.trafos, "scores": scores} |
|
|
963 |
|
|
|
964 |
with open(os.path.join(config["log_dir"], "results.pkl"), 'wb') as handle: |
|
|
965 |
pickle.dump(results, handle) |
|
|
966 |
|
|
|
967 |
trainer.save_checkpoint(os.path.join(config["log_dir"], "checkpoints", "model.ckpt")) |
|
|
968 |
with open(os.path.join(config["log_dir"], "config.txt"), "w") as text_file: |
|
|
969 |
print(config, file=text_file) |
|
|
970 |
|
|
|
971 |
def cli_main(): |
|
|
972 |
from pytorch_lightning import Trainer |
|
|
973 |
from online_evaluator import SSLOnlineEvaluator |
|
|
974 |
from ecg_datamodule import ECGDataModule |
|
|
975 |
from clinical_ts.create_logger import create_logger |
|
|
976 |
from os.path import exists |
|
|
977 |
|
|
|
978 |
parser = ArgumentParser() |
|
|
979 |
parser = parse_args(parser) |
|
|
980 |
logger.info("parse arguments") |
|
|
981 |
args = parser.parse_args() |
|
|
982 |
config, dataset, date, transformations, t_params, ptb_num_classes, tb_logger = pretrain_routine(args) |
|
|
983 |
|
|
|
984 |
# data |
|
|
985 |
ecg_datamodule = ECGDataModule(config, transformations, t_params) |
|
|
986 |
|
|
|
987 |
callbacks = [] |
|
|
988 |
if args.run_callbacks: |
|
|
989 |
# callback for online linear evaluation/fine-tuning |
|
|
990 |
linear_evaluator = SSLOnlineEvaluator(drop_p=0, |
|
|
991 |
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="linear_evaluation", verbose=False) |
|
|
992 |
|
|
|
993 |
fine_tuner = SSLOnlineEvaluator(drop_p=0, |
|
|
994 |
z_dim=512, num_classes=ptb_num_classes, hidden_dim=None, lin_eval_epochs=config["eval_epochs"], eval_every=config["eval_every"], mode="fine_tuning", verbose=False) |
|
|
995 |
|
|
|
996 |
callbacks.append(linear_evaluator) |
|
|
997 |
callbacks.append(fine_tuner) |
|
|
998 |
|
|
|
999 |
# configure trainer |
|
|
1000 |
trainer = Trainer(logger=tb_logger, max_epochs=config["epochs"], gpus=args.gpus, |
|
|
1001 |
distributed_backend=args.distributed_backend, auto_lr_find=False, num_nodes=args.num_nodes, precision=config["precision"], callbacks=callbacks) |
|
|
1002 |
|
|
|
1003 |
# pytorch lightning module |
|
|
1004 |
model = ResNetSimCLR(**config["model"]) |
|
|
1005 |
pl_model = CustomSwAV(model, config["gpus"], ecg_datamodule.num_samples, config["batch_size"], config=config, |
|
|
1006 |
transformations=ecg_datamodule.transformations, nmb_crops=config["dataset"]["nmb_crops"]) |
|
|
1007 |
# load checkpoint |
|
|
1008 |
if args.checkpoint_path != "": |
|
|
1009 |
if exists(args.checkpoint_path): |
|
|
1010 |
logger.info("Retrieve checkpoint from " + args.checkpoint_path) |
|
|
1011 |
pl_model.load_from_checkpoint(args.checkpoint_path) |
|
|
1012 |
else: |
|
|
1013 |
raise("checkpoint does not exist") |
|
|
1014 |
|
|
|
1015 |
# start training |
|
|
1016 |
trainer.fit(pl_model, ecg_datamodule) |
|
|
1017 |
|
|
|
1018 |
aftertrain_routine(config, args, trainer, pl_model, ecg_datamodule, callbacks) |
|
|
1019 |
|
|
|
1020 |
if __name__ == "__main__": |
|
|
1021 |
cli_main() |