|
a |
|
b/echonet/utils/__init__.py |
|
|
1 |
"""Utility functions for videos, plotting and computing performance metrics.""" |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import typing |
|
|
5 |
|
|
|
6 |
import cv2 # pytype: disable=attribute-error |
|
|
7 |
import matplotlib |
|
|
8 |
import numpy as np |
|
|
9 |
import torch |
|
|
10 |
import tqdm |
|
|
11 |
|
|
|
12 |
from . import video |
|
|
13 |
from . import segmentation |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def loadvideo(filename: str) -> np.ndarray: |
|
|
17 |
"""Loads a video from a file. |
|
|
18 |
|
|
|
19 |
Args: |
|
|
20 |
filename (str): filename of video |
|
|
21 |
|
|
|
22 |
Returns: |
|
|
23 |
A np.ndarray with dimensions (channels=3, frames, height, width). The |
|
|
24 |
values will be uint8's ranging from 0 to 255. |
|
|
25 |
|
|
|
26 |
Raises: |
|
|
27 |
FileNotFoundError: Could not find `filename` |
|
|
28 |
ValueError: An error occurred while reading the video |
|
|
29 |
""" |
|
|
30 |
|
|
|
31 |
if not os.path.exists(filename): |
|
|
32 |
raise FileNotFoundError(filename) |
|
|
33 |
capture = cv2.VideoCapture(filename) |
|
|
34 |
|
|
|
35 |
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
36 |
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
37 |
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
38 |
|
|
|
39 |
v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8) |
|
|
40 |
|
|
|
41 |
for count in range(frame_count): |
|
|
42 |
ret, frame = capture.read() |
|
|
43 |
if not ret: |
|
|
44 |
raise ValueError("Failed to load frame #{} of {}.".format(count, filename)) |
|
|
45 |
|
|
|
46 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
47 |
v[count, :, :] = frame |
|
|
48 |
|
|
|
49 |
v = v.transpose((3, 0, 1, 2)) |
|
|
50 |
|
|
|
51 |
return v |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1): |
|
|
55 |
"""Saves a video to a file. |
|
|
56 |
|
|
|
57 |
Args: |
|
|
58 |
filename (str): filename of video |
|
|
59 |
array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width) |
|
|
60 |
fps (float or int): frames per second |
|
|
61 |
|
|
|
62 |
Returns: |
|
|
63 |
None |
|
|
64 |
""" |
|
|
65 |
|
|
|
66 |
c, _, height, width = array.shape |
|
|
67 |
|
|
|
68 |
if c != 3: |
|
|
69 |
raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape)))) |
|
|
70 |
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') |
|
|
71 |
out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) |
|
|
72 |
|
|
|
73 |
for frame in array.transpose((1, 2, 3, 0)): |
|
|
74 |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
|
75 |
out.write(frame) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
def get_mean_and_std(dataset: torch.utils.data.Dataset, |
|
|
79 |
samples: int = 128, |
|
|
80 |
batch_size: int = 8, |
|
|
81 |
num_workers: int = 4): |
|
|
82 |
"""Computes mean and std from samples from a Pytorch dataset. |
|
|
83 |
|
|
|
84 |
Args: |
|
|
85 |
dataset (torch.utils.data.Dataset): A Pytorch dataset. |
|
|
86 |
``dataset[i][0]'' is expected to be the i-th video in the dataset, which |
|
|
87 |
should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) |
|
|
88 |
samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and |
|
|
89 |
standard deviation are computed over all elements. |
|
|
90 |
Defaults to 128. |
|
|
91 |
batch_size (int, optional): how many samples per batch to load |
|
|
92 |
Defaults to 8. |
|
|
93 |
num_workers (int, optional): how many subprocesses to use for data |
|
|
94 |
loading. If 0, the data will be loaded in the main process. |
|
|
95 |
Defaults to 4. |
|
|
96 |
|
|
|
97 |
Returns: |
|
|
98 |
A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,). |
|
|
99 |
""" |
|
|
100 |
|
|
|
101 |
if samples is not None and len(dataset) > samples: |
|
|
102 |
indices = np.random.choice(len(dataset), samples, replace=False) |
|
|
103 |
dataset = torch.utils.data.Subset(dataset, indices) |
|
|
104 |
dataloader = torch.utils.data.DataLoader( |
|
|
105 |
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) |
|
|
106 |
|
|
|
107 |
n = 0 # number of elements taken (should be equal to samples by end of for loop) |
|
|
108 |
s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,)) |
|
|
109 |
s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,)) |
|
|
110 |
for (x, *_) in tqdm.tqdm(dataloader): |
|
|
111 |
x = x.transpose(0, 1).contiguous().view(3, -1) |
|
|
112 |
n += x.shape[1] |
|
|
113 |
s1 += torch.sum(x, dim=1).numpy() |
|
|
114 |
s2 += torch.sum(x ** 2, dim=1).numpy() |
|
|
115 |
mean = s1 / n # type: np.ndarray |
|
|
116 |
std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray |
|
|
117 |
|
|
|
118 |
mean = mean.astype(np.float32) |
|
|
119 |
std = std.astype(np.float32) |
|
|
120 |
|
|
|
121 |
return mean, std |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
def bootstrap(a, b, func, samples=10000): |
|
|
125 |
"""Computes a bootstrapped confidence intervals for ``func(a, b)''. |
|
|
126 |
|
|
|
127 |
Args: |
|
|
128 |
a (array_like): first argument to `func`. |
|
|
129 |
b (array_like): second argument to `func`. |
|
|
130 |
func (callable): Function to compute confidence intervals for. |
|
|
131 |
``dataset[i][0]'' is expected to be the i-th video in the dataset, which |
|
|
132 |
should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) |
|
|
133 |
samples (int, optional): Number of samples to compute. |
|
|
134 |
Defaults to 10000. |
|
|
135 |
|
|
|
136 |
Returns: |
|
|
137 |
A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile). |
|
|
138 |
""" |
|
|
139 |
a = np.array(a) |
|
|
140 |
b = np.array(b) |
|
|
141 |
|
|
|
142 |
bootstraps = [] |
|
|
143 |
for _ in range(samples): |
|
|
144 |
ind = np.random.choice(len(a), len(a)) |
|
|
145 |
bootstraps.append(func(a[ind], b[ind])) |
|
|
146 |
bootstraps = sorted(bootstraps) |
|
|
147 |
|
|
|
148 |
return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))] |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
def latexify(): |
|
|
152 |
"""Sets matplotlib params to appear more like LaTeX. |
|
|
153 |
|
|
|
154 |
Based on https://nipunbatra.github.io/blog/2014/latexify.html |
|
|
155 |
""" |
|
|
156 |
params = {'backend': 'pdf', |
|
|
157 |
'axes.titlesize': 8, |
|
|
158 |
'axes.labelsize': 8, |
|
|
159 |
'font.size': 8, |
|
|
160 |
'legend.fontsize': 8, |
|
|
161 |
'xtick.labelsize': 8, |
|
|
162 |
'ytick.labelsize': 8, |
|
|
163 |
'font.family': 'DejaVu Serif', |
|
|
164 |
'font.serif': 'Computer Modern', |
|
|
165 |
} |
|
|
166 |
matplotlib.rcParams.update(params) |
|
|
167 |
|
|
|
168 |
|
|
|
169 |
def dice_similarity_coefficient(inter, union): |
|
|
170 |
"""Computes the dice similarity coefficient. |
|
|
171 |
|
|
|
172 |
Args: |
|
|
173 |
inter (iterable): iterable of the intersections |
|
|
174 |
union (iterable): iterable of the unions |
|
|
175 |
""" |
|
|
176 |
return 2 * sum(inter) / (sum(union) + sum(inter)) |
|
|
177 |
|
|
|
178 |
|
|
|
179 |
__all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"] |