Switch to side-by-side view

--- a
+++ b/tutorials/6_Train MOTOR.ipynb
@@ -0,0 +1,564 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "1c81279e-a568-4e36-9906-06317accb622",
+   "metadata": {},
+   "source": [
+    "# Train MOTOR\n",
+    "\n",
+    "This tutorial walks through the various steps to train a MOTOR model.\n",
+    "\n",
+    "Training MOTOR is a four step process:\n",
+    "\n",
+    "- Training a tokenizer\n",
+    "- Prefitting MOTOR\n",
+    "- Preparing batches\n",
+    "- Training the model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "7dcdfd70-58a1-4460-80a8-db737a8c5cd6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import shutil\n",
+    "import os\n",
+    "\n",
+    "# os.environ[\"HF_DATASETS_CACHE\"] = '/share/pi/nigam/ethanid/cache_dir'\n",
+    "\n",
+    "\n",
+    "TARGET_DIR = 'trash/tutorial_6'\n",
+    "\n",
+    "if os.path.exists(TARGET_DIR):\n",
+    "    shutil.rmtree(TARGET_DIR)\n",
+    "\n",
+    "os.mkdir(TARGET_DIR)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "646f7590",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n",
+      "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1395.58 examples/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[0, 1, 2, 4, 6, 7, 10, 11, 12, 13, 14, 15, 18, 20, 21, 23, 24, 26, 27, 28, 29, 30, 31, 33, 36, 37, 38, 40, 42, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 83, 85, 86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 107, 109, 110, 112, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 133, 134, 135, 136, 137, 139, 141, 142, 143, 144, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 172, 173, 174, 178, 181, 182, 183, 184, 185, 186, 187, 189, 192, 193, 195, 196, 197, 198, 199]\n",
+      "[19, 22, 25, 39, 46, 71, 82, 84, 87, 92, 106, 108, 113, 131, 132, 138, 146, 147, 148, 155, 177, 179, 180, 188, 190, 191]\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1316.29 examples/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "DatasetDict({\n",
+      "    train: Dataset({\n",
+      "        features: ['patient_id', 'events'],\n",
+      "        num_rows: 144\n",
+      "    })\n",
+      "    test: Dataset({\n",
+      "        features: ['patient_id', 'events'],\n",
+      "        num_rows: 26\n",
+      "    })\n",
+      "})\n"
+     ]
+    }
+   ],
+   "source": [
+    "import datasets\n",
+    "import femr.index\n",
+    "import femr.splits\n",
+    "\n",
+    "# First, we want to split our dataset into train, valid, and test\n",
+    "# We do this by calling our split functionality twice\n",
+    "\n",
+    "dataset = datasets.Dataset.from_parquet('input/meds/data/*')\n",
+    "\n",
+    "\n",
+    "index = femr.index.PatientIndex(dataset, num_proc=4)\n",
+    "main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)\n",
+    "\n",
+    "os.mkdir(os.path.join(TARGET_DIR, 'motor_model'))\n",
+    "# Note that we want to save this to the target directory since this is important information\n",
+    "\n",
+    "main_split.save_to_csv(os.path.join(TARGET_DIR, \"motor_model\", \"main_split.csv\"))\n",
+    "\n",
+    "train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)\n",
+    "\n",
+    "print(train_split.train_patient_ids)\n",
+    "print(train_split.test_patient_ids)\n",
+    "\n",
+    "main_dataset = main_split.split_dataset(dataset, index)\n",
+    "train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=4))\n",
+    "\n",
+    "print(train_dataset)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "f60ab7df-e851-44a5-ab70-7bee292be00c",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 331.19 examples/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "import femr.models.tokenizer\n",
+    "import pickle\n",
+    "\n",
+    "# First, we need to train a tokenizer\n",
+    "# Note, we need to use a hierarchical tokenizer for MOTOR\n",
+    "\n",
+    "with open('input/meds/ontology.pkl', 'rb') as f:\n",
+    "    ontology = pickle.load(f)\n",
+    "\n",
+    "# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run\n",
+    "tokenizer = femr.models.tokenizer.train_tokenizer(\n",
+    "    main_dataset['train'], vocab_size=128, is_hierarchical=True, num_proc=4, ontology=ontology)\n",
+    "\n",
+    "# Save the tokenizer to the same directory as the model\n",
+    "tokenizer.save_pretrained(os.path.join(TARGET_DIR, \"motor_model\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "69b60daa",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 249.31 examples/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "import femr.models.tasks\n",
+    "\n",
+    "# Second, we need to prefit the MOTOR model. This is necessary because piecewise exponential models are unstable without an initial fit\n",
+    "\n",
+    "motor_task = femr.models.tasks.MOTORTask.fit_pretraining_task_info(\n",
+    "    main_dataset['train'], tokenizer, num_tasks=64, num_bins=4, final_layer_size=32, num_proc=4)\n",
+    "\n",
+    "\n",
+    "# It's recommended to save this with pickle to avoid recomputing since it's an expensive operation\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "89611ba9-a242-4b87-9b8f-25670d838fc6",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Convert a single patient\n",
+      "Convert batches\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 261.72 examples/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Creating batches 7\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Generating train split: 7 examples [00:00, 12.06 examples/s]\n",
+      "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 50.63 examples/s]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Creating batches 1\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Setting num_proc from 4 back to 1 for the train split to disable multiprocessing as it only contains one shard.\n",
+      "Generating train split: 1 examples [00:00, 57.97 examples/s]"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Convert batches to pytorch\n",
+      "Done\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "import femr.models.processor\n",
+    "import femr.models.tasks\n",
+    "\n",
+    "# Third, we need to create batches. \n",
+    "\n",
+    "processor = femr.models.processor.FEMRBatchProcessor(tokenizer, motor_task)\n",
+    "\n",
+    "# We can do this one patient at a time\n",
+    "print(\"Convert a single patient\")\n",
+    "example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])\n",
+    "\n",
+    "print(\"Convert batches\")\n",
+    "# But generally we want to convert entire datasets\n",
+    "train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=4)\n",
+    "\n",
+    "print(\"Convert batches to pytorch\")\n",
+    "# Convert our batches to pytorch tensors\n",
+    "train_batches.set_format(\"pt\")\n",
+    "print(\"Done\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "f654a46c-5aa7-465c-b6c5-73d8ba26ed67",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n",
+      "  warnings.warn(\"Can't initialize NVML\")\n",
+      "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
+      "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
+      "  warnings.warn(\n",
+      "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "\n",
+       "    <div>\n",
+       "      \n",
+       "      <progress value='700' max='700' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
+       "      [700/700 00:10, Epoch 100/100]\n",
+       "    </div>\n",
+       "    <table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       " <tr style=\"text-align: left;\">\n",
+       "      <th>Step</th>\n",
+       "      <th>Training Loss</th>\n",
+       "      <th>Validation Loss</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <td>20</td>\n",
+       "      <td>0.855400</td>\n",
+       "      <td>0.506942</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>40</td>\n",
+       "      <td>0.871100</td>\n",
+       "      <td>0.506998</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>60</td>\n",
+       "      <td>0.826900</td>\n",
+       "      <td>0.507056</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>80</td>\n",
+       "      <td>0.856700</td>\n",
+       "      <td>0.507116</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>100</td>\n",
+       "      <td>0.856200</td>\n",
+       "      <td>0.507181</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>120</td>\n",
+       "      <td>0.829800</td>\n",
+       "      <td>0.507251</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>140</td>\n",
+       "      <td>0.859700</td>\n",
+       "      <td>0.507321</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>160</td>\n",
+       "      <td>0.851600</td>\n",
+       "      <td>0.507393</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>180</td>\n",
+       "      <td>0.852500</td>\n",
+       "      <td>0.507467</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>200</td>\n",
+       "      <td>0.868400</td>\n",
+       "      <td>0.507540</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>220</td>\n",
+       "      <td>0.850800</td>\n",
+       "      <td>0.507617</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>240</td>\n",
+       "      <td>0.835900</td>\n",
+       "      <td>0.507696</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>260</td>\n",
+       "      <td>0.850000</td>\n",
+       "      <td>0.507768</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>280</td>\n",
+       "      <td>0.831500</td>\n",
+       "      <td>0.507841</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>300</td>\n",
+       "      <td>0.860700</td>\n",
+       "      <td>0.507915</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>320</td>\n",
+       "      <td>0.846000</td>\n",
+       "      <td>0.507988</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>340</td>\n",
+       "      <td>0.826800</td>\n",
+       "      <td>0.508055</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>360</td>\n",
+       "      <td>0.830600</td>\n",
+       "      <td>0.508123</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>380</td>\n",
+       "      <td>0.884700</td>\n",
+       "      <td>0.508188</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>400</td>\n",
+       "      <td>0.823900</td>\n",
+       "      <td>0.508248</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>420</td>\n",
+       "      <td>0.856200</td>\n",
+       "      <td>0.508309</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>440</td>\n",
+       "      <td>0.848400</td>\n",
+       "      <td>0.508360</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>460</td>\n",
+       "      <td>0.855900</td>\n",
+       "      <td>0.508413</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>480</td>\n",
+       "      <td>0.849500</td>\n",
+       "      <td>0.508458</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>500</td>\n",
+       "      <td>0.831200</td>\n",
+       "      <td>0.508502</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>520</td>\n",
+       "      <td>0.848300</td>\n",
+       "      <td>0.508542</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>540</td>\n",
+       "      <td>0.858700</td>\n",
+       "      <td>0.508577</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>560</td>\n",
+       "      <td>0.829200</td>\n",
+       "      <td>0.508608</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>580</td>\n",
+       "      <td>0.858500</td>\n",
+       "      <td>0.508636</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>600</td>\n",
+       "      <td>0.825800</td>\n",
+       "      <td>0.508659</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>620</td>\n",
+       "      <td>0.878200</td>\n",
+       "      <td>0.508677</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>640</td>\n",
+       "      <td>0.839500</td>\n",
+       "      <td>0.508692</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>660</td>\n",
+       "      <td>0.813000</td>\n",
+       "      <td>0.508703</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>680</td>\n",
+       "      <td>0.854800</td>\n",
+       "      <td>0.508709</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <td>700</td>\n",
+       "      <td>0.847300</td>\n",
+       "      <td>0.508711</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table><p>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import transformers\n",
+    "\n",
+    "import femr.models.transformer\n",
+    "\n",
+    "# Finally, given the batches, we can train CLMBR.\n",
+    "# We can use huggingface's trainer to do this.\n",
+    "\n",
+    "transformer_config = femr.models.config.FEMRTransformerConfig(\n",
+    "    vocab_size=tokenizer.vocab_size, \n",
+    "    is_hierarchical=tokenizer.is_hierarchical, \n",
+    "    n_layers=2,\n",
+    "    hidden_size=64, \n",
+    "    intermediate_size=64*2,\n",
+    "    n_heads=8,\n",
+    ")\n",
+    "\n",
+    "config = femr.models.config.FEMRModelConfig.from_transformer_task_configs(transformer_config, motor_task.get_task_config())\n",
+    "\n",
+    "model = femr.models.transformer.FEMRModel(config)\n",
+    "\n",
+    "collator = processor.collate\n",
+    "\n",
+    "trainer_config = transformers.TrainingArguments(\n",
+    "    per_device_train_batch_size=1,\n",
+    "    per_device_eval_batch_size=1,\n",
+    "\n",
+    "    output_dir='tmp_trainer',\n",
+    "    remove_unused_columns=False,\n",
+    "    num_train_epochs=100,\n",
+    "\n",
+    "    eval_steps=20,\n",
+    "    evaluation_strategy=\"steps\",\n",
+    "\n",
+    "    logging_steps=20,\n",
+    "    logging_strategy='steps',\n",
+    "\n",
+    "    prediction_loss_only=True,\n",
+    ")\n",
+    "\n",
+    "trainer = transformers.Trainer(\n",
+    "    model=model,\n",
+    "    data_collator=processor.collate,\n",
+    "    train_dataset=train_batches['train'],\n",
+    "    eval_dataset=train_batches['test'],\n",
+    "    args=trainer_config,\n",
+    ")\n",
+    "\n",
+    "trainer.train()\n",
+    "\n",
+    "model.save_pretrained(os.path.join(TARGET_DIR, 'motor_model'))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.14"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}