|
a |
|
b/main_cpc_lightning.py |
|
|
1 |
############### |
|
|
2 |
#generic |
|
|
3 |
import torch |
|
|
4 |
from torch import nn |
|
|
5 |
import pytorch_lightning as pl |
|
|
6 |
from torch.utils.data import DataLoader, ConcatDataset |
|
|
7 |
from torchvision import transforms |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
|
|
|
10 |
import torchvision |
|
|
11 |
import os |
|
|
12 |
import argparse |
|
|
13 |
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
14 |
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
15 |
import copy |
|
|
16 |
|
|
|
17 |
################# |
|
|
18 |
#specific |
|
|
19 |
from clinical_ts.timeseries_utils import * |
|
|
20 |
from clinical_ts.ecg_utils import * |
|
|
21 |
|
|
|
22 |
from functools import partial |
|
|
23 |
from pathlib import Path |
|
|
24 |
import pandas as pd |
|
|
25 |
import numpy as np |
|
|
26 |
|
|
|
27 |
from clinical_ts.xresnet1d import xresnet1d50,xresnet1d101 |
|
|
28 |
from clinical_ts.basic_conv1d import weight_init |
|
|
29 |
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap |
|
|
30 |
from clinical_ts.cpc import * |
|
|
31 |
|
|
|
32 |
def _freeze_bn_stats(model, freeze=True): |
|
|
33 |
for m in model.modules(): |
|
|
34 |
if(isinstance(m,nn.BatchNorm1d)): |
|
|
35 |
if(freeze): |
|
|
36 |
m.eval() |
|
|
37 |
else: |
|
|
38 |
m.train() |
|
|
39 |
|
|
|
40 |
def sanity_check(model, state_dict_pre): |
|
|
41 |
""" |
|
|
42 |
Linear classifier should not change any weights other than the linear layer. |
|
|
43 |
This sanity check asserts nothing wrong happens (e.g., BN stats updated). |
|
|
44 |
""" |
|
|
45 |
print("=> loading state dict for sanity check") |
|
|
46 |
state_dict = model.state_dict() |
|
|
47 |
|
|
|
48 |
for k in list(state_dict.keys()): |
|
|
49 |
print(k) |
|
|
50 |
# only ignore fc layer |
|
|
51 |
if 'head.1.weight' in k or 'head.1.bias' in k: |
|
|
52 |
continue |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
assert ((state_dict[k].cpu() == state_dict_pre[k].cpu()).all()), \ |
|
|
56 |
'{} is changed in linear classifier training.'.format(k) |
|
|
57 |
|
|
|
58 |
print("=> sanity check passed.") |
|
|
59 |
|
|
|
60 |
class LightningCPC(pl.LightningModule): |
|
|
61 |
|
|
|
62 |
def __init__(self, hparams): |
|
|
63 |
super(LightningCPC, self).__init__() |
|
|
64 |
|
|
|
65 |
self.hparams = hparams |
|
|
66 |
self.lr = self.hparams.lr |
|
|
67 |
|
|
|
68 |
#these coincide with the adapted wav2vec2 params |
|
|
69 |
if(self.hparams.fc_encoder): |
|
|
70 |
strides=[1]*4 |
|
|
71 |
kss = [1]*4 |
|
|
72 |
features = [512]*4 |
|
|
73 |
else: #strided conv encoder |
|
|
74 |
strides=[2,2,2,2] #original wav2vec2 [5,2,2,2,2,2] original cpc [5,4,2,2,2] |
|
|
75 |
kss = [10,4,4,4] #original wav2vec2 [10,3,3,3,3,2] original cpc [18,8,4,4,4] |
|
|
76 |
features = [512]*4 #wav2vec2 [512]*6 original cpc [512]*5 |
|
|
77 |
|
|
|
78 |
if(self.hparams.finetune): |
|
|
79 |
self.criterion = F.cross_entropy if self.hparams.finetune_dataset == "thew" else F.binary_cross_entropy_with_logits |
|
|
80 |
if(self.hparams.finetune_dataset == "thew"): |
|
|
81 |
num_classes = 5 |
|
|
82 |
elif(self.hparams.finetune_dataset == "ptbxl_super"): |
|
|
83 |
num_classes = 5 |
|
|
84 |
if(self.hparams.finetune_dataset == "ptbxl_all"): |
|
|
85 |
num_classes = 71 |
|
|
86 |
else: |
|
|
87 |
num_classes = None |
|
|
88 |
|
|
|
89 |
self.model_cpc = CPCModel(input_channels=self.hparams.input_channels, strides=strides,kss=kss,features=features,n_hidden=self.hparams.n_hidden,n_layers=self.hparams.n_layers,mlp=self.hparams.mlp,lstm=not(self.hparams.gru),bias_proj=self.hparams.bias,num_classes=num_classes,skip_encoder=self.hparams.skip_encoder,bn_encoder=not(self.hparams.no_bn_encoder),lin_ftrs_head=[] if self.hparams.linear_eval else eval(self.hparams.lin_ftrs_head),ps_head=0 if self.hparams.linear_eval else self.hparams.dropout_head,bn_head=False if self.hparams.linear_eval else not(self.hparams.no_bn_head)) |
|
|
90 |
|
|
|
91 |
target_fs=100 |
|
|
92 |
if(not(self.hparams.finetune)): |
|
|
93 |
print("CPC pretraining:\ndownsampling factor:",self.model_cpc.encoder_downsampling_factor,"\nchunk length(s)",self.model_cpc.encoder_downsampling_factor/target_fs,"\npixels predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted,"\nseconds predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted/target_fs,"\nRNN input size:",self.hparams.input_size//self.model_cpc.encoder_downsampling_factor) |
|
|
94 |
|
|
|
95 |
def forward(self, x): |
|
|
96 |
return self.model_cpc(x) |
|
|
97 |
|
|
|
98 |
def _step(self,data_batch, batch_idx, train): |
|
|
99 |
if(self.hparams.finetune): |
|
|
100 |
preds = self.forward(data_batch[0]) |
|
|
101 |
loss = self.criterion(preds,data_batch[1]) |
|
|
102 |
self.log("train_loss" if train else "val_loss", loss) |
|
|
103 |
return {'loss':loss, "preds":preds.detach(), "targs": data_batch[1]} |
|
|
104 |
else: |
|
|
105 |
loss, acc = self.model_cpc.cpc_loss(data_batch[0],steps_predicted=self.hparams.steps_predicted,n_false_negatives=self.hparams.n_false_negatives, negatives_from_same_seq_only=self.hparams.negatives_from_same_seq_only, eval_acc=True) |
|
|
106 |
self.log("loss" if train else "val_loss", loss) |
|
|
107 |
self.log("acc" if train else "val_acc", acc) |
|
|
108 |
return loss |
|
|
109 |
|
|
|
110 |
def training_step(self, train_batch, batch_idx): |
|
|
111 |
if(self.hparams.linear_eval): |
|
|
112 |
_freeze_bn_stats(self) |
|
|
113 |
return self._step(train_batch,batch_idx,True) |
|
|
114 |
|
|
|
115 |
def validation_step(self, val_batch, batch_idx, dataloader_idx=0): |
|
|
116 |
return self._step(val_batch,batch_idx,False) |
|
|
117 |
|
|
|
118 |
def validation_epoch_end(self, outputs_all): |
|
|
119 |
if(self.hparams.finetune): |
|
|
120 |
for dataloader_idx,outputs in enumerate(outputs_all): #multiple val dataloaders |
|
|
121 |
preds_all = torch.cat([x['preds'] for x in outputs]) |
|
|
122 |
targs_all = torch.cat([x['targs'] for x in outputs]) |
|
|
123 |
if(self.hparams.finetune_dataset=="thew"): |
|
|
124 |
preds_all = F.softmax(preds_all,dim=-1) |
|
|
125 |
targs_all = torch.eye(len(self.lbl_itos))[targs_all].to(preds.device) |
|
|
126 |
else: |
|
|
127 |
preds_all = torch.sigmoid(preds_all) |
|
|
128 |
preds_all = preds_all.cpu().numpy() |
|
|
129 |
targs_all = targs_all.cpu().numpy() |
|
|
130 |
#instance level score |
|
|
131 |
res = eval_scores(targs_all,preds_all,classes=self.lbl_itos) |
|
|
132 |
|
|
|
133 |
idmap = self.val_dataset.get_id_mapping() |
|
|
134 |
preds_all_agg,targs_all_agg = aggregate_predictions(preds_all,targs_all,idmap,aggregate_fn=np.mean) |
|
|
135 |
res_agg = eval_scores(targs_all_agg,preds_all_agg,classes=self.lbl_itos) |
|
|
136 |
self.log_dict({"macro_auc_agg"+str(dataloader_idx):res_agg["label_AUC"]["macro"], "macro_auc_noagg"+str(dataloader_idx):res["label_AUC"]["macro"]}) |
|
|
137 |
print("epoch",self.current_epoch,"macro_auc_agg"+str(dataloader_idx)+":",res_agg["label_AUC"]["macro"],"macro_auc_noagg"+str(dataloader_idx)+":",res["label_AUC"]["macro"]) |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
def on_fit_start(self): |
|
|
141 |
if(self.hparams.linear_eval): |
|
|
142 |
print("copying state dict before training for sanity check after training") |
|
|
143 |
self.state_dict_pre = copy.deepcopy(self.state_dict().copy()) |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
def on_fit_end(self): |
|
|
147 |
if(self.hparams.linear_eval): |
|
|
148 |
sanity_check(self,self.state_dict_pre) |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
def setup(self, stage): |
|
|
152 |
# configure dataset params |
|
|
153 |
chunkify_train = False |
|
|
154 |
chunk_length_train = self.hparams.input_size if chunkify_train else 0 |
|
|
155 |
stride_train = self.hparams.input_size |
|
|
156 |
|
|
|
157 |
chunkify_valtest = True |
|
|
158 |
chunk_length_valtest = self.hparams.input_size if chunkify_valtest else 0 |
|
|
159 |
stride_valtest = self.hparams.input_size//2 |
|
|
160 |
|
|
|
161 |
train_datasets = [] |
|
|
162 |
val_datasets = [] |
|
|
163 |
test_datasets = [] |
|
|
164 |
|
|
|
165 |
for i,target_folder in enumerate(self.hparams.data): |
|
|
166 |
target_folder = Path(target_folder) |
|
|
167 |
|
|
|
168 |
df_mapped, lbl_itos, mean, std = load_dataset(target_folder) |
|
|
169 |
# always use PTB-XL stats |
|
|
170 |
mean = np.array([-0.00184586, -0.00130277, 0.00017031, -0.00091313, -0.00148835, -0.00174687, -0.00077071, -0.00207407, 0.00054329, 0.00155546, -0.00114379, -0.00035649]) |
|
|
171 |
std = np.array([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807, 0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828, 0.14606956, 0.14656108]) |
|
|
172 |
|
|
|
173 |
#specific for PTB-XL |
|
|
174 |
if(self.hparams.finetune and self.hparams.finetune_dataset.startswith("ptbxl")): |
|
|
175 |
if(self.hparams.finetune_dataset=="ptbxl_super"): |
|
|
176 |
ptb_xl_label = "label_diag_superclass" |
|
|
177 |
elif(self.hparams.finetune_dataset=="ptbxl_all"): |
|
|
178 |
ptb_xl_label = "label_all" |
|
|
179 |
|
|
|
180 |
lbl_itos= np.array(lbl_itos[ptb_xl_label]) |
|
|
181 |
|
|
|
182 |
def multihot_encode(x, num_classes): |
|
|
183 |
res = np.zeros(num_classes,dtype=np.float32) |
|
|
184 |
for y in x: |
|
|
185 |
res[y]=1 |
|
|
186 |
return res |
|
|
187 |
|
|
|
188 |
df_mapped["label"]= df_mapped[ptb_xl_label+"_filtered_numeric"].apply(lambda x: multihot_encode(x,len(lbl_itos))) |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
self.lbl_itos = lbl_itos |
|
|
192 |
tfms_ptb_xl_cpc = ToTensor() if self.hparams.normalize is False else transforms.Compose([Normalize(mean,std),ToTensor()]) |
|
|
193 |
|
|
|
194 |
max_fold_id = df_mapped.strat_fold.max() #unfortunately 1-based for PTB-XL; sometimes 100 (Ribeiro) |
|
|
195 |
df_train = df_mapped[df_mapped.strat_fold<(max_fold_id-1 if self.hparams.finetune else max_fold_id)] |
|
|
196 |
df_val = df_mapped[df_mapped.strat_fold==(max_fold_id-1 if self.hparams.finetune else max_fold_id)] |
|
|
197 |
if(self.hparams.finetune): |
|
|
198 |
df_test = df_mapped[df_mapped.strat_fold==max_fold_id] |
|
|
199 |
train_datasets.append(TimeseriesDatasetCrops(df_train,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_train,min_chunk_length=self.hparams.input_size, stride=stride_train,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy"))) |
|
|
200 |
val_datasets.append(TimeseriesDatasetCrops(df_val,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy"))) |
|
|
201 |
if(self.hparams.finetune): |
|
|
202 |
test_datasets.append(TimeseriesDatasetCrops(df_test,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label",memmap_filename=target_folder/("memmap.npy"))) |
|
|
203 |
|
|
|
204 |
print("\n",target_folder) |
|
|
205 |
print("train dataset:",len(train_datasets[-1]),"samples") |
|
|
206 |
print("val dataset:",len(val_datasets[-1]),"samples") |
|
|
207 |
if(self.hparams.finetune): |
|
|
208 |
print("test dataset:",len(test_datasets[-1]),"samples") |
|
|
209 |
|
|
|
210 |
if(len(train_datasets)>1): #multiple data folders |
|
|
211 |
print("\nCombined:") |
|
|
212 |
self.train_dataset = ConcatDataset(train_datasets) |
|
|
213 |
self.val_dataset = ConcatDataset(val_datasets) |
|
|
214 |
print("train dataset:",len(self.train_dataset),"samples") |
|
|
215 |
print("val dataset:",len(self.val_dataset),"samples") |
|
|
216 |
if(self.hparams.finetune): |
|
|
217 |
self.test_dataset = ConcatDataset(test_datasets) |
|
|
218 |
print("test dataset:",len(self.test_dataset),"samples") |
|
|
219 |
else: #just a single data folder |
|
|
220 |
self.train_dataset = train_datasets[0] |
|
|
221 |
self.val_dataset = val_datasets[0] |
|
|
222 |
if(self.hparams.finetune): |
|
|
223 |
self.test_dataset = test_datasets[0] |
|
|
224 |
|
|
|
225 |
def train_dataloader(self): |
|
|
226 |
return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last = True) |
|
|
227 |
|
|
|
228 |
def val_dataloader(self): |
|
|
229 |
if(self.hparams.finetune):#multiple val dataloaders |
|
|
230 |
return [DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4),DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, num_workers=4)] |
|
|
231 |
else: |
|
|
232 |
return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4) |
|
|
233 |
|
|
|
234 |
def configure_optimizers(self): |
|
|
235 |
if(self.hparams.optimizer == "sgd"): |
|
|
236 |
opt = torch.optim.SGD |
|
|
237 |
elif(self.hparams.optimizer == "adam"): |
|
|
238 |
opt = torch.optim.AdamW |
|
|
239 |
else: |
|
|
240 |
raise NotImplementedError("Unknown Optimizer.") |
|
|
241 |
|
|
|
242 |
if(self.hparams.finetune and (self.hparams.linear_eval or self.hparams.train_head_only)): |
|
|
243 |
optimizer = opt(self.model_cpc.head.parameters(), self.lr, weight_decay=self.hparams.weight_decay) |
|
|
244 |
elif(self.hparams.finetune and self.hparams.discriminative_lr_factor != 1.):#discrimative lrs |
|
|
245 |
optimizer = opt([{"params":self.model_cpc.encoder.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.rnn.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.head.parameters(), "lr":self.lr}],self.hparams.lr, weight_decay=self.hparams.weight_decay) |
|
|
246 |
else: |
|
|
247 |
optimizer = opt(self.parameters(), self.lr, weight_decay=self.hparams.weight_decay) |
|
|
248 |
|
|
|
249 |
return optimizer |
|
|
250 |
|
|
|
251 |
def load_weights_from_checkpoint(self, checkpoint): |
|
|
252 |
""" Function that loads the weights from a given checkpoint file. |
|
|
253 |
based on https://github.com/PyTorchLightning/pytorch-lightning/issues/525 |
|
|
254 |
""" |
|
|
255 |
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,) |
|
|
256 |
pretrained_dict = checkpoint["state_dict"] |
|
|
257 |
model_dict = self.state_dict() |
|
|
258 |
|
|
|
259 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
|
|
260 |
model_dict.update(pretrained_dict) |
|
|
261 |
self.load_state_dict(model_dict) |
|
|
262 |
|
|
|
263 |
##################################################################################################### |
|
|
264 |
#ARGPARSERS |
|
|
265 |
##################################################################################################### |
|
|
266 |
def add_model_specific_args(parser): |
|
|
267 |
parser.add_argument("--input-channels", type=int, default=12) |
|
|
268 |
parser.add_argument("--normalize", action='store_true', help='Normalize input using PTB-XL stats') |
|
|
269 |
parser.add_argument('--mlp', action='store_true', help="False: original CPC True: as in SimCLR") |
|
|
270 |
parser.add_argument('--bias', action='store_true', help="original CPC: no bias") |
|
|
271 |
parser.add_argument("--n-hidden", type=int, default=512) |
|
|
272 |
parser.add_argument("--gru", action="store_true") |
|
|
273 |
parser.add_argument("--n-layers", type=int, default=2) |
|
|
274 |
parser.add_argument("--steps-predicted", dest="steps_predicted", type=int, default=12) |
|
|
275 |
parser.add_argument("--n-false-negatives", dest="n_false_negatives", type=int, default=128) |
|
|
276 |
parser.add_argument("--skip-encoder", action="store_true", help="disable the convolutional encoder i.e. just RNN; for testing") |
|
|
277 |
parser.add_argument("--fc-encoder", action="store_true", help="use a fully connected encoder (as opposed to an encoder with strided convs)") |
|
|
278 |
parser.add_argument("--negatives-from-same-seq-only", action="store_true", help="only draw false negatives from same sequence (as opposed to drawing from everywhere)") |
|
|
279 |
parser.add_argument("--no-bn-encoder", action="store_true", help="switch off batch normalization in encoder") |
|
|
280 |
parser.add_argument("--dropout-head", type=float, default=0.5) |
|
|
281 |
parser.add_argument("--train-head-only", action="store_true", help="freeze everything except classification head (note: --linear-eval defaults to no hidden layer in classification head)") |
|
|
282 |
parser.add_argument("--lin-ftrs-head", type=str, default="[512]", help="hidden layers in the classification head") |
|
|
283 |
parser.add_argument('--no-bn-head', action='store_true', help="use no batch normalization in classification head") |
|
|
284 |
return parser |
|
|
285 |
|
|
|
286 |
def add_default_args(): |
|
|
287 |
parser = argparse.ArgumentParser(description='PyTorch Lightning CPC Training') |
|
|
288 |
parser.add_argument('--data', metavar='DIR',type=str, |
|
|
289 |
help='path(s) to dataset',action='append') |
|
|
290 |
parser.add_argument('--epochs', default=30, type=int, metavar='N', |
|
|
291 |
help='number of total epochs to run') |
|
|
292 |
parser.add_argument('--batch-size', default=64, type=int, |
|
|
293 |
metavar='N', |
|
|
294 |
help='mini-batch size (default: 256), this is the total ' |
|
|
295 |
'batch size of all GPUs on the current node when ' |
|
|
296 |
'using Data Parallel or Distributed Data Parallel') |
|
|
297 |
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, |
|
|
298 |
metavar='LR', help='initial learning rate', dest='lr') |
|
|
299 |
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, |
|
|
300 |
metavar='W', help='weight decay (default: 0.)', |
|
|
301 |
dest='weight_decay') |
|
|
302 |
|
|
|
303 |
parser.add_argument('--resume', default='', type=str, metavar='PATH', |
|
|
304 |
help='path to latest checkpoint (default: none)') |
|
|
305 |
|
|
|
306 |
parser.add_argument('--pretrained', default='', type=str, metavar='PATH', |
|
|
307 |
help='path to pretrained checkpoint (default: none)') |
|
|
308 |
parser.add_argument('--optimizer', default='adam', help='sgd/adam')#was sgd |
|
|
309 |
parser.add_argument('--output-path', default='.', type=str,dest="output_path", |
|
|
310 |
help='output path') |
|
|
311 |
parser.add_argument('--metadata', default='', type=str, |
|
|
312 |
help='metadata for output') |
|
|
313 |
|
|
|
314 |
parser.add_argument("--gpus", type=int, default=1, help="number of gpus") |
|
|
315 |
parser.add_argument("--num-nodes", dest="num_nodes", type=int, default=1, help="number of compute nodes") |
|
|
316 |
parser.add_argument("--precision", type=int, default=16, help="16/32") |
|
|
317 |
parser.add_argument("--distributed-backend", dest="distributed_backend", type=str, default=None, help="None/ddp") |
|
|
318 |
parser.add_argument("--accumulate", type=int, default=1, help="accumulate grad batches (total-bs=accumulate-batches*bs)") |
|
|
319 |
|
|
|
320 |
parser.add_argument("--input-size", dest="input_size", type=int, default=16000) |
|
|
321 |
|
|
|
322 |
parser.add_argument("--finetune", action="store_true", help="finetuning (downstream classification task)", default=False ) |
|
|
323 |
parser.add_argument("--linear-eval", action="store_true", help="linear evaluation instead of full finetuning", default=False ) |
|
|
324 |
|
|
|
325 |
parser.add_argument( |
|
|
326 |
"--finetune-dataset", |
|
|
327 |
type=str, |
|
|
328 |
help="thew/ptbxl_super/ptbxl_all", |
|
|
329 |
default="thew" |
|
|
330 |
) |
|
|
331 |
|
|
|
332 |
parser.add_argument( |
|
|
333 |
"--discriminative-lr-factor", |
|
|
334 |
type=float, |
|
|
335 |
help="factor by which the lr decreases per layer group during finetuning", |
|
|
336 |
default=0.1 |
|
|
337 |
) |
|
|
338 |
|
|
|
339 |
|
|
|
340 |
parser.add_argument( |
|
|
341 |
"--lr-find", |
|
|
342 |
action="store_true", |
|
|
343 |
help="run lr finder before training run", |
|
|
344 |
default=False |
|
|
345 |
) |
|
|
346 |
|
|
|
347 |
return parser |
|
|
348 |
|
|
|
349 |
################################################################################################### |
|
|
350 |
#MAIN |
|
|
351 |
################################################################################################### |
|
|
352 |
if __name__ == '__main__': |
|
|
353 |
parser = add_default_args() |
|
|
354 |
parser = add_model_specific_args(parser) |
|
|
355 |
hparams = parser.parse_args() |
|
|
356 |
hparams.executable = "cpc" |
|
|
357 |
|
|
|
358 |
if not os.path.exists(hparams.output_path): |
|
|
359 |
os.makedirs(hparams.output_path) |
|
|
360 |
|
|
|
361 |
model = LightningCPC(hparams) |
|
|
362 |
|
|
|
363 |
if(hparams.pretrained!=""): |
|
|
364 |
print("Loading pretrained weights from",hparams.pretrained) |
|
|
365 |
model.load_weights_from_checkpoint(hparams.pretrained) |
|
|
366 |
|
|
|
367 |
|
|
|
368 |
logger = TensorBoardLogger( |
|
|
369 |
save_dir=hparams.output_path, |
|
|
370 |
#version="",#hparams.metadata.split(":")[0], |
|
|
371 |
name="") |
|
|
372 |
print("Output directory:",logger.log_dir) |
|
|
373 |
checkpoint_callback = ModelCheckpoint( |
|
|
374 |
filepath=os.path.join(logger.log_dir,"best_model"),#hparams.output_path |
|
|
375 |
save_top_k=1, |
|
|
376 |
save_last=True, |
|
|
377 |
verbose=True, |
|
|
378 |
monitor='macro_auc_agg0' if hparams.finetune else 'val_loss',#val_loss/dataloader_idx_0 |
|
|
379 |
mode='max' if hparams.finetune else 'min', |
|
|
380 |
prefix='') |
|
|
381 |
lr_monitor = LearningRateMonitor() |
|
|
382 |
|
|
|
383 |
trainer = pl.Trainer( |
|
|
384 |
#overfit_batches=0.01, |
|
|
385 |
auto_lr_find = hparams.lr_find, |
|
|
386 |
accumulate_grad_batches=hparams.accumulate, |
|
|
387 |
max_epochs=hparams.epochs, |
|
|
388 |
min_epochs=hparams.epochs, |
|
|
389 |
|
|
|
390 |
default_root_dir=hparams.output_path, |
|
|
391 |
|
|
|
392 |
num_sanity_val_steps=0, |
|
|
393 |
|
|
|
394 |
logger=logger, |
|
|
395 |
checkpoint_callback=checkpoint_callback, |
|
|
396 |
callbacks = [],#lr_monitor], |
|
|
397 |
benchmark=True, |
|
|
398 |
|
|
|
399 |
gpus=hparams.gpus, |
|
|
400 |
num_nodes=hparams.num_nodes, |
|
|
401 |
precision=hparams.precision, |
|
|
402 |
distributed_backend=hparams.distributed_backend, |
|
|
403 |
|
|
|
404 |
progress_bar_refresh_rate=0, |
|
|
405 |
weights_summary='top', |
|
|
406 |
resume_from_checkpoint= None if hparams.resume=="" else hparams.resume) |
|
|
407 |
|
|
|
408 |
if(hparams.lr_find):#lr find |
|
|
409 |
trainer.tune(model) |
|
|
410 |
|
|
|
411 |
trainer.fit(model) |
|
|
412 |
|
|
|
413 |
|