a b/misc/download_pretrained_model.py
1
# Modified from https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.3/paddleseg/utils/download.py
2
3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#    http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
17
import os
18
import shutil
19
import requests
20
import time
21
import sys
22
import zipfile
23
lasttime = time.time()
24
FLUSH_INTERVAL = 0.1
25
26
27
def progress(str, end=False):
28
    global lasttime
29
    if end:
30
        str += "\n"
31
        lasttime = 0
32
    if time.time() - lasttime >= FLUSH_INTERVAL:
33
        sys.stdout.write("\r%s" % str)
34
        lasttime = time.time()
35
        sys.stdout.flush()
36
37
38
def _download_file(url, savepath, print_progress):
39
    if print_progress:
40
        print("Connecting to {}".format(url))
41
    r = requests.get(url, stream=True, timeout=15)
42
    total_length = r.headers.get('content-length')
43
44
    if total_length is None:
45
        with open(savepath, 'wb') as f:
46
            shutil.copyfileobj(r.raw, f)
47
    else:
48
        with open(savepath, 'wb') as f:
49
            dl = 0
50
            total_length = int(total_length)
51
            if print_progress:
52
                print("Downloading %s" % os.path.basename(savepath))
53
            for data in r.iter_content(chunk_size=4096):
54
                dl += len(data)
55
                f.write(data)
56
                if print_progress:
57
                    done = int(50 * dl / total_length)
58
                    progress("[%-50s] %.2f%%" %
59
                             ('=' * done, float(100 * dl) / total_length))
60
        if print_progress:
61
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
62
63
64
def _uncompress_file_zip(filepath, extrapath):
65
    files = zipfile.ZipFile(filepath, 'r')
66
    filelist = files.namelist()
67
    rootpath = filelist[0]
68
    total_num = len(filelist)
69
    for index, file in enumerate(filelist):
70
        files.extract(file, extrapath)
71
        yield total_num, index, rootpath
72
    files.close()
73
    yield total_num, index, rootpath
74
75
76
def download_file_and_uncompress(url,
77
                                 savepath=None,
78
                                 print_progress=True,
79
                                 replace=False,
80
                                 extrapath=None,
81
                                 delete_file=True):
82
    if savepath is None:
83
        savepath = "."
84
    if extrapath is None:
85
        extrapath = "."
86
    savename = url.split("/")[-1]
87
    if not savename.endswith("zip"):
88
        raise NotImplementedError(
89
            "Only support zip file, but got {}!".format(savename))
90
    if not os.path.exists(savepath):
91
        os.makedirs(savepath)
92
93
    savepath = os.path.join(savepath, savename)
94
    savename = ".".join(savename.split(".")[:-1])
95
96
    if replace:
97
        if os.path.exists(savepath):
98
            shutil.rmtree(savepath)
99
100
    if not os.path.exists(savename):
101
        if not os.path.exists(savepath):
102
            _download_file(url, savepath, print_progress)
103
104
        if print_progress:
105
            print("Uncompress %s" % os.path.basename(savepath))
106
        for total_num, index, rootpath in _uncompress_file_zip(savepath, extrapath):
107
            if print_progress:
108
                done = int(50 * float(index) / total_num)
109
                progress(
110
                    "[%-50s] %.2f%%" % ('=' * done, float(100 * index) / total_num))
111
        if print_progress:
112
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
113
114
        if delete_file:
115
            os.remove(savepath)
116
117
    return rootpath
118
119
120
if __name__ == "__main__":
121
    urls = [
122
        "https://github.com/ShiqiYu/OpenGait/releases/download/v1.0/pretrained_casiab_model.zip",
123
        "https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_oumvlp_model.zip",
124
        "https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_model.zip"]
125
    for url in urls:
126
        download_file_and_uncompress(
127
            url=url, extrapath='output')
128
    gaitgl_grew = ['https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_gaitgl.zip',
129
                   'https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_gaitgl_bnneck.zip']
130
    for gaitgl in gaitgl_grew:
131
        download_file_and_uncompress(
132
                url=gaitgl, extrapath='output/GREW/GaitGL')
133
    print("Pretrained model download success!")