1627 lines (1626 with data), 39.8 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "01ae6e4a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"import datasets\n",
"import numpy as np\n",
"from collections import defaultdict\n",
"from medcat.cat import CAT\n",
"from foresight.datasets import patient_concept_stream\n",
"from foresight.datasets.filters import filter_by_count, filter_by_type\n",
"from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \\\n",
" remove_parents_from_stream, bucket_concepts, cleanup_stream, \\\n",
" split_stream, add_age, get_all_splits, add_ttd, add_position_ids\n",
"from foresight.utils import pickle\n",
"from foresight.utils.cdb_utils import get_parents_map \n",
"from foresight.utils.stream_utils import docs2stream, calculate_counts\n",
"from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n",
"from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n",
"from medcat.cdb import CDB\n",
"from foresight.utils import pickle\n",
"import plotly.express as px"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f57b67c",
"metadata": {},
"outputs": [],
"source": [
"DAYS = 1 # Do: 1, 14, 30\n",
"MAX_SEQ_LEN = 256\n",
"#TYPES = ['T-45', 'T-55', 'T-26', 'T-29', 'T-40', 'T-9', 'T-27', 'T-11', 'T-39', 'T-18']\n",
"TYPES = ['ALL_TYPES']\n",
"#TYPES = ['T-11']\n",
"#TYPES = ['T-11', 'T-18']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ac3ee70",
"metadata": {},
"outputs": [],
"source": [
"BASE_NAME = 'annotated_february_2022'\n",
"DATASET_NAME = 'annotations_stream_phase2_v1'\n",
"RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c865ce1e",
"metadata": {},
"outputs": [],
"source": [
"ds_info = open(\"dataset-info/\" + RUN_NAME + '.txt', 'w')\n",
"def fprint(*texts):\n",
" for text in texts:\n",
" print(text)\n",
" ds_info.write(str(text) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "922abc5f",
"metadata": {},
"outputs": [],
"source": [
"FROM_BASE = False\n",
"BASE_TOKENIZER_PATH = ''"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c32b0b92",
"metadata": {},
"outputs": [],
"source": [
"TYPES = set(TYPES)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6efddbcc",
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}.pickle\"\n",
"DATA_PATH_SPLITS = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}_split/\"\n",
"TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n",
"ALMOST_PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_almost_prepared_split/\"\n",
"PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n",
"JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_just_before_encoding/\"\n",
"CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n",
"PT_DOB_PATH = \"./data/mimic/pt2dob_datetime.pickle\"\n",
"PT_DOD_PATH = \"./data/mimic/pt2dod_timestamp.pickle\"\n",
"PT_SEX_PATH = \"./data/mimic/pt2sex.pickle\"\n",
"PT_LNS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/lns_{DATASET_NAME}.pickle\"\n",
"PT_CNTS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/cnts_{DATASET_NAME}.pickle\"\n",
"PT_ETHNICITY_PATH = \"./data/mimic/pt2ethnicity.pickle\"\n",
"TOKEN_TYPES_PATH = f'./data/timecat/mimic/{BASE_NAME}/types_{DATASET_NAME}.pickle'\n",
"\n",
"BATCH_SIZE = 200\n",
"DEVICE = torch.device('cuda')\n",
"NUM_PROC = 16\n",
"MIN_COUNT = 2 # 3\n",
"MIN_GLOBAL_COUNT = 100 # 1000"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8eddb07",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n",
"cdb = cat.cdb"
]
},
{
"cell_type": "markdown",
"id": "c5b51800",
"metadata": {},
"source": [
"# Convert docs.pickle into patient stream used by HF datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9445cd0d",
"metadata": {},
"outputs": [],
"source": [
"doc2pt = pickle.load(\"./data/timecat/mimic/doc2pt.pickle\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41568f81",
"metadata": {},
"outputs": [],
"source": [
"doc2pt = {str(k):v for k,v in doc2pt.items()}"
]
},
{
"cell_type": "markdown",
"id": "e5bdc483",
"metadata": {},
"source": [
"### Get counts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad82350f",
"metadata": {},
"outputs": [],
"source": [
"pt2cui2cnt = None\n",
"base_path = './data/timecat/mimic/annotated_february_2022/'\n",
"doc_paths = os.listdir(base_path)\n",
"doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
"\n",
"for path in doc_paths:\n",
" docs = pickle.load(os.path.join(base_path, path))\n",
" \n",
" pt2cui2cnt = calculate_counts(docs=docs,\n",
" doc2pt=doc2pt,\n",
" pt2cui2cnt=pt2cui2cnt,\n",
" meta_requirements={'Presence': 'True', 'Subject': 'Patient'})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "063bf5b0",
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(dict(pt2cui2cnt), f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0355f5e",
"metadata": {},
"outputs": [],
"source": [
"pt2cui2cnt = pickle.load(f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "355925a7",
"metadata": {},
"outputs": [],
"source": [
"# Total number of annotations per type\n",
"cnt_per_type = {}\n",
"other_cnt = 0\n",
"for pt in pt2cui2cnt:\n",
" for cui in pt2cui2cnt[pt]:\n",
" if cat.cdb.cui2type_ids[cui]:\n",
" t = list(cat.cdb.cui2type_ids[cui])[0]\n",
" cnt_per_type[t] = cnt_per_type.get(t, 0) + pt2cui2cnt[pt][cui]\n",
" else:\n",
" other_cnt += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79c7b309",
"metadata": {},
"outputs": [],
"source": [
"fprint(\"Total number of annotations per type: \")\n",
"for t in cnt_per_type:\n",
" fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type[t]))\n",
"fprint(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c513bf0c",
"metadata": {},
"outputs": [],
"source": [
"fprint(\"Total number of annotations: \", sum([x for x in cnt_per_type.values()]))\n",
"fprint(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a0375ab",
"metadata": {},
"outputs": [],
"source": [
"# Get total number of different concepts\n",
"all_cuis = set()\n",
"for pt in pt2cui2cnt.keys():\n",
" for cui in pt2cui2cnt[pt]: \n",
" all_cuis.add(cui)\n",
"fprint(\"Total number of different concepts: \", len(all_cuis))\n",
"fprint(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5e5e6c4f",
"metadata": {},
"outputs": [],
"source": [
"# Total number of patients\n",
"fprint(\"Total number of patients: \", len(pt2cui2cnt))\n",
"fprint(\"\")"
]
},
{
"cell_type": "markdown",
"id": "8e3471c9",
"metadata": {},
"source": [
"### Get pt2stream"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "deb8bf33",
"metadata": {},
"outputs": [],
"source": [
"base_path = f'./data/timecat/mimic/{BASE_NAME}/'\n",
"doc_paths = os.listdir(base_path)\n",
"doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
"pt2stream = None\n",
"doc2time = {str(k):v for k,v in pickle.load(\"./data/timecat/mimic/doc2time.pickle\").items()}\n",
"\n",
"for path in doc_paths:\n",
" docs = pickle.load(os.path.join(base_path, path))\n",
" pt2stream = docs2stream(docs,\n",
" doc2pt=doc2pt,\n",
" pt2cui2cnt=pt2cui2cnt,\n",
" doc2time=doc2time,\n",
" entity_type_column='type_ids',\n",
" meta_requirements={'Subject': 'Patient', 'Presence': 'True'},\n",
" historical_meta='Time',\n",
" historical_meta_value='Past',\n",
" old_pt2stream=pt2stream,\n",
" skip_cuis={'S-418023006', '17971005'},\n",
" require_time=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c89f0594",
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(dict(pt2stream), DATA_PATH)"
]
},
{
"cell_type": "markdown",
"id": "6cfb0fa2",
"metadata": {},
"source": [
"# Load dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba72e618",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"dataset = datasets.load_dataset(os.path.abspath(patient_concept_stream.__file__), \n",
" data_files={'train': DATA_PATH})['train']"
]
},
{
"cell_type": "markdown",
"id": "305eed65",
"metadata": {},
"source": [
"# Calculate counts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28bb5e02",
"metadata": {},
"outputs": [],
"source": [
"# Calculate counts for tokens\n",
"token_cnt = defaultdict(int)\n",
"for _dataset in get_all_splits(dataset):\n",
" for stream in _dataset['stream']:\n",
" unique_tokens = set([sample['token'] for sample in stream])\n",
" for token in unique_tokens:\n",
" token_cnt[token] += 1\n",
"token_cnt = dict(token_cnt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5fbd7c7",
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(token_cnt, PT_CNTS_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c928e97",
"metadata": {},
"outputs": [],
"source": [
"MIN_GLOBAL_COUNT = 100 # 1000"
]
},
{
"cell_type": "markdown",
"id": "3b6766cf",
"metadata": {},
"source": [
"# Load and filter by count"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04085e31",
"metadata": {},
"outputs": [],
"source": [
"token_cnt = pickle.load(PT_CNTS_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "049dfd9f",
"metadata": {},
"outputs": [],
"source": [
"dataset = filter_by_count(dataset, min_count=MIN_COUNT, min_count_global=MIN_GLOBAL_COUNT, min_length=5, max_length=-1, \n",
" num_proc=NUM_PROC, token_cnt=token_cnt)"
]
},
{
"cell_type": "markdown",
"id": "0ca8e8e8",
"metadata": {},
"source": [
"### Split and save"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5feea2f",
"metadata": {},
"outputs": [],
"source": [
"# Total number of annotations per type\n",
"cnt_per_type = {}\n",
"for cui in token_cnt:\n",
" if cat.cdb.cui2type_ids[cui]:\n",
" t = list(cat.cdb.cui2type_ids[cui])[0]\n",
" cnt_per_type[t] = cnt_per_type.get(t, 0) + token_cnt[cui]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f743cfe",
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.train_test_split(test_size = 0.05)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49dffff3",
"metadata": {},
"outputs": [],
"source": [
"dataset.save_to_disk(DATA_PATH_SPLITS)"
]
},
{
"cell_type": "markdown",
"id": "2580b573",
"metadata": {},
"source": [
"# CONTINUE FROM HERE WHEN NOT THE FIRST RUN"
]
},
{
"cell_type": "markdown",
"id": "3c535b25",
"metadata": {},
"source": [
"### Load splits"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acf777de",
"metadata": {},
"outputs": [],
"source": [
"token_cnt = pickle.load(PT_CNTS_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab30780b",
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_from_disk(DATA_PATH_SPLITS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3929e28a",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fe3b048",
"metadata": {},
"outputs": [],
"source": [
"fprint(\"Total number of pts in train/test: {}/{}\".format(len(dataset['train']), len(dataset['test'])))"
]
},
{
"cell_type": "markdown",
"id": "2d7964a1",
"metadata": {},
"source": [
"# Filter to required type"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1067d23e",
"metadata": {},
"outputs": [],
"source": [
"if \"ALL_TYPES\" not in TYPES:\n",
" print(\"FILTERING\")\n",
" dataset = filter_by_type(dataset, types_to_keep=TYPES, num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "58033ab4",
"metadata": {},
"source": [
"# Add Death token"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9db46905",
"metadata": {},
"outputs": [],
"source": [
"pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
"pt2death = {k:\"The patient has died\" for k in pt2dod_timestamp.keys()}\n",
"dataset = dataset.map(\n",
" lambda examples: add_to_stream(examples, pt2death, last=True, prefix=None, token_type='death'),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "b5c995e2",
"metadata": {},
"source": [
"# Bucket and split"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c019959",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"dataset = dataset.map(\n",
" lambda examples: bucket_concepts(examples, bucket_size_seconds=DAYS*24*60*60, duration_separator=False),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18804746",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"id": "27383ac4",
"metadata": {},
"source": [
"## Trim long streams"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "460f779e",
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"lns = []\n",
"for _dataset in get_all_splits(dataset):\n",
" for stream in _dataset['stream']:\n",
" lns.append(len(stream))\n",
"pickle.dump(lns, PT_LNS_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2c18011",
"metadata": {},
"outputs": [],
"source": [
"lns = pickle.load(PT_LNS_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f44e23e0",
"metadata": {},
"outputs": [],
"source": [
"len(lns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ebf1af4",
"metadata": {},
"outputs": [],
"source": [
"max(lns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78d8b228",
"metadata": {},
"outputs": [],
"source": [
"max_len = int(np.percentile(lns, 95))\n",
"max_len"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f50eaf63",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"fig = px.histogram(x=[x for x in lns if x < max_len and x > 5], labels={'x': 'length'})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67531287",
"metadata": {},
"outputs": [],
"source": [
"fig.write_html(\"./dataset-info/\" + RUN_NAME + \".html\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5426144f",
"metadata": {},
"outputs": [],
"source": [
"dataset = filter_by_count(dataset, min_count=0, min_count_global=0, min_length=10, max_length=max_len, \n",
" num_proc=NUM_PROC, token_cnt=token_cnt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71c14b2b",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"id": "fb87c646",
"metadata": {},
"source": [
"## Split to max len"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ffc1da8",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"dataset = dataset.map(\n",
" lambda examples: split_stream(examples, max_seq_len=MAX_SEQ_LEN-32),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "577beded",
"metadata": {},
"source": [
"## Save again"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e18f654",
"metadata": {},
"outputs": [],
"source": [
"dataset.save_to_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da54723d",
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_from_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b10396a",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"id": "60b7eb32",
"metadata": {},
"source": [
"# Add DOD and TTD"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d04702ae",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"pt2dob_timestamp = {str(k):v for k,v in pickle.load(PT_DOB_PATH).items()}\n",
"dataset = dataset.map(\n",
" lambda examples: add_age(examples, pt2dob_timestamp=pt2dob_timestamp),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79d57596",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
"# ADD time to die\n",
"dataset = dataset.map(\n",
" lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"id": "312a702f",
"metadata": {},
"source": [
"### Another way for TTD"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88aca6b0",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"# ADD time to die\n",
"dataset['train'] = dataset['train'].map(\n",
" lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)\n",
"\n",
"dataset['test'] = dataset['test'].map(\n",
" lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60,\n",
" max_nttd=10, ttd_prob=1, duplicate_streams=True),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"id": "a520f26a",
"metadata": {},
"source": [
"# Add sex and ethnicity"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbcf3d29",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Add Sex\n",
"pt2sex = pickle.load(PT_SEX_PATH)\n",
"dataset = dataset.map(\n",
" lambda examples: add_to_stream(examples, pt2sex, last=False, prefix='<SEX>', token_type='sex'),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14005ffc",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Ethnicity\n",
"pt2ethnicity = pickle.load(PT_ETHNICITY_PATH)\n",
"dataset = dataset.map(\n",
" lambda examples: add_to_stream(examples, pt2ethnicity, last=False, prefix='<ETHNICITY>', token_type='ethnicity'),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "b0dd9e11",
"metadata": {},
"source": [
"# Final filter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bdb7f59b",
"metadata": {},
"outputs": [],
"source": [
"dataset = filter_by_count(dataset, min_count=None, min_count_global=None, min_length=10, num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "f1ce6130",
"metadata": {},
"source": [
"# Remove parents"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47f86d59",
"metadata": {},
"outputs": [],
"source": [
"# Diseases\n",
"cuis = [token for token in cdb.config.linking['filters']['cuis'] if token in cdb.cui2names]\n",
"ch2parents = get_parents_map(cuis, cdb.addl_info['pt2ch'], depth=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64b4664f",
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.map(\n",
" lambda examples: remove_parents_from_stream(examples, ch2parents=ch2parents, separator='<SEP>'),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "0a3aa069",
"metadata": {},
"source": [
"## Add position IDs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0858f524",
"metadata": {},
"outputs": [],
"source": [
"dataset = dataset.map(\n",
" lambda examples: add_position_ids(examples, separators={'<SEP>', '<SEP-1>', '<SEP-7>' '<SEP-14>', '<SEP-30>', '<SEP-365>'}),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "cf641980",
"metadata": {},
"source": [
"# Get token_type2tokens"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bef3781",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"token_type2tokens = defaultdict(set)\n",
"total_cnt = 0\n",
"for _dataset in get_all_splits(dataset):\n",
" for stream in _dataset['stream']:\n",
" for example in stream:\n",
" token_type2tokens[example['token_type']].add(example['token'])\n",
" total_cnt += 1\n",
"token_type2tokens = dict(token_type2tokens)\n",
"pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
"fprint(\"Total number of annotations: \", total_cnt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bb1f2a1",
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
"fprint(\"Total number of annotations: \", total_cnt)"
]
},
{
"cell_type": "markdown",
"id": "61c34003",
"metadata": {},
"source": [
"# Cleanup stream and leave only what we need"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33d5fbdb",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"dataset = dataset.map(\n",
" lambda examples: cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True,\n",
" keep_context_representation=False),\n",
" batched=True,\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "markdown",
"id": "71ad6184",
"metadata": {},
"source": [
"### Save"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc98c4f0",
"metadata": {},
"outputs": [],
"source": [
"dataset.save_to_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2b2c1e1",
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41ac5775",
"metadata": {},
"outputs": [],
"source": [
"JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38c09fd5",
"metadata": {},
"outputs": [],
"source": [
"# Total number of patients fater intial filtering\n",
"train_len = len(dataset['train'])\n",
"test_len = len(dataset['test'])\n",
"fprint(\"Total number of pts in train: \", train_len)\n",
"fprint(\"Total number of pts in test: \", test_len)\n",
"fprint(\"Total number of pts: \", train_len + test_len)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29bcc0c4",
"metadata": {},
"outputs": [],
"source": [
"# Total number of annotations per type after filtering\n",
"cnt_per_type_after = {}\n",
"for _dataset in get_all_splits(dataset):\n",
" for stream in _dataset['stream']:\n",
" for cui in stream:\n",
" if cat.cdb.cui2type_ids.get(cui, None):\n",
" t = list(cat.cdb.cui2type_ids[cui])[0]\n",
" cnt_per_type_after[t] = cnt_per_type_after.get(t, 0) + 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "208c24fb",
"metadata": {},
"outputs": [],
"source": [
"fprint(\"Total number of annotations per type: \\n\")\n",
"for t in cnt_per_type_after:\n",
" fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type_after[t]))"
]
},
{
"cell_type": "markdown",
"id": "b2b4d62e",
"metadata": {},
"source": [
"# Make tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cf2ec1a",
"metadata": {},
"outputs": [],
"source": [
"extra_tokenizer = None\n",
"#extra_tokenizer = SimpleMapTokenizer.load(\"./data/time/models/slam_tokenizer_annotations_stream_phase2_1d_200_ALL_TYPES.pickle\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e150f156",
"metadata": {},
"outputs": [],
"source": [
"token_type2tokens = pickle.load(TOKEN_TYPES_PATH)\n",
"extra_concepts = None\n",
"if extra_tokenizer is not None:\n",
" extra_concepts = list(extra_tokenizer.tkn2id.keys())\n",
"\n",
" for k,v in extra_tokenizer.token_type2tokens.items():\n",
" if k in token_type2tokens:\n",
" token_type2tokens[k].update(extra_tokenizer.token_type2tokens[k])\n",
" else:\n",
" token_type2tokens[k] = extra_tokenizer.token_type2tokens[k]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "767d8549",
"metadata": {},
"outputs": [],
"source": [
"_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
"embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,\n",
" concepts=extra_concepts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3b9946e",
"metadata": {},
"outputs": [],
"source": [
"tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
"tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
" token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
" global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51b9cce4",
"metadata": {},
"outputs": [],
"source": [
"assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)\n",
"assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)\n",
"assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)\n",
"fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "031b224b",
"metadata": {},
"outputs": [],
"source": [
"len(tokenizer.tkn2name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0295fe2d",
"metadata": {},
"outputs": [],
"source": [
"# save\n",
"tokenizer.save(TOKENIZER_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6cd9f04",
"metadata": {},
"outputs": [],
"source": [
"# Total number of different concepts after all filtering\n",
"fprint(\"Total number of concepts after filtering: \", len(tokenizer.tkn2id))\n",
"fprint(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9f996f0",
"metadata": {},
"outputs": [],
"source": [
"# Total number annotations after all filtering\n",
"fprint(\"Total number of annotations after filtering: \", sum([x for x in cnt_per_type_after.values()]))\n",
"fprint(\"\")"
]
},
{
"cell_type": "markdown",
"id": "862166a9",
"metadata": {},
"source": [
"# Print number of different concepts per type after filtering"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82ebd72c",
"metadata": {},
"outputs": [],
"source": [
"cnt_per_type = {}\n",
"for cui in tkn2id:\n",
" if cat.cdb.cui2type_ids.get(cui, ['Other']):\n",
" t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]\n",
" cnt_per_type[t] = cnt_per_type.get(t, 0) + 1\n",
"fprint(\"Total number of <<different>> concepts per type after filtering\")\n",
"for t in cnt_per_type:\n",
" fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))\n",
"fprint(\"\")"
]
},
{
"cell_type": "markdown",
"id": "19cf388d",
"metadata": {},
"source": [
"# Create global tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac9c9839",
"metadata": {},
"outputs": [],
"source": [
"_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
"concepts = list(cat.config.linking['filters']['cuis'])\n",
"embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eabbbf88",
"metadata": {},
"outputs": [],
"source": [
"tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
"tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
" token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
" global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22833eec",
"metadata": {},
"outputs": [],
"source": [
"tokenizer.save(BASE_TOKENIZER_PATH)"
]
},
{
"cell_type": "markdown",
"id": "4c06d09b",
"metadata": {},
"source": [
"# Convert tokens to IDs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "864df946",
"metadata": {},
"outputs": [],
"source": [
"if FROM_BASE:\n",
" print(\"USING BASE TOKENIZER\")\n",
" TOKENIZER_PATH = BASE_TOKENIZER_PATH"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4540c4a",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "548e8e3b",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset = dataset.map(\n",
" lambda examples: tokenizer.encode(examples),\n",
" batched=True,\n",
" remove_columns=['stream'],\n",
" load_from_cache_file=False,\n",
" num_proc=NUM_PROC)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8718ae76",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "340dd11d",
"metadata": {},
"outputs": [],
"source": [
"PREPARED_DATASET_SPLIT_PATH"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e512ab8c",
"metadata": {},
"outputs": [],
"source": [
"TOKENIZER_PATH"
]
},
{
"cell_type": "markdown",
"id": "6715bb47",
"metadata": {},
"source": [
"# Test is all OK"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60881f72",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f3c066f",
"metadata": {},
"outputs": [],
"source": [
"dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6dab3686",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "035a55fb",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b89f84af",
"metadata": {},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21d66a24",
"metadata": {},
"outputs": [],
"source": [
"ind = 1096"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "664dd895",
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49b1cbc1",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ec0594e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):\n",
" print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19003ff8",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset['train'][ind]['patient_id']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fc5460d",
"metadata": {},
"outputs": [],
"source": [
"ds_info.close()"
]
},
{
"cell_type": "markdown",
"id": "e5a3cc65",
"metadata": {},
"source": [
"# Preapre for Foresight"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2dec8a19",
"metadata": {},
"outputs": [],
"source": [
"ind = 32330"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2821af66",
"metadata": {},
"outputs": [],
"source": [
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55457df7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "196eaf9d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"for i, c in enumerate(dataset['train'][ind]['stream']):\n",
" print(i)\n",
" if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:\n",
" print(i, c, cdb.get_name(c))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a159ae3",
"metadata": {},
"outputs": [],
"source": [
"out = []\n",
"for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):\n",
" d = {\n",
" 'id': cui,\n",
" 'label': cdb.get_name(cui),\n",
" 'count': 1000000,\n",
" 'name': cdb.get_name(cui),\n",
" 'cui': cui,\n",
" 'saliency': 0,\n",
" 'uid': i\n",
" }\n",
" out.append(d)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97f40a06",
"metadata": {},
"outputs": [],
"source": [
"json.dump(out, open(\"./data/tmp/timeline_example_1.json\", 'w'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da6e079d",
"metadata": {},
"outputs": [],
"source": [
"len(out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2730480",
"metadata": {},
"outputs": [],
"source": [
"out"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}