|
a |
|
b/clinical_ts/stratify.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
def stratify(data, classes, ratios, samples_per_group=None): |
|
|
4 |
"""Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011) |
|
|
5 |
|
|
|
6 |
data is a list of lists: a list of labels, for each sample (possibly containing duplicates not multi-hot encoded). |
|
|
7 |
|
|
|
8 |
classes is the list of classes each label can take |
|
|
9 |
|
|
|
10 |
ratios is a list, summing to 1, of how the dataset should be split |
|
|
11 |
|
|
|
12 |
samples_per_group: list with number of samples per patient/group |
|
|
13 |
|
|
|
14 |
""" |
|
|
15 |
np.random.seed(0) # fix the random seed |
|
|
16 |
|
|
|
17 |
# data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for patient i (possibly multiple identical entries) |
|
|
18 |
|
|
|
19 |
if(samples_per_group is None): |
|
|
20 |
samples_per_group = np.ones(len(data)) |
|
|
21 |
|
|
|
22 |
#size is the number of ecgs |
|
|
23 |
size = np.sum(samples_per_group) |
|
|
24 |
|
|
|
25 |
# Organize data per label: for each label l, per_label_data[l] contains the list of patients |
|
|
26 |
# in data which have this label (potentially multiple identical entries) |
|
|
27 |
per_label_data = {c: [] for c in classes} |
|
|
28 |
for i, d in enumerate(data): |
|
|
29 |
for l in d: |
|
|
30 |
per_label_data[l].append(i) |
|
|
31 |
|
|
|
32 |
# In order not to compute lengths each time, they are tracked here. |
|
|
33 |
subset_sizes = [r * size for r in ratios] #list of subset_sizes in terms of ecgs |
|
|
34 |
per_label_subset_sizes = { c: [r * len(per_label_data[c]) for r in ratios] for c in classes } #dictionary with label: list of subset sizes in terms of patients |
|
|
35 |
|
|
|
36 |
# For each subset we want, the set of sample-ids which should end up in it |
|
|
37 |
stratified_data_ids = [set() for _ in range(len(ratios))] #initialize empty |
|
|
38 |
|
|
|
39 |
# For each sample in the data set |
|
|
40 |
print("Starting fold distribution...") |
|
|
41 |
size_prev=size+1 #just for output |
|
|
42 |
while size > 0: |
|
|
43 |
if(int(size_prev/1000) > int(size/1000)): |
|
|
44 |
print("Remaining entries to distribute:",size,"non-empty labels:", np.sum([1 for l, label_data in per_label_data.items() if len(label_data)>0])) |
|
|
45 |
size_prev=size |
|
|
46 |
# Compute |Di| |
|
|
47 |
lengths = { |
|
|
48 |
l: len(label_data) |
|
|
49 |
for l, label_data in per_label_data.items() |
|
|
50 |
} #dictionary label: number of ecgs with this label that have not been assigned to a fold yet |
|
|
51 |
try: |
|
|
52 |
# Find label of smallest |Di| |
|
|
53 |
label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get) |
|
|
54 |
except ValueError: |
|
|
55 |
# If the dictionary in `min` is empty we get a Value Error. |
|
|
56 |
# This can happen if there are unlabeled samples. |
|
|
57 |
# In this case, `size` would be > 0 but only samples without label would remain. |
|
|
58 |
# "No label" could be a class in itself: it's up to you to format your data accordingly. |
|
|
59 |
break |
|
|
60 |
# For each patient with label `label` get patient and corresponding counts |
|
|
61 |
unique_samples, unique_counts = np.unique(per_label_data[label],return_counts=True) |
|
|
62 |
idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1] |
|
|
63 |
unique_samples = unique_samples[idxs_sorted] # this is a list of all patient ids with this label sort by size descending |
|
|
64 |
unique_counts = unique_counts[idxs_sorted] # these are the corresponding counts |
|
|
65 |
|
|
|
66 |
# loop through all patient ids with this label |
|
|
67 |
for current_id, current_count in zip(unique_samples,unique_counts): |
|
|
68 |
|
|
|
69 |
subset_sizes_for_label = per_label_subset_sizes[label] #current subset sizes for the chosen label |
|
|
70 |
|
|
|
71 |
# Find argmax clj i.e. subset in greatest need of the current label |
|
|
72 |
largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten() |
|
|
73 |
|
|
|
74 |
# if there is a single best choice: assign it |
|
|
75 |
if len(largest_subsets) == 1: |
|
|
76 |
subset = largest_subsets[0] |
|
|
77 |
# If there is more than one such subset, find the one in greatest need of any label |
|
|
78 |
else: |
|
|
79 |
largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(np.array(subset_sizes)[largest_subsets])).flatten() |
|
|
80 |
subset = largest_subsets[np.random.choice(largest_subsets2)] |
|
|
81 |
|
|
|
82 |
# Store the sample's id in the selected subset |
|
|
83 |
stratified_data_ids[subset].add(current_id) |
|
|
84 |
|
|
|
85 |
# There is current_count fewer samples to distribute |
|
|
86 |
size -= samples_per_group[current_id] |
|
|
87 |
# The selected subset needs current_count fewer samples |
|
|
88 |
subset_sizes[subset] -= samples_per_group[current_id] |
|
|
89 |
|
|
|
90 |
# In the selected subset, there is one more example for each label |
|
|
91 |
# the current sample has |
|
|
92 |
for l in data[current_id]: |
|
|
93 |
per_label_subset_sizes[l][subset] -= 1 |
|
|
94 |
|
|
|
95 |
# Remove the sample from the dataset, meaning from all per_label dataset created |
|
|
96 |
for x in per_label_data.keys(): |
|
|
97 |
per_label_data[x] = [y for y in per_label_data[x] if y!=current_id] |
|
|
98 |
|
|
|
99 |
# Create the stratified dataset as a list of subsets, each containing the orginal labels |
|
|
100 |
stratified_data_ids = [sorted(strat) for strat in stratified_data_ids] |
|
|
101 |
#stratified_data = [ |
|
|
102 |
# [data[i] for i in strat] for strat in stratified_data_ids |
|
|
103 |
#] |
|
|
104 |
|
|
|
105 |
# Return both the stratified indexes, to be used to sample the `features` associated with your labels |
|
|
106 |
# And the stratified labels dataset |
|
|
107 |
|
|
|
108 |
#return stratified_data_ids, stratified_data |
|
|
109 |
return stratified_data_ids |