|
a |
|
b/misc/download_pretrained_model.py |
|
|
1 |
# Modified from https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.3/paddleseg/utils/download.py |
|
|
2 |
|
|
|
3 |
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
|
|
4 |
# |
|
|
5 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
6 |
# you may not use this file except in compliance with the License. |
|
|
7 |
# You may obtain a copy of the License at |
|
|
8 |
# |
|
|
9 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
# |
|
|
11 |
# Unless required by applicable law or agreed to in writing, software |
|
|
12 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
13 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
14 |
# See the License for the specific language governing permissions and |
|
|
15 |
# limitations under the License. |
|
|
16 |
|
|
|
17 |
import os |
|
|
18 |
import shutil |
|
|
19 |
import requests |
|
|
20 |
import time |
|
|
21 |
import sys |
|
|
22 |
import zipfile |
|
|
23 |
lasttime = time.time() |
|
|
24 |
FLUSH_INTERVAL = 0.1 |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
def progress(str, end=False): |
|
|
28 |
global lasttime |
|
|
29 |
if end: |
|
|
30 |
str += "\n" |
|
|
31 |
lasttime = 0 |
|
|
32 |
if time.time() - lasttime >= FLUSH_INTERVAL: |
|
|
33 |
sys.stdout.write("\r%s" % str) |
|
|
34 |
lasttime = time.time() |
|
|
35 |
sys.stdout.flush() |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def _download_file(url, savepath, print_progress): |
|
|
39 |
if print_progress: |
|
|
40 |
print("Connecting to {}".format(url)) |
|
|
41 |
r = requests.get(url, stream=True, timeout=15) |
|
|
42 |
total_length = r.headers.get('content-length') |
|
|
43 |
|
|
|
44 |
if total_length is None: |
|
|
45 |
with open(savepath, 'wb') as f: |
|
|
46 |
shutil.copyfileobj(r.raw, f) |
|
|
47 |
else: |
|
|
48 |
with open(savepath, 'wb') as f: |
|
|
49 |
dl = 0 |
|
|
50 |
total_length = int(total_length) |
|
|
51 |
if print_progress: |
|
|
52 |
print("Downloading %s" % os.path.basename(savepath)) |
|
|
53 |
for data in r.iter_content(chunk_size=4096): |
|
|
54 |
dl += len(data) |
|
|
55 |
f.write(data) |
|
|
56 |
if print_progress: |
|
|
57 |
done = int(50 * dl / total_length) |
|
|
58 |
progress("[%-50s] %.2f%%" % |
|
|
59 |
('=' * done, float(100 * dl) / total_length)) |
|
|
60 |
if print_progress: |
|
|
61 |
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def _uncompress_file_zip(filepath, extrapath): |
|
|
65 |
files = zipfile.ZipFile(filepath, 'r') |
|
|
66 |
filelist = files.namelist() |
|
|
67 |
rootpath = filelist[0] |
|
|
68 |
total_num = len(filelist) |
|
|
69 |
for index, file in enumerate(filelist): |
|
|
70 |
files.extract(file, extrapath) |
|
|
71 |
yield total_num, index, rootpath |
|
|
72 |
files.close() |
|
|
73 |
yield total_num, index, rootpath |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
def download_file_and_uncompress(url, |
|
|
77 |
savepath=None, |
|
|
78 |
print_progress=True, |
|
|
79 |
replace=False, |
|
|
80 |
extrapath=None, |
|
|
81 |
delete_file=True): |
|
|
82 |
if savepath is None: |
|
|
83 |
savepath = "." |
|
|
84 |
if extrapath is None: |
|
|
85 |
extrapath = "." |
|
|
86 |
savename = url.split("/")[-1] |
|
|
87 |
if not savename.endswith("zip"): |
|
|
88 |
raise NotImplementedError( |
|
|
89 |
"Only support zip file, but got {}!".format(savename)) |
|
|
90 |
if not os.path.exists(savepath): |
|
|
91 |
os.makedirs(savepath) |
|
|
92 |
|
|
|
93 |
savepath = os.path.join(savepath, savename) |
|
|
94 |
savename = ".".join(savename.split(".")[:-1]) |
|
|
95 |
|
|
|
96 |
if replace: |
|
|
97 |
if os.path.exists(savepath): |
|
|
98 |
shutil.rmtree(savepath) |
|
|
99 |
|
|
|
100 |
if not os.path.exists(savename): |
|
|
101 |
if not os.path.exists(savepath): |
|
|
102 |
_download_file(url, savepath, print_progress) |
|
|
103 |
|
|
|
104 |
if print_progress: |
|
|
105 |
print("Uncompress %s" % os.path.basename(savepath)) |
|
|
106 |
for total_num, index, rootpath in _uncompress_file_zip(savepath, extrapath): |
|
|
107 |
if print_progress: |
|
|
108 |
done = int(50 * float(index) / total_num) |
|
|
109 |
progress( |
|
|
110 |
"[%-50s] %.2f%%" % ('=' * done, float(100 * index) / total_num)) |
|
|
111 |
if print_progress: |
|
|
112 |
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) |
|
|
113 |
|
|
|
114 |
if delete_file: |
|
|
115 |
os.remove(savepath) |
|
|
116 |
|
|
|
117 |
return rootpath |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
if __name__ == "__main__": |
|
|
121 |
urls = [ |
|
|
122 |
"https://github.com/ShiqiYu/OpenGait/releases/download/v1.0/pretrained_casiab_model.zip", |
|
|
123 |
"https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_oumvlp_model.zip", |
|
|
124 |
"https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_model.zip"] |
|
|
125 |
for url in urls: |
|
|
126 |
download_file_and_uncompress( |
|
|
127 |
url=url, extrapath='output') |
|
|
128 |
gaitgl_grew = ['https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_gaitgl.zip', |
|
|
129 |
'https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_gaitgl_bnneck.zip'] |
|
|
130 |
for gaitgl in gaitgl_grew: |
|
|
131 |
download_file_and_uncompress( |
|
|
132 |
url=gaitgl, extrapath='output/GREW/GaitGL') |
|
|
133 |
print("Pretrained model download success!") |