Diff of /src/hs_dataset.py [000000] .. [3475df]

Switch to unified view

a b/src/hs_dataset.py
1
"""
2
Developed by: Daniel Crovo
3
Dataset class definition
4
5
"""
6
from torch.utils.data import Dataset
7
import os
8
from PIL import Image 
9
from pydicom import dcmread
10
from dicom_utils import apply_windowing
11
import numpy as np
12
13
class HSDataset(Dataset): 
14
    def __init__(self, image_dir, mask_dir, transform = None, w_level=35, w_width=350, normalized =False) -> None:
15
        """Initialises the Heart Segmentation Dataset class. The dataset asumes all the dicom files are in one single folder aswell as the masks
16
            both the images (dicom files) and masks should have the same name for each slice
17
        Args:
18
            image_dir (string): Path to the dicom directory
19
            mask_dir (string): Path to the masks directory 
20
            transform (_type_, optional):Aplied transformations Defaults to None.
21
            w_level (int, optional): The window level (center). Defaults to 35.
22
            w_width (int, optional): The width of the window Defaults to 400.
23
            normalized (bool, optional): wheter to normalized the windowed array after applying windoing
24
        """
25
        super().__init__()
26
        
27
        self.image_dir = image_dir
28
        self.mask_dir = mask_dir
29
        self.transform = transform
30
        self.w_level = w_level
31
        self.w_width = w_width
32
        self.normalized = normalized
33
        self.images = os.listdir(self.image_dir)
34
        self.masks = os.listdir(self.mask_dir)
35
36
    def __len__(self):
37
        return len(self.images)
38
    
39
    def __getitem__(self, idx):
40
        img_path = os.path.join(self.image_dir, self.images[idx])
41
        mask_path = os.path.join(self.mask_dir, self.masks[idx])
42
        ds = dcmread(img_path)
43
        dicom_img = apply_windowing(ds = ds,window_center = self.w_level, 
44
                                    window_width =self.w_width, normalized = self.normalized)
45
        image = Image.fromarray(dicom_img).convert('RGB')
46
        image = np.array(image)
47
        mask = Image.open(mask_path).convert('L')
48
        mask = np.array(mask)
49
        mask[mask == 255.0] = 1.0
50
51
52
        if self.transform is not None:
53
            transformations = self.transform(image = image, mask = mask)
54
            image = transformations['image']
55
            mask = transformations['mask']
56
            
57
        return image, mask
58