a b/train.py
1
import torch
2
import argparse
3
from datasets import QQRDataset,QQR_data,BertClassificationDataset
4
from tqdm import tqdm
5
from gensim.models import KeyedVectors
6
import time
7
from torch.utils.data import DataLoader
8
from models import SemNN,SemLSTM,SemAttention
9
import os
10
import torch.nn as nn
11
import torch.optim as optim
12
from transformers import AutoTokenizer
13
from transformers import BertForSequenceClassification
14
15
model_type1_list = ['SemNN','SemAttention','SemLSTM']
16
model_type2_list = ['Bert']
17
18
19
20
def train(args):
21
    batch_size = args.batch_size
22
    lr = args.lr
23
    save_path = args.savepath
24
    data_dir = args.datadir
25
    w2v_path = args.w2v_path
26
    max_length = args.max_length
27
    epochs = args.epochs
28
    model_name = args.model_name
29
    dropout_prob = args.dropout_prob
30
    in_feat = args.in_feat
31
    
32
    if model_name in model_type1_list:
33
        begin_time = time.perf_counter()
34
        w2v_model = KeyedVectors.load_word2vec_format(w2v_path,binary=False)
35
        end_time = time.perf_counter()
36
        print("Load {} cost {:.2f}s".format(w2v_path,end_time-begin_time))
37
        w2v_map = w2v_model.key_to_index
38
        
39
    elif model_name in model_type2_list:
40
        tokenizer = AutoTokenizer.from_pretrained(w2v_path)
41
    
42
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else 'cpu')
43
    
44
    save_path = os.path.join(save_path,model_name)
45
    if not os.path.exists(save_path):
46
        os.makedirs(save_path)
47
    
48
    data = QQR_data(data_dir)
49
    
50
    if model_name in model_type1_list:
51
        train_dataset = QQRDataset(data.get_train_data(),data.get_labels(),w2v_map=w2v_map,max_length=max_length)
52
        val_dataset = QQRDataset(data.get_dev_data(),data.get_labels(),w2v_map=w2v_map,max_length=max_length)
53
       
54
    elif model_name in model_type2_list:
55
        train_dataset = BertClassificationDataset(data.get_train_data(),tokenizer=tokenizer,label_list=data.get_labels(),max_length=max_length)
56
        val_dataset = BertClassificationDataset(data.get_dev_data(),tokenizer=tokenizer,label_list=data.get_labels(),max_length=max_length)
57
    
58
    train_examples_num = len(train_dataset)
59
    val_examples_num = len(val_dataset)
60
    
61
    dataset = {'train':train_dataset,'val':val_dataset}
62
    len_dataset = {'train':train_examples_num,'val':val_examples_num}
63
    
64
    if model_name == "SemNN":
65
        model = SemNN(
66
            in_feat=in_feat,
67
            num_labels=len(data.get_labels()),
68
            dropout_prob=dropout_prob,
69
            w2v_mapping=w2v_model
70
        )
71
    elif model_name == "SemLSTM":
72
        model = SemLSTM(in_feat=in_feat,
73
                        num_labels=len(data.get_labels()),
74
                        dropout_prob=dropout_prob,
75
                        w2v_mapping=w2v_model)
76
    elif model_name == "SemAttention":
77
        model = SemAttention(
78
            in_feat=in_feat,
79
            num_labels = len(data.get_labels()),
80
            dropout_prob=dropout_prob,
81
            w2v_mapping=w2v_model
82
        )
83
    elif model_name == "Bert":
84
        model = BertForSequenceClassification.from_pretrained(w2v_path,num_labels=len(data.get_labels()))
85
     
86
    print(model)
87
    
88
    model_paramters = model.parameters()
89
    
90
    criterion = nn.CrossEntropyLoss()
91
    criterion.to(device)
92
    model.to(device)
93
    
94
    optimizer = optim.SGD(model_paramters, lr=lr, momentum=0.9, weight_decay=5e-4)
95
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10,gamma=0.1)
96
    print('Model Name: '+model_name)
97
    print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
98
    
99
    best_val_acc = 0.0
100
    for epoch in range(epochs):
101
        for phase in ['train','val']:
102
            runing_loss = 0.0
103
            running_corrects = 0.0
104
            
105
            #Train or eval
106
            if(phase=='train'):
107
                scheduler.step()
108
                model.train()
109
            else:
110
                model.eval()
111
            
112
            dataloader = DataLoader(dataset[phase],batch_size=batch_size,shuffle=True,num_workers=4)
113
            for text_example in tqdm(dataloader):
114
                if model_name in model_type1_list:
115
                    text_a_inputs_id = text_example["text_a_inputs_id"].to(device)
116
                    text_b_inputs_id = text_example["text_b_inputs_id"].to(device)
117
                    text_a_attention_mask = text_example["text_a_attention_mask"].to(device)
118
                    text_b_attention_mask = text_example["text_b_attention_mask"].to(device)
119
                elif model_name in model_type2_list:
120
                    input_ids = text_example.get('input_ids').to(device)
121
                    token_type_ids = text_example.get('token_type_ids').to(device)
122
                    attention_mask = text_example.get('attention_mask').to(device)
123
                    
124
                labels = text_example['labels'].to(device)
125
                
126
                optimizer.zero_grad()
127
                
128
                if(phase=='train'):
129
                    if model_name in model_type1_list:
130
                        outputs = model(text_a_inputs_id,text_b_inputs_id,text_a_attention_mask,text_b_attention_mask)
131
                    elif model_name in model_type2_list:
132
                        outputs = model(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask,return_dict=True).get('logits')
133
                else:
134
                    with torch.no_grad():
135
                        if model_name in model_type1_list:
136
                            outputs = model(text_a_inputs_id,text_b_inputs_id,text_a_attention_mask,text_b_attention_mask)
137
                        elif model_name in model_type2_list:
138
                            outputs = model(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask,return_dict=True).get('logits')
139
                            
140
                probs = nn.Softmax(dim=1)(outputs)
141
                preds = torch.max(probs,1)[1]
142
                # print(preds.sum())
143
                
144
                loss = criterion(outputs,labels)
145
                
146
                if(phase=='train'):
147
                    loss.backward()
148
                    optimizer.step()
149
                
150
                runing_loss += loss.item() * labels.size(0)
151
                running_corrects += torch.sum(preds==labels.data)
152
            
153
            epoch_loss = runing_loss/len_dataset.get(phase)
154
            epoch_acc = running_corrects.double()/len_dataset.get(phase)
155
            if(phase=='val'):
156
                if(best_val_acc<epoch_acc):
157
                    best_val_acc = epoch_acc
158
                    torch.save({
159
                        'epoch':epoch+1,
160
                        'state_dict':model.state_dict(),
161
                        'opt_dict':optimizer.state_dict()
162
                    },os.path.join(save_path,'best_model.pth.tar'))
163
            
164
            if(epoch==epochs-1):
165
                torch.save({
166
                        'epoch':epoch+1,
167
                        'state_dict':model.state_dict(),
168
                        'opt_dict':optimizer.state_dict()
169
                    },os.path.join(save_path,'lastest_model.pth.tar'))
170
            print("[{}] Epoch: {}/{} Loss: {} Acc: {}".format(phase, epoch+1, epochs, epoch_loss, epoch_acc))
171
172
173
if __name__ == "__main__":
174
    parse = argparse.ArgumentParser()
175
    
176
    parse.add_argument('--model_name',type=str,default="SemAttention",help="Model name for train [SemNN,SemLSTM,SemAttention,Bert]")
177
    
178
    parse.add_argument('--batch_size',type=int,default=8,help="Batch-size for train")
179
    
180
    parse.add_argument('--in_feat',type=int,default=100,help="Length of features for embbeding word")
181
    
182
    parse.add_argument('--max_length',type=int,default=32,help="Max length for setence")
183
    
184
    parse.add_argument('--epochs',type=int,default=50,help="Set epochs for train")
185
    
186
    parse.add_argument('--lr',type=float,default=1e-3,help="Learning Rate for train")
187
    
188
    parse.add_argument('--dropout_prob',type=float,default=0.1,help="Dropout ratio for dropout layers")
189
    
190
    parse.add_argument('--savepath',type=str,default="./results",help="Save dir for trained model")
191
    
192
    parse.add_argument('--datadir',type=str,default='./data',help="Data path for train")
193
    
194
    parse.add_argument('--gpu',type=str,default='1',help="Gpu id for train")
195
    
196
    parse.add_argument('--w2v_path',type=str,default='./tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt',help="Path for w2v_model file")
197
    
198
    args = parse.parse_args()
199
    
200
    train(args)