Switch to unified view

a b/datacode/ultrasound_data.py
1
""" Dataset classes for MBZUAI- BiomedIA Fetal Ultra Sound datasets
2
"""
3
4
import os, sys
5
import json, glob
6
import random
7
import PIL.Image
8
import h5py
9
import pandas as pd
10
import numpy as np
11
12
import torch
13
from torch.utils.data import Dataset, WeightedRandomSampler
14
import torchvision.transforms as torch_transforms
15
from typing import List, Dict
16
17
##---------------------- Generals -----------------------------------------------
18
19
def filter_dataframe(self, df, filtering_dict):
20
    """ Usage:
21
    {"blacklist":{'class':["4ch"],"machine_type":["Voluson E8","Voluson S10 Expert","V830"]}}
22
    """
23
24
    if "blacklist" in filtering_dict and "whitelist" in filtering_dict:
25
        raise  Exception("Hey, decide between whitelisting or blacklisting,"+\
26
                         "Can't do both! remove either one")
27
28
    if "blacklist" in filtering_dict:
29
        print("blacklisting...")
30
        blacklist_dict = filtering_dict["blacklist"]
31
        new_df = df
32
        for k in blacklist_dict.keys():
33
            for val in blacklist_dict[k]:
34
                new_df = new_df[new_df[k] != val]
35
36
    elif "whitelist" in filtering_dict:
37
        print("whitelisting...")
38
        whitelist_dict = filtering_dict["whitelist"]
39
        new_df_list = []
40
        for k in whitelist_dict.keys():
41
            for val in whitelist_dict[k]:
42
                new_df_list.append(df[df[k] == val])
43
        new_df = pd.concat(new_df_list).drop_duplicates().reset_index(drop=True)
44
45
    else:
46
        print("No filtering of data done, Peace!")
47
        new_df = df
48
49
    return new_df
50
51
52
def get_class_weights(targets, nclasses):
53
    """
54
    Sample level weights fro balanced Loss statergy or data sampling
55
    targets: assumed to be Long ints representing class from dataset
56
    """
57
58
    n_target = len(targets)
59
    count_per_class = np.zeros(nclasses, dtype=int)
60
    for c in targets:
61
        count_per_class[c] += 1
62
    count_per_class[count_per_class==0] = n_target
63
64
    # for passing to Loss funcs
65
    weight_per_class = np.zeros(nclasses, dtype=float)
66
    for i in range(nclasses):
67
        weight_per_class[i] = float(n_target) / float(count_per_class[i])
68
69
    # for passing to sampler
70
    weight_samplewise = np.zeros(n_target, dtype=float)
71
    for idx, tgt in enumerate(targets):
72
        weight_samplewise[idx] = weight_per_class[tgt]
73
74
    return weight_per_class, weight_samplewise
75
76
77
78
## =============================================================================
79
## Classification
80
81
82
class ClassifyDataFromCSV(Dataset):
83
    def __init__(self, root_folder, csv_path, transform = None,
84
                        filtering_dict: Dict[str,Dict[str,List]] = None,
85
                        ):
86
        """
87
        """
88
        self.root_folder = root_folder
89
        self.df = pd.read_csv(csv_path)
90
91
        ## Filter based on some condition in dataframes
92
        if filtering_dict: self.df = filter_dataframe(self.df, filtering_dict)
93
94
        self.class_to_idx ={c:i for i, c in enumerate(sorted(set(
95
                                    self.df["class"])))}
96
        self.images_path =  [ os.path.join(root_folder, p)
97
                             for p in self.df["image_path"] ]
98
        self.targets =list(map(lambda x: self.class_to_idx[x],
99
                                    list(self.df["class"]) ))
100
101
        if transform: self.transform = transform
102
        else: self.transform = torch_transforms.ToTensor()
103
104
        print("Class Indexing:", self.class_to_idx)
105
106
    def __len__(self):
107
        return len(self.images_path)
108
109
    def __getitem__(self, index):
110
        imgpath = self.images_path[index]
111
        target = self.targets[index]
112
        image = PIL.Image.open(imgpath).convert("RGB")
113
        image = self.transform(image)
114
        return image, target
115
116
117
118
119
120
##================ US Video Frames Loader ======================================
121
122
123
class FetalUSFramesDataset(torch.utils.data.Dataset):
124
    """ Treats Video frames as Independant images for trainng purposes
125
    """
126
    def __init__(self, images_folder=None, hdf5_file=None,
127
                        transform = None,
128
                        load2ram = False, frame_skip=None):
129
        """
130
        """
131
        self.load2ram = load2ram
132
        self.frame_skip = frame_skip
133
        #tobedefined
134
        self.image_paths= []
135
        self.image_frames= []
136
        self.get_image_func = None
137
        ##-----
138
139
        if transform: self.transform = transform
140
        else: self.transform = torch_transforms.ToTensor()
141
142
        if hdf5_file:       self._hdf5file_handler(hdf5_file)
143
        elif images_folder:  self._imagefolder_handler(images_folder)
144
        else: raise Exception("No Data info to load")
145
146
147
    # for image folder handling
148
    def _imagefolder_handler(self, images_folder):
149
        def __get_image_lazy(index):
150
            return PIL.Image.open(self.image_paths[index]).convert("RGB")
151
        def __get_image_eager(index):
152
            return self.image_frames[index]
153
154
        self.image_paths = sorted(glob.glob(images_folder+"/**/*.png"))
155
156
        self.get_image_func = __get_image_lazy
157
        if self.load2ram:
158
            self.image_frames = [ __get_image_lazy(i)
159
                                    for i in range(len(self.image_paths))]
160
            self.get_image_func = __get_image_eager
161
162
        print("Frame Skip is not implemented")
163
164
    # for hdf5 file handling
165
    def _hdf5file_handler(self, hdf5_file):
166
        def __get_image_lazy(index):
167
            k, i = self.image_paths[index]
168
            arr = self.hdfobj[k][i]
169
            return PIL.Image.fromarray(arr).convert("RGB")
170
171
        def __get_image_eager(index):
172
            return self.image_frames[index]
173
174
        self.hdfobj = h5py.File(hdf5_file,'r')
175
        for k in self.hdfobj.keys():
176
            for i in range(self.hdfobj[k].shape[0]):
177
                if i % self.frame_skip: continue
178
                self.image_paths.append([k, i])
179
180
        self.get_image_func = __get_image_lazy
181
        if self.load2ram:
182
            self.image_frames = [ __get_image_lazy(i)
183
                                    for i in range(len(self.image_paths))]
184
            self.get_image_func = __get_image_eager
185
186
187
188
    def __len__(self):
189
        return len(self.image_paths)
190
191
    def __getitem__(self, index):
192
        image = self.get_image_func(index)
193
        image = self.transform(image)
194
        return image
195
196
    def get_info(self):
197
        print(self.get_image_func)
198
        return {
199
            "DataSize": self.__len__(),
200
            "Transforms": str(self.transform),
201
        }