[8956d4]: / unimol / data / coord_pad_dataset.py

Download this file

83 lines (67 with data), 2.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from unicore.data import BaseWrapperDataset
def collate_tokens_coords(
values,
pad_idx,
left_pad=False,
pad_to_length=None,
pad_to_multiple=1,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
#if pad_to_multiple != 1 and size % pad_to_multiple != 0:
# size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size, 3).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :])
return res
class RightPadDatasetCoord(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad=False):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
def collater(self, samples):
return collate_tokens_coords(
samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
)
def collate_cross_2d(
values,
pad_idx,
left_pad=False,
pad_to_length=None,
pad_to_multiple=1,
):
"""Convert a list of 2d tensors into a padded 2d tensor."""
size_h = max(v.size(0) for v in values)
size_w = max(v.size(1) for v in values)
if pad_to_multiple != 1 and size_h % pad_to_multiple != 0:
size_h = int(((size_h - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
if pad_to_multiple != 1 and size_w % pad_to_multiple != 0:
size_w = int(((size_w - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
res = values[0].new(len(values), size_h, size_w).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(
v,
res[i][size_h - v.size(0) :, size_w - v.size(1) :]
if left_pad
else res[i][: v.size(0), : v.size(1)],
)
return res
class RightPadDatasetCross2D(BaseWrapperDataset):
def __init__(self, dataset, pad_idx, left_pad=False):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad
def collater(self, samples):
return collate_cross_2d(
samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8
)