[20be63]: / experiments / Foresight MIMIC Final -- Prepare data.ipynb

Download this file

1627 lines (1626 with data), 39.8 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01ae6e4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import datasets\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from medcat.cat import CAT\n",
    "from foresight.datasets import patient_concept_stream\n",
    "from foresight.datasets.filters import filter_by_count, filter_by_type\n",
    "from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \\\n",
    "                                  remove_parents_from_stream, bucket_concepts, cleanup_stream, \\\n",
    "                                  split_stream, add_age, get_all_splits, add_ttd, add_position_ids\n",
    "from foresight.utils import pickle\n",
    "from foresight.utils.cdb_utils import get_parents_map \n",
    "from foresight.utils.stream_utils import docs2stream, calculate_counts\n",
    "from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n",
    "from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n",
    "from medcat.cdb import CDB\n",
    "from foresight.utils import pickle\n",
    "import plotly.express as px"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f57b67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "DAYS = 1 # Do: 1, 14, 30\n",
    "MAX_SEQ_LEN = 256\n",
    "#TYPES = ['T-45', 'T-55', 'T-26', 'T-29', 'T-40', 'T-9', 'T-27', 'T-11', 'T-39', 'T-18']\n",
    "TYPES = ['ALL_TYPES']\n",
    "#TYPES = ['T-11']\n",
    "#TYPES = ['T-11', 'T-18']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ac3ee70",
   "metadata": {},
   "outputs": [],
   "source": [
    "BASE_NAME = 'annotated_february_2022'\n",
    "DATASET_NAME = 'annotations_stream_phase2_v1'\n",
    "RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c865ce1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_info = open(\"dataset-info/\" + RUN_NAME + '.txt', 'w')\n",
    "def fprint(*texts):\n",
    "    for text in texts:\n",
    "        print(text)\n",
    "        ds_info.write(str(text) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "922abc5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "FROM_BASE = False\n",
    "BASE_TOKENIZER_PATH = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c32b0b92",
   "metadata": {},
   "outputs": [],
   "source": [
    "TYPES = set(TYPES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6efddbcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}.pickle\"\n",
    "DATA_PATH_SPLITS = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}_split/\"\n",
    "TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n",
    "ALMOST_PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_almost_prepared_split/\"\n",
    "PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n",
    "JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_just_before_encoding/\"\n",
    "CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n",
    "PT_DOB_PATH = \"./data/mimic/pt2dob_datetime.pickle\"\n",
    "PT_DOD_PATH = \"./data/mimic/pt2dod_timestamp.pickle\"\n",
    "PT_SEX_PATH = \"./data/mimic/pt2sex.pickle\"\n",
    "PT_LNS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/lns_{DATASET_NAME}.pickle\"\n",
    "PT_CNTS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/cnts_{DATASET_NAME}.pickle\"\n",
    "PT_ETHNICITY_PATH = \"./data/mimic/pt2ethnicity.pickle\"\n",
    "TOKEN_TYPES_PATH = f'./data/timecat/mimic/{BASE_NAME}/types_{DATASET_NAME}.pickle'\n",
    "\n",
    "BATCH_SIZE = 200\n",
    "DEVICE = torch.device('cuda')\n",
    "NUM_PROC = 16\n",
    "MIN_COUNT = 2 # 3\n",
    "MIN_GLOBAL_COUNT = 100 # 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8eddb07",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n",
    "cdb = cat.cdb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5b51800",
   "metadata": {},
   "source": [
    "# Convert docs.pickle into patient stream used by HF datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9445cd0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "doc2pt = pickle.load(\"./data/timecat/mimic/doc2pt.pickle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41568f81",
   "metadata": {},
   "outputs": [],
   "source": [
    "doc2pt = {str(k):v for k,v in doc2pt.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5bdc483",
   "metadata": {},
   "source": [
    "### Get counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad82350f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pt2cui2cnt = None\n",
    "base_path = './data/timecat/mimic/annotated_february_2022/'\n",
    "doc_paths = os.listdir(base_path)\n",
    "doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
    "\n",
    "for path in doc_paths:\n",
    "    docs = pickle.load(os.path.join(base_path, path))\n",
    "    \n",
    "    pt2cui2cnt = calculate_counts(docs=docs,\n",
    "                     doc2pt=doc2pt,\n",
    "                     pt2cui2cnt=pt2cui2cnt,\n",
    "                     meta_requirements={'Presence': 'True', 'Subject': 'Patient'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "063bf5b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(dict(pt2cui2cnt), f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0355f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pt2cui2cnt = pickle.load(f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "355925a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number of annotations per type\n",
    "cnt_per_type = {}\n",
    "other_cnt = 0\n",
    "for pt in pt2cui2cnt:\n",
    "    for cui in pt2cui2cnt[pt]:\n",
    "        if cat.cdb.cui2type_ids[cui]:\n",
    "            t = list(cat.cdb.cui2type_ids[cui])[0]\n",
    "            cnt_per_type[t] = cnt_per_type.get(t, 0) + pt2cui2cnt[pt][cui]\n",
    "        else:\n",
    "            other_cnt += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c7b309",
   "metadata": {},
   "outputs": [],
   "source": [
    "fprint(\"Total number of annotations per type: \")\n",
    "for t in cnt_per_type:\n",
    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type[t]))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c513bf0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fprint(\"Total number of annotations: \", sum([x for x in cnt_per_type.values()]))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a0375ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get total number of different concepts\n",
    "all_cuis = set()\n",
    "for pt in pt2cui2cnt.keys():\n",
    "    for cui in pt2cui2cnt[pt]: \n",
    "        all_cuis.add(cui)\n",
    "fprint(\"Total number of different concepts: \", len(all_cuis))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e5e6c4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number of patients\n",
    "fprint(\"Total number of patients: \", len(pt2cui2cnt))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e3471c9",
   "metadata": {},
   "source": [
    "### Get pt2stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deb8bf33",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_path = f'./data/timecat/mimic/{BASE_NAME}/'\n",
    "doc_paths = os.listdir(base_path)\n",
    "doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
    "pt2stream = None\n",
    "doc2time =  {str(k):v for k,v in pickle.load(\"./data/timecat/mimic/doc2time.pickle\").items()}\n",
    "\n",
    "for path in doc_paths:\n",
    "    docs = pickle.load(os.path.join(base_path, path))\n",
    "    pt2stream = docs2stream(docs,\n",
    "                            doc2pt=doc2pt,\n",
    "                            pt2cui2cnt=pt2cui2cnt,\n",
    "                            doc2time=doc2time,\n",
    "                            entity_type_column='type_ids',\n",
    "                            meta_requirements={'Subject': 'Patient', 'Presence': 'True'},\n",
    "                            historical_meta='Time',\n",
    "                            historical_meta_value='Past',\n",
    "                            old_pt2stream=pt2stream,\n",
    "                            skip_cuis={'S-418023006', '17971005'},\n",
    "                            require_time=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c89f0594",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(dict(pt2stream), DATA_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cfb0fa2",
   "metadata": {},
   "source": [
    "# Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba72e618",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dataset = datasets.load_dataset(os.path.abspath(patient_concept_stream.__file__), \n",
    "                                data_files={'train': DATA_PATH})['train']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "305eed65",
   "metadata": {},
   "source": [
    "# Calculate counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28bb5e02",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate counts for tokens\n",
    "token_cnt = defaultdict(int)\n",
    "for _dataset in get_all_splits(dataset):\n",
    "    for stream in _dataset['stream']:\n",
    "        unique_tokens = set([sample['token'] for sample in stream])\n",
    "        for token in unique_tokens:\n",
    "            token_cnt[token] += 1\n",
    "token_cnt = dict(token_cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5fbd7c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(token_cnt, PT_CNTS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c928e97",
   "metadata": {},
   "outputs": [],
   "source": [
    "MIN_GLOBAL_COUNT = 100 # 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b6766cf",
   "metadata": {},
   "source": [
    "# Load and filter by count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04085e31",
   "metadata": {},
   "outputs": [],
   "source": [
    "token_cnt = pickle.load(PT_CNTS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "049dfd9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = filter_by_count(dataset, min_count=MIN_COUNT, min_count_global=MIN_GLOBAL_COUNT, min_length=5, max_length=-1, \n",
    "                          num_proc=NUM_PROC, token_cnt=token_cnt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ca8e8e8",
   "metadata": {},
   "source": [
    "### Split and save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5feea2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number of annotations per type\n",
    "cnt_per_type = {}\n",
    "for cui in token_cnt:\n",
    "    if cat.cdb.cui2type_ids[cui]:\n",
    "        t = list(cat.cdb.cui2type_ids[cui])[0]\n",
    "        cnt_per_type[t] = cnt_per_type.get(t, 0) + token_cnt[cui]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f743cfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.train_test_split(test_size = 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49dffff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.save_to_disk(DATA_PATH_SPLITS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2580b573",
   "metadata": {},
   "source": [
    "# CONTINUE FROM HERE WHEN NOT THE FIRST RUN"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c535b25",
   "metadata": {},
   "source": [
    "### Load splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acf777de",
   "metadata": {},
   "outputs": [],
   "source": [
    "token_cnt = pickle.load(PT_CNTS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab30780b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(DATA_PATH_SPLITS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3929e28a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fe3b048",
   "metadata": {},
   "outputs": [],
   "source": [
    "fprint(\"Total number of pts in train/test: {}/{}\".format(len(dataset['train']), len(dataset['test'])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d7964a1",
   "metadata": {},
   "source": [
    "# Filter to required type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1067d23e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"ALL_TYPES\" not in TYPES:\n",
    "    print(\"FILTERING\")\n",
    "    dataset = filter_by_type(dataset, types_to_keep=TYPES, num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58033ab4",
   "metadata": {},
   "source": [
    "# Add Death token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9db46905",
   "metadata": {},
   "outputs": [],
   "source": [
    "pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
    "pt2death = {k:\"The patient has died\" for k in pt2dod_timestamp.keys()}\n",
    "dataset = dataset.map(\n",
    "        lambda examples: add_to_stream(examples, pt2death, last=True, prefix=None, token_type='death'),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5c995e2",
   "metadata": {},
   "source": [
    "# Bucket and split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c019959",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dataset = dataset.map(\n",
    "        lambda examples: bucket_concepts(examples, bucket_size_seconds=DAYS*24*60*60, duration_separator=False),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18804746",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27383ac4",
   "metadata": {},
   "source": [
    "## Trim long streams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "460f779e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "lns = []\n",
    "for _dataset in get_all_splits(dataset):\n",
    "    for stream in _dataset['stream']:\n",
    "        lns.append(len(stream))\n",
    "pickle.dump(lns, PT_LNS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c18011",
   "metadata": {},
   "outputs": [],
   "source": [
    "lns = pickle.load(PT_LNS_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f44e23e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(lns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ebf1af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "max(lns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d8b228",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_len = int(np.percentile(lns, 95))\n",
    "max_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f50eaf63",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "fig = px.histogram(x=[x for x in lns if x < max_len and x > 5], labels={'x': 'length'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67531287",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.write_html(\"./dataset-info/\" + RUN_NAME + \".html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5426144f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = filter_by_count(dataset, min_count=0, min_count_global=0, min_length=10, max_length=max_len, \n",
    "                          num_proc=NUM_PROC, token_cnt=token_cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71c14b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb87c646",
   "metadata": {},
   "source": [
    "## Split to max len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ffc1da8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dataset = dataset.map(\n",
    "        lambda examples: split_stream(examples, max_seq_len=MAX_SEQ_LEN-32),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "577beded",
   "metadata": {},
   "source": [
    "## Save again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e18f654",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.save_to_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da54723d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b10396a",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60b7eb32",
   "metadata": {},
   "source": [
    "# Add DOD and TTD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d04702ae",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pt2dob_timestamp = {str(k):v for k,v in pickle.load(PT_DOB_PATH).items()}\n",
    "dataset = dataset.map(\n",
    "        lambda examples: add_age(examples, pt2dob_timestamp=pt2dob_timestamp),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d57596",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
    "# ADD time to die\n",
    "dataset = dataset.map(\n",
    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "312a702f",
   "metadata": {},
   "source": [
    "### Another way for TTD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88aca6b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "# ADD time to die\n",
    "dataset['train'] = dataset['train'].map(\n",
    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)\n",
    "\n",
    "dataset['test'] = dataset['test'].map(\n",
    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60,\n",
    "                                 max_nttd=10, ttd_prob=1, duplicate_streams=True),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a520f26a",
   "metadata": {},
   "source": [
    "# Add sex and ethnicity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbcf3d29",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Add Sex\n",
    "pt2sex = pickle.load(PT_SEX_PATH)\n",
    "dataset = dataset.map(\n",
    "        lambda examples: add_to_stream(examples, pt2sex, last=False, prefix='<SEX>', token_type='sex'),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14005ffc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Ethnicity\n",
    "pt2ethnicity = pickle.load(PT_ETHNICITY_PATH)\n",
    "dataset = dataset.map(\n",
    "        lambda examples: add_to_stream(examples, pt2ethnicity, last=False, prefix='<ETHNICITY>', token_type='ethnicity'),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0dd9e11",
   "metadata": {},
   "source": [
    "# Final filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdb7f59b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = filter_by_count(dataset, min_count=None, min_count_global=None, min_length=10, num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1ce6130",
   "metadata": {},
   "source": [
    "# Remove parents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47f86d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Diseases\n",
    "cuis = [token for token in cdb.config.linking['filters']['cuis'] if token in cdb.cui2names]\n",
    "ch2parents = get_parents_map(cuis, cdb.addl_info['pt2ch'], depth=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64b4664f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(\n",
    "        lambda examples: remove_parents_from_stream(examples, ch2parents=ch2parents, separator='<SEP>'),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a3aa069",
   "metadata": {},
   "source": [
    "## Add position IDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0858f524",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.map(\n",
    "        lambda examples: add_position_ids(examples, separators={'<SEP>', '<SEP-1>', '<SEP-7>' '<SEP-14>', '<SEP-30>', '<SEP-365>'}),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf641980",
   "metadata": {},
   "source": [
    "# Get token_type2tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bef3781",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "token_type2tokens = defaultdict(set)\n",
    "total_cnt = 0\n",
    "for _dataset in get_all_splits(dataset):\n",
    "    for stream in _dataset['stream']:\n",
    "        for example in stream:\n",
    "            token_type2tokens[example['token_type']].add(example['token'])\n",
    "            total_cnt += 1\n",
    "token_type2tokens = dict(token_type2tokens)\n",
    "pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
    "fprint(\"Total number of annotations: \", total_cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb1f2a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
    "fprint(\"Total number of annotations: \", total_cnt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c34003",
   "metadata": {},
   "source": [
    "# Cleanup stream and leave only what we need"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33d5fbdb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dataset = dataset.map(\n",
    "        lambda examples: cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True,\n",
    "                                        keep_context_representation=False),\n",
    "        batched=True,\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71ad6184",
   "metadata": {},
   "source": [
    "### Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc98c4f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset.save_to_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2b2c1e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41ac5775",
   "metadata": {},
   "outputs": [],
   "source": [
    "JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38c09fd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number of patients fater intial filtering\n",
    "train_len = len(dataset['train'])\n",
    "test_len = len(dataset['test'])\n",
    "fprint(\"Total number of pts in train: \", train_len)\n",
    "fprint(\"Total number of pts in test: \", test_len)\n",
    "fprint(\"Total number of pts: \", train_len + test_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29bcc0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number of annotations per type after filtering\n",
    "cnt_per_type_after = {}\n",
    "for _dataset in get_all_splits(dataset):\n",
    "    for stream in _dataset['stream']:\n",
    "        for cui in stream:\n",
    "            if cat.cdb.cui2type_ids.get(cui, None):\n",
    "                t = list(cat.cdb.cui2type_ids[cui])[0]\n",
    "                cnt_per_type_after[t] = cnt_per_type_after.get(t, 0) + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "208c24fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "fprint(\"Total number of annotations per type: \\n\")\n",
    "for t in cnt_per_type_after:\n",
    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type_after[t]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2b4d62e",
   "metadata": {},
   "source": [
    "# Make tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf2ec1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_tokenizer = None\n",
    "#extra_tokenizer = SimpleMapTokenizer.load(\"./data/time/models/slam_tokenizer_annotations_stream_phase2_1d_200_ALL_TYPES.pickle\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e150f156",
   "metadata": {},
   "outputs": [],
   "source": [
    "token_type2tokens = pickle.load(TOKEN_TYPES_PATH)\n",
    "extra_concepts = None\n",
    "if extra_tokenizer is not None:\n",
    "    extra_concepts = list(extra_tokenizer.tkn2id.keys())\n",
    "\n",
    "    for k,v in extra_tokenizer.token_type2tokens.items():\n",
    "        if k in token_type2tokens:\n",
    "            token_type2tokens[k].update(extra_tokenizer.token_type2tokens[k])\n",
    "        else:\n",
    "            token_type2tokens[k] = extra_tokenizer.token_type2tokens[k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "767d8549",
   "metadata": {},
   "outputs": [],
   "source": [
    "_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
    "embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,\n",
    "                                                        concepts=extra_concepts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b9946e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
    "tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
    "                               token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
    "                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b9cce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)\n",
    "assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)\n",
    "assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)\n",
    "fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "031b224b",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tokenizer.tkn2name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0295fe2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save\n",
    "tokenizer.save(TOKENIZER_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6cd9f04",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number of different concepts after all filtering\n",
    "fprint(\"Total number of concepts after filtering: \", len(tokenizer.tkn2id))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9f996f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Total number annotations after all filtering\n",
    "fprint(\"Total number of annotations after filtering: \", sum([x for x in cnt_per_type_after.values()]))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "862166a9",
   "metadata": {},
   "source": [
    "# Print number of different concepts per type after filtering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82ebd72c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cnt_per_type = {}\n",
    "for cui in tkn2id:\n",
    "    if cat.cdb.cui2type_ids.get(cui, ['Other']):\n",
    "        t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]\n",
    "        cnt_per_type[t] = cnt_per_type.get(t, 0) + 1\n",
    "fprint(\"Total number of <<different>> concepts per type after filtering\")\n",
    "for t in cnt_per_type:\n",
    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))\n",
    "fprint(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19cf388d",
   "metadata": {},
   "source": [
    "# Create global tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac9c9839",
   "metadata": {},
   "outputs": [],
   "source": [
    "_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
    "concepts = list(cat.config.linking['filters']['cuis'])\n",
    "embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eabbbf88",
   "metadata": {},
   "outputs": [],
   "source": [
    "tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
    "tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
    "                               token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
    "                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22833eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.save(BASE_TOKENIZER_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c06d09b",
   "metadata": {},
   "source": [
    "# Convert tokens to IDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864df946",
   "metadata": {},
   "outputs": [],
   "source": [
    "if FROM_BASE:\n",
    "    print(\"USING BASE TOKENIZER\")\n",
    "    TOKENIZER_PATH = BASE_TOKENIZER_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4540c4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer =  SimpleMapTokenizer.load(TOKENIZER_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "548e8e3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_dataset = dataset.map(\n",
    "        lambda examples: tokenizer.encode(examples),\n",
    "        batched=True,\n",
    "        remove_columns=['stream'],\n",
    "        load_from_cache_file=False,\n",
    "        num_proc=NUM_PROC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8718ae76",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "340dd11d",
   "metadata": {},
   "outputs": [],
   "source": [
    "PREPARED_DATASET_SPLIT_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e512ab8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "TOKENIZER_PATH"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6715bb47",
   "metadata": {},
   "source": [
    "# Test is all OK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60881f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f3c066f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dab3686",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035a55fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89f84af",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d66a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "ind = 1096"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "664dd895",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49b1cbc1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec0594e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):\n",
    "    print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19003ff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_dataset['train'][ind]['patient_id']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc5460d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_info.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5a3cc65",
   "metadata": {},
   "source": [
    "# Preapre for Foresight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dec8a19",
   "metadata": {},
   "outputs": [],
   "source": [
    "ind = 32330"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2821af66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55457df7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "196eaf9d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, c in enumerate(dataset['train'][ind]['stream']):\n",
    "    print(i)\n",
    "    if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:\n",
    "        print(i, c, cdb.get_name(c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a159ae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = []\n",
    "for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):\n",
    "    d = {\n",
    "        'id': cui,\n",
    "        'label': cdb.get_name(cui),\n",
    "        'count': 1000000,\n",
    "        'name': cdb.get_name(cui),\n",
    "        'cui': cui,\n",
    "        'saliency': 0,\n",
    "        'uid': i\n",
    "    }\n",
    "    out.append(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f40a06",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(out, open(\"./data/tmp/timeline_example_1.json\", 'w'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da6e079d",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2730480",
   "metadata": {},
   "outputs": [],
   "source": [
    "out"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}