|
a |
|
b/mmseg/utils/checkpoint.py |
|
|
1 |
# Copyright (c) Open-MMLab. All rights reserved. |
|
|
2 |
import io |
|
|
3 |
import os |
|
|
4 |
import os.path as osp |
|
|
5 |
import pkgutil |
|
|
6 |
import time |
|
|
7 |
import warnings |
|
|
8 |
from collections import OrderedDict |
|
|
9 |
from importlib import import_module |
|
|
10 |
from tempfile import TemporaryDirectory |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
import torchvision |
|
|
14 |
from torch.optim import Optimizer |
|
|
15 |
from torch.utils import model_zoo |
|
|
16 |
from torch.nn import functional as F |
|
|
17 |
|
|
|
18 |
import mmcv |
|
|
19 |
from mmcv.fileio import FileClient |
|
|
20 |
from mmcv.fileio import load as load_file |
|
|
21 |
from mmcv.parallel import is_module_wrapper |
|
|
22 |
from mmcv.utils import mkdir_or_exist |
|
|
23 |
from mmcv.runner import get_dist_info |
|
|
24 |
|
|
|
25 |
ENV_MMCV_HOME = 'MMCV_HOME' |
|
|
26 |
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' |
|
|
27 |
DEFAULT_CACHE_DIR = '~/.cache' |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def _get_mmcv_home(): |
|
|
31 |
mmcv_home = os.path.expanduser( |
|
|
32 |
os.getenv( |
|
|
33 |
ENV_MMCV_HOME, |
|
|
34 |
os.path.join( |
|
|
35 |
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) |
|
|
36 |
|
|
|
37 |
mkdir_or_exist(mmcv_home) |
|
|
38 |
return mmcv_home |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def load_state_dict(module, state_dict, strict=False, logger=None): |
|
|
42 |
"""Load state_dict to a module. |
|
|
43 |
|
|
|
44 |
This method is modified from :meth:`torch.nn.Module.load_state_dict`. |
|
|
45 |
Default value for ``strict`` is set to ``False`` and the message for |
|
|
46 |
param mismatch will be shown even if strict is False. |
|
|
47 |
|
|
|
48 |
Args: |
|
|
49 |
module (Module): Module that receives the state_dict. |
|
|
50 |
state_dict (OrderedDict): Weights. |
|
|
51 |
strict (bool): whether to strictly enforce that the keys |
|
|
52 |
in :attr:`state_dict` match the keys returned by this module's |
|
|
53 |
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``. |
|
|
54 |
logger (:obj:`logging.Logger`, optional): Logger to log the error |
|
|
55 |
message. If not specified, print function will be used. |
|
|
56 |
""" |
|
|
57 |
unexpected_keys = [] |
|
|
58 |
all_missing_keys = [] |
|
|
59 |
err_msg = [] |
|
|
60 |
|
|
|
61 |
metadata = getattr(state_dict, '_metadata', None) |
|
|
62 |
state_dict = state_dict.copy() |
|
|
63 |
if metadata is not None: |
|
|
64 |
state_dict._metadata = metadata |
|
|
65 |
|
|
|
66 |
# use _load_from_state_dict to enable checkpoint version control |
|
|
67 |
def load(module, prefix=''): |
|
|
68 |
# recursively check parallel module in case that the model has a |
|
|
69 |
# complicated structure, e.g., nn.Module(nn.Module(DDP)) |
|
|
70 |
if is_module_wrapper(module): |
|
|
71 |
module = module.module |
|
|
72 |
local_metadata = {} if metadata is None else metadata.get( |
|
|
73 |
prefix[:-1], {}) |
|
|
74 |
module._load_from_state_dict(state_dict, prefix, local_metadata, True, |
|
|
75 |
all_missing_keys, unexpected_keys, |
|
|
76 |
err_msg) |
|
|
77 |
for name, child in module._modules.items(): |
|
|
78 |
if child is not None: |
|
|
79 |
load(child, prefix + name + '.') |
|
|
80 |
|
|
|
81 |
load(module) |
|
|
82 |
load = None # break load->load reference cycle |
|
|
83 |
|
|
|
84 |
# ignore "num_batches_tracked" of BN layers |
|
|
85 |
missing_keys = [ |
|
|
86 |
key for key in all_missing_keys if 'num_batches_tracked' not in key |
|
|
87 |
] |
|
|
88 |
|
|
|
89 |
if unexpected_keys: |
|
|
90 |
err_msg.append('unexpected key in source ' |
|
|
91 |
f'state_dict: {", ".join(unexpected_keys)}\n') |
|
|
92 |
if missing_keys: |
|
|
93 |
err_msg.append( |
|
|
94 |
f'missing keys in source state_dict: {", ".join(missing_keys)}\n') |
|
|
95 |
|
|
|
96 |
rank, _ = get_dist_info() |
|
|
97 |
if len(err_msg) > 0 and rank == 0: |
|
|
98 |
err_msg.insert( |
|
|
99 |
0, 'The model and loaded state dict do not match exactly\n') |
|
|
100 |
err_msg = '\n'.join(err_msg) |
|
|
101 |
if strict: |
|
|
102 |
raise RuntimeError(err_msg) |
|
|
103 |
elif logger is not None: |
|
|
104 |
logger.warning(err_msg) |
|
|
105 |
else: |
|
|
106 |
print(err_msg) |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
def load_url_dist(url, model_dir=None): |
|
|
110 |
"""In distributed setting, this function only download checkpoint at local |
|
|
111 |
rank 0.""" |
|
|
112 |
rank, world_size = get_dist_info() |
|
|
113 |
rank = int(os.environ.get('LOCAL_RANK', rank)) |
|
|
114 |
if rank == 0: |
|
|
115 |
checkpoint = model_zoo.load_url(url, model_dir=model_dir) |
|
|
116 |
if world_size > 1: |
|
|
117 |
torch.distributed.barrier() |
|
|
118 |
if rank > 0: |
|
|
119 |
checkpoint = model_zoo.load_url(url, model_dir=model_dir) |
|
|
120 |
return checkpoint |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
def load_pavimodel_dist(model_path, map_location=None): |
|
|
124 |
"""In distributed setting, this function only download checkpoint at local |
|
|
125 |
rank 0.""" |
|
|
126 |
try: |
|
|
127 |
from pavi import modelcloud |
|
|
128 |
except ImportError: |
|
|
129 |
raise ImportError( |
|
|
130 |
'Please install pavi to load checkpoint from modelcloud.') |
|
|
131 |
rank, world_size = get_dist_info() |
|
|
132 |
rank = int(os.environ.get('LOCAL_RANK', rank)) |
|
|
133 |
if rank == 0: |
|
|
134 |
model = modelcloud.get(model_path) |
|
|
135 |
with TemporaryDirectory() as tmp_dir: |
|
|
136 |
downloaded_file = osp.join(tmp_dir, model.name) |
|
|
137 |
model.download(downloaded_file) |
|
|
138 |
checkpoint = torch.load(downloaded_file, map_location=map_location) |
|
|
139 |
if world_size > 1: |
|
|
140 |
torch.distributed.barrier() |
|
|
141 |
if rank > 0: |
|
|
142 |
model = modelcloud.get(model_path) |
|
|
143 |
with TemporaryDirectory() as tmp_dir: |
|
|
144 |
downloaded_file = osp.join(tmp_dir, model.name) |
|
|
145 |
model.download(downloaded_file) |
|
|
146 |
checkpoint = torch.load( |
|
|
147 |
downloaded_file, map_location=map_location) |
|
|
148 |
return checkpoint |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
def load_fileclient_dist(filename, backend, map_location): |
|
|
152 |
"""In distributed setting, this function only download checkpoint at local |
|
|
153 |
rank 0.""" |
|
|
154 |
rank, world_size = get_dist_info() |
|
|
155 |
rank = int(os.environ.get('LOCAL_RANK', rank)) |
|
|
156 |
allowed_backends = ['ceph'] |
|
|
157 |
if backend not in allowed_backends: |
|
|
158 |
raise ValueError(f'Load from Backend {backend} is not supported.') |
|
|
159 |
if rank == 0: |
|
|
160 |
fileclient = FileClient(backend=backend) |
|
|
161 |
buffer = io.BytesIO(fileclient.get(filename)) |
|
|
162 |
checkpoint = torch.load(buffer, map_location=map_location) |
|
|
163 |
if world_size > 1: |
|
|
164 |
torch.distributed.barrier() |
|
|
165 |
if rank > 0: |
|
|
166 |
fileclient = FileClient(backend=backend) |
|
|
167 |
buffer = io.BytesIO(fileclient.get(filename)) |
|
|
168 |
checkpoint = torch.load(buffer, map_location=map_location) |
|
|
169 |
return checkpoint |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
def get_torchvision_models(): |
|
|
173 |
model_urls = dict() |
|
|
174 |
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): |
|
|
175 |
if ispkg: |
|
|
176 |
continue |
|
|
177 |
_zoo = import_module(f'torchvision.models.{name}') |
|
|
178 |
if hasattr(_zoo, 'model_urls'): |
|
|
179 |
_urls = getattr(_zoo, 'model_urls') |
|
|
180 |
model_urls.update(_urls) |
|
|
181 |
return model_urls |
|
|
182 |
|
|
|
183 |
|
|
|
184 |
def get_external_models(): |
|
|
185 |
mmcv_home = _get_mmcv_home() |
|
|
186 |
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') |
|
|
187 |
default_urls = load_file(default_json_path) |
|
|
188 |
assert isinstance(default_urls, dict) |
|
|
189 |
external_json_path = osp.join(mmcv_home, 'open_mmlab.json') |
|
|
190 |
if osp.exists(external_json_path): |
|
|
191 |
external_urls = load_file(external_json_path) |
|
|
192 |
assert isinstance(external_urls, dict) |
|
|
193 |
default_urls.update(external_urls) |
|
|
194 |
|
|
|
195 |
return default_urls |
|
|
196 |
|
|
|
197 |
|
|
|
198 |
def get_mmcls_models(): |
|
|
199 |
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') |
|
|
200 |
mmcls_urls = load_file(mmcls_json_path) |
|
|
201 |
|
|
|
202 |
return mmcls_urls |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
def get_deprecated_model_names(): |
|
|
206 |
deprecate_json_path = osp.join(mmcv.__path__[0], |
|
|
207 |
'model_zoo/deprecated.json') |
|
|
208 |
deprecate_urls = load_file(deprecate_json_path) |
|
|
209 |
assert isinstance(deprecate_urls, dict) |
|
|
210 |
|
|
|
211 |
return deprecate_urls |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
def _process_mmcls_checkpoint(checkpoint): |
|
|
215 |
state_dict = checkpoint['state_dict'] |
|
|
216 |
new_state_dict = OrderedDict() |
|
|
217 |
for k, v in state_dict.items(): |
|
|
218 |
if k.startswith('backbone.'): |
|
|
219 |
new_state_dict[k[9:]] = v |
|
|
220 |
new_checkpoint = dict(state_dict=new_state_dict) |
|
|
221 |
|
|
|
222 |
return new_checkpoint |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
def _load_checkpoint(filename, map_location=None): |
|
|
226 |
"""Load checkpoint from somewhere (modelzoo, file, url). |
|
|
227 |
|
|
|
228 |
Args: |
|
|
229 |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
230 |
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
|
|
231 |
details. |
|
|
232 |
map_location (str | None): Same as :func:`torch.load`. Default: None. |
|
|
233 |
|
|
|
234 |
Returns: |
|
|
235 |
dict | OrderedDict: The loaded checkpoint. It can be either an |
|
|
236 |
OrderedDict storing model weights or a dict containing other |
|
|
237 |
information, which depends on the checkpoint. |
|
|
238 |
""" |
|
|
239 |
if filename.startswith('modelzoo://'): |
|
|
240 |
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' |
|
|
241 |
'use "torchvision://" instead') |
|
|
242 |
model_urls = get_torchvision_models() |
|
|
243 |
model_name = filename[11:] |
|
|
244 |
checkpoint = load_url_dist(model_urls[model_name]) |
|
|
245 |
elif filename.startswith('torchvision://'): |
|
|
246 |
model_urls = get_torchvision_models() |
|
|
247 |
model_name = filename[14:] |
|
|
248 |
checkpoint = load_url_dist(model_urls[model_name]) |
|
|
249 |
elif filename.startswith('open-mmlab://'): |
|
|
250 |
model_urls = get_external_models() |
|
|
251 |
model_name = filename[13:] |
|
|
252 |
deprecated_urls = get_deprecated_model_names() |
|
|
253 |
if model_name in deprecated_urls: |
|
|
254 |
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' |
|
|
255 |
f'of open-mmlab://{deprecated_urls[model_name]}') |
|
|
256 |
model_name = deprecated_urls[model_name] |
|
|
257 |
model_url = model_urls[model_name] |
|
|
258 |
# check if is url |
|
|
259 |
if model_url.startswith(('http://', 'https://')): |
|
|
260 |
checkpoint = load_url_dist(model_url) |
|
|
261 |
else: |
|
|
262 |
filename = osp.join(_get_mmcv_home(), model_url) |
|
|
263 |
if not osp.isfile(filename): |
|
|
264 |
raise IOError(f'{filename} is not a checkpoint file') |
|
|
265 |
checkpoint = torch.load(filename, map_location=map_location) |
|
|
266 |
elif filename.startswith('mmcls://'): |
|
|
267 |
model_urls = get_mmcls_models() |
|
|
268 |
model_name = filename[8:] |
|
|
269 |
checkpoint = load_url_dist(model_urls[model_name]) |
|
|
270 |
checkpoint = _process_mmcls_checkpoint(checkpoint) |
|
|
271 |
elif filename.startswith(('http://', 'https://')): |
|
|
272 |
checkpoint = load_url_dist(filename) |
|
|
273 |
elif filename.startswith('pavi://'): |
|
|
274 |
model_path = filename[7:] |
|
|
275 |
checkpoint = load_pavimodel_dist(model_path, map_location=map_location) |
|
|
276 |
elif filename.startswith('s3://'): |
|
|
277 |
checkpoint = load_fileclient_dist( |
|
|
278 |
filename, backend='ceph', map_location=map_location) |
|
|
279 |
else: |
|
|
280 |
if not osp.isfile(filename): |
|
|
281 |
raise IOError(f'{filename} is not a checkpoint file') |
|
|
282 |
checkpoint = torch.load(filename, map_location=map_location) |
|
|
283 |
return checkpoint |
|
|
284 |
|
|
|
285 |
|
|
|
286 |
def load_checkpoint(model, |
|
|
287 |
filename, |
|
|
288 |
map_location='cpu', |
|
|
289 |
strict=False, |
|
|
290 |
logger=None): |
|
|
291 |
"""Load checkpoint from a file or URI. |
|
|
292 |
|
|
|
293 |
Args: |
|
|
294 |
model (Module): Module to load checkpoint. |
|
|
295 |
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
296 |
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for |
|
|
297 |
details. |
|
|
298 |
map_location (str): Same as :func:`torch.load`. |
|
|
299 |
strict (bool): Whether to allow different params for the model and |
|
|
300 |
checkpoint. |
|
|
301 |
logger (:mod:`logging.Logger` or None): The logger for error message. |
|
|
302 |
|
|
|
303 |
Returns: |
|
|
304 |
dict or OrderedDict: The loaded checkpoint. |
|
|
305 |
""" |
|
|
306 |
checkpoint = _load_checkpoint(filename, map_location) |
|
|
307 |
# OrderedDict is a subclass of dict |
|
|
308 |
if not isinstance(checkpoint, dict): |
|
|
309 |
raise RuntimeError( |
|
|
310 |
f'No state_dict found in checkpoint file {filename}') |
|
|
311 |
# get state_dict from checkpoint |
|
|
312 |
if 'state_dict' in checkpoint: |
|
|
313 |
state_dict = checkpoint['state_dict'] |
|
|
314 |
elif 'state_dict_ema' in checkpoint: |
|
|
315 |
state_dict = checkpoint['state_dict_ema'] |
|
|
316 |
elif 'model' in checkpoint: |
|
|
317 |
state_dict = checkpoint['model'] |
|
|
318 |
else: |
|
|
319 |
state_dict = checkpoint |
|
|
320 |
# strip prefix of state_dict |
|
|
321 |
if list(state_dict.keys())[0].startswith('module.'): |
|
|
322 |
state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
323 |
|
|
|
324 |
# for MoBY, load model of online branch |
|
|
325 |
if sorted(list(state_dict.keys()))[0].startswith('encoder'): |
|
|
326 |
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} |
|
|
327 |
|
|
|
328 |
# reshape absolute position embedding |
|
|
329 |
if state_dict.get('absolute_pos_embed') is not None: |
|
|
330 |
absolute_pos_embed = state_dict['absolute_pos_embed'] |
|
|
331 |
N1, L, C1 = absolute_pos_embed.size() |
|
|
332 |
N2, C2, H, W = model.absolute_pos_embed.size() |
|
|
333 |
if N1 != N2 or C1 != C2 or L != H*W: |
|
|
334 |
logger.warning("Error in loading absolute_pos_embed, pass") |
|
|
335 |
else: |
|
|
336 |
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) |
|
|
337 |
|
|
|
338 |
# interpolate position bias table if needed |
|
|
339 |
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] |
|
|
340 |
for table_key in relative_position_bias_table_keys: |
|
|
341 |
table_pretrained = state_dict[table_key] |
|
|
342 |
table_current = model.state_dict()[table_key] |
|
|
343 |
L1, nH1 = table_pretrained.size() |
|
|
344 |
L2, nH2 = table_current.size() |
|
|
345 |
if nH1 != nH2: |
|
|
346 |
logger.warning(f"Error in loading {table_key}, pass") |
|
|
347 |
else: |
|
|
348 |
if L1 != L2: |
|
|
349 |
S1 = int(L1 ** 0.5) |
|
|
350 |
S2 = int(L2 ** 0.5) |
|
|
351 |
table_pretrained_resized = F.interpolate( |
|
|
352 |
table_pretrained.permute(1, 0).view(1, nH1, S1, S1), |
|
|
353 |
size=(S2, S2), mode='bicubic') |
|
|
354 |
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) |
|
|
355 |
|
|
|
356 |
# load state_dict |
|
|
357 |
if model.in_chans != 3: |
|
|
358 |
cur_state_dict = model.state_dict() |
|
|
359 |
for k in state_dict.keys(): |
|
|
360 |
if k not in cur_state_dict: |
|
|
361 |
continue |
|
|
362 |
if len(state_dict[k].shape) > 1 and state_dict[k].shape[1] == 3 and cur_state_dict[k].shape[1] == model.in_chans: |
|
|
363 |
state_dict[k] = torch.cat([state_dict[k], *[state_dict[k][:,0:1] for _ in range(model.in_chans - 3)]], 1) |
|
|
364 |
load_state_dict(model, state_dict, strict, logger) |
|
|
365 |
return checkpoint |
|
|
366 |
|
|
|
367 |
|
|
|
368 |
def weights_to_cpu(state_dict): |
|
|
369 |
"""Copy a model state_dict to cpu. |
|
|
370 |
|
|
|
371 |
Args: |
|
|
372 |
state_dict (OrderedDict): Model weights on GPU. |
|
|
373 |
|
|
|
374 |
Returns: |
|
|
375 |
OrderedDict: Model weights on GPU. |
|
|
376 |
""" |
|
|
377 |
state_dict_cpu = OrderedDict() |
|
|
378 |
for key, val in state_dict.items(): |
|
|
379 |
state_dict_cpu[key] = val.cpu() |
|
|
380 |
return state_dict_cpu |
|
|
381 |
|
|
|
382 |
|
|
|
383 |
def _save_to_state_dict(module, destination, prefix, keep_vars): |
|
|
384 |
"""Saves module state to `destination` dictionary. |
|
|
385 |
|
|
|
386 |
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. |
|
|
387 |
|
|
|
388 |
Args: |
|
|
389 |
module (nn.Module): The module to generate state_dict. |
|
|
390 |
destination (dict): A dict where state will be stored. |
|
|
391 |
prefix (str): The prefix for parameters and buffers used in this |
|
|
392 |
module. |
|
|
393 |
""" |
|
|
394 |
for name, param in module._parameters.items(): |
|
|
395 |
if param is not None: |
|
|
396 |
destination[prefix + name] = param if keep_vars else param.detach() |
|
|
397 |
for name, buf in module._buffers.items(): |
|
|
398 |
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d |
|
|
399 |
if buf is not None: |
|
|
400 |
destination[prefix + name] = buf if keep_vars else buf.detach() |
|
|
401 |
|
|
|
402 |
|
|
|
403 |
def get_state_dict(module, destination=None, prefix='', keep_vars=False): |
|
|
404 |
"""Returns a dictionary containing a whole state of the module. |
|
|
405 |
|
|
|
406 |
Both parameters and persistent buffers (e.g. running averages) are |
|
|
407 |
included. Keys are corresponding parameter and buffer names. |
|
|
408 |
|
|
|
409 |
This method is modified from :meth:`torch.nn.Module.state_dict` to |
|
|
410 |
recursively check parallel module in case that the model has a complicated |
|
|
411 |
structure, e.g., nn.Module(nn.Module(DDP)). |
|
|
412 |
|
|
|
413 |
Args: |
|
|
414 |
module (nn.Module): The module to generate state_dict. |
|
|
415 |
destination (OrderedDict): Returned dict for the state of the |
|
|
416 |
module. |
|
|
417 |
prefix (str): Prefix of the key. |
|
|
418 |
keep_vars (bool): Whether to keep the variable property of the |
|
|
419 |
parameters. Default: False. |
|
|
420 |
|
|
|
421 |
Returns: |
|
|
422 |
dict: A dictionary containing a whole state of the module. |
|
|
423 |
""" |
|
|
424 |
# recursively check parallel module in case that the model has a |
|
|
425 |
# complicated structure, e.g., nn.Module(nn.Module(DDP)) |
|
|
426 |
if is_module_wrapper(module): |
|
|
427 |
module = module.module |
|
|
428 |
|
|
|
429 |
# below is the same as torch.nn.Module.state_dict() |
|
|
430 |
if destination is None: |
|
|
431 |
destination = OrderedDict() |
|
|
432 |
destination._metadata = OrderedDict() |
|
|
433 |
destination._metadata[prefix[:-1]] = local_metadata = dict( |
|
|
434 |
version=module._version) |
|
|
435 |
_save_to_state_dict(module, destination, prefix, keep_vars) |
|
|
436 |
for name, child in module._modules.items(): |
|
|
437 |
if child is not None: |
|
|
438 |
get_state_dict( |
|
|
439 |
child, destination, prefix + name + '.', keep_vars=keep_vars) |
|
|
440 |
for hook in module._state_dict_hooks.values(): |
|
|
441 |
hook_result = hook(module, destination, prefix, local_metadata) |
|
|
442 |
if hook_result is not None: |
|
|
443 |
destination = hook_result |
|
|
444 |
return destination |
|
|
445 |
|
|
|
446 |
|
|
|
447 |
def save_checkpoint(model, filename, optimizer=None, meta=None): |
|
|
448 |
"""Save checkpoint to file. |
|
|
449 |
|
|
|
450 |
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and |
|
|
451 |
``optimizer``. By default ``meta`` will contain version and time info. |
|
|
452 |
|
|
|
453 |
Args: |
|
|
454 |
model (Module): Module whose params are to be saved. |
|
|
455 |
filename (str): Checkpoint filename. |
|
|
456 |
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. |
|
|
457 |
meta (dict, optional): Metadata to be saved in checkpoint. |
|
|
458 |
""" |
|
|
459 |
if meta is None: |
|
|
460 |
meta = {} |
|
|
461 |
elif not isinstance(meta, dict): |
|
|
462 |
raise TypeError(f'meta must be a dict or None, but got {type(meta)}') |
|
|
463 |
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) |
|
|
464 |
|
|
|
465 |
if is_module_wrapper(model): |
|
|
466 |
model = model.module |
|
|
467 |
|
|
|
468 |
if hasattr(model, 'CLASSES') and model.CLASSES is not None: |
|
|
469 |
# save class name to the meta |
|
|
470 |
meta.update(CLASSES=model.CLASSES) |
|
|
471 |
|
|
|
472 |
checkpoint = { |
|
|
473 |
'meta': meta, |
|
|
474 |
'state_dict': weights_to_cpu(get_state_dict(model)) |
|
|
475 |
} |
|
|
476 |
# save optimizer state dict in the checkpoint |
|
|
477 |
if isinstance(optimizer, Optimizer): |
|
|
478 |
checkpoint['optimizer'] = optimizer.state_dict() |
|
|
479 |
elif isinstance(optimizer, dict): |
|
|
480 |
checkpoint['optimizer'] = {} |
|
|
481 |
for name, optim in optimizer.items(): |
|
|
482 |
checkpoint['optimizer'][name] = optim.state_dict() |
|
|
483 |
|
|
|
484 |
if filename.startswith('pavi://'): |
|
|
485 |
try: |
|
|
486 |
from pavi import modelcloud |
|
|
487 |
from pavi.exception import NodeNotFoundError |
|
|
488 |
except ImportError: |
|
|
489 |
raise ImportError( |
|
|
490 |
'Please install pavi to load checkpoint from modelcloud.') |
|
|
491 |
model_path = filename[7:] |
|
|
492 |
root = modelcloud.Folder() |
|
|
493 |
model_dir, model_name = osp.split(model_path) |
|
|
494 |
try: |
|
|
495 |
model = modelcloud.get(model_dir) |
|
|
496 |
except NodeNotFoundError: |
|
|
497 |
model = root.create_training_model(model_dir) |
|
|
498 |
with TemporaryDirectory() as tmp_dir: |
|
|
499 |
checkpoint_file = osp.join(tmp_dir, model_name) |
|
|
500 |
with open(checkpoint_file, 'wb') as f: |
|
|
501 |
torch.save(checkpoint, f) |
|
|
502 |
f.flush() |
|
|
503 |
model.create_file(checkpoint_file, name=model_name) |
|
|
504 |
else: |
|
|
505 |
mmcv.mkdir_or_exist(osp.dirname(filename)) |
|
|
506 |
# immediately flush buffer |
|
|
507 |
with open(filename, 'wb') as f: |
|
|
508 |
torch.save(checkpoint, f) |
|
|
509 |
f.flush() |