--- a +++ b/code/compute_stats.py @@ -0,0 +1,153 @@ +""" +DeepSlide +Computes the image statistics for normalization. + +Authors: Naofumi Tomita +""" +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import (List, Tuple) + + +import torch +from PIL import Image +from torchvision.transforms import ToTensor + +Image.MAX_IMAGE_PIXELS = None + + +def compute_stats(folderpath: Path, + image_ext: str) -> Tuple[List[float], List[float]]: + """ + Compute the mean and standard deviation of the images found in folderpath. + + Args: + folderpath: Path containing images. + image_ext: Extension of the image files. + + Returns: + A tuple containing the mean and standard deviation for the images over the channel, height, and width axes. + + This implementation is based on the discussion from: + https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560/9 + """ + class MyDataset(torch.utils.data.Dataset): + """ + Creates a dataset by reading images. + + Attributes: + data: List of the string image filenames. + """ + def __init__(self, folder: Path) -> None: + """ + Create the MyDataset object. + + Args: + folder: Path to the images. + """ + self.data = [] + + for file in folder.rglob(f"*{image_ext}"): + if not file.name.startswith("."): + self.data.append(file) + + def __getitem__(self, index: int) -> torch.Tensor: + """ + Finds the specified image and outputs in correct format. + + Args: + index: Index of the desired image. + + Returns: + A PyTorch Tensor in the correct color space. + """ + return ToTensor()(Image.open(self.data[index]).convert("RGB")) + + def __len__(self) -> int: + return len(self.data) + + def online_mean_and_sd( + loader: torch.utils.data.DataLoader, report_interval: int=1000 + ) -> Tuple[List[float], List[float]]: + """ + Computes the mean and standard deviation online. + Var[x] = E[X^2] - (E[X])^2 + + Args: + loader: The PyTorch DataLoader containing the images to iterate over. + report_interval: Report the intermediate results every N items. (N=0 to suppress reporting.) + + Returns: + A tuple containing the mean and standard deviation for the images + over the channel, height, and width axes. + """ + cnt = 0 + fst_moment = torch.empty(3) + snd_moment = torch.empty(3) + + for i, data in enumerate(loader, 1): + b, __, h, w = data.shape + nb_pixels = b * h * w + fst_moment = (cnt * fst_moment + + torch.sum(data, dim=[0, 2, 3])) / (cnt + nb_pixels) + snd_moment = (cnt * snd_moment + torch.sum( + data**2, dim=[0, 2, 3])) / (cnt + nb_pixels) + cnt += nb_pixels + if report_interval != 0 and i % report_interval == 0: + temp_mean = fst_moment.tolist() + temp_std = torch.sqrt(snd_moment - fst_moment**2).tolist() + print(f"Mean: {temp_mean}; STD: {temp_std} at iter: {i}") + return fst_moment.tolist(), torch.sqrt(snd_moment - + fst_moment**2).tolist() + + return online_mean_and_sd( + loader=torch.utils.data.DataLoader( + dataset=MyDataset(folder=folderpath), + batch_size=1, + num_workers=1, + shuffle=False)) + +def save_stats(mean: List, std: List, datapath: str): + data = { + 'mean': mean, + 'std': std, + 'datapath': datapath} + data = json.dumps(data, indent=4) + filename = f"stats_{datetime.now().strftime('%Y-%m-%d_%H:%M')}.json" + with open(filename, 'w') as outfile: + outfile.write(data) + + print(f"Results are saved in {filename}.") + +def load_stats(jsonfile: str): + """ Load a stats file in json and return mean and std in lists. + """ + with open(jsonfile, 'r') as infile: + data = json.load(infile) + + print(f"Stats of \'{data['datapath']}\' are loaded from {jsonfile}.") + return data['mean'], data['std'] + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Compute channel-wise patch color mean and std.') + parser.add_argument('--datapath', '-i', type=str, required=True, + help='Path containing images.') + parser.add_argument('--image_ext', '-x', type=str, default='.png', + help='Specify file extension of images. Default: .png') + parser.add_argument('--report_interval', '-n', type=int, default=1000, + help='Report the intermediate results every N items. Default: 1000') + parser.add_argument('--save_results', '-d', action='store_true', default=False, + help='Set this flag to save results.') + args = parser.parse_args() + + mean, std = compute_stats(Path(args.datapath), args.image_ext,) + print(f"Mean: {mean}; STD: {std}") + + if args.save_results: + save_stats(mean=mean, std=std, datapath=args.datapath) + +