a b/src/codellama-main/llama/generation.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
import json
5
import os
6
import sys
7
import time
8
from pathlib import Path
9
from typing import List, Literal, Optional, Tuple, TypedDict
10
11
import torch
12
import torch.nn.functional as F
13
from fairscale.nn.model_parallel.initialize import (
14
    get_model_parallel_rank,
15
    initialize_model_parallel,
16
    model_parallel_is_initialized,
17
)
18
19
from llama.model import ModelArgs, Transformer
20
from llama.tokenizer import Tokenizer
21
22
if torch.cuda.is_available():
23
    device = "cuda"
24
elif torch.backends.mps.is_available():
25
    device = "mps"
26
else:
27
    device = "cpu"
28
29
Role = Literal["system", "user", "assistant"]
30
31
class Message(TypedDict):
32
    role: Role
33
    content: str
34
35
36
class InfillingPrediction(TypedDict, total=False):
37
    generation: str
38
    full_text: str
39
    tokens: List[str]  # not required
40
    logprobs: List[float]  # not required
41
42
43
class CompletionPrediction(TypedDict, total=False):
44
    generation: str
45
    tokens: List[str]  # not required
46
    logprobs: List[float]  # not required
47
48
49
class ChatPrediction(TypedDict, total=False):
50
    generation: Message
51
    tokens: List[str]  # not required
52
    logprobs: List[float]  # not required
53
54
55
Dialog = List[Message]
56
57
B_INST, E_INST = "[INST]", "[/INST]"
58
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
59
60
SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
61
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
62
63
64
class Llama:
65
    @staticmethod
66
    def build(
67
        ckpt_dir: str,
68
        tokenizer_path: str,
69
        max_seq_len: int,
70
        max_batch_size: int,
71
        model_parallel_size: Optional[int] = None,
72
    ) -> "Llama":
73
        if not torch.distributed.is_initialized():
74
            if device == "cuda":
75
                torch.distributed.init_process_group("nccl")
76
            else:
77
                torch.distributed.init_process_group("gloo")
78
        if not model_parallel_is_initialized():
79
            if model_parallel_size is None:
80
                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
81
            initialize_model_parallel(model_parallel_size)
82
83
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
84
        if device == "cuda":
85
            torch.cuda.set_device(local_rank)
86
87
        # seed must be the same in all processes
88
        torch.manual_seed(1)
89
90
        if local_rank > 0:
91
            sys.stdout = open(os.devnull, "w")
92
93
        start_time = time.time()
94
        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
95
        assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
96
        assert model_parallel_size == len(
97
            checkpoints
98
        ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
99
        ckpt_path = checkpoints[get_model_parallel_rank()]
100
        checkpoint = torch.load(ckpt_path, map_location="cpu")
101
        with open(Path(ckpt_dir) / "params.json", "r") as f:
102
            params = json.loads(f.read())
103
104
        model_args: ModelArgs = ModelArgs(
105
            max_seq_len=max_seq_len,
106
            max_batch_size=max_batch_size,
107
            **params,
108
        )
109
        tokenizer = Tokenizer(model_path=tokenizer_path)
110
        model_args.vocab_size = tokenizer.n_words
111
        # support for mac
112
        if device == "cuda":
113
            if torch.cuda.is_bf16_supported():
114
                torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
115
            else:
116
                torch.set_default_tensor_type(torch.cuda.HalfTensor)
117
        #else:
118
        #    torch.set_default_tensor_type(torch.HalfTensor)
119
        model = Transformer(model_args)
120
        model.load_state_dict(checkpoint, strict=False)
121
        model.to(device)
122
        print(f"Loaded in {time.time() - start_time:.2f} seconds")
123
124
        return Llama(model, tokenizer)
125
126
    def __init__(self, model: Transformer, tokenizer: Tokenizer):
127
        self.model = model
128
        self.tokenizer = tokenizer
129
130
    @torch.inference_mode()
131
    def generate(
132
        self,
133
        prompt_tokens: List[List[int]],
134
        max_gen_len: int,
135
        temperature: float = 0.6,
136
        top_p: float = 0.9,
137
        logprobs: bool = False,
138
        echo: bool = False,
139
        stop_token: Optional[int] = None,
140
    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
141
        if stop_token is None:
142
            stop_token = self.tokenizer.eos_id
143
        params = self.model.params
144
        bsz = len(prompt_tokens)
145
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
146
147
        min_prompt_len = min(len(t) for t in prompt_tokens)
148
        max_prompt_len = max(len(t) for t in prompt_tokens)
149
        assert max_prompt_len <= params.max_seq_len
150
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
151
152
        pad_id = self.tokenizer.pad_id
153
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
154
        for k, t in enumerate(prompt_tokens):
155
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
156
        if logprobs:
157
            token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device)
158
159
        prev_pos = 0
160
        stop_reached = torch.tensor([False] * bsz, device=device)
161
        input_text_mask = tokens != pad_id
162
        for cur_pos in range(min_prompt_len, total_len):
163
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
164
            if logprobs:
165
                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
166
                    input=logits.transpose(1, 2),
167
                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
168
                    reduction="none",
169
                    ignore_index=pad_id,
170
                )
171
            if temperature > 0:
172
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
173
                next_token = sample_top_p(probs, top_p)
174
            else:
175
                next_token = torch.argmax(logits[:, -1], dim=-1)
176
177
            next_token = next_token.reshape(-1)
178
            # only replace token if prompt has already been generated
179
            next_token = torch.where(
180
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
181
            )
182
            tokens[:, cur_pos] = next_token
183
            stop_reached |= (~input_text_mask[:, cur_pos]) & (next_token == stop_token)
184
            prev_pos = cur_pos
185
            if all(stop_reached):
186
                break
187
188
        if logprobs:
189
            token_logprobs = token_logprobs.tolist()
190
        out_tokens, out_logprobs = [], []
191
        for i, toks in enumerate(tokens.tolist()):
192
            # cut to max gen len
193
            start = 0 if echo else len(prompt_tokens[i])
194
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
195
            probs = None
196
            if logprobs:
197
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
198
            # cut to stop token if present
199
            if stop_token in toks:
200
                stop_idx = toks.index(stop_token)
201
                toks = toks[:stop_idx]
202
                probs = probs[:stop_idx] if logprobs else None
203
            out_tokens.append(toks)
204
            out_logprobs.append(probs)
205
        return (out_tokens, out_logprobs if logprobs else None)
206
207
    def text_completion(
208
        self,
209
        prompts: List[str],
210
        temperature: float = 0.6,
211
        top_p: float = 0.9,
212
        max_gen_len: Optional[int] = None,
213
        logprobs: bool = False,
214
        echo: bool = False,
215
    ) -> List[CompletionPrediction]:
216
        if max_gen_len is None:
217
            max_gen_len = self.model.params.max_seq_len - 1
218
        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
219
        generation_tokens, generation_logprobs = self.generate(
220
            prompt_tokens=prompt_tokens,
221
            max_gen_len=max_gen_len,
222
            temperature=temperature,
223
            top_p=top_p,
224
            logprobs=logprobs,
225
            echo=echo,
226
        )
227
        if logprobs:
228
            return [
229
                {
230
                    "generation": self.tokenizer.decode(t),
231
                    "tokens": [self.tokenizer.decode(x) for x in t],
232
                    "logprobs": logprobs_i,
233
                }
234
                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
235
            ]
236
        return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
237
238
    def text_infilling(
239
        self,
240
        prefixes: List[str],
241
        suffixes: List[str],
242
        temperature: float = 0.6,
243
        top_p: float = 0.9,
244
        max_gen_len: Optional[int] = None,
245
        logprobs: bool = False,
246
        suffix_first: bool = False,
247
    ) -> List[InfillingPrediction]:
248
        assert self.tokenizer.eot_id is not None
249
        if max_gen_len is None:
250
            max_gen_len = self.model.params.max_seq_len - 1
251
        prompt_tokens = [
252
            infilling_prompt_tokens(
253
                self.tokenizer, prefix, suffix, suffix_first=suffix_first
254
            )
255
            for prefix, suffix in zip(prefixes, suffixes)
256
        ]
257
        generation_tokens, generation_logprobs = self.generate(
258
            prompt_tokens=prompt_tokens,
259
            max_gen_len=max_gen_len,
260
            temperature=temperature,
261
            top_p=top_p,
262
            logprobs=logprobs,
263
            echo=False,
264
            stop_token=self.tokenizer.eot_id,
265
        )
266
267
        generations = [self.tokenizer.decode_infilling(t) for t in generation_tokens]
268
269
        if logprobs:
270
            return [
271
                {
272
                    "generation": generation,
273
                    "logprobs": logprobs_i,
274
                    "tokens": t,
275
                    "full_text": prefix + generation + suffix,
276
                }
277
                for prefix, suffix, generation, t, logprobs_i in zip(
278
                    prefixes,
279
                    suffixes,
280
                    generations,
281
                    generation_tokens,
282
                    generation_logprobs,
283
                )
284
            ]
285
        else:
286
            return [
287
                {
288
                    "generation": generation,
289
                    "full_text": prefix + generation + suffix,
290
                }
291
                for prefix, suffix, generation in zip(prefixes, suffixes, generations)
292
            ]
293
294
    def chat_completion(
295
        self,
296
        dialogs: List[Dialog],
297
        temperature: float = 0.6,
298
        top_p: float = 0.9,
299
        max_gen_len: Optional[int] = None,
300
        logprobs: bool = False,
301
    ) -> List[ChatPrediction]:
302
        if max_gen_len is None:
303
            max_gen_len = self.model.params.max_seq_len - 1
304
        prompt_tokens = []
305
        unsafe_requests = []
306
        for dialog in dialogs:
307
            unsafe_requests.append(
308
                any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
309
            )
310
            if dialog[0]["role"] == "system":
311
                dialog = [
312
                    {
313
                        "role": dialog[1]["role"],
314
                        "content": B_SYS
315
                        + dialog[0]["content"]
316
                        + E_SYS
317
                        + dialog[1]["content"],
318
                    }
319
                ] + dialog[2:]
320
            assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
321
                [msg["role"] == "assistant" for msg in dialog[1::2]]
322
            ), (
323
                "model only supports 'system', 'user' and 'assistant' roles, "
324
                "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
325
            )
326
            dialog_tokens: List[int] = sum(
327
                [
328
                    self.tokenizer.encode(
329
                        f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
330
                        bos=True,
331
                        eos=True,
332
                    )
333
                    for prompt, answer in zip(
334
                        dialog[::2],
335
                        dialog[1::2],
336
                    )
337
                ],
338
                [],
339
            )
340
            assert (
341
                dialog[-1]["role"] == "user"
342
            ), f"Last message must be from user, got {dialog[-1]['role']}"
343
            dialog_tokens += self.tokenizer.encode(
344
                f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
345
                bos=True,
346
                eos=False,
347
            )
348
            prompt_tokens.append(dialog_tokens)
349
350
        generation_tokens, generation_logprobs = self.generate(
351
            prompt_tokens=prompt_tokens,
352
            max_gen_len=max_gen_len,
353
            temperature=temperature,
354
            top_p=top_p,
355
            logprobs=logprobs,
356
        )
357
        if logprobs:
358
            return [
359
                {
360
                    "generation": {
361
                        "role": "assistant",
362
                        "content": self.tokenizer.decode(t)
363
                        if not unsafe
364
                        else UNSAFE_ERROR,
365
                    },
366
                    "tokens": [self.tokenizer.decode(x) for x in t],
367
                    "logprobs": logprobs_i,
368
                }
369
                for t, logprobs_i, unsafe in zip(
370
                    generation_tokens, generation_logprobs, unsafe_requests
371
                )
372
            ]
373
        return [
374
            {
375
                "generation": {
376
                    "role": "assistant",
377
                    "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
378
                }
379
            }
380
            for t, unsafe in zip(generation_tokens, unsafe_requests)
381
        ]
382
383
384
def sample_top_p(probs, p):
385
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
386
    probs_sum = torch.cumsum(probs_sort, dim=-1)
387
    mask = probs_sum - probs_sort > p
388
    probs_sort[mask] = 0.0
389
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
390
    next_token = torch.multinomial(probs_sort, num_samples=1)
391
    next_token = torch.gather(probs_idx, -1, next_token)
392
    return next_token
393
394
395
def infilling_prompt_tokens(
396
    tokenizer: Tokenizer,
397
    pre: str,
398
    suf: str,
399
    suffix_first: bool = False,
400
) -> List[int]:
401
    """
402
    Format and encode an infilling problem.
403
    If `suffix_first` is set, format in suffix-prefix-middle format.
404
    """
405
    assert tokenizer.prefix_id is not None
406
    assert tokenizer.middle_id is not None
407
    assert tokenizer.suffix_id is not None
408
    if suffix_first:
409
        # format as "<PRE> <SUF>{suf} <MID> {pre}"
410
        return (
411
            [tokenizer.bos_id, tokenizer.prefix_id, tokenizer.suffix_id]
412
            + tokenizer.encode_infilling(suf)
413
            + [tokenizer.middle_id]
414
            + tokenizer.encode(pre, bos=False, eos=False)
415
        )
416
    else:
417
        # format as "<PRE> {pre} <SUF>{suf} <MID>"
418
        return (
419
            [tokenizer.bos_id, tokenizer.prefix_id]
420
            + tokenizer.encode(pre, bos=False, eos=False)
421
            + [tokenizer.suffix_id]
422
            + tokenizer.encode_infilling(suf)
423
            + [tokenizer.middle_id]
424
        )