a b/src/codellama-main/example_infilling.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
import fire
5
6
from llama import Llama
7
8
9
def main(
10
    ckpt_dir: str,
11
    tokenizer_path: str,
12
    temperature: float = 0.0,
13
    top_p: float = 0.9,
14
    max_seq_len: int = 192,
15
    max_gen_len: int = 128,
16
    max_batch_size: int = 4,
17
):
18
    generator = Llama.build(
19
        ckpt_dir=ckpt_dir,
20
        tokenizer_path=tokenizer_path,
21
        max_seq_len=max_seq_len,
22
        max_batch_size=max_batch_size,
23
    )
24
25
    prompts = [
26
        '''def remove_non_ascii(s: str) -> str:
27
    """ <FILL>
28
    return result
29
''',
30
        """# Installation instructions:
31
    ```bash
32
<FILL>
33
    ```
34
This downloads the LLaMA inference code and installs the repository as a local pip package.
35
""",
36
        """class InterfaceManagerFactory(AbstractManagerFactory):
37
    def __init__(<FILL>
38
def main():
39
    factory = InterfaceManagerFactory(start=datetime.now())
40
    managers = []
41
    for i in range(10):
42
        managers.append(factory.build(id=i))
43
""",
44
        """/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/
45
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :
46
  π₁ P = 0 ↔ <FILL> = 0 :=
47
begin
48
  split,
49
  { intros h f,
50
    rw pi_1_etalisation at h,
51
    simp [h],
52
    refl
53
  },
54
  { intro h,
55
    have := @quasi_adjoint C D P,
56
    simp [←pi_1_etalisation, this, h],
57
    refl
58
  }
59
end
60
""",
61
    ]
62
    prefixes = [p.split("<FILL>")[0] for p in prompts]
63
    suffixes = [p.split("<FILL>")[1] for p in prompts]
64
    results = generator.text_infilling(
65
        prefixes=prefixes,
66
        suffixes=suffixes,
67
        max_gen_len=max_gen_len,
68
        temperature=temperature,
69
        top_p=top_p,
70
    )
71
    for prompt, result in zip(prompts, results):
72
        print("\n================= Prompt text =================\n")
73
        print(prompt)
74
        print("\n================= Filled text =================\n")
75
        print(result["full_text"])
76
77
78
if __name__ == "__main__":
79
    fire.Fire(main)