a b/lavis/processors/functional_video.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import warnings
9
10
import torch
11
12
13
def _is_tensor_video_clip(clip):
14
    if not torch.is_tensor(clip):
15
        raise TypeError("clip should be Tensor. Got %s" % type(clip))
16
17
    if not clip.ndimension() == 4:
18
        raise ValueError("clip should be 4D. Got %dD" % clip.dim())
19
20
    return True
21
22
23
def crop(clip, i, j, h, w):
24
    """
25
    Args:
26
        clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
27
    """
28
    if len(clip.size()) != 4:
29
        raise ValueError("clip should be a 4D tensor")
30
    return clip[..., i : i + h, j : j + w]
31
32
33
def resize(clip, target_size, interpolation_mode):
34
    if len(target_size) != 2:
35
        raise ValueError(
36
            f"target size should be tuple (height, width), instead got {target_size}"
37
        )
38
    return torch.nn.functional.interpolate(
39
        clip, size=target_size, mode=interpolation_mode, align_corners=False
40
    )
41
42
43
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
44
    """
45
    Do spatial cropping and resizing to the video clip
46
    Args:
47
        clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
48
        i (int): i in (i,j) i.e coordinates of the upper left corner.
49
        j (int): j in (i,j) i.e coordinates of the upper left corner.
50
        h (int): Height of the cropped region.
51
        w (int): Width of the cropped region.
52
        size (tuple(int, int)): height and width of resized clip
53
    Returns:
54
        clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
55
    """
56
    if not _is_tensor_video_clip(clip):
57
        raise ValueError("clip should be a 4D torch.tensor")
58
    clip = crop(clip, i, j, h, w)
59
    clip = resize(clip, size, interpolation_mode)
60
    return clip
61
62
63
def center_crop(clip, crop_size):
64
    if not _is_tensor_video_clip(clip):
65
        raise ValueError("clip should be a 4D torch.tensor")
66
    h, w = clip.size(-2), clip.size(-1)
67
    th, tw = crop_size
68
    if h < th or w < tw:
69
        raise ValueError("height and width must be no smaller than crop_size")
70
71
    i = int(round((h - th) / 2.0))
72
    j = int(round((w - tw) / 2.0))
73
    return crop(clip, i, j, th, tw)
74
75
76
def to_tensor(clip):
77
    """
78
    Convert tensor data type from uint8 to float, divide value by 255.0 and
79
    permute the dimensions of clip tensor
80
    Args:
81
        clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
82
    Return:
83
        clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
84
    """
85
    _is_tensor_video_clip(clip)
86
    if not clip.dtype == torch.uint8:
87
        raise TypeError(
88
            "clip tensor should have data type uint8. Got %s" % str(clip.dtype)
89
        )
90
    return clip.float().permute(3, 0, 1, 2) / 255.0
91
92
93
def normalize(clip, mean, std, inplace=False):
94
    """
95
    Args:
96
        clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
97
        mean (tuple): pixel RGB mean. Size is (3)
98
        std (tuple): pixel standard deviation. Size is (3)
99
    Returns:
100
        normalized clip (torch.tensor): Size is (C, T, H, W)
101
    """
102
    if not _is_tensor_video_clip(clip):
103
        raise ValueError("clip should be a 4D torch.tensor")
104
    if not inplace:
105
        clip = clip.clone()
106
    mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
107
    std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
108
    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
109
    return clip
110
111
112
def hflip(clip):
113
    """
114
    Args:
115
        clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
116
    Returns:
117
        flipped clip (torch.tensor): Size is (C, T, H, W)
118
    """
119
    if not _is_tensor_video_clip(clip):
120
        raise ValueError("clip should be a 4D torch.tensor")
121
    return clip.flip(-1)