|
a |
|
b/online_evaluator.py |
|
|
1 |
import math |
|
|
2 |
import pdb |
|
|
3 |
import pytorch_lightning as pl |
|
|
4 |
import torch |
|
|
5 |
from pytorch_lightning.metrics.functional import accuracy |
|
|
6 |
from torch.nn import functional as F |
|
|
7 |
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap |
|
|
8 |
from sklearn.metrics import roc_auc_score |
|
|
9 |
from sklearn.preprocessing import normalize |
|
|
10 |
from torch.nn.modules.linear import Linear |
|
|
11 |
from copy import deepcopy |
|
|
12 |
from clinical_ts.create_logger import create_logger |
|
|
13 |
from tqdm import tqdm |
|
|
14 |
|
|
|
15 |
logger = create_logger(__name__) |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
class SSLOnlineEvaluator(pl.Callback): # pragma: no-cover |
|
|
19 |
|
|
|
20 |
def __init__(self, drop_p: float = 0.0, hidden_dim: int = 1024, z_dim: int = None, num_classes: int = None, lin_eval_epochs=5, eval_every=10, mode="linear_evaluation", discriminative=True, verbose=False): |
|
|
21 |
""" |
|
|
22 |
Attaches a MLP for finetuning using the standard self-supervised protocol. |
|
|
23 |
Example:: |
|
|
24 |
from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator |
|
|
25 |
# your model must have 2 attributes |
|
|
26 |
model = Model() |
|
|
27 |
model.z_dim = ... # the representation dim |
|
|
28 |
model.num_classes = ... # the num of classes in the model |
|
|
29 |
Args: |
|
|
30 |
drop_p: (0.2) dropout probability |
|
|
31 |
hidden_dim: (1024) the hidden dimension for the finetune MLP |
|
|
32 |
""" |
|
|
33 |
super().__init__() |
|
|
34 |
self.hidden_dim = hidden_dim |
|
|
35 |
self.drop_p = drop_p |
|
|
36 |
self.optimizer = None |
|
|
37 |
self.z_dim = z_dim |
|
|
38 |
self.num_classes = num_classes |
|
|
39 |
self.macro = 0 |
|
|
40 |
self.best_macro = 0 |
|
|
41 |
self.lin_eval_epochs = lin_eval_epochs |
|
|
42 |
self.eval_every = eval_every |
|
|
43 |
self.discriminative = discriminative |
|
|
44 |
self.verbose = verbose |
|
|
45 |
if mode == "linear_evaluation": |
|
|
46 |
self.mode = mode |
|
|
47 |
elif mode == "fine_tuning": |
|
|
48 |
self.mode = mode |
|
|
49 |
else: |
|
|
50 |
raise("mode " + str(mode) + " unknown") |
|
|
51 |
|
|
|
52 |
def get_representations(self, features, x): |
|
|
53 |
""" |
|
|
54 |
Override this to customize for the particular model |
|
|
55 |
Args: |
|
|
56 |
pl_module: |
|
|
57 |
x: |
|
|
58 |
""" |
|
|
59 |
if len(x) == 2 and isinstance(x, list): |
|
|
60 |
x = x[0] |
|
|
61 |
|
|
|
62 |
representations = features(x) |
|
|
63 |
|
|
|
64 |
if (isinstance(representations, list) or isinstance(representations, tuple)): |
|
|
65 |
representations = representations[0] |
|
|
66 |
|
|
|
67 |
representations = representations.reshape(representations.size(0), -1) |
|
|
68 |
return representations |
|
|
69 |
|
|
|
70 |
def to_device(self, batch, device): |
|
|
71 |
x, y = batch |
|
|
72 |
return x, y |
|
|
73 |
|
|
|
74 |
def put_on_device(self, batch, device, new_type): |
|
|
75 |
x, y = batch |
|
|
76 |
x = x.type(new_type).to(device) |
|
|
77 |
y = y.type(new_type).to(device) |
|
|
78 |
return x, y |
|
|
79 |
|
|
|
80 |
def on_sanity_check_start(self, trainer, pl_module): |
|
|
81 |
self.val_ds_size = len(trainer.val_dataloaders[0].dataset) |
|
|
82 |
self.last_batch_id = len(trainer.val_dataloaders[0])-1 |
|
|
83 |
|
|
|
84 |
def on_sanity_check_end(self, trainer, pl_module): |
|
|
85 |
self.macro = 0 |
|
|
86 |
|
|
|
87 |
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): |
|
|
88 |
#def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): |
|
|
89 |
# reset mlp after each epoch to get fresh linear evaluation values at every epoch |
|
|
90 |
if pl_module.epoch % self.eval_every == 0 and batch_idx == 0 and dataloader_idx == 0: |
|
|
91 |
new_type, device, valid_loader, features, linear_head, optimizer = self.online_train_setup( |
|
|
92 |
pl_module, trainer) |
|
|
93 |
|
|
|
94 |
loss_per_epoch = [] |
|
|
95 |
macro_per_epoch = [] |
|
|
96 |
linear_head2 = deepcopy(linear_head) |
|
|
97 |
for epoch in tqdm(range(self.lin_eval_epochs)): |
|
|
98 |
|
|
|
99 |
total_loss_one_epoch, linear_head = self.train_one_epoch( |
|
|
100 |
valid_loader, features, linear_head, optimizer, device, new_type) |
|
|
101 |
|
|
|
102 |
if self.verbose: |
|
|
103 |
loss_per_epoch.append(total_loss_one_epoch) |
|
|
104 |
macro, total_loss = self.eval_model( |
|
|
105 |
trainer, features, linear_head, device, new_type) |
|
|
106 |
macro_per_epoch.append(macro) |
|
|
107 |
logger.info("macro at epoch "+str(epoch) + ": " + str(macro)) |
|
|
108 |
logger.info("train loss at epoch "+str(epoch) + ": " + str(total_loss_one_epoch)) |
|
|
109 |
logger.info("test loss at epoch "+str(epoch) + ": " + str(total_loss)) |
|
|
110 |
|
|
|
111 |
macro, total_loss = self.eval_model(trainer, features, linear_head, device, new_type) |
|
|
112 |
self.log_values(trainer, pl_module, macro, total_loss) |
|
|
113 |
|
|
|
114 |
def online_train_setup(self, pl_module, trainer): |
|
|
115 |
new_type = pl_module.type() |
|
|
116 |
device = pl_module.get_device() |
|
|
117 |
valid_loader = trainer.val_dataloaders[1] |
|
|
118 |
if self.mode == "linear_evaluation": |
|
|
119 |
lr = 8e-3 *(valid_loader.batch_size/256) |
|
|
120 |
else: |
|
|
121 |
lr = 8e-5 *(valid_loader.batch_size/256) |
|
|
122 |
# print("using lr:", lr) |
|
|
123 |
# print("using batch size: ", valid_loader.batch_size) |
|
|
124 |
wd = 1e-1 |
|
|
125 |
features = deepcopy(pl_module.get_model()) |
|
|
126 |
linear_head = Linear( |
|
|
127 |
features.l1.in_features, self.num_classes, bias=True).type(new_type) |
|
|
128 |
if self.mode == "linear_evaluation": |
|
|
129 |
optimizer = torch.optim.AdamW( |
|
|
130 |
linear_head.parameters(), lr=lr, weight_decay=wd) |
|
|
131 |
else: |
|
|
132 |
if not self.discriminative: |
|
|
133 |
optimizer = torch.optim.AdamW([ |
|
|
134 |
{"params": features.parameters()}, {"params": linear_head.parameters()}], lr=lr, weight_decay=wd) |
|
|
135 |
else: |
|
|
136 |
lr = (8e-3*(valid_loader.batch_size/256)) |
|
|
137 |
param_dict = dict(features.named_parameters()) |
|
|
138 |
keys = param_dict.keys() |
|
|
139 |
weight_layer_nrs = set() |
|
|
140 |
for key in keys: |
|
|
141 |
if "features" in key: |
|
|
142 |
# parameter names have the form features.x |
|
|
143 |
weight_layer_nrs.add(key[9]) |
|
|
144 |
weight_layer_nrs = sorted(weight_layer_nrs, reverse=True) |
|
|
145 |
features_groups = [] |
|
|
146 |
while len(weight_layer_nrs) > 0: |
|
|
147 |
if len(weight_layer_nrs) > 1: |
|
|
148 |
features_groups.append(list(filter( |
|
|
149 |
lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x, keys))) |
|
|
150 |
del weight_layer_nrs[:2] |
|
|
151 |
else: |
|
|
152 |
features_groups.append( |
|
|
153 |
list(filter(lambda x: "features." + weight_layer_nrs[0] in x, keys))) |
|
|
154 |
del weight_layer_nrs[0] |
|
|
155 |
# linears = list(filter(lambda x: "l" in x, keys)) # filter linear layers |
|
|
156 |
# groups = [linears] + features_groups |
|
|
157 |
optimizer_param_list = [] |
|
|
158 |
tmp_lr = lr |
|
|
159 |
optimizer_param_list.append( |
|
|
160 |
{"params": linear_head.parameters(), "lr": tmp_lr}) |
|
|
161 |
tmp_lr /= 4 |
|
|
162 |
for layers in features_groups: |
|
|
163 |
layer_params = [param_dict[param_name] |
|
|
164 |
for param_name in layers] |
|
|
165 |
optimizer_param_list.append( |
|
|
166 |
{"params": layer_params, "lr": tmp_lr}) |
|
|
167 |
tmp_lr /= 4 |
|
|
168 |
optimizer = torch.optim.AdamW(optimizer_param_list, lr=lr, weight_decay=wd) |
|
|
169 |
|
|
|
170 |
return new_type, device, valid_loader, features, linear_head, optimizer |
|
|
171 |
|
|
|
172 |
def train_one_epoch(self, valid_loader, features, linear_head, optimizer, device, new_type): |
|
|
173 |
linear_head.train() |
|
|
174 |
if self.mode == "linear_evaluation": |
|
|
175 |
# we dont want to update things like batchnorm statistics in linear evaluation |
|
|
176 |
features.eval() |
|
|
177 |
else: |
|
|
178 |
features.train() |
|
|
179 |
total_loss_one_epoch = 0 |
|
|
180 |
for cur_batch in valid_loader: |
|
|
181 |
x, y = self.put_on_device( |
|
|
182 |
cur_batch, device, new_type) |
|
|
183 |
if self.mode == "linear_evaluation": |
|
|
184 |
with torch.no_grad(): |
|
|
185 |
representations = self.get_representations( |
|
|
186 |
features, x) |
|
|
187 |
else: |
|
|
188 |
with torch.enable_grad(): |
|
|
189 |
representations = self.get_representations( |
|
|
190 |
features, x) |
|
|
191 |
# forward pass |
|
|
192 |
with torch.enable_grad(): |
|
|
193 |
mlp_preds = linear_head(representations) |
|
|
194 |
mlp_loss = F.binary_cross_entropy_with_logits( |
|
|
195 |
mlp_preds, y) |
|
|
196 |
# update finetune weights |
|
|
197 |
optimizer.zero_grad() |
|
|
198 |
mlp_loss.backward() |
|
|
199 |
optimizer.step() |
|
|
200 |
total_loss_one_epoch += mlp_loss.item() |
|
|
201 |
return total_loss_one_epoch, linear_head |
|
|
202 |
|
|
|
203 |
def eval_model(self, trainer, features, linear_head, device, new_type): |
|
|
204 |
features.eval() |
|
|
205 |
preds = [] |
|
|
206 |
labels = [] |
|
|
207 |
total_loss = 0 |
|
|
208 |
test_loader = trainer.val_dataloaders[2] |
|
|
209 |
for cur_batch in test_loader: |
|
|
210 |
x, y = self.put_on_device( |
|
|
211 |
cur_batch, device, new_type) |
|
|
212 |
with torch.no_grad(): |
|
|
213 |
representations = self.get_representations(features, x) |
|
|
214 |
mlp_preds = torch.sigmoid( |
|
|
215 |
linear_head(representations)) |
|
|
216 |
preds.append(mlp_preds.cpu()) |
|
|
217 |
labels.append(y.cpu()) |
|
|
218 |
total_loss += F.binary_cross_entropy_with_logits( |
|
|
219 |
mlp_preds, y) |
|
|
220 |
preds = torch.cat(preds).numpy() |
|
|
221 |
labels = torch.cat(labels).numpy() |
|
|
222 |
macro = roc_auc_score(labels, preds) |
|
|
223 |
return macro, total_loss |
|
|
224 |
|
|
|
225 |
def log_values(self, trainer, pl_module, macro, total_loss): |
|
|
226 |
self.best_macro = macro if macro > self.best_macro else self.best_macro |
|
|
227 |
if self.mode == "linear_evaluation": |
|
|
228 |
log_key = "le" |
|
|
229 |
else: |
|
|
230 |
log_key = "ft" |
|
|
231 |
metrics = {log_key + '_mlp/loss': total_loss, |
|
|
232 |
log_key + '_mlp/macro': macro, log_key + '_mlp/best_macro': self.best_macro} |
|
|
233 |
pl_module.logger.log_metrics(metrics, step=trainer.global_step) |
|
|
234 |
|
|
|
235 |
def __str__(self): |
|
|
236 |
return self.mode+"_callback" |