Diff of /.dev/md2yml.py [000000] .. [4e96d3]

Switch to side-by-side view

--- a
+++ b/.dev/md2yml.py
@@ -0,0 +1,278 @@
+#!/usr/bin/env python
+
+# Copyright (c) OpenMMLab. All rights reserved.
+# This tool is used to update model-index.yml which is required by MIM, and
+# will be automatically called as a pre-commit hook. The updating will be
+# triggered if any change of model information (.md files in configs/) has been
+# detected before a commit.
+
+import glob
+import os
+import os.path as osp
+import re
+import sys
+
+import mmcv
+from lxml import etree
+
+MMSEG_ROOT = osp.dirname(osp.dirname((osp.dirname(__file__))))
+
+
+def dump_yaml_and_check_difference(obj, filename, sort_keys=False):
+    """Dump object to a yaml file, and check if the file content is different
+    from the original.
+
+    Args:
+        obj (any): The python object to be dumped.
+        filename (str): YAML filename to dump the object to.
+        sort_keys (str); Sort key by dictionary order.
+    Returns:
+        Bool: If the target YAML file is different from the original.
+    """
+
+    str_dump = mmcv.dump(obj, None, file_format='yaml', sort_keys=sort_keys)
+    if osp.isfile(filename):
+        file_exists = True
+        with open(filename, 'r', encoding='utf-8') as f:
+            str_orig = f.read()
+    else:
+        file_exists = False
+        str_orig = None
+
+    if file_exists and str_orig == str_dump:
+        is_different = False
+    else:
+        is_different = True
+        with open(filename, 'w', encoding='utf-8') as f:
+            f.write(str_dump)
+
+    return is_different
+
+
+def parse_md(md_file):
+    """Parse .md file and convert it to a .yml file which can be used for MIM.
+
+    Args:
+        md_file (str): Path to .md file.
+    Returns:
+        Bool: If the target YAML file is different from the original.
+    """
+    collection_name = osp.split(osp.dirname(md_file))[1]
+    configs = os.listdir(osp.dirname(md_file))
+
+    collection = dict(
+        Name=collection_name,
+        Metadata={'Training Data': []},
+        Paper={
+            'URL': '',
+            'Title': ''
+        },
+        README=md_file,
+        Code={
+            'URL': '',
+            'Version': ''
+        })
+    collection.update({'Converted From': {'Weights': '', 'Code': ''}})
+    models = []
+    datasets = []
+    paper_url = None
+    paper_title = None
+    code_url = None
+    code_version = None
+    repo_url = None
+
+    with open(md_file, 'r') as md:
+        lines = md.readlines()
+        i = 0
+        current_dataset = ''
+        while i < len(lines):
+            line = lines[i].strip()
+            if len(line) == 0:
+                i += 1
+                continue
+            if line[:2] == '# ':
+                paper_title = line.replace('# ', '')
+                i += 1
+            elif line[:3] == '<a ':
+                content = etree.HTML(line)
+                node = content.xpath('//a')[0]
+                if node.text == 'Code Snippet':
+                    code_url = node.get('href', None)
+                    assert code_url is not None, (
+                        f'{collection_name} hasn\'t code snippet url.')
+                    # version extraction
+                    filter_str = r'blob/(.*)/mm'
+                    pattern = re.compile(filter_str)
+                    code_version = pattern.findall(code_url)
+                    assert len(code_version) == 1, (
+                        f'false regular expression ({filter_str}) use.')
+                    code_version = code_version[0]
+                elif node.text == 'Official Repo':
+                    repo_url = node.get('href', None)
+                    assert repo_url is not None, (
+                        f'{collection_name} hasn\'t official repo url.')
+                i += 1
+            elif line[:9] == '<summary ':
+                content = etree.HTML(line)
+                nodes = content.xpath('//a')
+                assert len(nodes) == 1, (
+                    'summary tag should only have single a tag.')
+                paper_url = nodes[0].get('href', None)
+                i += 1
+            elif line[:4] == '### ':
+                datasets.append(line[4:])
+                current_dataset = line[4:]
+                i += 2
+            elif line[0] == '|' and (
+                    i + 1) < len(lines) and lines[i + 1][:3] == '| -':
+                cols = [col.strip() for col in line.split('|')]
+                backbone_id = cols.index('Backbone')
+                crop_size_id = cols.index('Crop Size')
+                lr_schd_id = cols.index('Lr schd')
+                mem_id = cols.index('Mem (GB)')
+                fps_id = cols.index('Inf time (fps)')
+                try:
+                    ss_id = cols.index('mIoU')
+                except ValueError:
+                    ss_id = cols.index('Dice')
+                try:
+                    ms_id = cols.index('mIoU(ms+flip)')
+                except ValueError:
+                    ms_id = False
+                config_id = cols.index('config')
+                download_id = cols.index('download')
+                j = i + 2
+                while j < len(lines) and lines[j][0] == '|':
+                    els = [el.strip() for el in lines[j].split('|')]
+                    config = ''
+                    model_name = ''
+                    weight = ''
+                    for fn in configs:
+                        if fn in els[config_id]:
+                            left = els[download_id].index(
+                                'https://download.openmmlab.com')
+                            right = els[download_id].index('.pth') + 4
+                            weight = els[download_id][left:right]
+                            config = f'configs/{collection_name}/{fn}'
+                            model_name = fn[:-3]
+                    fps = els[fps_id] if els[fps_id] != '-' and els[
+                        fps_id] != '' else -1
+                    mem = els[mem_id] if els[mem_id] != '-' and els[
+                        mem_id] != '' else -1
+                    crop_size = els[crop_size_id].split('x')
+                    assert len(crop_size) == 2
+                    model = {
+                        'Name':
+                        model_name,
+                        'In Collection':
+                        collection_name,
+                        'Metadata': {
+                            'backbone': els[backbone_id],
+                            'crop size': f'({crop_size[0]},{crop_size[1]})',
+                            'lr schd': int(els[lr_schd_id]),
+                        },
+                        'Results': [
+                            {
+                                'Task': 'Semantic Segmentation',
+                                'Dataset': current_dataset,
+                                'Metrics': {
+                                    cols[ss_id]: float(els[ss_id]),
+                                },
+                            },
+                        ],
+                        'Config':
+                        config,
+                        'Weights':
+                        weight,
+                    }
+                    if fps != -1:
+                        try:
+                            fps = float(fps)
+                        except Exception:
+                            j += 1
+                            continue
+                        model['Metadata']['inference time (ms/im)'] = [{
+                            'value':
+                            round(1000 / float(fps), 2),
+                            'hardware':
+                            'V100',
+                            'backend':
+                            'PyTorch',
+                            'batch size':
+                            1,
+                            'mode':
+                            'FP32' if 'fp16' not in config else 'FP16',
+                            'resolution':
+                            f'({crop_size[0]},{crop_size[1]})'
+                        }]
+                    if mem != -1:
+                        model['Metadata']['Training Memory (GB)'] = float(mem)
+                    # Only have semantic segmentation now
+                    if ms_id and els[ms_id] != '-' and els[ms_id] != '':
+                        model['Results'][0]['Metrics'][
+                            'mIoU(ms+flip)'] = float(els[ms_id])
+                    models.append(model)
+                    j += 1
+                i = j
+            else:
+                i += 1
+    flag = (code_url is not None) and (paper_url is not None) and (repo_url
+                                                                   is not None)
+    assert flag, f'{collection_name} readme error'
+    collection['Metadata']['Training Data'] = datasets
+    collection['Code']['URL'] = code_url
+    collection['Code']['Version'] = code_version
+    collection['Paper']['URL'] = paper_url
+    collection['Paper']['Title'] = paper_title
+    collection['Converted From']['Code'] = repo_url
+    # ['Converted From']['Weights] miss
+    # remove empty attribute
+    check_key_list = ['Code', 'Paper', 'Converted From']
+    for check_key in check_key_list:
+        key_list = list(collection[check_key].keys())
+        for key in key_list:
+            if check_key not in collection:
+                break
+            if collection[check_key][key] == '':
+                if len(collection[check_key].keys()) == 1:
+                    collection.pop(check_key)
+                else:
+                    collection[check_key].pop(key)
+
+    result = {'Collections': [collection], 'Models': models}
+    yml_file = f'{md_file[:-9]}{collection_name}.yml'
+    return dump_yaml_and_check_difference(result, yml_file)
+
+
+def update_model_index():
+    """Update model-index.yml according to model .md files.
+
+    Returns:
+        Bool: If the updated model-index.yml is different from the original.
+    """
+    configs_dir = osp.join(MMSEG_ROOT, 'configs')
+    yml_files = glob.glob(osp.join(configs_dir, '**', '*.yml'), recursive=True)
+    yml_files.sort()
+
+    model_index = {
+        'Import':
+        [osp.relpath(yml_file, MMSEG_ROOT) for yml_file in yml_files]
+    }
+    model_index_file = osp.join(MMSEG_ROOT, 'model-index.yml')
+    is_different = dump_yaml_and_check_difference(model_index,
+                                                  model_index_file)
+
+    return is_different
+
+
+if __name__ == '__main__':
+    file_list = [fn for fn in sys.argv[1:] if osp.basename(fn) == 'README.md']
+    if not file_list:
+        sys.exit(0)
+    file_modified = False
+    for fn in file_list:
+        file_modified |= parse_md(fn)
+
+    file_modified |= update_model_index()
+
+    sys.exit(1 if file_modified else 0)