a b/eval.py
1
import torch
2
import argparse
3
from datasets import QQRDataset,QQR_data,BertClassificationDataset
4
from tqdm import tqdm
5
from gensim.models import KeyedVectors
6
import time
7
from torch.utils.data import DataLoader
8
from models import SemNN,SemLSTM,SemAttention
9
import os
10
import torch.nn as nn
11
import torch.optim as optim
12
from transformers import AutoTokenizer
13
from transformers import BertForSequenceClassification
14
15
model_type1_list = ['SemNN','SemAttention','SemLSTM']
16
model_type2_list = ['Bert']
17
18
19
20
def train(args):
21
    batch_size = args.batch_size
22
    data_dir = args.datadir
23
    w2v_path = args.w2v_path
24
    max_length = args.max_length
25
    model_name = args.model_name
26
    in_feat = args.in_feat
27
    dropout_prob = args.dropout_prob
28
    model_path = args.model_path
29
    
30
    if model_name in model_type1_list:
31
        begin_time = time.perf_counter()
32
        w2v_model = KeyedVectors.load_word2vec_format(w2v_path,binary=False)
33
        end_time = time.perf_counter()
34
        print("Load {} cost {:.2f}s".format(w2v_path,end_time-begin_time))
35
        w2v_map = w2v_model.key_to_index
36
        
37
    elif model_name in model_type2_list:
38
        tokenizer = AutoTokenizer.from_pretrained(w2v_path)
39
    
40
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else 'cpu')
41
    
42
    data = QQR_data(data_dir)
43
    
44
    if model_name in model_type1_list:
45
        train_dataset = QQRDataset(data.get_train_data(),data.get_labels(),w2v_map=w2v_map,max_length=max_length)
46
        val_dataset = QQRDataset(data.get_dev_data(),data.get_labels(),w2v_map=w2v_map,max_length=max_length)
47
       
48
    elif model_name in model_type2_list:
49
        train_dataset = BertClassificationDataset(data.get_train_data(),tokenizer=tokenizer,label_list=data.get_labels(),max_length=max_length)
50
        val_dataset = BertClassificationDataset(data.get_dev_data(),tokenizer=tokenizer,label_list=data.get_labels(),max_length=max_length)
51
    
52
    train_examples_num = len(train_dataset)
53
    val_examples_num = len(val_dataset)
54
    
55
    dataset = {'train':train_dataset,'val':val_dataset}
56
    len_dataset = {'train':train_examples_num,'val':val_examples_num}
57
    
58
    if model_name == "SemNN":
59
        model = SemNN(
60
            in_feat=in_feat,
61
            num_labels=len(data.get_labels()),
62
            dropout_prob=dropout_prob,
63
            w2v_mapping=w2v_model
64
        )
65
    elif model_name == "SemLSTM":
66
        model = SemLSTM(in_feat=in_feat,
67
                        num_labels=len(data.get_labels()),
68
                        dropout_prob=dropout_prob,
69
                        w2v_mapping=w2v_model)
70
    elif model_name == "SemAttention":
71
        model = SemAttention(
72
            in_feat=in_feat,
73
            num_labels = len(data.get_labels()),
74
            dropout_prob=dropout_prob,
75
            w2v_mapping=w2v_model
76
        )
77
    elif model_name == "Bert":
78
        model = BertForSequenceClassification.from_pretrained(w2v_path,num_labels=len(data.get_labels()))
79
     
80
    print(model)
81
    
82
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
83
    model.load_state_dict(checkpoint['state_dict'])
84
    model.to(device)
85
    
86
    print('Model Name: '+model_name)
87
    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
88
    
89
    best_val_acc = 0.0
90
    for phase in ['train','val']:
91
        runing_loss = 0.0
92
        running_corrects = 0.0
93
        
94
        model.eval()
95
        
96
        dataloader = DataLoader(dataset[phase],batch_size=batch_size,shuffle=True,num_workers=4)
97
        for text_example in tqdm(dataloader):
98
            if model_name in model_type1_list:
99
                text_a_inputs_id = text_example["text_a_inputs_id"].to(device)
100
                text_b_inputs_id = text_example["text_b_inputs_id"].to(device)
101
                text_a_attention_mask = text_example["text_a_attention_mask"].to(device)
102
                text_b_attention_mask = text_example["text_b_attention_mask"].to(device)
103
            elif model_name in model_type2_list:
104
                input_ids = text_example.get('input_ids').to(device)
105
                token_type_ids = text_example.get('token_type_ids').to(device)
106
                attention_mask = text_example.get('attention_mask').to(device)
107
                
108
            labels = text_example['labels'].to(device)
109
            
110
111
            with torch.no_grad():
112
                if model_name in model_type1_list:
113
                    outputs = model(text_a_inputs_id,text_b_inputs_id,text_a_attention_mask,text_b_attention_mask)
114
                elif model_name in model_type2_list:
115
                    outputs = model(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask,return_dict=True).get('logits')
116
                        
117
            probs = nn.Softmax(dim=1)(outputs)
118
            preds = torch.max(probs,1)[1]
119
120
            running_corrects += torch.sum(preds==labels.data)
121
        
122
        epoch_acc = running_corrects.double()/len_dataset.get(phase)
123
        print("[{}]  Acc: {}".format(phase, epoch_acc))
124
125
126
if __name__ == "__main__":
127
    parse = argparse.ArgumentParser()
128
    
129
    parse.add_argument('--model_name',type=str,default="SemAttention",help="Model name for train [SemNN,SemLSTM,SemAttention,Bert]")
130
    
131
    parse.add_argument('--batch_size',type=int,default=8,help="Batch-size for train")
132
    
133
    parse.add_argument('--in_feat',type=int,default=100,help="Length of features for embbeding word")
134
    
135
    parse.add_argument('--model_path',type=str,default='./results/SemAttention/best_model.pth.tar',help="Saved model path")
136
    
137
    parse.add_argument('--max_length',type=int,default=32,help="Max length for setence")
138
    
139
    parse.add_argument('--dropout_prob',type=float,default=0.1,help="Dropout ratio for dropout layers")
140
    
141
    parse.add_argument('--datadir',type=str,default='./data',help="Data path for train")
142
    
143
    parse.add_argument('--gpu',type=str,default='1',help="Gpu id for train")
144
    
145
    parse.add_argument('--w2v_path',type=str,default='./tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt',help="Path for w2v_model file")
146
    
147
    args = parse.parse_args()
148
    
149
    train(args)