--- a +++ b/experiments/Foresight MIMIC -- Train and Test -- Final.ipynb @@ -0,0 +1,1239 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2bbeaebc", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "#os.environ[\"WANDB_DISABLED\"] = \"true\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c5ba9e3", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "log = logging.getLogger()\n", + "log.handlers.clear()\n", + "log.addHandler(logging.StreamHandler())\n", + "log.setLevel(logging.WARNING)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78bc409f", + "metadata": {}, + "outputs": [], + "source": [ + "from foresight.datasets.data_collator import CollataAndPad\n", + "\n", + "from foresight.utils import pickle\n", + "from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n", + "from medcat.cdb import CDB\n", + "from foresight.datasets.data_collator import CollataAndPad\n", + "from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n", + "from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments\n", + "from medcat.cat import CAT\n", + "from foresight.models.lucid_transformers import LucidLM2HF\n", + "from transformers import SchedulerType\n", + "\n", + "from datasets import Dataset\n", + "import math\n", + "import datasets\n", + "import numpy as np\n", + "from torch.utils.data import DataLoader\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import os\n", + "import shutil\n", + "import random\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04d1872a", + "metadata": {}, + "outputs": [], + "source": [ + "DAYS = 1\n", + "MAX_SEQ_LEN = 256\n", + "TYPES = ['ALL_TYPES']\n", + "#TYPES = ['T-11']\n", + "#TYPES = ['T-11', 'T-18']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd25aabc", + "metadata": {}, + "outputs": [], + "source": [ + "FROM_BASE = False\n", + "#BASE_TOKENIZER_PATH = f\"./data/time/models/gpt/tokenizer_annotations_stream_phase2_v1_1d_256_ALL_TYPES_v7.pickle\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd1eb81", + "metadata": {}, + "outputs": [], + "source": [ + "USE_POSITION_IDS = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4861809", + "metadata": {}, + "outputs": [], + "source": [ + "SMALL_TEST_SIZE = 1000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e9b04ae", + "metadata": {}, + "outputs": [], + "source": [ + "BASE_NAME = 'annotated_february_2022'\n", + "DATASET_NAME = 'annotations_stream_phase2_v1'\n", + "RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d08660d5", + "metadata": {}, + "outputs": [], + "source": [ + "TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n", + "PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n", + "MODEL_PATH = f\"./data/timecat/models/gpt-phase3-{RUN_NAME}-Positions-{USE_POSITION_IDS}-fromBase-{FROM_BASE}-old-test/\"\n", + "RESULTS_HYPERPARAM = \"./data/timecat/models/gpt/results/\"\n", + "CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n", + "\n", + "DEVICE = torch.device('cuda')" + ] + }, + { + "cell_type": "markdown", + "id": "93bc3bfa", + "metadata": {}, + "source": [ + "# Load everything and prepare train/test set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53750c5d", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n", + "cdb = cat.cdb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17d627fa", + "metadata": {}, + "outputs": [], + "source": [ + "encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)\n", + "encoded_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f43d612c", + "metadata": {}, + "outputs": [], + "source": [ + "if FROM_BASE:\n", + " print(\"USING BASE\")\n", + " TOKENIZER_PATH = BASE_TOKENIZER_PATH\n", + "tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5f0e619", + "metadata": {}, + "outputs": [], + "source": [ + "collate_fn = CollataAndPad(max_seq_len=tokenizer.max_len + 1, pad_id=tokenizer.tkn2id['<PAD>'], \n", + " shift_labels=False,\n", + " use_position_ids=USE_POSITION_IDS,\n", + " use_token_type_ids=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5580622b", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = DataLoader(encoded_dataset['train'], batch_size=1000, shuffle=False, collate_fn=collate_fn)\n", + "dataset_test = DataLoader(encoded_dataset['test'], batch_size=1000, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "id": "622a99b9", + "metadata": {}, + "source": [ + "### Create a mini dataset for testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4083f3da", + "metadata": {}, + "outputs": [], + "source": [ + "if SMALL_TEST_SIZE:\n", + " random.seed(11)\n", + " inds = random.choices([i for i in range(len(encoded_dataset['test']))], k=SMALL_TEST_SIZE)\n", + " encoded_dataset_test_mini = Dataset.from_dict(encoded_dataset['test'][inds])\n", + " dataset_test_mini = DataLoader(encoded_dataset_test_mini, batch_size=1000, shuffle=False, collate_fn=collate_fn)\n", + "else:\n", + " encoded_dataset_test_mini = encoded_dataset['test']" + ] + }, + { + "cell_type": "markdown", + "id": "1a22d32c", + "metadata": {}, + "source": [ + "# Create GPT2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efed3e18", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Load existing if you want, skip all other cells in this section if YES\n", + "model = GPT2LMHeadModel.from_pretrained('./data/timecat/models/gpt/gpt-phase2-annotations_stream_phase2_v1_1d_256_ALL_TYPES-Positions-False-fromBase-False-old-test/')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58cc7f79", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Make a new model\n", + "config = GPT2Config(\n", + " vocab_size=len(tokenizer.embeddings),\n", + " n_positions=tokenizer.max_len+1,\n", + " n_ctx=tokenizer.max_len+1,\n", + " n_embd=512,\n", + " n_layer=16,\n", + " n_head=16,\n", + " bos_token_id=tokenizer.tkn2id['<PAD>'],\n", + " eos_token_id=tokenizer.tkn2id['<PAD>']\n", + ")\n", + "model = GPT2LMHeadModel(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb8edfca", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "#model.transformer.wte.load_state_dict({'weight': torch.tensor(tokenizer.embeddings, dtype=torch.float32)})\n", + "#model.transformer.wte.weight.requires_grad = True" + ] + }, + { + "cell_type": "markdown", + "id": "82dffb5f", + "metadata": {}, + "source": [ + "# Lucid GPT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d644dc7", + "metadata": {}, + "outputs": [], + "source": [ + "# Make a new model\n", + "config = GPT2Config(\n", + " vocab_size=len(tokenizer.embeddings),\n", + " n_positions=tokenizer.max_len+1,\n", + " n_ctx=tokenizer.max_len+1,\n", + " n_embd=512,\n", + " n_layer=16,\n", + " n_head=16,\n", + " bos_token_id=tokenizer.tkn2id['<PAD>'],\n", + " eos_token_id=tokenizer.tkn2id['<PAD>']\n", + ")\n", + "\n", + "addl_decoder_config = {\n", + " 'rotary_pos_emb': True,\n", + "# 'ff_glu': True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "521f6f82", + "metadata": {}, + "outputs": [], + "source": [ + "model = LucidLM2HF(config, addl_decoder_config=addl_decoder_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b56ee1ea", + "metadata": {}, + "source": [ + "# Trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca720485", + "metadata": {}, + "outputs": [], + "source": [ + "test_set_to_use = encoded_dataset_test_mini # This will be automatically the whole test set if mini is not assigned" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "876c3fcc", + "metadata": {}, + "outputs": [], + "source": [ + "all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', \n", + " 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28e17497", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n", + " prediction_scope='time_range', \n", + " topk=1, \n", + " start=0, \n", + " return_all_metrics=False, \n", + " batch_size=1000, \n", + " select_token_types=all_types,\n", + " type_data=test_set_to_use['token_type'],\n", + " token_type2tokens=tokenizer.token_type2tokens,\n", + " time_data=test_set_to_use['time'], \n", + " time_range=30*24*60*60,\n", + " ignore_label_status=False,\n", + " min_time_left=24*60*60)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "974a735d", + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir='./gpt-16-16_1day_no_base_data', # output directory\n", + " num_train_epochs=10, # total number of training epochs\n", + " per_device_train_batch_size=4, # batch size per device during training\n", + " per_device_eval_batch_size=4, # batch size for evaluation\n", + " weight_decay=1e-2, # strength of weight decay\n", + " logging_dir='./logs', # directory for storing logs\n", + " warmup_ratio=0.01,\n", + " learning_rate= 3.14e-04,\n", + " eval_accumulation_steps=1,\n", + " gradient_accumulation_steps=16,\n", + " do_eval=True,\n", + " evaluation_strategy='epoch',\n", + " save_strategy='epoch',\n", + " metric_for_best_model='eval_precision',\n", + " load_best_model_at_end=True,\n", + " lr_scheduler_type=SchedulerType.LINEAR\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e25889f", + "metadata": {}, + "outputs": [], + "source": [ + "import wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c2b1847", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "wandb.init(project='timecat', entity='wish', name=RUN_NAME + '-gpt-16-16_1day_no_base_data')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7f0c996", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model, # the instantiated 🤗 Transformers model to be trained\n", + " args=training_args, # training arguments, defined above\n", + " train_dataset=encoded_dataset['train'], # training dataset\n", + " eval_dataset=test_set_to_use, # evaluation dataset\n", + " compute_metrics=compute_metrics,\n", + " data_collator=collate_fn,\n", + " tokenizer=None,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1634717d", + "metadata": {}, + "source": [ + "#### Make sure stuff is correct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b2c2162", + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "236c2c6e", + "metadata": {}, + "outputs": [], + "source": [ + "ind = 1117" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02fa347e", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "for ty, p, t, c, ind_id in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids']), encoded_dataset['train'][ind]['input_ids']):\n", + " print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c, ind_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0af3ff2", + "metadata": {}, + "outputs": [], + "source": [ + "encoded_dataset['train'][ind]['patient_id']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff4910d4", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH" + ] + }, + { + "cell_type": "markdown", + "id": "d756cf7d", + "metadata": {}, + "source": [ + "# Run training " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f3c7ed86", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "647efb10", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(MODEL_PATH)" + ] + }, + { + "cell_type": "markdown", + "id": "967599dd", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38cc2acd", + "metadata": {}, + "outputs": [], + "source": [ + "all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', \n", + " 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a67dd7d7", + "metadata": {}, + "outputs": [], + "source": [ + "test_set_to_use = encoded_dataset['test']\n", + "test_set_to_use" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dadd7b0d", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model, # the instantiated 🤗 Transformers model to be trained\n", + " args=training_args, # training arguments, defined above\n", + " train_dataset=None, # training dataset\n", + " eval_dataset=None, # evaluation dataset\n", + " compute_metrics=None,\n", + " data_collator=collate_fn,\n", + " tokenizer=None,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "152ec628", + "metadata": {}, + "outputs": [], + "source": [ + "def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):\n", + " size = 1000\n", + " for i in range(int(math.ceil(len(test_set_to_use) / size))):\n", + " _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n", + " compute_metrics.time_data = _dataset['time']\n", + " compute_metrics.type_data = _dataset['token_type']\n", + " if len(_dataset):\n", + " p = trainer.predict(_dataset)\n", + " metrics_data = compute_metrics(p, metrics_data)['metrics_data']\n", + " m_file.write(\"{}, {}, {}, {}\\n\".format(f_name, metrics_data['precision']['all'], \n", + " metrics_data['precision']['new'], \n", + " metrics_data['precision']['old'],\n", + " metrics_data['recall']['all'],\n", + " metrics_data['recall']['new'],\n", + " metrics_data['recall']['old']))\n", + " print(f_name,\n", + " metrics_data['precision']['all'], \n", + " metrics_data['precision']['new'], \n", + " metrics_data['precision']['old'],\n", + " metrics_data['recall']['all'],\n", + " metrics_data['recall']['new'],\n", + " metrics_data['recall']['old']) \n", + " pickle.dump(metrics_data, f_name)\n", + "\n", + " return metrics_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4aee33b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "m_file = open(\"./metrics/summary.txt\", 'w', buffering=1)\n", + "m_file.write(\"file_name, precision all, precision new, precision old\\n\")\n", + "\n", + "for types in [all_types, {'T-11'}, {'T-55'}, {'T-18'}, {'T-39'}]:\n", + " _types = list(types)[0] if len(types) == 1 else 'all_types'\n", + " for timerange in [30, 365, 1000000]:\n", + " compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n", + " prediction_scope='time_range', \n", + " topk=1, # 1, 5, 10\n", + " start=0, # 0, 10, 20, 50, 100\n", + " return_all_metrics=True, \n", + " batch_size=1000, \n", + " select_token_types=types,\n", + " type_data=test_set_to_use['token_type'],\n", + " token_type2tokens=tokenizer.token_type2tokens,\n", + " time_data=test_set_to_use['time'], \n", + " time_range=timerange*24*60*60, #30, 365, 1000000\n", + " ignore_label_status=False,\n", + " min_time_left=24*60*60)\n", + " f_name = f\"./metrics/start-0_topk-1_time_range-{timerange}_types-{_types}.pickle\"\n", + " get_metrics(None, test_set_to_use, trainer, m_file, f_name)\n", + "\n", + " for topk in [5, 10]:\n", + " compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n", + " prediction_scope='time_range', \n", + " topk=topk, # 1, 5, 10\n", + " start=0, # 0, 10, 20, 50, 100\n", + " return_all_metrics=True, \n", + " batch_size=1000, \n", + " select_token_types=types,\n", + " type_data=test_set_to_use['token_type'],\n", + " token_type2tokens=tokenizer.token_type2tokens,\n", + " time_data=test_set_to_use['time'], \n", + " time_range=30*24*60*60, #30, 365, 1000000\n", + " ignore_label_status=False,\n", + " min_time_left=24*60*60)\n", + " f_name = f\"./metrics/start-0_topk-{topk}_time_range-30_types-{_types}.pickle\"\n", + " get_metrics(None, test_set_to_use, trainer, m_file, f_name)\n", + "m_file.close()" + ] + }, + { + "cell_type": "markdown", + "id": "aa75db95", + "metadata": {}, + "source": [ + "# Test Death" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "daad35f1", + "metadata": {}, + "outputs": [], + "source": [ + "all_types = set(['death'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb93c9e7", + "metadata": {}, + "outputs": [], + "source": [ + "test_set_to_use" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d20a61c", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model, # the instantiated 🤗 Transformers model to be trained\n", + " args=training_args, # training arguments, defined above\n", + " train_dataset=None, # training dataset\n", + " eval_dataset=None, # evaluation dataset\n", + " compute_metrics=None,\n", + " data_collator=collate_fn,\n", + " tokenizer=None,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebabb840", + "metadata": {}, + "outputs": [], + "source": [ + "def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):\n", + " size = 1000\n", + " for i in range(int(math.ceil(len(test_set_to_use) / size))):\n", + " _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n", + " compute_metrics.time_data = _dataset['time']\n", + " compute_metrics.type_data = _dataset['token_type']\n", + " if len(_dataset):\n", + " p = trainer.predict(_dataset)\n", + " metrics_data = compute_metrics(p, metrics_data)['metrics_data']\n", + " m_file.write(\"{}, {}, {}, {}\\n\".format(f_name, metrics_data['precision']['all'], \n", + " metrics_data['precision']['new'], \n", + " metrics_data['precision']['old']))\n", + " print(f_name,\n", + " metrics_data['precision']['all'], \n", + " metrics_data['precision']['new'], \n", + " metrics_data['precision']['old'])\n", + " pickle.dump(metrics_data, f_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "202e1dc1", + "metadata": {}, + "outputs": [], + "source": [ + "compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n", + " topk=1, # 1, 5, 10\n", + " start=0, # 0, 10, 20, 50, 100\n", + " return_all_metrics=True, \n", + " batch_size=1000, \n", + " type_data=test_set_to_use['token_type'],\n", + " token_type2tokens=tokenizer.token_type2tokens,\n", + " time_data=test_set_to_use['time'], \n", + " time_range=24*60*60, #30, 365, 1000000\n", + " ignore_label_status=False,\n", + " min_time_left=0,\n", + " concept_id=270)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71282b28", + "metadata": {}, + "outputs": [], + "source": [ + "metrics_data = None\n", + "_dataset = Dataset.from_dict(test_set_to_use[0:1000])\n", + "compute_metrics.time_data = _dataset['time']\n", + "compute_metrics.type_data = _dataset['token_type']\n", + "if len(_dataset):\n", + " p = trainer.predict(_dataset)\n", + " metrics_data = compute_metrics(p, metrics_data)['metrics_data']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7b85a03", + "metadata": {}, + "outputs": [], + "source": [ + "metrics_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51c716a0", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer.tkn2id['The patient has died']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69a9a505", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(_dataset['input_ids'])):\n", + " if 270 in _dataset['input_ids'][i]:\n", + " print(i)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24bb8e62", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "metrics_data = None\n", + "size = 1000\n", + "for i in range(int(math.ceil(len(test_set_to_use) / size))):\n", + " _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n", + " compute_metrics.time_data = _dataset['time']\n", + " compute_metrics.type_data = _dataset['token_type']\n", + " if len(_dataset):\n", + " p = trainer.predict(_dataset)\n", + " metrics_data = compute_metrics(p, metrics_data)['metrics_data']" + ] + }, + { + "cell_type": "markdown", + "id": "055f6a43", + "metadata": {}, + "source": [ + "# Hyperparameter search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe6a39e1", + "metadata": {}, + "outputs": [], + "source": [ + "from ray.tune.schedulers import PopulationBasedTraining\n", + "from ray import tune\n", + "from ray.tune import CLIReporter \n", + "import ray" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5105ae66", + "metadata": {}, + "outputs": [], + "source": [ + "compute_metrics = ComputePrecisionHF(id2tkn, id2type, prediction_scope='age', topk=1, start=0, batch_size=2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea5f2956", + "metadata": {}, + "outputs": [], + "source": [ + "NUM_TRIALS = 20\n", + "N_GPU_PER_TRIAL = 1\n", + "METRIC_TO_OPTIMIZE = 'eval_precision'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27783bf0", + "metadata": {}, + "outputs": [], + "source": [ + "def get_model(params):\n", + " torch.cuda.empty_cache()\n", + " if params is None:\n", + " params = {}\n", + " \n", + " config = GPT2Config(\n", + " vocab_size=len(embeddings),\n", + " n_positions=MAX_SEQ_LEN+1,\n", + " n_ctx=MAX_SEQ_LEN+1,\n", + " n_embd=params.get('n_embd', 300),\n", + " n_layer=params.get('n_layer', 1),\n", + " n_head=params.get('n_head', 1),\n", + " bos_token_id=tkn2id['<PAD>'],\n", + " eos_token_id=tkn2id['<PAD>']\n", + " )\n", + " model = GPT2LMHeadModel(config)\n", + " \n", + " if params.get('load_weights', 0):\n", + " model.transformer.wte.load_state_dict({'weight': torch.tensor(embeddings, dtype=torch.float32)})\n", + " model.transformer.wte.weight.requires_grad = True\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16e4a39d", + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir='./results', # output directory\n", + " num_train_epochs=5, # total number of training epochs\n", + " per_device_train_batch_size=16, # batch size per device during training\n", + " per_device_eval_batch_size=128, # batch size for evaluation\n", + " weight_decay=0.01, # strength of weight decay\n", + " logging_dir='./logs', # directory for storing logs\n", + " logging_steps=200,\n", + " eval_steps=200,\n", + " learning_rate= 5e-5,\n", + " eval_accumulation_steps=1,\n", + " do_eval=True,\n", + " evaluation_strategy='steps',\n", + " skip_memory_metrics=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0edc2eb1", + "metadata": {}, + "outputs": [], + "source": [ + "training_args.n_head = 1\n", + "training_args.n_layer = 1\n", + "training_args.n_embd = 300\n", + "training_args.load_weights = 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e46e8cb6", + "metadata": {}, + "outputs": [], + "source": [ + "tune_dataset = encoded_dataset['train'].train_test_split(test_size=0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f111f0a", + "metadata": {}, + "outputs": [], + "source": [ + "tune_train_dataset = tune_dataset['train']\n", + "tune_test_dataset = tune_dataset['test']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f085ab71", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + "# model=model, # the instantiated 🤗 Transformers model to be trained\n", + " args=training_args, # training arguments, defined above\n", + " train_dataset=tune_train_dataset, # training dataset\n", + " eval_dataset=tune_test_dataset, # evaluation dataset\n", + " compute_metrics=compute_metrics,\n", + " data_collator=collate_fn,\n", + " tokenizer=None,\n", + " model_init=get_model,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0356cdc", + "metadata": {}, + "outputs": [], + "source": [ + "tune_config = {\n", + " \"num_train_epochs\": tune.choice([5]),\n", + " \"n_head\": tune.choice([2, 4, 6]),\n", + "}\n", + "scheduler = PopulationBasedTraining(\n", + " time_attr=\"training_iteration\",\n", + " metric=METRIC_TO_OPTIMIZE,\n", + " mode=\"max\",\n", + " perturbation_interval=1,\n", + " hyperparam_mutations={\n", + " \"weight_decay\": tune.uniform(0.0, 0.3),\n", + " \"learning_rate\": tune.uniform(1e-5, 5e-5),\n", + " \"per_device_train_batch_size\": [16, 32, 64, 128],\n", + " \"n_layer\": tune.choice([2, 4, 6, 8]),\n", + "# \"n_embd\": tune.choice([256, 512]),\n", + " \"load_weights\": tune.choice([0, 1]),\n", + " \"warmup_steps\": tune.choice([20, 40, 60, 100]),\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63f3fd7d", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "def compute_objective(metrics):\n", + " metrics = copy.deepcopy(metrics)\n", + " eval_precision = metrics.pop('eval_precision')\n", + " \n", + " return eval_precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23e5bff1", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "best_model = trainer.hyperparameter_search(\n", + " hp_space=lambda _: tune_config,\n", + " backend=\"ray\",\n", + " n_trials=NUM_TRIALS,\n", + " direction='maximize',\n", + " compute_objective=compute_objective,\n", + " resources_per_trial={\n", + " \"cpu\": 1,\n", + " \"gpu\": N_GPU_PER_TRIAL\n", + " },\n", + " scheduler=scheduler,\n", + " keep_checkpoints_num=1,\n", + " checkpoint_score_attr=METRIC_TO_OPTIMIZE,\n", + " stop=None,\n", + " local_dir=RESULTS_HYPERPARAM,\n", + " name=\"21_May_2021\",\n", + " log_to_file=False,\n", + " loggers=None,# (WandbLogger, ),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edf5a0ab", + "metadata": {}, + "outputs": [], + "source": [ + "best_model" + ] + }, + { + "cell_type": "markdown", + "id": "9db3fae5", + "metadata": {}, + "source": [ + "# Saliency " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "364f1896", + "metadata": {}, + "outputs": [], + "source": [ + "import ecco" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b237de8", + "metadata": {}, + "outputs": [], + "source": [ + "lm = ecco.LM(trainer.model, tokenizer, model_name='gpt2')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92ad3750", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "ind = 49\n", + "print(\"~~\".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids']]))\n", + "text = \"~~\".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids'][1:-1]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28108d90", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "output = lm.generate(text, generate=10, do_sample=True, temperature=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b1b3898", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "output.saliency(style=\"detailed\")" + ] + }, + { + "cell_type": "markdown", + "id": "3dcb6cf5", + "metadata": {}, + "source": [ + "# Probability prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09f0fbb7", + "metadata": {}, + "outputs": [], + "source": [ + "from foresight.sight import Sight" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4b0662d", + "metadata": {}, + "outputs": [], + "source": [ + "_ = model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6d741d6", + "metadata": {}, + "outputs": [], + "source": [ + "sight = Sight(tokenizer=tokenizer, device='cuda', model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a1b9bfec", + "metadata": {}, + "outputs": [], + "source": [ + "cdb.name2cuis['muscle~pain']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f92d2766", + "metadata": {}, + "outputs": [], + "source": [ + "cdb.get_name('pain')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b545e844", + "metadata": {}, + "outputs": [], + "source": [ + "text = '<ETHNICITY>~~White~~<SEX>~~Male~~<AGE>~~23~~49727002~~386661006'.split(\"~~\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a7fc62e", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Small with WD\n", + "r = sight.next_concepts(text, type_ids=['T-11'], n=40, p_new=True, create_position_ids=False)\n", + "print([cdb.get_name(x) for x in text])\n", + "for x in r:\n", + " print(x[0], x[1], cdb.get_name(x[0]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}