Switch to unified view

a b/rocaseg/datasets/dataset_maknee.py
1
import os
2
import glob
3
import logging
4
from collections import defaultdict
5
6
import numpy as np
7
import pandas as pd
8
9
import cv2
10
from torch.utils.data.dataset import Dataset
11
12
13
logging.basicConfig()
14
logger = logging.getLogger('dataset')
15
logger.setLevel(logging.DEBUG)
16
17
18
def index_from_path_maknee(path_root, force=False):
19
    fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv')
20
    fname_meta_base = os.path.join(path_root, 'meta_base.csv')
21
22
    if not os.path.exists(fname_meta_dyn) or force:
23
        fnames_image = glob.glob(
24
            os.path.join(path_root, '**', 'images', '*.png'), recursive=True)
25
        logger.info('{} images found'.format(len(fnames_image)))
26
        fnames_mask = glob.glob(
27
            os.path.join(path_root, '**', 'masks', '*.png'), recursive=True)
28
        logger.info('{} masks found'.format(len(fnames_mask)))
29
30
        df_meta = pd.read_csv(fname_meta_base,
31
                              dtype={'patient': str,
32
                                     'release': str,
33
                                     'sequence': str,
34
                                     'side': str,
35
                                     'slice_idx': int,
36
                                     'pixel_spacing_0': float,
37
                                     'pixel_spacing_1': float,
38
                                     'slice_thickness': float,
39
                                     'KL': int},
40
                              index_col=False)
41
42
        if len(fnames_image) != len(df_meta):
43
            raise ValueError("Number of images doesn't match with the metadata")
44
45
        df_meta['path_image'] = [os.path.join(path_root, e)
46
                                 for e in df_meta['path_rel_image']]
47
48
        # Sort the records
49
        df_meta_sorted = (df_meta
50
                          .sort_values(['patient', 'sequence', 'slice_idx'])
51
                          .reset_index()
52
                          .drop('index', axis=1))
53
54
        df_meta_sorted.to_csv(fname_meta_dyn, index=False)
55
    else:
56
        df_meta_sorted = pd.read_csv(fname_meta_dyn,
57
                                     dtype={'patient': str,
58
                                            'release': str,
59
                                            'sequence': str,
60
                                            'side': str,
61
                                            'slice_idx': int,
62
                                            'pixel_spacing_0': float,
63
                                            'pixel_spacing_1': float,
64
                                            'slice_thickness': float,
65
                                            'KL': int},
66
                                     index_col=False)
67
68
    return df_meta_sorted
69
70
71
def read_image(path_file):
72
    image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE)
73
    return image.reshape((1, *image.shape))
74
75
76
class DatasetMAKNEESagittal2d(Dataset):
77
    def __init__(self, df_meta, mask_mode=None, name=None, transforms=None,
78
                 sample_mode='x_y', **kwargs):
79
        logger.warning('Redundant dataset init arguments:\n{}'
80
                       .format(repr(kwargs)))
81
82
        self.df_meta = df_meta
83
        self.mask_mode = mask_mode
84
        self.name = name
85
        self.transforms = transforms
86
        self.sample_mode = sample_mode
87
88
    def __len__(self):
89
        return len(self.df_meta)
90
91
    def _getitem_x_y(self, idx):
92
        image = read_image(self.df_meta['path_image'].iloc[idx])
93
        mask = np.zeros_like(image)
94
95
        # Apply transformations
96
        if self.transforms is not None:
97
            for t in self.transforms:
98
                if hasattr(t, 'randomize'):
99
                    t.randomize()
100
                image, mask = t(image, mask)
101
102
        tmp = dict(self.df_meta.iloc[idx])
103
        tmp['image'] = image
104
        tmp['mask'] = mask
105
106
        tmp['xs'] = tmp['image']
107
        tmp['ys'] = tmp['mask']
108
        return tmp
109
110
    def __getitem__(self, idx):
111
        if self.sample_mode == 'x_y':
112
            return self._getitem_x_y(idx)
113
        else:
114
            raise ValueError('Invalid `sample_mode`')
115
116
    def describe(self):
117
        summary = defaultdict(float)
118
        for i in range(len(self)):
119
            if self.sample_mode == 'x_y':
120
                _, mask = self.__getitem__(i)
121
            else:
122
                mask = self.__getitem__(i)['mask']
123
            summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2))
124
        summary['class_importance'] = \
125
            np.sum(summary['num_class_pixels']) / summary['num_class_pixels']
126
        summary['class_importance'] /= np.sum(summary['class_importance'])
127
        logger.info('Dataset statistics:')
128
        logger.info(sorted(summary.items()))