--- a +++ b/tutorials/6_Train MOTOR.ipynb @@ -0,0 +1,564 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1c81279e-a568-4e36-9906-06317accb622", + "metadata": {}, + "source": [ + "# Train MOTOR\n", + "\n", + "This tutorial walks through the various steps to train a MOTOR model.\n", + "\n", + "Training MOTOR is a four step process:\n", + "\n", + "- Training a tokenizer\n", + "- Prefitting MOTOR\n", + "- Preparing batches\n", + "- Training the model" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7dcdfd70-58a1-4460-80a8-db737a8c5cd6", + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "import os\n", + "\n", + "# os.environ[\"HF_DATASETS_CACHE\"] = '/share/pi/nigam/ethanid/cache_dir'\n", + "\n", + "\n", + "TARGET_DIR = 'trash/tutorial_6'\n", + "\n", + "if os.path.exists(TARGET_DIR):\n", + " shutil.rmtree(TARGET_DIR)\n", + "\n", + "os.mkdir(TARGET_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "646f7590", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1395.58 examples/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0, 1, 2, 4, 6, 7, 10, 11, 12, 13, 14, 15, 18, 20, 21, 23, 24, 26, 27, 28, 29, 30, 31, 33, 36, 37, 38, 40, 42, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 83, 85, 86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 107, 109, 110, 112, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 133, 134, 135, 136, 137, 139, 141, 142, 143, 144, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 172, 173, 174, 178, 181, 182, 183, 184, 185, 186, 187, 189, 192, 193, 195, 196, 197, 198, 199]\n", + "[19, 22, 25, 39, 46, 71, 82, 84, 87, 92, 106, 108, 113, 131, 132, 138, 146, 147, 148, 155, 177, 179, 180, 188, 190, 191]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1316.29 examples/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['patient_id', 'events'],\n", + " num_rows: 144\n", + " })\n", + " test: Dataset({\n", + " features: ['patient_id', 'events'],\n", + " num_rows: 26\n", + " })\n", + "})\n" + ] + } + ], + "source": [ + "import datasets\n", + "import femr.index\n", + "import femr.splits\n", + "\n", + "# First, we want to split our dataset into train, valid, and test\n", + "# We do this by calling our split functionality twice\n", + "\n", + "dataset = datasets.Dataset.from_parquet('input/meds/data/*')\n", + "\n", + "\n", + "index = femr.index.PatientIndex(dataset, num_proc=4)\n", + "main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)\n", + "\n", + "os.mkdir(os.path.join(TARGET_DIR, 'motor_model'))\n", + "# Note that we want to save this to the target directory since this is important information\n", + "\n", + "main_split.save_to_csv(os.path.join(TARGET_DIR, \"motor_model\", \"main_split.csv\"))\n", + "\n", + "train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)\n", + "\n", + "print(train_split.train_patient_ids)\n", + "print(train_split.test_patient_ids)\n", + "\n", + "main_dataset = main_split.split_dataset(dataset, index)\n", + "train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=4))\n", + "\n", + "print(train_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f60ab7df-e851-44a5-ab70-7bee292be00c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 331.19 examples/s]\n" + ] + } + ], + "source": [ + "import femr.models.tokenizer\n", + "import pickle\n", + "\n", + "# First, we need to train a tokenizer\n", + "# Note, we need to use a hierarchical tokenizer for MOTOR\n", + "\n", + "with open('input/meds/ontology.pkl', 'rb') as f:\n", + " ontology = pickle.load(f)\n", + "\n", + "# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run\n", + "tokenizer = femr.models.tokenizer.train_tokenizer(\n", + " main_dataset['train'], vocab_size=128, is_hierarchical=True, num_proc=4, ontology=ontology)\n", + "\n", + "# Save the tokenizer to the same directory as the model\n", + "tokenizer.save_pretrained(os.path.join(TARGET_DIR, \"motor_model\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "69b60daa", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 249.31 examples/s]\n" + ] + } + ], + "source": [ + "\n", + "import femr.models.tasks\n", + "\n", + "# Second, we need to prefit the MOTOR model. This is necessary because piecewise exponential models are unstable without an initial fit\n", + "\n", + "motor_task = femr.models.tasks.MOTORTask.fit_pretraining_task_info(\n", + " main_dataset['train'], tokenizer, num_tasks=64, num_bins=4, final_layer_size=32, num_proc=4)\n", + "\n", + "\n", + "# It's recommended to save this with pickle to avoid recomputing since it's an expensive operation\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "89611ba9-a242-4b87-9b8f-25670d838fc6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Convert a single patient\n", + "Convert batches\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 261.72 examples/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating batches 7\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating train split: 7 examples [00:00, 12.06 examples/s]\n", + "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 50.63 examples/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating batches 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting num_proc from 4 back to 1 for the train split to disable multiprocessing as it only contains one shard.\n", + "Generating train split: 1 examples [00:00, 57.97 examples/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Convert batches to pytorch\n", + "Done\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import femr.models.processor\n", + "import femr.models.tasks\n", + "\n", + "# Third, we need to create batches. \n", + "\n", + "processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)\n", + "\n", + "# We can do this one patient at a time\n", + "print(\"Convert a single patient\")\n", + "example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])\n", + "\n", + "print(\"Convert batches\")\n", + "# But generally we want to convert entire datasets\n", + "train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=4)\n", + "\n", + "print(\"Convert batches to pytorch\")\n", + "# Convert our batches to pytorch tensors\n", + "train_batches.set_format(\"pt\")\n", + "print(\"Done\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f654a46c-5aa7-465c-b6c5-73d8ba26ed67", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n", + " warnings.warn(\"Can't initialize NVML\")\n", + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", + "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", + " warnings.warn(\n", + "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " <div>\n", + " \n", + " <progress value='700' max='700' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", + " [700/700 00:10, Epoch 100/100]\n", + " </div>\n", + " <table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: left;\">\n", + " <th>Step</th>\n", + " <th>Training Loss</th>\n", + " <th>Validation Loss</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <td>20</td>\n", + " <td>0.855400</td>\n", + " <td>0.506942</td>\n", + " </tr>\n", + " <tr>\n", + " <td>40</td>\n", + " <td>0.871100</td>\n", + " <td>0.506998</td>\n", + " </tr>\n", + " <tr>\n", + " <td>60</td>\n", + " <td>0.826900</td>\n", + " <td>0.507056</td>\n", + " </tr>\n", + " <tr>\n", + " <td>80</td>\n", + " <td>0.856700</td>\n", + " <td>0.507116</td>\n", + " </tr>\n", + " <tr>\n", + " <td>100</td>\n", + " <td>0.856200</td>\n", + " <td>0.507181</td>\n", + " </tr>\n", + " <tr>\n", + " <td>120</td>\n", + " <td>0.829800</td>\n", + " <td>0.507251</td>\n", + " </tr>\n", + " <tr>\n", + " <td>140</td>\n", + " <td>0.859700</td>\n", + " <td>0.507321</td>\n", + " </tr>\n", + " <tr>\n", + " <td>160</td>\n", + " <td>0.851600</td>\n", + " <td>0.507393</td>\n", + " </tr>\n", + " <tr>\n", + " <td>180</td>\n", + " <td>0.852500</td>\n", + " <td>0.507467</td>\n", + " </tr>\n", + " <tr>\n", + " <td>200</td>\n", + " <td>0.868400</td>\n", + " <td>0.507540</td>\n", + " </tr>\n", + " <tr>\n", + " <td>220</td>\n", + " <td>0.850800</td>\n", + " <td>0.507617</td>\n", + " </tr>\n", + " <tr>\n", + " <td>240</td>\n", + " <td>0.835900</td>\n", + " <td>0.507696</td>\n", + " </tr>\n", + " <tr>\n", + " <td>260</td>\n", + " <td>0.850000</td>\n", + " <td>0.507768</td>\n", + " </tr>\n", + " <tr>\n", + " <td>280</td>\n", + " <td>0.831500</td>\n", + " <td>0.507841</td>\n", + " </tr>\n", + " <tr>\n", + " <td>300</td>\n", + " <td>0.860700</td>\n", + " <td>0.507915</td>\n", + " </tr>\n", + " <tr>\n", + " <td>320</td>\n", + " <td>0.846000</td>\n", + " <td>0.507988</td>\n", + " </tr>\n", + " <tr>\n", + " <td>340</td>\n", + " <td>0.826800</td>\n", + " <td>0.508055</td>\n", + " </tr>\n", + " <tr>\n", + " <td>360</td>\n", + " <td>0.830600</td>\n", + " <td>0.508123</td>\n", + " </tr>\n", + " <tr>\n", + " <td>380</td>\n", + " <td>0.884700</td>\n", + " <td>0.508188</td>\n", + " </tr>\n", + " <tr>\n", + " <td>400</td>\n", + " <td>0.823900</td>\n", + " <td>0.508248</td>\n", + " </tr>\n", + " <tr>\n", + " <td>420</td>\n", + " <td>0.856200</td>\n", + " <td>0.508309</td>\n", + " </tr>\n", + " <tr>\n", + " <td>440</td>\n", + " <td>0.848400</td>\n", + " <td>0.508360</td>\n", + " </tr>\n", + " <tr>\n", + " <td>460</td>\n", + " <td>0.855900</td>\n", + " <td>0.508413</td>\n", + " </tr>\n", + " <tr>\n", + " <td>480</td>\n", + " <td>0.849500</td>\n", + " <td>0.508458</td>\n", + " </tr>\n", + " <tr>\n", + " <td>500</td>\n", + " <td>0.831200</td>\n", + " <td>0.508502</td>\n", + " </tr>\n", + " <tr>\n", + " <td>520</td>\n", + " <td>0.848300</td>\n", + " <td>0.508542</td>\n", + " </tr>\n", + " <tr>\n", + " <td>540</td>\n", + " <td>0.858700</td>\n", + " <td>0.508577</td>\n", + " </tr>\n", + " <tr>\n", + " <td>560</td>\n", + " <td>0.829200</td>\n", + " <td>0.508608</td>\n", + " </tr>\n", + " <tr>\n", + " <td>580</td>\n", + " <td>0.858500</td>\n", + " <td>0.508636</td>\n", + " </tr>\n", + " <tr>\n", + " <td>600</td>\n", + " <td>0.825800</td>\n", + " <td>0.508659</td>\n", + " </tr>\n", + " <tr>\n", + " <td>620</td>\n", + " <td>0.878200</td>\n", + " <td>0.508677</td>\n", + " </tr>\n", + " <tr>\n", + " <td>640</td>\n", + " <td>0.839500</td>\n", + " <td>0.508692</td>\n", + " </tr>\n", + " <tr>\n", + " <td>660</td>\n", + " <td>0.813000</td>\n", + " <td>0.508703</td>\n", + " </tr>\n", + " <tr>\n", + " <td>680</td>\n", + " <td>0.854800</td>\n", + " <td>0.508709</td>\n", + " </tr>\n", + " <tr>\n", + " <td>700</td>\n", + " <td>0.847300</td>\n", + " <td>0.508711</td>\n", + " </tr>\n", + " </tbody>\n", + "</table><p>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import transformers\n", + "\n", + "import femr.models.transformer\n", + "\n", + "# Finally, given the batches, we can train CLMBR.\n", + "# We can use huggingface's trainer to do this.\n", + "\n", + "transformer_config = femr.models.config.FEMRTransformerConfig(\n", + " vocab_size=tokenizer.vocab_size, \n", + " is_hierarchical=tokenizer.is_hierarchical, \n", + " n_layers=2,\n", + " hidden_size=64, \n", + " intermediate_size=64*2,\n", + " n_heads=8,\n", + ")\n", + "\n", + "config = femr.models.config.FEMRModelConfig.from_transformer_task_configs(transformer_config, motor_task.get_task_config())\n", + "\n", + "model = femr.models.transformer.FEMRModel(config)\n", + "\n", + "collator = processor.collate\n", + "\n", + "trainer_config = transformers.TrainingArguments(\n", + " per_device_train_batch_size=1,\n", + " per_device_eval_batch_size=1,\n", + "\n", + " output_dir='tmp_trainer',\n", + " remove_unused_columns=False,\n", + " num_train_epochs=100,\n", + "\n", + " eval_steps=20,\n", + " evaluation_strategy=\"steps\",\n", + "\n", + " logging_steps=20,\n", + " logging_strategy='steps',\n", + "\n", + " prediction_loss_only=True,\n", + ")\n", + "\n", + "trainer = transformers.Trainer(\n", + " model=model,\n", + " data_collator=processor.collate,\n", + " train_dataset=train_batches['train'],\n", + " eval_dataset=train_batches['test'],\n", + " args=trainer_config,\n", + ")\n", + "\n", + "trainer.train()\n", + "\n", + "model.save_pretrained(os.path.join(TARGET_DIR, 'motor_model'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}