|
a |
|
b/scripts/plot_complexity.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
"""Code to generate plots for Extended Data Fig. 4.""" |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
|
|
|
7 |
import matplotlib |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import numpy as np |
|
|
10 |
|
|
|
11 |
import echonet |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def main(root=os.path.join("timing", "video"), |
|
|
15 |
fig_root=os.path.join("figure", "complexity"), |
|
|
16 |
FRAMES=(1, 8, 16, 32, 64, 96), |
|
|
17 |
pretrained=True): |
|
|
18 |
"""Generate plots for Extended Data Fig. 4.""" |
|
|
19 |
|
|
|
20 |
echonet.utils.latexify() |
|
|
21 |
|
|
|
22 |
os.makedirs(fig_root, exist_ok=True) |
|
|
23 |
fig = plt.figure(figsize=(6.50, 2.50)) |
|
|
24 |
gs = matplotlib.gridspec.GridSpec(1, 3, width_ratios=[2.5, 2.5, 1.50]) |
|
|
25 |
ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2])) |
|
|
26 |
|
|
|
27 |
# Create legend |
|
|
28 |
for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3"], matplotlib.colors.TABLEAU_COLORS): |
|
|
29 |
ax[2].plot([float("nan")], [float("nan")], "-", color=color, label=model) |
|
|
30 |
ax[2].set_title("") |
|
|
31 |
ax[2].axis("off") |
|
|
32 |
ax[2].legend(loc="center") |
|
|
33 |
|
|
|
34 |
for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS): |
|
|
35 |
for split in ["val"]: # ["val", "train"]: |
|
|
36 |
print(model, split) |
|
|
37 |
data = [load(root, model, frames, 1, pretrained, split) for frames in FRAMES] |
|
|
38 |
time = np.array(list(map(lambda x: x[0], data))) |
|
|
39 |
n = np.array(list(map(lambda x: x[1], data))) |
|
|
40 |
mem_allocated = np.array(list(map(lambda x: x[2], data))) |
|
|
41 |
# mem_cached = np.array(list(map(lambda x: x[3], data))) |
|
|
42 |
batch_size = np.array(list(map(lambda x: x[4], data))) |
|
|
43 |
|
|
|
44 |
# Plot Time (panel a) |
|
|
45 |
ax[0].plot(FRAMES, time / n, "-" if pretrained else "--", marker=".", color=color, linewidth=(1 if split == "train" else None)) |
|
|
46 |
print("Time:\n" + "\n".join(map(lambda x: "{:8d}: {:f}".format(*x), zip(FRAMES, time / n)))) |
|
|
47 |
|
|
|
48 |
# Plot Memory (panel b) |
|
|
49 |
ax[1].plot(FRAMES, mem_allocated / batch_size / 1e9, "-" if pretrained else "--", marker=".", color=color, linewidth=(1 if split == "train" else None)) |
|
|
50 |
print("Memory:\n" + "\n".join(map(lambda x: "{:8d}: {:f}".format(*x), zip(FRAMES, mem_allocated / batch_size / 1e9)))) |
|
|
51 |
print() |
|
|
52 |
|
|
|
53 |
# Labels for panel a |
|
|
54 |
ax[0].set_xticks(FRAMES) |
|
|
55 |
ax[0].text(-0.05, 1.10, "(a)", transform=ax[0].transAxes) |
|
|
56 |
ax[0].set_xlabel("Clip length (frames)") |
|
|
57 |
ax[0].set_ylabel("Time Per Clip (seconds)") |
|
|
58 |
|
|
|
59 |
# Labels for panel b |
|
|
60 |
ax[1].set_xticks(FRAMES) |
|
|
61 |
ax[1].text(-0.05, 1.10, "(b)", transform=ax[1].transAxes) |
|
|
62 |
ax[1].set_xlabel("Clip length (frames)") |
|
|
63 |
ax[1].set_ylabel("Memory Per Clip (GB)") |
|
|
64 |
|
|
|
65 |
# Save figure |
|
|
66 |
plt.tight_layout() |
|
|
67 |
plt.savefig(os.path.join(fig_root, "complexity.pdf")) |
|
|
68 |
plt.savefig(os.path.join(fig_root, "complexity.eps")) |
|
|
69 |
plt.close(fig) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def load(root, model, frames, period, pretrained, split): |
|
|
73 |
"""Loads runtime and memory usage for specified hyperparameter choice.""" |
|
|
74 |
with open(os.path.join(root, "{}_{}_{}_{}".format(model, frames, period, "pretrained" if pretrained else "random"), "log.csv"), "r") as f: |
|
|
75 |
for line in f: |
|
|
76 |
line = line.split(",") |
|
|
77 |
if len(line) < 4: |
|
|
78 |
# Skip lines that are not csv (these lines log information) |
|
|
79 |
continue |
|
|
80 |
if line[1] == split: |
|
|
81 |
*_, time, n, mem_allocated, mem_cached, batch_size = line |
|
|
82 |
time = float(time) |
|
|
83 |
n = int(n) |
|
|
84 |
mem_allocated = int(mem_allocated) |
|
|
85 |
mem_cached = int(mem_cached) |
|
|
86 |
batch_size = int(batch_size) |
|
|
87 |
return time, n, mem_allocated, mem_cached, batch_size |
|
|
88 |
raise ValueError("File missing information.") |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
if __name__ == "__main__": |
|
|
92 |
main() |