|
a |
|
b/code/compute_stats.py |
|
|
1 |
""" |
|
|
2 |
DeepSlide |
|
|
3 |
Computes the image statistics for normalization. |
|
|
4 |
|
|
|
5 |
Authors: Naofumi Tomita |
|
|
6 |
""" |
|
|
7 |
import argparse |
|
|
8 |
import json |
|
|
9 |
from datetime import datetime |
|
|
10 |
from pathlib import Path |
|
|
11 |
from typing import (List, Tuple) |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
import torch |
|
|
15 |
from PIL import Image |
|
|
16 |
from torchvision.transforms import ToTensor |
|
|
17 |
|
|
|
18 |
Image.MAX_IMAGE_PIXELS = None |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def compute_stats(folderpath: Path, |
|
|
22 |
image_ext: str) -> Tuple[List[float], List[float]]: |
|
|
23 |
""" |
|
|
24 |
Compute the mean and standard deviation of the images found in folderpath. |
|
|
25 |
|
|
|
26 |
Args: |
|
|
27 |
folderpath: Path containing images. |
|
|
28 |
image_ext: Extension of the image files. |
|
|
29 |
|
|
|
30 |
Returns: |
|
|
31 |
A tuple containing the mean and standard deviation for the images over the channel, height, and width axes. |
|
|
32 |
|
|
|
33 |
This implementation is based on the discussion from: |
|
|
34 |
https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560/9 |
|
|
35 |
""" |
|
|
36 |
class MyDataset(torch.utils.data.Dataset): |
|
|
37 |
""" |
|
|
38 |
Creates a dataset by reading images. |
|
|
39 |
|
|
|
40 |
Attributes: |
|
|
41 |
data: List of the string image filenames. |
|
|
42 |
""" |
|
|
43 |
def __init__(self, folder: Path) -> None: |
|
|
44 |
""" |
|
|
45 |
Create the MyDataset object. |
|
|
46 |
|
|
|
47 |
Args: |
|
|
48 |
folder: Path to the images. |
|
|
49 |
""" |
|
|
50 |
self.data = [] |
|
|
51 |
|
|
|
52 |
for file in folder.rglob(f"*{image_ext}"): |
|
|
53 |
if not file.name.startswith("."): |
|
|
54 |
self.data.append(file) |
|
|
55 |
|
|
|
56 |
def __getitem__(self, index: int) -> torch.Tensor: |
|
|
57 |
""" |
|
|
58 |
Finds the specified image and outputs in correct format. |
|
|
59 |
|
|
|
60 |
Args: |
|
|
61 |
index: Index of the desired image. |
|
|
62 |
|
|
|
63 |
Returns: |
|
|
64 |
A PyTorch Tensor in the correct color space. |
|
|
65 |
""" |
|
|
66 |
return ToTensor()(Image.open(self.data[index]).convert("RGB")) |
|
|
67 |
|
|
|
68 |
def __len__(self) -> int: |
|
|
69 |
return len(self.data) |
|
|
70 |
|
|
|
71 |
def online_mean_and_sd( |
|
|
72 |
loader: torch.utils.data.DataLoader, report_interval: int=1000 |
|
|
73 |
) -> Tuple[List[float], List[float]]: |
|
|
74 |
""" |
|
|
75 |
Computes the mean and standard deviation online. |
|
|
76 |
Var[x] = E[X^2] - (E[X])^2 |
|
|
77 |
|
|
|
78 |
Args: |
|
|
79 |
loader: The PyTorch DataLoader containing the images to iterate over. |
|
|
80 |
report_interval: Report the intermediate results every N items. (N=0 to suppress reporting.) |
|
|
81 |
|
|
|
82 |
Returns: |
|
|
83 |
A tuple containing the mean and standard deviation for the images |
|
|
84 |
over the channel, height, and width axes. |
|
|
85 |
""" |
|
|
86 |
cnt = 0 |
|
|
87 |
fst_moment = torch.empty(3) |
|
|
88 |
snd_moment = torch.empty(3) |
|
|
89 |
|
|
|
90 |
for i, data in enumerate(loader, 1): |
|
|
91 |
b, __, h, w = data.shape |
|
|
92 |
nb_pixels = b * h * w |
|
|
93 |
fst_moment = (cnt * fst_moment + |
|
|
94 |
torch.sum(data, dim=[0, 2, 3])) / (cnt + nb_pixels) |
|
|
95 |
snd_moment = (cnt * snd_moment + torch.sum( |
|
|
96 |
data**2, dim=[0, 2, 3])) / (cnt + nb_pixels) |
|
|
97 |
cnt += nb_pixels |
|
|
98 |
if report_interval != 0 and i % report_interval == 0: |
|
|
99 |
temp_mean = fst_moment.tolist() |
|
|
100 |
temp_std = torch.sqrt(snd_moment - fst_moment**2).tolist() |
|
|
101 |
print(f"Mean: {temp_mean}; STD: {temp_std} at iter: {i}") |
|
|
102 |
return fst_moment.tolist(), torch.sqrt(snd_moment - |
|
|
103 |
fst_moment**2).tolist() |
|
|
104 |
|
|
|
105 |
return online_mean_and_sd( |
|
|
106 |
loader=torch.utils.data.DataLoader( |
|
|
107 |
dataset=MyDataset(folder=folderpath), |
|
|
108 |
batch_size=1, |
|
|
109 |
num_workers=1, |
|
|
110 |
shuffle=False)) |
|
|
111 |
|
|
|
112 |
def save_stats(mean: List, std: List, datapath: str): |
|
|
113 |
data = { |
|
|
114 |
'mean': mean, |
|
|
115 |
'std': std, |
|
|
116 |
'datapath': datapath} |
|
|
117 |
data = json.dumps(data, indent=4) |
|
|
118 |
filename = f"stats_{datetime.now().strftime('%Y-%m-%d_%H:%M')}.json" |
|
|
119 |
with open(filename, 'w') as outfile: |
|
|
120 |
outfile.write(data) |
|
|
121 |
|
|
|
122 |
print(f"Results are saved in {filename}.") |
|
|
123 |
|
|
|
124 |
def load_stats(jsonfile: str): |
|
|
125 |
""" Load a stats file in json and return mean and std in lists. |
|
|
126 |
""" |
|
|
127 |
with open(jsonfile, 'r') as infile: |
|
|
128 |
data = json.load(infile) |
|
|
129 |
|
|
|
130 |
print(f"Stats of \'{data['datapath']}\' are loaded from {jsonfile}.") |
|
|
131 |
return data['mean'], data['std'] |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
if __name__ == '__main__': |
|
|
135 |
parser = argparse.ArgumentParser( |
|
|
136 |
description='Compute channel-wise patch color mean and std.') |
|
|
137 |
parser.add_argument('--datapath', '-i', type=str, required=True, |
|
|
138 |
help='Path containing images.') |
|
|
139 |
parser.add_argument('--image_ext', '-x', type=str, default='.png', |
|
|
140 |
help='Specify file extension of images. Default: .png') |
|
|
141 |
parser.add_argument('--report_interval', '-n', type=int, default=1000, |
|
|
142 |
help='Report the intermediate results every N items. Default: 1000') |
|
|
143 |
parser.add_argument('--save_results', '-d', action='store_true', default=False, |
|
|
144 |
help='Set this flag to save results.') |
|
|
145 |
args = parser.parse_args() |
|
|
146 |
|
|
|
147 |
mean, std = compute_stats(Path(args.datapath), args.image_ext,) |
|
|
148 |
print(f"Mean: {mean}; STD: {std}") |
|
|
149 |
|
|
|
150 |
if args.save_results: |
|
|
151 |
save_stats(mean=mean, std=std, datapath=args.datapath) |
|
|
152 |
|
|
|
153 |
|