--- a +++ b/datacode/ultrasound_data.py @@ -0,0 +1,201 @@ +""" Dataset classes for MBZUAI- BiomedIA Fetal Ultra Sound datasets +""" + +import os, sys +import json, glob +import random +import PIL.Image +import h5py +import pandas as pd +import numpy as np + +import torch +from torch.utils.data import Dataset, WeightedRandomSampler +import torchvision.transforms as torch_transforms +from typing import List, Dict + +##---------------------- Generals ----------------------------------------------- + +def filter_dataframe(self, df, filtering_dict): + """ Usage: + {"blacklist":{'class':["4ch"],"machine_type":["Voluson E8","Voluson S10 Expert","V830"]}} + """ + + if "blacklist" in filtering_dict and "whitelist" in filtering_dict: + raise Exception("Hey, decide between whitelisting or blacklisting,"+\ + "Can't do both! remove either one") + + if "blacklist" in filtering_dict: + print("blacklisting...") + blacklist_dict = filtering_dict["blacklist"] + new_df = df + for k in blacklist_dict.keys(): + for val in blacklist_dict[k]: + new_df = new_df[new_df[k] != val] + + elif "whitelist" in filtering_dict: + print("whitelisting...") + whitelist_dict = filtering_dict["whitelist"] + new_df_list = [] + for k in whitelist_dict.keys(): + for val in whitelist_dict[k]: + new_df_list.append(df[df[k] == val]) + new_df = pd.concat(new_df_list).drop_duplicates().reset_index(drop=True) + + else: + print("No filtering of data done, Peace!") + new_df = df + + return new_df + + +def get_class_weights(targets, nclasses): + """ + Sample level weights fro balanced Loss statergy or data sampling + targets: assumed to be Long ints representing class from dataset + """ + + n_target = len(targets) + count_per_class = np.zeros(nclasses, dtype=int) + for c in targets: + count_per_class[c] += 1 + count_per_class[count_per_class==0] = n_target + + # for passing to Loss funcs + weight_per_class = np.zeros(nclasses, dtype=float) + for i in range(nclasses): + weight_per_class[i] = float(n_target) / float(count_per_class[i]) + + # for passing to sampler + weight_samplewise = np.zeros(n_target, dtype=float) + for idx, tgt in enumerate(targets): + weight_samplewise[idx] = weight_per_class[tgt] + + return weight_per_class, weight_samplewise + + + +## ============================================================================= +## Classification + + +class ClassifyDataFromCSV(Dataset): + def __init__(self, root_folder, csv_path, transform = None, + filtering_dict: Dict[str,Dict[str,List]] = None, + ): + """ + """ + self.root_folder = root_folder + self.df = pd.read_csv(csv_path) + + ## Filter based on some condition in dataframes + if filtering_dict: self.df = filter_dataframe(self.df, filtering_dict) + + self.class_to_idx ={c:i for i, c in enumerate(sorted(set( + self.df["class"])))} + self.images_path = [ os.path.join(root_folder, p) + for p in self.df["image_path"] ] + self.targets =list(map(lambda x: self.class_to_idx[x], + list(self.df["class"]) )) + + if transform: self.transform = transform + else: self.transform = torch_transforms.ToTensor() + + print("Class Indexing:", self.class_to_idx) + + def __len__(self): + return len(self.images_path) + + def __getitem__(self, index): + imgpath = self.images_path[index] + target = self.targets[index] + image = PIL.Image.open(imgpath).convert("RGB") + image = self.transform(image) + return image, target + + + + + +##================ US Video Frames Loader ====================================== + + +class FetalUSFramesDataset(torch.utils.data.Dataset): + """ Treats Video frames as Independant images for trainng purposes + """ + def __init__(self, images_folder=None, hdf5_file=None, + transform = None, + load2ram = False, frame_skip=None): + """ + """ + self.load2ram = load2ram + self.frame_skip = frame_skip + #tobedefined + self.image_paths= [] + self.image_frames= [] + self.get_image_func = None + ##----- + + if transform: self.transform = transform + else: self.transform = torch_transforms.ToTensor() + + if hdf5_file: self._hdf5file_handler(hdf5_file) + elif images_folder: self._imagefolder_handler(images_folder) + else: raise Exception("No Data info to load") + + + # for image folder handling + def _imagefolder_handler(self, images_folder): + def __get_image_lazy(index): + return PIL.Image.open(self.image_paths[index]).convert("RGB") + def __get_image_eager(index): + return self.image_frames[index] + + self.image_paths = sorted(glob.glob(images_folder+"/**/*.png")) + + self.get_image_func = __get_image_lazy + if self.load2ram: + self.image_frames = [ __get_image_lazy(i) + for i in range(len(self.image_paths))] + self.get_image_func = __get_image_eager + + print("Frame Skip is not implemented") + + # for hdf5 file handling + def _hdf5file_handler(self, hdf5_file): + def __get_image_lazy(index): + k, i = self.image_paths[index] + arr = self.hdfobj[k][i] + return PIL.Image.fromarray(arr).convert("RGB") + + def __get_image_eager(index): + return self.image_frames[index] + + self.hdfobj = h5py.File(hdf5_file,'r') + for k in self.hdfobj.keys(): + for i in range(self.hdfobj[k].shape[0]): + if i % self.frame_skip: continue + self.image_paths.append([k, i]) + + self.get_image_func = __get_image_lazy + if self.load2ram: + self.image_frames = [ __get_image_lazy(i) + for i in range(len(self.image_paths))] + self.get_image_func = __get_image_eager + + + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, index): + image = self.get_image_func(index) + image = self.transform(image) + return image + + def get_info(self): + print(self.get_image_func) + return { + "DataSize": self.__len__(), + "Transforms": str(self.transform), + } \ No newline at end of file