Diff of /mmseg/utils/checkpoint.py [000000] .. [4e96d3]

Switch to unified view

a b/mmseg/utils/checkpoint.py
1
# Copyright (c) Open-MMLab. All rights reserved.
2
import io
3
import os
4
import os.path as osp
5
import pkgutil
6
import time
7
import warnings
8
from collections import OrderedDict
9
from importlib import import_module
10
from tempfile import TemporaryDirectory
11
12
import torch
13
import torchvision
14
from torch.optim import Optimizer
15
from torch.utils import model_zoo
16
from torch.nn import functional as F
17
18
import mmcv
19
from mmcv.fileio import FileClient
20
from mmcv.fileio import load as load_file
21
from mmcv.parallel import is_module_wrapper
22
from mmcv.utils import mkdir_or_exist
23
from mmcv.runner import get_dist_info
24
25
ENV_MMCV_HOME = 'MMCV_HOME'
26
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27
DEFAULT_CACHE_DIR = '~/.cache'
28
29
30
def _get_mmcv_home():
31
    mmcv_home = os.path.expanduser(
32
        os.getenv(
33
            ENV_MMCV_HOME,
34
            os.path.join(
35
                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36
37
    mkdir_or_exist(mmcv_home)
38
    return mmcv_home
39
40
41
def load_state_dict(module, state_dict, strict=False, logger=None):
42
    """Load state_dict to a module.
43
44
    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
45
    Default value for ``strict`` is set to ``False`` and the message for
46
    param mismatch will be shown even if strict is False.
47
48
    Args:
49
        module (Module): Module that receives the state_dict.
50
        state_dict (OrderedDict): Weights.
51
        strict (bool): whether to strictly enforce that the keys
52
            in :attr:`state_dict` match the keys returned by this module's
53
            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
54
        logger (:obj:`logging.Logger`, optional): Logger to log the error
55
            message. If not specified, print function will be used.
56
    """
57
    unexpected_keys = []
58
    all_missing_keys = []
59
    err_msg = []
60
61
    metadata = getattr(state_dict, '_metadata', None)
62
    state_dict = state_dict.copy()
63
    if metadata is not None:
64
        state_dict._metadata = metadata
65
66
    # use _load_from_state_dict to enable checkpoint version control
67
    def load(module, prefix=''):
68
        # recursively check parallel module in case that the model has a
69
        # complicated structure, e.g., nn.Module(nn.Module(DDP))
70
        if is_module_wrapper(module):
71
            module = module.module
72
        local_metadata = {} if metadata is None else metadata.get(
73
            prefix[:-1], {})
74
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
75
                                     all_missing_keys, unexpected_keys,
76
                                     err_msg)
77
        for name, child in module._modules.items():
78
            if child is not None:
79
                load(child, prefix + name + '.')
80
81
    load(module)
82
    load = None  # break load->load reference cycle
83
84
    # ignore "num_batches_tracked" of BN layers
85
    missing_keys = [
86
        key for key in all_missing_keys if 'num_batches_tracked' not in key
87
    ]
88
89
    if unexpected_keys:
90
        err_msg.append('unexpected key in source '
91
                       f'state_dict: {", ".join(unexpected_keys)}\n')
92
    if missing_keys:
93
        err_msg.append(
94
            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
95
96
    rank, _ = get_dist_info()
97
    if len(err_msg) > 0 and rank == 0:
98
        err_msg.insert(
99
            0, 'The model and loaded state dict do not match exactly\n')
100
        err_msg = '\n'.join(err_msg)
101
        if strict:
102
            raise RuntimeError(err_msg)
103
        elif logger is not None:
104
            logger.warning(err_msg)
105
        else:
106
            print(err_msg)
107
108
109
def load_url_dist(url, model_dir=None):
110
    """In distributed setting, this function only download checkpoint at local
111
    rank 0."""
112
    rank, world_size = get_dist_info()
113
    rank = int(os.environ.get('LOCAL_RANK', rank))
114
    if rank == 0:
115
        checkpoint = model_zoo.load_url(url, model_dir=model_dir)
116
    if world_size > 1:
117
        torch.distributed.barrier()
118
        if rank > 0:
119
            checkpoint = model_zoo.load_url(url, model_dir=model_dir)
120
    return checkpoint
121
122
123
def load_pavimodel_dist(model_path, map_location=None):
124
    """In distributed setting, this function only download checkpoint at local
125
    rank 0."""
126
    try:
127
        from pavi import modelcloud
128
    except ImportError:
129
        raise ImportError(
130
            'Please install pavi to load checkpoint from modelcloud.')
131
    rank, world_size = get_dist_info()
132
    rank = int(os.environ.get('LOCAL_RANK', rank))
133
    if rank == 0:
134
        model = modelcloud.get(model_path)
135
        with TemporaryDirectory() as tmp_dir:
136
            downloaded_file = osp.join(tmp_dir, model.name)
137
            model.download(downloaded_file)
138
            checkpoint = torch.load(downloaded_file, map_location=map_location)
139
    if world_size > 1:
140
        torch.distributed.barrier()
141
        if rank > 0:
142
            model = modelcloud.get(model_path)
143
            with TemporaryDirectory() as tmp_dir:
144
                downloaded_file = osp.join(tmp_dir, model.name)
145
                model.download(downloaded_file)
146
                checkpoint = torch.load(
147
                    downloaded_file, map_location=map_location)
148
    return checkpoint
149
150
151
def load_fileclient_dist(filename, backend, map_location):
152
    """In distributed setting, this function only download checkpoint at local
153
    rank 0."""
154
    rank, world_size = get_dist_info()
155
    rank = int(os.environ.get('LOCAL_RANK', rank))
156
    allowed_backends = ['ceph']
157
    if backend not in allowed_backends:
158
        raise ValueError(f'Load from Backend {backend} is not supported.')
159
    if rank == 0:
160
        fileclient = FileClient(backend=backend)
161
        buffer = io.BytesIO(fileclient.get(filename))
162
        checkpoint = torch.load(buffer, map_location=map_location)
163
    if world_size > 1:
164
        torch.distributed.barrier()
165
        if rank > 0:
166
            fileclient = FileClient(backend=backend)
167
            buffer = io.BytesIO(fileclient.get(filename))
168
            checkpoint = torch.load(buffer, map_location=map_location)
169
    return checkpoint
170
171
172
def get_torchvision_models():
173
    model_urls = dict()
174
    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
175
        if ispkg:
176
            continue
177
        _zoo = import_module(f'torchvision.models.{name}')
178
        if hasattr(_zoo, 'model_urls'):
179
            _urls = getattr(_zoo, 'model_urls')
180
            model_urls.update(_urls)
181
    return model_urls
182
183
184
def get_external_models():
185
    mmcv_home = _get_mmcv_home()
186
    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
187
    default_urls = load_file(default_json_path)
188
    assert isinstance(default_urls, dict)
189
    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
190
    if osp.exists(external_json_path):
191
        external_urls = load_file(external_json_path)
192
        assert isinstance(external_urls, dict)
193
        default_urls.update(external_urls)
194
195
    return default_urls
196
197
198
def get_mmcls_models():
199
    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
200
    mmcls_urls = load_file(mmcls_json_path)
201
202
    return mmcls_urls
203
204
205
def get_deprecated_model_names():
206
    deprecate_json_path = osp.join(mmcv.__path__[0],
207
                                   'model_zoo/deprecated.json')
208
    deprecate_urls = load_file(deprecate_json_path)
209
    assert isinstance(deprecate_urls, dict)
210
211
    return deprecate_urls
212
213
214
def _process_mmcls_checkpoint(checkpoint):
215
    state_dict = checkpoint['state_dict']
216
    new_state_dict = OrderedDict()
217
    for k, v in state_dict.items():
218
        if k.startswith('backbone.'):
219
            new_state_dict[k[9:]] = v
220
    new_checkpoint = dict(state_dict=new_state_dict)
221
222
    return new_checkpoint
223
224
225
def _load_checkpoint(filename, map_location=None):
226
    """Load checkpoint from somewhere (modelzoo, file, url).
227
228
    Args:
229
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
230
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
231
            details.
232
        map_location (str | None): Same as :func:`torch.load`. Default: None.
233
234
    Returns:
235
        dict | OrderedDict: The loaded checkpoint. It can be either an
236
            OrderedDict storing model weights or a dict containing other
237
            information, which depends on the checkpoint.
238
    """
239
    if filename.startswith('modelzoo://'):
240
        warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
241
                      'use "torchvision://" instead')
242
        model_urls = get_torchvision_models()
243
        model_name = filename[11:]
244
        checkpoint = load_url_dist(model_urls[model_name])
245
    elif filename.startswith('torchvision://'):
246
        model_urls = get_torchvision_models()
247
        model_name = filename[14:]
248
        checkpoint = load_url_dist(model_urls[model_name])
249
    elif filename.startswith('open-mmlab://'):
250
        model_urls = get_external_models()
251
        model_name = filename[13:]
252
        deprecated_urls = get_deprecated_model_names()
253
        if model_name in deprecated_urls:
254
            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
255
                          f'of open-mmlab://{deprecated_urls[model_name]}')
256
            model_name = deprecated_urls[model_name]
257
        model_url = model_urls[model_name]
258
        # check if is url
259
        if model_url.startswith(('http://', 'https://')):
260
            checkpoint = load_url_dist(model_url)
261
        else:
262
            filename = osp.join(_get_mmcv_home(), model_url)
263
            if not osp.isfile(filename):
264
                raise IOError(f'{filename} is not a checkpoint file')
265
            checkpoint = torch.load(filename, map_location=map_location)
266
    elif filename.startswith('mmcls://'):
267
        model_urls = get_mmcls_models()
268
        model_name = filename[8:]
269
        checkpoint = load_url_dist(model_urls[model_name])
270
        checkpoint = _process_mmcls_checkpoint(checkpoint)
271
    elif filename.startswith(('http://', 'https://')):
272
        checkpoint = load_url_dist(filename)
273
    elif filename.startswith('pavi://'):
274
        model_path = filename[7:]
275
        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
276
    elif filename.startswith('s3://'):
277
        checkpoint = load_fileclient_dist(
278
            filename, backend='ceph', map_location=map_location)
279
    else:
280
        if not osp.isfile(filename):
281
            raise IOError(f'{filename} is not a checkpoint file')
282
        checkpoint = torch.load(filename, map_location=map_location)
283
    return checkpoint
284
285
286
def load_checkpoint(model,
287
                    filename,
288
                    map_location='cpu',
289
                    strict=False,
290
                    logger=None):
291
    """Load checkpoint from a file or URI.
292
293
    Args:
294
        model (Module): Module to load checkpoint.
295
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
296
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
297
            details.
298
        map_location (str): Same as :func:`torch.load`.
299
        strict (bool): Whether to allow different params for the model and
300
            checkpoint.
301
        logger (:mod:`logging.Logger` or None): The logger for error message.
302
303
    Returns:
304
        dict or OrderedDict: The loaded checkpoint.
305
    """
306
    checkpoint = _load_checkpoint(filename, map_location)
307
    # OrderedDict is a subclass of dict
308
    if not isinstance(checkpoint, dict):
309
        raise RuntimeError(
310
            f'No state_dict found in checkpoint file {filename}')
311
    # get state_dict from checkpoint
312
    if 'state_dict' in checkpoint:
313
        state_dict = checkpoint['state_dict']
314
    elif 'state_dict_ema' in checkpoint:
315
        state_dict = checkpoint['state_dict_ema']
316
    elif 'model' in checkpoint:
317
        state_dict = checkpoint['model']
318
    else:
319
        state_dict = checkpoint
320
    # strip prefix of state_dict
321
    if list(state_dict.keys())[0].startswith('module.'):
322
        state_dict = {k[7:]: v for k, v in state_dict.items()}
323
324
    # for MoBY, load model of online branch
325
    if sorted(list(state_dict.keys()))[0].startswith('encoder'):
326
        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
327
328
    # reshape absolute position embedding
329
    if state_dict.get('absolute_pos_embed') is not None:
330
        absolute_pos_embed = state_dict['absolute_pos_embed']
331
        N1, L, C1 = absolute_pos_embed.size()
332
        N2, C2, H, W = model.absolute_pos_embed.size()
333
        if N1 != N2 or C1 != C2 or L != H*W:
334
            logger.warning("Error in loading absolute_pos_embed, pass")
335
        else:
336
            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
337
338
    # interpolate position bias table if needed
339
    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
340
    for table_key in relative_position_bias_table_keys:
341
        table_pretrained = state_dict[table_key]
342
        table_current = model.state_dict()[table_key]
343
        L1, nH1 = table_pretrained.size()
344
        L2, nH2 = table_current.size()
345
        if nH1 != nH2:
346
            logger.warning(f"Error in loading {table_key}, pass")
347
        else:
348
            if L1 != L2:
349
                S1 = int(L1 ** 0.5)
350
                S2 = int(L2 ** 0.5)
351
                table_pretrained_resized = F.interpolate(
352
                     table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
353
                     size=(S2, S2), mode='bicubic')
354
                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
355
356
    # load state_dict
357
    if model.in_chans != 3:
358
        cur_state_dict = model.state_dict()
359
        for k in state_dict.keys():
360
            if k not in cur_state_dict:
361
                continue
362
            if len(state_dict[k].shape) > 1 and state_dict[k].shape[1] == 3 and cur_state_dict[k].shape[1] == model.in_chans:
363
                state_dict[k] = torch.cat([state_dict[k], *[state_dict[k][:,0:1] for _ in range(model.in_chans - 3)]], 1)
364
    load_state_dict(model, state_dict, strict, logger)
365
    return checkpoint
366
367
368
def weights_to_cpu(state_dict):
369
    """Copy a model state_dict to cpu.
370
371
    Args:
372
        state_dict (OrderedDict): Model weights on GPU.
373
374
    Returns:
375
        OrderedDict: Model weights on GPU.
376
    """
377
    state_dict_cpu = OrderedDict()
378
    for key, val in state_dict.items():
379
        state_dict_cpu[key] = val.cpu()
380
    return state_dict_cpu
381
382
383
def _save_to_state_dict(module, destination, prefix, keep_vars):
384
    """Saves module state to `destination` dictionary.
385
386
    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
387
388
    Args:
389
        module (nn.Module): The module to generate state_dict.
390
        destination (dict): A dict where state will be stored.
391
        prefix (str): The prefix for parameters and buffers used in this
392
            module.
393
    """
394
    for name, param in module._parameters.items():
395
        if param is not None:
396
            destination[prefix + name] = param if keep_vars else param.detach()
397
    for name, buf in module._buffers.items():
398
        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
399
        if buf is not None:
400
            destination[prefix + name] = buf if keep_vars else buf.detach()
401
402
403
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
404
    """Returns a dictionary containing a whole state of the module.
405
406
    Both parameters and persistent buffers (e.g. running averages) are
407
    included. Keys are corresponding parameter and buffer names.
408
409
    This method is modified from :meth:`torch.nn.Module.state_dict` to
410
    recursively check parallel module in case that the model has a complicated
411
    structure, e.g., nn.Module(nn.Module(DDP)).
412
413
    Args:
414
        module (nn.Module): The module to generate state_dict.
415
        destination (OrderedDict): Returned dict for the state of the
416
            module.
417
        prefix (str): Prefix of the key.
418
        keep_vars (bool): Whether to keep the variable property of the
419
            parameters. Default: False.
420
421
    Returns:
422
        dict: A dictionary containing a whole state of the module.
423
    """
424
    # recursively check parallel module in case that the model has a
425
    # complicated structure, e.g., nn.Module(nn.Module(DDP))
426
    if is_module_wrapper(module):
427
        module = module.module
428
429
    # below is the same as torch.nn.Module.state_dict()
430
    if destination is None:
431
        destination = OrderedDict()
432
        destination._metadata = OrderedDict()
433
    destination._metadata[prefix[:-1]] = local_metadata = dict(
434
        version=module._version)
435
    _save_to_state_dict(module, destination, prefix, keep_vars)
436
    for name, child in module._modules.items():
437
        if child is not None:
438
            get_state_dict(
439
                child, destination, prefix + name + '.', keep_vars=keep_vars)
440
    for hook in module._state_dict_hooks.values():
441
        hook_result = hook(module, destination, prefix, local_metadata)
442
        if hook_result is not None:
443
            destination = hook_result
444
    return destination
445
446
447
def save_checkpoint(model, filename, optimizer=None, meta=None):
448
    """Save checkpoint to file.
449
450
    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
451
    ``optimizer``. By default ``meta`` will contain version and time info.
452
453
    Args:
454
        model (Module): Module whose params are to be saved.
455
        filename (str): Checkpoint filename.
456
        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
457
        meta (dict, optional): Metadata to be saved in checkpoint.
458
    """
459
    if meta is None:
460
        meta = {}
461
    elif not isinstance(meta, dict):
462
        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
463
    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
464
465
    if is_module_wrapper(model):
466
        model = model.module
467
468
    if hasattr(model, 'CLASSES') and model.CLASSES is not None:
469
        # save class name to the meta
470
        meta.update(CLASSES=model.CLASSES)
471
472
    checkpoint = {
473
        'meta': meta,
474
        'state_dict': weights_to_cpu(get_state_dict(model))
475
    }
476
    # save optimizer state dict in the checkpoint
477
    if isinstance(optimizer, Optimizer):
478
        checkpoint['optimizer'] = optimizer.state_dict()
479
    elif isinstance(optimizer, dict):
480
        checkpoint['optimizer'] = {}
481
        for name, optim in optimizer.items():
482
            checkpoint['optimizer'][name] = optim.state_dict()
483
484
    if filename.startswith('pavi://'):
485
        try:
486
            from pavi import modelcloud
487
            from pavi.exception import NodeNotFoundError
488
        except ImportError:
489
            raise ImportError(
490
                'Please install pavi to load checkpoint from modelcloud.')
491
        model_path = filename[7:]
492
        root = modelcloud.Folder()
493
        model_dir, model_name = osp.split(model_path)
494
        try:
495
            model = modelcloud.get(model_dir)
496
        except NodeNotFoundError:
497
            model = root.create_training_model(model_dir)
498
        with TemporaryDirectory() as tmp_dir:
499
            checkpoint_file = osp.join(tmp_dir, model_name)
500
            with open(checkpoint_file, 'wb') as f:
501
                torch.save(checkpoint, f)
502
                f.flush()
503
            model.create_file(checkpoint_file, name=model_name)
504
    else:
505
        mmcv.mkdir_or_exist(osp.dirname(filename))
506
        # immediately flush buffer
507
        with open(filename, 'wb') as f:
508
            torch.save(checkpoint, f)
509
            f.flush()