Diff of /clinical_ts/stratify.py [000000] .. [134fd7]

Switch to unified view

a b/clinical_ts/stratify.py
1
import numpy as np
2
3
def stratify(data, classes, ratios, samples_per_group=None):
4
    """Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011)
5
6
    data is a list of lists: a list of labels, for each sample (possibly containing duplicates not multi-hot encoded).
7
    
8
    classes is the list of classes each label can take
9
10
    ratios is a list, summing to 1, of how the dataset should be split
11
12
    samples_per_group: list with number of samples per patient/group
13
14
    """
15
    np.random.seed(0) # fix the random seed
16
17
    # data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for patient i (possibly multiple identical entries)
18
19
    if(samples_per_group is None):
20
        samples_per_group = np.ones(len(data))
21
        
22
    #size is the number of ecgs
23
    size = np.sum(samples_per_group)
24
25
    # Organize data per label: for each label l, per_label_data[l] contains the list of patients
26
    # in data which have this label (potentially multiple identical entries)
27
    per_label_data = {c: [] for c in classes}
28
    for i, d in enumerate(data):
29
        for l in d:
30
            per_label_data[l].append(i)
31
32
    # In order not to compute lengths each time, they are tracked here.
33
    subset_sizes = [r * size for r in ratios] #list of subset_sizes in terms of ecgs
34
    per_label_subset_sizes = { c: [r * len(per_label_data[c]) for r in ratios] for c in classes } #dictionary with label: list of subset sizes in terms of patients
35
36
    # For each subset we want, the set of sample-ids which should end up in it
37
    stratified_data_ids = [set() for _ in range(len(ratios))] #initialize empty
38
39
    # For each sample in the data set
40
    print("Starting fold distribution...")
41
    size_prev=size+1 #just for output
42
    while size > 0:
43
        if(int(size_prev/1000) > int(size/1000)):
44
            print("Remaining entries to distribute:",size,"non-empty labels:", np.sum([1 for l, label_data in per_label_data.items() if len(label_data)>0]))
45
        size_prev=size
46
        # Compute |Di| 
47
        lengths = {
48
            l: len(label_data)
49
            for l, label_data in per_label_data.items()
50
        } #dictionary label: number of ecgs with this label that have not been assigned to a fold yet
51
        try:
52
            # Find label of smallest |Di|
53
            label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get)
54
        except ValueError:
55
            # If the dictionary in `min` is empty we get a Value Error. 
56
            # This can happen if there are unlabeled samples.
57
            # In this case, `size` would be > 0 but only samples without label would remain.
58
            # "No label" could be a class in itself: it's up to you to format your data accordingly.
59
            break
60
        # For each patient with label `label` get patient and corresponding counts
61
        unique_samples, unique_counts = np.unique(per_label_data[label],return_counts=True)
62
        idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1]
63
        unique_samples = unique_samples[idxs_sorted] # this is a list of all patient ids with this label sort by size descending
64
        unique_counts =  unique_counts[idxs_sorted] # these are the corresponding counts
65
        
66
        # loop through all patient ids with this label
67
        for current_id, current_count in zip(unique_samples,unique_counts):
68
            
69
            subset_sizes_for_label = per_label_subset_sizes[label] #current subset sizes for the chosen label
70
71
            # Find argmax clj i.e. subset in greatest need of the current label
72
            largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten()
73
            
74
            # if there is a single best choice: assign it
75
            if len(largest_subsets) == 1:
76
                subset = largest_subsets[0]
77
            # If there is more than one such subset, find the one in greatest need of any label
78
            else:
79
                largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(np.array(subset_sizes)[largest_subsets])).flatten()
80
                subset = largest_subsets[np.random.choice(largest_subsets2)]
81
82
            # Store the sample's id in the selected subset
83
            stratified_data_ids[subset].add(current_id)
84
85
            # There is current_count fewer samples to distribute
86
            size -= samples_per_group[current_id]
87
            # The selected subset needs current_count fewer samples
88
            subset_sizes[subset] -= samples_per_group[current_id]
89
90
            # In the selected subset, there is one more example for each label
91
            # the current sample has
92
            for l in data[current_id]:
93
                per_label_subset_sizes[l][subset] -= 1
94
               
95
            # Remove the sample from the dataset, meaning from all per_label dataset created
96
            for x in per_label_data.keys():
97
                per_label_data[x] = [y for y in per_label_data[x] if y!=current_id]
98
              
99
    # Create the stratified dataset as a list of subsets, each containing the orginal labels
100
    stratified_data_ids = [sorted(strat) for strat in stratified_data_ids]
101
    #stratified_data = [
102
    #    [data[i] for i in strat] for strat in stratified_data_ids
103
    #]
104
105
    # Return both the stratified indexes, to be used to sample the `features` associated with your labels
106
    # And the stratified labels dataset
107
108
    #return stratified_data_ids, stratified_data
109
    return stratified_data_ids