import math
import torch
import numpy as np
class BatchDataloader:
def __init__(self, *tensors, bs=1, mask=None):
nonzero_idx, = np.nonzero(mask)
self.tensors = tensors
self.batch_size = bs
self.mask = mask
if nonzero_idx.size > 0:
self.start_idx = min(nonzero_idx)
self.end_idx = max(nonzero_idx)+1
else:
self.start_idx = 0
self.end_idx = 0
def __next__(self):
if self.start == self.end_idx:
raise StopIteration
end = min(self.start + self.batch_size, self.end_idx)
batch_mask = self.mask[self.start:end]
while sum(batch_mask) == 0:
self.start = end
end = min(self.start + self.batch_size, self.end_idx)
batch_mask = self.mask[self.start:end]
batch = [np.array(t[self.start:end]) for t in self.tensors]
self.start = end
self.sum += sum(batch_mask)
return [torch.tensor(b[batch_mask], dtype=torch.float32) for b in batch]
def __iter__(self):
self.start = self.start_idx
self.sum = 0
return self
def __len__(self):
count = 0
start = self.start_idx
while start != self.end_idx:
end = min(start + self.batch_size, self.end_idx)
batch_mask = self.mask[start:end]
if sum(batch_mask) != 0:
count += 1
start = end
return count