Switch to side-by-side view

--- a
+++ b/misc/download_pretrained_model.py
@@ -0,0 +1,133 @@
+# Modified from https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.3/paddleseg/utils/download.py
+
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import requests
+import time
+import sys
+import zipfile
+lasttime = time.time()
+FLUSH_INTERVAL = 0.1
+
+
+def progress(str, end=False):
+    global lasttime
+    if end:
+        str += "\n"
+        lasttime = 0
+    if time.time() - lasttime >= FLUSH_INTERVAL:
+        sys.stdout.write("\r%s" % str)
+        lasttime = time.time()
+        sys.stdout.flush()
+
+
+def _download_file(url, savepath, print_progress):
+    if print_progress:
+        print("Connecting to {}".format(url))
+    r = requests.get(url, stream=True, timeout=15)
+    total_length = r.headers.get('content-length')
+
+    if total_length is None:
+        with open(savepath, 'wb') as f:
+            shutil.copyfileobj(r.raw, f)
+    else:
+        with open(savepath, 'wb') as f:
+            dl = 0
+            total_length = int(total_length)
+            if print_progress:
+                print("Downloading %s" % os.path.basename(savepath))
+            for data in r.iter_content(chunk_size=4096):
+                dl += len(data)
+                f.write(data)
+                if print_progress:
+                    done = int(50 * dl / total_length)
+                    progress("[%-50s] %.2f%%" %
+                             ('=' * done, float(100 * dl) / total_length))
+        if print_progress:
+            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
+
+
+def _uncompress_file_zip(filepath, extrapath):
+    files = zipfile.ZipFile(filepath, 'r')
+    filelist = files.namelist()
+    rootpath = filelist[0]
+    total_num = len(filelist)
+    for index, file in enumerate(filelist):
+        files.extract(file, extrapath)
+        yield total_num, index, rootpath
+    files.close()
+    yield total_num, index, rootpath
+
+
+def download_file_and_uncompress(url,
+                                 savepath=None,
+                                 print_progress=True,
+                                 replace=False,
+                                 extrapath=None,
+                                 delete_file=True):
+    if savepath is None:
+        savepath = "."
+    if extrapath is None:
+        extrapath = "."
+    savename = url.split("/")[-1]
+    if not savename.endswith("zip"):
+        raise NotImplementedError(
+            "Only support zip file, but got {}!".format(savename))
+    if not os.path.exists(savepath):
+        os.makedirs(savepath)
+
+    savepath = os.path.join(savepath, savename)
+    savename = ".".join(savename.split(".")[:-1])
+
+    if replace:
+        if os.path.exists(savepath):
+            shutil.rmtree(savepath)
+
+    if not os.path.exists(savename):
+        if not os.path.exists(savepath):
+            _download_file(url, savepath, print_progress)
+
+        if print_progress:
+            print("Uncompress %s" % os.path.basename(savepath))
+        for total_num, index, rootpath in _uncompress_file_zip(savepath, extrapath):
+            if print_progress:
+                done = int(50 * float(index) / total_num)
+                progress(
+                    "[%-50s] %.2f%%" % ('=' * done, float(100 * index) / total_num))
+        if print_progress:
+            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
+
+        if delete_file:
+            os.remove(savepath)
+
+    return rootpath
+
+
+if __name__ == "__main__":
+    urls = [
+        "https://github.com/ShiqiYu/OpenGait/releases/download/v1.0/pretrained_casiab_model.zip",
+        "https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_oumvlp_model.zip",
+        "https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_model.zip"]
+    for url in urls:
+        download_file_and_uncompress(
+            url=url, extrapath='output')
+    gaitgl_grew = ['https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_gaitgl.zip',
+                   'https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_grew_gaitgl_bnneck.zip']
+    for gaitgl in gaitgl_grew:
+        download_file_and_uncompress(
+                url=gaitgl, extrapath='output/GREW/GaitGL')
+    print("Pretrained model download success!")