Diff of /encoders.py [000000] .. [352cae]

Switch to side-by-side view

--- a
+++ b/encoders.py
@@ -0,0 +1,157 @@
+# From EsVIT repo
+from esvit.models import build_model
+from esvit.config import config, update_config, save_config
+
+import torch
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+# From EsVIT repo. Need to download full repo. https://github.com/microsoft/esvit
+def load_encoder_esVIT(args, device):
+    # ============ building network ... ============
+    num_features = []
+    # if the network is a 4-stage vision transformer (i.e. swin)
+    if 'swin' in args.arch :
+        update_config(config, args)
+        model = build_model(config, is_teacher=True)
+
+        swin_spec = config.MODEL.SPEC
+        embed_dim=swin_spec['DIM_EMBED']
+        depths=swin_spec['DEPTHS']
+        num_heads=swin_spec['NUM_HEADS'] 
+
+        # For each stage, we have n stacked models (d)
+        # Each model takes embeddings of dimension embed_dim (the first param), 
+        # And then the stage i, input dim is input dim(i-1)*2
+        for i, d in enumerate(depths):
+            num_features += [int(embed_dim * 2 ** i)] * d 
+
+    # if the network is a 4-stage vision transformer (i.e. longformer)
+    elif 'vil' in args.arch :
+        update_config(config, args)
+        model = build_model(config, is_teacher=True)
+
+        msvit_spec = config.MODEL.SPEC
+        arch = msvit_spec.MSVIT.ARCH
+
+        layer_cfgs = model.layer_cfgs
+        num_stages = len(model.layer_cfgs)
+        depths = [cfg['n'] for cfg in model.layer_cfgs]
+        dims = [cfg['d'] for cfg in model.layer_cfgs]
+        out_planes = model.layer_cfgs[-1]['d']
+        Nglos = [cfg['g'] for cfg in model.layer_cfgs]
+
+        print(dims)
+
+        for i, d in enumerate(depths):
+            num_features += [ dims[i] ] * d
+
+    # if the network is a 4-stage vision transformer (i.e. CvT)
+    elif 'cvt' in args.arch :
+        update_config(config, args)
+        model = build_model(config, is_teacher=True)
+
+        cvt_spec = config.MODEL.SPEC
+        embed_dim=cvt_spec['DIM_EMBED']
+        depths=cvt_spec['DEPTH']
+        num_heads=cvt_spec['NUM_HEADS'] 
+
+
+        print(f'embed_dim {embed_dim} depths {depths}')
+        
+        for i, d in enumerate(depths):
+            num_features += [int(embed_dim[i])] * int(d) 
+
+    # if the network is a vanilla vision transformer (i.e. deit_tiny, deit_small, vit_base)
+    else:
+        raise ValueError(f'{args.arch} not supported yet.')
+
+    model.to(device)
+
+    # load weights to evaluate
+    state_dict = torch.load(args.checkpoint, map_location=device)
+    # Technically we can also load the weights of the student but in knowledge distillation, I think it's more common to take the teacher 
+    # and in DINO paper, they show that the teacher learns better. 
+    state_dict = state_dict['teacher']
+    #Line below was initally in the code but I think it's usefless in our case (swin-t)
+    #state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+    
+    #in trained model, you probably have the dense DINO head and in the loaded one a regular head. Those keys won't be matching.
+    #IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head_dense.mlp.0.weight', 'head_dense.mlp.0.bias', 'head_dense.mlp.2.weight', 'head_dense.mlp.2.bias', 'head_dense.mlp.4.weight', 'head_dense.mlp.4.bias', 'head_dense.last_layer.weight_g', 'head_dense.last_layer.weight_v', 'head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
+    #in any case, we do not use the heads but the out features of each stage.
+    msg = model.load_state_dict(state_dict, strict=False)
+    print(msg)
+    model.eval()
+    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built with pretrained weigths {args.checkpoint}.")
+    
+    ##a choice, 4 will take the last 2 stages for instance
+    #if n>1, they are just stacked features.
+    #paper says : For all transformers architecture, we use the concatenation of view-level features
+    # in the last layers (results are similar to the use of 3 or 5 layers in our initial experiments)
+    
+    num_features_linear = sum(num_features[-args.n_last_blocks:])
+    print(f'num_features_linear {num_features_linear}')
+
+    return model, num_features_linear, depths
+
+# Regular resnet encoder. 
+def load_encoder_resnet(backbone, checkpoint_file, use_imagenet_weights, device):
+    import torch.nn as nn
+    import torchvision.models as models
+
+    class DecapitatedResnet(nn.Module):
+        def __init__(self, base_encoder, pretrained):
+            super(DecapitatedResnet, self).__init__()
+            self.encoder = base_encoder(pretrained=pretrained)
+
+        def forward(self, x):
+            # Same forward pass function as used in the torchvision 'stock' ResNet code
+            # but with the final FC layer removed.
+            x = self.encoder.conv1(x)
+            x = self.encoder.bn1(x)
+            x = self.encoder.relu(x)
+            x = self.encoder.maxpool(x)
+
+            x = self.encoder.layer1(x)
+            x = self.encoder.layer2(x)
+            x = self.encoder.layer3(x)
+            x = self.encoder.layer4(x)
+
+            x = self.encoder.avgpool(x)
+            x = torch.flatten(x, 1)
+
+            return x
+
+    model = DecapitatedResnet(models.__dict__[backbone], use_imagenet_weights)
+
+    if use_imagenet_weights:
+        if checkpoint_file is not None:
+            raise Exception(
+                "Either provide a weights checkpoint or the --imagenet flag, not both."
+            )
+        print(f"Created encoder with Imagenet weights")
+    else:
+        checkpoint = torch.load(checkpoint_file, map_location="cpu")
+        state_dict = checkpoint["state_dict"]
+        for k in list(state_dict.keys()):
+            # retain only encoder_q up to before the embedding layer
+            if k.startswith("module.encoder_q") and not k.startswith(
+                "module.encoder_q.fc"
+            ):
+                # remove prefix from key names
+                state_dict[k[len("module.encoder_q.") :]] = state_dict[k]
+            # delete renamed or unused k
+            del state_dict[k]
+
+        # Verify that the checkpoint did not contain data for the final FC layer
+        msg = model.encoder.load_state_dict(state_dict, strict=False)
+        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
+        print(f"Loaded checkpoint {checkpoint_file}")
+
+    model = model.to(device)
+    if torch.cuda.device_count() > 1:
+        model = torch.nn.DataParallel(model)
+    model.eval()
+
+    return model
+