--- a +++ b/docproduct/train_gpt2.py @@ -0,0 +1,96 @@ +import os +import json +from shutil import copyfile + +import tensorflow as tf +# import tensorflow.compat.v1 as tf +import tensorflow_estimator as tf_estimator + +import gpt2_estimator + +from docproduct.mqa_load_dataset import Sampler, load_dataset + +DEVICE = ["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3"] + + +def train_gpt2( + model_dir='models/gpt2', + pretrained_path='models/117M', + steps=100000, + batch_size=1, + max_seq_len=1024, + num_gpu=3, + learning_rate=0.0001): + """Function to train the GPT2 model + + For each question, we use topk qa pairs that retreived by FAISS and the question + as features, and correct answer as target to train GPT2. + + Data: my eyes hurt, go see a doctor + Feature: + question: aaa, answer: bbb, question: ccc, answer: ddd, question: my eyes hurt, answer: + Target: + go see a doctor + + + Keyword Arguments: + model_dir {str} -- Path to save the GPT2 model (default: {'models/gpt2'}) + pretrained_path {str} -- Pretrained GPT2 model path, + usually the output file of train_embedding_to_gpt2_data (default: {'models/117M'}) + steps {int} -- Number of steps of training (default: {100000}) + batch_size {int} -- Batch size per GPU (default: {4}) + num_gpu {int} -- Number of GPU to use (default: {4}) + learning_rate {float} -- Learning rate (default: {0.0001}) + """ + os.makedirs(model_dir, exist_ok=True) + + tf.compat.v1.disable_eager_execution() + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG) + mirrored_strategy = tf.distribute.MirroredStrategy( + devices=DEVICE[:num_gpu]) + learning_rate = learning_rate*num_gpu + session_config = tf.compat.v1.ConfigProto( + allow_soft_placement=True) + session_config.gpu_options.allow_growth = False + config = tf_estimator.estimator.RunConfig( + session_config=session_config, + train_distribute=mirrored_strategy, + eval_distribute=mirrored_strategy, + log_step_count_steps=50) + + gpt2_model_fn = gpt2_estimator.get_gpt2_model_fn( + accumulate_gradients=3, + learning_rate=learning_rate, + length=max_seq_len, + batch_size=batch_size, + temperature=0.7, + top_k=1 + ) + copyfile(os.path.join(pretrained_path, 'hparams.json'), + os.path.join(model_dir, 'hparams.json')) + copyfile(os.path.join(pretrained_path, 'vocab.bpe'), + os.path.join(model_dir, 'vocab.bpe')) + copyfile(os.path.join(pretrained_path, 'encoder.json'), + os.path.join(model_dir, 'encoder.json')) + hparams = gpt2_estimator.default_hparams() + with open(os.path.join(pretrained_path, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + estimator = tf_estimator.estimator.Estimator( + gpt2_model_fn, + model_dir=model_dir, + params=hparams, + config=config) + + restore_hook = gpt2_estimator.RestoreCheckpointHook(pretrained_path) + estimator.train( + lambda: gpt2_estimator.train_input_fn(batch_size=batch_size, dataset_load_fn=load_dataset, sampler=Sampler, max_seq_len=max_seq_len), max_steps=steps, hooks=[restore_hook]) + + # keep as an example + # pred = estimator.predict( + # lambda: gpt2_estimator.predict_input_fn( + # 'i am sick', batch_size=batch_size) + # ) + + +if __name__ == "__main__": + train_gpt2(steps=5000000)