--- a +++ b/Utilities.py @@ -0,0 +1,89 @@ +from torch import Tensor, FloatTensor, LongTensor +from torch_geometric.data import Data +from torch.utils.data import Dataset +from torch import is_tensor +from numpy import load, array, int64 +from os import listdir +from Extracting_Planes import Extract_And_Convert +from time import time + +class PairData(Data): + r""" + PyTorch Geometric data class used for the ImageCHD dataset. + """ + + def __inc__(self, key, value, *args, **kwargs): + if key == 'edge_index': + return self.x.size(0) + return super().__inc__(key, value, *args, **kwargs) + + def __init__(self, + x: Tensor | None = None, + edge_index: Tensor | None = None, + edge_attr: Tensor | None = None, + y: Tensor | int | float | None = None, + pos: Tensor | None = None, + time: Tensor | None = None, + **kwargs): + r""" + Arguments for proper ImageCHD processing: + x (Tensor): Source coronary-CT image as graph. + y (Tensor): Ground truth segmentation as graph. + edge_index (Tensor): Adjacency matrix. + adj_count (int): Source image width. + """ + super().__init__(x, edge_index, edge_attr, y, pos, time, **kwargs) + +class CHD_Dataset(Dataset): + r""" + PyTorch dataset class used for the ImageCHD dataset. + """ + + def __init__(self, metadata, directory, adjacency): + r""" + Arguments: + metadata (DataFrame): Pandas DataFrame containing dataset information. + directory (string): Path to the directory of the dataset. + adjacency (Dictionary): Dictionary of adjacency matrices. + """ + self.metadata = metadata + self.directory = directory + self.image_dir = directory + 'IMAGES/' + self.label_dir = directory + 'LABELS/' + self.adjacency = adjacency + + def __len__(self): + return len(self.metadata) + + def __getitem__(self, idx): + if is_tensor(idx): + idx = idx.tolist() + + image, label = Extract_And_Convert(path_to_image = self.image_dir \ + + str(self.metadata['index'][idx]) + '.nii.gz', + path_to_label = self.label_dir \ + + str(self.metadata['index'][idx]) + '.nii.gz', + plane_type = self.metadata['Type'][idx], + plane_index = self.metadata['Indice'][idx]) + # start = time() + adj_matrix = self.adjacency[str(self.metadata['Adjacency_count'][idx])] + # print('Adj_matrix assign time: ', time() - start) + + # start = time() + sample = PairData(x = FloatTensor(image), + edge_index = LongTensor(adj_matrix), + y = LongTensor(label), + adj_count = self.metadata['Adjacency_count'][idx]) + # print('Sample preparation time: ', time() - start) + + return sample + + def get(self, idx): + return self.__getitem__(idx) + +def __Load_Adjacency__(path): + files = listdir(path) + adjacency = {} + for f in files: + adjacency[f.split('_')[2].split('.')[0]] = array(load(path + f), dtype = int64) + return adjacency \ No newline at end of file