--- a +++ b/dataloader.py @@ -0,0 +1,47 @@ +import math +import torch +import numpy as np + + +class BatchDataloader: + def __init__(self, *tensors, bs=1, mask=None): + nonzero_idx, = np.nonzero(mask) + self.tensors = tensors + self.batch_size = bs + self.mask = mask + if nonzero_idx.size > 0: + self.start_idx = min(nonzero_idx) + self.end_idx = max(nonzero_idx)+1 + else: + self.start_idx = 0 + self.end_idx = 0 + + def __next__(self): + if self.start == self.end_idx: + raise StopIteration + end = min(self.start + self.batch_size, self.end_idx) + batch_mask = self.mask[self.start:end] + while sum(batch_mask) == 0: + self.start = end + end = min(self.start + self.batch_size, self.end_idx) + batch_mask = self.mask[self.start:end] + batch = [np.array(t[self.start:end]) for t in self.tensors] + self.start = end + self.sum += sum(batch_mask) + return [torch.tensor(b[batch_mask], dtype=torch.float32) for b in batch] + + def __iter__(self): + self.start = self.start_idx + self.sum = 0 + return self + + def __len__(self): + count = 0 + start = self.start_idx + while start != self.end_idx: + end = min(start + self.batch_size, self.end_idx) + batch_mask = self.mask[start:end] + if sum(batch_mask) != 0: + count += 1 + start = end + return count \ No newline at end of file