a b/lit_gpt/packed_dataset.py
1
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
4
5
import os
6
import random
7
import struct
8
9
import numpy as np
10
import torch
11
from torch.utils.data import IterableDataset, get_worker_info
12
13
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16}
14
15
16
def code(dtype):
17
    for k in dtypes:
18
        if dtypes[k] == dtype:
19
            return k
20
    raise ValueError(dtype)
21
22
23
HDR_MAGIC = b"LITPKDS"
24
HDR_SIZE = 24  # bytes
25
26
27
class PackedDataset(IterableDataset):
28
    def __init__(
29
        self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0
30
    ):
31
        self._filenames = filenames
32
        self._n_chunks = n_chunks
33
        self._block_size = block_size
34
        self._seed = seed
35
        self._shuffle = shuffle
36
        self._wrap = wrap
37
        self._num_processes = num_processes
38
        self._process_rank = process_rank
39
40
    def __iter__(self):
41
        worker_info = get_worker_info()
42
        num_workers = worker_info.num_workers if worker_info is not None else 1
43
        worker_id = worker_info.id if worker_info is not None else 0
44
        num_shards = num_workers * self._num_processes
45
        shard_id = self._process_rank * num_workers + worker_id
46
47
        max_num_files = len(self._filenames) // num_shards * num_shards
48
        filenames = self._filenames[shard_id:max_num_files:num_shards]
49
50
        return PackedDatasetIterator(
51
            filenames=filenames,
52
            n_chunks=self._n_chunks,
53
            block_size=self._block_size,
54
            seed=self._seed,
55
            shuffle=self._shuffle,
56
            wrap=self._wrap,
57
        )
58
59
60
class PackedDatasetBuilder(object):
61
    def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
62
        if dtype == "auto":
63
            if vocab_size is None:
64
                raise ValueError("vocab_size cannot be None when dtype='auto'")
65
            if vocab_size is not None and vocab_size < 65500:
66
                self._dtype = np.uint16
67
            else:
68
                self._dtype = np.int32
69
        else:
70
            self._dtype = dtype
71
        self._counter = 0
72
        self._chunk_size = chunk_size
73
        self._outdir = outdir
74
        self._prefix = prefix
75
        self._sep_token = sep_token
76
        self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
77
        self._arr.fill(self._sep_token)
78
        self._idx = 0
79
        self._version = 1
80
        self._filenames = []
81
82
    def _write_chunk(self):
83
        filename = f"{self._prefix}_{self._counter:010d}.bin"
84
        filename = os.path.join(self._outdir, filename)
85
86
        with open(filename, "wb") as f:
87
            f.write(HDR_MAGIC)
88
            f.write(struct.pack("<Q", self._version))
89
            f.write(struct.pack("<B", code(self._dtype)))
90
            f.write(struct.pack("<Q", self._chunk_size))
91
            f.write(self._arr.tobytes(order="C"))
92
93
        self._filenames.append(filename)
94
        self._counter += 1
95
        self._arr.fill(self._sep_token)
96
        self._idx = 0
97
98
    @property
99
    def dtype(self):
100
        return self._dtype
101
102
    @property
103
    def filenames(self):
104
        return self._filenames.copy()
105
106
    def add_array(self, arr):
107
        while self._idx + arr.shape[0] > self._chunk_size:
108
            part_len = self._chunk_size - self._idx
109
            self._arr[self._idx : self._idx + part_len] = arr[:part_len]
110
            self._write_chunk()
111
            arr = arr[part_len:]
112
113
        arr_len = arr.shape[0]
114
        self._arr[self._idx : self._idx + arr_len] = arr
115
        self._idx += arr_len
116
117
    def write_reminder(self):
118
        self._write_chunk()
119
120
121
class PackedDatasetIterator:
122
    def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
123
        self._seed = seed
124
        self._shuffle = shuffle
125
        self._rng = np.random.default_rng(seed) if shuffle else None
126
        self._block_idxs = None
127
128
        self._wrap = wrap
129
130
        # TODO: instead of filenames, we could have a single text stream
131
        #       (or text file) with the sequence of all files to be
132
        #       fetched/loaded.
133
        self._filenames = filenames
134
        self._file_idx = 0
135
136
        self._n_chunks = n_chunks
137
138
        self._dtype = None
139
        self._block_size = block_size
140
        self._n_blocks = None
141
142
        self._mmaps = []
143
        self._buffers = []
144
145
        self._block_idxs = []
146
        self._curr_idx = 0
147
148
        self._load_n_chunks()
149
150
    def _read_header(self, path):
151
        with open(path, "rb") as f:
152
            magic = f.read(len(HDR_MAGIC))
153
            assert magic == HDR_MAGIC, "File doesn't match expected format."
154
            version = struct.unpack("<Q", f.read(8))
155
            assert version == (1,)
156
            (dtype_code,) = struct.unpack("<B", f.read(1))
157
            dtype = dtypes[dtype_code]
158
            (chunk_size,) = struct.unpack("<Q", f.read(8))
159
        return dtype, chunk_size
160
161
    def _close_mmaps(self):
162
        for mmap in self._mmaps:
163
            mmap._mmap.close()
164
165
    def _load_n_chunks(self):
166
        self._close_mmaps()
167
        self._mmaps = []
168
        self._buffers = []
169
170
        if self._n_chunks > len(self._filenames[self._file_idx :]):
171
            if not self._wrap:
172
                raise StopIteration
173
            self._file_idx = 0
174
175
        for i in range(self._n_chunks):
176
            filename = self._filenames[self._file_idx + i]
177
            if self._dtype is None:
178
                self._dtype, self._chunk_size = self._read_header(filename)
179
                self._n_blocks = self._chunk_size // self._block_size
180
            # TODO: check header matches with previous files
181
            mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
182
            self._mmaps.append(mmap)
183
            self._buffers.append(memoryview(mmap))
184
185
        self._file_idx += self._n_chunks
186
        n_all_blocks = self._n_chunks * self._n_blocks
187
188
        self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
189
190
        self._curr_idx = 0
191
192
    def __del__(self):
193
        self._close_mmaps()
194
        del self._mmaps
195
        del self._buffers
196
197
    def __iter__(self):
198
        return self
199
200
    def __next__(self):
201
        if self._curr_idx >= len(self._block_idxs):
202
            self._load_n_chunks()
203
            # TODO: trigger fetching next next n_chunks if remote
204
        block_idx = self._block_idxs[self._curr_idx]
205
        chunk_id = block_idx // self._n_blocks
206
        buffer = self._buffers[chunk_id]
207
        elem_id = (block_idx % self._n_blocks) * self._block_size
208
        offset = np.dtype(self._dtype).itemsize * elem_id
209
        arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
210
        self._curr_idx += 1
211
        return torch.from_numpy(arr.astype(np.int64))
212
213
214
class CombinedDataset(IterableDataset):
215
    def __init__(self, datasets, seed, weights=None):
216
        self._seed = seed
217
        self._datasets = datasets
218
        self._weights = weights
219
        n_datasets = len(datasets)
220
        if weights is None:
221
            self._weights = [1 / n_datasets] * n_datasets
222
        else:
223
            self._weights = [w / sum(weights) for w in weights]
224
225
    def __iter__(self):
226
        return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
227
228
229
class CombinedDatasetIterator:
230
    def __init__(self, datasets, seed, weights):
231
        self._datasets = [iter(el) for el in datasets]
232
        self._weights = weights
233
        self._rng = random.Random(seed)
234
235
    def __next__(self):
236
        (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
237
        return next(dataset)