[d3ab9c]: / bert_mixup / early_mixup / args_parser.py

Download this file

111 lines (109 with data), 3.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Mixup for text classification")
parser.add_argument(
"--name", default="cnn-text-fine-tune", type=str, help="name of the experiment"
)
parser.add_argument(
"--num-labels",
type=int,
default=2,
metavar="L",
help="number of labels of the train dataset (default: 2)",
)
parser.add_argument(
"--model-name-or-path",
type=str,
default="shahrukhx01/smole-bert",
metavar="M",
help="name of the pre-trained transformer model from hf hub",
)
parser.add_argument(
"--dataset-name",
type=str,
default="bace",
metavar="D",
help="name of the molecule net dataset (default: bace) all: bace, bbbp",
)
parser.add_argument(
"--cuda",
default=True,
type=lambda x: (str(x).lower() == "true"),
help="use cuda if available",
)
parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
parser.add_argument("--dropout", default=0.5, type=float, help="dropout rate")
parser.add_argument("--decay", default=0.0, type=float, help="weight decay")
parser.add_argument("--seed", default=1, type=int, help="random seed")
parser.add_argument(
"--batch-size", default=50, type=int, help="batch size (default: 128)"
)
parser.add_argument(
"--epoch", default=20, type=int, help="total epochs (default: 20)"
)
parser.add_argument(
"--fine-tune",
default=True,
type=lambda x: (str(x).lower() == "true"),
help="whether to fine-tune embedding or not",
)
parser.add_argument(
"--method",
default="embed",
type=str,
help="which mixing method to use (default: none)",
)
parser.add_argument(
"--alpha",
default=1.0,
type=float,
help="mixup interpolation coefficient (default: 1)",
)
parser.add_argument(
"--save-path", default="out", type=str, help="output log/result directory"
)
parser.add_argument("--num-runs", default=1, type=int, help="number of runs")
parser.add_argument(
"--debug",
type=int,
default=0,
metavar="DB",
help="flag to enable debug mode for dev (default: 0)",
)
parser.add_argument(
"--samples-per-class",
type=int,
default=-1,
metavar="SPC",
help="no. of samples per class label to sample for SSL (default: 250)",
)
parser.add_argument(
"--n-augment",
type=int,
default=0,
metavar="NAUG",
help="number of enumeration augmentations",
)
parser.add_argument(
"--eval-after",
type=int,
default=10,
metavar="EA",
help="number of epochs after which model is evaluated on test set (default: 10)",
)
parser.add_argument(
"--patience",
type=int,
default=10,
metavar="PAT",
help="Patience epochs when doing model selection if the model does not improve",
)
parser.add_argument(
"--out-file",
type=str,
default="eval_result.csv",
metavar="OF",
help="outpul file for logging metrics",
)
args = parser.parse_args()
return args