Diff of /Utilities.py [000000] .. [b52eda]

Switch to side-by-side view

--- a
+++ b/Utilities.py
@@ -0,0 +1,89 @@
+from torch import Tensor, FloatTensor, LongTensor
+from torch_geometric.data import Data
+from torch.utils.data import Dataset
+from torch import is_tensor
+from numpy import load, array, int64
+from os import listdir
+from Extracting_Planes import Extract_And_Convert
+from time import time
+
+class PairData(Data):
+  r"""
+    PyTorch Geometric data class used for the ImageCHD dataset.
+  """
+
+  def __inc__(self, key, value, *args, **kwargs):
+      if key == 'edge_index':
+          return self.x.size(0)
+      return super().__inc__(key, value, *args, **kwargs)
+  
+  def __init__(self,
+               x: Tensor | None = None,
+               edge_index: Tensor | None = None,
+               edge_attr: Tensor | None = None,
+               y: Tensor | int | float | None = None,
+               pos: Tensor | None = None,
+               time: Tensor | None = None,
+               **kwargs):
+     r"""
+        Arguments for proper ImageCHD processing:
+            x (Tensor): Source coronary-CT image as graph.
+            y (Tensor): Ground truth segmentation as graph.
+            edge_index (Tensor): Adjacency matrix.
+            adj_count (int): Source image width.
+     """
+     super().__init__(x, edge_index, edge_attr, y, pos, time, **kwargs)
+
+class CHD_Dataset(Dataset):
+  r"""
+    PyTorch dataset class used for the ImageCHD dataset.
+  """
+
+  def __init__(self, metadata, directory, adjacency):
+    r"""
+    Arguments:
+        metadata (DataFrame): Pandas DataFrame containing dataset information.
+        directory (string):   Path to the directory of the dataset.
+        adjacency (Dictionary): Dictionary of adjacency matrices.
+    """
+    self.metadata = metadata
+    self.directory = directory
+    self.image_dir = directory + 'IMAGES/'
+    self.label_dir = directory + 'LABELS/'
+    self.adjacency = adjacency
+
+  def __len__(self):
+    return len(self.metadata)
+
+  def __getitem__(self, idx):
+    if is_tensor(idx):
+      idx = idx.tolist()
+    
+    image, label = Extract_And_Convert(path_to_image = self.image_dir \
+                                        + str(self.metadata['index'][idx]) + '.nii.gz',
+                                       path_to_label = self.label_dir \
+                                        + str(self.metadata['index'][idx]) + '.nii.gz',
+                                       plane_type = self.metadata['Type'][idx],
+                                       plane_index = self.metadata['Indice'][idx])
+    # start = time()
+    adj_matrix = self.adjacency[str(self.metadata['Adjacency_count'][idx])]
+    # print('Adj_matrix assign time: ', time() - start)
+
+    # start = time()
+    sample = PairData(x = FloatTensor(image),
+                      edge_index = LongTensor(adj_matrix),
+                      y = LongTensor(label),
+                      adj_count = self.metadata['Adjacency_count'][idx])
+    # print('Sample preparation time: ', time() - start)
+    
+    return sample
+
+  def get(self, idx):
+    return self.__getitem__(idx)
+  
+def __Load_Adjacency__(path):
+  files = listdir(path)
+  adjacency = {}
+  for f in files:
+    adjacency[f.split('_')[2].split('.')[0]] = array(load(path + f), dtype = int64)
+  return adjacency
\ No newline at end of file