Switch to side-by-side view

--- a
+++ b/echonet/utils/segmentation.py
@@ -0,0 +1,498 @@
+"""Functions for training and running segmentation."""
+
+import math
+import os
+import time
+
+import click
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy.signal
+import skimage.draw
+import torch
+import torchvision
+import tqdm
+
+import echonet
+
+
+@click.command("segmentation")
+@click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
+@click.option("--output", type=click.Path(file_okay=False), default=None)
+@click.option("--model_name", type=click.Choice(
+    sorted(name for name in torchvision.models.segmentation.__dict__
+           if name.islower() and not name.startswith("__") and callable(torchvision.models.segmentation.__dict__[name]))),
+    default="deeplabv3_resnet50")
+@click.option("--pretrained/--random", default=False)
+@click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
+@click.option("--run_test/--skip_test", default=False)
+@click.option("--save_video/--skip_video", default=False)
+@click.option("--num_epochs", type=int, default=50)
+@click.option("--lr", type=float, default=1e-5)
+@click.option("--weight_decay", type=float, default=0)
+@click.option("--lr_step_period", type=int, default=None)
+@click.option("--num_train_patients", type=int, default=None)
+@click.option("--num_workers", type=int, default=4)
+@click.option("--batch_size", type=int, default=20)
+@click.option("--device", type=str, default=None)
+@click.option("--seed", type=int, default=0)
+def run(
+    data_dir=None,
+    output=None,
+
+    model_name="deeplabv3_resnet50",
+    pretrained=False,
+    weights=None,
+
+    run_test=False,
+    save_video=False,
+    num_epochs=50,
+    lr=1e-5,
+    weight_decay=1e-5,
+    lr_step_period=None,
+    num_train_patients=None,
+    num_workers=4,
+    batch_size=20,
+    device=None,
+    seed=0,
+):
+    """Trains/tests segmentation model.
+
+    Args:
+        data_dir (str, optional): Directory containing dataset. Defaults to
+            `echonet.config.DATA_DIR`.
+        output (str, optional): Directory to place outputs. Defaults to
+            output/segmentation/<model_name>_<pretrained/random>/.
+        model_name (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'',
+            ``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101''
+            (options are torchvision.models.segmentation.<model_name>)
+            Defaults to ``deeplabv3_resnet50''.
+        pretrained (bool, optional): Whether to use pretrained weights for model
+            Defaults to False.
+        weights (str, optional): Path to checkpoint containing weights to
+            initialize model. Defaults to None.
+        run_test (bool, optional): Whether or not to run on test.
+            Defaults to False.
+        save_video (bool, optional): Whether to save videos with segmentations.
+            Defaults to False.
+        num_epochs (int, optional): Number of epochs during training
+            Defaults to 50.
+        lr (float, optional): Learning rate for SGD
+            Defaults to 1e-5.
+        weight_decay (float, optional): Weight decay for SGD
+            Defaults to 0.
+        lr_step_period (int or None, optional): Period of learning rate decay
+            (learning rate is decayed by a multiplicative factor of 0.1)
+            Defaults to math.inf (never decay learning rate).
+        num_train_patients (int or None, optional): Number of training patients
+            for ablations. Defaults to all patients.
+        num_workers (int, optional): Number of subprocesses to use for data
+            loading. If 0, the data will be loaded in the main process.
+            Defaults to 4.
+        device (str or None, optional): Name of device to run on. Options from
+            https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
+            Defaults to ``cuda'' if available, and ``cpu'' otherwise.
+        batch_size (int, optional): Number of samples to load per batch
+            Defaults to 20.
+        seed (int, optional): Seed for random number generator. Defaults to 0.
+    """
+
+    # Seed RNGs
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+
+    # Set default output directory
+    if output is None:
+        output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random"))
+    os.makedirs(output, exist_ok=True)
+
+    # Set device for computations
+    if device is None:
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    # Set up model
+    model = torchvision.models.segmentation.__dict__[model_name](pretrained=pretrained, aux_loss=False)
+
+    model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)  # change number of outputs to 1
+    if device.type == "cuda":
+        model = torch.nn.DataParallel(model)
+    model.to(device)
+
+    if weights is not None:
+        checkpoint = torch.load(weights)
+        model.load_state_dict(checkpoint['state_dict'])
+
+    # Set up optimizer
+    optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
+    if lr_step_period is None:
+        lr_step_period = math.inf
+    scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
+
+    # Compute mean and std
+    mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
+    tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
+    kwargs = {"target_type": tasks,
+              "mean": mean,
+              "std": std
+              }
+
+    # Set up datasets and dataloaders
+    dataset = {}
+    dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs)
+    if num_train_patients is not None and len(dataset["train"]) > num_train_patients:
+        # Subsample patients (used for ablation experiment)
+        indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False)
+        dataset["train"] = torch.utils.data.Subset(dataset["train"], indices)
+    dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs)
+
+    # Run training and testing loops
+    with open(os.path.join(output, "log.csv"), "a") as f:
+        epoch_resume = 0
+        bestLoss = float("inf")
+        try:
+            # Attempt to load checkpoint
+            checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
+            model.load_state_dict(checkpoint['state_dict'])
+            optim.load_state_dict(checkpoint['opt_dict'])
+            scheduler.load_state_dict(checkpoint['scheduler_dict'])
+            epoch_resume = checkpoint["epoch"] + 1
+            bestLoss = checkpoint["best_loss"]
+            f.write("Resuming from epoch {}\n".format(epoch_resume))
+        except FileNotFoundError:
+            f.write("Starting run from scratch\n")
+
+        for epoch in range(epoch_resume, num_epochs):
+            print("Epoch #{}".format(epoch), flush=True)
+            for phase in ['train', 'val']:
+                start_time = time.time()
+                for i in range(torch.cuda.device_count()):
+                    torch.cuda.reset_peak_memory_stats(i)
+
+                ds = dataset[phase]
+                dataloader = torch.utils.data.DataLoader(
+                    ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
+
+                loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, phase == "train", optim, device)
+                overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum())
+                large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum())
+                small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum())
+                f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
+                                                                    phase,
+                                                                    loss,
+                                                                    overall_dice,
+                                                                    large_dice,
+                                                                    small_dice,
+                                                                    time.time() - start_time,
+                                                                    large_inter.size,
+                                                                    sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
+                                                                    sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
+                                                                    batch_size))
+                f.flush()
+            scheduler.step()
+
+            # Save checkpoint
+            save = {
+                'epoch': epoch,
+                'state_dict': model.state_dict(),
+                'best_loss': bestLoss,
+                'loss': loss,
+                'opt_dict': optim.state_dict(),
+                'scheduler_dict': scheduler.state_dict(),
+            }
+            torch.save(save, os.path.join(output, "checkpoint.pt"))
+            if loss < bestLoss:
+                torch.save(save, os.path.join(output, "best.pt"))
+                bestLoss = loss
+
+        # Load best weights
+        if num_epochs != 0:
+            checkpoint = torch.load(os.path.join(output, "best.pt"))
+            model.load_state_dict(checkpoint['state_dict'])
+            f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))
+
+        if run_test:
+            # Run on validation and test
+            for split in ["val", "test"]:
+                dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs)
+                dataloader = torch.utils.data.DataLoader(dataset,
+                                                         batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"))
+                loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, False, None, device)
+
+                overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter)
+                large_dice = 2 * large_inter / (large_union + large_inter)
+                small_dice = 2 * small_inter / (small_union + small_inter)
+                with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g:
+                    g.write("Filename, Overall, Large, Small\n")
+                    for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice):
+                        g.write("{},{},{},{}\n".format(filename, overall, large, small))
+
+                f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient)))
+                f.write("{} dice (large):   {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient)))
+                f.write("{} dice (small):   {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient)))
+                f.flush()
+
+    # Saving videos with segmentations
+    dataset = echonet.datasets.Echo(root=data_dir, split="test",
+                                    target_type=["Filename", "LargeIndex", "SmallIndex"],  # Need filename for saving, and human-selected frames to annotate
+                                    mean=mean, std=std,  # Normalization
+                                    length=None, max_length=None, period=1  # Take all frames
+                                    )
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn)
+
+    # Save videos with segmentation
+    if save_video and not all(os.path.isfile(os.path.join(output, "videos", f)) for f in dataloader.dataset.fnames):
+        # Only run if missing videos
+
+        model.eval()
+
+        os.makedirs(os.path.join(output, "videos"), exist_ok=True)
+        os.makedirs(os.path.join(output, "size"), exist_ok=True)
+        echonet.utils.latexify()
+
+        with torch.no_grad():
+            with open(os.path.join(output, "size.csv"), "w") as g:
+                g.write("Filename,Frame,Size,HumanLarge,HumanSmall,ComputerSmall\n")
+                for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader):
+                    # Run segmentation model on blocks of frames one-by-one
+                    # The whole concatenated video may be too long to run together
+                    y = np.concatenate([model(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)])
+
+                    start = 0
+                    x = x.numpy()
+                    for (i, (filename, offset)) in enumerate(zip(filenames, length)):
+                        # Extract one video and segmentation predictions
+                        video = x[start:(start + offset), ...]
+                        logit = y[start:(start + offset), 0, :, :]
+
+                        # Un-normalize video
+                        video *= std.reshape(1, 3, 1, 1)
+                        video += mean.reshape(1, 3, 1, 1)
+
+                        # Get frames, channels, height, and width
+                        f, c, h, w = video.shape  # pylint: disable=W0612
+                        assert c == 3
+
+                        # Put two copies of the video side by side
+                        video = np.concatenate((video, video), 3)
+
+                        # If a pixel is in the segmentation, saturate blue channel
+                        # Leave alone otherwise
+                        video[:, 0, :, w:] = np.maximum(255. * (logit > 0), video[:, 0, :, w:])  # pylint: disable=E1111
+
+                        # Add blank canvas under pair of videos
+                        video = np.concatenate((video, np.zeros_like(video)), 2)
+
+                        # Compute size of segmentation per frame
+                        size = (logit > 0).sum((1, 2))
+
+                        # Identify systole frames with peak detection
+                        trim_min = sorted(size)[round(len(size) ** 0.05)]
+                        trim_max = sorted(size)[round(len(size) ** 0.95)]
+                        trim_range = trim_max - trim_min
+                        systole = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0])
+
+                        # Write sizes and frames to file
+                        for (frame, s) in enumerate(size):
+                            g.write("{},{},{},{},{},{}\n".format(filename, frame, s, 1 if frame == large_index[i] else 0, 1 if frame == small_index[i] else 0, 1 if frame in systole else 0))
+
+                        # Plot sizes
+                        fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3))
+                        plt.scatter(np.arange(size.shape[0]) / 50, size, s=1)
+                        ylim = plt.ylim()
+                        for s in systole:
+                            plt.plot(np.array([s, s]) / 50, ylim, linewidth=1)
+                        plt.ylim(ylim)
+                        plt.title(os.path.splitext(filename)[0])
+                        plt.xlabel("Seconds")
+                        plt.ylabel("Size (pixels)")
+                        plt.tight_layout()
+                        plt.savefig(os.path.join(output, "size", os.path.splitext(filename)[0] + ".pdf"))
+                        plt.close(fig)
+
+                        # Normalize size to [0, 1]
+                        size -= size.min()
+                        size = size / size.max()
+                        size = 1 - size
+
+                        # Iterate the frames in this video
+                        for (f, s) in enumerate(size):
+
+                            # On all frames, mark a pixel for the size of the frame
+                            video[:, :, int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))] = 255.
+
+                            if f in systole:
+                                # If frame is computer-selected systole, mark with a line
+                                video[:, :, 115:224, int(round(f / len(size) * 200 + 10))] = 255.
+
+                            def dash(start, stop, on=10, off=10):
+                                buf = []
+                                x = start
+                                while x < stop:
+                                    buf.extend(range(x, x + on))
+                                    x += on
+                                    x += off
+                                buf = np.array(buf)
+                                buf = buf[buf < stop]
+                                return buf
+                            d = dash(115, 224)
+
+                            if f == large_index[i]:
+                                # If frame is human-selected diastole, mark with green dashed line on all frames
+                                video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 225, 0]).reshape((1, 3, 1))
+                            if f == small_index[i]:
+                                # If frame is human-selected systole, mark with red dashed line on all frames
+                                video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 0, 225]).reshape((1, 3, 1))
+
+                            # Get pixels for a circle centered on the pixel
+                            r, c = skimage.draw.disk((int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))), 4.1)
+
+                            # On the frame that's being shown, put a circle over the pixel
+                            video[f, :, r, c] = 255.
+
+                        # Rearrange dimensions and save
+                        video = video.transpose(1, 0, 2, 3)
+                        video = video.astype(np.uint8)
+                        echonet.utils.savevideo(os.path.join(output, "videos", filename), video, 50)
+
+                        # Move to next video
+                        start += offset
+
+
+def run_epoch(model, dataloader, train, optim, device):
+    """Run one epoch of training/evaluation for segmentation.
+
+    Args:
+        model (torch.nn.Module): Model to train/evaulate.
+        dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
+        train (bool): Whether or not to train model.
+        optim (torch.optim.Optimizer): Optimizer
+        device (torch.device): Device to run on
+    """
+
+    total = 0.
+    n = 0
+
+    pos = 0
+    neg = 0
+    pos_pix = 0
+    neg_pix = 0
+
+    model.train(train)
+
+    large_inter = 0
+    large_union = 0
+    small_inter = 0
+    small_union = 0
+    large_inter_list = []
+    large_union_list = []
+    small_inter_list = []
+    small_union_list = []
+
+    with torch.set_grad_enabled(train):
+        with tqdm.tqdm(total=len(dataloader)) as pbar:
+            for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader:
+                # Count number of pixels in/out of human segmentation
+                pos += (large_trace == 1).sum().item()
+                pos += (small_trace == 1).sum().item()
+                neg += (large_trace == 0).sum().item()
+                neg += (small_trace == 0).sum().item()
+
+                # Count number of pixels in/out of computer segmentation
+                pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy()
+                pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy()
+                neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy()
+                neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy()
+
+                # Run prediction for diastolic frames and compute loss
+                large_frame = large_frame.to(device)
+                large_trace = large_trace.to(device)
+                y_large = model(large_frame)["out"]
+                loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum")
+                # Compute pixel intersection and union between human and computer segmentations
+                large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
+                large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
+                large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
+                large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
+
+                # Run prediction for systolic frames and compute loss
+                small_frame = small_frame.to(device)
+                small_trace = small_trace.to(device)
+                y_small = model(small_frame)["out"]
+                loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum")
+                # Compute pixel intersection and union between human and computer segmentations
+                small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
+                small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
+                small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
+                small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
+
+                # Take gradient step if training
+                loss = (loss_large + loss_small) / 2
+                if train:
+                    optim.zero_grad()
+                    loss.backward()
+                    optim.step()
+
+                # Accumulate losses and compute baselines
+                total += loss.item()
+                n += large_trace.size(0)
+                p = pos / (pos + neg)
+                p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2)
+
+                # Show info on process bar
+                pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter)))
+                pbar.update()
+
+    large_inter_list = np.array(large_inter_list)
+    large_union_list = np.array(large_union_list)
+    small_inter_list = np.array(small_inter_list)
+    small_union_list = np.array(small_union_list)
+
+    return (total / n / 112 / 112,
+            large_inter_list,
+            large_union_list,
+            small_inter_list,
+            small_union_list,
+            )
+
+
+def _video_collate_fn(x):
+    """Collate function for Pytorch dataloader to merge multiple videos.
+
+    This function should be used in a dataloader for a dataset that returns
+    a video as the first element, along with some (non-zero) tuple of
+    targets. Then, the input x is a list of tuples:
+      - x[i][0] is the i-th video in the batch
+      - x[i][1] are the targets for the i-th video
+
+    This function returns a 3-tuple:
+      - The first element is the videos concatenated along the frames
+        dimension. This is done so that videos of different lengths can be
+        processed together (tensors cannot be "jagged", so we cannot have
+        a dimension for video, and another for frames).
+      - The second element is contains the targets with no modification.
+      - The third element is a list of the lengths of the videos in frames.
+    """
+    video, target = zip(*x)  # Extract the videos and targets
+
+    # ``video'' is a tuple of length ``batch_size''
+    #   Each element has shape (channels=3, frames, height, width)
+    #   height and width are expected to be the same across videos, but
+    #   frames can be different.
+
+    # ``target'' is also a tuple of length ``batch_size''
+    # Each element is a tuple of the targets for the item.
+
+    i = list(map(lambda t: t.shape[1], video))  # Extract lengths of videos in frames
+
+    # This contatenates the videos along the the frames dimension (basically
+    # playing the videos one after another). The frames dimension is then
+    # moved to be first.
+    # Resulting shape is (total frames, channels=3, height, width)
+    video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1))
+
+    # Swap dimensions (approximately a transpose)
+    # Before: target[i][j] is the j-th target of element i
+    # After:  target[i][j] is the i-th target of element j
+    target = zip(*target)
+
+    return video, target, i