|
a |
|
b/scripts/plot_loss.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
|
|
|
3 |
"""Code to generate plots for Extended Data Fig. 3.""" |
|
|
4 |
|
|
|
5 |
import argparse |
|
|
6 |
import os |
|
|
7 |
import matplotlib |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
|
|
|
10 |
import echonet |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def main(): |
|
|
14 |
"""Generate plots for Extended Data Fig. 3.""" |
|
|
15 |
|
|
|
16 |
# Select paths and hyperparameter to plot |
|
|
17 |
parser = argparse.ArgumentParser() |
|
|
18 |
parser.add_argument("dir", nargs="?", default="output") |
|
|
19 |
parser.add_argument("fig", nargs="?", default=os.path.join("figure", "loss")) |
|
|
20 |
parser.add_argument("--frames", type=int, default=32) |
|
|
21 |
parser.add_argument("--period", type=int, default=2) |
|
|
22 |
args = parser.parse_args() |
|
|
23 |
|
|
|
24 |
# Set up figure |
|
|
25 |
echonet.utils.latexify() |
|
|
26 |
os.makedirs(args.fig, exist_ok=True) |
|
|
27 |
fig = plt.figure(figsize=(7, 5)) |
|
|
28 |
gs = matplotlib.gridspec.GridSpec(ncols=3, nrows=2, figure=fig, width_ratios=[2.75, 2.75, 1.50]) |
|
|
29 |
|
|
|
30 |
# Plot EF loss curve |
|
|
31 |
ax0 = fig.add_subplot(gs[0, 0]) |
|
|
32 |
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0) |
|
|
33 |
for pretrained in [True]: |
|
|
34 |
for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS): |
|
|
35 |
loss = load(os.path.join(args.dir, "video", "{}_{}_{}_{}".format(model, args.frames, args.period, "pretrained" if pretrained else "random"), "log.csv")) |
|
|
36 |
ax0.plot(range(1, 1 + len(loss["train"])), loss["train"], "-" if pretrained else "--", color=color) |
|
|
37 |
ax1.plot(range(1, 1 + len(loss["val"])), loss["val"], "-" if pretrained else "--", color=color) |
|
|
38 |
|
|
|
39 |
plt.axis([0, max(len(loss["train"]), len(loss["val"])), 0, max(max(loss["train"]), max(loss["val"]))]) |
|
|
40 |
ax0.text(-0.25, 1.00, "(a)", transform=ax0.transAxes) |
|
|
41 |
ax1.text(-0.25, 1.00, "(b)", transform=ax1.transAxes) |
|
|
42 |
ax0.set_xlabel("Epochs") |
|
|
43 |
ax1.set_xlabel("Epochs") |
|
|
44 |
ax0.set_xticks([0, 15, 30, 45]) |
|
|
45 |
ax1.set_xticks([0, 15, 30, 45]) |
|
|
46 |
ax0.set_ylabel("Training MSE Loss") |
|
|
47 |
ax1.set_ylabel("Validation MSE Loss") |
|
|
48 |
|
|
|
49 |
# Plot segmentation loss curve |
|
|
50 |
ax0 = fig.add_subplot(gs[1, 0]) |
|
|
51 |
ax1 = fig.add_subplot(gs[1, 1], sharey=ax0) |
|
|
52 |
pretrained = False |
|
|
53 |
for (model, color) in zip(["deeplabv3_resnet50"], list(matplotlib.colors.TABLEAU_COLORS)[3:]): |
|
|
54 |
loss = load(os.path.join(args.dir, "segmentation", "{}_{}".format(model, "pretrained" if pretrained else "random"), "log.csv")) |
|
|
55 |
ax0.plot(range(1, 1 + len(loss["train"])), loss["train"], "--", color=color) |
|
|
56 |
ax1.plot(range(1, 1 + len(loss["val"])), loss["val"], "--", color=color) |
|
|
57 |
|
|
|
58 |
ax0.text(-0.25, 1.00, "(c)", transform=ax0.transAxes) |
|
|
59 |
ax1.text(-0.25, 1.00, "(d)", transform=ax1.transAxes) |
|
|
60 |
ax0.set_ylim([0, 0.13]) |
|
|
61 |
ax0.set_xlabel("Epochs") |
|
|
62 |
ax1.set_xlabel("Epochs") |
|
|
63 |
ax0.set_xticks([0, 25, 50]) |
|
|
64 |
ax1.set_xticks([0, 25, 50]) |
|
|
65 |
ax0.set_ylabel("Training Cross Entropy Loss") |
|
|
66 |
ax1.set_ylabel("Validation Cross Entropy Loss") |
|
|
67 |
|
|
|
68 |
# Legend |
|
|
69 |
ax = fig.add_subplot(gs[:, 2]) |
|
|
70 |
for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3", "EchoNet-Dynamic (Seg)"], matplotlib.colors.TABLEAU_COLORS): |
|
|
71 |
ax.plot([float("nan")], [float("nan")], "-", color=color, label=model) |
|
|
72 |
ax.set_title("") |
|
|
73 |
ax.axis("off") |
|
|
74 |
ax.legend(loc="center") |
|
|
75 |
|
|
|
76 |
plt.tight_layout() |
|
|
77 |
plt.savefig(os.path.join(args.fig, "loss.pdf")) |
|
|
78 |
plt.savefig(os.path.join(args.fig, "loss.eps")) |
|
|
79 |
plt.savefig(os.path.join(args.fig, "loss.png")) |
|
|
80 |
plt.close(fig) |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
def load(filename): |
|
|
84 |
"""Loads losses from specified file.""" |
|
|
85 |
|
|
|
86 |
losses = {"train": [], "val": []} |
|
|
87 |
with open(filename, "r") as f: |
|
|
88 |
for line in f: |
|
|
89 |
line = line.split(",") |
|
|
90 |
if len(line) < 4: |
|
|
91 |
continue |
|
|
92 |
epoch, split, loss, *_ = line |
|
|
93 |
epoch = int(epoch) |
|
|
94 |
loss = float(loss) |
|
|
95 |
assert(split in ["train", "val"]) |
|
|
96 |
if epoch == len(losses[split]): |
|
|
97 |
losses[split].append(loss) |
|
|
98 |
elif epoch == len(losses[split]) - 1: |
|
|
99 |
losses[split][-1] = loss |
|
|
100 |
else: |
|
|
101 |
raise ValueError("File has uninterpretable formatting.") |
|
|
102 |
return losses |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
if __name__ == "__main__": |
|
|
106 |
main() |