Switch to unified view

a b/echonet/utils/segmentation.py
1
"""Functions for training and running segmentation."""
2
3
import math
4
import os
5
import time
6
7
import click
8
import matplotlib.pyplot as plt
9
import numpy as np
10
import scipy.signal
11
import skimage.draw
12
import torch
13
import torchvision
14
import tqdm
15
16
import echonet
17
18
19
@click.command("segmentation")
20
@click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
21
@click.option("--output", type=click.Path(file_okay=False), default=None)
22
@click.option("--model_name", type=click.Choice(
23
    sorted(name for name in torchvision.models.segmentation.__dict__
24
           if name.islower() and not name.startswith("__") and callable(torchvision.models.segmentation.__dict__[name]))),
25
    default="deeplabv3_resnet50")
26
@click.option("--pretrained/--random", default=False)
27
@click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
28
@click.option("--run_test/--skip_test", default=False)
29
@click.option("--save_video/--skip_video", default=False)
30
@click.option("--num_epochs", type=int, default=50)
31
@click.option("--lr", type=float, default=1e-5)
32
@click.option("--weight_decay", type=float, default=0)
33
@click.option("--lr_step_period", type=int, default=None)
34
@click.option("--num_train_patients", type=int, default=None)
35
@click.option("--num_workers", type=int, default=4)
36
@click.option("--batch_size", type=int, default=20)
37
@click.option("--device", type=str, default=None)
38
@click.option("--seed", type=int, default=0)
39
def run(
40
    data_dir=None,
41
    output=None,
42
43
    model_name="deeplabv3_resnet50",
44
    pretrained=False,
45
    weights=None,
46
47
    run_test=False,
48
    save_video=False,
49
    num_epochs=50,
50
    lr=1e-5,
51
    weight_decay=1e-5,
52
    lr_step_period=None,
53
    num_train_patients=None,
54
    num_workers=4,
55
    batch_size=20,
56
    device=None,
57
    seed=0,
58
):
59
    """Trains/tests segmentation model.
60
61
    Args:
62
        data_dir (str, optional): Directory containing dataset. Defaults to
63
            `echonet.config.DATA_DIR`.
64
        output (str, optional): Directory to place outputs. Defaults to
65
            output/segmentation/<model_name>_<pretrained/random>/.
66
        model_name (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'',
67
            ``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101''
68
            (options are torchvision.models.segmentation.<model_name>)
69
            Defaults to ``deeplabv3_resnet50''.
70
        pretrained (bool, optional): Whether to use pretrained weights for model
71
            Defaults to False.
72
        weights (str, optional): Path to checkpoint containing weights to
73
            initialize model. Defaults to None.
74
        run_test (bool, optional): Whether or not to run on test.
75
            Defaults to False.
76
        save_video (bool, optional): Whether to save videos with segmentations.
77
            Defaults to False.
78
        num_epochs (int, optional): Number of epochs during training
79
            Defaults to 50.
80
        lr (float, optional): Learning rate for SGD
81
            Defaults to 1e-5.
82
        weight_decay (float, optional): Weight decay for SGD
83
            Defaults to 0.
84
        lr_step_period (int or None, optional): Period of learning rate decay
85
            (learning rate is decayed by a multiplicative factor of 0.1)
86
            Defaults to math.inf (never decay learning rate).
87
        num_train_patients (int or None, optional): Number of training patients
88
            for ablations. Defaults to all patients.
89
        num_workers (int, optional): Number of subprocesses to use for data
90
            loading. If 0, the data will be loaded in the main process.
91
            Defaults to 4.
92
        device (str or None, optional): Name of device to run on. Options from
93
            https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
94
            Defaults to ``cuda'' if available, and ``cpu'' otherwise.
95
        batch_size (int, optional): Number of samples to load per batch
96
            Defaults to 20.
97
        seed (int, optional): Seed for random number generator. Defaults to 0.
98
    """
99
100
    # Seed RNGs
101
    np.random.seed(seed)
102
    torch.manual_seed(seed)
103
104
    # Set default output directory
105
    if output is None:
106
        output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random"))
107
    os.makedirs(output, exist_ok=True)
108
109
    # Set device for computations
110
    if device is None:
111
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
113
    # Set up model
114
    model = torchvision.models.segmentation.__dict__[model_name](pretrained=pretrained, aux_loss=False)
115
116
    model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)  # change number of outputs to 1
117
    if device.type == "cuda":
118
        model = torch.nn.DataParallel(model)
119
    model.to(device)
120
121
    if weights is not None:
122
        checkpoint = torch.load(weights)
123
        model.load_state_dict(checkpoint['state_dict'])
124
125
    # Set up optimizer
126
    optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
127
    if lr_step_period is None:
128
        lr_step_period = math.inf
129
    scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
130
131
    # Compute mean and std
132
    mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
133
    tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
134
    kwargs = {"target_type": tasks,
135
              "mean": mean,
136
              "std": std
137
              }
138
139
    # Set up datasets and dataloaders
140
    dataset = {}
141
    dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs)
142
    if num_train_patients is not None and len(dataset["train"]) > num_train_patients:
143
        # Subsample patients (used for ablation experiment)
144
        indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False)
145
        dataset["train"] = torch.utils.data.Subset(dataset["train"], indices)
146
    dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs)
147
148
    # Run training and testing loops
149
    with open(os.path.join(output, "log.csv"), "a") as f:
150
        epoch_resume = 0
151
        bestLoss = float("inf")
152
        try:
153
            # Attempt to load checkpoint
154
            checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
155
            model.load_state_dict(checkpoint['state_dict'])
156
            optim.load_state_dict(checkpoint['opt_dict'])
157
            scheduler.load_state_dict(checkpoint['scheduler_dict'])
158
            epoch_resume = checkpoint["epoch"] + 1
159
            bestLoss = checkpoint["best_loss"]
160
            f.write("Resuming from epoch {}\n".format(epoch_resume))
161
        except FileNotFoundError:
162
            f.write("Starting run from scratch\n")
163
164
        for epoch in range(epoch_resume, num_epochs):
165
            print("Epoch #{}".format(epoch), flush=True)
166
            for phase in ['train', 'val']:
167
                start_time = time.time()
168
                for i in range(torch.cuda.device_count()):
169
                    torch.cuda.reset_peak_memory_stats(i)
170
171
                ds = dataset[phase]
172
                dataloader = torch.utils.data.DataLoader(
173
                    ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
174
175
                loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, phase == "train", optim, device)
176
                overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum())
177
                large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum())
178
                small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum())
179
                f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
180
                                                                    phase,
181
                                                                    loss,
182
                                                                    overall_dice,
183
                                                                    large_dice,
184
                                                                    small_dice,
185
                                                                    time.time() - start_time,
186
                                                                    large_inter.size,
187
                                                                    sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
188
                                                                    sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
189
                                                                    batch_size))
190
                f.flush()
191
            scheduler.step()
192
193
            # Save checkpoint
194
            save = {
195
                'epoch': epoch,
196
                'state_dict': model.state_dict(),
197
                'best_loss': bestLoss,
198
                'loss': loss,
199
                'opt_dict': optim.state_dict(),
200
                'scheduler_dict': scheduler.state_dict(),
201
            }
202
            torch.save(save, os.path.join(output, "checkpoint.pt"))
203
            if loss < bestLoss:
204
                torch.save(save, os.path.join(output, "best.pt"))
205
                bestLoss = loss
206
207
        # Load best weights
208
        if num_epochs != 0:
209
            checkpoint = torch.load(os.path.join(output, "best.pt"))
210
            model.load_state_dict(checkpoint['state_dict'])
211
            f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))
212
213
        if run_test:
214
            # Run on validation and test
215
            for split in ["val", "test"]:
216
                dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs)
217
                dataloader = torch.utils.data.DataLoader(dataset,
218
                                                         batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"))
219
                loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, False, None, device)
220
221
                overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter)
222
                large_dice = 2 * large_inter / (large_union + large_inter)
223
                small_dice = 2 * small_inter / (small_union + small_inter)
224
                with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g:
225
                    g.write("Filename, Overall, Large, Small\n")
226
                    for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice):
227
                        g.write("{},{},{},{}\n".format(filename, overall, large, small))
228
229
                f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient)))
230
                f.write("{} dice (large):   {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient)))
231
                f.write("{} dice (small):   {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient)))
232
                f.flush()
233
234
    # Saving videos with segmentations
235
    dataset = echonet.datasets.Echo(root=data_dir, split="test",
236
                                    target_type=["Filename", "LargeIndex", "SmallIndex"],  # Need filename for saving, and human-selected frames to annotate
237
                                    mean=mean, std=std,  # Normalization
238
                                    length=None, max_length=None, period=1  # Take all frames
239
                                    )
240
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn)
241
242
    # Save videos with segmentation
243
    if save_video and not all(os.path.isfile(os.path.join(output, "videos", f)) for f in dataloader.dataset.fnames):
244
        # Only run if missing videos
245
246
        model.eval()
247
248
        os.makedirs(os.path.join(output, "videos"), exist_ok=True)
249
        os.makedirs(os.path.join(output, "size"), exist_ok=True)
250
        echonet.utils.latexify()
251
252
        with torch.no_grad():
253
            with open(os.path.join(output, "size.csv"), "w") as g:
254
                g.write("Filename,Frame,Size,HumanLarge,HumanSmall,ComputerSmall\n")
255
                for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader):
256
                    # Run segmentation model on blocks of frames one-by-one
257
                    # The whole concatenated video may be too long to run together
258
                    y = np.concatenate([model(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)])
259
260
                    start = 0
261
                    x = x.numpy()
262
                    for (i, (filename, offset)) in enumerate(zip(filenames, length)):
263
                        # Extract one video and segmentation predictions
264
                        video = x[start:(start + offset), ...]
265
                        logit = y[start:(start + offset), 0, :, :]
266
267
                        # Un-normalize video
268
                        video *= std.reshape(1, 3, 1, 1)
269
                        video += mean.reshape(1, 3, 1, 1)
270
271
                        # Get frames, channels, height, and width
272
                        f, c, h, w = video.shape  # pylint: disable=W0612
273
                        assert c == 3
274
275
                        # Put two copies of the video side by side
276
                        video = np.concatenate((video, video), 3)
277
278
                        # If a pixel is in the segmentation, saturate blue channel
279
                        # Leave alone otherwise
280
                        video[:, 0, :, w:] = np.maximum(255. * (logit > 0), video[:, 0, :, w:])  # pylint: disable=E1111
281
282
                        # Add blank canvas under pair of videos
283
                        video = np.concatenate((video, np.zeros_like(video)), 2)
284
285
                        # Compute size of segmentation per frame
286
                        size = (logit > 0).sum((1, 2))
287
288
                        # Identify systole frames with peak detection
289
                        trim_min = sorted(size)[round(len(size) ** 0.05)]
290
                        trim_max = sorted(size)[round(len(size) ** 0.95)]
291
                        trim_range = trim_max - trim_min
292
                        systole = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0])
293
294
                        # Write sizes and frames to file
295
                        for (frame, s) in enumerate(size):
296
                            g.write("{},{},{},{},{},{}\n".format(filename, frame, s, 1 if frame == large_index[i] else 0, 1 if frame == small_index[i] else 0, 1 if frame in systole else 0))
297
298
                        # Plot sizes
299
                        fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3))
300
                        plt.scatter(np.arange(size.shape[0]) / 50, size, s=1)
301
                        ylim = plt.ylim()
302
                        for s in systole:
303
                            plt.plot(np.array([s, s]) / 50, ylim, linewidth=1)
304
                        plt.ylim(ylim)
305
                        plt.title(os.path.splitext(filename)[0])
306
                        plt.xlabel("Seconds")
307
                        plt.ylabel("Size (pixels)")
308
                        plt.tight_layout()
309
                        plt.savefig(os.path.join(output, "size", os.path.splitext(filename)[0] + ".pdf"))
310
                        plt.close(fig)
311
312
                        # Normalize size to [0, 1]
313
                        size -= size.min()
314
                        size = size / size.max()
315
                        size = 1 - size
316
317
                        # Iterate the frames in this video
318
                        for (f, s) in enumerate(size):
319
320
                            # On all frames, mark a pixel for the size of the frame
321
                            video[:, :, int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))] = 255.
322
323
                            if f in systole:
324
                                # If frame is computer-selected systole, mark with a line
325
                                video[:, :, 115:224, int(round(f / len(size) * 200 + 10))] = 255.
326
327
                            def dash(start, stop, on=10, off=10):
328
                                buf = []
329
                                x = start
330
                                while x < stop:
331
                                    buf.extend(range(x, x + on))
332
                                    x += on
333
                                    x += off
334
                                buf = np.array(buf)
335
                                buf = buf[buf < stop]
336
                                return buf
337
                            d = dash(115, 224)
338
339
                            if f == large_index[i]:
340
                                # If frame is human-selected diastole, mark with green dashed line on all frames
341
                                video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 225, 0]).reshape((1, 3, 1))
342
                            if f == small_index[i]:
343
                                # If frame is human-selected systole, mark with red dashed line on all frames
344
                                video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 0, 225]).reshape((1, 3, 1))
345
346
                            # Get pixels for a circle centered on the pixel
347
                            r, c = skimage.draw.disk((int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))), 4.1)
348
349
                            # On the frame that's being shown, put a circle over the pixel
350
                            video[f, :, r, c] = 255.
351
352
                        # Rearrange dimensions and save
353
                        video = video.transpose(1, 0, 2, 3)
354
                        video = video.astype(np.uint8)
355
                        echonet.utils.savevideo(os.path.join(output, "videos", filename), video, 50)
356
357
                        # Move to next video
358
                        start += offset
359
360
361
def run_epoch(model, dataloader, train, optim, device):
362
    """Run one epoch of training/evaluation for segmentation.
363
364
    Args:
365
        model (torch.nn.Module): Model to train/evaulate.
366
        dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
367
        train (bool): Whether or not to train model.
368
        optim (torch.optim.Optimizer): Optimizer
369
        device (torch.device): Device to run on
370
    """
371
372
    total = 0.
373
    n = 0
374
375
    pos = 0
376
    neg = 0
377
    pos_pix = 0
378
    neg_pix = 0
379
380
    model.train(train)
381
382
    large_inter = 0
383
    large_union = 0
384
    small_inter = 0
385
    small_union = 0
386
    large_inter_list = []
387
    large_union_list = []
388
    small_inter_list = []
389
    small_union_list = []
390
391
    with torch.set_grad_enabled(train):
392
        with tqdm.tqdm(total=len(dataloader)) as pbar:
393
            for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader:
394
                # Count number of pixels in/out of human segmentation
395
                pos += (large_trace == 1).sum().item()
396
                pos += (small_trace == 1).sum().item()
397
                neg += (large_trace == 0).sum().item()
398
                neg += (small_trace == 0).sum().item()
399
400
                # Count number of pixels in/out of computer segmentation
401
                pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy()
402
                pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy()
403
                neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy()
404
                neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy()
405
406
                # Run prediction for diastolic frames and compute loss
407
                large_frame = large_frame.to(device)
408
                large_trace = large_trace.to(device)
409
                y_large = model(large_frame)["out"]
410
                loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum")
411
                # Compute pixel intersection and union between human and computer segmentations
412
                large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
413
                large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
414
                large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
415
                large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
416
417
                # Run prediction for systolic frames and compute loss
418
                small_frame = small_frame.to(device)
419
                small_trace = small_trace.to(device)
420
                y_small = model(small_frame)["out"]
421
                loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum")
422
                # Compute pixel intersection and union between human and computer segmentations
423
                small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
424
                small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
425
                small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
426
                small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
427
428
                # Take gradient step if training
429
                loss = (loss_large + loss_small) / 2
430
                if train:
431
                    optim.zero_grad()
432
                    loss.backward()
433
                    optim.step()
434
435
                # Accumulate losses and compute baselines
436
                total += loss.item()
437
                n += large_trace.size(0)
438
                p = pos / (pos + neg)
439
                p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2)
440
441
                # Show info on process bar
442
                pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter)))
443
                pbar.update()
444
445
    large_inter_list = np.array(large_inter_list)
446
    large_union_list = np.array(large_union_list)
447
    small_inter_list = np.array(small_inter_list)
448
    small_union_list = np.array(small_union_list)
449
450
    return (total / n / 112 / 112,
451
            large_inter_list,
452
            large_union_list,
453
            small_inter_list,
454
            small_union_list,
455
            )
456
457
458
def _video_collate_fn(x):
459
    """Collate function for Pytorch dataloader to merge multiple videos.
460
461
    This function should be used in a dataloader for a dataset that returns
462
    a video as the first element, along with some (non-zero) tuple of
463
    targets. Then, the input x is a list of tuples:
464
      - x[i][0] is the i-th video in the batch
465
      - x[i][1] are the targets for the i-th video
466
467
    This function returns a 3-tuple:
468
      - The first element is the videos concatenated along the frames
469
        dimension. This is done so that videos of different lengths can be
470
        processed together (tensors cannot be "jagged", so we cannot have
471
        a dimension for video, and another for frames).
472
      - The second element is contains the targets with no modification.
473
      - The third element is a list of the lengths of the videos in frames.
474
    """
475
    video, target = zip(*x)  # Extract the videos and targets
476
477
    # ``video'' is a tuple of length ``batch_size''
478
    #   Each element has shape (channels=3, frames, height, width)
479
    #   height and width are expected to be the same across videos, but
480
    #   frames can be different.
481
482
    # ``target'' is also a tuple of length ``batch_size''
483
    # Each element is a tuple of the targets for the item.
484
485
    i = list(map(lambda t: t.shape[1], video))  # Extract lengths of videos in frames
486
487
    # This contatenates the videos along the the frames dimension (basically
488
    # playing the videos one after another). The frames dimension is then
489
    # moved to be first.
490
    # Resulting shape is (total frames, channels=3, height, width)
491
    video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1))
492
493
    # Swap dimensions (approximately a transpose)
494
    # Before: target[i][j] is the j-th target of element i
495
    # After:  target[i][j] is the i-th target of element j
496
    target = zip(*target)
497
498
    return video, target, i