Diff of /dataLoader/utils.py [000000] .. [bad60c]

Switch to unified view

a b/dataLoader/utils.py
1
import random
2
3
4
def code2index(tokens, token2idx, mask_token=None):
5
    output_tokens = []
6
    for i, token in enumerate(tokens):
7
        if token==mask_token:
8
            output_tokens.append(token2idx['UNK'])
9
        else:
10
            output_tokens.append(token2idx.get(token, token2idx['UNK']))
11
    return tokens, output_tokens
12
13
14
def random_mask(tokens, token2idx):
15
    output_label = []
16
    output_token = []
17
    for i, token in enumerate(tokens):
18
        prob = random.random()
19
        # mask token with 15% probability
20
        if prob < 0.15:
21
            prob /= 0.15
22
23
            # 80% randomly change token to mask token
24
            if prob < 0.8:
25
                output_token.append(token2idx["MASK"])
26
27
            # 10% randomly change token to random token
28
            elif prob < 0.9:
29
                output_token.append(random.choice(list(token2idx.values())))
30
31
            # -> rest 10% randomly keep current token
32
33
            # append current token to output (we will predict these later
34
            output_label.append(token2idx.get(token, token2idx['UNK']))
35
        else:
36
            # no masking token (will be ignored by loss function later)
37
            output_label.append(-1)
38
            output_token.append(token2idx.get(token, token2idx['UNK']))
39
40
    return tokens, output_token, output_label
41
42
43
def index_seg(tokens, symbol='SEP'):
44
    flag = 0
45
    seg = []
46
47
    for token in tokens:
48
        if token == symbol:
49
            seg.append(flag)
50
            if flag == 0:
51
                flag = 1
52
            else:
53
                flag = 0
54
        else:
55
            seg.append(flag)
56
    return seg
57
58
59
def position_idx(tokens, symbol='SEP'):
60
    pos = []
61
    flag = 0
62
63
    for token in tokens:
64
        if token == symbol:
65
            pos.append(flag)
66
            flag += 1
67
        else:
68
            pos.append(flag)
69
    return pos
70
71
72
def seq_padding(tokens, max_len, token2idx=None, symbol=None, unkown=True):
73
    if symbol is None:
74
        symbol = 'PAD'
75
76
    seq = []
77
    token_len = len(tokens)
78
    for i in range(max_len):
79
        if token2idx is None:
80
            if i < token_len:
81
                seq.append(tokens[i])
82
            else:
83
                seq.append(symbol)
84
        else:
85
            if i < token_len:
86
                # 1 indicate UNK
87
                if unkown:
88
                    seq.append(token2idx.get(tokens[i], token2idx['UNK']))
89
                else:
90
                    seq.append(token2idx.get(tokens[i]))
91
            else:
92
                seq.append(token2idx.get(symbol))
93
    return seq