a b/adpkd_segmentation/datasets/splits.py
1
from sklearn.model_selection import train_test_split
2
3
4
class GenSplit:
5
    def __init__(self, train=0.7, val=0.15, test=0.15, seed=1):
6
        super().__init__()
7
8
        self.train = train
9
        self.val = val
10
        self.test = test
11
        self.seed = seed
12
13
    def __call__(self, all_idxs):
14
15
        # split train from validation-test
16
        train_idxs, test_val_idxs = train_test_split(
17
            all_idxs, test_size=(self.val + self.test), random_state=self.seed
18
        )
19
20
        # split validation from test
21
        val_idxs, test_idxs = train_test_split(
22
            test_val_idxs,
23
            test_size=(self.test / (self.test + self.val)),
24
            random_state=self.seed,
25
        )
26
27
        self.train_idxs = train_idxs
28
        self.val_idxs = val_idxs
29
        self.test_idxs = test_idxs
30
31
        print(
32
            "The number of (filtered) train patients: {}".format(
33
                len(self.train_idxs)
34
            )
35
        )
36
        print(
37
            "The number of (filtered) validation patients: {}".format(
38
                len(self.val_idxs)
39
            )
40
        )
41
        print(
42
            "The number of (filtered) test patients: {}".format(
43
                len(self.test_idxs)
44
            )
45
        )
46
47
        return {
48
            "train": self.train_idxs,
49
            "val": self.val_idxs,
50
            "test": self.test_idxs,
51
        }