Diff of /scripts/gpt-pretrain.py [000000] .. [b9e282]

Switch to unified view

a b/scripts/gpt-pretrain.py
1
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
# SPDX-License-Identifier: LicenseRef-Apache2
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
16
17
from pathlib import Path
18
from typing import List, Optional, Sequence, TypedDict
19
20
import lightning.pytorch as pl
21
import numpy as np
22
import torch
23
24
# In lightning.pytorch 2.0 these are commented as being "any iterable or collection of iterables"
25
#  for now we'll use them incase the lightning type becomes something more specific in a future release.
26
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
27
from nemo import lightning as nl
28
from nemo.collections import llm
29
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
30
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
31
from nemo.lightning.megatron_parallel import DataT
32
from nemo.lightning.pytorch.plugins import MegatronDataSampler
33
from torch.utils import data
34
from torch.utils.data import DataLoader, Dataset
35
36
from bionemo.llm.model.biobert.lightning import LossLoggingCallback
37
38
39
__all__: Sequence[str] = ()
40
41
42
class MockDataModule(pl.LightningDataModule):
43
    def __init__(
44
        self,
45
        seq_length: int = 2048,
46
        tokenizer: Optional[TokenizerSpec] = None,
47
        micro_batch_size: int = 4,
48
        global_batch_size: int = 8,
49
        rampup_batch_size: Optional[List[int]] = None,
50
        num_train_samples: int = 10_000,
51
        num_val_samples: int = 10_000,
52
        num_test_samples: int = 10_000,
53
        num_workers: int = 8,
54
        pin_memory: bool = True,
55
        persistent_workers: bool = False,
56
    ):
57
        super().__init__()
58
        self.seq_length = seq_length
59
        self.num_train_samples = num_train_samples
60
        self.num_val_samples = num_val_samples
61
        self.num_test_samples = num_test_samples
62
        self.num_workers = num_workers
63
        self.pin_memory = pin_memory
64
        self.persistent_workers = persistent_workers
65
66
        self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
67
        self.data_sampler = MegatronDataSampler(
68
            seq_len=self.seq_length,
69
            micro_batch_size=micro_batch_size,
70
            global_batch_size=global_batch_size,
71
            rampup_batch_size=rampup_batch_size,
72
        )
73
        # NOTE: the datasets and other distributed state is instantiated in `setup` rather than in `__init__` to support
74
        #  the different kinds of accellerators/strategies that lightning supports. This is a common pattern in lightning.
75
76
    def setup(self, stage: str = "") -> None:
77
        """See lightning documentation for more information on the stage and setup method. It is not required but
78
        if you want to be efficient about only initializing data that is needed in a particular stage you can do it here.
79
        According to the documentation valid values match the available calls to trainer.{fit,validate,test,predict},
80
        for example stage="fit". If we wanted to be fancy we could only initialize train/val during "fit". We could
81
        only instantiate "test" data during "test" etc.
82
        """
83
        self._train_ds = _MockGPTDataset(self.tokenizer, "train", self.num_train_samples, self.seq_length)
84
        self._validation_ds = _MockGPTDataset(self.tokenizer, "valid", self.num_val_samples, self.seq_length)
85
        self._test_ds = _MockGPTDataset(self.tokenizer, "test", self.num_test_samples, self.seq_length)
86
87
    def train_dataloader(self) -> TRAIN_DATALOADERS:
88
        return self._create_dataloader(self._train_ds)
89
90
    def val_dataloader(self) -> EVAL_DATALOADERS:
91
        return self._create_dataloader(self._validation_ds)
92
93
    def test_dataloader(self) -> EVAL_DATALOADERS:
94
        return self._create_dataloader(self._test_ds)
95
96
    def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
97
        return DataLoader(
98
            dataset,
99
            num_workers=self.num_workers,
100
            pin_memory=self.pin_memory,
101
            persistent_workers=self.persistent_workers,
102
            collate_fn=dataset.collate_fn,
103
            **kwargs,
104
        )
105
106
107
class GptDataItem(TypedDict):
108
    tokens: torch.Tensor
109
    labels: torch.Tensor
110
    attention_mask: torch.Tensor
111
    loss_mask: torch.Tensor
112
    position_ids: torch.Tensor
113
114
115
class _MockGPTDataset(Dataset):
116
    def __init__(
117
        self,
118
        tokenizer: TokenizerSpec,
119
        name: str,
120
        num_samples: int,
121
        seq_length: int,
122
        seed: int = 42,
123
    ):
124
        super().__init__()
125
        self.name = name
126
        self.seq_length = seq_length
127
        self.vocab_size = tokenizer.vocab_size
128
        self.length = num_samples
129
        self.seed = seed
130
131
        self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0)
132
        self.attention_mask = self.attention_mask < 0.5
133
        self.loss_mask = torch.ones(self.seq_length, dtype=torch.float)
134
        self.position_ids = torch.arange(self.seq_length, dtype=torch.int64)
135
136
    def __len__(self) -> int:
137
        return self.length
138
139
    def _get_text(self, idx: int) -> np.ndarray:
140
        np_gen = np.random.default_rng(seed=(self.seed + idx))
141
        return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)
142
143
    def __getitem__(self, idx) -> GptDataItem:
144
        # Generate data of the expected size and datatype (based on GPTDataset).
145
        np_gen = np.random.default_rng(seed=(self.seed + idx))
146
        # Always return the same thing
147
        np_gen = np.random.default_rng(seed=(self.seed))
148
        tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64))
149
        labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64))
150
151
        return {
152
            "tokens": tokens,
153
            "labels": labels,
154
            "attention_mask": self.attention_mask,
155
            "loss_mask": self.loss_mask,
156
            "position_ids": self.position_ids,
157
        }
158
159
    def _collate_fn(self, batch: DataT) -> DataT:
160
        """A default implementation of a collation function.
161
        Users should override this method to define custom data loaders.
162
        """
163
        return data.dataloader.default_collate(batch)
164
165
    def collate_fn(self, batch: DataT) -> DataT:
166
        """Method that user pass as functor to DataLoader.
167
168
        The method optionally performs neural type checking and add types to the outputs.
169
170
        Please note, subclasses of Dataset should not implement `input_types`.
171
172
        # Usage:
173
        dataloader = torch.utils.data.DataLoader(
174
                ....,
175
                collate_fn=dataset.collate_fn,
176
                ....
177
        )
178
179
        Returns:
180
        -------
181
            Collated batch, with or without types.
182
        """
183
        return self._collate_fn(batch)
184
185
186
def main() -> None:
187
    devices, seq_length = 1, 2048
188
189
    strategy = nl.MegatronStrategy(
190
        tensor_model_parallel_size=1,
191
        pipeline_model_parallel_size=1,
192
        pipeline_dtype=torch.float32,
193
        ckpt_async_save=False,
194
    )
195
    trainer = nl.Trainer(
196
        devices=devices,
197
        max_steps=100,
198
        accelerator="gpu",
199
        strategy=strategy,
200
        callbacks=[LossLoggingCallback()],
201
        # TODO(@jstjohn) See if we can get the example working with mixed precision
202
        # plugins=nl.MegatronMixedPrecision(precision="float32", amp_O2=False),
203
    )
204
205
    _data = MockDataModule(seq_length=seq_length, global_batch_size=32)
206
207
    gpt_config = llm.GPTConfig(
208
        num_layers=4,
209
        hidden_size=256,
210
        ffn_hidden_size=512,
211
        num_attention_heads=4,
212
        seq_length=seq_length,
213
        pipeline_dtype=torch.float32,
214
    )
215
    model = llm.GPTModel(gpt_config, tokenizer=_data.tokenizer)
216
217
    trainer.fit(model, _data)
218
    checkpoint_path = Path(trainer.logger.log_dir) / "ckpt"
219
    trainer.save_checkpoint(checkpoint_path)
220
221
222
if __name__ == "__main__":
223
    main()