a b/Retrieval/utils/masking.py
1
import torch
2
3
4
class TriangularCausalMask():
5
    def __init__(self, B, L, device="cpu"):
6
        mask_shape = [B, 1, L, L]
7
        with torch.no_grad():
8
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
9
10
    @property
11
    def mask(self):
12
        return self._mask
13
14
15
class ProbMask():
16
    def __init__(self, B, H, L, index, scores, device="cpu"):
17
        _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
18
        _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
19
        indicator = _mask_ex[torch.arange(B)[:, None, None],
20
                    torch.arange(H)[None, :, None],
21
                    index, :].to(device)
22
        self._mask = indicator.view(scores.shape).to(device)
23
24
    @property
25
    def mask(self):
26
        return self._mask