Diff of /echonet/utils/__init__.py [000000] .. [aeb6cc]

Switch to unified view

a b/echonet/utils/__init__.py
1
"""Utility functions for videos, plotting and computing performance metrics."""
2
3
import os
4
import typing
5
6
import cv2  # pytype: disable=attribute-error
7
import matplotlib
8
import numpy as np
9
import torch
10
import tqdm
11
12
from . import video
13
from . import segmentation
14
15
16
def loadvideo(filename: str) -> np.ndarray:
17
    """Loads a video from a file.
18
19
    Args:
20
        filename (str): filename of video
21
22
    Returns:
23
        A np.ndarray with dimensions (channels=3, frames, height, width). The
24
        values will be uint8's ranging from 0 to 255.
25
26
    Raises:
27
        FileNotFoundError: Could not find `filename`
28
        ValueError: An error occurred while reading the video
29
    """
30
31
    if not os.path.exists(filename):
32
        raise FileNotFoundError(filename)
33
    capture = cv2.VideoCapture(filename)
34
35
    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
36
    frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
37
    frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
38
39
    v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8)
40
41
    for count in range(frame_count):
42
        ret, frame = capture.read()
43
        if not ret:
44
            raise ValueError("Failed to load frame #{} of {}.".format(count, filename))
45
46
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
        v[count, :, :] = frame
48
49
    v = v.transpose((3, 0, 1, 2))
50
51
    return v
52
53
54
def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1):
55
    """Saves a video to a file.
56
57
    Args:
58
        filename (str): filename of video
59
        array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width)
60
        fps (float or int): frames per second
61
62
    Returns:
63
        None
64
    """
65
66
    c, _, height, width = array.shape
67
68
    if c != 3:
69
        raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape))))
70
    fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
71
    out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
72
73
    for frame in array.transpose((1, 2, 3, 0)):
74
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
75
        out.write(frame)
76
77
78
def get_mean_and_std(dataset: torch.utils.data.Dataset,
79
                     samples: int = 128,
80
                     batch_size: int = 8,
81
                     num_workers: int = 4):
82
    """Computes mean and std from samples from a Pytorch dataset.
83
84
    Args:
85
        dataset (torch.utils.data.Dataset): A Pytorch dataset.
86
            ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
87
            should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
88
        samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and
89
            standard deviation are computed over all elements.
90
            Defaults to 128.
91
        batch_size (int, optional): how many samples per batch to load
92
            Defaults to 8.
93
        num_workers (int, optional): how many subprocesses to use for data
94
            loading. If 0, the data will be loaded in the main process.
95
            Defaults to 4.
96
97
    Returns:
98
       A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,).
99
    """
100
101
    if samples is not None and len(dataset) > samples:
102
        indices = np.random.choice(len(dataset), samples, replace=False)
103
        dataset = torch.utils.data.Subset(dataset, indices)
104
    dataloader = torch.utils.data.DataLoader(
105
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
106
107
    n = 0  # number of elements taken (should be equal to samples by end of for loop)
108
    s1 = 0.  # sum of elements along channels (ends up as np.array of dimension (channels,))
109
    s2 = 0.  # sum of squares of elements along channels (ends up as np.array of dimension (channels,))
110
    for (x, *_) in tqdm.tqdm(dataloader):
111
        x = x.transpose(0, 1).contiguous().view(3, -1)
112
        n += x.shape[1]
113
        s1 += torch.sum(x, dim=1).numpy()
114
        s2 += torch.sum(x ** 2, dim=1).numpy()
115
    mean = s1 / n  # type: np.ndarray
116
    std = np.sqrt(s2 / n - mean ** 2)  # type: np.ndarray
117
118
    mean = mean.astype(np.float32)
119
    std = std.astype(np.float32)
120
121
    return mean, std
122
123
124
def bootstrap(a, b, func, samples=10000):
125
    """Computes a bootstrapped confidence intervals for ``func(a, b)''.
126
127
    Args:
128
        a (array_like): first argument to `func`.
129
        b (array_like): second argument to `func`.
130
        func (callable): Function to compute confidence intervals for.
131
            ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
132
            should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
133
        samples (int, optional): Number of samples to compute.
134
            Defaults to 10000.
135
136
    Returns:
137
       A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile).
138
    """
139
    a = np.array(a)
140
    b = np.array(b)
141
142
    bootstraps = []
143
    for _ in range(samples):
144
        ind = np.random.choice(len(a), len(a))
145
        bootstraps.append(func(a[ind], b[ind]))
146
    bootstraps = sorted(bootstraps)
147
148
    return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))]
149
150
151
def latexify():
152
    """Sets matplotlib params to appear more like LaTeX.
153
154
    Based on https://nipunbatra.github.io/blog/2014/latexify.html
155
    """
156
    params = {'backend': 'pdf',
157
              'axes.titlesize': 8,
158
              'axes.labelsize': 8,
159
              'font.size': 8,
160
              'legend.fontsize': 8,
161
              'xtick.labelsize': 8,
162
              'ytick.labelsize': 8,
163
              'font.family': 'DejaVu Serif',
164
              'font.serif': 'Computer Modern',
165
              }
166
    matplotlib.rcParams.update(params)
167
168
169
def dice_similarity_coefficient(inter, union):
170
    """Computes the dice similarity coefficient.
171
172
    Args:
173
        inter (iterable): iterable of the intersections
174
        union (iterable): iterable of the unions
175
    """
176
    return 2 * sum(inter) / (sum(union) + sum(inter))
177
178
179
__all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"]