Diff of /clinical_ts/stratify.py [000000] .. [134fd7]

Switch to side-by-side view

--- a
+++ b/clinical_ts/stratify.py
@@ -0,0 +1,109 @@
+import numpy as np
+
+def stratify(data, classes, ratios, samples_per_group=None):
+    """Stratifying procedure. Modified from https://vict0rs.ch/2018/05/24/sample-multilabel-dataset/ (based on Sechidis 2011)
+
+    data is a list of lists: a list of labels, for each sample (possibly containing duplicates not multi-hot encoded).
+    
+    classes is the list of classes each label can take
+
+    ratios is a list, summing to 1, of how the dataset should be split
+
+    samples_per_group: list with number of samples per patient/group
+
+    """
+    np.random.seed(0) # fix the random seed
+
+    # data is now always a list of lists; len(data) is the number of patients; data[i] is the list of all labels for patient i (possibly multiple identical entries)
+
+    if(samples_per_group is None):
+        samples_per_group = np.ones(len(data))
+        
+    #size is the number of ecgs
+    size = np.sum(samples_per_group)
+
+    # Organize data per label: for each label l, per_label_data[l] contains the list of patients
+    # in data which have this label (potentially multiple identical entries)
+    per_label_data = {c: [] for c in classes}
+    for i, d in enumerate(data):
+        for l in d:
+            per_label_data[l].append(i)
+
+    # In order not to compute lengths each time, they are tracked here.
+    subset_sizes = [r * size for r in ratios] #list of subset_sizes in terms of ecgs
+    per_label_subset_sizes = { c: [r * len(per_label_data[c]) for r in ratios] for c in classes } #dictionary with label: list of subset sizes in terms of patients
+
+    # For each subset we want, the set of sample-ids which should end up in it
+    stratified_data_ids = [set() for _ in range(len(ratios))] #initialize empty
+
+    # For each sample in the data set
+    print("Starting fold distribution...")
+    size_prev=size+1 #just for output
+    while size > 0:
+        if(int(size_prev/1000) > int(size/1000)):
+            print("Remaining entries to distribute:",size,"non-empty labels:", np.sum([1 for l, label_data in per_label_data.items() if len(label_data)>0]))
+        size_prev=size
+        # Compute |Di| 
+        lengths = {
+            l: len(label_data)
+            for l, label_data in per_label_data.items()
+        } #dictionary label: number of ecgs with this label that have not been assigned to a fold yet
+        try:
+            # Find label of smallest |Di|
+            label = min({k: v for k, v in lengths.items() if v > 0}, key=lengths.get)
+        except ValueError:
+            # If the dictionary in `min` is empty we get a Value Error. 
+            # This can happen if there are unlabeled samples.
+            # In this case, `size` would be > 0 but only samples without label would remain.
+            # "No label" could be a class in itself: it's up to you to format your data accordingly.
+            break
+        # For each patient with label `label` get patient and corresponding counts
+        unique_samples, unique_counts = np.unique(per_label_data[label],return_counts=True)
+        idxs_sorted = np.argsort(unique_counts, kind='stable')[::-1]
+        unique_samples = unique_samples[idxs_sorted] # this is a list of all patient ids with this label sort by size descending
+        unique_counts =  unique_counts[idxs_sorted] # these are the corresponding counts
+        
+        # loop through all patient ids with this label
+        for current_id, current_count in zip(unique_samples,unique_counts):
+            
+            subset_sizes_for_label = per_label_subset_sizes[label] #current subset sizes for the chosen label
+
+            # Find argmax clj i.e. subset in greatest need of the current label
+            largest_subsets = np.argwhere(subset_sizes_for_label == np.amax(subset_sizes_for_label)).flatten()
+            
+            # if there is a single best choice: assign it
+            if len(largest_subsets) == 1:
+                subset = largest_subsets[0]
+            # If there is more than one such subset, find the one in greatest need of any label
+            else:
+                largest_subsets2 = np.argwhere(np.array(subset_sizes)[largest_subsets] == np.amax(np.array(subset_sizes)[largest_subsets])).flatten()
+                subset = largest_subsets[np.random.choice(largest_subsets2)]
+
+            # Store the sample's id in the selected subset
+            stratified_data_ids[subset].add(current_id)
+
+            # There is current_count fewer samples to distribute
+            size -= samples_per_group[current_id]
+            # The selected subset needs current_count fewer samples
+            subset_sizes[subset] -= samples_per_group[current_id]
+
+            # In the selected subset, there is one more example for each label
+            # the current sample has
+            for l in data[current_id]:
+                per_label_subset_sizes[l][subset] -= 1
+               
+            # Remove the sample from the dataset, meaning from all per_label dataset created
+            for x in per_label_data.keys():
+                per_label_data[x] = [y for y in per_label_data[x] if y!=current_id]
+              
+    # Create the stratified dataset as a list of subsets, each containing the orginal labels
+    stratified_data_ids = [sorted(strat) for strat in stratified_data_ids]
+    #stratified_data = [
+    #    [data[i] for i in strat] for strat in stratified_data_ids
+    #]
+
+    # Return both the stratified indexes, to be used to sample the `features` associated with your labels
+    # And the stratified labels dataset
+
+    #return stratified_data_ids, stratified_data
+    return stratified_data_ids