a b/datasets.py
1
#!/usr/bin/env python
2
# -*- encoding: utf-8 -*-
3
'''
4
@File         :datasets.py
5
@Description  :DataSets  for NLP_query
6
@Time         :2023/01/17 15:21:45
7
@Author       :KangQing
8
@Version      :1.0
9
'''
10
11
12
import sys
13
sys.path.append('./')
14
import json
15
import os
16
import jieba
17
from torch.utils.data import Dataset
18
import torch
19
import numpy as np
20
import logging
21
from transformers import InputFeatures
22
jieba.setLogLevel(logging.INFO)
23
24
25
class InputExample():
26
    def __init__(self,id:str,text_a:str,text_b:str=None,label:str=None) -> None:  
27
        self.id = id
28
        self.text_a = text_a
29
        self.text_b = text_b
30
        self.label = label
31
    def __str__(self) -> str:
32
        return json.dumps({'id':self.id,'text_a':self.text_a,'text_b':self.text_b,'label':self.label},indent=2,ensure_ascii=False)+'\n'
33
34
35
class QQR_data():
36
    def __init__(self,data_path='data') -> None:
37
        self.data_path = data_path
38
    
39
    def get_data(self,json_data_path):
40
        with open(json_data_path,'r',encoding='utf-8') as f:
41
            data = json.load(f,encoding='utf-8')
42
        
43
        examples = []
44
        for example in data:
45
            examples.append(InputExample(example['id'],example['query1'],example['query2'],example['label'] if example['label']!="" else None))
46
        return examples
47
48
    def get_labels(self):
49
        return ['0','1','2']
50
    
51
    def get_train_data(self):
52
        path = os.path.join(self.data_path,'KUAKE-QQR_{}.json'.format('train'))
53
        return self.get_data(path)
54
    
55
    def get_dev_data(self):
56
        path = os.path.join(self.data_path,'KUAKE-QQR_{}.json'.format('dev'))
57
        return self.get_data(path)
58
    
59
    def get_test_data(self):
60
        path = os.path.join(self.data_path,'KUAKE-QQR_{}.json'.format('test'))
61
        return self.get_data(path)
62
    
63
64
    
65
class QQRDataset(Dataset):
66
    def __init__(self,examples_list,labels_list,w2v_map,max_length):
67
        self.examples_list = examples_list
68
        self.label2id = {label:idx for idx,label in enumerate(labels_list)}
69
        self.id2label = {idx:label for idx,label in enumerate(labels_list)}
70
        self.w2v_map = w2v_map
71
        self.max_length = max_length
72
    
73
    def __len__(self):
74
        return len(self.examples_list)
75
    
76
    def _tokenize(self,text):
77
        token_list = list(jieba.cut(text))
78
        token_ids = []
79
        for token in token_list:
80
            if(token in self.w2v_map):
81
                token_ids.append(self.w2v_map.get(token))
82
            else:
83
                if(len(token)>1):
84
                    for character in token:
85
                        token_ids.append(self.w2v_map.get(token) if self.w2v_map.get(token)!=None else np.random.choice(len(self.w2v_map),1).item())
86
                else:
87
                    token_ids.append(np.random.choice(len(self.w2v_map),1).item())
88
        
89
        token_ids,attention_mask = self._pad_and_cut(token_ids)
90
        return token_ids,attention_mask
91
    
92
    def _pad_and_cut(self,token_ids):
93
        
94
        #Generate attention mask
95
        attention_mask = None
96
        
97
        if(len(token_ids)>self.max_length):
98
            token_ids = token_ids[:self.max_length]
99
            attention_mask = [1]*self.max_length
100
        else:
101
            attention_mask = [1]*len(token_ids)
102
            diff = self.max_length - len(token_ids)
103
            token_ids.extend([0]*diff)
104
            attention_mask.extend([0]*diff)
105
        
106
        return torch.tensor(token_ids,dtype=torch.long),torch.tensor(attention_mask,dtype=torch.long)
107
    
108
    
109
    def __getitem__(self,index):
110
        example = self.examples_list[index]
111
        idx = example.id
112
        text_a = example.text_a
113
        text_b = example.text_b
114
        if(example.label in self.label2id):
115
            label = self.label2id[example.label]
116
        else:
117
            label = 3
118
        
119
        text_a_inputs_id,text_a_attention_mask = self._tokenize(text_a)
120
        text_b_inputs_id,text_b_attention_mask = self._tokenize(text_b)
121
        
122
        label = torch.tensor(label,dtype=torch.long)
123
124
        
125
        return {
126
            'text_a_inputs_id':text_a_inputs_id,
127
            'text_b_inputs_id':text_b_inputs_id,
128
            'text_a_attention_mask':text_a_attention_mask,
129
            'text_b_attention_mask':text_b_attention_mask,
130
            'labels':label,
131
            'text_a':text_a,
132
            'text_b':text_b,
133
            'idx':idx
134
        }
135
        
136
class BertClassificationDataset(Dataset):
137
    
138
    def __init__(
139
        self,
140
        examples,
141
        tokenizer,
142
        label_list,
143
        max_length,
144
        processer=None
145
        ):
146
        super().__init__()
147
        
148
        self.examples = examples
149
        self.max_length = max_length
150
        self.tokenizer =tokenizer
151
        self.processor = processer
152
        
153
        self.label2id = {label:idx for idx,label in enumerate(label_list)}
154
        self.id2label = {idx:label for idx,label in enumerate(label_list)}
155
        
156
    def __len__(self):
157
        return len(self.examples)
158
    
159
    def __getitem__(self, index):
160
        example = self.examples[index]
161
        if(example.label in self.label2id):
162
            label = self.label2id[example.label]
163
        else:
164
            label = 3
165
        
166
        inputs = self.tokenizer(
167
            text = example.text_a,
168
            text_pair = example.text_b,
169
            padding = 'max_length',
170
            truncation = True,
171
            max_length = self.max_length
172
        )
173
        
174
        input_ids = torch.tensor(inputs.get('input_ids'),dtype=torch.long)
175
        attention_mask = torch.tensor(inputs.get('attention_mask'),dtype=torch.long)
176
        token_type_ids = torch.tensor(inputs.get('token_type_ids'),dtype=torch.long)
177
        label = torch.tensor(label,dtype=torch.long)
178
        
179
        return {
180
            'labels':label,
181
            'text_a':example.text_a,
182
            'text_b':example.text_b,
183
            'idx':example.id,
184
            'input_ids':input_ids,
185
            'token_type_ids':token_type_ids,
186
            'attention_mask':attention_mask
187
        }
188
        
189