|
a |
|
b/lit_gpt/packed_dataset.py |
|
|
1 |
# Very loosely inspired by indexed_dataset in Fairseq, Megatron |
|
|
2 |
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
import random |
|
|
7 |
import struct |
|
|
8 |
|
|
|
9 |
import numpy as np |
|
|
10 |
import torch |
|
|
11 |
from torch.utils.data import IterableDataset, get_worker_info |
|
|
12 |
|
|
|
13 |
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def code(dtype): |
|
|
17 |
for k in dtypes: |
|
|
18 |
if dtypes[k] == dtype: |
|
|
19 |
return k |
|
|
20 |
raise ValueError(dtype) |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
HDR_MAGIC = b"LITPKDS" |
|
|
24 |
HDR_SIZE = 24 # bytes |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class PackedDataset(IterableDataset): |
|
|
28 |
def __init__( |
|
|
29 |
self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 |
|
|
30 |
): |
|
|
31 |
self._filenames = filenames |
|
|
32 |
self._n_chunks = n_chunks |
|
|
33 |
self._block_size = block_size |
|
|
34 |
self._seed = seed |
|
|
35 |
self._shuffle = shuffle |
|
|
36 |
self._wrap = wrap |
|
|
37 |
self._num_processes = num_processes |
|
|
38 |
self._process_rank = process_rank |
|
|
39 |
|
|
|
40 |
def __iter__(self): |
|
|
41 |
worker_info = get_worker_info() |
|
|
42 |
num_workers = worker_info.num_workers if worker_info is not None else 1 |
|
|
43 |
worker_id = worker_info.id if worker_info is not None else 0 |
|
|
44 |
num_shards = num_workers * self._num_processes |
|
|
45 |
shard_id = self._process_rank * num_workers + worker_id |
|
|
46 |
|
|
|
47 |
max_num_files = len(self._filenames) // num_shards * num_shards |
|
|
48 |
filenames = self._filenames[shard_id:max_num_files:num_shards] |
|
|
49 |
|
|
|
50 |
return PackedDatasetIterator( |
|
|
51 |
filenames=filenames, |
|
|
52 |
n_chunks=self._n_chunks, |
|
|
53 |
block_size=self._block_size, |
|
|
54 |
seed=self._seed, |
|
|
55 |
shuffle=self._shuffle, |
|
|
56 |
wrap=self._wrap, |
|
|
57 |
) |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
class PackedDatasetBuilder(object): |
|
|
61 |
def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): |
|
|
62 |
if dtype == "auto": |
|
|
63 |
if vocab_size is None: |
|
|
64 |
raise ValueError("vocab_size cannot be None when dtype='auto'") |
|
|
65 |
if vocab_size is not None and vocab_size < 65500: |
|
|
66 |
self._dtype = np.uint16 |
|
|
67 |
else: |
|
|
68 |
self._dtype = np.int32 |
|
|
69 |
else: |
|
|
70 |
self._dtype = dtype |
|
|
71 |
self._counter = 0 |
|
|
72 |
self._chunk_size = chunk_size |
|
|
73 |
self._outdir = outdir |
|
|
74 |
self._prefix = prefix |
|
|
75 |
self._sep_token = sep_token |
|
|
76 |
self._arr = np.zeros(self._chunk_size, dtype=self._dtype) |
|
|
77 |
self._arr.fill(self._sep_token) |
|
|
78 |
self._idx = 0 |
|
|
79 |
self._version = 1 |
|
|
80 |
self._filenames = [] |
|
|
81 |
|
|
|
82 |
def _write_chunk(self): |
|
|
83 |
filename = f"{self._prefix}_{self._counter:010d}.bin" |
|
|
84 |
filename = os.path.join(self._outdir, filename) |
|
|
85 |
|
|
|
86 |
with open(filename, "wb") as f: |
|
|
87 |
f.write(HDR_MAGIC) |
|
|
88 |
f.write(struct.pack("<Q", self._version)) |
|
|
89 |
f.write(struct.pack("<B", code(self._dtype))) |
|
|
90 |
f.write(struct.pack("<Q", self._chunk_size)) |
|
|
91 |
f.write(self._arr.tobytes(order="C")) |
|
|
92 |
|
|
|
93 |
self._filenames.append(filename) |
|
|
94 |
self._counter += 1 |
|
|
95 |
self._arr.fill(self._sep_token) |
|
|
96 |
self._idx = 0 |
|
|
97 |
|
|
|
98 |
@property |
|
|
99 |
def dtype(self): |
|
|
100 |
return self._dtype |
|
|
101 |
|
|
|
102 |
@property |
|
|
103 |
def filenames(self): |
|
|
104 |
return self._filenames.copy() |
|
|
105 |
|
|
|
106 |
def add_array(self, arr): |
|
|
107 |
while self._idx + arr.shape[0] > self._chunk_size: |
|
|
108 |
part_len = self._chunk_size - self._idx |
|
|
109 |
self._arr[self._idx : self._idx + part_len] = arr[:part_len] |
|
|
110 |
self._write_chunk() |
|
|
111 |
arr = arr[part_len:] |
|
|
112 |
|
|
|
113 |
arr_len = arr.shape[0] |
|
|
114 |
self._arr[self._idx : self._idx + arr_len] = arr |
|
|
115 |
self._idx += arr_len |
|
|
116 |
|
|
|
117 |
def write_reminder(self): |
|
|
118 |
self._write_chunk() |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
class PackedDatasetIterator: |
|
|
122 |
def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): |
|
|
123 |
self._seed = seed |
|
|
124 |
self._shuffle = shuffle |
|
|
125 |
self._rng = np.random.default_rng(seed) if shuffle else None |
|
|
126 |
self._block_idxs = None |
|
|
127 |
|
|
|
128 |
self._wrap = wrap |
|
|
129 |
|
|
|
130 |
# TODO: instead of filenames, we could have a single text stream |
|
|
131 |
# (or text file) with the sequence of all files to be |
|
|
132 |
# fetched/loaded. |
|
|
133 |
self._filenames = filenames |
|
|
134 |
self._file_idx = 0 |
|
|
135 |
|
|
|
136 |
self._n_chunks = n_chunks |
|
|
137 |
|
|
|
138 |
self._dtype = None |
|
|
139 |
self._block_size = block_size |
|
|
140 |
self._n_blocks = None |
|
|
141 |
|
|
|
142 |
self._mmaps = [] |
|
|
143 |
self._buffers = [] |
|
|
144 |
|
|
|
145 |
self._block_idxs = [] |
|
|
146 |
self._curr_idx = 0 |
|
|
147 |
|
|
|
148 |
self._load_n_chunks() |
|
|
149 |
|
|
|
150 |
def _read_header(self, path): |
|
|
151 |
with open(path, "rb") as f: |
|
|
152 |
magic = f.read(len(HDR_MAGIC)) |
|
|
153 |
assert magic == HDR_MAGIC, "File doesn't match expected format." |
|
|
154 |
version = struct.unpack("<Q", f.read(8)) |
|
|
155 |
assert version == (1,) |
|
|
156 |
(dtype_code,) = struct.unpack("<B", f.read(1)) |
|
|
157 |
dtype = dtypes[dtype_code] |
|
|
158 |
(chunk_size,) = struct.unpack("<Q", f.read(8)) |
|
|
159 |
return dtype, chunk_size |
|
|
160 |
|
|
|
161 |
def _close_mmaps(self): |
|
|
162 |
for mmap in self._mmaps: |
|
|
163 |
mmap._mmap.close() |
|
|
164 |
|
|
|
165 |
def _load_n_chunks(self): |
|
|
166 |
self._close_mmaps() |
|
|
167 |
self._mmaps = [] |
|
|
168 |
self._buffers = [] |
|
|
169 |
|
|
|
170 |
if self._n_chunks > len(self._filenames[self._file_idx :]): |
|
|
171 |
if not self._wrap: |
|
|
172 |
raise StopIteration |
|
|
173 |
self._file_idx = 0 |
|
|
174 |
|
|
|
175 |
for i in range(self._n_chunks): |
|
|
176 |
filename = self._filenames[self._file_idx + i] |
|
|
177 |
if self._dtype is None: |
|
|
178 |
self._dtype, self._chunk_size = self._read_header(filename) |
|
|
179 |
self._n_blocks = self._chunk_size // self._block_size |
|
|
180 |
# TODO: check header matches with previous files |
|
|
181 |
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) |
|
|
182 |
self._mmaps.append(mmap) |
|
|
183 |
self._buffers.append(memoryview(mmap)) |
|
|
184 |
|
|
|
185 |
self._file_idx += self._n_chunks |
|
|
186 |
n_all_blocks = self._n_chunks * self._n_blocks |
|
|
187 |
|
|
|
188 |
self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) |
|
|
189 |
|
|
|
190 |
self._curr_idx = 0 |
|
|
191 |
|
|
|
192 |
def __del__(self): |
|
|
193 |
self._close_mmaps() |
|
|
194 |
del self._mmaps |
|
|
195 |
del self._buffers |
|
|
196 |
|
|
|
197 |
def __iter__(self): |
|
|
198 |
return self |
|
|
199 |
|
|
|
200 |
def __next__(self): |
|
|
201 |
if self._curr_idx >= len(self._block_idxs): |
|
|
202 |
self._load_n_chunks() |
|
|
203 |
# TODO: trigger fetching next next n_chunks if remote |
|
|
204 |
block_idx = self._block_idxs[self._curr_idx] |
|
|
205 |
chunk_id = block_idx // self._n_blocks |
|
|
206 |
buffer = self._buffers[chunk_id] |
|
|
207 |
elem_id = (block_idx % self._n_blocks) * self._block_size |
|
|
208 |
offset = np.dtype(self._dtype).itemsize * elem_id |
|
|
209 |
arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) |
|
|
210 |
self._curr_idx += 1 |
|
|
211 |
return torch.from_numpy(arr.astype(np.int64)) |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
class CombinedDataset(IterableDataset): |
|
|
215 |
def __init__(self, datasets, seed, weights=None): |
|
|
216 |
self._seed = seed |
|
|
217 |
self._datasets = datasets |
|
|
218 |
self._weights = weights |
|
|
219 |
n_datasets = len(datasets) |
|
|
220 |
if weights is None: |
|
|
221 |
self._weights = [1 / n_datasets] * n_datasets |
|
|
222 |
else: |
|
|
223 |
self._weights = [w / sum(weights) for w in weights] |
|
|
224 |
|
|
|
225 |
def __iter__(self): |
|
|
226 |
return CombinedDatasetIterator(self._datasets, self._seed, self._weights) |
|
|
227 |
|
|
|
228 |
|
|
|
229 |
class CombinedDatasetIterator: |
|
|
230 |
def __init__(self, datasets, seed, weights): |
|
|
231 |
self._datasets = [iter(el) for el in datasets] |
|
|
232 |
self._weights = weights |
|
|
233 |
self._rng = random.Random(seed) |
|
|
234 |
|
|
|
235 |
def __next__(self): |
|
|
236 |
(dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) |
|
|
237 |
return next(dataset) |