Diff of /Utilities.py [000000] .. [b52eda]

Switch to unified view

a b/Utilities.py
1
from torch import Tensor, FloatTensor, LongTensor
2
from torch_geometric.data import Data
3
from torch.utils.data import Dataset
4
from torch import is_tensor
5
from numpy import load, array, int64
6
from os import listdir
7
from Extracting_Planes import Extract_And_Convert
8
from time import time
9
10
class PairData(Data):
11
  r"""
12
    PyTorch Geometric data class used for the ImageCHD dataset.
13
  """
14
15
  def __inc__(self, key, value, *args, **kwargs):
16
      if key == 'edge_index':
17
          return self.x.size(0)
18
      return super().__inc__(key, value, *args, **kwargs)
19
  
20
  def __init__(self,
21
               x: Tensor | None = None,
22
               edge_index: Tensor | None = None,
23
               edge_attr: Tensor | None = None,
24
               y: Tensor | int | float | None = None,
25
               pos: Tensor | None = None,
26
               time: Tensor | None = None,
27
               **kwargs):
28
     r"""
29
        Arguments for proper ImageCHD processing:
30
            x (Tensor): Source coronary-CT image as graph.
31
            y (Tensor): Ground truth segmentation as graph.
32
            edge_index (Tensor): Adjacency matrix.
33
            adj_count (int): Source image width.
34
     """
35
     super().__init__(x, edge_index, edge_attr, y, pos, time, **kwargs)
36
37
class CHD_Dataset(Dataset):
38
  r"""
39
    PyTorch dataset class used for the ImageCHD dataset.
40
  """
41
42
  def __init__(self, metadata, directory, adjacency):
43
    r"""
44
    Arguments:
45
        metadata (DataFrame): Pandas DataFrame containing dataset information.
46
        directory (string):   Path to the directory of the dataset.
47
        adjacency (Dictionary): Dictionary of adjacency matrices.
48
    """
49
    self.metadata = metadata
50
    self.directory = directory
51
    self.image_dir = directory + 'IMAGES/'
52
    self.label_dir = directory + 'LABELS/'
53
    self.adjacency = adjacency
54
55
  def __len__(self):
56
    return len(self.metadata)
57
58
  def __getitem__(self, idx):
59
    if is_tensor(idx):
60
      idx = idx.tolist()
61
    
62
    image, label = Extract_And_Convert(path_to_image = self.image_dir \
63
                                        + str(self.metadata['index'][idx]) + '.nii.gz',
64
                                       path_to_label = self.label_dir \
65
                                        + str(self.metadata['index'][idx]) + '.nii.gz',
66
                                       plane_type = self.metadata['Type'][idx],
67
                                       plane_index = self.metadata['Indice'][idx])
68
    # start = time()
69
    adj_matrix = self.adjacency[str(self.metadata['Adjacency_count'][idx])]
70
    # print('Adj_matrix assign time: ', time() - start)
71
72
    # start = time()
73
    sample = PairData(x = FloatTensor(image),
74
                      edge_index = LongTensor(adj_matrix),
75
                      y = LongTensor(label),
76
                      adj_count = self.metadata['Adjacency_count'][idx])
77
    # print('Sample preparation time: ', time() - start)
78
    
79
    return sample
80
81
  def get(self, idx):
82
    return self.__getitem__(idx)
83
  
84
def __Load_Adjacency__(path):
85
  files = listdir(path)
86
  adjacency = {}
87
  for f in files:
88
    adjacency[f.split('_')[2].split('.')[0]] = array(load(path + f), dtype = int64)
89
  return adjacency