--- a +++ b/src/codellama-main/example_completion.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from typing import Optional + +import fire + +from llama import Llama + + +def main( + ckpt_dir: str, + tokenizer_path: str, + temperature: float = 0.2, + top_p: float = 0.9, + max_seq_len: int = 256, + max_batch_size: int = 4, + max_gen_len: Optional[int] = None, +): + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + ) + + prompts = [ + # For these prompts, the expected answer is the natural continuation of the prompt + """\ +import socket + +def ping_exponential_backoff(host: str):""", + """\ +import argparse + +def main(string: str): + print(string) + print(string[::-1]) + +if __name__ == "__main__":""" + ] + results = generator.text_completion( + prompts, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + for prompt, result in zip(prompts, results): + print(prompt) + print(f"> {result['generation']}") + print("\n==================================\n") + + +if __name__ == "__main__": + fire.Fire(main) +