a b/src/llama-main/llama/generation.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
import json
5
import os
6
import sys
7
import time
8
from pathlib import Path
9
from typing import List, Literal, Optional, Tuple, TypedDict
10
11
import torch
12
import torch.nn.functional as F
13
from fairscale.nn.model_parallel.initialize import (
14
    get_model_parallel_rank,
15
    initialize_model_parallel,
16
    model_parallel_is_initialized,
17
)
18
19
from llama.model import ModelArgs, Transformer
20
from llama.tokenizer import Tokenizer
21
22
Role = Literal["system", "user", "assistant"]
23
24
25
class Message(TypedDict):
26
    role: Role
27
    content: str
28
29
30
class CompletionPrediction(TypedDict, total=False):
31
    generation: str
32
    tokens: List[str]  # not required
33
    logprobs: List[float]  # not required
34
35
36
class ChatPrediction(TypedDict, total=False):
37
    generation: Message
38
    tokens: List[str]  # not required
39
    logprobs: List[float]  # not required
40
41
42
Dialog = List[Message]
43
44
B_INST, E_INST = "[INST]", "[/INST]"
45
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
46
47
SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
48
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
49
50
51
class Llama:
52
    @staticmethod
53
    def build(
54
        ckpt_dir: str,
55
        tokenizer_path: str,
56
        max_seq_len: int,
57
        max_batch_size: int,
58
        model_parallel_size: Optional[int] = None,
59
        seed: int = 1,
60
    ) -> "Llama":
61
        """
62
        Build a Llama instance by initializing and loading a pre-trained model.
63
64
        Args:
65
            ckpt_dir (str): Path to the directory containing checkpoint files.
66
            tokenizer_path (str): Path to the tokenizer file.
67
            max_seq_len (int): Maximum sequence length for input text.
68
            max_batch_size (int): Maximum batch size for inference.
69
            model_parallel_size (Optional[int], optional): Number of model parallel processes.
70
                If not provided, it's determined from the environment. Defaults to None.
71
72
        Returns:
73
            Llama: An instance of the Llama class with the loaded model and tokenizer.
74
75
        Raises:
76
            AssertionError: If there are no checkpoint files in the specified directory,
77
                or if the model parallel size does not match the number of checkpoint files.
78
79
        Note:
80
            This method initializes the distributed process group, sets the device to CUDA,
81
            and loads the pre-trained model and tokenizer.
82
83
        """
84
        if not torch.distributed.is_initialized():
85
            torch.distributed.init_process_group("nccl")
86
        if not model_parallel_is_initialized():
87
            if model_parallel_size is None:
88
                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
89
            initialize_model_parallel(model_parallel_size)
90
91
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
92
        torch.cuda.set_device(local_rank)
93
94
        # seed must be the same in all processes
95
        torch.manual_seed(seed)
96
97
        if local_rank > 0:
98
            sys.stdout = open(os.devnull, "w")
99
100
        start_time = time.time()
101
        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
102
        assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
103
        assert model_parallel_size == len(
104
            checkpoints
105
        ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
106
        ckpt_path = checkpoints[get_model_parallel_rank()]
107
        checkpoint = torch.load(ckpt_path, map_location="cpu")
108
        with open(Path(ckpt_dir) / "params.json", "r") as f:
109
            params = json.loads(f.read())
110
111
        model_args: ModelArgs = ModelArgs(
112
            max_seq_len=max_seq_len,
113
            max_batch_size=max_batch_size,
114
            **params,
115
        )
116
        tokenizer = Tokenizer(model_path=tokenizer_path)
117
        model_args.vocab_size = tokenizer.n_words
118
        torch.set_default_tensor_type(torch.cuda.HalfTensor)
119
        model = Transformer(model_args)
120
        model.load_state_dict(checkpoint, strict=False)
121
        print(f"Loaded in {time.time() - start_time:.2f} seconds")
122
123
        return Llama(model, tokenizer)
124
125
    def __init__(self, model: Transformer, tokenizer: Tokenizer):
126
        self.model = model
127
        self.tokenizer = tokenizer
128
129
    @torch.inference_mode()
130
    def generate(
131
        self,
132
        prompt_tokens: List[List[int]],
133
        max_gen_len: int,
134
        temperature: float = 0.6,
135
        top_p: float = 0.9,
136
        logprobs: bool = False,
137
        echo: bool = False,
138
    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
139
        """
140
        Generate text sequences based on provided prompts using the language generation model.
141
142
        Args:
143
            prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
144
            max_gen_len (int): Maximum length of the generated text sequence.
145
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
146
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
147
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
148
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
149
150
        Returns:
151
            Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
152
153
        Note:
154
            This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
155
            If logprobs is True, token log probabilities are computed for each generated token.
156
157
        """
158
        params = self.model.params
159
        bsz = len(prompt_tokens)
160
        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
161
162
        min_prompt_len = min(len(t) for t in prompt_tokens)
163
        max_prompt_len = max(len(t) for t in prompt_tokens)
164
        assert max_prompt_len <= params.max_seq_len
165
        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
166
167
        pad_id = self.tokenizer.pad_id
168
        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
169
        for k, t in enumerate(prompt_tokens):
170
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
171
        if logprobs:
172
            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
173
174
        prev_pos = 0
175
        eos_reached = torch.tensor([False] * bsz, device="cuda")
176
        input_text_mask = tokens != pad_id
177
        if min_prompt_len == total_len:
178
            logits = self.model.forward(tokens, prev_pos)
179
            token_logprobs = -F.cross_entropy(
180
                input=logits.transpose(1, 2),
181
                target=tokens,
182
                reduction="none",
183
                ignore_index=pad_id,
184
            )
185
186
        for cur_pos in range(min_prompt_len, total_len):
187
            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
188
            if temperature > 0:
189
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
190
                next_token = sample_top_p(probs, top_p)
191
            else:
192
                next_token = torch.argmax(logits[:, -1], dim=-1)
193
194
            next_token = next_token.reshape(-1)
195
            # only replace token if prompt has already been generated
196
            next_token = torch.where(
197
                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
198
            )
199
            tokens[:, cur_pos] = next_token
200
            if logprobs:
201
                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
202
                    input=logits.transpose(1, 2),
203
                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
204
                    reduction="none",
205
                    ignore_index=pad_id,
206
                )
207
            eos_reached |= (~input_text_mask[:, cur_pos]) & (
208
                next_token == self.tokenizer.eos_id
209
            )
210
            prev_pos = cur_pos
211
            if all(eos_reached):
212
                break
213
214
        if logprobs:
215
            token_logprobs = token_logprobs.tolist()
216
        out_tokens, out_logprobs = [], []
217
        for i, toks in enumerate(tokens.tolist()):
218
            # cut to max gen len
219
            start = 0 if echo else len(prompt_tokens[i])
220
            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
221
            probs = None
222
            if logprobs:
223
                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
224
            # cut to eos tok if any
225
            if self.tokenizer.eos_id in toks:
226
                eos_idx = toks.index(self.tokenizer.eos_id)
227
                toks = toks[:eos_idx]
228
                probs = probs[:eos_idx] if logprobs else None
229
            out_tokens.append(toks)
230
            out_logprobs.append(probs)
231
        return (out_tokens, out_logprobs if logprobs else None)
232
233
    def text_completion(
234
        self,
235
        prompts: List[str],
236
        temperature: float = 0.6,
237
        top_p: float = 0.9,
238
        max_gen_len: Optional[int] = None,
239
        logprobs: bool = False,
240
        echo: bool = False,
241
    ) -> List[CompletionPrediction]:
242
        """
243
        Perform text completion for a list of prompts using the language generation model.
244
245
        Args:
246
            prompts (List[str]): List of text prompts for completion.
247
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
248
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
249
            max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
250
                If not provided, it's set to the model's maximum sequence length minus 1.
251
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
252
            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
253
254
        Returns:
255
            List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
256
257
        Note:
258
            This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
259
            If logprobs is True, token log probabilities are computed for each generated token.
260
261
        """
262
        if max_gen_len is None:
263
            max_gen_len = self.model.params.max_seq_len - 1
264
        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
265
        generation_tokens, generation_logprobs = self.generate(
266
            prompt_tokens=prompt_tokens,
267
            max_gen_len=max_gen_len,
268
            temperature=temperature,
269
            top_p=top_p,
270
            logprobs=logprobs,
271
            echo=echo,
272
        )
273
        if logprobs:
274
            return [
275
                {
276
                    "generation": self.tokenizer.decode(t),
277
                    "tokens": [self.tokenizer.decode(x) for x in t],
278
                    "logprobs": logprobs_i,
279
                }
280
                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
281
            ]
282
        return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
283
284
    def chat_completion(
285
        self,
286
        dialogs: List[Dialog],
287
        temperature: float = 0.6,
288
        top_p: float = 0.9,
289
        max_gen_len: Optional[int] = None,
290
        logprobs: bool = False,
291
    ) -> List[ChatPrediction]:
292
        """
293
        Generate assistant responses for a list of conversational dialogs using the language generation model.
294
295
        Args:
296
            dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
297
            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
298
            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
299
            max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
300
                If not provided, it's set to the model's maximum sequence length minus 1.
301
            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
302
303
        Returns:
304
            List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
305
306
        Raises:
307
            AssertionError: If the last message in a dialog is not from the user.
308
            AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
309
310
        Note:
311
            This method generates assistant responses for the provided conversational dialogs.
312
            It employs nucleus sampling to introduce controlled randomness in text generation.
313
            If logprobs is True, token log probabilities are computed for each generated token.
314
315
        """
316
        if max_gen_len is None:
317
            max_gen_len = self.model.params.max_seq_len - 1
318
        prompt_tokens = []
319
        unsafe_requests = []
320
        for dialog in dialogs:
321
            unsafe_requests.append(
322
                any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
323
            )
324
            if dialog[0]["role"] == "system":
325
                dialog = [
326
                    {
327
                        "role": dialog[1]["role"],
328
                        "content": B_SYS
329
                        + dialog[0]["content"]
330
                        + E_SYS
331
                        + dialog[1]["content"],
332
                    }
333
                ] + dialog[2:]
334
            assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
335
                [msg["role"] == "assistant" for msg in dialog[1::2]]
336
            ), (
337
                "model only supports 'system', 'user' and 'assistant' roles, "
338
                "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
339
            )
340
            dialog_tokens: List[int] = sum(
341
                [
342
                    self.tokenizer.encode(
343
                        f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
344
                        bos=True,
345
                        eos=True,
346
                    )
347
                    for prompt, answer in zip(
348
                        dialog[::2],
349
                        dialog[1::2],
350
                    )
351
                ],
352
                [],
353
            )
354
            assert (
355
                dialog[-1]["role"] == "user"
356
            ), f"Last message must be from user, got {dialog[-1]['role']}"
357
            dialog_tokens += self.tokenizer.encode(
358
                f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
359
                bos=True,
360
                eos=False,
361
            )
362
            prompt_tokens.append(dialog_tokens)
363
364
        generation_tokens, generation_logprobs = self.generate(
365
            prompt_tokens=prompt_tokens,
366
            max_gen_len=max_gen_len,
367
            temperature=temperature,
368
            top_p=top_p,
369
            logprobs=logprobs,
370
        )
371
        if logprobs:
372
            return [
373
                {
374
                    "generation": {
375
                        "role": "assistant",
376
                        "content": self.tokenizer.decode(t)
377
                        if not unsafe
378
                        else UNSAFE_ERROR,
379
                    },
380
                    "tokens": [self.tokenizer.decode(x) for x in t],
381
                    "logprobs": logprobs_i,
382
                }
383
                for t, logprobs_i, unsafe in zip(
384
                    generation_tokens, generation_logprobs, unsafe_requests
385
                )
386
            ]
387
        return [
388
            {
389
                "generation": {
390
                    "role": "assistant",
391
                    "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
392
                }
393
            }
394
            for t, unsafe in zip(generation_tokens, unsafe_requests)
395
        ]
396
397
398
def sample_top_p(probs, p):
399
    """
400
    Perform top-p (nucleus) sampling on a probability distribution.
401
402
    Args:
403
        probs (torch.Tensor): Probability distribution tensor.
404
        p (float): Probability threshold for top-p sampling.
405
406
    Returns:
407
        torch.Tensor: Sampled token indices.
408
409
    Note:
410
        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
411
        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
412
413
    """
414
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
415
    probs_sum = torch.cumsum(probs_sort, dim=-1)
416
    mask = probs_sum - probs_sort > p
417
    probs_sort[mask] = 0.0
418
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
419
    next_token = torch.multinomial(probs_sort, num_samples=1)
420
    next_token = torch.gather(probs_idx, -1, next_token)
421
    return next_token