Switch to unified view

a b/experiments/Foresight MIMIC Final -- Prepare data.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "id": "01ae6e4a",
7
   "metadata": {},
8
   "outputs": [],
9
   "source": [
10
    "import torch\n",
11
    "import os\n",
12
    "import datasets\n",
13
    "import numpy as np\n",
14
    "from collections import defaultdict\n",
15
    "from medcat.cat import CAT\n",
16
    "from foresight.datasets import patient_concept_stream\n",
17
    "from foresight.datasets.filters import filter_by_count, filter_by_type\n",
18
    "from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \\\n",
19
    "                                  remove_parents_from_stream, bucket_concepts, cleanup_stream, \\\n",
20
    "                                  split_stream, add_age, get_all_splits, add_ttd, add_position_ids\n",
21
    "from foresight.utils import pickle\n",
22
    "from foresight.utils.cdb_utils import get_parents_map \n",
23
    "from foresight.utils.stream_utils import docs2stream, calculate_counts\n",
24
    "from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n",
25
    "from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n",
26
    "from medcat.cdb import CDB\n",
27
    "from foresight.utils import pickle\n",
28
    "import plotly.express as px"
29
   ]
30
  },
31
  {
32
   "cell_type": "code",
33
   "execution_count": null,
34
   "id": "9f57b67c",
35
   "metadata": {},
36
   "outputs": [],
37
   "source": [
38
    "DAYS = 1 # Do: 1, 14, 30\n",
39
    "MAX_SEQ_LEN = 256\n",
40
    "#TYPES = ['T-45', 'T-55', 'T-26', 'T-29', 'T-40', 'T-9', 'T-27', 'T-11', 'T-39', 'T-18']\n",
41
    "TYPES = ['ALL_TYPES']\n",
42
    "#TYPES = ['T-11']\n",
43
    "#TYPES = ['T-11', 'T-18']"
44
   ]
45
  },
46
  {
47
   "cell_type": "code",
48
   "execution_count": null,
49
   "id": "5ac3ee70",
50
   "metadata": {},
51
   "outputs": [],
52
   "source": [
53
    "BASE_NAME = 'annotated_february_2022'\n",
54
    "DATASET_NAME = 'annotations_stream_phase2_v1'\n",
55
    "RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'"
56
   ]
57
  },
58
  {
59
   "cell_type": "code",
60
   "execution_count": null,
61
   "id": "c865ce1e",
62
   "metadata": {},
63
   "outputs": [],
64
   "source": [
65
    "ds_info = open(\"dataset-info/\" + RUN_NAME + '.txt', 'w')\n",
66
    "def fprint(*texts):\n",
67
    "    for text in texts:\n",
68
    "        print(text)\n",
69
    "        ds_info.write(str(text) + \"\\n\")"
70
   ]
71
  },
72
  {
73
   "cell_type": "code",
74
   "execution_count": null,
75
   "id": "922abc5f",
76
   "metadata": {},
77
   "outputs": [],
78
   "source": [
79
    "FROM_BASE = False\n",
80
    "BASE_TOKENIZER_PATH = ''"
81
   ]
82
  },
83
  {
84
   "cell_type": "code",
85
   "execution_count": null,
86
   "id": "c32b0b92",
87
   "metadata": {},
88
   "outputs": [],
89
   "source": [
90
    "TYPES = set(TYPES)"
91
   ]
92
  },
93
  {
94
   "cell_type": "code",
95
   "execution_count": null,
96
   "id": "6efddbcc",
97
   "metadata": {},
98
   "outputs": [],
99
   "source": [
100
    "DATA_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}.pickle\"\n",
101
    "DATA_PATH_SPLITS = f\"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}_split/\"\n",
102
    "TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n",
103
    "ALMOST_PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_almost_prepared_split/\"\n",
104
    "PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n",
105
    "JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_just_before_encoding/\"\n",
106
    "CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n",
107
    "PT_DOB_PATH = \"./data/mimic/pt2dob_datetime.pickle\"\n",
108
    "PT_DOD_PATH = \"./data/mimic/pt2dod_timestamp.pickle\"\n",
109
    "PT_SEX_PATH = \"./data/mimic/pt2sex.pickle\"\n",
110
    "PT_LNS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/lns_{DATASET_NAME}.pickle\"\n",
111
    "PT_CNTS_PATH = f\"./data/timecat/mimic/{BASE_NAME}/cnts_{DATASET_NAME}.pickle\"\n",
112
    "PT_ETHNICITY_PATH = \"./data/mimic/pt2ethnicity.pickle\"\n",
113
    "TOKEN_TYPES_PATH = f'./data/timecat/mimic/{BASE_NAME}/types_{DATASET_NAME}.pickle'\n",
114
    "\n",
115
    "BATCH_SIZE = 200\n",
116
    "DEVICE = torch.device('cuda')\n",
117
    "NUM_PROC = 16\n",
118
    "MIN_COUNT = 2 # 3\n",
119
    "MIN_GLOBAL_COUNT = 100 # 1000"
120
   ]
121
  },
122
  {
123
   "cell_type": "code",
124
   "execution_count": null,
125
   "id": "a8eddb07",
126
   "metadata": {
127
    "scrolled": true
128
   },
129
   "outputs": [],
130
   "source": [
131
    "cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n",
132
    "cdb = cat.cdb"
133
   ]
134
  },
135
  {
136
   "cell_type": "markdown",
137
   "id": "c5b51800",
138
   "metadata": {},
139
   "source": [
140
    "# Convert docs.pickle into patient stream used by HF datasets"
141
   ]
142
  },
143
  {
144
   "cell_type": "code",
145
   "execution_count": null,
146
   "id": "9445cd0d",
147
   "metadata": {},
148
   "outputs": [],
149
   "source": [
150
    "doc2pt = pickle.load(\"./data/timecat/mimic/doc2pt.pickle\")"
151
   ]
152
  },
153
  {
154
   "cell_type": "code",
155
   "execution_count": null,
156
   "id": "41568f81",
157
   "metadata": {},
158
   "outputs": [],
159
   "source": [
160
    "doc2pt = {str(k):v for k,v in doc2pt.items()}"
161
   ]
162
  },
163
  {
164
   "cell_type": "markdown",
165
   "id": "e5bdc483",
166
   "metadata": {},
167
   "source": [
168
    "### Get counts"
169
   ]
170
  },
171
  {
172
   "cell_type": "code",
173
   "execution_count": null,
174
   "id": "ad82350f",
175
   "metadata": {},
176
   "outputs": [],
177
   "source": [
178
    "pt2cui2cnt = None\n",
179
    "base_path = './data/timecat/mimic/annotated_february_2022/'\n",
180
    "doc_paths = os.listdir(base_path)\n",
181
    "doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
182
    "\n",
183
    "for path in doc_paths:\n",
184
    "    docs = pickle.load(os.path.join(base_path, path))\n",
185
    "    \n",
186
    "    pt2cui2cnt = calculate_counts(docs=docs,\n",
187
    "                     doc2pt=doc2pt,\n",
188
    "                     pt2cui2cnt=pt2cui2cnt,\n",
189
    "                     meta_requirements={'Presence': 'True', 'Subject': 'Patient'})"
190
   ]
191
  },
192
  {
193
   "cell_type": "code",
194
   "execution_count": null,
195
   "id": "063bf5b0",
196
   "metadata": {},
197
   "outputs": [],
198
   "source": [
199
    "pickle.dump(dict(pt2cui2cnt), f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
200
   ]
201
  },
202
  {
203
   "cell_type": "code",
204
   "execution_count": null,
205
   "id": "d0355f5e",
206
   "metadata": {},
207
   "outputs": [],
208
   "source": [
209
    "pt2cui2cnt = pickle.load(f\"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle\")"
210
   ]
211
  },
212
  {
213
   "cell_type": "code",
214
   "execution_count": null,
215
   "id": "355925a7",
216
   "metadata": {},
217
   "outputs": [],
218
   "source": [
219
    "# Total number of annotations per type\n",
220
    "cnt_per_type = {}\n",
221
    "other_cnt = 0\n",
222
    "for pt in pt2cui2cnt:\n",
223
    "    for cui in pt2cui2cnt[pt]:\n",
224
    "        if cat.cdb.cui2type_ids[cui]:\n",
225
    "            t = list(cat.cdb.cui2type_ids[cui])[0]\n",
226
    "            cnt_per_type[t] = cnt_per_type.get(t, 0) + pt2cui2cnt[pt][cui]\n",
227
    "        else:\n",
228
    "            other_cnt += 1"
229
   ]
230
  },
231
  {
232
   "cell_type": "code",
233
   "execution_count": null,
234
   "id": "79c7b309",
235
   "metadata": {},
236
   "outputs": [],
237
   "source": [
238
    "fprint(\"Total number of annotations per type: \")\n",
239
    "for t in cnt_per_type:\n",
240
    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type[t]))\n",
241
    "fprint(\"\")"
242
   ]
243
  },
244
  {
245
   "cell_type": "code",
246
   "execution_count": null,
247
   "id": "c513bf0c",
248
   "metadata": {},
249
   "outputs": [],
250
   "source": [
251
    "fprint(\"Total number of annotations: \", sum([x for x in cnt_per_type.values()]))\n",
252
    "fprint(\"\")"
253
   ]
254
  },
255
  {
256
   "cell_type": "code",
257
   "execution_count": null,
258
   "id": "6a0375ab",
259
   "metadata": {},
260
   "outputs": [],
261
   "source": [
262
    "# Get total number of different concepts\n",
263
    "all_cuis = set()\n",
264
    "for pt in pt2cui2cnt.keys():\n",
265
    "    for cui in pt2cui2cnt[pt]: \n",
266
    "        all_cuis.add(cui)\n",
267
    "fprint(\"Total number of different concepts: \", len(all_cuis))\n",
268
    "fprint(\"\")"
269
   ]
270
  },
271
  {
272
   "cell_type": "code",
273
   "execution_count": null,
274
   "id": "5e5e6c4f",
275
   "metadata": {},
276
   "outputs": [],
277
   "source": [
278
    "# Total number of patients\n",
279
    "fprint(\"Total number of patients: \", len(pt2cui2cnt))\n",
280
    "fprint(\"\")"
281
   ]
282
  },
283
  {
284
   "cell_type": "markdown",
285
   "id": "8e3471c9",
286
   "metadata": {},
287
   "source": [
288
    "### Get pt2stream"
289
   ]
290
  },
291
  {
292
   "cell_type": "code",
293
   "execution_count": null,
294
   "id": "deb8bf33",
295
   "metadata": {},
296
   "outputs": [],
297
   "source": [
298
    "base_path = f'./data/timecat/mimic/{BASE_NAME}/'\n",
299
    "doc_paths = os.listdir(base_path)\n",
300
    "doc_paths = [path for path in doc_paths if path.startswith(\"part_\")] # So we keep only annotations data\n",
301
    "pt2stream = None\n",
302
    "doc2time =  {str(k):v for k,v in pickle.load(\"./data/timecat/mimic/doc2time.pickle\").items()}\n",
303
    "\n",
304
    "for path in doc_paths:\n",
305
    "    docs = pickle.load(os.path.join(base_path, path))\n",
306
    "    pt2stream = docs2stream(docs,\n",
307
    "                            doc2pt=doc2pt,\n",
308
    "                            pt2cui2cnt=pt2cui2cnt,\n",
309
    "                            doc2time=doc2time,\n",
310
    "                            entity_type_column='type_ids',\n",
311
    "                            meta_requirements={'Subject': 'Patient', 'Presence': 'True'},\n",
312
    "                            historical_meta='Time',\n",
313
    "                            historical_meta_value='Past',\n",
314
    "                            old_pt2stream=pt2stream,\n",
315
    "                            skip_cuis={'S-418023006', '17971005'},\n",
316
    "                            require_time=True)"
317
   ]
318
  },
319
  {
320
   "cell_type": "code",
321
   "execution_count": null,
322
   "id": "c89f0594",
323
   "metadata": {},
324
   "outputs": [],
325
   "source": [
326
    "pickle.dump(dict(pt2stream), DATA_PATH)"
327
   ]
328
  },
329
  {
330
   "cell_type": "markdown",
331
   "id": "6cfb0fa2",
332
   "metadata": {},
333
   "source": [
334
    "# Load dataset"
335
   ]
336
  },
337
  {
338
   "cell_type": "code",
339
   "execution_count": null,
340
   "id": "ba72e618",
341
   "metadata": {
342
    "scrolled": true
343
   },
344
   "outputs": [],
345
   "source": [
346
    "dataset = datasets.load_dataset(os.path.abspath(patient_concept_stream.__file__), \n",
347
    "                                data_files={'train': DATA_PATH})['train']"
348
   ]
349
  },
350
  {
351
   "cell_type": "markdown",
352
   "id": "305eed65",
353
   "metadata": {},
354
   "source": [
355
    "# Calculate counts"
356
   ]
357
  },
358
  {
359
   "cell_type": "code",
360
   "execution_count": null,
361
   "id": "28bb5e02",
362
   "metadata": {},
363
   "outputs": [],
364
   "source": [
365
    "# Calculate counts for tokens\n",
366
    "token_cnt = defaultdict(int)\n",
367
    "for _dataset in get_all_splits(dataset):\n",
368
    "    for stream in _dataset['stream']:\n",
369
    "        unique_tokens = set([sample['token'] for sample in stream])\n",
370
    "        for token in unique_tokens:\n",
371
    "            token_cnt[token] += 1\n",
372
    "token_cnt = dict(token_cnt)"
373
   ]
374
  },
375
  {
376
   "cell_type": "code",
377
   "execution_count": null,
378
   "id": "d5fbd7c7",
379
   "metadata": {},
380
   "outputs": [],
381
   "source": [
382
    "pickle.dump(token_cnt, PT_CNTS_PATH)"
383
   ]
384
  },
385
  {
386
   "cell_type": "code",
387
   "execution_count": null,
388
   "id": "9c928e97",
389
   "metadata": {},
390
   "outputs": [],
391
   "source": [
392
    "MIN_GLOBAL_COUNT = 100 # 1000"
393
   ]
394
  },
395
  {
396
   "cell_type": "markdown",
397
   "id": "3b6766cf",
398
   "metadata": {},
399
   "source": [
400
    "# Load and filter by count"
401
   ]
402
  },
403
  {
404
   "cell_type": "code",
405
   "execution_count": null,
406
   "id": "04085e31",
407
   "metadata": {},
408
   "outputs": [],
409
   "source": [
410
    "token_cnt = pickle.load(PT_CNTS_PATH)"
411
   ]
412
  },
413
  {
414
   "cell_type": "code",
415
   "execution_count": null,
416
   "id": "049dfd9f",
417
   "metadata": {},
418
   "outputs": [],
419
   "source": [
420
    "dataset = filter_by_count(dataset, min_count=MIN_COUNT, min_count_global=MIN_GLOBAL_COUNT, min_length=5, max_length=-1, \n",
421
    "                          num_proc=NUM_PROC, token_cnt=token_cnt)"
422
   ]
423
  },
424
  {
425
   "cell_type": "markdown",
426
   "id": "0ca8e8e8",
427
   "metadata": {},
428
   "source": [
429
    "### Split and save"
430
   ]
431
  },
432
  {
433
   "cell_type": "code",
434
   "execution_count": null,
435
   "id": "b5feea2f",
436
   "metadata": {},
437
   "outputs": [],
438
   "source": [
439
    "# Total number of annotations per type\n",
440
    "cnt_per_type = {}\n",
441
    "for cui in token_cnt:\n",
442
    "    if cat.cdb.cui2type_ids[cui]:\n",
443
    "        t = list(cat.cdb.cui2type_ids[cui])[0]\n",
444
    "        cnt_per_type[t] = cnt_per_type.get(t, 0) + token_cnt[cui]"
445
   ]
446
  },
447
  {
448
   "cell_type": "code",
449
   "execution_count": null,
450
   "id": "5f743cfe",
451
   "metadata": {},
452
   "outputs": [],
453
   "source": [
454
    "dataset = dataset.train_test_split(test_size = 0.05)"
455
   ]
456
  },
457
  {
458
   "cell_type": "code",
459
   "execution_count": null,
460
   "id": "49dffff3",
461
   "metadata": {},
462
   "outputs": [],
463
   "source": [
464
    "dataset.save_to_disk(DATA_PATH_SPLITS)"
465
   ]
466
  },
467
  {
468
   "cell_type": "markdown",
469
   "id": "2580b573",
470
   "metadata": {},
471
   "source": [
472
    "# CONTINUE FROM HERE WHEN NOT THE FIRST RUN"
473
   ]
474
  },
475
  {
476
   "cell_type": "markdown",
477
   "id": "3c535b25",
478
   "metadata": {},
479
   "source": [
480
    "### Load splits"
481
   ]
482
  },
483
  {
484
   "cell_type": "code",
485
   "execution_count": null,
486
   "id": "acf777de",
487
   "metadata": {},
488
   "outputs": [],
489
   "source": [
490
    "token_cnt = pickle.load(PT_CNTS_PATH)"
491
   ]
492
  },
493
  {
494
   "cell_type": "code",
495
   "execution_count": null,
496
   "id": "ab30780b",
497
   "metadata": {},
498
   "outputs": [],
499
   "source": [
500
    "dataset = datasets.load_from_disk(DATA_PATH_SPLITS)"
501
   ]
502
  },
503
  {
504
   "cell_type": "code",
505
   "execution_count": null,
506
   "id": "3929e28a",
507
   "metadata": {},
508
   "outputs": [],
509
   "source": [
510
    "dataset"
511
   ]
512
  },
513
  {
514
   "cell_type": "code",
515
   "execution_count": null,
516
   "id": "6fe3b048",
517
   "metadata": {},
518
   "outputs": [],
519
   "source": [
520
    "fprint(\"Total number of pts in train/test: {}/{}\".format(len(dataset['train']), len(dataset['test'])))"
521
   ]
522
  },
523
  {
524
   "cell_type": "markdown",
525
   "id": "2d7964a1",
526
   "metadata": {},
527
   "source": [
528
    "# Filter to required type"
529
   ]
530
  },
531
  {
532
   "cell_type": "code",
533
   "execution_count": null,
534
   "id": "1067d23e",
535
   "metadata": {},
536
   "outputs": [],
537
   "source": [
538
    "if \"ALL_TYPES\" not in TYPES:\n",
539
    "    print(\"FILTERING\")\n",
540
    "    dataset = filter_by_type(dataset, types_to_keep=TYPES, num_proc=NUM_PROC)"
541
   ]
542
  },
543
  {
544
   "cell_type": "markdown",
545
   "id": "58033ab4",
546
   "metadata": {},
547
   "source": [
548
    "# Add Death token"
549
   ]
550
  },
551
  {
552
   "cell_type": "code",
553
   "execution_count": null,
554
   "id": "9db46905",
555
   "metadata": {},
556
   "outputs": [],
557
   "source": [
558
    "pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
559
    "pt2death = {k:\"The patient has died\" for k in pt2dod_timestamp.keys()}\n",
560
    "dataset = dataset.map(\n",
561
    "        lambda examples: add_to_stream(examples, pt2death, last=True, prefix=None, token_type='death'),\n",
562
    "        batched=True,\n",
563
    "        load_from_cache_file=False,\n",
564
    "        num_proc=NUM_PROC)"
565
   ]
566
  },
567
  {
568
   "cell_type": "markdown",
569
   "id": "b5c995e2",
570
   "metadata": {},
571
   "source": [
572
    "# Bucket and split"
573
   ]
574
  },
575
  {
576
   "cell_type": "code",
577
   "execution_count": null,
578
   "id": "3c019959",
579
   "metadata": {
580
    "scrolled": true
581
   },
582
   "outputs": [],
583
   "source": [
584
    "dataset = dataset.map(\n",
585
    "        lambda examples: bucket_concepts(examples, bucket_size_seconds=DAYS*24*60*60, duration_separator=False),\n",
586
    "        batched=True,\n",
587
    "        load_from_cache_file=False,\n",
588
    "        num_proc=NUM_PROC)"
589
   ]
590
  },
591
  {
592
   "cell_type": "code",
593
   "execution_count": null,
594
   "id": "18804746",
595
   "metadata": {},
596
   "outputs": [],
597
   "source": [
598
    "dataset"
599
   ]
600
  },
601
  {
602
   "cell_type": "markdown",
603
   "id": "27383ac4",
604
   "metadata": {},
605
   "source": [
606
    "## Trim long streams"
607
   ]
608
  },
609
  {
610
   "cell_type": "code",
611
   "execution_count": null,
612
   "id": "460f779e",
613
   "metadata": {},
614
   "outputs": [],
615
   "source": [
616
    "from collections import defaultdict\n",
617
    "lns = []\n",
618
    "for _dataset in get_all_splits(dataset):\n",
619
    "    for stream in _dataset['stream']:\n",
620
    "        lns.append(len(stream))\n",
621
    "pickle.dump(lns, PT_LNS_PATH)"
622
   ]
623
  },
624
  {
625
   "cell_type": "code",
626
   "execution_count": null,
627
   "id": "e2c18011",
628
   "metadata": {},
629
   "outputs": [],
630
   "source": [
631
    "lns = pickle.load(PT_LNS_PATH)"
632
   ]
633
  },
634
  {
635
   "cell_type": "code",
636
   "execution_count": null,
637
   "id": "f44e23e0",
638
   "metadata": {},
639
   "outputs": [],
640
   "source": [
641
    "len(lns)"
642
   ]
643
  },
644
  {
645
   "cell_type": "code",
646
   "execution_count": null,
647
   "id": "9ebf1af4",
648
   "metadata": {},
649
   "outputs": [],
650
   "source": [
651
    "max(lns)"
652
   ]
653
  },
654
  {
655
   "cell_type": "code",
656
   "execution_count": null,
657
   "id": "78d8b228",
658
   "metadata": {},
659
   "outputs": [],
660
   "source": [
661
    "max_len = int(np.percentile(lns, 95))\n",
662
    "max_len"
663
   ]
664
  },
665
  {
666
   "cell_type": "code",
667
   "execution_count": null,
668
   "id": "f50eaf63",
669
   "metadata": {
670
    "scrolled": true
671
   },
672
   "outputs": [],
673
   "source": [
674
    "fig = px.histogram(x=[x for x in lns if x < max_len and x > 5], labels={'x': 'length'})"
675
   ]
676
  },
677
  {
678
   "cell_type": "code",
679
   "execution_count": null,
680
   "id": "67531287",
681
   "metadata": {},
682
   "outputs": [],
683
   "source": [
684
    "fig.write_html(\"./dataset-info/\" + RUN_NAME + \".html\")"
685
   ]
686
  },
687
  {
688
   "cell_type": "code",
689
   "execution_count": null,
690
   "id": "5426144f",
691
   "metadata": {},
692
   "outputs": [],
693
   "source": [
694
    "dataset = filter_by_count(dataset, min_count=0, min_count_global=0, min_length=10, max_length=max_len, \n",
695
    "                          num_proc=NUM_PROC, token_cnt=token_cnt)"
696
   ]
697
  },
698
  {
699
   "cell_type": "code",
700
   "execution_count": null,
701
   "id": "71c14b2b",
702
   "metadata": {},
703
   "outputs": [],
704
   "source": [
705
    "dataset"
706
   ]
707
  },
708
  {
709
   "cell_type": "markdown",
710
   "id": "fb87c646",
711
   "metadata": {},
712
   "source": [
713
    "## Split to max len"
714
   ]
715
  },
716
  {
717
   "cell_type": "code",
718
   "execution_count": null,
719
   "id": "0ffc1da8",
720
   "metadata": {
721
    "scrolled": true
722
   },
723
   "outputs": [],
724
   "source": [
725
    "dataset = dataset.map(\n",
726
    "        lambda examples: split_stream(examples, max_seq_len=MAX_SEQ_LEN-32),\n",
727
    "        batched=True,\n",
728
    "        load_from_cache_file=False,\n",
729
    "        num_proc=NUM_PROC)"
730
   ]
731
  },
732
  {
733
   "cell_type": "markdown",
734
   "id": "577beded",
735
   "metadata": {},
736
   "source": [
737
    "## Save again"
738
   ]
739
  },
740
  {
741
   "cell_type": "code",
742
   "execution_count": null,
743
   "id": "9e18f654",
744
   "metadata": {},
745
   "outputs": [],
746
   "source": [
747
    "dataset.save_to_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
748
   ]
749
  },
750
  {
751
   "cell_type": "code",
752
   "execution_count": null,
753
   "id": "da54723d",
754
   "metadata": {},
755
   "outputs": [],
756
   "source": [
757
    "dataset = datasets.load_from_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)"
758
   ]
759
  },
760
  {
761
   "cell_type": "code",
762
   "execution_count": null,
763
   "id": "2b10396a",
764
   "metadata": {},
765
   "outputs": [],
766
   "source": [
767
    "dataset"
768
   ]
769
  },
770
  {
771
   "cell_type": "markdown",
772
   "id": "60b7eb32",
773
   "metadata": {},
774
   "source": [
775
    "# Add DOD and TTD"
776
   ]
777
  },
778
  {
779
   "cell_type": "code",
780
   "execution_count": null,
781
   "id": "d04702ae",
782
   "metadata": {
783
    "scrolled": true
784
   },
785
   "outputs": [],
786
   "source": [
787
    "pt2dob_timestamp = {str(k):v for k,v in pickle.load(PT_DOB_PATH).items()}\n",
788
    "dataset = dataset.map(\n",
789
    "        lambda examples: add_age(examples, pt2dob_timestamp=pt2dob_timestamp),\n",
790
    "        batched=True,\n",
791
    "        load_from_cache_file=False,\n",
792
    "        num_proc=NUM_PROC)"
793
   ]
794
  },
795
  {
796
   "cell_type": "code",
797
   "execution_count": null,
798
   "id": "79d57596",
799
   "metadata": {},
800
   "outputs": [],
801
   "source": [
802
    "\"\"\"\n",
803
    "pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}\n",
804
    "# ADD time to die\n",
805
    "dataset = dataset.map(\n",
806
    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
807
    "        batched=True,\n",
808
    "        load_from_cache_file=False,\n",
809
    "        num_proc=NUM_PROC)\n",
810
    "\"\"\""
811
   ]
812
  },
813
  {
814
   "cell_type": "markdown",
815
   "id": "312a702f",
816
   "metadata": {},
817
   "source": [
818
    "### Another way for TTD"
819
   ]
820
  },
821
  {
822
   "cell_type": "code",
823
   "execution_count": null,
824
   "id": "88aca6b0",
825
   "metadata": {},
826
   "outputs": [],
827
   "source": [
828
    "\"\"\"\n",
829
    "# ADD time to die\n",
830
    "dataset['train'] = dataset['train'].map(\n",
831
    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),\n",
832
    "        batched=True,\n",
833
    "        load_from_cache_file=False,\n",
834
    "        num_proc=NUM_PROC)\n",
835
    "\n",
836
    "dataset['test'] = dataset['test'].map(\n",
837
    "        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60,\n",
838
    "                                 max_nttd=10, ttd_prob=1, duplicate_streams=True),\n",
839
    "        batched=True,\n",
840
    "        load_from_cache_file=False,\n",
841
    "        num_proc=NUM_PROC)\n",
842
    "\"\"\""
843
   ]
844
  },
845
  {
846
   "cell_type": "markdown",
847
   "id": "a520f26a",
848
   "metadata": {},
849
   "source": [
850
    "# Add sex and ethnicity"
851
   ]
852
  },
853
  {
854
   "cell_type": "code",
855
   "execution_count": null,
856
   "id": "fbcf3d29",
857
   "metadata": {
858
    "scrolled": true
859
   },
860
   "outputs": [],
861
   "source": [
862
    "# Add Sex\n",
863
    "pt2sex = pickle.load(PT_SEX_PATH)\n",
864
    "dataset = dataset.map(\n",
865
    "        lambda examples: add_to_stream(examples, pt2sex, last=False, prefix='<SEX>', token_type='sex'),\n",
866
    "        batched=True,\n",
867
    "        load_from_cache_file=False,\n",
868
    "        num_proc=NUM_PROC)"
869
   ]
870
  },
871
  {
872
   "cell_type": "code",
873
   "execution_count": null,
874
   "id": "14005ffc",
875
   "metadata": {
876
    "scrolled": true
877
   },
878
   "outputs": [],
879
   "source": [
880
    "# Ethnicity\n",
881
    "pt2ethnicity = pickle.load(PT_ETHNICITY_PATH)\n",
882
    "dataset = dataset.map(\n",
883
    "        lambda examples: add_to_stream(examples, pt2ethnicity, last=False, prefix='<ETHNICITY>', token_type='ethnicity'),\n",
884
    "        batched=True,\n",
885
    "        load_from_cache_file=False,\n",
886
    "        num_proc=NUM_PROC)"
887
   ]
888
  },
889
  {
890
   "cell_type": "markdown",
891
   "id": "b0dd9e11",
892
   "metadata": {},
893
   "source": [
894
    "# Final filter"
895
   ]
896
  },
897
  {
898
   "cell_type": "code",
899
   "execution_count": null,
900
   "id": "bdb7f59b",
901
   "metadata": {},
902
   "outputs": [],
903
   "source": [
904
    "dataset = filter_by_count(dataset, min_count=None, min_count_global=None, min_length=10, num_proc=NUM_PROC)"
905
   ]
906
  },
907
  {
908
   "cell_type": "markdown",
909
   "id": "f1ce6130",
910
   "metadata": {},
911
   "source": [
912
    "# Remove parents"
913
   ]
914
  },
915
  {
916
   "cell_type": "code",
917
   "execution_count": null,
918
   "id": "47f86d59",
919
   "metadata": {},
920
   "outputs": [],
921
   "source": [
922
    "# Diseases\n",
923
    "cuis = [token for token in cdb.config.linking['filters']['cuis'] if token in cdb.cui2names]\n",
924
    "ch2parents = get_parents_map(cuis, cdb.addl_info['pt2ch'], depth=2)"
925
   ]
926
  },
927
  {
928
   "cell_type": "code",
929
   "execution_count": null,
930
   "id": "64b4664f",
931
   "metadata": {},
932
   "outputs": [],
933
   "source": [
934
    "dataset = dataset.map(\n",
935
    "        lambda examples: remove_parents_from_stream(examples, ch2parents=ch2parents, separator='<SEP>'),\n",
936
    "        batched=True,\n",
937
    "        load_from_cache_file=False,\n",
938
    "        num_proc=NUM_PROC)"
939
   ]
940
  },
941
  {
942
   "cell_type": "markdown",
943
   "id": "0a3aa069",
944
   "metadata": {},
945
   "source": [
946
    "## Add position IDs"
947
   ]
948
  },
949
  {
950
   "cell_type": "code",
951
   "execution_count": null,
952
   "id": "0858f524",
953
   "metadata": {},
954
   "outputs": [],
955
   "source": [
956
    "dataset = dataset.map(\n",
957
    "        lambda examples: add_position_ids(examples, separators={'<SEP>', '<SEP-1>', '<SEP-7>' '<SEP-14>', '<SEP-30>', '<SEP-365>'}),\n",
958
    "        batched=True,\n",
959
    "        load_from_cache_file=False,\n",
960
    "        num_proc=NUM_PROC)"
961
   ]
962
  },
963
  {
964
   "cell_type": "markdown",
965
   "id": "cf641980",
966
   "metadata": {},
967
   "source": [
968
    "# Get token_type2tokens"
969
   ]
970
  },
971
  {
972
   "cell_type": "code",
973
   "execution_count": null,
974
   "id": "9bef3781",
975
   "metadata": {
976
    "scrolled": true
977
   },
978
   "outputs": [],
979
   "source": [
980
    "token_type2tokens = defaultdict(set)\n",
981
    "total_cnt = 0\n",
982
    "for _dataset in get_all_splits(dataset):\n",
983
    "    for stream in _dataset['stream']:\n",
984
    "        for example in stream:\n",
985
    "            token_type2tokens[example['token_type']].add(example['token'])\n",
986
    "            total_cnt += 1\n",
987
    "token_type2tokens = dict(token_type2tokens)\n",
988
    "pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
989
    "fprint(\"Total number of annotations: \", total_cnt)"
990
   ]
991
  },
992
  {
993
   "cell_type": "code",
994
   "execution_count": null,
995
   "id": "8bb1f2a1",
996
   "metadata": {},
997
   "outputs": [],
998
   "source": [
999
    "pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)\n",
1000
    "fprint(\"Total number of annotations: \", total_cnt)"
1001
   ]
1002
  },
1003
  {
1004
   "cell_type": "markdown",
1005
   "id": "61c34003",
1006
   "metadata": {},
1007
   "source": [
1008
    "# Cleanup stream and leave only what we need"
1009
   ]
1010
  },
1011
  {
1012
   "cell_type": "code",
1013
   "execution_count": null,
1014
   "id": "33d5fbdb",
1015
   "metadata": {
1016
    "scrolled": true
1017
   },
1018
   "outputs": [],
1019
   "source": [
1020
    "dataset = dataset.map(\n",
1021
    "        lambda examples: cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True,\n",
1022
    "                                        keep_context_representation=False),\n",
1023
    "        batched=True,\n",
1024
    "        load_from_cache_file=False,\n",
1025
    "        num_proc=NUM_PROC)"
1026
   ]
1027
  },
1028
  {
1029
   "cell_type": "markdown",
1030
   "id": "71ad6184",
1031
   "metadata": {},
1032
   "source": [
1033
    "### Save"
1034
   ]
1035
  },
1036
  {
1037
   "cell_type": "code",
1038
   "execution_count": null,
1039
   "id": "fc98c4f0",
1040
   "metadata": {},
1041
   "outputs": [],
1042
   "source": [
1043
    "dataset.save_to_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
1044
   ]
1045
  },
1046
  {
1047
   "cell_type": "code",
1048
   "execution_count": null,
1049
   "id": "c2b2c1e1",
1050
   "metadata": {},
1051
   "outputs": [],
1052
   "source": [
1053
    "dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
1054
   ]
1055
  },
1056
  {
1057
   "cell_type": "code",
1058
   "execution_count": null,
1059
   "id": "41ac5775",
1060
   "metadata": {},
1061
   "outputs": [],
1062
   "source": [
1063
    "JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH"
1064
   ]
1065
  },
1066
  {
1067
   "cell_type": "code",
1068
   "execution_count": null,
1069
   "id": "38c09fd5",
1070
   "metadata": {},
1071
   "outputs": [],
1072
   "source": [
1073
    "# Total number of patients fater intial filtering\n",
1074
    "train_len = len(dataset['train'])\n",
1075
    "test_len = len(dataset['test'])\n",
1076
    "fprint(\"Total number of pts in train: \", train_len)\n",
1077
    "fprint(\"Total number of pts in test: \", test_len)\n",
1078
    "fprint(\"Total number of pts: \", train_len + test_len)"
1079
   ]
1080
  },
1081
  {
1082
   "cell_type": "code",
1083
   "execution_count": null,
1084
   "id": "29bcc0c4",
1085
   "metadata": {},
1086
   "outputs": [],
1087
   "source": [
1088
    "# Total number of annotations per type after filtering\n",
1089
    "cnt_per_type_after = {}\n",
1090
    "for _dataset in get_all_splits(dataset):\n",
1091
    "    for stream in _dataset['stream']:\n",
1092
    "        for cui in stream:\n",
1093
    "            if cat.cdb.cui2type_ids.get(cui, None):\n",
1094
    "                t = list(cat.cdb.cui2type_ids[cui])[0]\n",
1095
    "                cnt_per_type_after[t] = cnt_per_type_after.get(t, 0) + 1"
1096
   ]
1097
  },
1098
  {
1099
   "cell_type": "code",
1100
   "execution_count": null,
1101
   "id": "208c24fb",
1102
   "metadata": {},
1103
   "outputs": [],
1104
   "source": [
1105
    "fprint(\"Total number of annotations per type: \\n\")\n",
1106
    "for t in cnt_per_type_after:\n",
1107
    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type_after[t]))"
1108
   ]
1109
  },
1110
  {
1111
   "cell_type": "markdown",
1112
   "id": "b2b4d62e",
1113
   "metadata": {},
1114
   "source": [
1115
    "# Make tokenizer"
1116
   ]
1117
  },
1118
  {
1119
   "cell_type": "code",
1120
   "execution_count": null,
1121
   "id": "8cf2ec1a",
1122
   "metadata": {},
1123
   "outputs": [],
1124
   "source": [
1125
    "extra_tokenizer = None\n",
1126
    "#extra_tokenizer = SimpleMapTokenizer.load(\"./data/time/models/slam_tokenizer_annotations_stream_phase2_1d_200_ALL_TYPES.pickle\")"
1127
   ]
1128
  },
1129
  {
1130
   "cell_type": "code",
1131
   "execution_count": null,
1132
   "id": "e150f156",
1133
   "metadata": {},
1134
   "outputs": [],
1135
   "source": [
1136
    "token_type2tokens = pickle.load(TOKEN_TYPES_PATH)\n",
1137
    "extra_concepts = None\n",
1138
    "if extra_tokenizer is not None:\n",
1139
    "    extra_concepts = list(extra_tokenizer.tkn2id.keys())\n",
1140
    "\n",
1141
    "    for k,v in extra_tokenizer.token_type2tokens.items():\n",
1142
    "        if k in token_type2tokens:\n",
1143
    "            token_type2tokens[k].update(extra_tokenizer.token_type2tokens[k])\n",
1144
    "        else:\n",
1145
    "            token_type2tokens[k] = extra_tokenizer.token_type2tokens[k]"
1146
   ]
1147
  },
1148
  {
1149
   "cell_type": "code",
1150
   "execution_count": null,
1151
   "id": "767d8549",
1152
   "metadata": {},
1153
   "outputs": [],
1154
   "source": [
1155
    "_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
1156
    "embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,\n",
1157
    "                                                        concepts=extra_concepts)"
1158
   ]
1159
  },
1160
  {
1161
   "cell_type": "code",
1162
   "execution_count": null,
1163
   "id": "e3b9946e",
1164
   "metadata": {},
1165
   "outputs": [],
1166
   "source": [
1167
    "tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
1168
    "tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
1169
    "                               token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
1170
    "                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
1171
   ]
1172
  },
1173
  {
1174
   "cell_type": "code",
1175
   "execution_count": null,
1176
   "id": "51b9cce4",
1177
   "metadata": {},
1178
   "outputs": [],
1179
   "source": [
1180
    "assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)\n",
1181
    "assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)\n",
1182
    "assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)\n",
1183
    "fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])"
1184
   ]
1185
  },
1186
  {
1187
   "cell_type": "code",
1188
   "execution_count": null,
1189
   "id": "031b224b",
1190
   "metadata": {},
1191
   "outputs": [],
1192
   "source": [
1193
    "len(tokenizer.tkn2name)"
1194
   ]
1195
  },
1196
  {
1197
   "cell_type": "code",
1198
   "execution_count": null,
1199
   "id": "0295fe2d",
1200
   "metadata": {},
1201
   "outputs": [],
1202
   "source": [
1203
    "# save\n",
1204
    "tokenizer.save(TOKENIZER_PATH)"
1205
   ]
1206
  },
1207
  {
1208
   "cell_type": "code",
1209
   "execution_count": null,
1210
   "id": "f6cd9f04",
1211
   "metadata": {},
1212
   "outputs": [],
1213
   "source": [
1214
    "# Total number of different concepts after all filtering\n",
1215
    "fprint(\"Total number of concepts after filtering: \", len(tokenizer.tkn2id))\n",
1216
    "fprint(\"\")"
1217
   ]
1218
  },
1219
  {
1220
   "cell_type": "code",
1221
   "execution_count": null,
1222
   "id": "b9f996f0",
1223
   "metadata": {},
1224
   "outputs": [],
1225
   "source": [
1226
    "# Total number annotations after all filtering\n",
1227
    "fprint(\"Total number of annotations after filtering: \", sum([x for x in cnt_per_type_after.values()]))\n",
1228
    "fprint(\"\")"
1229
   ]
1230
  },
1231
  {
1232
   "cell_type": "markdown",
1233
   "id": "862166a9",
1234
   "metadata": {},
1235
   "source": [
1236
    "# Print number of different concepts per type after filtering"
1237
   ]
1238
  },
1239
  {
1240
   "cell_type": "code",
1241
   "execution_count": null,
1242
   "id": "82ebd72c",
1243
   "metadata": {},
1244
   "outputs": [],
1245
   "source": [
1246
    "cnt_per_type = {}\n",
1247
    "for cui in tkn2id:\n",
1248
    "    if cat.cdb.cui2type_ids.get(cui, ['Other']):\n",
1249
    "        t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]\n",
1250
    "        cnt_per_type[t] = cnt_per_type.get(t, 0) + 1\n",
1251
    "fprint(\"Total number of <<different>> concepts per type after filtering\")\n",
1252
    "for t in cnt_per_type:\n",
1253
    "    fprint(\"{:30}: {}\".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))\n",
1254
    "fprint(\"\")"
1255
   ]
1256
  },
1257
  {
1258
   "cell_type": "markdown",
1259
   "id": "19cf388d",
1260
   "metadata": {},
1261
   "source": [
1262
    "# Create global tokenizer"
1263
   ]
1264
  },
1265
  {
1266
   "cell_type": "code",
1267
   "execution_count": null,
1268
   "id": "ac9c9839",
1269
   "metadata": {},
1270
   "outputs": [],
1271
   "source": [
1272
    "_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())\n",
1273
    "concepts = list(cat.config.linking['filters']['cuis'])\n",
1274
    "embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)"
1275
   ]
1276
  },
1277
  {
1278
   "cell_type": "code",
1279
   "execution_count": null,
1280
   "id": "eabbbf88",
1281
   "metadata": {},
1282
   "outputs": [],
1283
   "source": [
1284
    "tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}\n",
1285
    "tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,\n",
1286
    "                               token_type2tokens=token_type2tokens, embeddings=embeddings,\n",
1287
    "                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)"
1288
   ]
1289
  },
1290
  {
1291
   "cell_type": "code",
1292
   "execution_count": null,
1293
   "id": "22833eec",
1294
   "metadata": {},
1295
   "outputs": [],
1296
   "source": [
1297
    "tokenizer.save(BASE_TOKENIZER_PATH)"
1298
   ]
1299
  },
1300
  {
1301
   "cell_type": "markdown",
1302
   "id": "4c06d09b",
1303
   "metadata": {},
1304
   "source": [
1305
    "# Convert tokens to IDs"
1306
   ]
1307
  },
1308
  {
1309
   "cell_type": "code",
1310
   "execution_count": null,
1311
   "id": "864df946",
1312
   "metadata": {},
1313
   "outputs": [],
1314
   "source": [
1315
    "if FROM_BASE:\n",
1316
    "    print(\"USING BASE TOKENIZER\")\n",
1317
    "    TOKENIZER_PATH = BASE_TOKENIZER_PATH"
1318
   ]
1319
  },
1320
  {
1321
   "cell_type": "code",
1322
   "execution_count": null,
1323
   "id": "e4540c4a",
1324
   "metadata": {},
1325
   "outputs": [],
1326
   "source": [
1327
    "tokenizer =  SimpleMapTokenizer.load(TOKENIZER_PATH)"
1328
   ]
1329
  },
1330
  {
1331
   "cell_type": "code",
1332
   "execution_count": null,
1333
   "id": "548e8e3b",
1334
   "metadata": {},
1335
   "outputs": [],
1336
   "source": [
1337
    "encoded_dataset = dataset.map(\n",
1338
    "        lambda examples: tokenizer.encode(examples),\n",
1339
    "        batched=True,\n",
1340
    "        remove_columns=['stream'],\n",
1341
    "        load_from_cache_file=False,\n",
1342
    "        num_proc=NUM_PROC)"
1343
   ]
1344
  },
1345
  {
1346
   "cell_type": "code",
1347
   "execution_count": null,
1348
   "id": "8718ae76",
1349
   "metadata": {},
1350
   "outputs": [],
1351
   "source": [
1352
    "encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)"
1353
   ]
1354
  },
1355
  {
1356
   "cell_type": "code",
1357
   "execution_count": null,
1358
   "id": "340dd11d",
1359
   "metadata": {},
1360
   "outputs": [],
1361
   "source": [
1362
    "PREPARED_DATASET_SPLIT_PATH"
1363
   ]
1364
  },
1365
  {
1366
   "cell_type": "code",
1367
   "execution_count": null,
1368
   "id": "e512ab8c",
1369
   "metadata": {},
1370
   "outputs": [],
1371
   "source": [
1372
    "TOKENIZER_PATH"
1373
   ]
1374
  },
1375
  {
1376
   "cell_type": "markdown",
1377
   "id": "6715bb47",
1378
   "metadata": {},
1379
   "source": [
1380
    "# Test is all OK"
1381
   ]
1382
  },
1383
  {
1384
   "cell_type": "code",
1385
   "execution_count": null,
1386
   "id": "60881f72",
1387
   "metadata": {},
1388
   "outputs": [],
1389
   "source": [
1390
    "encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)"
1391
   ]
1392
  },
1393
  {
1394
   "cell_type": "code",
1395
   "execution_count": null,
1396
   "id": "2f3c066f",
1397
   "metadata": {},
1398
   "outputs": [],
1399
   "source": [
1400
    "dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)"
1401
   ]
1402
  },
1403
  {
1404
   "cell_type": "code",
1405
   "execution_count": null,
1406
   "id": "6dab3686",
1407
   "metadata": {},
1408
   "outputs": [],
1409
   "source": [
1410
    "tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
1411
   ]
1412
  },
1413
  {
1414
   "cell_type": "code",
1415
   "execution_count": null,
1416
   "id": "035a55fb",
1417
   "metadata": {},
1418
   "outputs": [],
1419
   "source": [
1420
    "encoded_dataset"
1421
   ]
1422
  },
1423
  {
1424
   "cell_type": "code",
1425
   "execution_count": null,
1426
   "id": "b89f84af",
1427
   "metadata": {},
1428
   "outputs": [],
1429
   "source": [
1430
    "dataset"
1431
   ]
1432
  },
1433
  {
1434
   "cell_type": "code",
1435
   "execution_count": null,
1436
   "id": "21d66a24",
1437
   "metadata": {},
1438
   "outputs": [],
1439
   "source": [
1440
    "ind = 1096"
1441
   ]
1442
  },
1443
  {
1444
   "cell_type": "code",
1445
   "execution_count": null,
1446
   "id": "664dd895",
1447
   "metadata": {},
1448
   "outputs": [],
1449
   "source": [
1450
    "from datetime import datetime"
1451
   ]
1452
  },
1453
  {
1454
   "cell_type": "code",
1455
   "execution_count": null,
1456
   "id": "49b1cbc1",
1457
   "metadata": {
1458
    "scrolled": true
1459
   },
1460
   "outputs": [],
1461
   "source": [
1462
    "[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
1463
   ]
1464
  },
1465
  {
1466
   "cell_type": "code",
1467
   "execution_count": null,
1468
   "id": "9ec0594e",
1469
   "metadata": {
1470
    "scrolled": true
1471
   },
1472
   "outputs": [],
1473
   "source": [
1474
    "for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):\n",
1475
    "    print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c)"
1476
   ]
1477
  },
1478
  {
1479
   "cell_type": "code",
1480
   "execution_count": null,
1481
   "id": "19003ff8",
1482
   "metadata": {},
1483
   "outputs": [],
1484
   "source": [
1485
    "encoded_dataset['train'][ind]['patient_id']"
1486
   ]
1487
  },
1488
  {
1489
   "cell_type": "code",
1490
   "execution_count": null,
1491
   "id": "1fc5460d",
1492
   "metadata": {},
1493
   "outputs": [],
1494
   "source": [
1495
    "ds_info.close()"
1496
   ]
1497
  },
1498
  {
1499
   "cell_type": "markdown",
1500
   "id": "e5a3cc65",
1501
   "metadata": {},
1502
   "source": [
1503
    "# Preapre for Foresight"
1504
   ]
1505
  },
1506
  {
1507
   "cell_type": "code",
1508
   "execution_count": null,
1509
   "id": "2dec8a19",
1510
   "metadata": {},
1511
   "outputs": [],
1512
   "source": [
1513
    "ind = 32330"
1514
   ]
1515
  },
1516
  {
1517
   "cell_type": "code",
1518
   "execution_count": null,
1519
   "id": "2821af66",
1520
   "metadata": {},
1521
   "outputs": [],
1522
   "source": [
1523
    "import json"
1524
   ]
1525
  },
1526
  {
1527
   "cell_type": "code",
1528
   "execution_count": null,
1529
   "id": "55457df7",
1530
   "metadata": {
1531
    "scrolled": true
1532
   },
1533
   "outputs": [],
1534
   "source": [
1535
    "[cdb.get_name(x) for x in dataset['train'][ind]['stream']]"
1536
   ]
1537
  },
1538
  {
1539
   "cell_type": "code",
1540
   "execution_count": null,
1541
   "id": "196eaf9d",
1542
   "metadata": {
1543
    "scrolled": true
1544
   },
1545
   "outputs": [],
1546
   "source": [
1547
    "for i, c in enumerate(dataset['train'][ind]['stream']):\n",
1548
    "    print(i)\n",
1549
    "    if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:\n",
1550
    "        print(i, c, cdb.get_name(c))"
1551
   ]
1552
  },
1553
  {
1554
   "cell_type": "code",
1555
   "execution_count": null,
1556
   "id": "4a159ae3",
1557
   "metadata": {},
1558
   "outputs": [],
1559
   "source": [
1560
    "out = []\n",
1561
    "for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):\n",
1562
    "    d = {\n",
1563
    "        'id': cui,\n",
1564
    "        'label': cdb.get_name(cui),\n",
1565
    "        'count': 1000000,\n",
1566
    "        'name': cdb.get_name(cui),\n",
1567
    "        'cui': cui,\n",
1568
    "        'saliency': 0,\n",
1569
    "        'uid': i\n",
1570
    "    }\n",
1571
    "    out.append(d)"
1572
   ]
1573
  },
1574
  {
1575
   "cell_type": "code",
1576
   "execution_count": null,
1577
   "id": "97f40a06",
1578
   "metadata": {},
1579
   "outputs": [],
1580
   "source": [
1581
    "json.dump(out, open(\"./data/tmp/timeline_example_1.json\", 'w'))"
1582
   ]
1583
  },
1584
  {
1585
   "cell_type": "code",
1586
   "execution_count": null,
1587
   "id": "da6e079d",
1588
   "metadata": {},
1589
   "outputs": [],
1590
   "source": [
1591
    "len(out)"
1592
   ]
1593
  },
1594
  {
1595
   "cell_type": "code",
1596
   "execution_count": null,
1597
   "id": "e2730480",
1598
   "metadata": {},
1599
   "outputs": [],
1600
   "source": [
1601
    "out"
1602
   ]
1603
  }
1604
 ],
1605
 "metadata": {
1606
  "kernelspec": {
1607
   "display_name": "Python 3 (ipykernel)",
1608
   "language": "python",
1609
   "name": "python3"
1610
  },
1611
  "language_info": {
1612
   "codemirror_mode": {
1613
    "name": "ipython",
1614
    "version": 3
1615
   },
1616
   "file_extension": ".py",
1617
   "mimetype": "text/x-python",
1618
   "name": "python",
1619
   "nbconvert_exporter": "python",
1620
   "pygments_lexer": "ipython3",
1621
   "version": "3.8.0"
1622
  }
1623
 },
1624
 "nbformat": 4,
1625
 "nbformat_minor": 5
1626
}