Diff of /code/compute_stats.py [000000] .. [594161]

Switch to side-by-side view

--- a
+++ b/code/compute_stats.py
@@ -0,0 +1,153 @@
+"""
+DeepSlide
+Computes the image statistics for normalization.
+
+Authors: Naofumi Tomita
+"""
+import argparse
+import json
+from datetime import datetime
+from pathlib import Path
+from typing import (List, Tuple)
+
+
+import torch
+from PIL import Image
+from torchvision.transforms import ToTensor
+
+Image.MAX_IMAGE_PIXELS = None
+
+
+def compute_stats(folderpath: Path,
+                  image_ext: str) -> Tuple[List[float], List[float]]:
+    """
+    Compute the mean and standard deviation of the images found in folderpath.
+
+    Args:
+        folderpath: Path containing images.
+        image_ext: Extension of the image files.
+
+    Returns:
+        A tuple containing the mean and standard deviation for the images over the channel, height, and width axes.
+
+    This implementation is based on the discussion from: 
+        https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560/9
+    """
+    class MyDataset(torch.utils.data.Dataset):
+        """
+        Creates a dataset by reading images.
+
+        Attributes:
+            data: List of the string image filenames.
+        """
+        def __init__(self, folder: Path) -> None:
+            """
+            Create the MyDataset object.
+
+            Args:
+                folder: Path to the images.
+            """
+            self.data = []
+
+            for file in folder.rglob(f"*{image_ext}"):
+                if not file.name.startswith("."):
+                    self.data.append(file)
+
+        def __getitem__(self, index: int) -> torch.Tensor:
+            """
+            Finds the specified image and outputs in correct format.
+
+            Args:
+                index: Index of the desired image.
+
+            Returns:
+                A PyTorch Tensor in the correct color space.
+            """
+            return ToTensor()(Image.open(self.data[index]).convert("RGB"))
+
+        def __len__(self) -> int:
+            return len(self.data)
+
+    def online_mean_and_sd(
+        loader: torch.utils.data.DataLoader, report_interval: int=1000
+                           ) -> Tuple[List[float], List[float]]:
+        """
+        Computes the mean and standard deviation online.
+            Var[x] = E[X^2] - (E[X])^2
+
+        Args:
+            loader: The PyTorch DataLoader containing the images to iterate over.
+            report_interval: Report the intermediate results every N items. (N=0 to suppress reporting.)
+
+        Returns:
+            A tuple containing the mean and standard deviation for the images
+            over the channel, height, and width axes.
+        """
+        cnt = 0
+        fst_moment = torch.empty(3)
+        snd_moment = torch.empty(3)
+
+        for i, data in enumerate(loader, 1):
+            b, __, h, w = data.shape
+            nb_pixels = b * h * w
+            fst_moment = (cnt * fst_moment +
+                          torch.sum(data, dim=[0, 2, 3])) / (cnt + nb_pixels)
+            snd_moment = (cnt * snd_moment + torch.sum(
+                data**2, dim=[0, 2, 3])) / (cnt + nb_pixels)
+            cnt += nb_pixels
+            if report_interval != 0 and i % report_interval == 0:
+                temp_mean = fst_moment.tolist()
+                temp_std = torch.sqrt(snd_moment - fst_moment**2).tolist()
+                print(f"Mean: {temp_mean}; STD: {temp_std} at iter: {i}")
+        return fst_moment.tolist(), torch.sqrt(snd_moment -
+                                               fst_moment**2).tolist()
+
+    return online_mean_and_sd(
+        loader=torch.utils.data.DataLoader(
+            dataset=MyDataset(folder=folderpath),
+            batch_size=1,
+            num_workers=1,
+            shuffle=False))
+
+def save_stats(mean: List, std: List, datapath: str):
+    data = {
+        'mean': mean,
+        'std': std,
+        'datapath': datapath}
+    data = json.dumps(data, indent=4)
+    filename = f"stats_{datetime.now().strftime('%Y-%m-%d_%H:%M')}.json"
+    with open(filename, 'w') as outfile:
+        outfile.write(data)
+
+    print(f"Results are saved in {filename}.")
+
+def load_stats(jsonfile: str):
+    """ Load a stats file in json and return mean and std in lists.
+    """
+    with open(jsonfile, 'r') as infile:        
+        data = json.load(infile)
+
+    print(f"Stats of \'{data['datapath']}\' are loaded from {jsonfile}.")
+    return data['mean'], data['std']
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description='Compute channel-wise patch color mean and std.')
+    parser.add_argument('--datapath', '-i', type=str, required=True,
+        help='Path containing images.')
+    parser.add_argument('--image_ext', '-x', type=str, default='.png',
+        help='Specify file extension of images. Default: .png')
+    parser.add_argument('--report_interval', '-n', type=int, default=1000,
+        help='Report the intermediate results every N items. Default: 1000')
+    parser.add_argument('--save_results', '-d', action='store_true', default=False,
+        help='Set this flag to save results.')
+    args = parser.parse_args()
+
+    mean, std = compute_stats(Path(args.datapath), args.image_ext,)
+    print(f"Mean: {mean}; STD: {std}")
+
+    if args.save_results:
+        save_stats(mean=mean, std=std, datapath=args.datapath)
+
+