|
a |
|
b/rocaseg/datasets/dataset_maknee.py |
|
|
1 |
import os |
|
|
2 |
import glob |
|
|
3 |
import logging |
|
|
4 |
from collections import defaultdict |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
import pandas as pd |
|
|
8 |
|
|
|
9 |
import cv2 |
|
|
10 |
from torch.utils.data.dataset import Dataset |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
logging.basicConfig() |
|
|
14 |
logger = logging.getLogger('dataset') |
|
|
15 |
logger.setLevel(logging.DEBUG) |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def index_from_path_maknee(path_root, force=False): |
|
|
19 |
fname_meta_dyn = os.path.join(path_root, 'meta_dynamic.csv') |
|
|
20 |
fname_meta_base = os.path.join(path_root, 'meta_base.csv') |
|
|
21 |
|
|
|
22 |
if not os.path.exists(fname_meta_dyn) or force: |
|
|
23 |
fnames_image = glob.glob( |
|
|
24 |
os.path.join(path_root, '**', 'images', '*.png'), recursive=True) |
|
|
25 |
logger.info('{} images found'.format(len(fnames_image))) |
|
|
26 |
fnames_mask = glob.glob( |
|
|
27 |
os.path.join(path_root, '**', 'masks', '*.png'), recursive=True) |
|
|
28 |
logger.info('{} masks found'.format(len(fnames_mask))) |
|
|
29 |
|
|
|
30 |
df_meta = pd.read_csv(fname_meta_base, |
|
|
31 |
dtype={'patient': str, |
|
|
32 |
'release': str, |
|
|
33 |
'sequence': str, |
|
|
34 |
'side': str, |
|
|
35 |
'slice_idx': int, |
|
|
36 |
'pixel_spacing_0': float, |
|
|
37 |
'pixel_spacing_1': float, |
|
|
38 |
'slice_thickness': float, |
|
|
39 |
'KL': int}, |
|
|
40 |
index_col=False) |
|
|
41 |
|
|
|
42 |
if len(fnames_image) != len(df_meta): |
|
|
43 |
raise ValueError("Number of images doesn't match with the metadata") |
|
|
44 |
|
|
|
45 |
df_meta['path_image'] = [os.path.join(path_root, e) |
|
|
46 |
for e in df_meta['path_rel_image']] |
|
|
47 |
|
|
|
48 |
# Sort the records |
|
|
49 |
df_meta_sorted = (df_meta |
|
|
50 |
.sort_values(['patient', 'sequence', 'slice_idx']) |
|
|
51 |
.reset_index() |
|
|
52 |
.drop('index', axis=1)) |
|
|
53 |
|
|
|
54 |
df_meta_sorted.to_csv(fname_meta_dyn, index=False) |
|
|
55 |
else: |
|
|
56 |
df_meta_sorted = pd.read_csv(fname_meta_dyn, |
|
|
57 |
dtype={'patient': str, |
|
|
58 |
'release': str, |
|
|
59 |
'sequence': str, |
|
|
60 |
'side': str, |
|
|
61 |
'slice_idx': int, |
|
|
62 |
'pixel_spacing_0': float, |
|
|
63 |
'pixel_spacing_1': float, |
|
|
64 |
'slice_thickness': float, |
|
|
65 |
'KL': int}, |
|
|
66 |
index_col=False) |
|
|
67 |
|
|
|
68 |
return df_meta_sorted |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def read_image(path_file): |
|
|
72 |
image = cv2.imread(path_file, cv2.IMREAD_GRAYSCALE) |
|
|
73 |
return image.reshape((1, *image.shape)) |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
class DatasetMAKNEESagittal2d(Dataset): |
|
|
77 |
def __init__(self, df_meta, mask_mode=None, name=None, transforms=None, |
|
|
78 |
sample_mode='x_y', **kwargs): |
|
|
79 |
logger.warning('Redundant dataset init arguments:\n{}' |
|
|
80 |
.format(repr(kwargs))) |
|
|
81 |
|
|
|
82 |
self.df_meta = df_meta |
|
|
83 |
self.mask_mode = mask_mode |
|
|
84 |
self.name = name |
|
|
85 |
self.transforms = transforms |
|
|
86 |
self.sample_mode = sample_mode |
|
|
87 |
|
|
|
88 |
def __len__(self): |
|
|
89 |
return len(self.df_meta) |
|
|
90 |
|
|
|
91 |
def _getitem_x_y(self, idx): |
|
|
92 |
image = read_image(self.df_meta['path_image'].iloc[idx]) |
|
|
93 |
mask = np.zeros_like(image) |
|
|
94 |
|
|
|
95 |
# Apply transformations |
|
|
96 |
if self.transforms is not None: |
|
|
97 |
for t in self.transforms: |
|
|
98 |
if hasattr(t, 'randomize'): |
|
|
99 |
t.randomize() |
|
|
100 |
image, mask = t(image, mask) |
|
|
101 |
|
|
|
102 |
tmp = dict(self.df_meta.iloc[idx]) |
|
|
103 |
tmp['image'] = image |
|
|
104 |
tmp['mask'] = mask |
|
|
105 |
|
|
|
106 |
tmp['xs'] = tmp['image'] |
|
|
107 |
tmp['ys'] = tmp['mask'] |
|
|
108 |
return tmp |
|
|
109 |
|
|
|
110 |
def __getitem__(self, idx): |
|
|
111 |
if self.sample_mode == 'x_y': |
|
|
112 |
return self._getitem_x_y(idx) |
|
|
113 |
else: |
|
|
114 |
raise ValueError('Invalid `sample_mode`') |
|
|
115 |
|
|
|
116 |
def describe(self): |
|
|
117 |
summary = defaultdict(float) |
|
|
118 |
for i in range(len(self)): |
|
|
119 |
if self.sample_mode == 'x_y': |
|
|
120 |
_, mask = self.__getitem__(i) |
|
|
121 |
else: |
|
|
122 |
mask = self.__getitem__(i)['mask'] |
|
|
123 |
summary['num_class_pixels'] += mask.numpy().sum(axis=(1, 2)) |
|
|
124 |
summary['class_importance'] = \ |
|
|
125 |
np.sum(summary['num_class_pixels']) / summary['num_class_pixels'] |
|
|
126 |
summary['class_importance'] /= np.sum(summary['class_importance']) |
|
|
127 |
logger.info('Dataset statistics:') |
|
|
128 |
logger.info(sorted(summary.items())) |