|
a |
|
b/scripts/gpt-pretrain.py |
|
|
1 |
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|
|
2 |
# SPDX-License-Identifier: LicenseRef-Apache2 |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
from pathlib import Path |
|
|
18 |
from typing import List, Optional, Sequence, TypedDict |
|
|
19 |
|
|
|
20 |
import lightning.pytorch as pl |
|
|
21 |
import numpy as np |
|
|
22 |
import torch |
|
|
23 |
|
|
|
24 |
# In lightning.pytorch 2.0 these are commented as being "any iterable or collection of iterables" |
|
|
25 |
# for now we'll use them incase the lightning type becomes something more specific in a future release. |
|
|
26 |
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS |
|
|
27 |
from nemo import lightning as nl |
|
|
28 |
from nemo.collections import llm |
|
|
29 |
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
|
|
30 |
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer |
|
|
31 |
from nemo.lightning.megatron_parallel import DataT |
|
|
32 |
from nemo.lightning.pytorch.plugins import MegatronDataSampler |
|
|
33 |
from torch.utils import data |
|
|
34 |
from torch.utils.data import DataLoader, Dataset |
|
|
35 |
|
|
|
36 |
from bionemo.llm.model.biobert.lightning import LossLoggingCallback |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
__all__: Sequence[str] = () |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
class MockDataModule(pl.LightningDataModule): |
|
|
43 |
def __init__( |
|
|
44 |
self, |
|
|
45 |
seq_length: int = 2048, |
|
|
46 |
tokenizer: Optional[TokenizerSpec] = None, |
|
|
47 |
micro_batch_size: int = 4, |
|
|
48 |
global_batch_size: int = 8, |
|
|
49 |
rampup_batch_size: Optional[List[int]] = None, |
|
|
50 |
num_train_samples: int = 10_000, |
|
|
51 |
num_val_samples: int = 10_000, |
|
|
52 |
num_test_samples: int = 10_000, |
|
|
53 |
num_workers: int = 8, |
|
|
54 |
pin_memory: bool = True, |
|
|
55 |
persistent_workers: bool = False, |
|
|
56 |
): |
|
|
57 |
super().__init__() |
|
|
58 |
self.seq_length = seq_length |
|
|
59 |
self.num_train_samples = num_train_samples |
|
|
60 |
self.num_val_samples = num_val_samples |
|
|
61 |
self.num_test_samples = num_test_samples |
|
|
62 |
self.num_workers = num_workers |
|
|
63 |
self.pin_memory = pin_memory |
|
|
64 |
self.persistent_workers = persistent_workers |
|
|
65 |
|
|
|
66 |
self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer") |
|
|
67 |
self.data_sampler = MegatronDataSampler( |
|
|
68 |
seq_len=self.seq_length, |
|
|
69 |
micro_batch_size=micro_batch_size, |
|
|
70 |
global_batch_size=global_batch_size, |
|
|
71 |
rampup_batch_size=rampup_batch_size, |
|
|
72 |
) |
|
|
73 |
# NOTE: the datasets and other distributed state is instantiated in `setup` rather than in `__init__` to support |
|
|
74 |
# the different kinds of accellerators/strategies that lightning supports. This is a common pattern in lightning. |
|
|
75 |
|
|
|
76 |
def setup(self, stage: str = "") -> None: |
|
|
77 |
"""See lightning documentation for more information on the stage and setup method. It is not required but |
|
|
78 |
if you want to be efficient about only initializing data that is needed in a particular stage you can do it here. |
|
|
79 |
According to the documentation valid values match the available calls to trainer.{fit,validate,test,predict}, |
|
|
80 |
for example stage="fit". If we wanted to be fancy we could only initialize train/val during "fit". We could |
|
|
81 |
only instantiate "test" data during "test" etc. |
|
|
82 |
""" |
|
|
83 |
self._train_ds = _MockGPTDataset(self.tokenizer, "train", self.num_train_samples, self.seq_length) |
|
|
84 |
self._validation_ds = _MockGPTDataset(self.tokenizer, "valid", self.num_val_samples, self.seq_length) |
|
|
85 |
self._test_ds = _MockGPTDataset(self.tokenizer, "test", self.num_test_samples, self.seq_length) |
|
|
86 |
|
|
|
87 |
def train_dataloader(self) -> TRAIN_DATALOADERS: |
|
|
88 |
return self._create_dataloader(self._train_ds) |
|
|
89 |
|
|
|
90 |
def val_dataloader(self) -> EVAL_DATALOADERS: |
|
|
91 |
return self._create_dataloader(self._validation_ds) |
|
|
92 |
|
|
|
93 |
def test_dataloader(self) -> EVAL_DATALOADERS: |
|
|
94 |
return self._create_dataloader(self._test_ds) |
|
|
95 |
|
|
|
96 |
def _create_dataloader(self, dataset, **kwargs) -> DataLoader: |
|
|
97 |
return DataLoader( |
|
|
98 |
dataset, |
|
|
99 |
num_workers=self.num_workers, |
|
|
100 |
pin_memory=self.pin_memory, |
|
|
101 |
persistent_workers=self.persistent_workers, |
|
|
102 |
collate_fn=dataset.collate_fn, |
|
|
103 |
**kwargs, |
|
|
104 |
) |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
class GptDataItem(TypedDict): |
|
|
108 |
tokens: torch.Tensor |
|
|
109 |
labels: torch.Tensor |
|
|
110 |
attention_mask: torch.Tensor |
|
|
111 |
loss_mask: torch.Tensor |
|
|
112 |
position_ids: torch.Tensor |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
class _MockGPTDataset(Dataset): |
|
|
116 |
def __init__( |
|
|
117 |
self, |
|
|
118 |
tokenizer: TokenizerSpec, |
|
|
119 |
name: str, |
|
|
120 |
num_samples: int, |
|
|
121 |
seq_length: int, |
|
|
122 |
seed: int = 42, |
|
|
123 |
): |
|
|
124 |
super().__init__() |
|
|
125 |
self.name = name |
|
|
126 |
self.seq_length = seq_length |
|
|
127 |
self.vocab_size = tokenizer.vocab_size |
|
|
128 |
self.length = num_samples |
|
|
129 |
self.seed = seed |
|
|
130 |
|
|
|
131 |
self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0) |
|
|
132 |
self.attention_mask = self.attention_mask < 0.5 |
|
|
133 |
self.loss_mask = torch.ones(self.seq_length, dtype=torch.float) |
|
|
134 |
self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) |
|
|
135 |
|
|
|
136 |
def __len__(self) -> int: |
|
|
137 |
return self.length |
|
|
138 |
|
|
|
139 |
def _get_text(self, idx: int) -> np.ndarray: |
|
|
140 |
np_gen = np.random.default_rng(seed=(self.seed + idx)) |
|
|
141 |
return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) |
|
|
142 |
|
|
|
143 |
def __getitem__(self, idx) -> GptDataItem: |
|
|
144 |
# Generate data of the expected size and datatype (based on GPTDataset). |
|
|
145 |
np_gen = np.random.default_rng(seed=(self.seed + idx)) |
|
|
146 |
# Always return the same thing |
|
|
147 |
np_gen = np.random.default_rng(seed=(self.seed)) |
|
|
148 |
tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) |
|
|
149 |
labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) |
|
|
150 |
|
|
|
151 |
return { |
|
|
152 |
"tokens": tokens, |
|
|
153 |
"labels": labels, |
|
|
154 |
"attention_mask": self.attention_mask, |
|
|
155 |
"loss_mask": self.loss_mask, |
|
|
156 |
"position_ids": self.position_ids, |
|
|
157 |
} |
|
|
158 |
|
|
|
159 |
def _collate_fn(self, batch: DataT) -> DataT: |
|
|
160 |
"""A default implementation of a collation function. |
|
|
161 |
Users should override this method to define custom data loaders. |
|
|
162 |
""" |
|
|
163 |
return data.dataloader.default_collate(batch) |
|
|
164 |
|
|
|
165 |
def collate_fn(self, batch: DataT) -> DataT: |
|
|
166 |
"""Method that user pass as functor to DataLoader. |
|
|
167 |
|
|
|
168 |
The method optionally performs neural type checking and add types to the outputs. |
|
|
169 |
|
|
|
170 |
Please note, subclasses of Dataset should not implement `input_types`. |
|
|
171 |
|
|
|
172 |
# Usage: |
|
|
173 |
dataloader = torch.utils.data.DataLoader( |
|
|
174 |
...., |
|
|
175 |
collate_fn=dataset.collate_fn, |
|
|
176 |
.... |
|
|
177 |
) |
|
|
178 |
|
|
|
179 |
Returns: |
|
|
180 |
------- |
|
|
181 |
Collated batch, with or without types. |
|
|
182 |
""" |
|
|
183 |
return self._collate_fn(batch) |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
def main() -> None: |
|
|
187 |
devices, seq_length = 1, 2048 |
|
|
188 |
|
|
|
189 |
strategy = nl.MegatronStrategy( |
|
|
190 |
tensor_model_parallel_size=1, |
|
|
191 |
pipeline_model_parallel_size=1, |
|
|
192 |
pipeline_dtype=torch.float32, |
|
|
193 |
ckpt_async_save=False, |
|
|
194 |
) |
|
|
195 |
trainer = nl.Trainer( |
|
|
196 |
devices=devices, |
|
|
197 |
max_steps=100, |
|
|
198 |
accelerator="gpu", |
|
|
199 |
strategy=strategy, |
|
|
200 |
callbacks=[LossLoggingCallback()], |
|
|
201 |
# TODO(@jstjohn) See if we can get the example working with mixed precision |
|
|
202 |
# plugins=nl.MegatronMixedPrecision(precision="float32", amp_O2=False), |
|
|
203 |
) |
|
|
204 |
|
|
|
205 |
_data = MockDataModule(seq_length=seq_length, global_batch_size=32) |
|
|
206 |
|
|
|
207 |
gpt_config = llm.GPTConfig( |
|
|
208 |
num_layers=4, |
|
|
209 |
hidden_size=256, |
|
|
210 |
ffn_hidden_size=512, |
|
|
211 |
num_attention_heads=4, |
|
|
212 |
seq_length=seq_length, |
|
|
213 |
pipeline_dtype=torch.float32, |
|
|
214 |
) |
|
|
215 |
model = llm.GPTModel(gpt_config, tokenizer=_data.tokenizer) |
|
|
216 |
|
|
|
217 |
trainer.fit(model, _data) |
|
|
218 |
checkpoint_path = Path(trainer.logger.log_dir) / "ckpt" |
|
|
219 |
trainer.save_checkpoint(checkpoint_path) |
|
|
220 |
|
|
|
221 |
|
|
|
222 |
if __name__ == "__main__": |
|
|
223 |
main() |