Switch to unified view

a b/src/codellama-main/llama/tokenizer.py
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
4
import os
5
from logging import getLogger
6
from typing import List, Optional
7
8
from sentencepiece import SentencePieceProcessor
9
10
11
logger = getLogger()
12
13
14
class Tokenizer:
15
    def __init__(self, model_path: str):
16
        # reload tokenizer
17
        assert os.path.isfile(model_path), model_path
18
        self.sp_model = SentencePieceProcessor(model_file=model_path)
19
        logger.info(f"Reloaded SentencePiece model from {model_path}")
20
21
        # BOS / EOS token IDs
22
        self.n_words: int = self.sp_model.vocab_size()
23
        self.bos_id: int = self.sp_model.bos_id()
24
        self.eos_id: int = self.sp_model.eos_id()
25
        self.pad_id: int = self.sp_model.pad_id()
26
27
        # token IDs for special infilling tokens
28
        self.prefix_id: Optional[int] = self.sp_model.piece_to_id("▁<PRE>") or None
29
        self.middle_id: Optional[int] = self.sp_model.piece_to_id("▁<MID>") or None
30
        self.suffix_id: Optional[int] = self.sp_model.piece_to_id("▁<SUF>") or None
31
        self.eot_id: Optional[int] = self.sp_model.piece_to_id("▁<EOT>") or None
32
        logger.info(
33
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id} "
34
            f"- PRE ID: {self.prefix_id} - MID ID: {self.middle_id} - SUF ID: {self.suffix_id} - EOT ID: {self.eot_id}"
35
        )
36
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
37
38
    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
39
        assert type(s) is str
40
        t = self.sp_model.encode(s)
41
        if bos:
42
            t = [self.bos_id] + t
43
        if eos:
44
            t = t + [self.eos_id]
45
        return t
46
47
    def decode(self, t: List[int]) -> str:
48
        return self.sp_model.decode(t)
49
50
    def encode_infilling(self, s: str) -> List[int]:
51
        """Encode a string without an implicit leading space."""
52
        return self.sp_model.encode("☺" + s)[2:]
53
54
    def decode_infilling(self, t: List[int]) -> str:
55
        """Decode a string without an implicit leading space."""
56
        return self.sp_model.decode([self.sp_model.piece_to_id("☺")] + t)[1:]