387 lines (386 with data), 14.0 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "1c81279e-a568-4e36-9906-06317accb622",
"metadata": {},
"source": [
"# Train CLMBR\n",
"\n",
"This tutorial walks through the various steps to train a CLMBR model.\n",
"\n",
"Training CLMBR is a three step process:\n",
"\n",
"- Training a tokenizer\n",
"- Preparing batches\n",
"- Training the model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7dcdfd70-58a1-4460-80a8-db737a8c5cd6",
"metadata": {},
"outputs": [],
"source": [
"import shutil\n",
"import os\n",
"\n",
"TARGET_DIR = 'trash/tutorial_4'\n",
"\n",
"if os.path.exists(TARGET_DIR):\n",
" shutil.rmtree(TARGET_DIR)\n",
"\n",
"os.mkdir(TARGET_DIR)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "646f7590",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1181.95 examples/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 2, 4, 6, 7, 10, 11, 12, 13, 14, 15, 18, 20, 21, 23, 24, 26, 27, 28, 29, 30, 31, 33, 36, 37, 38, 40, 42, 44, 45, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 65, 66, 67, 69, 70, 73, 74, 75, 76, 77, 79, 80, 83, 85, 86, 88, 89, 90, 91, 93, 94, 95, 96, 97, 98, 100, 101, 102, 103, 104, 105, 107, 109, 110, 112, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 133, 134, 135, 136, 137, 139, 141, 142, 143, 144, 149, 150, 151, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 165, 166, 168, 169, 171, 172, 173, 174, 178, 181, 182, 183, 184, 185, 186, 187, 189, 192, 193, 195, 196, 197, 198, 199]\n",
"[19, 22, 25, 39, 46, 71, 82, 84, 87, 92, 106, 108, 113, 131, 132, 138, 146, 147, 148, 155, 177, 179, 180, 188, 190, 191]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1359.83 examples/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['patient_id', 'events'],\n",
" num_rows: 144\n",
" })\n",
" test: Dataset({\n",
" features: ['patient_id', 'events'],\n",
" num_rows: 26\n",
" })\n",
"})\n"
]
}
],
"source": [
"import datasets\n",
"import femr.index\n",
"import femr.splits\n",
"\n",
"# First, we want to split our dataset into train, valid, and test\n",
"# We do this by calling our split functionality twice\n",
"\n",
"dataset = datasets.Dataset.from_parquet('input/meds/data/*')\n",
"\n",
"\n",
"index = femr.index.PatientIndex(dataset, num_proc=4)\n",
"main_split = femr.splits.generate_hash_split(index.get_patient_ids(), 97, frac_test=0.15)\n",
"\n",
"\n",
"os.mkdir(os.path.join(TARGET_DIR, 'clmbr_model'))\n",
"# Note that we want to save this to the target directory since this is important information\n",
"\n",
"main_split.save_to_csv(os.path.join(TARGET_DIR, \"clmbr_model\", \"main_split.csv\"))\n",
"\n",
"train_split = femr.splits.generate_hash_split(main_split.train_patient_ids, 87, frac_test=0.15)\n",
"\n",
"print(train_split.train_patient_ids)\n",
"print(train_split.test_patient_ids)\n",
"\n",
"main_dataset = main_split.split_dataset(dataset, index)\n",
"train_dataset = train_split.split_dataset(main_dataset['train'], femr.index.PatientIndex(main_dataset['train'], num_proc=4))\n",
"\n",
"print(train_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f60ab7df-e851-44a5-ab70-7bee292be00c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1110.06 examples/s]\n"
]
}
],
"source": [
"import femr.models.tokenizer\n",
"\n",
"# First, we need to train a tokenizer\n",
"\n",
"# NOTE: A vocab size of 128 is probably too low for a real model. 128 was chosen to make this tutorial quick to run\n",
"tokenizer = femr.models.tokenizer.train_tokenizer(\n",
" main_dataset['train'], vocab_size=128, num_proc=4)\n",
"\n",
"# Save the tokenizer to the same directory as the model\n",
"tokenizer.save_pretrained(os.path.join(TARGET_DIR, \"clmbr_model\"))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "89611ba9-a242-4b87-9b8f-25670d838fc6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 930.24 examples/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating batches 12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Generating train split: 12 examples [00:00, 58.74 examples/s]\n",
"Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 162.43 examples/s]\n",
"Setting num_proc from 4 to 3 for the train split as it only contains 3 shards.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating batches 3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Generating train split: 3 examples [00:00, 17.22 examples/s]\n"
]
}
],
"source": [
"import femr.models.processor\n",
"import femr.models.tasks\n",
"\n",
"# Second, we need to create batches. We define the CLMBR task at this time\n",
"\n",
"clmbr_task = femr.models.tasks.CLMBRTask(clmbr_vocab_size=64)\n",
"\n",
"processor = femr.models.processor.FEMRBatchProcessor(tokenizer, clmbr_task)\n",
"\n",
"# We can do this one patient at a time\n",
"example_batch = processor.collate([processor.convert_patient(train_dataset['train'][0], tensor_type='pt')])\n",
"\n",
"# But generally we want to convert entire datasets\n",
"train_batches = processor.convert_dataset(train_dataset, tokens_per_batch=32, num_proc=4)\n",
"\n",
"# Convert our batches to pytorch tensors\n",
"train_batches.set_format(\"pt\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f654a46c-5aa7-465c-b6c5-73d8ba26ed67",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n",
" warnings.warn(\"Can't initialize NVML\")\n",
"/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
"dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
" warnings.warn(\n",
"Could not estimate the number of tokens of the input, floating-point operations will not be computed\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='240' max='240' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [240/240 00:03, Epoch 20/20]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>3.953800</td>\n",
" <td>3.833314</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>3.626800</td>\n",
" <td>3.582646</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>3.392900</td>\n",
" <td>3.366862</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>3.171500</td>\n",
" <td>3.182760</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>2.962200</td>\n",
" <td>3.030737</td>\n",
" </tr>\n",
" <tr>\n",
" <td>120</td>\n",
" <td>2.828600</td>\n",
" <td>2.907934</td>\n",
" </tr>\n",
" <tr>\n",
" <td>140</td>\n",
" <td>2.702800</td>\n",
" <td>2.809832</td>\n",
" </tr>\n",
" <tr>\n",
" <td>160</td>\n",
" <td>2.601000</td>\n",
" <td>2.735075</td>\n",
" </tr>\n",
" <tr>\n",
" <td>180</td>\n",
" <td>2.525200</td>\n",
" <td>2.680505</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>2.434700</td>\n",
" <td>2.642956</td>\n",
" </tr>\n",
" <tr>\n",
" <td>220</td>\n",
" <td>2.514500</td>\n",
" <td>2.621699</td>\n",
" </tr>\n",
" <tr>\n",
" <td>240</td>\n",
" <td>2.396900</td>\n",
" <td>2.614570</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import transformers\n",
"\n",
"import femr.models.transformer\n",
"\n",
"# Finally, given the batches, we can train CLMBR.\n",
"# We can use huggingface's trainer to do this.\n",
"\n",
"transformer_config = femr.models.config.FEMRTransformerConfig(\n",
" vocab_size=tokenizer.vocab_size, \n",
" is_hierarchical=tokenizer.is_hierarchical, \n",
" n_layers=2,\n",
" hidden_size=64, \n",
" intermediate_size=64*2,\n",
" n_heads=8,\n",
")\n",
"\n",
"config = femr.models.config.FEMRModelConfig.from_transformer_task_configs(transformer_config, clmbr_task.get_task_config())\n",
"\n",
"model = femr.models.transformer.FEMRModel(config)\n",
"\n",
"trainer_config = transformers.TrainingArguments(\n",
" per_device_train_batch_size=1,\n",
" per_device_eval_batch_size=1,\n",
"\n",
" output_dir='tmp_trainer',\n",
" remove_unused_columns=False,\n",
" num_train_epochs=20,\n",
"\n",
" eval_steps=20,\n",
" evaluation_strategy=\"steps\",\n",
"\n",
" logging_steps=20,\n",
" logging_strategy='steps',\n",
"\n",
" prediction_loss_only=True,\n",
")\n",
"\n",
"trainer = transformers.Trainer(\n",
" model=model,\n",
" data_collator=processor.collate,\n",
" train_dataset=train_batches['train'],\n",
" eval_dataset=train_batches['test'],\n",
" args=trainer_config,\n",
")\n",
"\n",
"trainer.train()\n",
"\n",
"model.save_pretrained(os.path.join(TARGET_DIR, 'clmbr_model'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}