Diff of /online_evaluator.py [000000] .. [134fd7]

Switch to unified view

a b/online_evaluator.py
1
import math
2
import pdb
3
import pytorch_lightning as pl
4
import torch
5
from pytorch_lightning.metrics.functional import accuracy
6
from torch.nn import functional as F
7
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
8
from sklearn.metrics import roc_auc_score
9
from sklearn.preprocessing import normalize
10
from torch.nn.modules.linear import Linear
11
from copy import deepcopy
12
from clinical_ts.create_logger import create_logger
13
from tqdm import tqdm
14
15
logger = create_logger(__name__)
16
17
18
class SSLOnlineEvaluator(pl.Callback):  # pragma: no-cover
19
20
    def __init__(self, drop_p: float = 0.0, hidden_dim: int = 1024, z_dim: int = None, num_classes: int = None, lin_eval_epochs=5, eval_every=10, mode="linear_evaluation", discriminative=True, verbose=False):
21
        """
22
        Attaches a MLP for finetuning using the standard self-supervised protocol.
23
        Example::
24
            from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
25
            # your model must have 2 attributes
26
            model = Model()
27
            model.z_dim = ... # the representation dim
28
            model.num_classes = ... # the num of classes in the model
29
        Args:
30
            drop_p: (0.2) dropout probability
31
            hidden_dim: (1024) the hidden dimension for the finetune MLP
32
        """
33
        super().__init__()
34
        self.hidden_dim = hidden_dim
35
        self.drop_p = drop_p
36
        self.optimizer = None
37
        self.z_dim = z_dim
38
        self.num_classes = num_classes
39
        self.macro = 0
40
        self.best_macro = 0
41
        self.lin_eval_epochs = lin_eval_epochs
42
        self.eval_every = eval_every
43
        self.discriminative = discriminative
44
        self.verbose = verbose
45
        if mode == "linear_evaluation":
46
            self.mode = mode
47
        elif mode == "fine_tuning":
48
            self.mode = mode
49
        else:
50
            raise("mode " + str(mode) + " unknown")
51
52
    def get_representations(self, features, x):
53
        """
54
        Override this to customize for the particular model
55
        Args:
56
            pl_module:
57
            x:
58
        """
59
        if len(x) == 2 and isinstance(x, list):
60
            x = x[0]
61
62
        representations = features(x)
63
64
        if (isinstance(representations, list) or isinstance(representations, tuple)):
65
            representations = representations[0]
66
67
        representations = representations.reshape(representations.size(0), -1)
68
        return representations
69
70
    def to_device(self, batch, device):
71
        x, y = batch
72
        return x, y
73
74
    def put_on_device(self, batch, device, new_type):
75
        x, y = batch
76
        x = x.type(new_type).to(device)
77
        y = y.type(new_type).to(device)
78
        return x, y
79
80
    def on_sanity_check_start(self, trainer, pl_module):
81
        self.val_ds_size = len(trainer.val_dataloaders[0].dataset)
82
        self.last_batch_id = len(trainer.val_dataloaders[0])-1
83
84
    def on_sanity_check_end(self, trainer, pl_module):
85
        self.macro = 0
86
87
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
88
    #def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
89
        # reset mlp after each epoch to get fresh linear evaluation values at every epoch
90
        if pl_module.epoch % self.eval_every == 0 and batch_idx == 0 and dataloader_idx == 0:
91
            new_type, device, valid_loader, features, linear_head, optimizer = self.online_train_setup(
92
                pl_module, trainer)
93
94
            loss_per_epoch = []
95
            macro_per_epoch = []
96
            linear_head2 = deepcopy(linear_head)
97
            for epoch in tqdm(range(self.lin_eval_epochs)):
98
99
                total_loss_one_epoch, linear_head = self.train_one_epoch(
100
                    valid_loader, features, linear_head, optimizer, device, new_type)
101
                
102
                if self.verbose:
103
                    loss_per_epoch.append(total_loss_one_epoch)
104
                    macro, total_loss = self.eval_model(
105
                        trainer, features, linear_head, device, new_type)
106
                    macro_per_epoch.append(macro)
107
                    logger.info("macro at epoch "+str(epoch) + ": " + str(macro))
108
                    logger.info("train loss at epoch "+str(epoch) + ": " + str(total_loss_one_epoch))
109
                    logger.info("test loss at epoch "+str(epoch) + ": " + str(total_loss))
110
111
            macro, total_loss = self.eval_model(trainer, features, linear_head, device, new_type)
112
            self.log_values(trainer, pl_module, macro, total_loss)
113
           
114
    def online_train_setup(self, pl_module, trainer):
115
        new_type = pl_module.type()
116
        device = pl_module.get_device()
117
        valid_loader = trainer.val_dataloaders[1]
118
        if self.mode == "linear_evaluation":
119
            lr = 8e-3 *(valid_loader.batch_size/256)
120
        else: 
121
            lr = 8e-5 *(valid_loader.batch_size/256)
122
        # print("using lr:", lr)
123
        # print("using batch size: ", valid_loader.batch_size)
124
        wd = 1e-1
125
        features = deepcopy(pl_module.get_model())
126
        linear_head = Linear(
127
            features.l1.in_features, self.num_classes, bias=True).type(new_type)
128
        if self.mode == "linear_evaluation":
129
            optimizer = torch.optim.AdamW(
130
                linear_head.parameters(), lr=lr, weight_decay=wd)
131
        else:
132
            if not self.discriminative:
133
                optimizer = torch.optim.AdamW([
134
                    {"params": features.parameters()}, {"params": linear_head.parameters()}], lr=lr, weight_decay=wd)
135
            else:
136
                lr = (8e-3*(valid_loader.batch_size/256))
137
                param_dict = dict(features.named_parameters())
138
                keys = param_dict.keys()
139
                weight_layer_nrs = set()
140
                for key in keys:
141
                    if "features" in key:
142
                        # parameter names have the form features.x
143
                        weight_layer_nrs.add(key[9])
144
                weight_layer_nrs = sorted(weight_layer_nrs, reverse=True)
145
                features_groups = []
146
                while len(weight_layer_nrs) > 0:
147
                    if len(weight_layer_nrs) > 1:
148
                        features_groups.append(list(filter(
149
                            lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x,  keys)))
150
                        del weight_layer_nrs[:2]
151
                    else:
152
                        features_groups.append(
153
                            list(filter(lambda x: "features." + weight_layer_nrs[0] in x,  keys)))
154
                        del weight_layer_nrs[0]
155
                # linears = list(filter(lambda x: "l" in x, keys)) # filter linear layers
156
                # groups = [linears] + features_groups
157
                optimizer_param_list = []
158
                tmp_lr = lr
159
                optimizer_param_list.append(
160
                        {"params": linear_head.parameters(), "lr": tmp_lr})
161
                tmp_lr /= 4
162
                for layers in features_groups:
163
                    layer_params = [param_dict[param_name]
164
                                    for param_name in layers]
165
                    optimizer_param_list.append(
166
                        {"params": layer_params, "lr": tmp_lr})
167
                    tmp_lr /= 4
168
                optimizer = torch.optim.AdamW(optimizer_param_list, lr=lr, weight_decay=wd)
169
            
170
        return new_type, device, valid_loader, features, linear_head, optimizer
171
172
    def train_one_epoch(self, valid_loader, features, linear_head, optimizer, device, new_type):
173
        linear_head.train()
174
        if self.mode == "linear_evaluation":
175
            # we dont want to update things like batchnorm statistics in linear evaluation
176
            features.eval()
177
        else:
178
            features.train()
179
        total_loss_one_epoch = 0
180
        for cur_batch in valid_loader:
181
            x, y = self.put_on_device(
182
                cur_batch, device, new_type)
183
            if self.mode == "linear_evaluation":
184
                with torch.no_grad():
185
                    representations = self.get_representations(
186
                        features, x)
187
            else:
188
                with torch.enable_grad():
189
                    representations = self.get_representations(
190
                        features, x)
191
            # forward pass
192
            with torch.enable_grad():
193
                mlp_preds = linear_head(representations)
194
                mlp_loss = F.binary_cross_entropy_with_logits(
195
                    mlp_preds, y)
196
                # update finetune weights
197
                optimizer.zero_grad()
198
                mlp_loss.backward()
199
                optimizer.step()
200
            total_loss_one_epoch += mlp_loss.item()
201
        return total_loss_one_epoch, linear_head
202
203
    def eval_model(self, trainer, features, linear_head, device, new_type):
204
        features.eval()
205
        preds = []
206
        labels = []
207
        total_loss = 0
208
        test_loader = trainer.val_dataloaders[2]
209
        for cur_batch in test_loader:
210
            x, y = self.put_on_device(
211
                cur_batch, device, new_type)
212
            with torch.no_grad():
213
                representations = self.get_representations(features, x)
214
                mlp_preds = torch.sigmoid(
215
                    linear_head(representations))
216
                preds.append(mlp_preds.cpu())
217
                labels.append(y.cpu())
218
                total_loss += F.binary_cross_entropy_with_logits(
219
                    mlp_preds, y)
220
        preds = torch.cat(preds).numpy()
221
        labels = torch.cat(labels).numpy()
222
        macro = roc_auc_score(labels, preds)
223
        return macro, total_loss
224
225
    def log_values(self, trainer, pl_module, macro, total_loss):
226
        self.best_macro = macro if macro > self.best_macro else self.best_macro
227
        if self.mode == "linear_evaluation":
228
            log_key = "le"
229
        else:
230
            log_key = "ft"
231
        metrics = {log_key + '_mlp/loss': total_loss,
232
                   log_key + '_mlp/macro': macro, log_key + '_mlp/best_macro': self.best_macro}
233
        pl_module.logger.log_metrics(metrics, step=trainer.global_step)
234
235
    def __str__(self):
236
        return self.mode+"_callback"