Diff of /dl/utils/sampler.py [000000] .. [4807fa]

Switch to side-by-side view

--- a
+++ b/dl/utils/sampler.py
@@ -0,0 +1,186 @@
+import random
+import numpy as np
+import collections
+
+try:
+    import torch
+    from torch.autograd import Variable
+except ImportError:
+    pass
+
+__all__ = ['BatchSequentialSampler', 'RepeatedBatchSampler', 'balanced_sampler', 'BatchLoader']
+
+if torch.cuda.is_available():
+  dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} #pylint disable=no-member
+else:
+  dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} 
+
+
+class BatchSequentialSampler(object):
+    """return a list of batches (same implementation with torch.utils.data.sampler.BatchSampler)
+    Args:
+        sampler: an iterator, eg: range(100)
+        batch_size: int
+        drop_last: bool
+    Return:
+        an iterator, each iter returns a batch of batch_size from sampler 
+    """
+    def __init__(self, sampler, batch_size=1, drop_last=False):
+        self.sampler = sampler    
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+    
+    def __iter__(self):
+        batch = []
+        for i in self.sampler:
+            batch.append(i)
+            if len(batch) == self.batch_size:
+                yield batch
+                batch = []
+        if len(batch) > 0 and not self.drop_last:
+            yield batch
+    
+    def __len__(self):
+        if self.drop_last:
+            return len(self.sampler) // self.batch_size
+        else:
+            return (len(self.sampler) + self.batch_size - 1) // self.batch_size
+
+        
+class RepeatedBatchSampler(object):
+    """Generate num_iter of batches with batch_size
+    Args:
+        sampler: an iterator, that will be converted to list
+        batch_size: int
+        num_iter: int, default: None
+        shuffle: bool, default: True
+        allow_duplicate: bool, default: True
+    Return:
+        an iterator of length num_iter
+    """
+    def __init__(self, sampler, batch_size=1, num_iter=None, shuffle=True, allow_duplicate=True, seed=None): 
+        assert len(sampler) > 0
+        self.sampler = sampler # only for __next__(); uselesss
+        if len(sampler) < batch_size and not allow_duplicate:
+            batch_size = len(sampler)
+        assert batch_size > 0
+        self.batch_size = batch_size
+        if num_iter is None:
+            num_iter = (len(sampler) + batch_size - 1) // batch_size
+        assert num_iter > 0
+        self.num_iter = num_iter
+        num_repeats = (num_iter * batch_size + len(sampler) - 1 ) // len(sampler)
+        self.sampler_ext = []
+        if seed is not None: # if seed: is buggy (seed=0)
+            np.random.seed(seed)
+        for i in range(num_repeats):
+            if shuffle:
+                idx = np.random.permutation(len(sampler))
+            else:
+                idx = range(len(sampler))
+            self.sampler_ext += [sampler[i] for i in idx]
+            
+    def __iter__(self):
+        cnt = 0
+        for i in range(self.num_iter):
+            yield self.sampler_ext[cnt:(cnt + self.batch_size)]
+            cnt += self.batch_size
+            
+    def __len__(self):
+        return self.num_iter
+    
+    def __next__(self):
+        indices = np.random.permutation(len(self.sampler))[:self.batch_size]
+        sampler = list(self.sampler)
+        batch = [sampler[i] for i in indices]
+        return batch
+
+def balanced_sampler(y, batch_size=10, num_iter=None, allow_duplicate=False, 
+                     max_redundancy=3, shuffle=True, seed=None):
+    """Given class labels y, return a balanced batch sampler, i.e., 
+       each class appears the same number of times in each batch
+    Args:
+        y: list, tuple, or numpy 1-d array
+        batch_size: int; how many instances of each class should be included in a batch. 
+                    Thus the real batch size = batch_size * num_classes in most cases
+        num_iter: number of batches. If None, calculate from y, batch_size, etc.
+        allow_duplicate: in case batch_size > the smallest class size, if not allow_duplicate, 
+                         reduce batch_size
+        max_redundancy: default 3; if num_iter is initially None, 
+                the calculated num_iter will be larger than num_iter of a 'traditional' epoch 
+                by a factor of num_classes. max_redundancy can reduce this factor
+        shuffle: default True. Always shuffle the batches
+        seed: if not None, call np.random.seed(seed). For unittest
+    Return:
+        a numpy array of shape (num_iter, real_batch_size)
+    """
+    z = collections.defaultdict(list)
+    # this is extremely buggy; when y is torch.Tensor, e is different even if they have the same value
+    [z[e.item()].append(i) for i, e in enumerate(y)] 
+    least_size = min([len(v) for k, v in z.items()])
+    if least_size < batch_size and not allow_duplicate:
+        batch_size = least_size
+    if num_iter is None:
+        num_iter = (len(y) + batch_size - 1) // batch_size
+        if len(z) > max_redundancy:
+            num_iter = (num_iter * max_redundancy + len(z) - 1) // len(z)
+    
+    bs = [RepeatedBatchSampler(v, batch_size=batch_size, num_iter=num_iter, shuffle=shuffle,
+                               allow_duplicate=allow_duplicate, seed=seed)
+          for k, v in z.items()]
+    bs = [[e for e in s] for s in bs]
+    indices = np.array(bs).transpose(1, 0, 2).reshape(num_iter, -1)
+    # In each batch, shuffle instances so that instances of the same class won't cluster together
+    # may not be necessary
+    if shuffle:
+        if seed is not None:
+            np.random.seed(seed)
+        [np.random.shuffle(v) for v in indices]
+    return indices
+
+
+class BatchLoader(object):
+    """Return an iterator of data batches
+    Args:
+        data: a single or a list/tuple of np.array/torch.Tensor
+        labels: class labels, e.g., a list of int, used for balanced_sampler
+        batch_size: int
+        balanced: if true, used balanced_sampler, else use BatchSequentialSampler
+        The rest of parameters are to be passed to balanced_sampler
+    """
+    def __init__(self, data, batch_size=10, labels=None, balanced=True, num_iter=None, 
+                 allow_duplicate=False, max_redundancy=3, shuffle=True, seed=None):
+        assert (labels is None and isinstance(data, (tuple, list)) and len(data) > 1) or (
+            labels is not None)
+        if labels is None:
+            labels = data[-1]
+        if not isinstance(data, (tuple, list)):
+            data = [data]
+        assert len(data) > 0 and len(data[0]) == len(labels)
+        self.data = data
+        N = len(labels)   
+        if balanced:
+            self.indices = balanced_sampler(labels, batch_size=batch_size, num_iter=num_iter, 
+                                            allow_duplicate=allow_duplicate, max_redundancy=max_redundancy,
+                                            shuffle=shuffle, seed=seed)
+        else:
+            idx = range(N)
+            if shuffle:
+                idx = np.random.permutation(N).tolist()
+            self.indices = RepeatedBatchSampler(idx, batch_size=batch_size, 
+                num_iter=num_iter, shuffle=shuffle, allow_duplicate=allow_duplicate, seed=seed)
+    
+    def __iter__(self):
+        for idx in self.indices:
+            batch = []
+            for data in self.data:
+                try:
+                    if isinstance(data, torch.Tensor):
+                        idx = torch.LongTensor(idx)
+                except NameError:
+                    pass
+                batch.append(data[idx]) 
+            yield batch
+                
+    def __len__(self):
+        return len(self.indices)
\ No newline at end of file