a b/src/codellama-main/example_completion.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
from typing import Optional
5
6
import fire
7
8
from llama import Llama
9
10
11
def main(
12
    ckpt_dir: str,
13
    tokenizer_path: str,
14
    temperature: float = 0.2,
15
    top_p: float = 0.9,
16
    max_seq_len: int = 256,
17
    max_batch_size: int = 4,
18
    max_gen_len: Optional[int] = None,
19
):
20
    generator = Llama.build(
21
        ckpt_dir=ckpt_dir,
22
        tokenizer_path=tokenizer_path,
23
        max_seq_len=max_seq_len,
24
        max_batch_size=max_batch_size,
25
    )
26
27
    prompts = [
28
        # For these prompts, the expected answer is the natural continuation of the prompt
29
        """\
30
import socket
31
32
def ping_exponential_backoff(host: str):""",
33
        """\
34
import argparse
35
36
def main(string: str):
37
    print(string)
38
    print(string[::-1])
39
40
if __name__ == "__main__":"""
41
    ]
42
    results = generator.text_completion(
43
        prompts,
44
        max_gen_len=max_gen_len,
45
        temperature=temperature,
46
        top_p=top_p,
47
    )
48
    for prompt, result in zip(prompts, results):
49
        print(prompt)
50
        print(f"> {result['generation']}")
51
        print("\n==================================\n")
52
53
54
if __name__ == "__main__":
55
    fire.Fire(main)
56