--- a
+++ b/experiments/Foresight MIMIC Final -- Prepare data.ipynb
@@ -0,0 +1,1626 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "01ae6e4a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import os\n",
+    "import datasets\n",
+    "import numpy as np\n",
+    "from collections import defaultdict\n",
+    "from medcat.cat import CAT\n",
+    "from foresight.datasets import patient_concept_stream\n",
+    "from foresight.datasets.filters import filter_by_count, filter_by_type\n",
+    "from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \\\n",
+    "                                  remove_parents_from_stream, bucket_concepts, cleanup_stream, \\\n",
+    "                                  split_stream, add_age, get_all_splits, add_ttd, add_position_ids\n",
+    "from foresight.utils import pickle\n",
+    "from foresight.utils.cdb_utils import get_parents_map \n",
+    "from foresight.utils.stream_utils import docs2stream, calculate_counts\n",
+    "from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n",
+    "from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n",
+    "from medcat.cdb import CDB\n",
+    "from foresight.utils import pickle\n",
+    "import plotly.express as px"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9f57b67c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "DAYS = 1 # Do: 1, 14, 30\n",
+    "MAX_SEQ_LEN = 256\n",
+    "#TYPES = ['T-45', 'T-55', 'T-26', 'T-29', 'T-40', 'T-9', 'T-27', 'T-11', 'T-39', 'T-18']\n",
+    "TYPES = ['ALL_TYPES']\n",
+    "#TYPES = ['T-11']\n",
+    "#TYPES = ['T-11', 'T-18']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5ac3ee70",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "BASE_NAME = 'annotated_february_2022'\n",
+    "DATASET_NAME = 'annotations_stream_phase2_v1'\n",
+    "RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c865ce1e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ds_info = open(\"dataset-info/\" + RUN_NAME + '.txt', 'w')\n",
+    "def fprint(*texts):\n",
+    "    for text in texts:\n",
+    "        print(text)\n",
+    "        ds_info.write(str(text) + \"\\n\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "922abc5f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "FROM_BASE = False\n",
+    "BASE_TOKENIZER_PATH = ''"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c32b0b92",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "TYPES = set(TYPES)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6efddbcc",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "DATA_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}.pickle\"\n",
+    "DATA_PATH_SPLITS = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}_split/\"\n",
+    "TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n",
+    "ALMOST_PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_almost_prepared_split/\"\n",
+    "PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n",
+    "JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_just_before_encoding/\"\n",
+    "CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n",
+    "PT_DOB_PATH = \"./data/mimic/pt2dob_datetime.pickle\"\n",
+    "PT_DOD_PATH = \"./data/mimic/pt2dod_timestamp.pickle\"\n",
+    "PT_SEX_PATH = \"./data/mimic/pt2sex.pickle\"\n",
+    "PT_LNS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/lns_{DATASET_NAME}.pickle\"\n",
+    "PT_CNTS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/cnts_{DATASET_NAME}.pickle\"\n",
+    "PT_ETHNICITY_PATH = \"./data/mimic/pt2ethnicity.pickle\"\n",
+    "TOKEN_TYPES_PATH = f'./data/timecat/mimic/{BASE_NAME}/types_{DATASET_NAME}.pickle'\n",
+    "\n",
+    "BATCH_SIZE = 200\n",
+    "DEVICE = torch.device('cuda')\n",
+    "NUM_PROC = 16\n",
+    "MIN_COUNT = 2 # 3\n",
+    "MIN_GLOBAL_COUNT = 100 # 1000"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a8eddb07",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n",
+    "cdb = cat.cdb"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "c5b51800",
+   "metadata": {},
+   "source": [
+    "# Convert docs.pickle into patient stream used by HF datasets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9445cd0d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "doc2pt = pickle.load(\"./data/timecat/mimic/doc2pt.pickle\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "41568f81",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "doc2pt = {str(k):v for k,v in doc2pt.items()}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e5bdc483",
+   "metadata": {},
+   "source": [
+    "### Get counts"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ad82350f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pt2cui2cnt = None\n",
+    "base_path = './data/timecat/mimic/annotated_february_2022/'\n",
+    "doc_paths = os.listdir(base_path)\n",
+    "doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
+    "\n",
+    "for path in doc_paths:\n",
+    "    docs = pickle.load(os.path.join(base_path, path))\n",
+    "    \n",
+    "    pt2cui2cnt = calculate_counts(docs=docs,\n",
+    "                     doc2pt=doc2pt,\n",
+    "                     pt2cui2cnt=pt2cui2cnt,\n",
+    "                     meta_requirements={'Presence': 'True', 'Subject': 'Patient'})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "063bf5b0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pickle.dump(dict(pt2cui2cnt), f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d0355f5e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pt2cui2cnt = pickle.load(f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "355925a7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number of annotations per type\n",
+    "cnt_per_type = {}\n",
+    "other_cnt = 0\n",
+    "for pt in pt2cui2cnt:\n",
+    "    for cui in pt2cui2cnt[pt]:\n",
+    "        if cat.cdb.cui2type_ids[cui]:\n",
+    "            t = list(cat.cdb.cui2type_ids[cui])[0]\n",
+    "            cnt_per_type[t] = cnt_per_type.get(t, 0) + pt2cui2cnt[pt][cui]\n",
+    "        else:\n",
+    "            other_cnt += 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "79c7b309",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fprint(\"Total number of annotations per type: \")\n",
+    "for t in cnt_per_type:\n",
+    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type[t]))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c513bf0c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fprint(\"Total number of annotations: \", sum([x for x in cnt_per_type.values()]))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6a0375ab",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Get total number of different concepts\n",
+    "all_cuis = set()\n",
+    "for pt in pt2cui2cnt.keys():\n",
+    "    for cui in pt2cui2cnt[pt]: \n",
+    "        all_cuis.add(cui)\n",
+    "fprint(\"Total number of different concepts: \", len(all_cuis))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5e5e6c4f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number of patients\n",
+    "fprint(\"Total number of patients: \", len(pt2cui2cnt))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8e3471c9",
+   "metadata": {},
+   "source": [
+    "### Get pt2stream"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "deb8bf33",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "base_path = f'./data/timecat/mimic/{BASE_NAME}/'\n",
+    "doc_paths = os.listdir(base_path)\n",
+    "doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
+    "pt2stream = None\n",
+    "doc2time =  {str(k):v for k,v in pickle.load(\"./data/timecat/mimic/doc2time.pickle\").items()}\n",
+    "\n",
+    "for path in doc_paths:\n",
+    "    docs = pickle.load(os.path.join(base_path, path))\n",
+    "    pt2stream = docs2stream(docs,\n",
+    "                            doc2pt=doc2pt,\n",
+    "                            pt2cui2cnt=pt2cui2cnt,\n",
+    "                            doc2time=doc2time,\n",
+    "                            entity_type_column='type_ids',\n",
+    "                            meta_requirements={'Subject': 'Patient', 'Presence': 'True'},\n",
+    "                            historical_meta='Time',\n",
+    "                            historical_meta_value='Past',\n",
+    "                            old_pt2stream=pt2stream,\n",
+    "                            skip_cuis={'S-418023006', '17971005'},\n",
+    "                            require_time=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c89f0594",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pickle.dump(dict(pt2stream), DATA_PATH)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6cfb0fa2",
+   "metadata": {},
+   "source": [
+    "# Load dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ba72e618",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "dataset = datasets.load_dataset(os.path.abspath(patient_concept_stream.__file__), \n",
+    "                                data_files={'train': DATA_PATH})['train']"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "305eed65",
+   "metadata": {},
+   "source": [
+    "# Calculate counts"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "28bb5e02",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Calculate counts for tokens\n",
+    "token_cnt = defaultdict(int)\n",
+    "for _dataset in get_all_splits(dataset):\n",
+    "    for stream in _dataset['stream']:\n",
+    "        unique_tokens = set([sample['token'] for sample in stream])\n",
+    "        for token in unique_tokens:\n",
+    "            token_cnt[token] += 1\n",
+    "token_cnt = dict(token_cnt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d5fbd7c7",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pickle.dump(token_cnt, PT_CNTS_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9c928e97",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MIN_GLOBAL_COUNT = 100 # 1000"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3b6766cf",
+   "metadata": {},
+   "source": [
+    "# Load and filter by count"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "04085e31",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "token_cnt = pickle.load(PT_CNTS_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "049dfd9f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = filter_by_count(dataset, min_count=MIN_COUNT, min_count_global=MIN_GLOBAL_COUNT, min_length=5, max_length=-1, \n",
+    "                          num_proc=NUM_PROC, token_cnt=token_cnt)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0ca8e8e8",
+   "metadata": {},
+   "source": [
+    "### Split and save"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b5feea2f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number of annotations per type\n",
+    "cnt_per_type = {}\n",
+    "for cui in token_cnt:\n",
+    "    if cat.cdb.cui2type_ids[cui]:\n",
+    "        t = list(cat.cdb.cui2type_ids[cui])[0]\n",
+    "        cnt_per_type[t] = cnt_per_type.get(t, 0) + token_cnt[cui]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5f743cfe",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = dataset.train_test_split(test_size = 0.05)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "49dffff3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset.save_to_disk(DATA_PATH_SPLITS)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2580b573",
+   "metadata": {},
+   "source": [
+    "# CONTINUE FROM HERE WHEN NOT THE FIRST RUN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3c535b25",
+   "metadata": {},
+   "source": [
+    "### Load splits"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "acf777de",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "token_cnt = pickle.load(PT_CNTS_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ab30780b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = datasets.load_from_disk(DATA_PATH_SPLITS)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3929e28a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6fe3b048",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fprint(\"Total number of pts in train/test: {}/{}\".format(len(dataset['train']), len(dataset['test'])))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2d7964a1",
+   "metadata": {},
+   "source": [
+    "# Filter to required type"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1067d23e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if \"ALL_TYPES\" not in TYPES:\n",
+    "    print(\"FILTERING\")\n",
+    "    dataset = filter_by_type(dataset, types_to_keep=TYPES, num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "58033ab4",
+   "metadata": {},
+   "source": [
+    "# Add Death token"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9db46905",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
+    "pt2death = {k:\"The patient has died\" for k in pt2dod_timestamp.keys()}\n",
+    "dataset = dataset.map(\n",
+    "        lambda examples: add_to_stream(examples, pt2death, last=True, prefix=None, token_type='death'),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b5c995e2",
+   "metadata": {},
+   "source": [
+    "# Bucket and split"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3c019959",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "dataset = dataset.map(\n",
+    "        lambda examples: bucket_concepts(examples, bucket_size_seconds=DAYS*24*60*60, duration_separator=False),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "18804746",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "27383ac4",
+   "metadata": {},
+   "source": [
+    "## Trim long streams"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "460f779e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from collections import defaultdict\n",
+    "lns = []\n",
+    "for _dataset in get_all_splits(dataset):\n",
+    "    for stream in _dataset['stream']:\n",
+    "        lns.append(len(stream))\n",
+    "pickle.dump(lns, PT_LNS_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e2c18011",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lns = pickle.load(PT_LNS_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f44e23e0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "len(lns)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9ebf1af4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "max(lns)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "78d8b228",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "max_len = int(np.percentile(lns, 95))\n",
+    "max_len"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f50eaf63",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "fig = px.histogram(x=[x for x in lns if x < max_len and x > 5], labels={'x': 'length'})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "67531287",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fig.write_html(\"./dataset-info/\" + RUN_NAME + \".html\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5426144f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = filter_by_count(dataset, min_count=0, min_count_global=0, min_length=10, max_length=max_len, \n",
+    "                          num_proc=NUM_PROC, token_cnt=token_cnt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "71c14b2b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fb87c646",
+   "metadata": {},
+   "source": [
+    "## Split to max len"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0ffc1da8",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "dataset = dataset.map(\n",
+    "        lambda examples: split_stream(examples, max_seq_len=MAX_SEQ_LEN-32),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "577beded",
+   "metadata": {},
+   "source": [
+    "## Save again"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9e18f654",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset.save_to_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "da54723d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = datasets.load_from_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2b10396a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "60b7eb32",
+   "metadata": {},
+   "source": [
+    "# Add DOD and TTD"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d04702ae",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "pt2dob_timestamp = {str(k):v for k,v in pickle.load(PT_DOB_PATH).items()}\n",
+    "dataset = dataset.map(\n",
+    "        lambda examples: add_age(examples, pt2dob_timestamp=pt2dob_timestamp),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "79d57596",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\"\"\"\n",
+    "pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
+    "# ADD time to die\n",
+    "dataset = dataset.map(\n",
+    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "312a702f",
+   "metadata": {},
+   "source": [
+    "### Another way for TTD"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "88aca6b0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\"\"\"\n",
+    "# ADD time to die\n",
+    "dataset['train'] = dataset['train'].map(\n",
+    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)\n",
+    "\n",
+    "dataset['test'] = dataset['test'].map(\n",
+    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60,\n",
+    "                                 max_nttd=10, ttd_prob=1, duplicate_streams=True),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a520f26a",
+   "metadata": {},
+   "source": [
+    "# Add sex and ethnicity"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fbcf3d29",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "# Add Sex\n",
+    "pt2sex = pickle.load(PT_SEX_PATH)\n",
+    "dataset = dataset.map(\n",
+    "        lambda examples: add_to_stream(examples, pt2sex, last=False, prefix='<SEX>', token_type='sex'),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "14005ffc",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "# Ethnicity\n",
+    "pt2ethnicity = pickle.load(PT_ETHNICITY_PATH)\n",
+    "dataset = dataset.map(\n",
+    "        lambda examples: add_to_stream(examples, pt2ethnicity, last=False, prefix='<ETHNICITY>', token_type='ethnicity'),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b0dd9e11",
+   "metadata": {},
+   "source": [
+    "# Final filter"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "bdb7f59b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = filter_by_count(dataset, min_count=None, min_count_global=None, min_length=10, num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f1ce6130",
+   "metadata": {},
+   "source": [
+    "# Remove parents"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "47f86d59",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Diseases\n",
+    "cuis = [token for token in cdb.config.linking['filters']['cuis'] if token in cdb.cui2names]\n",
+    "ch2parents = get_parents_map(cuis, cdb.addl_info['pt2ch'], depth=2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "64b4664f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = dataset.map(\n",
+    "        lambda examples: remove_parents_from_stream(examples, ch2parents=ch2parents, separator='<SEP>'),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0a3aa069",
+   "metadata": {},
+   "source": [
+    "## Add position IDs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0858f524",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = dataset.map(\n",
+    "        lambda examples: add_position_ids(examples, separators={'<SEP>', '<SEP-1>', '<SEP-7>' '<SEP-14>', '<SEP-30>', '<SEP-365>'}),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cf641980",
+   "metadata": {},
+   "source": [
+    "# Get token_type2tokens"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9bef3781",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "token_type2tokens = defaultdict(set)\n",
+    "total_cnt = 0\n",
+    "for _dataset in get_all_splits(dataset):\n",
+    "    for stream in _dataset['stream']:\n",
+    "        for example in stream:\n",
+    "            token_type2tokens[example['token_type']].add(example['token'])\n",
+    "            total_cnt += 1\n",
+    "token_type2tokens = dict(token_type2tokens)\n",
+    "pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
+    "fprint(\"Total number of annotations: \", total_cnt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8bb1f2a1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
+    "fprint(\"Total number of annotations: \", total_cnt)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "61c34003",
+   "metadata": {},
+   "source": [
+    "# Cleanup stream and leave only what we need"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "33d5fbdb",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "dataset = dataset.map(\n",
+    "        lambda examples: cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True,\n",
+    "                                        keep_context_representation=False),\n",
+    "        batched=True,\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "71ad6184",
+   "metadata": {},
+   "source": [
+    "### Save"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "fc98c4f0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset.save_to_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c2b2c1e1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "41ac5775",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "38c09fd5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number of patients fater intial filtering\n",
+    "train_len = len(dataset['train'])\n",
+    "test_len = len(dataset['test'])\n",
+    "fprint(\"Total number of pts in train: \", train_len)\n",
+    "fprint(\"Total number of pts in test: \", test_len)\n",
+    "fprint(\"Total number of pts: \", train_len + test_len)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "29bcc0c4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number of annotations per type after filtering\n",
+    "cnt_per_type_after = {}\n",
+    "for _dataset in get_all_splits(dataset):\n",
+    "    for stream in _dataset['stream']:\n",
+    "        for cui in stream:\n",
+    "            if cat.cdb.cui2type_ids.get(cui, None):\n",
+    "                t = list(cat.cdb.cui2type_ids[cui])[0]\n",
+    "                cnt_per_type_after[t] = cnt_per_type_after.get(t, 0) + 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "208c24fb",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fprint(\"Total number of annotations per type: \\n\")\n",
+    "for t in cnt_per_type_after:\n",
+    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type_after[t]))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b2b4d62e",
+   "metadata": {},
+   "source": [
+    "# Make tokenizer"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8cf2ec1a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "extra_tokenizer = None\n",
+    "#extra_tokenizer = SimpleMapTokenizer.load(\"./data/time/models/slam_tokenizer_annotations_stream_phase2_1d_200_ALL_TYPES.pickle\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e150f156",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "token_type2tokens = pickle.load(TOKEN_TYPES_PATH)\n",
+    "extra_concepts = None\n",
+    "if extra_tokenizer is not None:\n",
+    "    extra_concepts = list(extra_tokenizer.tkn2id.keys())\n",
+    "\n",
+    "    for k,v in extra_tokenizer.token_type2tokens.items():\n",
+    "        if k in token_type2tokens:\n",
+    "            token_type2tokens[k].update(extra_tokenizer.token_type2tokens[k])\n",
+    "        else:\n",
+    "            token_type2tokens[k] = extra_tokenizer.token_type2tokens[k]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "767d8549",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
+    "embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,\n",
+    "                                                        concepts=extra_concepts)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e3b9946e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
+    "tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
+    "                               token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
+    "                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "51b9cce4",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)\n",
+    "assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)\n",
+    "assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)\n",
+    "fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "031b224b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "len(tokenizer.tkn2name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0295fe2d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# save\n",
+    "tokenizer.save(TOKENIZER_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f6cd9f04",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number of different concepts after all filtering\n",
+    "fprint(\"Total number of concepts after filtering: \", len(tokenizer.tkn2id))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b9f996f0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Total number annotations after all filtering\n",
+    "fprint(\"Total number of annotations after filtering: \", sum([x for x in cnt_per_type_after.values()]))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "862166a9",
+   "metadata": {},
+   "source": [
+    "# Print number of different concepts per type after filtering"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "82ebd72c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "cnt_per_type = {}\n",
+    "for cui in tkn2id:\n",
+    "    if cat.cdb.cui2type_ids.get(cui, ['Other']):\n",
+    "        t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]\n",
+    "        cnt_per_type[t] = cnt_per_type.get(t, 0) + 1\n",
+    "fprint(\"Total number of <<different>> concepts per type after filtering\")\n",
+    "for t in cnt_per_type:\n",
+    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))\n",
+    "fprint(\"\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "19cf388d",
+   "metadata": {},
+   "source": [
+    "# Create global tokenizer"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ac9c9839",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
+    "concepts = list(cat.config.linking['filters']['cuis'])\n",
+    "embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "eabbbf88",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
+    "tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
+    "                               token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
+    "                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "22833eec",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer.save(BASE_TOKENIZER_PATH)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "4c06d09b",
+   "metadata": {},
+   "source": [
+    "# Convert tokens to IDs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "864df946",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if FROM_BASE:\n",
+    "    print(\"USING BASE TOKENIZER\")\n",
+    "    TOKENIZER_PATH = BASE_TOKENIZER_PATH"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e4540c4a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer =  SimpleMapTokenizer.load(TOKENIZER_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "548e8e3b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset = dataset.map(\n",
+    "        lambda examples: tokenizer.encode(examples),\n",
+    "        batched=True,\n",
+    "        remove_columns=['stream'],\n",
+    "        load_from_cache_file=False,\n",
+    "        num_proc=NUM_PROC)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8718ae76",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "340dd11d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "PREPARED_DATASET_SPLIT_PATH"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e512ab8c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "TOKENIZER_PATH"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6715bb47",
+   "metadata": {},
+   "source": [
+    "# Test is all OK"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "60881f72",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2f3c066f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6dab3686",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "035a55fb",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b89f84af",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "21d66a24",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ind = 1096"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "664dd895",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from datetime import datetime"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "49b1cbc1",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9ec0594e",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):\n",
+    "    print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "19003ff8",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "encoded_dataset['train'][ind]['patient_id']"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1fc5460d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ds_info.close()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e5a3cc65",
+   "metadata": {},
+   "source": [
+    "# Preapre for Foresight"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2dec8a19",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ind = 32330"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2821af66",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "55457df7",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "196eaf9d",
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "for i, c in enumerate(dataset['train'][ind]['stream']):\n",
+    "    print(i)\n",
+    "    if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:\n",
+    "        print(i, c, cdb.get_name(c))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4a159ae3",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "out = []\n",
+    "for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):\n",
+    "    d = {\n",
+    "        'id': cui,\n",
+    "        'label': cdb.get_name(cui),\n",
+    "        'count': 1000000,\n",
+    "        'name': cdb.get_name(cui),\n",
+    "        'cui': cui,\n",
+    "        'saliency': 0,\n",
+    "        'uid': i\n",
+    "    }\n",
+    "    out.append(d)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "97f40a06",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "json.dump(out, open(\"./data/tmp/timeline_example_1.json\", 'w'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "da6e079d",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "len(out)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "e2730480",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "out"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}