Diff of /code/compute_stats.py [000000] .. [594161]

Switch to unified view

a b/code/compute_stats.py
1
"""
2
DeepSlide
3
Computes the image statistics for normalization.
4
5
Authors: Naofumi Tomita
6
"""
7
import argparse
8
import json
9
from datetime import datetime
10
from pathlib import Path
11
from typing import (List, Tuple)
12
13
14
import torch
15
from PIL import Image
16
from torchvision.transforms import ToTensor
17
18
Image.MAX_IMAGE_PIXELS = None
19
20
21
def compute_stats(folderpath: Path,
22
                  image_ext: str) -> Tuple[List[float], List[float]]:
23
    """
24
    Compute the mean and standard deviation of the images found in folderpath.
25
26
    Args:
27
        folderpath: Path containing images.
28
        image_ext: Extension of the image files.
29
30
    Returns:
31
        A tuple containing the mean and standard deviation for the images over the channel, height, and width axes.
32
33
    This implementation is based on the discussion from: 
34
        https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560/9
35
    """
36
    class MyDataset(torch.utils.data.Dataset):
37
        """
38
        Creates a dataset by reading images.
39
40
        Attributes:
41
            data: List of the string image filenames.
42
        """
43
        def __init__(self, folder: Path) -> None:
44
            """
45
            Create the MyDataset object.
46
47
            Args:
48
                folder: Path to the images.
49
            """
50
            self.data = []
51
52
            for file in folder.rglob(f"*{image_ext}"):
53
                if not file.name.startswith("."):
54
                    self.data.append(file)
55
56
        def __getitem__(self, index: int) -> torch.Tensor:
57
            """
58
            Finds the specified image and outputs in correct format.
59
60
            Args:
61
                index: Index of the desired image.
62
63
            Returns:
64
                A PyTorch Tensor in the correct color space.
65
            """
66
            return ToTensor()(Image.open(self.data[index]).convert("RGB"))
67
68
        def __len__(self) -> int:
69
            return len(self.data)
70
71
    def online_mean_and_sd(
72
        loader: torch.utils.data.DataLoader, report_interval: int=1000
73
                           ) -> Tuple[List[float], List[float]]:
74
        """
75
        Computes the mean and standard deviation online.
76
            Var[x] = E[X^2] - (E[X])^2
77
78
        Args:
79
            loader: The PyTorch DataLoader containing the images to iterate over.
80
            report_interval: Report the intermediate results every N items. (N=0 to suppress reporting.)
81
82
        Returns:
83
            A tuple containing the mean and standard deviation for the images
84
            over the channel, height, and width axes.
85
        """
86
        cnt = 0
87
        fst_moment = torch.empty(3)
88
        snd_moment = torch.empty(3)
89
90
        for i, data in enumerate(loader, 1):
91
            b, __, h, w = data.shape
92
            nb_pixels = b * h * w
93
            fst_moment = (cnt * fst_moment +
94
                          torch.sum(data, dim=[0, 2, 3])) / (cnt + nb_pixels)
95
            snd_moment = (cnt * snd_moment + torch.sum(
96
                data**2, dim=[0, 2, 3])) / (cnt + nb_pixels)
97
            cnt += nb_pixels
98
            if report_interval != 0 and i % report_interval == 0:
99
                temp_mean = fst_moment.tolist()
100
                temp_std = torch.sqrt(snd_moment - fst_moment**2).tolist()
101
                print(f"Mean: {temp_mean}; STD: {temp_std} at iter: {i}")
102
        return fst_moment.tolist(), torch.sqrt(snd_moment -
103
                                               fst_moment**2).tolist()
104
105
    return online_mean_and_sd(
106
        loader=torch.utils.data.DataLoader(
107
            dataset=MyDataset(folder=folderpath),
108
            batch_size=1,
109
            num_workers=1,
110
            shuffle=False))
111
112
def save_stats(mean: List, std: List, datapath: str):
113
    data = {
114
        'mean': mean,
115
        'std': std,
116
        'datapath': datapath}
117
    data = json.dumps(data, indent=4)
118
    filename = f"stats_{datetime.now().strftime('%Y-%m-%d_%H:%M')}.json"
119
    with open(filename, 'w') as outfile:
120
        outfile.write(data)
121
122
    print(f"Results are saved in {filename}.")
123
124
def load_stats(jsonfile: str):
125
    """ Load a stats file in json and return mean and std in lists.
126
    """
127
    with open(jsonfile, 'r') as infile:        
128
        data = json.load(infile)
129
130
    print(f"Stats of \'{data['datapath']}\' are loaded from {jsonfile}.")
131
    return data['mean'], data['std']
132
133
134
if __name__ == '__main__':
135
    parser = argparse.ArgumentParser(
136
        description='Compute channel-wise patch color mean and std.')
137
    parser.add_argument('--datapath', '-i', type=str, required=True,
138
        help='Path containing images.')
139
    parser.add_argument('--image_ext', '-x', type=str, default='.png',
140
        help='Specify file extension of images. Default: .png')
141
    parser.add_argument('--report_interval', '-n', type=int, default=1000,
142
        help='Report the intermediate results every N items. Default: 1000')
143
    parser.add_argument('--save_results', '-d', action='store_true', default=False,
144
        help='Set this flag to save results.')
145
    args = parser.parse_args()
146
147
    mean, std = compute_stats(Path(args.datapath), args.image_ext,)
148
    print(f"Mean: {mean}; STD: {std}")
149
150
    if args.save_results:
151
        save_stats(mean=mean, std=std, datapath=args.datapath)
152
153