|
a |
|
b/src/hs_dataset.py |
|
|
1 |
""" |
|
|
2 |
Developed by: Daniel Crovo |
|
|
3 |
Dataset class definition |
|
|
4 |
|
|
|
5 |
""" |
|
|
6 |
from torch.utils.data import Dataset |
|
|
7 |
import os |
|
|
8 |
from PIL import Image |
|
|
9 |
from pydicom import dcmread |
|
|
10 |
from dicom_utils import apply_windowing |
|
|
11 |
import numpy as np |
|
|
12 |
|
|
|
13 |
class HSDataset(Dataset): |
|
|
14 |
def __init__(self, image_dir, mask_dir, transform = None, w_level=35, w_width=350, normalized =False) -> None: |
|
|
15 |
"""Initialises the Heart Segmentation Dataset class. The dataset asumes all the dicom files are in one single folder aswell as the masks |
|
|
16 |
both the images (dicom files) and masks should have the same name for each slice |
|
|
17 |
Args: |
|
|
18 |
image_dir (string): Path to the dicom directory |
|
|
19 |
mask_dir (string): Path to the masks directory |
|
|
20 |
transform (_type_, optional):Aplied transformations Defaults to None. |
|
|
21 |
w_level (int, optional): The window level (center). Defaults to 35. |
|
|
22 |
w_width (int, optional): The width of the window Defaults to 400. |
|
|
23 |
normalized (bool, optional): wheter to normalized the windowed array after applying windoing |
|
|
24 |
""" |
|
|
25 |
super().__init__() |
|
|
26 |
|
|
|
27 |
self.image_dir = image_dir |
|
|
28 |
self.mask_dir = mask_dir |
|
|
29 |
self.transform = transform |
|
|
30 |
self.w_level = w_level |
|
|
31 |
self.w_width = w_width |
|
|
32 |
self.normalized = normalized |
|
|
33 |
self.images = os.listdir(self.image_dir) |
|
|
34 |
self.masks = os.listdir(self.mask_dir) |
|
|
35 |
|
|
|
36 |
def __len__(self): |
|
|
37 |
return len(self.images) |
|
|
38 |
|
|
|
39 |
def __getitem__(self, idx): |
|
|
40 |
img_path = os.path.join(self.image_dir, self.images[idx]) |
|
|
41 |
mask_path = os.path.join(self.mask_dir, self.masks[idx]) |
|
|
42 |
ds = dcmread(img_path) |
|
|
43 |
dicom_img = apply_windowing(ds = ds,window_center = self.w_level, |
|
|
44 |
window_width =self.w_width, normalized = self.normalized) |
|
|
45 |
image = Image.fromarray(dicom_img).convert('RGB') |
|
|
46 |
image = np.array(image) |
|
|
47 |
mask = Image.open(mask_path).convert('L') |
|
|
48 |
mask = np.array(mask) |
|
|
49 |
mask[mask == 255.0] = 1.0 |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
if self.transform is not None: |
|
|
53 |
transformations = self.transform(image = image, mask = mask) |
|
|
54 |
image = transformations['image'] |
|
|
55 |
mask = transformations['mask'] |
|
|
56 |
|
|
|
57 |
return image, mask |
|
|
58 |
|