Diff of /encoders.py [000000] .. [352cae]

Switch to unified view

a b/encoders.py
1
# From EsVIT repo
2
from esvit.models import build_model
3
from esvit.config import config, update_config, save_config
4
5
import torch
6
from torch.utils.data import DataLoader, Dataset
7
from torchvision import transforms
8
9
# From EsVIT repo. Need to download full repo. https://github.com/microsoft/esvit
10
def load_encoder_esVIT(args, device):
11
    # ============ building network ... ============
12
    num_features = []
13
    # if the network is a 4-stage vision transformer (i.e. swin)
14
    if 'swin' in args.arch :
15
        update_config(config, args)
16
        model = build_model(config, is_teacher=True)
17
18
        swin_spec = config.MODEL.SPEC
19
        embed_dim=swin_spec['DIM_EMBED']
20
        depths=swin_spec['DEPTHS']
21
        num_heads=swin_spec['NUM_HEADS'] 
22
23
        # For each stage, we have n stacked models (d)
24
        # Each model takes embeddings of dimension embed_dim (the first param), 
25
        # And then the stage i, input dim is input dim(i-1)*2
26
        for i, d in enumerate(depths):
27
            num_features += [int(embed_dim * 2 ** i)] * d 
28
29
    # if the network is a 4-stage vision transformer (i.e. longformer)
30
    elif 'vil' in args.arch :
31
        update_config(config, args)
32
        model = build_model(config, is_teacher=True)
33
34
        msvit_spec = config.MODEL.SPEC
35
        arch = msvit_spec.MSVIT.ARCH
36
37
        layer_cfgs = model.layer_cfgs
38
        num_stages = len(model.layer_cfgs)
39
        depths = [cfg['n'] for cfg in model.layer_cfgs]
40
        dims = [cfg['d'] for cfg in model.layer_cfgs]
41
        out_planes = model.layer_cfgs[-1]['d']
42
        Nglos = [cfg['g'] for cfg in model.layer_cfgs]
43
44
        print(dims)
45
46
        for i, d in enumerate(depths):
47
            num_features += [ dims[i] ] * d
48
49
    # if the network is a 4-stage vision transformer (i.e. CvT)
50
    elif 'cvt' in args.arch :
51
        update_config(config, args)
52
        model = build_model(config, is_teacher=True)
53
54
        cvt_spec = config.MODEL.SPEC
55
        embed_dim=cvt_spec['DIM_EMBED']
56
        depths=cvt_spec['DEPTH']
57
        num_heads=cvt_spec['NUM_HEADS'] 
58
59
60
        print(f'embed_dim {embed_dim} depths {depths}')
61
        
62
        for i, d in enumerate(depths):
63
            num_features += [int(embed_dim[i])] * int(d) 
64
65
    # if the network is a vanilla vision transformer (i.e. deit_tiny, deit_small, vit_base)
66
    else:
67
        raise ValueError(f'{args.arch} not supported yet.')
68
69
    model.to(device)
70
71
    # load weights to evaluate
72
    state_dict = torch.load(args.checkpoint, map_location=device)
73
    # Technically we can also load the weights of the student but in knowledge distillation, I think it's more common to take the teacher 
74
    # and in DINO paper, they show that the teacher learns better. 
75
    state_dict = state_dict['teacher']
76
    #Line below was initally in the code but I think it's usefless in our case (swin-t)
77
    #state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
78
    
79
    #in trained model, you probably have the dense DINO head and in the loaded one a regular head. Those keys won't be matching.
80
    #IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head_dense.mlp.0.weight', 'head_dense.mlp.0.bias', 'head_dense.mlp.2.weight', 'head_dense.mlp.2.bias', 'head_dense.mlp.4.weight', 'head_dense.mlp.4.bias', 'head_dense.last_layer.weight_g', 'head_dense.last_layer.weight_v', 'head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])
81
    #in any case, we do not use the heads but the out features of each stage.
82
    msg = model.load_state_dict(state_dict, strict=False)
83
    print(msg)
84
    model.eval()
85
    print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built with pretrained weigths {args.checkpoint}.")
86
    
87
    ##a choice, 4 will take the last 2 stages for instance
88
    #if n>1, they are just stacked features.
89
    #paper says : For all transformers architecture, we use the concatenation of view-level features
90
    # in the last layers (results are similar to the use of 3 or 5 layers in our initial experiments)
91
    
92
    num_features_linear = sum(num_features[-args.n_last_blocks:])
93
    print(f'num_features_linear {num_features_linear}')
94
95
    return model, num_features_linear, depths
96
97
# Regular resnet encoder. 
98
def load_encoder_resnet(backbone, checkpoint_file, use_imagenet_weights, device):
99
    import torch.nn as nn
100
    import torchvision.models as models
101
102
    class DecapitatedResnet(nn.Module):
103
        def __init__(self, base_encoder, pretrained):
104
            super(DecapitatedResnet, self).__init__()
105
            self.encoder = base_encoder(pretrained=pretrained)
106
107
        def forward(self, x):
108
            # Same forward pass function as used in the torchvision 'stock' ResNet code
109
            # but with the final FC layer removed.
110
            x = self.encoder.conv1(x)
111
            x = self.encoder.bn1(x)
112
            x = self.encoder.relu(x)
113
            x = self.encoder.maxpool(x)
114
115
            x = self.encoder.layer1(x)
116
            x = self.encoder.layer2(x)
117
            x = self.encoder.layer3(x)
118
            x = self.encoder.layer4(x)
119
120
            x = self.encoder.avgpool(x)
121
            x = torch.flatten(x, 1)
122
123
            return x
124
125
    model = DecapitatedResnet(models.__dict__[backbone], use_imagenet_weights)
126
127
    if use_imagenet_weights:
128
        if checkpoint_file is not None:
129
            raise Exception(
130
                "Either provide a weights checkpoint or the --imagenet flag, not both."
131
            )
132
        print(f"Created encoder with Imagenet weights")
133
    else:
134
        checkpoint = torch.load(checkpoint_file, map_location="cpu")
135
        state_dict = checkpoint["state_dict"]
136
        for k in list(state_dict.keys()):
137
            # retain only encoder_q up to before the embedding layer
138
            if k.startswith("module.encoder_q") and not k.startswith(
139
                "module.encoder_q.fc"
140
            ):
141
                # remove prefix from key names
142
                state_dict[k[len("module.encoder_q.") :]] = state_dict[k]
143
            # delete renamed or unused k
144
            del state_dict[k]
145
146
        # Verify that the checkpoint did not contain data for the final FC layer
147
        msg = model.encoder.load_state_dict(state_dict, strict=False)
148
        assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
149
        print(f"Loaded checkpoint {checkpoint_file}")
150
151
    model = model.to(device)
152
    if torch.cuda.device_count() > 1:
153
        model = torch.nn.DataParallel(model)
154
    model.eval()
155
156
    return model
157