a b/rocaseg/datasets/sources.py
1
import os
2
import logging
3
4
from sklearn.model_selection import GroupShuffleSplit, GroupKFold
5
6
from rocaseg.datasets import (index_from_path_oai_imo,
7
                              index_from_path_okoa,
8
                              index_from_path_maknee)
9
10
11
logging.basicConfig()
12
logger = logging.getLogger('datasets')
13
logger.setLevel(logging.DEBUG)
14
15
16
def sources_from_path(path_data_root,
17
                      selection=None,
18
                      with_folds=False,
19
                      fold_num=5,
20
                      seed_trainval_test=0):
21
    """
22
23
    Args:
24
        path_data_root: str
25
26
        selection: iterable or str or None
27
28
        with_folds: bool
29
            Whether to split trainval subset into the folds.
30
        fold_num: int
31
            Number of folds.
32
        seed_trainval_test: int
33
            Random state for the trainval/test splitting.
34
35
    Returns:
36
37
    """
38
    if selection is None:
39
        selection = ('oai_imo', 'okoa', 'maknee')
40
    elif isinstance(selection, str):
41
        selection = (selection, )
42
43
    sources = dict()
44
45
    for name in selection:
46
        if name == 'oai_imo':
47
            logger.info('--- OAI iMorphics dataset ---')
48
            tmp = dict()
49
            tmp['path_root'] = os.path.join(path_data_root,
50
                                            '91_OAI_iMorphics_full_meta')
51
52
            if not os.path.exists(tmp['path_root']):
53
                logger.warning(f"Dataset {name} is not found in {tmp['path_root']}")
54
                continue
55
56
            tmp['full_df'] = index_from_path_oai_imo(tmp['path_root'])
57
            logger.info(f"Total number of samples: "
58
                        f"{len(tmp['full_df'])}")
59
60
            # Select the specific subset
61
            # Remove two series from the dataset as they are completely missing
62
            # information on patellar cartilage:
63
            #        /0.C.2/9674570/20040913/10699609/
64
            #        /1.C.2/9674570/20050829/10488714/
65
            tmp['sel_df'] = tmp['full_df'][tmp['full_df']['patient'] != '9674570']
66
            logger.info(f"Selected number of samples: "
67
                        f"{len(tmp['sel_df'])}")
68
69
            if with_folds:
70
                # Get trainval/test split
71
                tmp_groups = tmp['sel_df'].loc[:, 'patient'].values
72
                tmp_grades = tmp['sel_df'].loc[:, 'KL'].values
73
74
                tmp_gss = GroupShuffleSplit(n_splits=1, test_size=0.2,
75
                                            random_state=seed_trainval_test)
76
                tmp_idcs_trainval, tmp_idcs_test = next(tmp_gss.split(X=tmp['sel_df'],
77
                                                                      y=tmp_grades,
78
                                                                      groups=tmp_groups))
79
                tmp['trainval_df'] = tmp['sel_df'].iloc[tmp_idcs_trainval]
80
                tmp['test_df'] = tmp['sel_df'].iloc[tmp_idcs_test]
81
                logger.info(f"Made trainval-test split, number of samples: "
82
                            f"{len(tmp['trainval_df'])}, "
83
                            f"{len(tmp['test_df'])}")
84
85
                # Make k folds
86
                tmp_gkf = GroupKFold(n_splits=fold_num)
87
                tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values
88
                tmp_grades = tmp['trainval_df'].loc[:, 'KL'].values
89
90
                tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'],
91
                                                      y=tmp_grades, groups=tmp_groups)
92
            sources['oai_imo'] = tmp
93
94
        elif name == 'okoa':
95
            logger.info('--- OKOA dataset ---')
96
            tmp = dict()
97
            tmp['path_root'] = os.path.join(path_data_root,
98
                                            '32_OKOA_full_meta_rescaled')
99
100
            if not os.path.exists(tmp['path_root']):
101
                logger.warning(f"Dataset {name} is not found in {tmp['path_root']}")
102
                continue
103
104
            tmp['full_df'] = index_from_path_okoa(tmp['path_root'])
105
            logger.info(f"Total number of samples: "
106
                        f"{len(tmp['full_df'])}")
107
108
            # Select the specific subset
109
            tmp['sel_df'] = tmp['full_df']
110
            logger.info(f"Selected number of samples: "
111
                        f"{len(tmp['sel_df'])}")
112
113
            if with_folds:
114
                # Get trainval/test split
115
                tmp['trainval_df'] = tmp['sel_df'][tmp['sel_df']['subset'] == 'training']
116
                tmp['test_df'] = tmp['sel_df'][tmp['sel_df']['subset'] == 'evaluation']
117
                logger.info(f"Made trainval-test split, number of samples: "
118
                            f"{len(tmp['trainval_df'])}, "
119
                            f"{len(tmp['test_df'])}")
120
121
                # Make k folds
122
                tmp_gkf = GroupKFold(n_splits=fold_num)
123
                tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values
124
125
                tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'],
126
                                                      groups=tmp_groups)
127
            sources['okoa'] = tmp
128
129
        elif name == 'maknee':
130
            logger.info('--- MAKNEE dataset ---')
131
            tmp = dict()
132
            tmp['path_root'] = os.path.join(path_data_root,
133
                                            '42_MAKNEE_full_meta_rescaled')
134
135
            if not os.path.exists(tmp['path_root']):
136
                logger.warning(f"Dataset {name} is not found in {tmp['path_root']}")
137
                continue
138
139
            tmp['full_df'] = index_from_path_maknee(tmp['path_root'])
140
            logger.info(f"Total number of samples: "
141
                        f"{len(tmp['full_df'])}")
142
143
            # Select the specific subset
144
            tmp['sel_df'] = tmp['full_df']
145
            logger.info(f"Selected number of samples: "
146
                        f"{len(tmp['sel_df'])}")
147
148
            # Get trainval/test split
149
            tmp_groups = tmp['sel_df'].loc[:, 'patient'].values
150
151
            tmp_gss = GroupShuffleSplit(n_splits=1, test_size=0.2,
152
                                        random_state=seed_trainval_test)
153
            tmp_idcs_trainval, tmp_idcs_test = next(tmp_gss.split(X=tmp['sel_df'],
154
                                                                  groups=tmp_groups))
155
            tmp['trainval_df'] = tmp['sel_df'].iloc[tmp_idcs_trainval]
156
            tmp['test_df'] = tmp['sel_df'].iloc[tmp_idcs_test]
157
            logger.info(f"Made trainval-test split, number of samples: "
158
                        f"{len(tmp['trainval_df'])}, "
159
                        f"{len(tmp['test_df'])}")
160
161
            if with_folds:
162
                # Make k folds
163
                tmp_gkf = GroupKFold(n_splits=fold_num)
164
                tmp_groups = tmp['trainval_df'].loc[:, 'patient'].values
165
166
                tmp['trainval_folds'] = tmp_gkf.split(X=tmp['trainval_df'],
167
                                                      groups=tmp_groups)
168
            sources['maknee'] = tmp
169
170
        else:
171
            raise ValueError(f'Unknown dataset `{name}`')
172
173
    return sources