|
a |
|
b/echonet/utils/video.py |
|
|
1 |
"""Functions for training and running EF prediction.""" |
|
|
2 |
|
|
|
3 |
import math |
|
|
4 |
import os |
|
|
5 |
import time |
|
|
6 |
|
|
|
7 |
import click |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import numpy as np |
|
|
10 |
import sklearn.metrics |
|
|
11 |
import torch |
|
|
12 |
import torchvision |
|
|
13 |
import tqdm |
|
|
14 |
|
|
|
15 |
import echonet |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
@click.command("video") |
|
|
19 |
@click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None) |
|
|
20 |
@click.option("--output", type=click.Path(file_okay=False), default=None) |
|
|
21 |
@click.option("--task", type=str, default="EF") |
|
|
22 |
@click.option("--model_name", type=click.Choice( |
|
|
23 |
sorted(name for name in torchvision.models.video.__dict__ |
|
|
24 |
if name.islower() and not name.startswith("__") and callable(torchvision.models.video.__dict__[name]))), |
|
|
25 |
default="r2plus1d_18") |
|
|
26 |
@click.option("--pretrained/--random", default=True) |
|
|
27 |
@click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None) |
|
|
28 |
@click.option("--run_test/--skip_test", default=False) |
|
|
29 |
@click.option("--num_epochs", type=int, default=45) |
|
|
30 |
@click.option("--lr", type=float, default=1e-4) |
|
|
31 |
@click.option("--weight_decay", type=float, default=1e-4) |
|
|
32 |
@click.option("--lr_step_period", type=int, default=15) |
|
|
33 |
@click.option("--frames", type=int, default=32) |
|
|
34 |
@click.option("--period", type=int, default=2) |
|
|
35 |
@click.option("--num_train_patients", type=int, default=None) |
|
|
36 |
@click.option("--num_workers", type=int, default=4) |
|
|
37 |
@click.option("--batch_size", type=int, default=20) |
|
|
38 |
@click.option("--device", type=str, default=None) |
|
|
39 |
@click.option("--seed", type=int, default=0) |
|
|
40 |
def run( |
|
|
41 |
data_dir=None, |
|
|
42 |
output=None, |
|
|
43 |
task="EF", |
|
|
44 |
|
|
|
45 |
model_name="r2plus1d_18", |
|
|
46 |
pretrained=True, |
|
|
47 |
weights=None, |
|
|
48 |
|
|
|
49 |
run_test=False, |
|
|
50 |
num_epochs=45, |
|
|
51 |
lr=1e-4, |
|
|
52 |
weight_decay=1e-4, |
|
|
53 |
lr_step_period=15, |
|
|
54 |
frames=32, |
|
|
55 |
period=2, |
|
|
56 |
num_train_patients=None, |
|
|
57 |
num_workers=4, |
|
|
58 |
batch_size=20, |
|
|
59 |
device=None, |
|
|
60 |
seed=0, |
|
|
61 |
): |
|
|
62 |
"""Trains/tests EF prediction model. |
|
|
63 |
|
|
|
64 |
\b |
|
|
65 |
Args: |
|
|
66 |
data_dir (str, optional): Directory containing dataset. Defaults to |
|
|
67 |
`echonet.config.DATA_DIR`. |
|
|
68 |
output (str, optional): Directory to place outputs. Defaults to |
|
|
69 |
output/video/<model_name>_<pretrained/random>/. |
|
|
70 |
task (str, optional): Name of task to predict. Options are the headers |
|
|
71 |
of FileList.csv. Defaults to ``EF''. |
|
|
72 |
model_name (str, optional): Name of model. One of ``mc3_18'', |
|
|
73 |
``r2plus1d_18'', or ``r3d_18'' |
|
|
74 |
(options are torchvision.models.video.<model_name>) |
|
|
75 |
Defaults to ``r2plus1d_18''. |
|
|
76 |
pretrained (bool, optional): Whether to use pretrained weights for model |
|
|
77 |
Defaults to True. |
|
|
78 |
weights (str, optional): Path to checkpoint containing weights to |
|
|
79 |
initialize model. Defaults to None. |
|
|
80 |
run_test (bool, optional): Whether or not to run on test. |
|
|
81 |
Defaults to False. |
|
|
82 |
num_epochs (int, optional): Number of epochs during training. |
|
|
83 |
Defaults to 45. |
|
|
84 |
lr (float, optional): Learning rate for SGD |
|
|
85 |
Defaults to 1e-4. |
|
|
86 |
weight_decay (float, optional): Weight decay for SGD |
|
|
87 |
Defaults to 1e-4. |
|
|
88 |
lr_step_period (int or None, optional): Period of learning rate decay |
|
|
89 |
(learning rate is decayed by a multiplicative factor of 0.1) |
|
|
90 |
Defaults to 15. |
|
|
91 |
frames (int, optional): Number of frames to use in clip |
|
|
92 |
Defaults to 32. |
|
|
93 |
period (int, optional): Sampling period for frames |
|
|
94 |
Defaults to 2. |
|
|
95 |
n_train_patients (int or None, optional): Number of training patients |
|
|
96 |
for ablations. Defaults to all patients. |
|
|
97 |
num_workers (int, optional): Number of subprocesses to use for data |
|
|
98 |
loading. If 0, the data will be loaded in the main process. |
|
|
99 |
Defaults to 4. |
|
|
100 |
device (str or None, optional): Name of device to run on. Options from |
|
|
101 |
https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device |
|
|
102 |
Defaults to ``cuda'' if available, and ``cpu'' otherwise. |
|
|
103 |
batch_size (int, optional): Number of samples to load per batch |
|
|
104 |
Defaults to 20. |
|
|
105 |
seed (int, optional): Seed for random number generator. Defaults to 0. |
|
|
106 |
""" |
|
|
107 |
|
|
|
108 |
# Seed RNGs |
|
|
109 |
np.random.seed(seed) |
|
|
110 |
torch.manual_seed(seed) |
|
|
111 |
|
|
|
112 |
# Set default output directory |
|
|
113 |
if output is None: |
|
|
114 |
output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random")) |
|
|
115 |
os.makedirs(output, exist_ok=True) |
|
|
116 |
|
|
|
117 |
# Set device for computations |
|
|
118 |
if device is None: |
|
|
119 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
120 |
|
|
|
121 |
# Set up model |
|
|
122 |
model = torchvision.models.video.__dict__[model_name](pretrained=pretrained) |
|
|
123 |
|
|
|
124 |
model.fc = torch.nn.Linear(model.fc.in_features, 1) |
|
|
125 |
model.fc.bias.data[0] = 55.6 |
|
|
126 |
if device.type == "cuda": |
|
|
127 |
model = torch.nn.DataParallel(model) |
|
|
128 |
model.to(device) |
|
|
129 |
|
|
|
130 |
if weights is not None: |
|
|
131 |
checkpoint = torch.load(weights) |
|
|
132 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
133 |
|
|
|
134 |
# Set up optimizer |
|
|
135 |
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) |
|
|
136 |
if lr_step_period is None: |
|
|
137 |
lr_step_period = math.inf |
|
|
138 |
scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) |
|
|
139 |
|
|
|
140 |
# Compute mean and std |
|
|
141 |
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) |
|
|
142 |
kwargs = {"target_type": task, |
|
|
143 |
"mean": mean, |
|
|
144 |
"std": std, |
|
|
145 |
"length": frames, |
|
|
146 |
"period": period, |
|
|
147 |
} |
|
|
148 |
|
|
|
149 |
# Set up datasets and dataloaders |
|
|
150 |
dataset = {} |
|
|
151 |
dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12) |
|
|
152 |
if num_train_patients is not None and len(dataset["train"]) > num_train_patients: |
|
|
153 |
# Subsample patients (used for ablation experiment) |
|
|
154 |
indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) |
|
|
155 |
dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) |
|
|
156 |
dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) |
|
|
157 |
|
|
|
158 |
# Run training and testing loops |
|
|
159 |
with open(os.path.join(output, "log.csv"), "a") as f: |
|
|
160 |
epoch_resume = 0 |
|
|
161 |
bestLoss = float("inf") |
|
|
162 |
try: |
|
|
163 |
# Attempt to load checkpoint |
|
|
164 |
checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) |
|
|
165 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
166 |
optim.load_state_dict(checkpoint['opt_dict']) |
|
|
167 |
scheduler.load_state_dict(checkpoint['scheduler_dict']) |
|
|
168 |
epoch_resume = checkpoint["epoch"] + 1 |
|
|
169 |
bestLoss = checkpoint["best_loss"] |
|
|
170 |
f.write("Resuming from epoch {}\n".format(epoch_resume)) |
|
|
171 |
except FileNotFoundError: |
|
|
172 |
f.write("Starting run from scratch\n") |
|
|
173 |
|
|
|
174 |
for epoch in range(epoch_resume, num_epochs): |
|
|
175 |
print("Epoch #{}".format(epoch), flush=True) |
|
|
176 |
for phase in ['train', 'val']: |
|
|
177 |
start_time = time.time() |
|
|
178 |
for i in range(torch.cuda.device_count()): |
|
|
179 |
torch.cuda.reset_peak_memory_stats(i) |
|
|
180 |
|
|
|
181 |
ds = dataset[phase] |
|
|
182 |
dataloader = torch.utils.data.DataLoader( |
|
|
183 |
ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) |
|
|
184 |
|
|
|
185 |
loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, phase == "train", optim, device) |
|
|
186 |
f.write("{},{},{},{},{},{},{},{},{}\n".format(epoch, |
|
|
187 |
phase, |
|
|
188 |
loss, |
|
|
189 |
sklearn.metrics.r2_score(y, yhat), |
|
|
190 |
time.time() - start_time, |
|
|
191 |
y.size, |
|
|
192 |
sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), |
|
|
193 |
sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), |
|
|
194 |
batch_size)) |
|
|
195 |
f.flush() |
|
|
196 |
scheduler.step() |
|
|
197 |
|
|
|
198 |
# Save checkpoint |
|
|
199 |
save = { |
|
|
200 |
'epoch': epoch, |
|
|
201 |
'state_dict': model.state_dict(), |
|
|
202 |
'period': period, |
|
|
203 |
'frames': frames, |
|
|
204 |
'best_loss': bestLoss, |
|
|
205 |
'loss': loss, |
|
|
206 |
'r2': sklearn.metrics.r2_score(y, yhat), |
|
|
207 |
'opt_dict': optim.state_dict(), |
|
|
208 |
'scheduler_dict': scheduler.state_dict(), |
|
|
209 |
} |
|
|
210 |
torch.save(save, os.path.join(output, "checkpoint.pt")) |
|
|
211 |
if loss < bestLoss: |
|
|
212 |
torch.save(save, os.path.join(output, "best.pt")) |
|
|
213 |
bestLoss = loss |
|
|
214 |
|
|
|
215 |
# Load best weights |
|
|
216 |
if num_epochs != 0: |
|
|
217 |
checkpoint = torch.load(os.path.join(output, "best.pt")) |
|
|
218 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
219 |
f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) |
|
|
220 |
f.flush() |
|
|
221 |
|
|
|
222 |
if run_test: |
|
|
223 |
for split in ["val", "test"]: |
|
|
224 |
# Performance without test-time augmentation |
|
|
225 |
dataloader = torch.utils.data.DataLoader( |
|
|
226 |
echonet.datasets.Echo(root=data_dir, split=split, **kwargs), |
|
|
227 |
batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda")) |
|
|
228 |
loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device) |
|
|
229 |
f.write("{} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score))) |
|
|
230 |
f.write("{} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error))) |
|
|
231 |
f.write("{} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error))))) |
|
|
232 |
f.flush() |
|
|
233 |
|
|
|
234 |
# Performance with test-time augmentation |
|
|
235 |
ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all") |
|
|
236 |
dataloader = torch.utils.data.DataLoader( |
|
|
237 |
ds, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) |
|
|
238 |
loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device, save_all=True, block_size=batch_size) |
|
|
239 |
f.write("{} (all clips) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.r2_score))) |
|
|
240 |
f.write("{} (all clips) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_absolute_error))) |
|
|
241 |
f.write("{} (all clips) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_squared_error))))) |
|
|
242 |
f.flush() |
|
|
243 |
|
|
|
244 |
# Write full performance to file |
|
|
245 |
with open(os.path.join(output, "{}_predictions.csv".format(split)), "w") as g: |
|
|
246 |
for (filename, pred) in zip(ds.fnames, yhat): |
|
|
247 |
for (i, p) in enumerate(pred): |
|
|
248 |
g.write("{},{},{:.4f}\n".format(filename, i, p)) |
|
|
249 |
echonet.utils.latexify() |
|
|
250 |
yhat = np.array(list(map(lambda x: x.mean(), yhat))) |
|
|
251 |
|
|
|
252 |
# Plot actual and predicted EF |
|
|
253 |
fig = plt.figure(figsize=(3, 3)) |
|
|
254 |
lower = min(y.min(), yhat.min()) |
|
|
255 |
upper = max(y.max(), yhat.max()) |
|
|
256 |
plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2) |
|
|
257 |
plt.plot([0, 100], [0, 100], linewidth=1, zorder=3) |
|
|
258 |
plt.axis([lower - 3, upper + 3, lower - 3, upper + 3]) |
|
|
259 |
plt.gca().set_aspect("equal", "box") |
|
|
260 |
plt.xlabel("Actual EF (%)") |
|
|
261 |
plt.ylabel("Predicted EF (%)") |
|
|
262 |
plt.xticks([10, 20, 30, 40, 50, 60, 70, 80]) |
|
|
263 |
plt.yticks([10, 20, 30, 40, 50, 60, 70, 80]) |
|
|
264 |
plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1) |
|
|
265 |
plt.tight_layout() |
|
|
266 |
plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split))) |
|
|
267 |
plt.close(fig) |
|
|
268 |
|
|
|
269 |
# Plot AUROC |
|
|
270 |
fig = plt.figure(figsize=(3, 3)) |
|
|
271 |
plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--") |
|
|
272 |
for thresh in [35, 40, 45, 50]: |
|
|
273 |
fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat) |
|
|
274 |
print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat)) |
|
|
275 |
plt.plot(fpr, tpr) |
|
|
276 |
|
|
|
277 |
plt.axis([-0.01, 1.01, -0.01, 1.01]) |
|
|
278 |
plt.xlabel("False Positive Rate") |
|
|
279 |
plt.ylabel("True Positive Rate") |
|
|
280 |
plt.tight_layout() |
|
|
281 |
plt.savefig(os.path.join(output, "{}_roc.pdf".format(split))) |
|
|
282 |
plt.close(fig) |
|
|
283 |
|
|
|
284 |
|
|
|
285 |
def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None): |
|
|
286 |
"""Run one epoch of training/evaluation for segmentation. |
|
|
287 |
|
|
|
288 |
Args: |
|
|
289 |
model (torch.nn.Module): Model to train/evaulate. |
|
|
290 |
dataloder (torch.utils.data.DataLoader): Dataloader for dataset. |
|
|
291 |
train (bool): Whether or not to train model. |
|
|
292 |
optim (torch.optim.Optimizer): Optimizer |
|
|
293 |
device (torch.device): Device to run on |
|
|
294 |
save_all (bool, optional): If True, return predictions for all |
|
|
295 |
test-time augmentations separately. If False, return only |
|
|
296 |
the mean prediction. |
|
|
297 |
Defaults to False. |
|
|
298 |
block_size (int or None, optional): Maximum number of augmentations |
|
|
299 |
to run on at the same time. Use to limit the amount of memory |
|
|
300 |
used. If None, always run on all augmentations simultaneously. |
|
|
301 |
Default is None. |
|
|
302 |
""" |
|
|
303 |
|
|
|
304 |
model.train(train) |
|
|
305 |
|
|
|
306 |
total = 0 # total training loss |
|
|
307 |
n = 0 # number of videos processed |
|
|
308 |
s1 = 0 # sum of ground truth EF |
|
|
309 |
s2 = 0 # Sum of ground truth EF squared |
|
|
310 |
|
|
|
311 |
yhat = [] |
|
|
312 |
y = [] |
|
|
313 |
|
|
|
314 |
with torch.set_grad_enabled(train): |
|
|
315 |
with tqdm.tqdm(total=len(dataloader)) as pbar: |
|
|
316 |
for (X, outcome) in dataloader: |
|
|
317 |
|
|
|
318 |
y.append(outcome.numpy()) |
|
|
319 |
X = X.to(device) |
|
|
320 |
outcome = outcome.to(device) |
|
|
321 |
|
|
|
322 |
average = (len(X.shape) == 6) |
|
|
323 |
if average: |
|
|
324 |
batch, n_clips, c, f, h, w = X.shape |
|
|
325 |
X = X.view(-1, c, f, h, w) |
|
|
326 |
|
|
|
327 |
s1 += outcome.sum() |
|
|
328 |
s2 += (outcome ** 2).sum() |
|
|
329 |
|
|
|
330 |
if block_size is None: |
|
|
331 |
outputs = model(X) |
|
|
332 |
else: |
|
|
333 |
outputs = torch.cat([model(X[j:(j + block_size), ...]) for j in range(0, X.shape[0], block_size)]) |
|
|
334 |
|
|
|
335 |
if save_all: |
|
|
336 |
yhat.append(outputs.view(-1).to("cpu").detach().numpy()) |
|
|
337 |
|
|
|
338 |
if average: |
|
|
339 |
outputs = outputs.view(batch, n_clips, -1).mean(1) |
|
|
340 |
|
|
|
341 |
if not save_all: |
|
|
342 |
yhat.append(outputs.view(-1).to("cpu").detach().numpy()) |
|
|
343 |
|
|
|
344 |
loss = torch.nn.functional.mse_loss(outputs.view(-1), outcome) |
|
|
345 |
|
|
|
346 |
if train: |
|
|
347 |
optim.zero_grad() |
|
|
348 |
loss.backward() |
|
|
349 |
optim.step() |
|
|
350 |
|
|
|
351 |
total += loss.item() * X.size(0) |
|
|
352 |
n += X.size(0) |
|
|
353 |
|
|
|
354 |
pbar.set_postfix_str("{:.2f} ({:.2f}) / {:.2f}".format(total / n, loss.item(), s2 / n - (s1 / n) ** 2)) |
|
|
355 |
pbar.update() |
|
|
356 |
|
|
|
357 |
if not save_all: |
|
|
358 |
yhat = np.concatenate(yhat) |
|
|
359 |
y = np.concatenate(y) |
|
|
360 |
|
|
|
361 |
return total / n, yhat, y |