In [None]:
import torch
import os
import datasets
import numpy as np
from collections import defaultdict
from medcat.cat import CAT
from foresight.datasets import patient_concept_stream
from foresight.datasets.filters import filter_by_count, filter_by_type
from foresight.datasets.utils import get_embeddings_for_tokens, stream_to_separate_examples, add_to_stream, \
                                  remove_parents_from_stream, bucket_concepts, cleanup_stream, \
                                  split_stream, add_age, get_all_splits, add_ttd, add_position_ids
from foresight.utils import pickle
from foresight.utils.cdb_utils import get_parents_map 
from foresight.utils.stream_utils import docs2stream, calculate_counts
from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
from medcat.cdb import CDB
from foresight.utils import pickle
import plotly.express as px

In [None]:
DAYS = 1 # Do: 1, 14, 30
MAX_SEQ_LEN = 256
#TYPES = ['T-45', 'T-55', 'T-26', 'T-29', 'T-40', 'T-9', 'T-27', 'T-11', 'T-39', 'T-18']
TYPES = ['ALL_TYPES']
#TYPES = ['T-11']
#TYPES = ['T-11', 'T-18']

In [None]:
BASE_NAME = 'annotated_february_2022'
DATASET_NAME = 'annotations_stream_phase2_v1'
RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{"_".join(TYPES)}'

In [None]:
ds_info = open("dataset-info/" + RUN_NAME + '.txt', 'w')
def fprint(*texts):
    for text in texts:
        print(text)
        ds_info.write(str(text) + "\n")

In [None]:
FROM_BASE = False
BASE_TOKENIZER_PATH = ''

In [None]:
TYPES = set(TYPES)

In [None]:
DATA_PATH = f"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}.pickle"
DATA_PATH_SPLITS = f"./data/timecat/mimic/{BASE_NAME}/{DATASET_NAME}_split/"
TOKENIZER_PATH = f"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle"
ALMOST_PREPARED_DATASET_SPLIT_PATH = f"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_almost_prepared_split/"
PREPARED_DATASET_SPLIT_PATH = f"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/"
JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH = f"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_just_before_encoding/"
CAT_PATH = "./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip"
PT_DOB_PATH = "./data/mimic/pt2dob_datetime.pickle"
PT_DOD_PATH = "./data/mimic/pt2dod_timestamp.pickle"
PT_SEX_PATH = "./data/mimic/pt2sex.pickle"
PT_LNS_PATH = f"./data/timecat/mimic/{BASE_NAME}/lns_{DATASET_NAME}.pickle"
PT_CNTS_PATH = f"./data/timecat/mimic/{BASE_NAME}/cnts_{DATASET_NAME}.pickle"
PT_ETHNICITY_PATH = "./data/mimic/pt2ethnicity.pickle"
TOKEN_TYPES_PATH = f'./data/timecat/mimic/{BASE_NAME}/types_{DATASET_NAME}.pickle'

BATCH_SIZE = 200
DEVICE = torch.device('cuda')
NUM_PROC = 16
MIN_COUNT = 2 # 3
MIN_GLOBAL_COUNT = 100 # 1000

In [None]:
cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})
cdb = cat.cdb

# Convert docs.pickle into patient stream used by HF datasets

In [None]:
doc2pt = pickle.load("./data/timecat/mimic/doc2pt.pickle")

In [None]:
doc2pt = {str(k):v for k,v in doc2pt.items()}

### Get counts

In [None]:
pt2cui2cnt = None
base_path = './data/timecat/mimic/annotated_february_2022/'
doc_paths = os.listdir(base_path)
doc_paths = [path for path in doc_paths if path.startswith("part_")] # So we keep only annotations data

for path in doc_paths:
    docs = pickle.load(os.path.join(base_path, path))
    
    pt2cui2cnt = calculate_counts(docs=docs,
                     doc2pt=doc2pt,
                     pt2cui2cnt=pt2cui2cnt,
                     meta_requirements={'Presence': 'True', 'Subject': 'Patient'})

In [None]:
pickle.dump(dict(pt2cui2cnt), f"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle")

In [None]:
pt2cui2cnt = pickle.load(f"./data/timecat/mimic/{BASE_NAME}/pt2cui2cnt.pickle")

In [None]:
# Total number of annotations per type
cnt_per_type = {}
other_cnt = 0
for pt in pt2cui2cnt:
    for cui in pt2cui2cnt[pt]:
        if cat.cdb.cui2type_ids[cui]:
            t = list(cat.cdb.cui2type_ids[cui])[0]
            cnt_per_type[t] = cnt_per_type.get(t, 0) + pt2cui2cnt[pt][cui]
        else:
            other_cnt += 1

In [None]:
fprint("Total number of annotations per type: ")
for t in cnt_per_type:
    fprint("{:30}: {}".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type[t]))
fprint("")

In [None]:
fprint("Total number of annotations: ", sum([x for x in cnt_per_type.values()]))
fprint("")

In [None]:
# Get total number of different concepts
all_cuis = set()
for pt in pt2cui2cnt.keys():
    for cui in pt2cui2cnt[pt]: 
        all_cuis.add(cui)
fprint("Total number of different concepts: ", len(all_cuis))
fprint("")

In [None]:
# Total number of patients
fprint("Total number of patients: ", len(pt2cui2cnt))
fprint("")

### Get pt2stream

In [None]:
base_path = f'./data/timecat/mimic/{BASE_NAME}/'
doc_paths = os.listdir(base_path)
doc_paths = [path for path in doc_paths if path.startswith("part_")] # So we keep only annotations data
pt2stream = None
doc2time =  {str(k):v for k,v in pickle.load("./data/timecat/mimic/doc2time.pickle").items()}

for path in doc_paths:
    docs = pickle.load(os.path.join(base_path, path))
    pt2stream = docs2stream(docs,
                            doc2pt=doc2pt,
                            pt2cui2cnt=pt2cui2cnt,
                            doc2time=doc2time,
                            entity_type_column='type_ids',
                            meta_requirements={'Subject': 'Patient', 'Presence': 'True'},
                            historical_meta='Time',
                            historical_meta_value='Past',
                            old_pt2stream=pt2stream,
                            skip_cuis={'S-418023006', '17971005'},
                            require_time=True)

In [None]:
pickle.dump(dict(pt2stream), DATA_PATH)

# Load dataset

In [None]:
dataset = datasets.load_dataset(os.path.abspath(patient_concept_stream.__file__), 
                                data_files={'train': DATA_PATH})['train']

# Calculate counts

In [None]:
# Calculate counts for tokens
token_cnt = defaultdict(int)
for _dataset in get_all_splits(dataset):
    for stream in _dataset['stream']:
        unique_tokens = set([sample['token'] for sample in stream])
        for token in unique_tokens:
            token_cnt[token] += 1
token_cnt = dict(token_cnt)

In [None]:
pickle.dump(token_cnt, PT_CNTS_PATH)

In [None]:
MIN_GLOBAL_COUNT = 100 # 1000

# Load and filter by count

In [None]:
token_cnt = pickle.load(PT_CNTS_PATH)

In [None]:
dataset = filter_by_count(dataset, min_count=MIN_COUNT, min_count_global=MIN_GLOBAL_COUNT, min_length=5, max_length=-1, 
                          num_proc=NUM_PROC, token_cnt=token_cnt)

### Split and save

In [None]:
# Total number of annotations per type
cnt_per_type = {}
for cui in token_cnt:
    if cat.cdb.cui2type_ids[cui]:
        t = list(cat.cdb.cui2type_ids[cui])[0]
        cnt_per_type[t] = cnt_per_type.get(t, 0) + token_cnt[cui]

In [None]:
dataset = dataset.train_test_split(test_size = 0.05)

In [None]:
dataset.save_to_disk(DATA_PATH_SPLITS)

# CONTINUE FROM HERE WHEN NOT THE FIRST RUN

### Load splits

In [None]:
token_cnt = pickle.load(PT_CNTS_PATH)

In [None]:
dataset = datasets.load_from_disk(DATA_PATH_SPLITS)

In [None]:
dataset

In [None]:
fprint("Total number of pts in train/test: {}/{}".format(len(dataset['train']), len(dataset['test'])))

# Filter to required type

In [None]:
if "ALL_TYPES" not in TYPES:
    print("FILTERING")
    dataset = filter_by_type(dataset, types_to_keep=TYPES, num_proc=NUM_PROC)

# Add Death token

In [None]:
pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}
pt2death = {k:"The patient has died" for k in pt2dod_timestamp.keys()}
dataset = dataset.map(
        lambda examples: add_to_stream(examples, pt2death, last=True, prefix=None, token_type='death'),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

# Bucket and split

In [None]:
dataset = dataset.map(
        lambda examples: bucket_concepts(examples, bucket_size_seconds=DAYS*24*60*60, duration_separator=False),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
dataset

## Trim long streams

In [None]:
from collections import defaultdict
lns = []
for _dataset in get_all_splits(dataset):
    for stream in _dataset['stream']:
        lns.append(len(stream))
pickle.dump(lns, PT_LNS_PATH)

In [None]:
lns = pickle.load(PT_LNS_PATH)

In [None]:
len(lns)

In [None]:
max(lns)

In [None]:
max_len = int(np.percentile(lns, 95))
max_len

In [None]:
fig = px.histogram(x=[x for x in lns if x < max_len and x > 5], labels={'x': 'length'})

In [None]:
fig.write_html("./dataset-info/" + RUN_NAME + ".html")

In [None]:
dataset = filter_by_count(dataset, min_count=0, min_count_global=0, min_length=10, max_length=max_len, 
                          num_proc=NUM_PROC, token_cnt=token_cnt)

In [None]:
dataset

## Split to max len

In [None]:
dataset = dataset.map(
        lambda examples: split_stream(examples, max_seq_len=MAX_SEQ_LEN-32),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

## Save again

In [None]:
dataset.save_to_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)

In [None]:
dataset = datasets.load_from_disk(ALMOST_PREPARED_DATASET_SPLIT_PATH)

In [None]:
dataset

# Add DOD and TTD

In [None]:
pt2dob_timestamp = {str(k):v for k,v in pickle.load(PT_DOB_PATH).items()}
dataset = dataset.map(
        lambda examples: add_age(examples, pt2dob_timestamp=pt2dob_timestamp),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
"""
pt2dod_timestamp = {str(k):v for k,v in pickle.load(PT_DOD_PATH).items()}
# ADD time to die
dataset = dataset.map(
        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)
"""

### Another way for TTD

In [None]:
"""
# ADD time to die
dataset['train'] = dataset['train'].map(
        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

dataset['test'] = dataset['test'].map(
        lambda examples: add_ttd(examples, pt2dod_timestamp=pt2dod_timestamp, ttd_normalizer=14 * 24 * 60 * 60,
                                 max_nttd=10, ttd_prob=1, duplicate_streams=True),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)
"""

# Add sex and ethnicity

In [None]:
# Add Sex
pt2sex = pickle.load(PT_SEX_PATH)
dataset = dataset.map(
        lambda examples: add_to_stream(examples, pt2sex, last=False, prefix='<SEX>', token_type='sex'),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
# Ethnicity
pt2ethnicity = pickle.load(PT_ETHNICITY_PATH)
dataset = dataset.map(
        lambda examples: add_to_stream(examples, pt2ethnicity, last=False, prefix='<ETHNICITY>', token_type='ethnicity'),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

# Final filter

In [None]:
dataset = filter_by_count(dataset, min_count=None, min_count_global=None, min_length=10, num_proc=NUM_PROC)

# Remove parents

In [None]:
# Diseases
cuis = [token for token in cdb.config.linking['filters']['cuis'] if token in cdb.cui2names]
ch2parents = get_parents_map(cuis, cdb.addl_info['pt2ch'], depth=2)

In [None]:
dataset = dataset.map(
        lambda examples: remove_parents_from_stream(examples, ch2parents=ch2parents, separator='<SEP>'),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

## Add position IDs

In [None]:
dataset = dataset.map(
        lambda examples: add_position_ids(examples, separators={'<SEP>', '<SEP-1>', '<SEP-7>' '<SEP-14>', '<SEP-30>', '<SEP-365>'}),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

# Get token_type2tokens

In [None]:
token_type2tokens = defaultdict(set)
total_cnt = 0
for _dataset in get_all_splits(dataset):
    for stream in _dataset['stream']:
        for example in stream:
            token_type2tokens[example['token_type']].add(example['token'])
            total_cnt += 1
token_type2tokens = dict(token_type2tokens)
pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)
fprint("Total number of annotations: ", total_cnt)

In [None]:
pickle.dump(token_type2tokens, TOKEN_TYPES_PATH)
fprint("Total number of annotations: ", total_cnt)

# Cleanup stream and leave only what we need

In [None]:
dataset = dataset.map(
        lambda examples: cleanup_stream(examples, keep_time=True, keep_type=True, keep_position_ids=True,
                                        keep_context_representation=False),
        batched=True,
        load_from_cache_file=False,
        num_proc=NUM_PROC)

### Save

In [None]:
dataset.save_to_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH

In [None]:
# Total number of patients fater intial filtering
train_len = len(dataset['train'])
test_len = len(dataset['test'])
fprint("Total number of pts in train: ", train_len)
fprint("Total number of pts in test: ", test_len)
fprint("Total number of pts: ", train_len + test_len)

In [None]:
# Total number of annotations per type after filtering
cnt_per_type_after = {}
for _dataset in get_all_splits(dataset):
    for stream in _dataset['stream']:
        for cui in stream:
            if cat.cdb.cui2type_ids.get(cui, None):
                t = list(cat.cdb.cui2type_ids[cui])[0]
                cnt_per_type_after[t] = cnt_per_type_after.get(t, 0) + 1

In [None]:
fprint("Total number of annotations per type: \n")
for t in cnt_per_type_after:
    fprint("{:30}: {}".format(cat.cdb.addl_info['type_id2name'][t].title(), cnt_per_type_after[t]))

# Make tokenizer

In [None]:
extra_tokenizer = None
#extra_tokenizer = SimpleMapTokenizer.load("./data/time/models/slam_tokenizer_annotations_stream_phase2_1d_200_ALL_TYPES.pickle")

In [None]:
token_type2tokens = pickle.load(TOKEN_TYPES_PATH)
extra_concepts = None
if extra_tokenizer is not None:
    extra_concepts = list(extra_tokenizer.tkn2id.keys())

    for k,v in extra_tokenizer.token_type2tokens.items():
        if k in token_type2tokens:
            token_type2tokens[k].update(extra_tokenizer.token_type2tokens[k])
        else:
            token_type2tokens[k] = extra_tokenizer.token_type2tokens[k]

In [None]:
_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())
embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types,
                                                        concepts=extra_concepts)

In [None]:
tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
assert len(tokenizer.tkn2id) == len(tokenizer.id2tkn)
assert len(tokenizer.embeddings) == len(tokenizer.id2tkn)
assert len(tokenizer.tkn2name) == len(tokenizer.id2tkn)
fprint(tokenizer.pad_id, tokenizer.id2tkn[tokenizer.pad_id])

In [None]:
len(tokenizer.tkn2name)

In [None]:
# save
tokenizer.save(TOKENIZER_PATH)

In [None]:
# Total number of different concepts after all filtering
fprint("Total number of concepts after filtering: ", len(tokenizer.tkn2id))
fprint("")

In [None]:
# Total number annotations after all filtering
fprint("Total number of annotations after filtering: ", sum([x for x in cnt_per_type_after.values()]))
fprint("")

# Print number of different concepts per type after filtering

In [None]:
cnt_per_type = {}
for cui in tkn2id:
    if cat.cdb.cui2type_ids.get(cui, ['Other']):
        t = list(cat.cdb.cui2type_ids.get(cui, ['Other']))[0]
        cnt_per_type[t] = cnt_per_type.get(t, 0) + 1
fprint("Total number of <<different>> concepts per type after filtering")
for t in cnt_per_type:
    fprint("{:30}: {}".format(cat.cdb.addl_info['type_id2name'].get(t, t).title(), cnt_per_type[t]))
fprint("")

# Create global tokenizer

In [None]:
_types = list(cdb.addl_info['type_id2name'].keys()) + list(token_type2tokens.keys())
concepts = list(cat.config.linking['filters']['cuis'])
embeddings, tkn2id, id2tkn, = get_embeddings_for_tokens(dataset, cdb, context_type='xlong', types=_types, concepts=concepts)

In [None]:
tkn2name = {tkn:cdb.get_name(tkn) for tkn in tkn2id.keys()}
tokenizer = SimpleMapTokenizer(tkn2id=tkn2id, pad_id=tkn2id['<PAD>'], tkn2name=tkn2name,
                               token_type2tokens=token_type2tokens, embeddings=embeddings,
                               global_token_cnt=token_cnt, max_len=MAX_SEQ_LEN)

In [None]:
tokenizer.save(BASE_TOKENIZER_PATH)

# Convert tokens to IDs

In [None]:
if FROM_BASE:
    print("USING BASE TOKENIZER")
    TOKENIZER_PATH = BASE_TOKENIZER_PATH

In [None]:
tokenizer =  SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
encoded_dataset = dataset.map(
        lambda examples: tokenizer.encode(examples),
        batched=True,
        remove_columns=['stream'],
        load_from_cache_file=False,
        num_proc=NUM_PROC)

In [None]:
encoded_dataset.save_to_disk(PREPARED_DATASET_SPLIT_PATH)

In [None]:
PREPARED_DATASET_SPLIT_PATH

In [None]:
TOKENIZER_PATH

# Test is all OK

In [None]:
encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)

In [None]:
dataset = datasets.load_from_disk(JUST_BEFORE_ENCODING_DATASET_SPLIT_PATH)

In [None]:
tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
encoded_dataset

In [None]:
dataset

In [None]:
ind = 1096

In [None]:
from datetime import datetime

In [None]:
[cdb.get_name(x) for x in dataset['train'][ind]['stream']]

In [None]:
for ty, p, t, c in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids'])):
    print(datetime.fromtimestamp(t), p, "{:20}".format(ty), c)

In [None]:
encoded_dataset['train'][ind]['patient_id']

In [None]:
ds_info.close()

# Preapre for Foresight

In [None]:
ind = 32330

In [None]:
import json

In [None]:
[cdb.get_name(x) for x in dataset['train'][ind]['stream']]

In [None]:
for i, c in enumerate(dataset['train'][ind]['stream']):
    print(i)
    if i > 20 and c not in dataset['train'][ind]['stream'][0:i]:
        print(i, c, cdb.get_name(c))

In [None]:
out = []
for i, cui in enumerate(dataset['train'][ind]['stream'][:161]):
    d = {
        'id': cui,
        'label': cdb.get_name(cui),
        'count': 1000000,
        'name': cdb.get_name(cui),
        'cui': cui,
        'saliency': 0,
        'uid': i
    }
    out.append(d)

In [None]:
json.dump(out, open("./data/tmp/timeline_example_1.json", 'w'))

In [None]:
len(out)

In [None]:
out