|
a |
|
b/pretrain/sampler_util.py |
|
|
1 |
import torch |
|
|
2 |
from torch.utils.data import Dataset, DataLoader |
|
|
3 |
from torch.utils.data.sampler import Sampler, RandomSampler |
|
|
4 |
|
|
|
5 |
""" |
|
|
6 |
class TmpDataset(Dataset): |
|
|
7 |
def __init__(self, m=10): |
|
|
8 |
self.len = m |
|
|
9 |
|
|
|
10 |
def __getitem__(self, index): |
|
|
11 |
return (list(range(10)) * index, [0] * index) |
|
|
12 |
|
|
|
13 |
def __len__(self): |
|
|
14 |
return self.len |
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
class FixedLengthBatchSampler(Sampler): |
|
|
18 |
def __init__(self, sampler, fixed_length, drop_last): |
|
|
19 |
self.sampler = sampler |
|
|
20 |
self.fixed_length = fixed_length |
|
|
21 |
self.drop_last = drop_last |
|
|
22 |
self.rel_sampler_count = 0 |
|
|
23 |
|
|
|
24 |
def __iter__(self): |
|
|
25 |
batch = [] |
|
|
26 |
now_length = 0 |
|
|
27 |
for idx in self.sampler: |
|
|
28 |
#print(batch, now_length) |
|
|
29 |
sample_length = len(self.sampler.data_source[idx][-1]) * 3 |
|
|
30 |
if now_length + sample_length > self.fixed_length: |
|
|
31 |
#print(batch, now_length) |
|
|
32 |
yield batch |
|
|
33 |
batch = [] |
|
|
34 |
now_length = 0 |
|
|
35 |
batch.append(idx) |
|
|
36 |
now_length += sample_length |
|
|
37 |
self.rel_sampler_count += 1 |
|
|
38 |
if len(batch) > 0 and not self.drop_last: |
|
|
39 |
yield batch |
|
|
40 |
|
|
|
41 |
def my_collate_fn(batch): |
|
|
42 |
type_count = len(batch[0]) |
|
|
43 |
batch_size = sum([len(item[-1]) for item in batch]) |
|
|
44 |
output = () |
|
|
45 |
for i in range(type_count): |
|
|
46 |
tmp = [] |
|
|
47 |
for item in batch: |
|
|
48 |
tmp.extend(item[i]) |
|
|
49 |
if len(tmp) <= batch_size: |
|
|
50 |
output += (torch.LongTensor(tmp),) |
|
|
51 |
else: |
|
|
52 |
output += (torch.LongTensor(tmp).reshape(batch_size, -1),) |
|
|
53 |
return output |