--- a +++ b/datasets/dataset_h5.py @@ -0,0 +1,171 @@ +from __future__ import print_function, division +import os +import torch +import numpy as np +import pandas as pd +import math +import re +import pdb +import pickle + +from torch.utils.data import Dataset, DataLoader, sampler +from torchvision import transforms, utils, models +import torch.nn.functional as F + +from PIL import Image +import h5py + +from random import randrange + +def eval_transforms(pretrained=False): + if pretrained: + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + + else: + mean = (0.5,0.5,0.5) + std = (0.5,0.5,0.5) + + trnsfrms_val = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean = mean, std = std) + ] + ) + + return trnsfrms_val + +class Whole_Slide_Bag(Dataset): + def __init__(self, + file_path, + pretrained=False, + custom_transforms=None, + target_patch_size=-1, + ): + """ + Args: + file_path (string): Path to the .h5 file containing patched data. + pretrained (bool): Use ImageNet transforms + custom_transforms (callable, optional): Optional transform to be applied on a sample + """ + self.pretrained=pretrained + if target_patch_size > 0: + self.target_patch_size = (target_patch_size, target_patch_size) + else: + self.target_patch_size = None + + if not custom_transforms: + self.roi_transforms = eval_transforms(pretrained=pretrained) + else: + self.roi_transforms = custom_transforms + + self.file_path = file_path + + with h5py.File(self.file_path, "r") as f: + dset = f['imgs'] + self.length = len(dset) + + self.summary() + + def __len__(self): + return self.length + + def summary(self): + hdf5_file = h5py.File(self.file_path, "r") + dset = hdf5_file['imgs'] + for name, value in dset.attrs.items(): + print(name, value) + + print('pretrained:', self.pretrained) + print('transformations:', self.roi_transforms) + if self.target_patch_size is not None: + print('target_size: ', self.target_patch_size) + + def __getitem__(self, idx): + with h5py.File(self.file_path,'r') as hdf5_file: + img = hdf5_file['imgs'][idx] + coord = hdf5_file['coords'][idx] + + img = Image.fromarray(img) + if self.target_patch_size is not None: + img = img.resize(self.target_patch_size) + img = self.roi_transforms(img).unsqueeze(0) + return img, coord + +class Whole_Slide_Bag_FP(Dataset): + def __init__(self, + file_path, + wsi, + pretrained=False, + custom_transforms=None, + custom_downsample=1, + target_patch_size=-1 + ): + """ + Args: + file_path (string): Path to the .h5 file containing patched data. + pretrained (bool): Use ImageNet transforms + custom_transforms (callable, optional): Optional transform to be applied on a sample + custom_downsample (int): Custom defined downscale factor (overruled by target_patch_size) + target_patch_size (int): Custom defined image size before embedding + """ + self.pretrained=pretrained + self.wsi = wsi + if not custom_transforms: + self.roi_transforms = eval_transforms(pretrained=pretrained) + else: + self.roi_transforms = custom_transforms + + self.file_path = file_path + + with h5py.File(self.file_path, "r") as f: + dset = f['coords'] + self.patch_level = f['coords'].attrs['patch_level'] + self.patch_size = f['coords'].attrs['patch_size'] + self.length = len(dset) + if target_patch_size > 0: + self.target_patch_size = (target_patch_size, ) * 2 + elif custom_downsample > 1: + self.target_patch_size = (self.patch_size // custom_downsample, ) * 2 + else: + self.target_patch_size = None + self.summary() + + def __len__(self): + return self.length + + def summary(self): + hdf5_file = h5py.File(self.file_path, "r") + dset = hdf5_file['coords'] + for name, value in dset.attrs.items(): + print(name, value) + + print('\nfeature extraction settings') + print('target patch size: ', self.target_patch_size) + print('pretrained: ', self.pretrained) + print('transformations: ', self.roi_transforms) + + def __getitem__(self, idx): + with h5py.File(self.file_path,'r') as hdf5_file: + coord = hdf5_file['coords'][idx] + img = self.wsi.read_region(coord, self.patch_level, (self.patch_size, self.patch_size)).convert('RGB') + + if self.target_patch_size is not None: + img = img.resize(self.target_patch_size) + img = self.roi_transforms(img).unsqueeze(0) + return img, coord + +class Dataset_All_Bags(Dataset): + + def __init__(self, csv_path): + self.df = pd.read_csv(csv_path) + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + return self.df['slide_id'][idx] + + + +