Diff of /online_evaluator.py [000000] .. [134fd7]

Switch to side-by-side view

--- a
+++ b/online_evaluator.py
@@ -0,0 +1,236 @@
+import math
+import pdb
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning.metrics.functional import accuracy
+from torch.nn import functional as F
+from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
+from sklearn.metrics import roc_auc_score
+from sklearn.preprocessing import normalize
+from torch.nn.modules.linear import Linear
+from copy import deepcopy
+from clinical_ts.create_logger import create_logger
+from tqdm import tqdm
+
+logger = create_logger(__name__)
+
+
+class SSLOnlineEvaluator(pl.Callback):  # pragma: no-cover
+
+    def __init__(self, drop_p: float = 0.0, hidden_dim: int = 1024, z_dim: int = None, num_classes: int = None, lin_eval_epochs=5, eval_every=10, mode="linear_evaluation", discriminative=True, verbose=False):
+        """
+        Attaches a MLP for finetuning using the standard self-supervised protocol.
+        Example::
+            from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator
+            # your model must have 2 attributes
+            model = Model()
+            model.z_dim = ... # the representation dim
+            model.num_classes = ... # the num of classes in the model
+        Args:
+            drop_p: (0.2) dropout probability
+            hidden_dim: (1024) the hidden dimension for the finetune MLP
+        """
+        super().__init__()
+        self.hidden_dim = hidden_dim
+        self.drop_p = drop_p
+        self.optimizer = None
+        self.z_dim = z_dim
+        self.num_classes = num_classes
+        self.macro = 0
+        self.best_macro = 0
+        self.lin_eval_epochs = lin_eval_epochs
+        self.eval_every = eval_every
+        self.discriminative = discriminative
+        self.verbose = verbose
+        if mode == "linear_evaluation":
+            self.mode = mode
+        elif mode == "fine_tuning":
+            self.mode = mode
+        else:
+            raise("mode " + str(mode) + " unknown")
+
+    def get_representations(self, features, x):
+        """
+        Override this to customize for the particular model
+        Args:
+            pl_module:
+            x:
+        """
+        if len(x) == 2 and isinstance(x, list):
+            x = x[0]
+
+        representations = features(x)
+
+        if (isinstance(representations, list) or isinstance(representations, tuple)):
+            representations = representations[0]
+
+        representations = representations.reshape(representations.size(0), -1)
+        return representations
+
+    def to_device(self, batch, device):
+        x, y = batch
+        return x, y
+
+    def put_on_device(self, batch, device, new_type):
+        x, y = batch
+        x = x.type(new_type).to(device)
+        y = y.type(new_type).to(device)
+        return x, y
+
+    def on_sanity_check_start(self, trainer, pl_module):
+        self.val_ds_size = len(trainer.val_dataloaders[0].dataset)
+        self.last_batch_id = len(trainer.val_dataloaders[0])-1
+
+    def on_sanity_check_end(self, trainer, pl_module):
+        self.macro = 0
+
+    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+    #def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
+        # reset mlp after each epoch to get fresh linear evaluation values at every epoch
+        if pl_module.epoch % self.eval_every == 0 and batch_idx == 0 and dataloader_idx == 0:
+            new_type, device, valid_loader, features, linear_head, optimizer = self.online_train_setup(
+                pl_module, trainer)
+
+            loss_per_epoch = []
+            macro_per_epoch = []
+            linear_head2 = deepcopy(linear_head)
+            for epoch in tqdm(range(self.lin_eval_epochs)):
+
+                total_loss_one_epoch, linear_head = self.train_one_epoch(
+                    valid_loader, features, linear_head, optimizer, device, new_type)
+                
+                if self.verbose:
+                    loss_per_epoch.append(total_loss_one_epoch)
+                    macro, total_loss = self.eval_model(
+                        trainer, features, linear_head, device, new_type)
+                    macro_per_epoch.append(macro)
+                    logger.info("macro at epoch "+str(epoch) + ": " + str(macro))
+                    logger.info("train loss at epoch "+str(epoch) + ": " + str(total_loss_one_epoch))
+                    logger.info("test loss at epoch "+str(epoch) + ": " + str(total_loss))
+
+            macro, total_loss = self.eval_model(trainer, features, linear_head, device, new_type)
+            self.log_values(trainer, pl_module, macro, total_loss)
+           
+    def online_train_setup(self, pl_module, trainer):
+        new_type = pl_module.type()
+        device = pl_module.get_device()
+        valid_loader = trainer.val_dataloaders[1]
+        if self.mode == "linear_evaluation":
+            lr = 8e-3 *(valid_loader.batch_size/256)
+        else: 
+            lr = 8e-5 *(valid_loader.batch_size/256)
+        # print("using lr:", lr)
+        # print("using batch size: ", valid_loader.batch_size)
+        wd = 1e-1
+        features = deepcopy(pl_module.get_model())
+        linear_head = Linear(
+            features.l1.in_features, self.num_classes, bias=True).type(new_type)
+        if self.mode == "linear_evaluation":
+            optimizer = torch.optim.AdamW(
+                linear_head.parameters(), lr=lr, weight_decay=wd)
+        else:
+            if not self.discriminative:
+                optimizer = torch.optim.AdamW([
+                    {"params": features.parameters()}, {"params": linear_head.parameters()}], lr=lr, weight_decay=wd)
+            else:
+                lr = (8e-3*(valid_loader.batch_size/256))
+                param_dict = dict(features.named_parameters())
+                keys = param_dict.keys()
+                weight_layer_nrs = set()
+                for key in keys:
+                    if "features" in key:
+                        # parameter names have the form features.x
+                        weight_layer_nrs.add(key[9])
+                weight_layer_nrs = sorted(weight_layer_nrs, reverse=True)
+                features_groups = []
+                while len(weight_layer_nrs) > 0:
+                    if len(weight_layer_nrs) > 1:
+                        features_groups.append(list(filter(
+                            lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x,  keys)))
+                        del weight_layer_nrs[:2]
+                    else:
+                        features_groups.append(
+                            list(filter(lambda x: "features." + weight_layer_nrs[0] in x,  keys)))
+                        del weight_layer_nrs[0]
+                # linears = list(filter(lambda x: "l" in x, keys)) # filter linear layers
+                # groups = [linears] + features_groups
+                optimizer_param_list = []
+                tmp_lr = lr
+                optimizer_param_list.append(
+                        {"params": linear_head.parameters(), "lr": tmp_lr})
+                tmp_lr /= 4
+                for layers in features_groups:
+                    layer_params = [param_dict[param_name]
+                                    for param_name in layers]
+                    optimizer_param_list.append(
+                        {"params": layer_params, "lr": tmp_lr})
+                    tmp_lr /= 4
+                optimizer = torch.optim.AdamW(optimizer_param_list, lr=lr, weight_decay=wd)
+            
+        return new_type, device, valid_loader, features, linear_head, optimizer
+
+    def train_one_epoch(self, valid_loader, features, linear_head, optimizer, device, new_type):
+        linear_head.train()
+        if self.mode == "linear_evaluation":
+            # we dont want to update things like batchnorm statistics in linear evaluation
+            features.eval()
+        else:
+            features.train()
+        total_loss_one_epoch = 0
+        for cur_batch in valid_loader:
+            x, y = self.put_on_device(
+                cur_batch, device, new_type)
+            if self.mode == "linear_evaluation":
+                with torch.no_grad():
+                    representations = self.get_representations(
+                        features, x)
+            else:
+                with torch.enable_grad():
+                    representations = self.get_representations(
+                        features, x)
+            # forward pass
+            with torch.enable_grad():
+                mlp_preds = linear_head(representations)
+                mlp_loss = F.binary_cross_entropy_with_logits(
+                    mlp_preds, y)
+                # update finetune weights
+                optimizer.zero_grad()
+                mlp_loss.backward()
+                optimizer.step()
+            total_loss_one_epoch += mlp_loss.item()
+        return total_loss_one_epoch, linear_head
+
+    def eval_model(self, trainer, features, linear_head, device, new_type):
+        features.eval()
+        preds = []
+        labels = []
+        total_loss = 0
+        test_loader = trainer.val_dataloaders[2]
+        for cur_batch in test_loader:
+            x, y = self.put_on_device(
+                cur_batch, device, new_type)
+            with torch.no_grad():
+                representations = self.get_representations(features, x)
+                mlp_preds = torch.sigmoid(
+                    linear_head(representations))
+                preds.append(mlp_preds.cpu())
+                labels.append(y.cpu())
+                total_loss += F.binary_cross_entropy_with_logits(
+                    mlp_preds, y)
+        preds = torch.cat(preds).numpy()
+        labels = torch.cat(labels).numpy()
+        macro = roc_auc_score(labels, preds)
+        return macro, total_loss
+
+    def log_values(self, trainer, pl_module, macro, total_loss):
+        self.best_macro = macro if macro > self.best_macro else self.best_macro
+        if self.mode == "linear_evaluation":
+            log_key = "le"
+        else:
+            log_key = "ft"
+        metrics = {log_key + '_mlp/loss': total_loss,
+                   log_key + '_mlp/macro': macro, log_key + '_mlp/best_macro': self.best_macro}
+        pl_module.logger.log_metrics(metrics, step=trainer.global_step)
+
+    def __str__(self):
+        return self.mode+"_callback"
\ No newline at end of file