|
a |
|
b/encoders.py |
|
|
1 |
# From EsVIT repo |
|
|
2 |
from esvit.models import build_model |
|
|
3 |
from esvit.config import config, update_config, save_config |
|
|
4 |
|
|
|
5 |
import torch |
|
|
6 |
from torch.utils.data import DataLoader, Dataset |
|
|
7 |
from torchvision import transforms |
|
|
8 |
|
|
|
9 |
# From EsVIT repo. Need to download full repo. https://github.com/microsoft/esvit |
|
|
10 |
def load_encoder_esVIT(args, device): |
|
|
11 |
# ============ building network ... ============ |
|
|
12 |
num_features = [] |
|
|
13 |
# if the network is a 4-stage vision transformer (i.e. swin) |
|
|
14 |
if 'swin' in args.arch : |
|
|
15 |
update_config(config, args) |
|
|
16 |
model = build_model(config, is_teacher=True) |
|
|
17 |
|
|
|
18 |
swin_spec = config.MODEL.SPEC |
|
|
19 |
embed_dim=swin_spec['DIM_EMBED'] |
|
|
20 |
depths=swin_spec['DEPTHS'] |
|
|
21 |
num_heads=swin_spec['NUM_HEADS'] |
|
|
22 |
|
|
|
23 |
# For each stage, we have n stacked models (d) |
|
|
24 |
# Each model takes embeddings of dimension embed_dim (the first param), |
|
|
25 |
# And then the stage i, input dim is input dim(i-1)*2 |
|
|
26 |
for i, d in enumerate(depths): |
|
|
27 |
num_features += [int(embed_dim * 2 ** i)] * d |
|
|
28 |
|
|
|
29 |
# if the network is a 4-stage vision transformer (i.e. longformer) |
|
|
30 |
elif 'vil' in args.arch : |
|
|
31 |
update_config(config, args) |
|
|
32 |
model = build_model(config, is_teacher=True) |
|
|
33 |
|
|
|
34 |
msvit_spec = config.MODEL.SPEC |
|
|
35 |
arch = msvit_spec.MSVIT.ARCH |
|
|
36 |
|
|
|
37 |
layer_cfgs = model.layer_cfgs |
|
|
38 |
num_stages = len(model.layer_cfgs) |
|
|
39 |
depths = [cfg['n'] for cfg in model.layer_cfgs] |
|
|
40 |
dims = [cfg['d'] for cfg in model.layer_cfgs] |
|
|
41 |
out_planes = model.layer_cfgs[-1]['d'] |
|
|
42 |
Nglos = [cfg['g'] for cfg in model.layer_cfgs] |
|
|
43 |
|
|
|
44 |
print(dims) |
|
|
45 |
|
|
|
46 |
for i, d in enumerate(depths): |
|
|
47 |
num_features += [ dims[i] ] * d |
|
|
48 |
|
|
|
49 |
# if the network is a 4-stage vision transformer (i.e. CvT) |
|
|
50 |
elif 'cvt' in args.arch : |
|
|
51 |
update_config(config, args) |
|
|
52 |
model = build_model(config, is_teacher=True) |
|
|
53 |
|
|
|
54 |
cvt_spec = config.MODEL.SPEC |
|
|
55 |
embed_dim=cvt_spec['DIM_EMBED'] |
|
|
56 |
depths=cvt_spec['DEPTH'] |
|
|
57 |
num_heads=cvt_spec['NUM_HEADS'] |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
print(f'embed_dim {embed_dim} depths {depths}') |
|
|
61 |
|
|
|
62 |
for i, d in enumerate(depths): |
|
|
63 |
num_features += [int(embed_dim[i])] * int(d) |
|
|
64 |
|
|
|
65 |
# if the network is a vanilla vision transformer (i.e. deit_tiny, deit_small, vit_base) |
|
|
66 |
else: |
|
|
67 |
raise ValueError(f'{args.arch} not supported yet.') |
|
|
68 |
|
|
|
69 |
model.to(device) |
|
|
70 |
|
|
|
71 |
# load weights to evaluate |
|
|
72 |
state_dict = torch.load(args.checkpoint, map_location=device) |
|
|
73 |
# Technically we can also load the weights of the student but in knowledge distillation, I think it's more common to take the teacher |
|
|
74 |
# and in DINO paper, they show that the teacher learns better. |
|
|
75 |
state_dict = state_dict['teacher'] |
|
|
76 |
#Line below was initally in the code but I think it's usefless in our case (swin-t) |
|
|
77 |
#state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
78 |
|
|
|
79 |
#in trained model, you probably have the dense DINO head and in the loaded one a regular head. Those keys won't be matching. |
|
|
80 |
#IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head_dense.mlp.0.weight', 'head_dense.mlp.0.bias', 'head_dense.mlp.2.weight', 'head_dense.mlp.2.bias', 'head_dense.mlp.4.weight', 'head_dense.mlp.4.bias', 'head_dense.last_layer.weight_g', 'head_dense.last_layer.weight_v', 'head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v']) |
|
|
81 |
#in any case, we do not use the heads but the out features of each stage. |
|
|
82 |
msg = model.load_state_dict(state_dict, strict=False) |
|
|
83 |
print(msg) |
|
|
84 |
model.eval() |
|
|
85 |
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built with pretrained weigths {args.checkpoint}.") |
|
|
86 |
|
|
|
87 |
##a choice, 4 will take the last 2 stages for instance |
|
|
88 |
#if n>1, they are just stacked features. |
|
|
89 |
#paper says : For all transformers architecture, we use the concatenation of view-level features |
|
|
90 |
# in the last layers (results are similar to the use of 3 or 5 layers in our initial experiments) |
|
|
91 |
|
|
|
92 |
num_features_linear = sum(num_features[-args.n_last_blocks:]) |
|
|
93 |
print(f'num_features_linear {num_features_linear}') |
|
|
94 |
|
|
|
95 |
return model, num_features_linear, depths |
|
|
96 |
|
|
|
97 |
# Regular resnet encoder. |
|
|
98 |
def load_encoder_resnet(backbone, checkpoint_file, use_imagenet_weights, device): |
|
|
99 |
import torch.nn as nn |
|
|
100 |
import torchvision.models as models |
|
|
101 |
|
|
|
102 |
class DecapitatedResnet(nn.Module): |
|
|
103 |
def __init__(self, base_encoder, pretrained): |
|
|
104 |
super(DecapitatedResnet, self).__init__() |
|
|
105 |
self.encoder = base_encoder(pretrained=pretrained) |
|
|
106 |
|
|
|
107 |
def forward(self, x): |
|
|
108 |
# Same forward pass function as used in the torchvision 'stock' ResNet code |
|
|
109 |
# but with the final FC layer removed. |
|
|
110 |
x = self.encoder.conv1(x) |
|
|
111 |
x = self.encoder.bn1(x) |
|
|
112 |
x = self.encoder.relu(x) |
|
|
113 |
x = self.encoder.maxpool(x) |
|
|
114 |
|
|
|
115 |
x = self.encoder.layer1(x) |
|
|
116 |
x = self.encoder.layer2(x) |
|
|
117 |
x = self.encoder.layer3(x) |
|
|
118 |
x = self.encoder.layer4(x) |
|
|
119 |
|
|
|
120 |
x = self.encoder.avgpool(x) |
|
|
121 |
x = torch.flatten(x, 1) |
|
|
122 |
|
|
|
123 |
return x |
|
|
124 |
|
|
|
125 |
model = DecapitatedResnet(models.__dict__[backbone], use_imagenet_weights) |
|
|
126 |
|
|
|
127 |
if use_imagenet_weights: |
|
|
128 |
if checkpoint_file is not None: |
|
|
129 |
raise Exception( |
|
|
130 |
"Either provide a weights checkpoint or the --imagenet flag, not both." |
|
|
131 |
) |
|
|
132 |
print(f"Created encoder with Imagenet weights") |
|
|
133 |
else: |
|
|
134 |
checkpoint = torch.load(checkpoint_file, map_location="cpu") |
|
|
135 |
state_dict = checkpoint["state_dict"] |
|
|
136 |
for k in list(state_dict.keys()): |
|
|
137 |
# retain only encoder_q up to before the embedding layer |
|
|
138 |
if k.startswith("module.encoder_q") and not k.startswith( |
|
|
139 |
"module.encoder_q.fc" |
|
|
140 |
): |
|
|
141 |
# remove prefix from key names |
|
|
142 |
state_dict[k[len("module.encoder_q.") :]] = state_dict[k] |
|
|
143 |
# delete renamed or unused k |
|
|
144 |
del state_dict[k] |
|
|
145 |
|
|
|
146 |
# Verify that the checkpoint did not contain data for the final FC layer |
|
|
147 |
msg = model.encoder.load_state_dict(state_dict, strict=False) |
|
|
148 |
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} |
|
|
149 |
print(f"Loaded checkpoint {checkpoint_file}") |
|
|
150 |
|
|
|
151 |
model = model.to(device) |
|
|
152 |
if torch.cuda.device_count() > 1: |
|
|
153 |
model = torch.nn.DataParallel(model) |
|
|
154 |
model.eval() |
|
|
155 |
|
|
|
156 |
return model |
|
|
157 |
|