--- a +++ b/scripts/gpt-pretrain.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import List, Optional, Sequence, TypedDict + +import lightning.pytorch as pl +import numpy as np +import torch + +# In lightning.pytorch 2.0 these are commented as being "any iterable or collection of iterables" +# for now we'll use them incase the lightning type becomes something more specific in a future release. +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning.megatron_parallel import DataT +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from torch.utils import data +from torch.utils.data import DataLoader, Dataset + +from bionemo.llm.model.biobert.lightning import LossLoggingCallback + + +__all__: Sequence[str] = () + + +class MockDataModule(pl.LightningDataModule): + def __init__( + self, + seq_length: int = 2048, + tokenizer: Optional[TokenizerSpec] = None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + num_train_samples: int = 10_000, + num_val_samples: int = 10_000, + num_test_samples: int = 10_000, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + ): + super().__init__() + self.seq_length = seq_length + self.num_train_samples = num_train_samples + self.num_val_samples = num_val_samples + self.num_test_samples = num_test_samples + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer") + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + ) + # NOTE: the datasets and other distributed state is instantiated in `setup` rather than in `__init__` to support + # the different kinds of accellerators/strategies that lightning supports. This is a common pattern in lightning. + + def setup(self, stage: str = "") -> None: + """See lightning documentation for more information on the stage and setup method. It is not required but + if you want to be efficient about only initializing data that is needed in a particular stage you can do it here. + According to the documentation valid values match the available calls to trainer.{fit,validate,test,predict}, + for example stage="fit". If we wanted to be fancy we could only initialize train/val during "fit". We could + only instantiate "test" data during "test" etc. + """ + self._train_ds = _MockGPTDataset(self.tokenizer, "train", self.num_train_samples, self.seq_length) + self._validation_ds = _MockGPTDataset(self.tokenizer, "valid", self.num_val_samples, self.seq_length) + self._test_ds = _MockGPTDataset(self.tokenizer, "test", self.num_test_samples, self.seq_length) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + return self._create_dataloader(self._validation_ds) + + def test_dataloader(self) -> EVAL_DATALOADERS: + return self._create_dataloader(self._test_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + + +class GptDataItem(TypedDict): + tokens: torch.Tensor + labels: torch.Tensor + attention_mask: torch.Tensor + loss_mask: torch.Tensor + position_ids: torch.Tensor + + +class _MockGPTDataset(Dataset): + def __init__( + self, + tokenizer: TokenizerSpec, + name: str, + num_samples: int, + seq_length: int, + seed: int = 42, + ): + super().__init__() + self.name = name + self.seq_length = seq_length + self.vocab_size = tokenizer.vocab_size + self.length = num_samples + self.seed = seed + + self.attention_mask = torch.tril(torch.ones((self.seq_length, self.seq_length))).unsqueeze(0) + self.attention_mask = self.attention_mask < 0.5 + self.loss_mask = torch.ones(self.seq_length, dtype=torch.float) + self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) + + def __len__(self) -> int: + return self.length + + def _get_text(self, idx: int) -> np.ndarray: + np_gen = np.random.default_rng(seed=(self.seed + idx)) + return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) + + def __getitem__(self, idx) -> GptDataItem: + # Generate data of the expected size and datatype (based on GPTDataset). + np_gen = np.random.default_rng(seed=(self.seed + idx)) + # Always return the same thing + np_gen = np.random.default_rng(seed=(self.seed)) + tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) + labels = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)) + + return { + "tokens": tokens, + "labels": labels, + "attention_mask": self.attention_mask, + "loss_mask": self.loss_mask, + "position_ids": self.position_ids, + } + + def _collate_fn(self, batch: DataT) -> DataT: + """A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + return data.dataloader.default_collate(batch) + + def collate_fn(self, batch: DataT) -> DataT: + """Method that user pass as functor to DataLoader. + + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns: + ------- + Collated batch, with or without types. + """ + return self._collate_fn(batch) + + +def main() -> None: + devices, seq_length = 1, 2048 + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + pipeline_dtype=torch.float32, + ckpt_async_save=False, + ) + trainer = nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + callbacks=[LossLoggingCallback()], + # TODO(@jstjohn) See if we can get the example working with mixed precision + # plugins=nl.MegatronMixedPrecision(precision="float32", amp_O2=False), + ) + + _data = MockDataModule(seq_length=seq_length, global_batch_size=32) + + gpt_config = llm.GPTConfig( + num_layers=4, + hidden_size=256, + ffn_hidden_size=512, + num_attention_heads=4, + seq_length=seq_length, + pipeline_dtype=torch.float32, + ) + model = llm.GPTModel(gpt_config, tokenizer=_data.tokenizer) + + trainer.fit(model, _data) + checkpoint_path = Path(trainer.logger.log_dir) / "ckpt" + trainer.save_checkpoint(checkpoint_path) + + +if __name__ == "__main__": + main()