a b/src/llama-main/llama/tokenizer.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
import os
5
from logging import getLogger
6
from typing import List
7
8
from sentencepiece import SentencePieceProcessor
9
10
11
logger = getLogger()
12
13
14
class Tokenizer:
15
    """tokenizing and encoding/decoding text using SentencePiece."""
16
    def __init__(self, model_path: str):
17
        """
18
        Initializes the Tokenizer with a SentencePiece model.
19
20
        Args:
21
            model_path (str): The path to the SentencePiece model file.
22
        """
23
        # reload tokenizer
24
        assert os.path.isfile(model_path), model_path
25
        self.sp_model = SentencePieceProcessor(model_file=model_path)
26
        logger.info(f"Reloaded SentencePiece model from {model_path}")
27
28
        # BOS / EOS token IDs
29
        self.n_words: int = self.sp_model.vocab_size()
30
        self.bos_id: int = self.sp_model.bos_id()
31
        self.eos_id: int = self.sp_model.eos_id()
32
        self.pad_id: int = self.sp_model.pad_id()
33
        logger.info(
34
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
35
        )
36
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
37
38
    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
39
        """
40
        Encodes a string into a list of token IDs.
41
42
        Args:
43
            s (str): The input string to be encoded.
44
            bos (bool): Whether to prepend the beginning-of-sequence token.
45
            eos (bool): Whether to append the end-of-sequence token.
46
47
        Returns:
48
            List[int]: A list of token IDs.
49
        """
50
        assert type(s) is str
51
        t = self.sp_model.encode(s)
52
        if bos:
53
            t = [self.bos_id] + t
54
        if eos:
55
            t = t + [self.eos_id]
56
        return t
57
58
    def decode(self, t: List[int]) -> str:
59
        """
60
        Decodes a list of token IDs into a string.
61
62
        Args:
63
            t (List[int]): The list of token IDs to be decoded.
64
65
        Returns:
66
            str: The decoded string.
67
        """
68
        return self.sp_model.decode(t)