Diff of /utils/image_list.py [000000] .. [98e649]

Switch to unified view

a b/utils/image_list.py
1
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
from __future__ import division
3
4
import torch
5
import numpy as np
6
7
class ImageList(object):
8
    """
9
    Structure that holds a list of images (of possibly
10
    varying sizes) as a single tensor.
11
    This works by padding the images to the same size,
12
    and storing in a field the original sizes of each image
13
    """
14
15
    def __init__(self, tensors, image_sizes):
16
        """
17
        Arguments:
18
            tensors (tensor)
19
            image_sizes (list[tuple[int, int]])
20
        """
21
        self.tensors = tensors
22
        self.image_sizes = image_sizes
23
24
    def to(self, *args, **kwargs):
25
        cast_tensor = self.tensors.to(*args, **kwargs)
26
        return ImageList(cast_tensor, self.image_sizes)
27
28
29
def to_image_list(tensors, size_divisible=0, return_size=False):
30
    """
31
    tensors can be an ImageList, a torch.Tensor or
32
    an iterable of Tensors. It can't be a numpy array.
33
    When tensors is an iterable of Tensors, it pads
34
    the Tensors with zeros so that they have the same
35
    shape
36
    """
37
    if isinstance(tensors, torch.Tensor) and size_divisible > 0:
38
        tensors = [tensors]
39
40
    if isinstance(tensors, ImageList):
41
        return tensors
42
    elif isinstance(tensors, torch.Tensor):
43
        # single tensor shape can be inferred
44
        if tensors.dim() == 3:
45
            tensors = tensors[None]
46
        assert tensors.dim() == 4
47
        image_sizes = [tensor.shape[-2:] for tensor in tensors]
48
        return ImageList(tensors, image_sizes)
49
    elif isinstance(tensors, (tuple, list, np.ndarray)):
50
        max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
51
52
        # TODO Ideally, just remove this and let me model handle arbitrary
53
        # input sizs
54
        if size_divisible > 0:
55
            import math
56
57
            stride = size_divisible
58
            max_size = list(max_size)
59
            max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
60
            max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
61
            max_size = tuple(max_size)
62
63
        batch_shape = (len(tensors),) + max_size
64
        batched_imgs = tensors[0].new(*batch_shape).zero_()
65
        for img, pad_img in zip(tensors, batched_imgs):
66
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
67
68
        image_sizes = [im.shape[-2:] for im in tensors]
69
70
        # return ImageList(batched_imgs, image_sizes)
71
        if return_size:
72
            return batched_imgs, image_sizes
73
        else:
74
            return batched_imgs
75
    else:
76
        raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
77