Diff of /dataloader.py [000000] .. [a8f942]

Switch to side-by-side view

--- a
+++ b/dataloader.py
@@ -0,0 +1,47 @@
+import math
+import torch
+import numpy as np
+
+
+class BatchDataloader:
+    def __init__(self, *tensors, bs=1, mask=None):
+        nonzero_idx, = np.nonzero(mask)
+        self.tensors = tensors
+        self.batch_size = bs
+        self.mask = mask
+        if nonzero_idx.size > 0:
+            self.start_idx = min(nonzero_idx)
+            self.end_idx = max(nonzero_idx)+1
+        else:
+            self.start_idx = 0
+            self.end_idx = 0
+
+    def __next__(self):
+        if self.start == self.end_idx:
+            raise StopIteration
+        end = min(self.start + self.batch_size, self.end_idx)
+        batch_mask = self.mask[self.start:end]
+        while sum(batch_mask) == 0:
+            self.start = end
+            end = min(self.start + self.batch_size, self.end_idx)
+            batch_mask = self.mask[self.start:end]
+        batch = [np.array(t[self.start:end]) for t in self.tensors]
+        self.start = end
+        self.sum += sum(batch_mask)
+        return [torch.tensor(b[batch_mask], dtype=torch.float32) for b in batch]
+
+    def __iter__(self):
+        self.start = self.start_idx
+        self.sum = 0
+        return self
+
+    def __len__(self):
+        count = 0
+        start = self.start_idx
+        while start != self.end_idx:
+            end = min(start + self.batch_size, self.end_idx)
+            batch_mask = self.mask[start:end]
+            if sum(batch_mask) != 0:
+                count += 1
+            start = end
+        return count
\ No newline at end of file