--- a
+++ b/experiments/Foresight MIMIC -- Train and Test -- Final.ipynb
@@ -0,0 +1,1239 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2bbeaebc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "#os.environ[\"WANDB_DISABLED\"] = \"true\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9c5ba9e3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import logging\n",
+    "log = logging.getLogger()\n",
+    "log.handlers.clear()\n",
+    "log.addHandler(logging.StreamHandler())\n",
+    "log.setLevel(logging.WARNING)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "78bc409f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from foresight.datasets.data_collator import CollataAndPad\n",
+    "\n",
+    "from foresight.utils import pickle\n",
+    "from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n",
+    "from medcat.cdb import CDB\n",
+    "from foresight.datasets.data_collator import CollataAndPad\n",
+    "from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n",
+    "from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments\n",
+    "from medcat.cat import CAT\n",
+    "from foresight.models.lucid_transformers import LucidLM2HF\n",
+    "from transformers import SchedulerType\n",
+    "\n",
+    "from datasets import Dataset\n",
+    "import math\n",
+    "import datasets\n",
+    "import numpy as np\n",
+    "from torch.utils.data import DataLoader\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.optim as optim\n",
+    "import os\n",
+    "import shutil\n",
+    "import random\n",
+    "import pandas as pd"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "04d1872a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "DAYS = 1\n",
+    "MAX_SEQ_LEN = 256\n",
+    "TYPES = ['ALL_TYPES']\n",
+    "#TYPES = ['T-11']\n",
+    "#TYPES = ['T-11', 'T-18']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "bd25aabc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "FROM_BASE = False\n",
+    "#BASE_TOKENIZER_PATH = f\"./data/time/models/gpt/tokenizer_annotations_stream_phase2_v1_1d_256_ALL_TYPES_v7.pickle\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7bd1eb81",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "USE_POSITION_IDS = True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b4861809",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SMALL_TEST_SIZE = 1000"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0e9b04ae",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "BASE_NAME = 'annotated_february_2022'\n",
+    "DATASET_NAME = 'annotations_stream_phase2_v1'\n",
+    "RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d08660d5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n",
+    "PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n",
+    "MODEL_PATH = f\"./data/timecat/models/gpt-phase3-{RUN_NAME}-Positions-{USE_POSITION_IDS}-fromBase-{FROM_BASE}-old-test/\"\n",
+    "RESULTS_HYPERPARAM = \"./data/timecat/models/gpt/results/\"\n",
+    "CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n",
+    "\n",
+    "DEVICE = torch.device('cuda')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "93bc3bfa",
+   "metadata": {},
+   "source": [
+    "# Load everything and prepare train/test set"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "53750c5d",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n",
+    "cdb = cat.cdb"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "17d627fa",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)\n",
+    "encoded_dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f43d612c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if FROM_BASE:\n",
+    "    print(\"USING BASE\")\n",
+    "    TOKENIZER_PATH = BASE_TOKENIZER_PATH\n",
+    "tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e5f0e619",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "collate_fn = CollataAndPad(max_seq_len=tokenizer.max_len + 1, pad_id=tokenizer.tkn2id['<PAD>'], \n",
+    "                           shift_labels=False,\n",
+    "                           use_position_ids=USE_POSITION_IDS,\n",
+    "                           use_token_type_ids=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5580622b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset_train = DataLoader(encoded_dataset['train'], batch_size=1000, shuffle=False, collate_fn=collate_fn)\n",
+    "dataset_test = DataLoader(encoded_dataset['test'], batch_size=1000, shuffle=False, collate_fn=collate_fn)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "622a99b9",
+   "metadata": {},
+   "source": [
+    "### Create a mini dataset for testing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4083f3da",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if SMALL_TEST_SIZE:\n",
+    "    random.seed(11)\n",
+    "    inds = random.choices([i for i in range(len(encoded_dataset['test']))], k=SMALL_TEST_SIZE)\n",
+    "    encoded_dataset_test_mini = Dataset.from_dict(encoded_dataset['test'][inds])\n",
+    "    dataset_test_mini = DataLoader(encoded_dataset_test_mini, batch_size=1000, shuffle=False, collate_fn=collate_fn)\n",
+    "else:\n",
+    "    encoded_dataset_test_mini = encoded_dataset['test']"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1a22d32c",
+   "metadata": {},
+   "source": [
+    "# Create GPT2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "efed3e18",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "# Load existing if you want, skip all other cells in this section if YES\n",
+    "model = GPT2LMHeadModel.from_pretrained('./data/timecat/models/gpt/gpt-phase2-annotations_stream_phase2_v1_1d_256_ALL_TYPES-Positions-False-fromBase-False-old-test/')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "58cc7f79",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "# Make a new model\n",
+    "config = GPT2Config(\n",
+    "    vocab_size=len(tokenizer.embeddings),\n",
+    "    n_positions=tokenizer.max_len+1,\n",
+    "    n_ctx=tokenizer.max_len+1,\n",
+    "    n_embd=512,\n",
+    "    n_layer=16,\n",
+    "    n_head=16,\n",
+    "    bos_token_id=tokenizer.tkn2id['<PAD>'],\n",
+    "    eos_token_id=tokenizer.tkn2id['<PAD>']\n",
+    ")\n",
+    "model = GPT2LMHeadModel(config)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fb8edfca",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "#model.transformer.wte.load_state_dict({'weight': torch.tensor(tokenizer.embeddings, dtype=torch.float32)})\n",
+    "#model.transformer.wte.weight.requires_grad = True"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "82dffb5f",
+   "metadata": {},
+   "source": [
+    "# Lucid GPT"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7d644dc7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Make a new model\n",
+    "config = GPT2Config(\n",
+    "    vocab_size=len(tokenizer.embeddings),\n",
+    "    n_positions=tokenizer.max_len+1,\n",
+    "    n_ctx=tokenizer.max_len+1,\n",
+    "    n_embd=512,\n",
+    "    n_layer=16,\n",
+    "    n_head=16,\n",
+    "    bos_token_id=tokenizer.tkn2id['<PAD>'],\n",
+    "    eos_token_id=tokenizer.tkn2id['<PAD>']\n",
+    ")\n",
+    "\n",
+    "addl_decoder_config = {\n",
+    "    'rotary_pos_emb': True,\n",
+    "#    'ff_glu': True,\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "521f6f82",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = LucidLM2HF(config, addl_decoder_config=addl_decoder_config)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b56ee1ea",
+   "metadata": {},
+   "source": [
+    "# Trainer"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ca720485",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_set_to_use = encoded_dataset_test_mini # This will be automatically the whole test set if mini is not assigned"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "876c3fcc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', \n",
+    "                 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "28e17497",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
+    "                                     prediction_scope='time_range', \n",
+    "                                     topk=1, \n",
+    "                                     start=0, \n",
+    "                                     return_all_metrics=False, \n",
+    "                                     batch_size=1000, \n",
+    "                                     select_token_types=all_types,\n",
+    "                                     type_data=test_set_to_use['token_type'],\n",
+    "                                     token_type2tokens=tokenizer.token_type2tokens,\n",
+    "                                     time_data=test_set_to_use['time'], \n",
+    "                                     time_range=30*24*60*60,\n",
+    "                                     ignore_label_status=False,\n",
+    "                                     min_time_left=24*60*60)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "974a735d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "training_args = TrainingArguments(\n",
+    "    output_dir='./gpt-16-16_1day_no_base_data',          # output directory\n",
+    "    num_train_epochs=10,              # total number of training epochs\n",
+    "    per_device_train_batch_size=4,  # batch size per device during training\n",
+    "    per_device_eval_batch_size=4,   # batch size for evaluation\n",
+    "    weight_decay=1e-2,               # strength of weight decay\n",
+    "    logging_dir='./logs',            # directory for storing logs\n",
+    "    warmup_ratio=0.01,\n",
+    "    learning_rate= 3.14e-04,\n",
+    "    eval_accumulation_steps=1,\n",
+    "    gradient_accumulation_steps=16,\n",
+    "    do_eval=True,\n",
+    "    evaluation_strategy='epoch',\n",
+    "    save_strategy='epoch',\n",
+    "    metric_for_best_model='eval_precision',\n",
+    "    load_best_model_at_end=True,\n",
+    "    lr_scheduler_type=SchedulerType.LINEAR\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1e25889f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import wandb"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7c2b1847",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "wandb.init(project='timecat', entity='wish', name=RUN_NAME + '-gpt-16-16_1day_no_base_data')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e7f0c996",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "trainer = Trainer(\n",
+    "    model=model,                         # the instantiated 🤗 Transformers model to be trained\n",
+    "    args=training_args,                  # training arguments, defined above\n",
+    "    train_dataset=encoded_dataset['train'],         # training dataset\n",
+    "    eval_dataset=test_set_to_use,             # evaluation dataset\n",
+    "    compute_metrics=compute_metrics,\n",
+    "    data_collator=collate_fn,\n",
+    "    tokenizer=None,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1634717d",
+   "metadata": {},
+   "source": [
+    "#### Make sure stuff is correct"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7b2c2162",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from datetime import datetime"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "236c2c6e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ind = 1117"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "02fa347e",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "for ty, p, t, c, ind_id in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids']), encoded_dataset['train'][ind]['input_ids']):\n",
+    "    print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c, ind_id)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b0af3ff2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset['train'][ind]['patient_id']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ff4910d4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MODEL_PATH"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d756cf7d",
+   "metadata": {},
+   "source": [
+    "# Run training "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f3c7ed86",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "trainer.train()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "647efb10",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer.save_model(MODEL_PATH)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "967599dd",
+   "metadata": {},
+   "source": [
+    "# Test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "38cc2acd",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', \n",
+    "                 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a67dd7d7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_set_to_use = encoded_dataset['test']\n",
+    "test_set_to_use"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "dadd7b0d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer = Trainer(\n",
+    "    model=model,                         # the instantiated 🤗 Transformers model to be trained\n",
+    "    args=training_args,                  # training arguments, defined above\n",
+    "    train_dataset=None,         # training dataset\n",
+    "    eval_dataset=None,             # evaluation dataset\n",
+    "    compute_metrics=None,\n",
+    "    data_collator=collate_fn,\n",
+    "    tokenizer=None,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "152ec628",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):\n",
+    "    size = 1000\n",
+    "    for i in range(int(math.ceil(len(test_set_to_use) / size))):\n",
+    "        _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n",
+    "        compute_metrics.time_data = _dataset['time']\n",
+    "        compute_metrics.type_data = _dataset['token_type']\n",
+    "        if len(_dataset):\n",
+    "            p = trainer.predict(_dataset)\n",
+    "            metrics_data = compute_metrics(p, metrics_data)['metrics_data']\n",
+    "    m_file.write(\"{}, {}, {}, {}\\n\".format(f_name, metrics_data['precision']['all'], \n",
+    "                                 metrics_data['precision']['new'], \n",
+    "                                 metrics_data['precision']['old'],\n",
+    "                                 metrics_data['recall']['all'],\n",
+    "                                 metrics_data['recall']['new'],\n",
+    "                                 metrics_data['recall']['old']))\n",
+    "    print(f_name,\n",
+    "          metrics_data['precision']['all'], \n",
+    "          metrics_data['precision']['new'], \n",
+    "          metrics_data['precision']['old'],\n",
+    "          metrics_data['recall']['all'],\n",
+    "          metrics_data['recall']['new'],\n",
+    "          metrics_data['recall']['old']) \n",
+    "    pickle.dump(metrics_data, f_name)\n",
+    "\n",
+    "    return metrics_data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a4aee33b",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "m_file = open(\"./metrics/summary.txt\", 'w', buffering=1)\n",
+    "m_file.write(\"file_name, precision all, precision new, precision old\\n\")\n",
+    "\n",
+    "for types in [all_types, {'T-11'}, {'T-55'}, {'T-18'}, {'T-39'}]:\n",
+    "    _types = list(types)[0] if len(types) == 1 else 'all_types'\n",
+    "    for timerange in [30, 365, 1000000]:\n",
+    "        compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
+    "                                         prediction_scope='time_range', \n",
+    "                                         topk=1, # 1, 5, 10\n",
+    "                                         start=0, # 0, 10, 20, 50, 100\n",
+    "                                         return_all_metrics=True, \n",
+    "                                         batch_size=1000, \n",
+    "                                         select_token_types=types,\n",
+    "                                         type_data=test_set_to_use['token_type'],\n",
+    "                                         token_type2tokens=tokenizer.token_type2tokens,\n",
+    "                                         time_data=test_set_to_use['time'], \n",
+    "                                         time_range=timerange*24*60*60, #30, 365, 1000000\n",
+    "                                         ignore_label_status=False,\n",
+    "                                         min_time_left=24*60*60)\n",
+    "        f_name = f\"./metrics/start-0_topk-1_time_range-{timerange}_types-{_types}.pickle\"\n",
+    "        get_metrics(None, test_set_to_use, trainer, m_file, f_name)\n",
+    "\n",
+    "    for topk in [5, 10]:\n",
+    "        compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
+    "                                         prediction_scope='time_range', \n",
+    "                                         topk=topk, # 1, 5, 10\n",
+    "                                         start=0, # 0, 10, 20, 50, 100\n",
+    "                                         return_all_metrics=True, \n",
+    "                                         batch_size=1000, \n",
+    "                                         select_token_types=types,\n",
+    "                                         type_data=test_set_to_use['token_type'],\n",
+    "                                         token_type2tokens=tokenizer.token_type2tokens,\n",
+    "                                         time_data=test_set_to_use['time'], \n",
+    "                                         time_range=30*24*60*60, #30, 365, 1000000\n",
+    "                                         ignore_label_status=False,\n",
+    "                                         min_time_left=24*60*60)\n",
+    "        f_name = f\"./metrics/start-0_topk-{topk}_time_range-30_types-{_types}.pickle\"\n",
+    "        get_metrics(None, test_set_to_use, trainer, m_file, f_name)\n",
+    "m_file.close()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "aa75db95",
+   "metadata": {},
+   "source": [
+    "# Test Death"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "daad35f1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "all_types = set(['death'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "eb93c9e7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_set_to_use"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5d20a61c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer = Trainer(\n",
+    "    model=model,                         # the instantiated 🤗 Transformers model to be trained\n",
+    "    args=training_args,                  # training arguments, defined above\n",
+    "    train_dataset=None,         # training dataset\n",
+    "    eval_dataset=None,             # evaluation dataset\n",
+    "    compute_metrics=None,\n",
+    "    data_collator=collate_fn,\n",
+    "    tokenizer=None,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ebabb840",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):\n",
+    "    size = 1000\n",
+    "    for i in range(int(math.ceil(len(test_set_to_use) / size))):\n",
+    "        _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n",
+    "        compute_metrics.time_data = _dataset['time']\n",
+    "        compute_metrics.type_data = _dataset['token_type']\n",
+    "        if len(_dataset):\n",
+    "            p = trainer.predict(_dataset)\n",
+    "            metrics_data = compute_metrics(p, metrics_data)['metrics_data']\n",
+    "    m_file.write(\"{}, {}, {}, {}\\n\".format(f_name, metrics_data['precision']['all'], \n",
+    "                                 metrics_data['precision']['new'], \n",
+    "                                 metrics_data['precision']['old']))\n",
+    "    print(f_name,\n",
+    "          metrics_data['precision']['all'], \n",
+    "          metrics_data['precision']['new'], \n",
+    "          metrics_data['precision']['old'])\n",
+    "    pickle.dump(metrics_data, f_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "202e1dc1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
+    "                                 topk=1, # 1, 5, 10\n",
+    "                                 start=0, # 0, 10, 20, 50, 100\n",
+    "                                 return_all_metrics=True, \n",
+    "                                 batch_size=1000, \n",
+    "                                 type_data=test_set_to_use['token_type'],\n",
+    "                                 token_type2tokens=tokenizer.token_type2tokens,\n",
+    "                                 time_data=test_set_to_use['time'], \n",
+    "                                 time_range=24*60*60, #30, 365, 1000000\n",
+    "                                 ignore_label_status=False,\n",
+    "                                 min_time_left=0,\n",
+    "                                 concept_id=270)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "71282b28",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metrics_data = None\n",
+    "_dataset = Dataset.from_dict(test_set_to_use[0:1000])\n",
+    "compute_metrics.time_data = _dataset['time']\n",
+    "compute_metrics.type_data = _dataset['token_type']\n",
+    "if len(_dataset):\n",
+    "    p = trainer.predict(_dataset)\n",
+    "    metrics_data = compute_metrics(p, metrics_data)['metrics_data']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d7b85a03",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "metrics_data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "51c716a0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer.tkn2id['The patient has died']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "69a9a505",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for i in range(len(_dataset['input_ids'])):\n",
+    "    if 270 in _dataset['input_ids'][i]:\n",
+    "        print(i)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "24bb8e62",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "metrics_data = None\n",
+    "size = 1000\n",
+    "for i in range(int(math.ceil(len(test_set_to_use) / size))):\n",
+    "    _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n",
+    "    compute_metrics.time_data = _dataset['time']\n",
+    "    compute_metrics.type_data = _dataset['token_type']\n",
+    "    if len(_dataset):\n",
+    "        p = trainer.predict(_dataset)\n",
+    "        metrics_data = compute_metrics(p, metrics_data)['metrics_data']"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "055f6a43",
+   "metadata": {},
+   "source": [
+    "# Hyperparameter search"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fe6a39e1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from ray.tune.schedulers import PopulationBasedTraining\n",
+    "from ray import tune\n",
+    "from ray.tune import CLIReporter \n",
+    "import ray"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5105ae66",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "compute_metrics = ComputePrecisionHF(id2tkn, id2type, prediction_scope='age', topk=1, start=0, batch_size=2000)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ea5f2956",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "NUM_TRIALS = 20\n",
+    "N_GPU_PER_TRIAL = 1\n",
+    "METRIC_TO_OPTIMIZE = 'eval_precision'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "27783bf0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_model(params):\n",
+    "    torch.cuda.empty_cache()\n",
+    "    if params is None:\n",
+    "        params = {}\n",
+    "    \n",
+    "    config = GPT2Config(\n",
+    "        vocab_size=len(embeddings),\n",
+    "        n_positions=MAX_SEQ_LEN+1,\n",
+    "        n_ctx=MAX_SEQ_LEN+1,\n",
+    "        n_embd=params.get('n_embd', 300),\n",
+    "        n_layer=params.get('n_layer', 1),\n",
+    "        n_head=params.get('n_head', 1),\n",
+    "        bos_token_id=tkn2id['<PAD>'],\n",
+    "        eos_token_id=tkn2id['<PAD>']\n",
+    "    )\n",
+    "    model = GPT2LMHeadModel(config)\n",
+    "    \n",
+    "    if params.get('load_weights', 0):\n",
+    "        model.transformer.wte.load_state_dict({'weight': torch.tensor(embeddings, dtype=torch.float32)})\n",
+    "        model.transformer.wte.weight.requires_grad = True\n",
+    "    \n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "16e4a39d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "training_args = TrainingArguments(\n",
+    "    output_dir='./results',          # output directory\n",
+    "    num_train_epochs=5,              # total number of training epochs\n",
+    "    per_device_train_batch_size=16,  # batch size per device during training\n",
+    "    per_device_eval_batch_size=128,   # batch size for evaluation\n",
+    "    weight_decay=0.01,               # strength of weight decay\n",
+    "    logging_dir='./logs',            # directory for storing logs\n",
+    "    logging_steps=200,\n",
+    "    eval_steps=200,\n",
+    "    learning_rate= 5e-5,\n",
+    "    eval_accumulation_steps=1,\n",
+    "    do_eval=True,\n",
+    "    evaluation_strategy='steps',\n",
+    "    skip_memory_metrics=True,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0edc2eb1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "training_args.n_head = 1\n",
+    "training_args.n_layer = 1\n",
+    "training_args.n_embd = 300\n",
+    "training_args.load_weights = 0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e46e8cb6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tune_dataset = encoded_dataset['train'].train_test_split(test_size=0.1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2f111f0a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tune_train_dataset = tune_dataset['train']\n",
+    "tune_test_dataset = tune_dataset['test']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f085ab71",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "trainer = Trainer(\n",
+    "#    model=model,                         # the instantiated 🤗 Transformers model to be trained\n",
+    "    args=training_args,                  # training arguments, defined above\n",
+    "    train_dataset=tune_train_dataset,         # training dataset\n",
+    "    eval_dataset=tune_test_dataset,             # evaluation dataset\n",
+    "    compute_metrics=compute_metrics,\n",
+    "    data_collator=collate_fn,\n",
+    "    tokenizer=None,\n",
+    "    model_init=get_model,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d0356cdc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tune_config = {\n",
+    "    \"num_train_epochs\": tune.choice([5]),\n",
+    "    \"n_head\": tune.choice([2, 4, 6]),\n",
+    "}\n",
+    "scheduler = PopulationBasedTraining(\n",
+    "    time_attr=\"training_iteration\",\n",
+    "    metric=METRIC_TO_OPTIMIZE,\n",
+    "    mode=\"max\",\n",
+    "    perturbation_interval=1,\n",
+    "    hyperparam_mutations={\n",
+    "        \"weight_decay\": tune.uniform(0.0, 0.3),\n",
+    "        \"learning_rate\": tune.uniform(1e-5, 5e-5),\n",
+    "        \"per_device_train_batch_size\": [16, 32, 64, 128],\n",
+    "        \"n_layer\": tune.choice([2, 4, 6, 8]),\n",
+    "#       \"n_embd\": tune.choice([256, 512]),\n",
+    "        \"load_weights\": tune.choice([0, 1]),\n",
+    "        \"warmup_steps\": tune.choice([20, 40, 60, 100]),\n",
+    "    })"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "63f3fd7d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import copy\n",
+    "def compute_objective(metrics):\n",
+    "    metrics = copy.deepcopy(metrics)\n",
+    "    eval_precision = metrics.pop('eval_precision')\n",
+    "    \n",
+    "    return eval_precision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "23e5bff1",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "best_model = trainer.hyperparameter_search(\n",
+    "    hp_space=lambda _: tune_config,\n",
+    "    backend=\"ray\",\n",
+    "    n_trials=NUM_TRIALS,\n",
+    "    direction='maximize',\n",
+    "    compute_objective=compute_objective,\n",
+    "    resources_per_trial={\n",
+    "        \"cpu\": 1,\n",
+    "        \"gpu\": N_GPU_PER_TRIAL\n",
+    "    },\n",
+    "    scheduler=scheduler,\n",
+    "    keep_checkpoints_num=1,\n",
+    "    checkpoint_score_attr=METRIC_TO_OPTIMIZE,\n",
+    "    stop=None,\n",
+    "    local_dir=RESULTS_HYPERPARAM,\n",
+    "    name=\"21_May_2021\",\n",
+    "    log_to_file=False,\n",
+    "    loggers=None,# (WandbLogger, ),\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "edf5a0ab",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "best_model"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9db3fae5",
+   "metadata": {},
+   "source": [
+    "# Saliency "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "364f1896",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import ecco"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1b237de8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lm = ecco.LM(trainer.model, tokenizer, model_name='gpt2')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "92ad3750",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "ind = 49\n",
+    "print(\"~~\".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids']]))\n",
+    "text = \"~~\".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids'][1:-1]])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "28108d90",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "output = lm.generate(text, generate=10, do_sample=True, temperature=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7b1b3898",
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "output.saliency(style=\"detailed\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3dcb6cf5",
+   "metadata": {},
+   "source": [
+    "# Probability prediction"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "09f0fbb7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from foresight.sight import Sight"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a4b0662d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "_ = model.eval()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f6d741d6",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sight = Sight(tokenizer=tokenizer, device='cuda', model=model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a1b9bfec",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cdb.name2cuis['muscle~pain']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f92d2766",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cdb.get_name('pain')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b545e844",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "text = '<ETHNICITY>~~White~~<SEX>~~Male~~<AGE>~~23~~49727002~~386661006'.split(\"~~\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7a7fc62e",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "# Small with WD\n",
+    "r = sight.next_concepts(text, type_ids=['T-11'], n=40, p_new=True, create_position_ids=False)\n",
+    "print([cdb.get_name(x) for x in text])\n",
+    "for x in r:\n",
+    "    print(x[0], x[1], cdb.get_name(x[0]))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}