a b/bert_mixup/late_mixup/train.py
1
import argparse
2
import csv
3
import os
4
import random
5
6
import numpy as np
7
import torch
8
import torch.backends.cudnn as cudnn
9
import torch.nn as nn
10
from tqdm import tqdm
11
12
from models.text_bert import TextBERT
13
from data_loader import MoleculeDataLoader
14
15
16
def parse_args():
17
    parser = argparse.ArgumentParser(description="Mixup for text classification")
18
    parser.add_argument(
19
        "--name", default="cnn-text-fine-tune", type=str, help="name of the experiment"
20
    )
21
    parser.add_argument(
22
        "--num-labels",
23
        type=int,
24
        default=2,
25
        metavar="L",
26
        help="number of labels of the train dataset (default: 2)",
27
    )
28
    parser.add_argument(
29
        "--model-name-or-path",
30
        type=str,
31
        default="shahrukhx01/smole-bert",
32
        metavar="M",
33
        help="name of the pre-trained transformer model from hf hub",
34
    )
35
    parser.add_argument(
36
        "--dataset-name",
37
        type=str,
38
        default="bace",
39
        metavar="D",
40
        help="name of the molecule net dataset (default: bace) all: bace, bbbp",
41
    )
42
    parser.add_argument(
43
        "--cuda",
44
        default=True,
45
        type=lambda x: (str(x).lower() == "true"),
46
        help="use cuda if available",
47
    )
48
    parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
49
    parser.add_argument("--dropout", default=0.5, type=float, help="dropout rate")
50
    parser.add_argument("--decay", default=0.0, type=float, help="weight decay")
51
    parser.add_argument(
52
        "--model", default="TextCNN", type=str, help="model type (default: TextCNN)"
53
    )
54
    parser.add_argument("--seed", default=1, type=int, help="random seed")
55
    parser.add_argument(
56
        "--batch-size", default=50, type=int, help="batch size (default: 128)"
57
    )
58
    parser.add_argument(
59
        "--epoch", default=50, type=int, help="total epochs (default: 200)"
60
    )
61
    parser.add_argument(
62
        "--fine-tune",
63
        default=True,
64
        type=lambda x: (str(x).lower() == "true"),
65
        help="whether to fine-tune embedding or not",
66
    )
67
    parser.add_argument(
68
        "--method",
69
        default="embed",
70
        type=str,
71
        help="which mixing method to use (default: none)",
72
    )
73
    parser.add_argument(
74
        "--alpha",
75
        default=1.0,
76
        type=float,
77
        help="mixup interpolation coefficient (default: 1)",
78
    )
79
    parser.add_argument(
80
        "--save-path", default="out", type=str, help="output log/result directory"
81
    )
82
    parser.add_argument("--num-runs", default=10, type=int, help="number of runs")
83
    parser.add_argument(
84
        "--debug",
85
        type=int,
86
        default=0,
87
        metavar="DB",
88
        help="flag to enable debug mode for dev (default: 0)",
89
    )
90
91
    parser.add_argument(
92
        "--samples-per-class",
93
        type=int,
94
        default=-1,
95
        metavar="SPC",
96
        help="no. of samples per class label to sample for SSL (default: 250)",
97
    )
98
    parser.add_argument(
99
        "--n-augment",
100
        type=int,
101
        default=0,
102
        metavar="NAUG",
103
        help="number of enumeration augmentations",
104
    )
105
    parser.add_argument(
106
        "--eval-after",
107
        type=int,
108
        default=10,
109
        metavar="EA",
110
        help="number of epochs after which model is evaluated on test set (default: 10)",
111
    )
112
    args = parser.parse_args()
113
    return args
114
115
116
def mixup_criterion_cross_entropy(criterion, pred, y_a, y_b, lam):
117
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
118
119
120
class Classification:
121
    def __init__(self, args):
122
        self.args = args
123
124
        self.use_cuda = args.cuda and torch.cuda.is_available()
125
126
        # for reproducibility
127
        torch.manual_seed(args.seed)
128
        torch.backends.cudnn.deterministic = True
129
        torch.backends.cudnn.benchmark = False
130
        np.random.seed(args.seed)
131
        random.seed(args.seed)
132
133
        # data loaders
134
        data_loaders = MoleculeDataLoader(
135
            dataset_name=args.dataset_name,
136
            batch_size=args.batch_size,
137
            debug=args.debug,
138
            n_augment=args.n_augment,
139
            samples_per_class=args.samples_per_class,
140
            model_name_or_path=args.model_name_or_path,
141
        )
142
        data_loaders.create_supervised_loaders(samples_per_class=args.samples_per_class)
143
        # model
144
145
        self.model = TextBERT(
146
            pretrained_model=args.model_name_or_path,
147
            num_class=args.num_labels,
148
            fine_tune=args.fine_tune,
149
            dropout=args.dropout,
150
        )
151
        self.device = torch.device(
152
            "cuda" if (args.cuda and torch.cuda.is_available()) else "cpu"
153
        )
154
        self.model.to(self.device)
155
156
        # logs
157
        os.makedirs(args.save_path, exist_ok=True)
158
        self.model_save_path = os.path.join(args.save_path, args.name + "_weights.pt")
159
        self.log_path = os.path.join(args.save_path, args.name + "_logs.csv")
160
        print(str(args))
161
        with open(self.log_path, "a") as f:
162
            f.write(str(args) + "\n")
163
        with open(self.log_path, "a", newline="") as out:
164
            writer = csv.writer(out)
165
            writer.writerow(["mode", "epoch", "step", "loss", "acc"])
166
167
        # optimizer
168
        self.criterion = nn.CrossEntropyLoss()
169
        self.optimizer = torch.optim.Adam(
170
            self.model.parameters(), lr=args.lr, weight_decay=args.decay
171
        )
172
173
        # for early stopping
174
        self.best_val_acc = 0
175
        self.early_stop = False
176
        self.val_patience = (
177
            0  # successive iteration when validation acc did not improve
178
        )
179
180
        self.iteration_number = 0
181
182
    def get_perm(self, x):
183
        """get random permutation"""
184
        batch_size = x.size()[0]
185
        if self.use_cuda:
186
            index = torch.randperm(batch_size).cuda()
187
        else:
188
            index = torch.randperm(batch_size)
189
        return index
190
191
    def test(self, iterator):
192
        self.model.eval()
193
        test_loss = 0
194
        total = 0
195
        correct = 0
196
        with torch.no_grad():
197
            # for _, batch in tqdm(enumerate(iterator), total=len(iterator), desc='test'):
198
            for _, batch in enumerate(iterator):
199
                batch = tuple(t.to(self.device) for t in batch)
200
                b_input_ids, b_input_mask, b_labels = batch
201
                y_pred = self.model(b_input_ids, b_input_mask)
202
                loss = self.criterion(y_pred, b_labels)
203
                test_loss += loss.item() * b_labels.shape[0]
204
                total += b_labels.shape[0]
205
                correct += torch.sum(torch.argmax(y_pred, dim=1) == b_labels).item()
206
207
        avg_loss = test_loss / total
208
        acc = 100.0 * correct / total
209
        return avg_loss, acc
210
211
    def train_mixup(self, epoch):
212
        self.model.train()
213
        train_loss = 0
214
        total = 0
215
        correct = 0
216
        for _, batch in enumerate(self.train_iterator):
217
            batch = tuple(t.to(self.device) for t in batch)
218
            lam = np.random.beta(self.args.alpha, self.args.alpha)
219
            b_input_ids, b_input_mask, b_labels = batch
220
            index = self.get_perm(b_input_ids)
221
            b_input_ids1 = b_input_ids1[:, index]
222
            b_input_mask1 = b_input_mask[:, index]
223
            b_labels1 = b_labels[:, index]
224
225
            if self.args.method == "embed":
226
                y_pred = self.model.forward_mix_embed(
227
                    b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam
228
                )
229
            elif self.args.method == "sent":
230
                y_pred = self.model.forward_mix_sent(
231
                    b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam
232
                )
233
            elif self.args.method == "encoder":
234
                y_pred = self.model.forward_mix_encoder(
235
                    b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam
236
                )
237
            else:
238
                raise ValueError("invalid method name")
239
240
            loss = mixup_criterion_cross_entropy(
241
                self.criterion, y_pred, b_labels, b_labels1, lam
242
            )
243
            train_loss += loss.item() * b_labels.shape[0]
244
            total += b_labels.shape[0]
245
            _, predicted = torch.max(y_pred.data, 1)
246
            correct += (
247
                (
248
                    lam * predicted.eq(b_labels.data).cpu().sum().float()
249
                    + (1 - lam) * predicted.eq(b_labels1.data).cpu().sum().float()
250
                )
251
            ).item()
252
253
            self.optimizer.zero_grad()
254
            loss.backward()
255
            self.optimizer.step()
256
257
            # eval
258
            self.iteration_number += 1
259
            if self.iteration_number % self.args.eval_after == 0:
260
                avg_loss = train_loss / total
261
                acc = 100.0 * correct / total
262
                # print('Train loss: {}, Train acc: {}'.format(avg_loss, acc))
263
                train_loss = 0
264
                total = 0
265
                correct = 0
266
267
                val_loss, val_acc = self.test(iterator=self.val_iterator)
268
                # print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc))
269
                if val_acc > self.best_val_acc:
270
                    torch.save(self.model.state_dict(), self.model_save_path)
271
                    self.best_val_acc = val_acc
272
                    self.val_patience = 0
273
                else:
274
                    self.val_patience += 1
275
                    if self.val_patience == self.config.patience:
276
                        self.early_stop = True
277
                        return
278
                with open(self.log_path, "a", newline="") as out:
279
                    writer = csv.writer(out)
280
                    writer.writerow(
281
                        ["train", epoch, self.iteration_number, avg_loss, acc]
282
                    )
283
                    writer.writerow(
284
                        ["val", epoch, self.iteration_number, val_loss, val_acc]
285
                    )
286
                self.model.train()
287
288
    def run(self):
289
        for epoch in range(self.args.epoch):
290
            print(
291
                "------------------------------------- Epoch {} -------------------------------------".format(
292
                    epoch
293
                )
294
            )
295
            if self.args.method == "none":
296
                self.train(epoch)
297
            else:
298
                self.train_mixup(epoch)
299
            if self.early_stop:
300
                break
301
        print("Training complete!")
302
        print("Best Validation Acc: ", self.best_val_acc)
303
304
        self.model.load_state_dict(torch.load(self.model_save_path))
305
        # train_loss, train_acc = self.test(self.train_iterator)
306
        val_loss, val_acc = self.test(self.val_iterator)
307
        test_loss, test_acc = self.test(self.test_iterator)
308
309
        with open(self.log_path, "a", newline="") as out:
310
            writer = csv.writer(out)
311
            # writer.writerow(['train', -1, -1, train_loss, train_acc])
312
            writer.writerow(["val", -1, -1, val_loss, val_acc])
313
            writer.writerow(["test", -1, -1, test_loss, test_acc])
314
315
        # print('Train loss: {}, Train acc: {}'.format(train_loss, train_acc))
316
        print("Val loss: {}, Val acc: {}".format(val_loss, val_acc))
317
        print("Test loss: {}, Test acc: {}".format(test_loss, test_acc))
318
319
        return val_acc, test_acc
320
321
322
if __name__ == "__main__":
323
    args = parse_args()
324
    num_runs = args.num_runs
325
326
    test_acc = []
327
    val_acc = []
328
329
    for i in range(num_runs):
330
        cls = Classification(args)
331
        val, test = cls.run()
332
        val_acc.append(val)
333
        test_acc.append(test)
334
        args.seed += 1
335
336
    with open(os.path.join(args.save_path, args.name + "_result.txt", "a")) as f:
337
        f.write(str(args))
338
        f.write("val acc:" + str(val_acc) + "\n")
339
        f.write("test acc:" + str(test_acc) + "\n")
340
        f.write("mean val acc:" + str(np.mean(val_acc)) + "\n")
341
        f.write("std val acc:" + str(np.std(val_acc, ddof=1)) + "\n")
342
        f.write("mean test acc:" + str(np.mean(test_acc)) + "\n")
343
        f.write("std test acc:" + str(np.std(test_acc, ddof=1)) + "\n\n\n")