a b/.dev/md2yml.py
1
#!/usr/bin/env python
2
3
# Copyright (c) OpenMMLab. All rights reserved.
4
# This tool is used to update model-index.yml which is required by MIM, and
5
# will be automatically called as a pre-commit hook. The updating will be
6
# triggered if any change of model information (.md files in configs/) has been
7
# detected before a commit.
8
9
import glob
10
import os
11
import os.path as osp
12
import re
13
import sys
14
15
import mmcv
16
from lxml import etree
17
18
MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__))))
19
20
21
def dump_yaml_and_check_difference(obj, filename, sort_keys=False):
22
    """Dump object to a yaml file, and check if the file content is different
23
    from the original.
24
25
    Args:
26
        obj (any): The python object to be dumped.
27
        filename (str): YAML filename to dump the object to.
28
        sort_keys (str); Sort key by dictionary order.
29
    Returns:
30
        Bool: If the target YAML file is different from the original.
31
    """
32
33
    str_dump = mmcv.dump(obj, None, file_format='yaml', sort_keys=sort_keys)
34
    if osp.isfile(filename):
35
        file_exists = True
36
        with open(filename, 'r', encoding='utf-8') as f:
37
            str_orig = f.read()
38
    else:
39
        file_exists = False
40
        str_orig = None
41
42
    if file_exists and str_orig == str_dump:
43
        is_different = False
44
    else:
45
        is_different = True
46
        with open(filename, 'w', encoding='utf-8') as f:
47
            f.write(str_dump)
48
49
    return is_different
50
51
52
def parse_md(md_file):
53
    """Parse .md file and convert it to a .yml file which can be used for MIM.
54
55
    Args:
56
        md_file (str): Path to .md file.
57
    Returns:
58
        Bool: If the target YAML file is different from the original.
59
    """
60
    collection_name = osp.split(osp.dirname(md_file))[1]
61
    configs = os.listdir(osp.dirname(md_file))
62
63
    collection = dict(
64
        Name=collection_name,
65
        Metadata={'Training Data': []},
66
        Paper={
67
            'URL': '',
68
            'Title': ''
69
        },
70
        README=md_file,
71
        Code={
72
            'URL': '',
73
            'Version': ''
74
        })
75
    collection.update({'Converted From': {'Weights': '', 'Code': ''}})
76
    models = []
77
    datasets = []
78
    paper_url = None
79
    paper_title = None
80
    code_url = None
81
    code_version = None
82
    repo_url = None
83
84
    with open(md_file, 'r') as md:
85
        lines = md.readlines()
86
        i = 0
87
        current_dataset = ''
88
        while i < len(lines):
89
            line = lines[i].strip()
90
            if len(line) == 0:
91
                i += 1
92
                continue
93
            if line[:2] == '# ':
94
                paper_title = line.replace('# ', '')
95
                i += 1
96
            elif line[:3] == '<a ':
97
                content = etree.HTML(line)
98
                node = content.xpath('//a')[0]
99
                if node.text == 'Code Snippet':
100
                    code_url = node.get('href', None)
101
                    assert code_url is not None, (
102
                        f'{collection_name} hasn\'t code snippet url.')
103
                    # version extraction
104
                    filter_str = r'blob/(.*)/mm'
105
                    pattern = re.compile(filter_str)
106
                    code_version = pattern.findall(code_url)
107
                    assert len(code_version) == 1, (
108
                        f'false regular expression ({filter_str}) use.')
109
                    code_version = code_version[0]
110
                elif node.text == 'Official Repo':
111
                    repo_url = node.get('href', None)
112
                    assert repo_url is not None, (
113
                        f'{collection_name} hasn\'t official repo url.')
114
                i += 1
115
            elif line[:9] == '<summary ':
116
                content = etree.HTML(line)
117
                nodes = content.xpath('//a')
118
                assert len(nodes) == 1, (
119
                    'summary tag should only have single a tag.')
120
                paper_url = nodes[0].get('href', None)
121
                i += 1
122
            elif line[:4] == '### ':
123
                datasets.append(line[4:])
124
                current_dataset = line[4:]
125
                i += 2
126
            elif line[0] == '|' and (
127
                    i + 1) < len(lines) and lines[i + 1][:3] == '| -':
128
                cols = [col.strip() for col in line.split('|')]
129
                backbone_id = cols.index('Backbone')
130
                crop_size_id = cols.index('Crop Size')
131
                lr_schd_id = cols.index('Lr schd')
132
                mem_id = cols.index('Mem (GB)')
133
                fps_id = cols.index('Inf time (fps)')
134
                try:
135
                    ss_id = cols.index('mIoU')
136
                except ValueError:
137
                    ss_id = cols.index('Dice')
138
                try:
139
                    ms_id = cols.index('mIoU(ms+flip)')
140
                except ValueError:
141
                    ms_id = False
142
                config_id = cols.index('config')
143
                download_id = cols.index('download')
144
                j = i + 2
145
                while j < len(lines) and lines[j][0] == '|':
146
                    els = [el.strip() for el in lines[j].split('|')]
147
                    config = ''
148
                    model_name = ''
149
                    weight = ''
150
                    for fn in configs:
151
                        if fn in els[config_id]:
152
                            left = els[download_id].index(
153
                                'https://download.openmmlab.com')
154
                            right = els[download_id].index('.pth') + 4
155
                            weight = els[download_id][left:right]
156
                            config = f'configs/{collection_name}/{fn}'
157
                            model_name = fn[:-3]
158
                    fps = els[fps_id] if els[fps_id] != '-' and els[
159
                        fps_id] != '' else -1
160
                    mem = els[mem_id] if els[mem_id] != '-' and els[
161
                        mem_id] != '' else -1
162
                    crop_size = els[crop_size_id].split('x')
163
                    assert len(crop_size) == 2
164
                    model = {
165
                        'Name':
166
                        model_name,
167
                        'In Collection':
168
                        collection_name,
169
                        'Metadata': {
170
                            'backbone': els[backbone_id],
171
                            'crop size': f'({crop_size[0]},{crop_size[1]})',
172
                            'lr schd': int(els[lr_schd_id]),
173
                        },
174
                        'Results': [
175
                            {
176
                                'Task': 'Semantic Segmentation',
177
                                'Dataset': current_dataset,
178
                                'Metrics': {
179
                                    cols[ss_id]: float(els[ss_id]),
180
                                },
181
                            },
182
                        ],
183
                        'Config':
184
                        config,
185
                        'Weights':
186
                        weight,
187
                    }
188
                    if fps != -1:
189
                        try:
190
                            fps = float(fps)
191
                        except Exception:
192
                            j += 1
193
                            continue
194
                        model['Metadata']['inference time (ms/im)'] = [{
195
                            'value':
196
                            round(1000 / float(fps), 2),
197
                            'hardware':
198
                            'V100',
199
                            'backend':
200
                            'PyTorch',
201
                            'batch size':
202
                            1,
203
                            'mode':
204
                            'FP32' if 'fp16' not in config else 'FP16',
205
                            'resolution':
206
                            f'({crop_size[0]},{crop_size[1]})'
207
                        }]
208
                    if mem != -1:
209
                        model['Metadata']['Training Memory (GB)'] = float(mem)
210
                    # Only have semantic segmentation now
211
                    if ms_id and els[ms_id] != '-' and els[ms_id] != '':
212
                        model['Results'][0]['Metrics'][
213
                            'mIoU(ms+flip)'] = float(els[ms_id])
214
                    models.append(model)
215
                    j += 1
216
                i = j
217
            else:
218
                i += 1
219
    flag = (code_url is not None) and (paper_url is not None) and (repo_url
220
                                                                   is not None)
221
    assert flag, f'{collection_name} readme error'
222
    collection['Metadata']['Training Data'] = datasets
223
    collection['Code']['URL'] = code_url
224
    collection['Code']['Version'] = code_version
225
    collection['Paper']['URL'] = paper_url
226
    collection['Paper']['Title'] = paper_title
227
    collection['Converted From']['Code'] = repo_url
228
    # ['Converted From']['Weights] miss
229
    # remove empty attribute
230
    check_key_list = ['Code', 'Paper', 'Converted From']
231
    for check_key in check_key_list:
232
        key_list = list(collection[check_key].keys())
233
        for key in key_list:
234
            if check_key not in collection:
235
                break
236
            if collection[check_key][key] == '':
237
                if len(collection[check_key].keys()) == 1:
238
                    collection.pop(check_key)
239
                else:
240
                    collection[check_key].pop(key)
241
242
    result = {'Collections': [collection], 'Models': models}
243
    yml_file = f'{md_file[:-9]}{collection_name}.yml'
244
    return dump_yaml_and_check_difference(result, yml_file)
245
246
247
def update_model_index():
248
    """Update model-index.yml according to model .md files.
249
250
    Returns:
251
        Bool: If the updated model-index.yml is different from the original.
252
    """
253
    configs_dir = osp.join(MMSEG_ROOT, 'configs')
254
    yml_files = glob.glob(osp.join(configs_dir, '**', '*.yml'), recursive=True)
255
    yml_files.sort()
256
257
    model_index = {
258
        'Import':
259
        [osp.relpath(yml_file, MMSEG_ROOT) for yml_file in yml_files]
260
    }
261
    model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
262
    is_different = dump_yaml_and_check_difference(model_index,
263
                                                  model_index_file)
264
265
    return is_different
266
267
268
if __name__ == '__main__':
269
    file_list = [fn for fn in sys.argv[1:] if osp.basename(fn) == 'README.md']
270
    if not file_list:
271
        sys.exit(0)
272
    file_modified = False
273
    for fn in file_list:
274
        file_modified |= parse_md(fn)
275
276
    file_modified |= update_model_index()
277
278
    sys.exit(1 if file_modified else 0)