Switch to side-by-side view

--- a
+++ b/bert_mixup/late_mixup/train.py
@@ -0,0 +1,343 @@
+import argparse
+import csv
+import os
+import random
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+import torch.nn as nn
+from tqdm import tqdm
+
+from models.text_bert import TextBERT
+from data_loader import MoleculeDataLoader
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Mixup for text classification")
+    parser.add_argument(
+        "--name", default="cnn-text-fine-tune", type=str, help="name of the experiment"
+    )
+    parser.add_argument(
+        "--num-labels",
+        type=int,
+        default=2,
+        metavar="L",
+        help="number of labels of the train dataset (default: 2)",
+    )
+    parser.add_argument(
+        "--model-name-or-path",
+        type=str,
+        default="shahrukhx01/smole-bert",
+        metavar="M",
+        help="name of the pre-trained transformer model from hf hub",
+    )
+    parser.add_argument(
+        "--dataset-name",
+        type=str,
+        default="bace",
+        metavar="D",
+        help="name of the molecule net dataset (default: bace) all: bace, bbbp",
+    )
+    parser.add_argument(
+        "--cuda",
+        default=True,
+        type=lambda x: (str(x).lower() == "true"),
+        help="use cuda if available",
+    )
+    parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
+    parser.add_argument("--dropout", default=0.5, type=float, help="dropout rate")
+    parser.add_argument("--decay", default=0.0, type=float, help="weight decay")
+    parser.add_argument(
+        "--model", default="TextCNN", type=str, help="model type (default: TextCNN)"
+    )
+    parser.add_argument("--seed", default=1, type=int, help="random seed")
+    parser.add_argument(
+        "--batch-size", default=50, type=int, help="batch size (default: 128)"
+    )
+    parser.add_argument(
+        "--epoch", default=50, type=int, help="total epochs (default: 200)"
+    )
+    parser.add_argument(
+        "--fine-tune",
+        default=True,
+        type=lambda x: (str(x).lower() == "true"),
+        help="whether to fine-tune embedding or not",
+    )
+    parser.add_argument(
+        "--method",
+        default="embed",
+        type=str,
+        help="which mixing method to use (default: none)",
+    )
+    parser.add_argument(
+        "--alpha",
+        default=1.0,
+        type=float,
+        help="mixup interpolation coefficient (default: 1)",
+    )
+    parser.add_argument(
+        "--save-path", default="out", type=str, help="output log/result directory"
+    )
+    parser.add_argument("--num-runs", default=10, type=int, help="number of runs")
+    parser.add_argument(
+        "--debug",
+        type=int,
+        default=0,
+        metavar="DB",
+        help="flag to enable debug mode for dev (default: 0)",
+    )
+
+    parser.add_argument(
+        "--samples-per-class",
+        type=int,
+        default=-1,
+        metavar="SPC",
+        help="no. of samples per class label to sample for SSL (default: 250)",
+    )
+    parser.add_argument(
+        "--n-augment",
+        type=int,
+        default=0,
+        metavar="NAUG",
+        help="number of enumeration augmentations",
+    )
+    parser.add_argument(
+        "--eval-after",
+        type=int,
+        default=10,
+        metavar="EA",
+        help="number of epochs after which model is evaluated on test set (default: 10)",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def mixup_criterion_cross_entropy(criterion, pred, y_a, y_b, lam):
+    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
+
+
+class Classification:
+    def __init__(self, args):
+        self.args = args
+
+        self.use_cuda = args.cuda and torch.cuda.is_available()
+
+        # for reproducibility
+        torch.manual_seed(args.seed)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+        np.random.seed(args.seed)
+        random.seed(args.seed)
+
+        # data loaders
+        data_loaders = MoleculeDataLoader(
+            dataset_name=args.dataset_name,
+            batch_size=args.batch_size,
+            debug=args.debug,
+            n_augment=args.n_augment,
+            samples_per_class=args.samples_per_class,
+            model_name_or_path=args.model_name_or_path,
+        )
+        data_loaders.create_supervised_loaders(samples_per_class=args.samples_per_class)
+        # model
+
+        self.model = TextBERT(
+            pretrained_model=args.model_name_or_path,
+            num_class=args.num_labels,
+            fine_tune=args.fine_tune,
+            dropout=args.dropout,
+        )
+        self.device = torch.device(
+            "cuda" if (args.cuda and torch.cuda.is_available()) else "cpu"
+        )
+        self.model.to(self.device)
+
+        # logs
+        os.makedirs(args.save_path, exist_ok=True)
+        self.model_save_path = os.path.join(args.save_path, args.name + "_weights.pt")
+        self.log_path = os.path.join(args.save_path, args.name + "_logs.csv")
+        print(str(args))
+        with open(self.log_path, "a") as f:
+            f.write(str(args) + "\n")
+        with open(self.log_path, "a", newline="") as out:
+            writer = csv.writer(out)
+            writer.writerow(["mode", "epoch", "step", "loss", "acc"])
+
+        # optimizer
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.Adam(
+            self.model.parameters(), lr=args.lr, weight_decay=args.decay
+        )
+
+        # for early stopping
+        self.best_val_acc = 0
+        self.early_stop = False
+        self.val_patience = (
+            0  # successive iteration when validation acc did not improve
+        )
+
+        self.iteration_number = 0
+
+    def get_perm(self, x):
+        """get random permutation"""
+        batch_size = x.size()[0]
+        if self.use_cuda:
+            index = torch.randperm(batch_size).cuda()
+        else:
+            index = torch.randperm(batch_size)
+        return index
+
+    def test(self, iterator):
+        self.model.eval()
+        test_loss = 0
+        total = 0
+        correct = 0
+        with torch.no_grad():
+            # for _, batch in tqdm(enumerate(iterator), total=len(iterator), desc='test'):
+            for _, batch in enumerate(iterator):
+                batch = tuple(t.to(self.device) for t in batch)
+                b_input_ids, b_input_mask, b_labels = batch
+                y_pred = self.model(b_input_ids, b_input_mask)
+                loss = self.criterion(y_pred, b_labels)
+                test_loss += loss.item() * b_labels.shape[0]
+                total += b_labels.shape[0]
+                correct += torch.sum(torch.argmax(y_pred, dim=1) == b_labels).item()
+
+        avg_loss = test_loss / total
+        acc = 100.0 * correct / total
+        return avg_loss, acc
+
+    def train_mixup(self, epoch):
+        self.model.train()
+        train_loss = 0
+        total = 0
+        correct = 0
+        for _, batch in enumerate(self.train_iterator):
+            batch = tuple(t.to(self.device) for t in batch)
+            lam = np.random.beta(self.args.alpha, self.args.alpha)
+            b_input_ids, b_input_mask, b_labels = batch
+            index = self.get_perm(b_input_ids)
+            b_input_ids1 = b_input_ids1[:, index]
+            b_input_mask1 = b_input_mask[:, index]
+            b_labels1 = b_labels[:, index]
+
+            if self.args.method == "embed":
+                y_pred = self.model.forward_mix_embed(
+                    b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam
+                )
+            elif self.args.method == "sent":
+                y_pred = self.model.forward_mix_sent(
+                    b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam
+                )
+            elif self.args.method == "encoder":
+                y_pred = self.model.forward_mix_encoder(
+                    b_input_ids, b_input_mask, b_input_ids1, b_input_mask1, lam
+                )
+            else:
+                raise ValueError("invalid method name")
+
+            loss = mixup_criterion_cross_entropy(
+                self.criterion, y_pred, b_labels, b_labels1, lam
+            )
+            train_loss += loss.item() * b_labels.shape[0]
+            total += b_labels.shape[0]
+            _, predicted = torch.max(y_pred.data, 1)
+            correct += (
+                (
+                    lam * predicted.eq(b_labels.data).cpu().sum().float()
+                    + (1 - lam) * predicted.eq(b_labels1.data).cpu().sum().float()
+                )
+            ).item()
+
+            self.optimizer.zero_grad()
+            loss.backward()
+            self.optimizer.step()
+
+            # eval
+            self.iteration_number += 1
+            if self.iteration_number % self.args.eval_after == 0:
+                avg_loss = train_loss / total
+                acc = 100.0 * correct / total
+                # print('Train loss: {}, Train acc: {}'.format(avg_loss, acc))
+                train_loss = 0
+                total = 0
+                correct = 0
+
+                val_loss, val_acc = self.test(iterator=self.val_iterator)
+                # print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc))
+                if val_acc > self.best_val_acc:
+                    torch.save(self.model.state_dict(), self.model_save_path)
+                    self.best_val_acc = val_acc
+                    self.val_patience = 0
+                else:
+                    self.val_patience += 1
+                    if self.val_patience == self.config.patience:
+                        self.early_stop = True
+                        return
+                with open(self.log_path, "a", newline="") as out:
+                    writer = csv.writer(out)
+                    writer.writerow(
+                        ["train", epoch, self.iteration_number, avg_loss, acc]
+                    )
+                    writer.writerow(
+                        ["val", epoch, self.iteration_number, val_loss, val_acc]
+                    )
+                self.model.train()
+
+    def run(self):
+        for epoch in range(self.args.epoch):
+            print(
+                "------------------------------------- Epoch {} -------------------------------------".format(
+                    epoch
+                )
+            )
+            if self.args.method == "none":
+                self.train(epoch)
+            else:
+                self.train_mixup(epoch)
+            if self.early_stop:
+                break
+        print("Training complete!")
+        print("Best Validation Acc: ", self.best_val_acc)
+
+        self.model.load_state_dict(torch.load(self.model_save_path))
+        # train_loss, train_acc = self.test(self.train_iterator)
+        val_loss, val_acc = self.test(self.val_iterator)
+        test_loss, test_acc = self.test(self.test_iterator)
+
+        with open(self.log_path, "a", newline="") as out:
+            writer = csv.writer(out)
+            # writer.writerow(['train', -1, -1, train_loss, train_acc])
+            writer.writerow(["val", -1, -1, val_loss, val_acc])
+            writer.writerow(["test", -1, -1, test_loss, test_acc])
+
+        # print('Train loss: {}, Train acc: {}'.format(train_loss, train_acc))
+        print("Val loss: {}, Val acc: {}".format(val_loss, val_acc))
+        print("Test loss: {}, Test acc: {}".format(test_loss, test_acc))
+
+        return val_acc, test_acc
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    num_runs = args.num_runs
+
+    test_acc = []
+    val_acc = []
+
+    for i in range(num_runs):
+        cls = Classification(args)
+        val, test = cls.run()
+        val_acc.append(val)
+        test_acc.append(test)
+        args.seed += 1
+
+    with open(os.path.join(args.save_path, args.name + "_result.txt", "a")) as f:
+        f.write(str(args))
+        f.write("val acc:" + str(val_acc) + "\n")
+        f.write("test acc:" + str(test_acc) + "\n")
+        f.write("mean val acc:" + str(np.mean(val_acc)) + "\n")
+        f.write("std val acc:" + str(np.std(val_acc, ddof=1)) + "\n")
+        f.write("mean test acc:" + str(np.mean(test_acc)) + "\n")
+        f.write("std test acc:" + str(np.std(test_acc, ddof=1)) + "\n\n\n")