|
a |
|
b/src/codellama-main/example_infilling.py |
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates. |
|
|
2 |
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. |
|
|
3 |
|
|
|
4 |
import fire |
|
|
5 |
|
|
|
6 |
from llama import Llama |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def main( |
|
|
10 |
ckpt_dir: str, |
|
|
11 |
tokenizer_path: str, |
|
|
12 |
temperature: float = 0.0, |
|
|
13 |
top_p: float = 0.9, |
|
|
14 |
max_seq_len: int = 192, |
|
|
15 |
max_gen_len: int = 128, |
|
|
16 |
max_batch_size: int = 4, |
|
|
17 |
): |
|
|
18 |
generator = Llama.build( |
|
|
19 |
ckpt_dir=ckpt_dir, |
|
|
20 |
tokenizer_path=tokenizer_path, |
|
|
21 |
max_seq_len=max_seq_len, |
|
|
22 |
max_batch_size=max_batch_size, |
|
|
23 |
) |
|
|
24 |
|
|
|
25 |
prompts = [ |
|
|
26 |
'''def remove_non_ascii(s: str) -> str: |
|
|
27 |
""" <FILL> |
|
|
28 |
return result |
|
|
29 |
''', |
|
|
30 |
"""# Installation instructions: |
|
|
31 |
```bash |
|
|
32 |
<FILL> |
|
|
33 |
``` |
|
|
34 |
This downloads the LLaMA inference code and installs the repository as a local pip package. |
|
|
35 |
""", |
|
|
36 |
"""class InterfaceManagerFactory(AbstractManagerFactory): |
|
|
37 |
def __init__(<FILL> |
|
|
38 |
def main(): |
|
|
39 |
factory = InterfaceManagerFactory(start=datetime.now()) |
|
|
40 |
managers = [] |
|
|
41 |
for i in range(10): |
|
|
42 |
managers.append(factory.build(id=i)) |
|
|
43 |
""", |
|
|
44 |
"""/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/ |
|
|
45 |
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) : |
|
|
46 |
π₁ P = 0 ↔ <FILL> = 0 := |
|
|
47 |
begin |
|
|
48 |
split, |
|
|
49 |
{ intros h f, |
|
|
50 |
rw pi_1_etalisation at h, |
|
|
51 |
simp [h], |
|
|
52 |
refl |
|
|
53 |
}, |
|
|
54 |
{ intro h, |
|
|
55 |
have := @quasi_adjoint C D P, |
|
|
56 |
simp [←pi_1_etalisation, this, h], |
|
|
57 |
refl |
|
|
58 |
} |
|
|
59 |
end |
|
|
60 |
""", |
|
|
61 |
] |
|
|
62 |
prefixes = [p.split("<FILL>")[0] for p in prompts] |
|
|
63 |
suffixes = [p.split("<FILL>")[1] for p in prompts] |
|
|
64 |
results = generator.text_infilling( |
|
|
65 |
prefixes=prefixes, |
|
|
66 |
suffixes=suffixes, |
|
|
67 |
max_gen_len=max_gen_len, |
|
|
68 |
temperature=temperature, |
|
|
69 |
top_p=top_p, |
|
|
70 |
) |
|
|
71 |
for prompt, result in zip(prompts, results): |
|
|
72 |
print("\n================= Prompt text =================\n") |
|
|
73 |
print(prompt) |
|
|
74 |
print("\n================= Filled text =================\n") |
|
|
75 |
print(result["full_text"]) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
if __name__ == "__main__": |
|
|
79 |
fire.Fire(main) |