Switch to unified view

a b/tools/model_converters/vit2mmseg.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import os.path as osp
4
from collections import OrderedDict
5
6
import mmcv
7
import torch
8
from mmcv.runner import CheckpointLoader
9
10
11
def convert_vit(ckpt):
12
13
    new_ckpt = OrderedDict()
14
15
    for k, v in ckpt.items():
16
        if k.startswith('head'):
17
            continue
18
        if k.startswith('norm'):
19
            new_k = k.replace('norm.', 'ln1.')
20
        elif k.startswith('patch_embed'):
21
            if 'proj' in k:
22
                new_k = k.replace('proj', 'projection')
23
            else:
24
                new_k = k
25
        elif k.startswith('blocks'):
26
            if 'norm' in k:
27
                new_k = k.replace('norm', 'ln')
28
            elif 'mlp.fc1' in k:
29
                new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
30
            elif 'mlp.fc2' in k:
31
                new_k = k.replace('mlp.fc2', 'ffn.layers.1')
32
            elif 'attn.qkv' in k:
33
                new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
34
            elif 'attn.proj' in k:
35
                new_k = k.replace('attn.proj', 'attn.attn.out_proj')
36
            else:
37
                new_k = k
38
            new_k = new_k.replace('blocks.', 'layers.')
39
        else:
40
            new_k = k
41
        new_ckpt[new_k] = v
42
43
    return new_ckpt
44
45
46
def main():
47
    parser = argparse.ArgumentParser(
48
        description='Convert keys in timm pretrained vit models to '
49
        'MMSegmentation style.')
50
    parser.add_argument('src', help='src model path or url')
51
    # The dst path must be a full path of the new checkpoint.
52
    parser.add_argument('dst', help='save path')
53
    args = parser.parse_args()
54
55
    checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
56
    if 'state_dict' in checkpoint:
57
        # timm checkpoint
58
        state_dict = checkpoint['state_dict']
59
    elif 'model' in checkpoint:
60
        # deit checkpoint
61
        state_dict = checkpoint['model']
62
    else:
63
        state_dict = checkpoint
64
    weight = convert_vit(state_dict)
65
    mmcv.mkdir_or_exist(osp.dirname(args.dst))
66
    torch.save(weight, args.dst)
67
68
69
if __name__ == '__main__':
70
    main()