In [None]:
import os
#os.environ["WANDB_DISABLED"] = "true"

In [None]:
import logging
log = logging.getLogger()
log.handlers.clear()
log.addHandler(logging.StreamHandler())
log.setLevel(logging.WARNING)

In [None]:
from foresight.datasets.data_collator import CollataAndPad

from foresight.utils import pickle
from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer
from medcat.cdb import CDB
from foresight.datasets.data_collator import CollataAndPad
from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from medcat.cat import CAT
from foresight.models.lucid_transformers import LucidLM2HF
from transformers import SchedulerType

from datasets import Dataset
import math
import datasets
import numpy as np
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import shutil
import random
import pandas as pd

In [None]:
DAYS = 1
MAX_SEQ_LEN = 256
TYPES = ['ALL_TYPES']
#TYPES = ['T-11']
#TYPES = ['T-11', 'T-18']

In [None]:
FROM_BASE = False
#BASE_TOKENIZER_PATH = f"./data/time/models/gpt/tokenizer_annotations_stream_phase2_v1_1d_256_ALL_TYPES_v7.pickle"

In [None]:
USE_POSITION_IDS = True

In [None]:
SMALL_TEST_SIZE = 1000

In [None]:
BASE_NAME = 'annotated_february_2022'
DATASET_NAME = 'annotations_stream_phase2_v1'
RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{"_".join(TYPES)}'

In [None]:
TOKENIZER_PATH = f"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle"
PREPARED_DATASET_SPLIT_PATH = f"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/"
MODEL_PATH = f"./data/timecat/models/gpt-phase3-{RUN_NAME}-Positions-{USE_POSITION_IDS}-fromBase-{FROM_BASE}-old-test/"
RESULTS_HYPERPARAM = "./data/timecat/models/gpt/results/"
CAT_PATH = "./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip"

DEVICE = torch.device('cuda')

# Load everything and prepare train/test set

In [None]:
cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})
cdb = cat.cdb

In [None]:
encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)
encoded_dataset

In [None]:
if FROM_BASE:
    print("USING BASE")
    TOKENIZER_PATH = BASE_TOKENIZER_PATH
tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)

In [None]:
collate_fn = CollataAndPad(max_seq_len=tokenizer.max_len + 1, pad_id=tokenizer.tkn2id['<PAD>'], 
                           shift_labels=False,
                           use_position_ids=USE_POSITION_IDS,
                           use_token_type_ids=False)

In [None]:
dataset_train = DataLoader(encoded_dataset['train'], batch_size=1000, shuffle=False, collate_fn=collate_fn)
dataset_test = DataLoader(encoded_dataset['test'], batch_size=1000, shuffle=False, collate_fn=collate_fn)

### Create a mini dataset for testing

In [None]:
if SMALL_TEST_SIZE:
    random.seed(11)
    inds = random.choices([i for i in range(len(encoded_dataset['test']))], k=SMALL_TEST_SIZE)
    encoded_dataset_test_mini = Dataset.from_dict(encoded_dataset['test'][inds])
    dataset_test_mini = DataLoader(encoded_dataset_test_mini, batch_size=1000, shuffle=False, collate_fn=collate_fn)
else:
    encoded_dataset_test_mini = encoded_dataset['test']

# Create GPT2

In [None]:
# Load existing if you want, skip all other cells in this section if YES
model = GPT2LMHeadModel.from_pretrained('./data/timecat/models/gpt/gpt-phase2-annotations_stream_phase2_v1_1d_256_ALL_TYPES-Positions-False-fromBase-False-old-test/')

In [None]:
# Make a new model
config = GPT2Config(
    vocab_size=len(tokenizer.embeddings),
    n_positions=tokenizer.max_len+1,
    n_ctx=tokenizer.max_len+1,
    n_embd=512,
    n_layer=16,
    n_head=16,
    bos_token_id=tokenizer.tkn2id['<PAD>'],
    eos_token_id=tokenizer.tkn2id['<PAD>']
)
model = GPT2LMHeadModel(config)

In [None]:
#model.transformer.wte.load_state_dict({'weight': torch.tensor(tokenizer.embeddings, dtype=torch.float32)})
#model.transformer.wte.weight.requires_grad = True

# Lucid GPT

In [None]:
# Make a new model
config = GPT2Config(
    vocab_size=len(tokenizer.embeddings),
    n_positions=tokenizer.max_len+1,
    n_ctx=tokenizer.max_len+1,
    n_embd=512,
    n_layer=16,
    n_head=16,
    bos_token_id=tokenizer.tkn2id['<PAD>'],
    eos_token_id=tokenizer.tkn2id['<PAD>']
)

addl_decoder_config = {
    'rotary_pos_emb': True,
#    'ff_glu': True,
}

In [None]:
model = LucidLM2HF(config, addl_decoder_config=addl_decoder_config)

# Trainer

In [None]:
test_set_to_use = encoded_dataset_test_mini # This will be automatically the whole test set if mini is not assigned

In [None]:
all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', 
                 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])

In [None]:
compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, 
                                     prediction_scope='time_range', 
                                     topk=1, 
                                     start=0, 
                                     return_all_metrics=False, 
                                     batch_size=1000, 
                                     select_token_types=all_types,
                                     type_data=test_set_to_use['token_type'],
                                     token_type2tokens=tokenizer.token_type2tokens,
                                     time_data=test_set_to_use['time'], 
                                     time_range=30*24*60*60,
                                     ignore_label_status=False,
                                     min_time_left=24*60*60)

In [None]:
training_args = TrainingArguments(
    output_dir='./gpt-16-16_1day_no_base_data',          # output directory
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=4,  # batch size per device during training
    per_device_eval_batch_size=4,   # batch size for evaluation
    weight_decay=1e-2,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    warmup_ratio=0.01,
    learning_rate= 3.14e-04,
    eval_accumulation_steps=1,
    gradient_accumulation_steps=16,
    do_eval=True,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    metric_for_best_model='eval_precision',
    load_best_model_at_end=True,
    lr_scheduler_type=SchedulerType.LINEAR
)

In [None]:
import wandb

In [None]:
wandb.init(project='timecat', entity='wish', name=RUN_NAME + '-gpt-16-16_1day_no_base_data')

In [None]:
trainer = Trainer(
    model=model,                         # the instantiated ðŸ¤— Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=encoded_dataset['train'],         # training dataset
    eval_dataset=test_set_to_use,             # evaluation dataset
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=None,
)

#### Make sure stuff is correct

In [None]:
from datetime import datetime

In [None]:
ind = 1117

In [None]:
for ty, p, t, c, ind_id in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids']), encoded_dataset['train'][ind]['input_ids']):
    print(datetime.fromtimestamp(t), p, "{:20}".format(ty), c, ind_id)

In [None]:
encoded_dataset['train'][ind]['patient_id']

In [None]:
MODEL_PATH

# Run training 

In [None]:
trainer.train()

In [None]:
trainer.save_model(MODEL_PATH)

# Test

In [None]:
all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', 
                 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])

In [None]:
test_set_to_use = encoded_dataset['test']
test_set_to_use

In [None]:
trainer = Trainer(
    model=model,                         # the instantiated ðŸ¤— Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=None,         # training dataset
    eval_dataset=None,             # evaluation dataset
    compute_metrics=None,
    data_collator=collate_fn,
    tokenizer=None,
)

In [None]:
def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):
    size = 1000
    for i in range(int(math.ceil(len(test_set_to_use) / size))):
        _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])
        compute_metrics.time_data = _dataset['time']
        compute_metrics.type_data = _dataset['token_type']
        if len(_dataset):
            p = trainer.predict(_dataset)
            metrics_data = compute_metrics(p, metrics_data)['metrics_data']
    m_file.write("{}, {}, {}, {}\n".format(f_name, metrics_data['precision']['all'], 
                                 metrics_data['precision']['new'], 
                                 metrics_data['precision']['old'],
                                 metrics_data['recall']['all'],
                                 metrics_data['recall']['new'],
                                 metrics_data['recall']['old']))
    print(f_name,
          metrics_data['precision']['all'], 
          metrics_data['precision']['new'], 
          metrics_data['precision']['old'],
          metrics_data['recall']['all'],
          metrics_data['recall']['new'],
          metrics_data['recall']['old']) 
    pickle.dump(metrics_data, f_name)

    return metrics_data

In [None]:
m_file = open("./metrics/summary.txt", 'w', buffering=1)
m_file.write("file_name, precision all, precision new, precision old\n")

for types in [all_types, {'T-11'}, {'T-55'}, {'T-18'}, {'T-39'}]:
    _types = list(types)[0] if len(types) == 1 else 'all_types'
    for timerange in [30, 365, 1000000]:
        compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, 
                                         prediction_scope='time_range', 
                                         topk=1, # 1, 5, 10
                                         start=0, # 0, 10, 20, 50, 100
                                         return_all_metrics=True, 
                                         batch_size=1000, 
                                         select_token_types=types,
                                         type_data=test_set_to_use['token_type'],
                                         token_type2tokens=tokenizer.token_type2tokens,
                                         time_data=test_set_to_use['time'], 
                                         time_range=timerange*24*60*60, #30, 365, 1000000
                                         ignore_label_status=False,
                                         min_time_left=24*60*60)
        f_name = f"./metrics/start-0_topk-1_time_range-{timerange}_types-{_types}.pickle"
        get_metrics(None, test_set_to_use, trainer, m_file, f_name)

    for topk in [5, 10]:
        compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, 
                                         prediction_scope='time_range', 
                                         topk=topk, # 1, 5, 10
                                         start=0, # 0, 10, 20, 50, 100
                                         return_all_metrics=True, 
                                         batch_size=1000, 
                                         select_token_types=types,
                                         type_data=test_set_to_use['token_type'],
                                         token_type2tokens=tokenizer.token_type2tokens,
                                         time_data=test_set_to_use['time'], 
                                         time_range=30*24*60*60, #30, 365, 1000000
                                         ignore_label_status=False,
                                         min_time_left=24*60*60)
        f_name = f"./metrics/start-0_topk-{topk}_time_range-30_types-{_types}.pickle"
        get_metrics(None, test_set_to_use, trainer, m_file, f_name)
m_file.close()

# Test Death

In [None]:
all_types = set(['death'])

In [None]:
test_set_to_use

In [None]:
trainer = Trainer(
    model=model,                         # the instantiated ðŸ¤— Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=None,         # training dataset
    eval_dataset=None,             # evaluation dataset
    compute_metrics=None,
    data_collator=collate_fn,
    tokenizer=None,
)

In [None]:
def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):
    size = 1000
    for i in range(int(math.ceil(len(test_set_to_use) / size))):
        _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])
        compute_metrics.time_data = _dataset['time']
        compute_metrics.type_data = _dataset['token_type']
        if len(_dataset):
            p = trainer.predict(_dataset)
            metrics_data = compute_metrics(p, metrics_data)['metrics_data']
    m_file.write("{}, {}, {}, {}\n".format(f_name, metrics_data['precision']['all'], 
                                 metrics_data['precision']['new'], 
                                 metrics_data['precision']['old']))
    print(f_name,
          metrics_data['precision']['all'], 
          metrics_data['precision']['new'], 
          metrics_data['precision']['old'])
    pickle.dump(metrics_data, f_name)

In [None]:
compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, 
                                 topk=1, # 1, 5, 10
                                 start=0, # 0, 10, 20, 50, 100
                                 return_all_metrics=True, 
                                 batch_size=1000, 
                                 type_data=test_set_to_use['token_type'],
                                 token_type2tokens=tokenizer.token_type2tokens,
                                 time_data=test_set_to_use['time'], 
                                 time_range=24*60*60, #30, 365, 1000000
                                 ignore_label_status=False,
                                 min_time_left=0,
                                 concept_id=270)

In [None]:
metrics_data = None
_dataset = Dataset.from_dict(test_set_to_use[0:1000])
compute_metrics.time_data = _dataset['time']
compute_metrics.type_data = _dataset['token_type']
if len(_dataset):
    p = trainer.predict(_dataset)
    metrics_data = compute_metrics(p, metrics_data)['metrics_data']

In [None]:
metrics_data

In [None]:
tokenizer.tkn2id['The patient has died']

In [None]:
for i in range(len(_dataset['input_ids'])):
    if 270 in _dataset['input_ids'][i]:
        print(i)

In [None]:
metrics_data = None
size = 1000
for i in range(int(math.ceil(len(test_set_to_use) / size))):
    _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])
    compute_metrics.time_data = _dataset['time']
    compute_metrics.type_data = _dataset['token_type']
    if len(_dataset):
        p = trainer.predict(_dataset)
        metrics_data = compute_metrics(p, metrics_data)['metrics_data']

# Hyperparameter search

In [None]:
from ray.tune.schedulers import PopulationBasedTraining
from ray import tune
from ray.tune import CLIReporter 
import ray

In [None]:
compute_metrics = ComputePrecisionHF(id2tkn, id2type, prediction_scope='age', topk=1, start=0, batch_size=2000)

In [None]:
NUM_TRIALS = 20
N_GPU_PER_TRIAL = 1
METRIC_TO_OPTIMIZE = 'eval_precision'

In [None]:
def get_model(params):
    torch.cuda.empty_cache()
    if params is None:
        params = {}
    
    config = GPT2Config(
        vocab_size=len(embeddings),
        n_positions=MAX_SEQ_LEN+1,
        n_ctx=MAX_SEQ_LEN+1,
        n_embd=params.get('n_embd', 300),
        n_layer=params.get('n_layer', 1),
        n_head=params.get('n_head', 1),
        bos_token_id=tkn2id['<PAD>'],
        eos_token_id=tkn2id['<PAD>']
    )
    model = GPT2LMHeadModel(config)
    
    if params.get('load_weights', 0):
        model.transformer.wte.load_state_dict({'weight': torch.tensor(embeddings, dtype=torch.float32)})
        model.transformer.wte.weight.requires_grad = True
    
    return model

In [None]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=5,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=128,   # batch size for evaluation
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,
    eval_steps=200,
    learning_rate= 5e-5,
    eval_accumulation_steps=1,
    do_eval=True,
    evaluation_strategy='steps',
    skip_memory_metrics=True,
)

In [None]:
training_args.n_head = 1
training_args.n_layer = 1
training_args.n_embd = 300
training_args.load_weights = 0

In [None]:
tune_dataset = encoded_dataset['train'].train_test_split(test_size=0.1)

In [None]:
tune_train_dataset = tune_dataset['train']
tune_test_dataset = tune_dataset['test']

In [None]:
trainer = Trainer(
#    model=model,                         # the instantiated ðŸ¤— Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=tune_train_dataset,         # training dataset
    eval_dataset=tune_test_dataset,             # evaluation dataset
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=None,
    model_init=get_model,
)

In [None]:
tune_config = {
    "num_train_epochs": tune.choice([5]),
    "n_head": tune.choice([2, 4, 6]),
}
scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric=METRIC_TO_OPTIMIZE,
    mode="max",
    perturbation_interval=1,
    hyperparam_mutations={
        "weight_decay": tune.uniform(0.0, 0.3),
        "learning_rate": tune.uniform(1e-5, 5e-5),
        "per_device_train_batch_size": [16, 32, 64, 128],
        "n_layer": tune.choice([2, 4, 6, 8]),
#       "n_embd": tune.choice([256, 512]),
        "load_weights": tune.choice([0, 1]),
        "warmup_steps": tune.choice([20, 40, 60, 100]),
    })

In [None]:
import copy
def compute_objective(metrics):
    metrics = copy.deepcopy(metrics)
    eval_precision = metrics.pop('eval_precision')
    
    return eval_precision

In [None]:
best_model = trainer.hyperparameter_search(
    hp_space=lambda _: tune_config,
    backend="ray",
    n_trials=NUM_TRIALS,
    direction='maximize',
    compute_objective=compute_objective,
    resources_per_trial={
        "cpu": 1,
        "gpu": N_GPU_PER_TRIAL
    },
    scheduler=scheduler,
    keep_checkpoints_num=1,
    checkpoint_score_attr=METRIC_TO_OPTIMIZE,
    stop=None,
    local_dir=RESULTS_HYPERPARAM,
    name="21_May_2021",
    log_to_file=False,
    loggers=None,# (WandbLogger, ),
    )

In [None]:
best_model

# Saliency 

In [None]:
import ecco

In [None]:
lm = ecco.LM(trainer.model, tokenizer, model_name='gpt2')

In [None]:
ind = 49
print("~~".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids']]))
text = "~~".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids'][1:-1]])

In [None]:
output = lm.generate(text, generate=10, do_sample=True, temperature=1)

In [None]:
output.saliency(style="detailed")

# Probability prediction

In [None]:
from foresight.sight import Sight

In [None]:
_ = model.eval()

In [None]:
sight = Sight(tokenizer=tokenizer, device='cuda', model=model)

In [None]:
cdb.name2cuis['muscle~pain']

In [None]:
cdb.get_name('pain')

In [None]:
text = '<ETHNICITY>~~White~~<SEX>~~Male~~<AGE>~~23~~49727002~~386661006'.split("~~")

In [None]:
# Small with WD
r = sight.next_concepts(text, type_ids=['T-11'], n=40, p_new=True, create_position_ids=False)
print([cdb.get_name(x) for x in text])
for x in r:
    print(x[0], x[1], cdb.get_name(x[0]))