Diff of /echonet/utils/video.py [000000] .. [aeb6cc]

Switch to unified view

a b/echonet/utils/video.py
1
"""Functions for training and running EF prediction."""
2
3
import math
4
import os
5
import time
6
7
import click
8
import matplotlib.pyplot as plt
9
import numpy as np
10
import sklearn.metrics
11
import torch
12
import torchvision
13
import tqdm
14
15
import echonet
16
17
18
@click.command("video")
19
@click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
20
@click.option("--output", type=click.Path(file_okay=False), default=None)
21
@click.option("--task", type=str, default="EF")
22
@click.option("--model_name", type=click.Choice(
23
    sorted(name for name in torchvision.models.video.__dict__
24
           if name.islower() and not name.startswith("__") and callable(torchvision.models.video.__dict__[name]))),
25
    default="r2plus1d_18")
26
@click.option("--pretrained/--random", default=True)
27
@click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
28
@click.option("--run_test/--skip_test", default=False)
29
@click.option("--num_epochs", type=int, default=45)
30
@click.option("--lr", type=float, default=1e-4)
31
@click.option("--weight_decay", type=float, default=1e-4)
32
@click.option("--lr_step_period", type=int, default=15)
33
@click.option("--frames", type=int, default=32)
34
@click.option("--period", type=int, default=2)
35
@click.option("--num_train_patients", type=int, default=None)
36
@click.option("--num_workers", type=int, default=4)
37
@click.option("--batch_size", type=int, default=20)
38
@click.option("--device", type=str, default=None)
39
@click.option("--seed", type=int, default=0)
40
def run(
41
    data_dir=None,
42
    output=None,
43
    task="EF",
44
45
    model_name="r2plus1d_18",
46
    pretrained=True,
47
    weights=None,
48
49
    run_test=False,
50
    num_epochs=45,
51
    lr=1e-4,
52
    weight_decay=1e-4,
53
    lr_step_period=15,
54
    frames=32,
55
    period=2,
56
    num_train_patients=None,
57
    num_workers=4,
58
    batch_size=20,
59
    device=None,
60
    seed=0,
61
):
62
    """Trains/tests EF prediction model.
63
64
    \b
65
    Args:
66
        data_dir (str, optional): Directory containing dataset. Defaults to
67
            `echonet.config.DATA_DIR`.
68
        output (str, optional): Directory to place outputs. Defaults to
69
            output/video/<model_name>_<pretrained/random>/.
70
        task (str, optional): Name of task to predict. Options are the headers
71
            of FileList.csv. Defaults to ``EF''.
72
        model_name (str, optional): Name of model. One of ``mc3_18'',
73
            ``r2plus1d_18'', or ``r3d_18''
74
            (options are torchvision.models.video.<model_name>)
75
            Defaults to ``r2plus1d_18''.
76
        pretrained (bool, optional): Whether to use pretrained weights for model
77
            Defaults to True.
78
        weights (str, optional): Path to checkpoint containing weights to
79
            initialize model. Defaults to None.
80
        run_test (bool, optional): Whether or not to run on test.
81
            Defaults to False.
82
        num_epochs (int, optional): Number of epochs during training.
83
            Defaults to 45.
84
        lr (float, optional): Learning rate for SGD
85
            Defaults to 1e-4.
86
        weight_decay (float, optional): Weight decay for SGD
87
            Defaults to 1e-4.
88
        lr_step_period (int or None, optional): Period of learning rate decay
89
            (learning rate is decayed by a multiplicative factor of 0.1)
90
            Defaults to 15.
91
        frames (int, optional): Number of frames to use in clip
92
            Defaults to 32.
93
        period (int, optional): Sampling period for frames
94
            Defaults to 2.
95
        n_train_patients (int or None, optional): Number of training patients
96
            for ablations. Defaults to all patients.
97
        num_workers (int, optional): Number of subprocesses to use for data
98
            loading. If 0, the data will be loaded in the main process.
99
            Defaults to 4.
100
        device (str or None, optional): Name of device to run on. Options from
101
            https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
102
            Defaults to ``cuda'' if available, and ``cpu'' otherwise.
103
        batch_size (int, optional): Number of samples to load per batch
104
            Defaults to 20.
105
        seed (int, optional): Seed for random number generator. Defaults to 0.
106
    """
107
108
    # Seed RNGs
109
    np.random.seed(seed)
110
    torch.manual_seed(seed)
111
112
    # Set default output directory
113
    if output is None:
114
        output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random"))
115
    os.makedirs(output, exist_ok=True)
116
117
    # Set device for computations
118
    if device is None:
119
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
121
    # Set up model
122
    model = torchvision.models.video.__dict__[model_name](pretrained=pretrained)
123
124
    model.fc = torch.nn.Linear(model.fc.in_features, 1)
125
    model.fc.bias.data[0] = 55.6
126
    if device.type == "cuda":
127
        model = torch.nn.DataParallel(model)
128
    model.to(device)
129
130
    if weights is not None:
131
        checkpoint = torch.load(weights)
132
        model.load_state_dict(checkpoint['state_dict'])
133
134
    # Set up optimizer
135
    optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
136
    if lr_step_period is None:
137
        lr_step_period = math.inf
138
    scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
139
140
    # Compute mean and std
141
    mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
142
    kwargs = {"target_type": task,
143
              "mean": mean,
144
              "std": std,
145
              "length": frames,
146
              "period": period,
147
              }
148
149
    # Set up datasets and dataloaders
150
    dataset = {}
151
    dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12)
152
    if num_train_patients is not None and len(dataset["train"]) > num_train_patients:
153
        # Subsample patients (used for ablation experiment)
154
        indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False)
155
        dataset["train"] = torch.utils.data.Subset(dataset["train"], indices)
156
    dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs)
157
158
    # Run training and testing loops
159
    with open(os.path.join(output, "log.csv"), "a") as f:
160
        epoch_resume = 0
161
        bestLoss = float("inf")
162
        try:
163
            # Attempt to load checkpoint
164
            checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
165
            model.load_state_dict(checkpoint['state_dict'])
166
            optim.load_state_dict(checkpoint['opt_dict'])
167
            scheduler.load_state_dict(checkpoint['scheduler_dict'])
168
            epoch_resume = checkpoint["epoch"] + 1
169
            bestLoss = checkpoint["best_loss"]
170
            f.write("Resuming from epoch {}\n".format(epoch_resume))
171
        except FileNotFoundError:
172
            f.write("Starting run from scratch\n")
173
174
        for epoch in range(epoch_resume, num_epochs):
175
            print("Epoch #{}".format(epoch), flush=True)
176
            for phase in ['train', 'val']:
177
                start_time = time.time()
178
                for i in range(torch.cuda.device_count()):
179
                    torch.cuda.reset_peak_memory_stats(i)
180
181
                ds = dataset[phase]
182
                dataloader = torch.utils.data.DataLoader(
183
                    ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
184
185
                loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, phase == "train", optim, device)
186
                f.write("{},{},{},{},{},{},{},{},{}\n".format(epoch,
187
                                                              phase,
188
                                                              loss,
189
                                                              sklearn.metrics.r2_score(y, yhat),
190
                                                              time.time() - start_time,
191
                                                              y.size,
192
                                                              sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
193
                                                              sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
194
                                                              batch_size))
195
                f.flush()
196
            scheduler.step()
197
198
            # Save checkpoint
199
            save = {
200
                'epoch': epoch,
201
                'state_dict': model.state_dict(),
202
                'period': period,
203
                'frames': frames,
204
                'best_loss': bestLoss,
205
                'loss': loss,
206
                'r2': sklearn.metrics.r2_score(y, yhat),
207
                'opt_dict': optim.state_dict(),
208
                'scheduler_dict': scheduler.state_dict(),
209
            }
210
            torch.save(save, os.path.join(output, "checkpoint.pt"))
211
            if loss < bestLoss:
212
                torch.save(save, os.path.join(output, "best.pt"))
213
                bestLoss = loss
214
215
        # Load best weights
216
        if num_epochs != 0:
217
            checkpoint = torch.load(os.path.join(output, "best.pt"))
218
            model.load_state_dict(checkpoint['state_dict'])
219
            f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))
220
            f.flush()
221
222
        if run_test:
223
            for split in ["val", "test"]:
224
                # Performance without test-time augmentation
225
                dataloader = torch.utils.data.DataLoader(
226
                    echonet.datasets.Echo(root=data_dir, split=split, **kwargs),
227
                    batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"))
228
                loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device)
229
                f.write("{} (one clip) R2:   {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
230
                f.write("{} (one clip) MAE:  {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
231
                f.write("{} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))
232
                f.flush()
233
234
                # Performance with test-time augmentation
235
                ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all")
236
                dataloader = torch.utils.data.DataLoader(
237
                    ds, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"))
238
                loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device, save_all=True, block_size=batch_size)
239
                f.write("{} (all clips) R2:   {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.r2_score)))
240
                f.write("{} (all clips) MAE:  {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_absolute_error)))
241
                f.write("{} (all clips) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_squared_error)))))
242
                f.flush()
243
244
                # Write full performance to file
245
                with open(os.path.join(output, "{}_predictions.csv".format(split)), "w") as g:
246
                    for (filename, pred) in zip(ds.fnames, yhat):
247
                        for (i, p) in enumerate(pred):
248
                            g.write("{},{},{:.4f}\n".format(filename, i, p))
249
                echonet.utils.latexify()
250
                yhat = np.array(list(map(lambda x: x.mean(), yhat)))
251
252
                # Plot actual and predicted EF
253
                fig = plt.figure(figsize=(3, 3))
254
                lower = min(y.min(), yhat.min())
255
                upper = max(y.max(), yhat.max())
256
                plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2)
257
                plt.plot([0, 100], [0, 100], linewidth=1, zorder=3)
258
                plt.axis([lower - 3, upper + 3, lower - 3, upper + 3])
259
                plt.gca().set_aspect("equal", "box")
260
                plt.xlabel("Actual EF (%)")
261
                plt.ylabel("Predicted EF (%)")
262
                plt.xticks([10, 20, 30, 40, 50, 60, 70, 80])
263
                plt.yticks([10, 20, 30, 40, 50, 60, 70, 80])
264
                plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1)
265
                plt.tight_layout()
266
                plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split)))
267
                plt.close(fig)
268
269
                # Plot AUROC
270
                fig = plt.figure(figsize=(3, 3))
271
                plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--")
272
                for thresh in [35, 40, 45, 50]:
273
                    fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat)
274
                    print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat))
275
                    plt.plot(fpr, tpr)
276
277
                plt.axis([-0.01, 1.01, -0.01, 1.01])
278
                plt.xlabel("False Positive Rate")
279
                plt.ylabel("True Positive Rate")
280
                plt.tight_layout()
281
                plt.savefig(os.path.join(output, "{}_roc.pdf".format(split)))
282
                plt.close(fig)
283
284
285
def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None):
286
    """Run one epoch of training/evaluation for segmentation.
287
288
    Args:
289
        model (torch.nn.Module): Model to train/evaulate.
290
        dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
291
        train (bool): Whether or not to train model.
292
        optim (torch.optim.Optimizer): Optimizer
293
        device (torch.device): Device to run on
294
        save_all (bool, optional): If True, return predictions for all
295
            test-time augmentations separately. If False, return only
296
            the mean prediction.
297
            Defaults to False.
298
        block_size (int or None, optional): Maximum number of augmentations
299
            to run on at the same time. Use to limit the amount of memory
300
            used. If None, always run on all augmentations simultaneously.
301
            Default is None.
302
    """
303
304
    model.train(train)
305
306
    total = 0  # total training loss
307
    n = 0      # number of videos processed
308
    s1 = 0     # sum of ground truth EF
309
    s2 = 0     # Sum of ground truth EF squared
310
311
    yhat = []
312
    y = []
313
314
    with torch.set_grad_enabled(train):
315
        with tqdm.tqdm(total=len(dataloader)) as pbar:
316
            for (X, outcome) in dataloader:
317
318
                y.append(outcome.numpy())
319
                X = X.to(device)
320
                outcome = outcome.to(device)
321
322
                average = (len(X.shape) == 6)
323
                if average:
324
                    batch, n_clips, c, f, h, w = X.shape
325
                    X = X.view(-1, c, f, h, w)
326
327
                s1 += outcome.sum()
328
                s2 += (outcome ** 2).sum()
329
330
                if block_size is None:
331
                    outputs = model(X)
332
                else:
333
                    outputs = torch.cat([model(X[j:(j + block_size), ...]) for j in range(0, X.shape[0], block_size)])
334
335
                if save_all:
336
                    yhat.append(outputs.view(-1).to("cpu").detach().numpy())
337
338
                if average:
339
                    outputs = outputs.view(batch, n_clips, -1).mean(1)
340
341
                if not save_all:
342
                    yhat.append(outputs.view(-1).to("cpu").detach().numpy())
343
344
                loss = torch.nn.functional.mse_loss(outputs.view(-1), outcome)
345
346
                if train:
347
                    optim.zero_grad()
348
                    loss.backward()
349
                    optim.step()
350
351
                total += loss.item() * X.size(0)
352
                n += X.size(0)
353
354
                pbar.set_postfix_str("{:.2f} ({:.2f}) / {:.2f}".format(total / n, loss.item(), s2 / n - (s1 / n) ** 2))
355
                pbar.update()
356
357
    if not save_all:
358
        yhat = np.concatenate(yhat)
359
    y = np.concatenate(y)
360
361
    return total / n, yhat, y