|
a |
|
b/Utilities.py |
|
|
1 |
from torch import Tensor, FloatTensor, LongTensor |
|
|
2 |
from torch_geometric.data import Data |
|
|
3 |
from torch.utils.data import Dataset |
|
|
4 |
from torch import is_tensor |
|
|
5 |
from numpy import load, array, int64 |
|
|
6 |
from os import listdir |
|
|
7 |
from Extracting_Planes import Extract_And_Convert |
|
|
8 |
from time import time |
|
|
9 |
|
|
|
10 |
class PairData(Data): |
|
|
11 |
r""" |
|
|
12 |
PyTorch Geometric data class used for the ImageCHD dataset. |
|
|
13 |
""" |
|
|
14 |
|
|
|
15 |
def __inc__(self, key, value, *args, **kwargs): |
|
|
16 |
if key == 'edge_index': |
|
|
17 |
return self.x.size(0) |
|
|
18 |
return super().__inc__(key, value, *args, **kwargs) |
|
|
19 |
|
|
|
20 |
def __init__(self, |
|
|
21 |
x: Tensor | None = None, |
|
|
22 |
edge_index: Tensor | None = None, |
|
|
23 |
edge_attr: Tensor | None = None, |
|
|
24 |
y: Tensor | int | float | None = None, |
|
|
25 |
pos: Tensor | None = None, |
|
|
26 |
time: Tensor | None = None, |
|
|
27 |
**kwargs): |
|
|
28 |
r""" |
|
|
29 |
Arguments for proper ImageCHD processing: |
|
|
30 |
x (Tensor): Source coronary-CT image as graph. |
|
|
31 |
y (Tensor): Ground truth segmentation as graph. |
|
|
32 |
edge_index (Tensor): Adjacency matrix. |
|
|
33 |
adj_count (int): Source image width. |
|
|
34 |
""" |
|
|
35 |
super().__init__(x, edge_index, edge_attr, y, pos, time, **kwargs) |
|
|
36 |
|
|
|
37 |
class CHD_Dataset(Dataset): |
|
|
38 |
r""" |
|
|
39 |
PyTorch dataset class used for the ImageCHD dataset. |
|
|
40 |
""" |
|
|
41 |
|
|
|
42 |
def __init__(self, metadata, directory, adjacency): |
|
|
43 |
r""" |
|
|
44 |
Arguments: |
|
|
45 |
metadata (DataFrame): Pandas DataFrame containing dataset information. |
|
|
46 |
directory (string): Path to the directory of the dataset. |
|
|
47 |
adjacency (Dictionary): Dictionary of adjacency matrices. |
|
|
48 |
""" |
|
|
49 |
self.metadata = metadata |
|
|
50 |
self.directory = directory |
|
|
51 |
self.image_dir = directory + 'IMAGES/' |
|
|
52 |
self.label_dir = directory + 'LABELS/' |
|
|
53 |
self.adjacency = adjacency |
|
|
54 |
|
|
|
55 |
def __len__(self): |
|
|
56 |
return len(self.metadata) |
|
|
57 |
|
|
|
58 |
def __getitem__(self, idx): |
|
|
59 |
if is_tensor(idx): |
|
|
60 |
idx = idx.tolist() |
|
|
61 |
|
|
|
62 |
image, label = Extract_And_Convert(path_to_image = self.image_dir \ |
|
|
63 |
+ str(self.metadata['index'][idx]) + '.nii.gz', |
|
|
64 |
path_to_label = self.label_dir \ |
|
|
65 |
+ str(self.metadata['index'][idx]) + '.nii.gz', |
|
|
66 |
plane_type = self.metadata['Type'][idx], |
|
|
67 |
plane_index = self.metadata['Indice'][idx]) |
|
|
68 |
# start = time() |
|
|
69 |
adj_matrix = self.adjacency[str(self.metadata['Adjacency_count'][idx])] |
|
|
70 |
# print('Adj_matrix assign time: ', time() - start) |
|
|
71 |
|
|
|
72 |
# start = time() |
|
|
73 |
sample = PairData(x = FloatTensor(image), |
|
|
74 |
edge_index = LongTensor(adj_matrix), |
|
|
75 |
y = LongTensor(label), |
|
|
76 |
adj_count = self.metadata['Adjacency_count'][idx]) |
|
|
77 |
# print('Sample preparation time: ', time() - start) |
|
|
78 |
|
|
|
79 |
return sample |
|
|
80 |
|
|
|
81 |
def get(self, idx): |
|
|
82 |
return self.__getitem__(idx) |
|
|
83 |
|
|
|
84 |
def __Load_Adjacency__(path): |
|
|
85 |
files = listdir(path) |
|
|
86 |
adjacency = {} |
|
|
87 |
for f in files: |
|
|
88 |
adjacency[f.split('_')[2].split('.')[0]] = array(load(path + f), dtype = int64) |
|
|
89 |
return adjacency |