Switch to unified view

a b/scripts/plot_complexity.py
1
#!/usr/bin/env python3
2
3
"""Code to generate plots for Extended Data Fig. 4."""
4
5
import os
6
7
import matplotlib
8
import matplotlib.pyplot as plt
9
import numpy as np
10
11
import echonet
12
13
14
def main(root=os.path.join("timing", "video"),
15
         fig_root=os.path.join("figure", "complexity"),
16
         FRAMES=(1, 8, 16, 32, 64, 96),
17
         pretrained=True):
18
    """Generate plots for Extended Data Fig. 4."""
19
20
    echonet.utils.latexify()
21
22
    os.makedirs(fig_root, exist_ok=True)
23
    fig = plt.figure(figsize=(6.50, 2.50))
24
    gs = matplotlib.gridspec.GridSpec(1, 3, width_ratios=[2.5, 2.5, 1.50])
25
    ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2]))
26
27
    # Create legend
28
    for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3"], matplotlib.colors.TABLEAU_COLORS):
29
        ax[2].plot([float("nan")], [float("nan")], "-", color=color, label=model)
30
    ax[2].set_title("")
31
    ax[2].axis("off")
32
    ax[2].legend(loc="center")
33
34
    for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
35
        for split in ["val"]:  # ["val", "train"]:
36
            print(model, split)
37
            data = [load(root, model, frames, 1, pretrained, split) for frames in FRAMES]
38
            time = np.array(list(map(lambda x: x[0], data)))
39
            n = np.array(list(map(lambda x: x[1], data)))
40
            mem_allocated = np.array(list(map(lambda x: x[2], data)))
41
            # mem_cached = np.array(list(map(lambda x: x[3], data)))
42
            batch_size = np.array(list(map(lambda x: x[4], data)))
43
44
            # Plot Time (panel a)
45
            ax[0].plot(FRAMES, time / n, "-" if pretrained else "--", marker=".", color=color, linewidth=(1 if split == "train" else None))
46
            print("Time:\n" + "\n".join(map(lambda x: "{:8d}: {:f}".format(*x), zip(FRAMES, time / n))))
47
48
            # Plot Memory (panel b)
49
            ax[1].plot(FRAMES, mem_allocated / batch_size / 1e9, "-" if pretrained else "--", marker=".", color=color, linewidth=(1 if split == "train" else None))
50
            print("Memory:\n" + "\n".join(map(lambda x: "{:8d}: {:f}".format(*x), zip(FRAMES, mem_allocated / batch_size / 1e9))))
51
            print()
52
53
    # Labels for panel a
54
    ax[0].set_xticks(FRAMES)
55
    ax[0].text(-0.05, 1.10, "(a)", transform=ax[0].transAxes)
56
    ax[0].set_xlabel("Clip length (frames)")
57
    ax[0].set_ylabel("Time Per Clip (seconds)")
58
59
    # Labels for panel b
60
    ax[1].set_xticks(FRAMES)
61
    ax[1].text(-0.05, 1.10, "(b)", transform=ax[1].transAxes)
62
    ax[1].set_xlabel("Clip length (frames)")
63
    ax[1].set_ylabel("Memory Per Clip (GB)")
64
65
    # Save figure
66
    plt.tight_layout()
67
    plt.savefig(os.path.join(fig_root, "complexity.pdf"))
68
    plt.savefig(os.path.join(fig_root, "complexity.eps"))
69
    plt.close(fig)
70
71
72
def load(root, model, frames, period, pretrained, split):
73
    """Loads runtime and memory usage for specified hyperparameter choice."""
74
    with open(os.path.join(root, "{}_{}_{}_{}".format(model, frames, period, "pretrained" if pretrained else "random"), "log.csv"), "r") as f:
75
        for line in f:
76
            line = line.split(",")
77
            if len(line) < 4:
78
                # Skip lines that are not csv (these lines log information)
79
                continue
80
            if line[1] == split:
81
                *_, time, n, mem_allocated, mem_cached, batch_size = line
82
                time = float(time)
83
                n = int(n)
84
                mem_allocated = int(mem_allocated)
85
                mem_cached = int(mem_cached)
86
                batch_size = int(batch_size)
87
                return time, n, mem_allocated, mem_cached, batch_size
88
    raise ValueError("File missing information.")
89
90
91
if __name__ == "__main__":
92
    main()