a b/dataset_dataloader.py
1
import os
2
3
import numpy as np
4
import pandas as pd
5
import cv2
6
7
from sklearn.model_selection import train_test_split
8
9
import torch
10
import torch.nn as nn
11
from torch.utils.data import Dataset, DataLoader
12
13
from albumentations.pytorch import ToTensor, ToTensorV2 
14
from albumentations import (HorizontalFlip,
15
                            VerticalFlip,
16
                            Normalize,
17
                            Compose)
18
19
20
class LungsDataset(Dataset):
21
    def __init__(self, 
22
                 imgs_dir: str,
23
                 masks_dir:str,
24
                 df: pd.DataFrame,
25
                 phase: str):
26
        """Initialization."""
27
        self.root_imgs_dir = imgs_dir
28
        self.root_masks_dir = masks_dir
29
        self.df = df
30
        self.augmentations = get_augmentations(phase)
31
    
32
    def __len__(self):
33
        return len(self.df)
34
35
    def __getitem__(self, idx):
36
        img_name = self.df.loc[idx, "ImageId"]
37
        mask_name = self.df.loc[idx, "MaskId"]
38
        img_path = os.path.join(self.root_imgs_dir, img_name)
39
        mask_path = os.path.join(self.root_masks_dir, mask_name)
40
        img = cv2.imread(img_path)
41
        mask = cv2.imread(mask_path)
42
        mask[mask < 240] = 0    # remove artifacts
43
        mask[mask > 0] = 1
44
45
        augmented = self.augmentations(image=img, 
46
                                       mask=mask.astype(np.float32))
47
        img = augmented['image']
48
        mask = augmented['mask'].permute(2, 0, 1)
49
50
        return img, mask
51
52
53
def get_augmentations(phase,
54
                   mean: tuple = (0.485, 0.456, 0.406),
55
                   std: tuple = (0.229, 0.224, 0.225),):
56
    list_transforms = []
57
    if phase == "train":
58
        list_transforms.extend(
59
            [
60
                VerticalFlip(p=0.5), 
61
            ]
62
        )
63
    list_transforms.extend(
64
        [
65
            Normalize(mean=mean, std=std, p=1),
66
            #ToTensor(num_classes=3, sigmoid=False),
67
            ToTensorV2(),
68
        ]
69
    )
70
    list_trfms = Compose(list_transforms)
71
    return list_trfms
72
73
74
def get_dataloader(
75
    imgs_dir: str,
76
    masks_dir: str,
77
    path_to_csv: str,
78
    phase: str,
79
    batch_size: int = 8,
80
    num_workers: int = 6,
81
    test_size: float = 0.2,
82
):
83
    '''Returns: dataloader for the model training'''
84
    df = pd.read_csv(path_to_csv)
85
    
86
87
    train_df, val_df = train_test_split(df, 
88
                                          test_size=test_size, 
89
                                          random_state=69)
90
    train_df, val_df = train_df.reset_index(drop=True), val_df.reset_index(drop=True)
91
92
    df = train_df if phase == "train" else val_df
93
    image_dataset = LungsDataset(imgs_dir, masks_dir, df, phase)
94
    dataloader = DataLoader(
95
        image_dataset,
96
        batch_size=batch_size,
97
        num_workers=num_workers,
98
        pin_memory=True,
99
        shuffle=True,   
100
    )
101
102
    return dataloader