|
a |
|
b/bert_mixup/late_mixup/train.py |
|
|
1 |
import argparse |
|
|
2 |
import csv |
|
|
3 |
import os |
|
|
4 |
import random |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
import torch |
|
|
8 |
import torch.backends.cudnn as cudnn |
|
|
9 |
import torch.nn as nn |
|
|
10 |
from tqdm import tqdm |
|
|
11 |
|
|
|
12 |
from models.text_bert import TextBERT |
|
|
13 |
from data_loader import MoleculeDataLoader |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def parse_args(): |
|
|
17 |
parser = argparse.ArgumentParser(description="Mixup for text classification") |
|
|
18 |
parser.add_argument( |
|
|
19 |
"--name", default="cnn-text-fine-tune", type=str, help="name of the experiment" |
|
|
20 |
) |
|
|
21 |
parser.add_argument( |
|
|
22 |
"--num-labels", |
|
|
23 |
type=int, |
|
|
24 |
default=2, |
|
|
25 |
metavar="L", |
|
|
26 |
help="number of labels of the train dataset (default: 2)", |
|
|
27 |
) |
|
|
28 |
parser.add_argument( |
|
|
29 |
"--model-name-or-path", |
|
|
30 |
type=str, |
|
|
31 |
default="shahrukhx01/smole-bert", |
|
|
32 |
metavar="M", |
|
|
33 |
help="name of the pre-trained transformer model from hf hub", |
|
|
34 |
) |
|
|
35 |
parser.add_argument( |
|
|
36 |
"--dataset-name", |
|
|
37 |
type=str, |
|
|
38 |
default="bace", |
|
|
39 |
metavar="D", |
|
|
40 |
help="name of the molecule net dataset (default: bace) all: bace, bbbp", |
|
|
41 |
) |
|
|
42 |
parser.add_argument( |
|
|
43 |
"--cuda", |
|
|
44 |
default=True, |
|
|
45 |
type=lambda x: (str(x).lower() == "true"), |
|
|
46 |
help="use cuda if available", |
|
|
47 |
) |
|
|
48 |
parser.add_argument("--lr", default=0.001, type=float, help="learning rate") |
|
|
49 |
parser.add_argument("--dropout", default=0.5, type=float, help="dropout rate") |
|
|
50 |
parser.add_argument("--decay", default=0.0, type=float, help="weight decay") |
|
|
51 |
parser.add_argument( |
|
|
52 |
"--model", default="TextCNN", type=str, help="model type (default: TextCNN)" |
|
|
53 |
) |
|
|
54 |
parser.add_argument("--seed", default=1, type=int, help="random seed") |
|
|
55 |
parser.add_argument( |
|
|
56 |
"--batch-size", default=50, type=int, help="batch size (default: 128)" |
|
|
57 |
) |
|
|
58 |
parser.add_argument( |
|
|
59 |
"--epoch", default=50, type=int, help="total epochs (default: 200)" |
|
|
60 |
) |
|
|
61 |
parser.add_argument( |
|
|
62 |
"--fine-tune", |
|
|
63 |
default=True, |
|
|
64 |
type=lambda x: (str(x).lower() == "true"), |
|
|
65 |
help="whether to fine-tune embedding or not", |
|
|
66 |
) |
|
|
67 |
parser.add_argument( |
|
|
68 |
"--method", |
|
|
69 |
default="embed", |
|
|
70 |
type=str, |
|
|
71 |
help="which mixing method to use (default: none)", |
|
|
72 |
) |
|
|
73 |
parser.add_argument( |
|
|
74 |
"--alpha", |
|
|
75 |
default=1.0, |
|
|
76 |
type=float, |
|
|
77 |
help="mixup interpolation coefficient (default: 1)", |
|
|
78 |
) |
|
|
79 |
parser.add_argument( |
|
|
80 |
"--save-path", default="out", type=str, help="output log/result directory" |
|
|
81 |
) |
|
|
82 |
parser.add_argument("--num-runs", default=10, type=int, help="number of runs") |
|
|
83 |
parser.add_argument( |
|
|
84 |
"--debug", |
|
|
85 |
type=int, |
|
|
86 |
default=0, |
|
|
87 |
metavar="DB", |
|
|
88 |
help="flag to enable debug mode for dev (default: 0)", |
|
|
89 |
) |
|
|
90 |
|
|
|
91 |
parser.add_argument( |
|
|
92 |
"--samples-per-class", |
|
|
93 |
type=int, |
|
|
94 |
default=-1, |
|
|
95 |
metavar="SPC", |
|
|
96 |
help="no. of samples per class label to sample for SSL (default: 250)", |
|
|
97 |
) |
|
|
98 |
parser.add_argument( |
|
|
99 |
"--n-augment", |
|
|
100 |
type=int, |
|
|
101 |
default=0, |
|
|
102 |
metavar="NAUG", |
|
|
103 |
help="number of enumeration augmentations", |
|
|
104 |
) |
|
|
105 |
parser.add_argument( |
|
|
106 |
"--eval-after", |
|
|
107 |
type=int, |
|
|
108 |
default=10, |
|
|
109 |
metavar="EA", |
|
|
110 |
help="number of epochs after which model is evaluated on test set (default: 10)", |
|
|
111 |
) |
|
|
112 |
args = parser.parse_args() |
|
|
113 |
return args |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
def mixup_criterion_cross_entropy(criterion, pred, y_a, y_b, lam): |
|
|
117 |
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
class Classification: |
|
|
121 |
def __init__(self, args): |
|
|
122 |
self.args = args |
|
|
123 |
|
|
|
124 |
self.use_cuda = args.cuda and torch.cuda.is_available() |
|
|
125 |
|
|
|
126 |
# for reproducibility |
|
|
127 |
torch.manual_seed(args.seed) |
|
|
128 |
torch.backends.cudnn.deterministic = True |
|
|
129 |
torch.backends.cudnn.benchmark = False |
|
|
130 |
np.random.seed(args.seed) |
|
|
131 |
random.seed(args.seed) |
|
|
132 |
|
|
|
133 |
# data loaders |
|
|
134 |
data_loaders = MoleculeDataLoader( |
|
|
135 |
dataset_name=args.dataset_name, |
|
|
136 |
batch_size=args.batch_size, |
|
|
137 |
debug=args.debug, |
|
|
138 |
n_augment=args.n_augment, |
|
|
139 |
samples_per_class=args.samples_per_class, |
|
|
140 |
model_name_or_path=args.model_name_or_path, |
|
|
141 |
) |
|
|
142 |
data_loaders.create_supervised_loaders(samples_per_class=args.samples_per_class) |
|
|
143 |
# model |
|
|
144 |
|
|
|
145 |
self.model = TextBERT( |
|
|
146 |
pretrained_model=args.model_name_or_path, |
|
|
147 |
num_class=args.num_labels, |
|
|
148 |
fine_tune=args.fine_tune, |
|
|
149 |
dropout=args.dropout, |
|
|
150 |
) |
|
|
151 |
self.device = torch.device( |
|
|
152 |
"cuda" if (args.cuda and torch.cuda.is_available()) else "cpu" |
|
|
153 |
) |
|
|
154 |
self.model.to(self.device) |
|
|
155 |
|
|
|
156 |
# logs |
|
|
157 |
os.makedirs(args.save_path, exist_ok=True) |
|
|
158 |
self.model_save_path = os.path.join(args.save_path, args.name + "_weights.pt") |
|
|
159 |
self.log_path = os.path.join(args.save_path, args.name + "_logs.csv") |
|
|
160 |
print(str(args)) |
|
|
161 |
with open(self.log_path, "a") as f: |
|
|
162 |
f.write(str(args) + "\n") |
|
|
163 |
with open(self.log_path, "a", newline="") as out: |
|
|
164 |
writer = csv.writer(out) |
|
|
165 |
writer.writerow(["mode", "epoch", "step", "loss", "acc"]) |
|
|
166 |
|
|
|
167 |
# optimizer |
|
|
168 |
self.criterion = nn.CrossEntropyLoss() |
|
|
169 |
self.optimizer = torch.optim.Adam( |
|
|
170 |
self.model.parameters(), lr=args.lr, weight_decay=args.decay |
|
|
171 |
) |
|
|
172 |
|
|
|
173 |
# for early stopping |
|
|
174 |
self.best_val_acc = 0 |
|
|
175 |
self.early_stop = False |
|
|
176 |
self.val_patience = ( |
|
|
177 |
0 # successive iteration when validation acc did not improve |
|
|
178 |
) |
|
|
179 |
|
|
|
180 |
self.iteration_number = 0 |
|
|
181 |
|
|
|
182 |
def get_perm(self, x): |
|
|
183 |
"""get random permutation""" |
|
|
184 |
batch_size = x.size()[0] |
|
|
185 |
if self.use_cuda: |
|
|
186 |
index = torch.randperm(batch_size).cuda() |
|
|
187 |
else: |
|
|
188 |
index = torch.randperm(batch_size) |
|
|
189 |
return index |
|
|
190 |
|
|
|
191 |
def test(self, iterator): |
|
|
192 |
self.model.eval() |
|
|
193 |
test_loss = 0 |
|
|
194 |
total = 0 |
|
|
195 |
correct = 0 |
|
|
196 |
with torch.no_grad(): |
|
|
197 |
# for _, batch in tqdm(enumerate(iterator), total=len(iterator), desc='test'): |
|
|
198 |
for _, batch in enumerate(iterator): |
|
|
199 |
batch = tuple(t.to(self.device) for t in batch) |
|
|
200 |
b_input_ids, b_input_mask, b_labels = batch |
|
|
201 |
y_pred = self.model(b_input_ids, b_input_mask) |
|
|
202 |
loss = self.criterion(y_pred, b_labels) |
|
|
203 |
test_loss += loss.item() * b_labels.shape[0] |
|
|
204 |
total += b_labels.shape[0] |
|
|
205 |
correct += torch.sum(torch.argmax(y_pred, dim=1) == b_labels).item() |
|
|
206 |
|
|
|
207 |
avg_loss = test_loss / total |
|
|
208 |
acc = 100.0 * correct / total |
|
|
209 |
return avg_loss, acc |
|
|
210 |
|
|
|
211 |
def train_mixup(self, epoch): |
|
|
212 |
self.model.train() |
|
|
213 |
train_loss = 0 |
|
|
214 |
total = 0 |
|
|
215 |
correct = 0 |
|
|
216 |
for _, batch in enumerate(self.train_iterator): |
|
|
217 |
batch = tuple(t.to(self.device) for t in batch) |
|
|
218 |
lam = np.random.beta(self.args.alpha, self.args.alpha) |
|
|
219 |
b_input_ids, b_input_mask, b_labels = batch |
|
|
220 |
index = self.get_perm(b_input_ids) |
|
|
221 |
b_input_ids1 = b_input_ids1[:, index] |
|
|
222 |
b_input_mask1 = b_input_mask[:, index] |
|
|
223 |
b_labels1 = b_labels[:, index] |
|
|
224 |
|
|
|
225 |
if self.args.method == "embed": |
|
|
226 |
y_pred = self.model.forward_mix_embed( |
|
|
227 |
b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam |
|
|
228 |
) |
|
|
229 |
elif self.args.method == "sent": |
|
|
230 |
y_pred = self.model.forward_mix_sent( |
|
|
231 |
b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam |
|
|
232 |
) |
|
|
233 |
elif self.args.method == "encoder": |
|
|
234 |
y_pred = self.model.forward_mix_encoder( |
|
|
235 |
b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam |
|
|
236 |
) |
|
|
237 |
else: |
|
|
238 |
raise ValueError("invalid method name") |
|
|
239 |
|
|
|
240 |
loss = mixup_criterion_cross_entropy( |
|
|
241 |
self.criterion, y_pred, b_labels, b_labels1, lam |
|
|
242 |
) |
|
|
243 |
train_loss += loss.item() * b_labels.shape[0] |
|
|
244 |
total += b_labels.shape[0] |
|
|
245 |
_, predicted = torch.max(y_pred.data, 1) |
|
|
246 |
correct += ( |
|
|
247 |
( |
|
|
248 |
lam * predicted.eq(b_labels.data).cpu().sum().float() |
|
|
249 |
+ (1 - lam) * predicted.eq(b_labels1.data).cpu().sum().float() |
|
|
250 |
) |
|
|
251 |
).item() |
|
|
252 |
|
|
|
253 |
self.optimizer.zero_grad() |
|
|
254 |
loss.backward() |
|
|
255 |
self.optimizer.step() |
|
|
256 |
|
|
|
257 |
# eval |
|
|
258 |
self.iteration_number += 1 |
|
|
259 |
if self.iteration_number % self.args.eval_after == 0: |
|
|
260 |
avg_loss = train_loss / total |
|
|
261 |
acc = 100.0 * correct / total |
|
|
262 |
# print('Train loss: {}, Train acc: {}'.format(avg_loss, acc)) |
|
|
263 |
train_loss = 0 |
|
|
264 |
total = 0 |
|
|
265 |
correct = 0 |
|
|
266 |
|
|
|
267 |
val_loss, val_acc = self.test(iterator=self.val_iterator) |
|
|
268 |
# print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) |
|
|
269 |
if val_acc > self.best_val_acc: |
|
|
270 |
torch.save(self.model.state_dict(), self.model_save_path) |
|
|
271 |
self.best_val_acc = val_acc |
|
|
272 |
self.val_patience = 0 |
|
|
273 |
else: |
|
|
274 |
self.val_patience += 1 |
|
|
275 |
if self.val_patience == self.config.patience: |
|
|
276 |
self.early_stop = True |
|
|
277 |
return |
|
|
278 |
with open(self.log_path, "a", newline="") as out: |
|
|
279 |
writer = csv.writer(out) |
|
|
280 |
writer.writerow( |
|
|
281 |
["train", epoch, self.iteration_number, avg_loss, acc] |
|
|
282 |
) |
|
|
283 |
writer.writerow( |
|
|
284 |
["val", epoch, self.iteration_number, val_loss, val_acc] |
|
|
285 |
) |
|
|
286 |
self.model.train() |
|
|
287 |
|
|
|
288 |
def run(self): |
|
|
289 |
for epoch in range(self.args.epoch): |
|
|
290 |
print( |
|
|
291 |
"------------------------------------- Epoch {} -------------------------------------".format( |
|
|
292 |
epoch |
|
|
293 |
) |
|
|
294 |
) |
|
|
295 |
if self.args.method == "none": |
|
|
296 |
self.train(epoch) |
|
|
297 |
else: |
|
|
298 |
self.train_mixup(epoch) |
|
|
299 |
if self.early_stop: |
|
|
300 |
break |
|
|
301 |
print("Training complete!") |
|
|
302 |
print("Best Validation Acc: ", self.best_val_acc) |
|
|
303 |
|
|
|
304 |
self.model.load_state_dict(torch.load(self.model_save_path)) |
|
|
305 |
# train_loss, train_acc = self.test(self.train_iterator) |
|
|
306 |
val_loss, val_acc = self.test(self.val_iterator) |
|
|
307 |
test_loss, test_acc = self.test(self.test_iterator) |
|
|
308 |
|
|
|
309 |
with open(self.log_path, "a", newline="") as out: |
|
|
310 |
writer = csv.writer(out) |
|
|
311 |
# writer.writerow(['train', -1, -1, train_loss, train_acc]) |
|
|
312 |
writer.writerow(["val", -1, -1, val_loss, val_acc]) |
|
|
313 |
writer.writerow(["test", -1, -1, test_loss, test_acc]) |
|
|
314 |
|
|
|
315 |
# print('Train loss: {}, Train acc: {}'.format(train_loss, train_acc)) |
|
|
316 |
print("Val loss: {}, Val acc: {}".format(val_loss, val_acc)) |
|
|
317 |
print("Test loss: {}, Test acc: {}".format(test_loss, test_acc)) |
|
|
318 |
|
|
|
319 |
return val_acc, test_acc |
|
|
320 |
|
|
|
321 |
|
|
|
322 |
if __name__ == "__main__": |
|
|
323 |
args = parse_args() |
|
|
324 |
num_runs = args.num_runs |
|
|
325 |
|
|
|
326 |
test_acc = [] |
|
|
327 |
val_acc = [] |
|
|
328 |
|
|
|
329 |
for i in range(num_runs): |
|
|
330 |
cls = Classification(args) |
|
|
331 |
val, test = cls.run() |
|
|
332 |
val_acc.append(val) |
|
|
333 |
test_acc.append(test) |
|
|
334 |
args.seed += 1 |
|
|
335 |
|
|
|
336 |
with open(os.path.join(args.save_path, args.name + "_result.txt", "a")) as f: |
|
|
337 |
f.write(str(args)) |
|
|
338 |
f.write("val acc:" + str(val_acc) + "\n") |
|
|
339 |
f.write("test acc:" + str(test_acc) + "\n") |
|
|
340 |
f.write("mean val acc:" + str(np.mean(val_acc)) + "\n") |
|
|
341 |
f.write("std val acc:" + str(np.std(val_acc, ddof=1)) + "\n") |
|
|
342 |
f.write("mean test acc:" + str(np.mean(test_acc)) + "\n") |
|
|
343 |
f.write("std test acc:" + str(np.std(test_acc, ddof=1)) + "\n\n\n") |