a b/inference.py
1
import torch
2
import argparse
3
from datasets import QQRDataset,QQR_data,BertClassificationDataset
4
from tqdm import tqdm
5
from gensim.models import KeyedVectors
6
import time
7
from torch.utils.data import DataLoader
8
from models import SemNN,SemLSTM,SemAttention
9
from transformers import AutoTokenizer
10
import os
11
import torch.nn as nn
12
import json
13
from transformers import BertForSequenceClassification
14
15
model_type1_list = ['SemNN','SemAttention','SemLSTM']
16
model_type2_list = ['Bert']
17
18
19
def inference(args):
20
    batch_size = args.batch_size
21
    save_path = args.savepath
22
    data_dir = args.datadir
23
    w2v_path = args.w2v_path
24
    max_length = args.max_length
25
    model_path = args.model_path
26
    model_name = args.model_name
27
    in_feat = args.in_feat
28
    dropout_prob = args.dropout_prob
29
    
30
    if model_name in model_type1_list:
31
        begin_time = time.perf_counter()
32
        w2v_model = KeyedVectors.load_word2vec_format(w2v_path,binary=False)
33
        end_time = time.perf_counter()
34
        print("Load {} cost {:.2f}s".format(w2v_path,end_time-begin_time))
35
        w2v_map = w2v_model.key_to_index
36
        
37
    elif model_name in model_type2_list:
38
        tokenizer = AutoTokenizer.from_pretrained(w2v_path)
39
    
40
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else 'cpu')
41
    
42
    if not os.path.exists(save_path):
43
        os.makedirs(save_path)
44
    
45
    data = QQR_data(data_dir)
46
    
47
    if model_name in model_type1_list:
48
        test_dataset = QQRDataset(data.get_test_data(),data.get_labels(),w2v_map=w2v_map,max_length=max_length)
49
    elif model_name in model_type2_list:
50
        test_dataset = BertClassificationDataset(data.get_test_data(),tokenizer=tokenizer,label_list=data.get_labels(),max_length=max_length)
51
    
52
    id2label = test_dataset.id2label
53
    
54
    dataloader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=4)
55
    
56
    if model_name == "SemNN":
57
        model = SemNN(
58
            in_feat=100,
59
            num_labels=len(data.get_labels()),
60
            dropout_prob=dropout_prob,
61
            w2v_mapping=w2v_model
62
        )
63
    elif model_name == "SemLSTM":
64
        model = SemLSTM(
65
            in_feat=in_feat,
66
            num_labels=len(data.get_labels()),
67
            dropout_prob=dropout_prob,
68
            w2v_mapping=w2v_model
69
        )
70
    elif model_name == "SemAttention":
71
        model = SemAttention(
72
            in_feat=in_feat,
73
            num_labels = len(data.get_labels()),
74
            dropout_prob=dropout_prob,
75
            w2v_mapping=w2v_model
76
        )
77
    elif model_name == "Bert":
78
        model = BertForSequenceClassification.from_pretrained(w2v_path,num_labels=len(data.get_labels()))
79
    
80
        
81
    print(model)
82
    # model_paramters = model.parameters()
83
    print('Model Name: '+model_name)
84
    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
85
    
86
    
87
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
88
    model.load_state_dict(checkpoint['state_dict'])
89
    model.to(device)
90
    model.eval()
91
    
92
    print(model)
93
    
94
    json_results = []
95
    
96
    preds = 0
97
    for text_example in dataloader:
98
        text_a = text_example.get('text_a')
99
        text_b = text_example.get('text_b')
100
        idx = text_example.get('idx')
101
        if model_name in model_type1_list:
102
            text_a_inputs_id = text_example.get("text_a_inputs_id").to(device)
103
            text_b_inputs_id = text_example.get("text_b_inputs_id").to(device)
104
            text_a_attention_mask = text_example.get("text_a_attention_mask").to(device)
105
            text_b_attention_mask = text_example.get("text_b_attention_mask").to(device)
106
        elif model_name in model_type2_list:
107
            input_ids = text_example.get('input_ids').to(device)
108
            token_type_ids = text_example.get('token_type_ids').to(device)
109
            attention_mask = text_example.get('attention_mask').to(device)
110
        with torch.no_grad():
111
            if model_name in model_type1_list:
112
                outputs = model(text_a_inputs_id,text_b_inputs_id,text_a_attention_mask,text_b_attention_mask)
113
            elif model_name in model_type2_list:
114
                outputs = model(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask,return_dict=True).get('logits')
115
 
116
            probs = nn.Softmax(dim=1)(outputs)
117
            preds = torch.max(probs,1)[1].data.cpu()
118
            # print(preds)
119
            
120
        for i in range(outputs.size(0)):
121
            json_results.append({
122
                "id":idx[i],
123
                "query1":text_a[i],
124
                "query2":text_b[i],
125
                "label":id2label[preds[i].item()]
126
            })
127
            # print(json_results)
128
            # break
129
        
130
        with open(os.path.join(save_path,'results_test.json'),'w',encoding='utf-8') as f:
131
            json.dump(json_results,f,ensure_ascii=False,indent=2)
132
            f.close()
133
            
134
            
135
            
136
137
            
138
        
139
        
140
        
141
    
142
    
143
    
144
145
146
147
if __name__ == "__main__":
148
    parse = argparse.ArgumentParser()
149
    
150
    parse.add_argument('--model_name',type=str,default="SemAttention",help="Model name for train")
151
    
152
    parse.add_argument('--in_feat',type=int,default=100,help="Length of features for embbeding word")
153
    
154
    parse.add_argument('--dropout_prob',type=float,default=0.1,help="Dropout ratio for dropout layers")
155
    
156
    parse.add_argument('--batch_size',type=int,default=128,help="Batch-size for train")
157
    
158
    parse.add_argument('--max_length',type=int,default=32,help="Max length for setence")
159
    
160
    parse.add_argument('--savepath',type=str,default="./results/SemAttention",help="Save dir for trained model")
161
    
162
    parse.add_argument('--datadir',type=str,default='./data',help="Data path for train and test")
163
    
164
    parse.add_argument('--model_path',type=str,default='./results/SemAttention/best_model.pth.tar',help="Saved model path")
165
    
166
    parse.add_argument('--gpu',type=str,default='1',help="Gpu id for train")
167
    
168
    parse.add_argument('--w2v_path',type=str,default='./tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt',help="Path for w2v_model file")
169
    
170
    args = parse.parse_args()
171
    
172
    inference(args)