|
a |
|
b/utils/image_list.py |
|
|
1 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. |
|
|
2 |
from __future__ import division |
|
|
3 |
|
|
|
4 |
import torch |
|
|
5 |
import numpy as np |
|
|
6 |
|
|
|
7 |
class ImageList(object): |
|
|
8 |
""" |
|
|
9 |
Structure that holds a list of images (of possibly |
|
|
10 |
varying sizes) as a single tensor. |
|
|
11 |
This works by padding the images to the same size, |
|
|
12 |
and storing in a field the original sizes of each image |
|
|
13 |
""" |
|
|
14 |
|
|
|
15 |
def __init__(self, tensors, image_sizes): |
|
|
16 |
""" |
|
|
17 |
Arguments: |
|
|
18 |
tensors (tensor) |
|
|
19 |
image_sizes (list[tuple[int, int]]) |
|
|
20 |
""" |
|
|
21 |
self.tensors = tensors |
|
|
22 |
self.image_sizes = image_sizes |
|
|
23 |
|
|
|
24 |
def to(self, *args, **kwargs): |
|
|
25 |
cast_tensor = self.tensors.to(*args, **kwargs) |
|
|
26 |
return ImageList(cast_tensor, self.image_sizes) |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def to_image_list(tensors, size_divisible=0, return_size=False): |
|
|
30 |
""" |
|
|
31 |
tensors can be an ImageList, a torch.Tensor or |
|
|
32 |
an iterable of Tensors. It can't be a numpy array. |
|
|
33 |
When tensors is an iterable of Tensors, it pads |
|
|
34 |
the Tensors with zeros so that they have the same |
|
|
35 |
shape |
|
|
36 |
""" |
|
|
37 |
if isinstance(tensors, torch.Tensor) and size_divisible > 0: |
|
|
38 |
tensors = [tensors] |
|
|
39 |
|
|
|
40 |
if isinstance(tensors, ImageList): |
|
|
41 |
return tensors |
|
|
42 |
elif isinstance(tensors, torch.Tensor): |
|
|
43 |
# single tensor shape can be inferred |
|
|
44 |
if tensors.dim() == 3: |
|
|
45 |
tensors = tensors[None] |
|
|
46 |
assert tensors.dim() == 4 |
|
|
47 |
image_sizes = [tensor.shape[-2:] for tensor in tensors] |
|
|
48 |
return ImageList(tensors, image_sizes) |
|
|
49 |
elif isinstance(tensors, (tuple, list, np.ndarray)): |
|
|
50 |
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) |
|
|
51 |
|
|
|
52 |
# TODO Ideally, just remove this and let me model handle arbitrary |
|
|
53 |
# input sizs |
|
|
54 |
if size_divisible > 0: |
|
|
55 |
import math |
|
|
56 |
|
|
|
57 |
stride = size_divisible |
|
|
58 |
max_size = list(max_size) |
|
|
59 |
max_size[1] = int(math.ceil(max_size[1] / stride) * stride) |
|
|
60 |
max_size[2] = int(math.ceil(max_size[2] / stride) * stride) |
|
|
61 |
max_size = tuple(max_size) |
|
|
62 |
|
|
|
63 |
batch_shape = (len(tensors),) + max_size |
|
|
64 |
batched_imgs = tensors[0].new(*batch_shape).zero_() |
|
|
65 |
for img, pad_img in zip(tensors, batched_imgs): |
|
|
66 |
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
|
|
67 |
|
|
|
68 |
image_sizes = [im.shape[-2:] for im in tensors] |
|
|
69 |
|
|
|
70 |
# return ImageList(batched_imgs, image_sizes) |
|
|
71 |
if return_size: |
|
|
72 |
return batched_imgs, image_sizes |
|
|
73 |
else: |
|
|
74 |
return batched_imgs |
|
|
75 |
else: |
|
|
76 |
raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) |
|
|
77 |
|