1240 lines (1239 with data), 35.2 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "2bbeaebc",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"#os.environ[\"WANDB_DISABLED\"] = \"true\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c5ba9e3",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"log = logging.getLogger()\n",
"log.handlers.clear()\n",
"log.addHandler(logging.StreamHandler())\n",
"log.setLevel(logging.WARNING)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78bc409f",
"metadata": {},
"outputs": [],
"source": [
"from foresight.datasets.data_collator import CollataAndPad\n",
"\n",
"from foresight.utils import pickle\n",
"from foresight.tokenizers.simple_map_tokenizer import SimpleMapTokenizer\n",
"from medcat.cdb import CDB\n",
"from foresight.datasets.data_collator import CollataAndPad\n",
"from foresight.metrics.next_concept_prediction import precision, metrics_data2df, ComputePrecisionHF\n",
"from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments\n",
"from medcat.cat import CAT\n",
"from foresight.models.lucid_transformers import LucidLM2HF\n",
"from transformers import SchedulerType\n",
"\n",
"from datasets import Dataset\n",
"import math\n",
"import datasets\n",
"import numpy as np\n",
"from torch.utils.data import DataLoader\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import os\n",
"import shutil\n",
"import random\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04d1872a",
"metadata": {},
"outputs": [],
"source": [
"DAYS = 1\n",
"MAX_SEQ_LEN = 256\n",
"TYPES = ['ALL_TYPES']\n",
"#TYPES = ['T-11']\n",
"#TYPES = ['T-11', 'T-18']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd25aabc",
"metadata": {},
"outputs": [],
"source": [
"FROM_BASE = False\n",
"#BASE_TOKENIZER_PATH = f\"./data/time/models/gpt/tokenizer_annotations_stream_phase2_v1_1d_256_ALL_TYPES_v7.pickle\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bd1eb81",
"metadata": {},
"outputs": [],
"source": [
"USE_POSITION_IDS = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4861809",
"metadata": {},
"outputs": [],
"source": [
"SMALL_TEST_SIZE = 1000"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e9b04ae",
"metadata": {},
"outputs": [],
"source": [
"BASE_NAME = 'annotated_february_2022'\n",
"DATASET_NAME = 'annotations_stream_phase2_v1'\n",
"RUN_NAME = f'{DATASET_NAME}_{DAYS}d_{MAX_SEQ_LEN}_{\"_\".join(TYPES)}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d08660d5",
"metadata": {},
"outputs": [],
"source": [
"TOKENIZER_PATH = f\"./data/timecat/models/gpt/tokenizer_{RUN_NAME}.pickle\"\n",
"PREPARED_DATASET_SPLIT_PATH = f\"./data/timecat/mimic/{BASE_NAME}/{RUN_NAME}_prepared_split/\"\n",
"MODEL_PATH = f\"./data/timecat/models/gpt-phase3-{RUN_NAME}-Positions-{USE_POSITION_IDS}-fromBase-{FROM_BASE}-old-test/\"\n",
"RESULTS_HYPERPARAM = \"./data/timecat/models/gpt/results/\"\n",
"CAT_PATH = \"./data/models/modelpacks/mc_modelpack_phase2_snomed_190k_february_2022.zip\"\n",
"\n",
"DEVICE = torch.device('cuda')"
]
},
{
"cell_type": "markdown",
"id": "93bc3bfa",
"metadata": {},
"source": [
"# Load everything and prepare train/test set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53750c5d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"cat = CAT.load_model_pack(CAT_PATH, meta_cat_config_dict={'general': {'device': 'cpu'}})\n",
"cdb = cat.cdb"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17d627fa",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset = datasets.load_from_disk(PREPARED_DATASET_SPLIT_PATH)\n",
"encoded_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f43d612c",
"metadata": {},
"outputs": [],
"source": [
"if FROM_BASE:\n",
" print(\"USING BASE\")\n",
" TOKENIZER_PATH = BASE_TOKENIZER_PATH\n",
"tokenizer = SimpleMapTokenizer.load(TOKENIZER_PATH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5f0e619",
"metadata": {},
"outputs": [],
"source": [
"collate_fn = CollataAndPad(max_seq_len=tokenizer.max_len + 1, pad_id=tokenizer.tkn2id['<PAD>'], \n",
" shift_labels=False,\n",
" use_position_ids=USE_POSITION_IDS,\n",
" use_token_type_ids=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5580622b",
"metadata": {},
"outputs": [],
"source": [
"dataset_train = DataLoader(encoded_dataset['train'], batch_size=1000, shuffle=False, collate_fn=collate_fn)\n",
"dataset_test = DataLoader(encoded_dataset['test'], batch_size=1000, shuffle=False, collate_fn=collate_fn)"
]
},
{
"cell_type": "markdown",
"id": "622a99b9",
"metadata": {},
"source": [
"### Create a mini dataset for testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4083f3da",
"metadata": {},
"outputs": [],
"source": [
"if SMALL_TEST_SIZE:\n",
" random.seed(11)\n",
" inds = random.choices([i for i in range(len(encoded_dataset['test']))], k=SMALL_TEST_SIZE)\n",
" encoded_dataset_test_mini = Dataset.from_dict(encoded_dataset['test'][inds])\n",
" dataset_test_mini = DataLoader(encoded_dataset_test_mini, batch_size=1000, shuffle=False, collate_fn=collate_fn)\n",
"else:\n",
" encoded_dataset_test_mini = encoded_dataset['test']"
]
},
{
"cell_type": "markdown",
"id": "1a22d32c",
"metadata": {},
"source": [
"# Create GPT2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efed3e18",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Load existing if you want, skip all other cells in this section if YES\n",
"model = GPT2LMHeadModel.from_pretrained('./data/timecat/models/gpt/gpt-phase2-annotations_stream_phase2_v1_1d_256_ALL_TYPES-Positions-False-fromBase-False-old-test/')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58cc7f79",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Make a new model\n",
"config = GPT2Config(\n",
" vocab_size=len(tokenizer.embeddings),\n",
" n_positions=tokenizer.max_len+1,\n",
" n_ctx=tokenizer.max_len+1,\n",
" n_embd=512,\n",
" n_layer=16,\n",
" n_head=16,\n",
" bos_token_id=tokenizer.tkn2id['<PAD>'],\n",
" eos_token_id=tokenizer.tkn2id['<PAD>']\n",
")\n",
"model = GPT2LMHeadModel(config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb8edfca",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"#model.transformer.wte.load_state_dict({'weight': torch.tensor(tokenizer.embeddings, dtype=torch.float32)})\n",
"#model.transformer.wte.weight.requires_grad = True"
]
},
{
"cell_type": "markdown",
"id": "82dffb5f",
"metadata": {},
"source": [
"# Lucid GPT"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d644dc7",
"metadata": {},
"outputs": [],
"source": [
"# Make a new model\n",
"config = GPT2Config(\n",
" vocab_size=len(tokenizer.embeddings),\n",
" n_positions=tokenizer.max_len+1,\n",
" n_ctx=tokenizer.max_len+1,\n",
" n_embd=512,\n",
" n_layer=16,\n",
" n_head=16,\n",
" bos_token_id=tokenizer.tkn2id['<PAD>'],\n",
" eos_token_id=tokenizer.tkn2id['<PAD>']\n",
")\n",
"\n",
"addl_decoder_config = {\n",
" 'rotary_pos_emb': True,\n",
"# 'ff_glu': True,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "521f6f82",
"metadata": {},
"outputs": [],
"source": [
"model = LucidLM2HF(config, addl_decoder_config=addl_decoder_config)"
]
},
{
"cell_type": "markdown",
"id": "b56ee1ea",
"metadata": {},
"source": [
"# Trainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca720485",
"metadata": {},
"outputs": [],
"source": [
"test_set_to_use = encoded_dataset_test_mini # This will be automatically the whole test set if mini is not assigned"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "876c3fcc",
"metadata": {},
"outputs": [],
"source": [
"all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', \n",
" 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28e17497",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
" prediction_scope='time_range', \n",
" topk=1, \n",
" start=0, \n",
" return_all_metrics=False, \n",
" batch_size=1000, \n",
" select_token_types=all_types,\n",
" type_data=test_set_to_use['token_type'],\n",
" token_type2tokens=tokenizer.token_type2tokens,\n",
" time_data=test_set_to_use['time'], \n",
" time_range=30*24*60*60,\n",
" ignore_label_status=False,\n",
" min_time_left=24*60*60)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "974a735d",
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\n",
" output_dir='./gpt-16-16_1day_no_base_data', # output directory\n",
" num_train_epochs=10, # total number of training epochs\n",
" per_device_train_batch_size=4, # batch size per device during training\n",
" per_device_eval_batch_size=4, # batch size for evaluation\n",
" weight_decay=1e-2, # strength of weight decay\n",
" logging_dir='./logs', # directory for storing logs\n",
" warmup_ratio=0.01,\n",
" learning_rate= 3.14e-04,\n",
" eval_accumulation_steps=1,\n",
" gradient_accumulation_steps=16,\n",
" do_eval=True,\n",
" evaluation_strategy='epoch',\n",
" save_strategy='epoch',\n",
" metric_for_best_model='eval_precision',\n",
" load_best_model_at_end=True,\n",
" lr_scheduler_type=SchedulerType.LINEAR\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e25889f",
"metadata": {},
"outputs": [],
"source": [
"import wandb"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c2b1847",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"wandb.init(project='timecat', entity='wish', name=RUN_NAME + '-gpt-16-16_1day_no_base_data')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7f0c996",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model, # the instantiated 🤗 Transformers model to be trained\n",
" args=training_args, # training arguments, defined above\n",
" train_dataset=encoded_dataset['train'], # training dataset\n",
" eval_dataset=test_set_to_use, # evaluation dataset\n",
" compute_metrics=compute_metrics,\n",
" data_collator=collate_fn,\n",
" tokenizer=None,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1634717d",
"metadata": {},
"source": [
"#### Make sure stuff is correct"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b2c2162",
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "236c2c6e",
"metadata": {},
"outputs": [],
"source": [
"ind = 1117"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02fa347e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"for ty, p, t, c, ind_id in zip(encoded_dataset['train'][ind]['token_type'], encoded_dataset['train'][ind]['position_ids'], encoded_dataset['train'][ind]['time'], tokenizer.convert_ids2tokens(encoded_dataset['train'][ind]['input_ids']), encoded_dataset['train'][ind]['input_ids']):\n",
" print(datetime.fromtimestamp(t), p, \"{:20}\".format(ty), c, ind_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b0af3ff2",
"metadata": {},
"outputs": [],
"source": [
"encoded_dataset['train'][ind]['patient_id']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff4910d4",
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH"
]
},
{
"cell_type": "markdown",
"id": "d756cf7d",
"metadata": {},
"source": [
"# Run training "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3c7ed86",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "647efb10",
"metadata": {},
"outputs": [],
"source": [
"trainer.save_model(MODEL_PATH)"
]
},
{
"cell_type": "markdown",
"id": "967599dd",
"metadata": {},
"source": [
"# Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38cc2acd",
"metadata": {},
"outputs": [],
"source": [
"all_types = set(['T-11', 'T-45', 'T-55', 'T-18', 'T-26', 'T-40', 'T-39', 'T-49', 'T-29', 'T-34', \n",
" 'T-9', 'T-33', 'T-44', 'T-6', 'T-27', 'T-38', 'T-35', 'T-3', 'T-58'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a67dd7d7",
"metadata": {},
"outputs": [],
"source": [
"test_set_to_use = encoded_dataset['test']\n",
"test_set_to_use"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dadd7b0d",
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model, # the instantiated 🤗 Transformers model to be trained\n",
" args=training_args, # training arguments, defined above\n",
" train_dataset=None, # training dataset\n",
" eval_dataset=None, # evaluation dataset\n",
" compute_metrics=None,\n",
" data_collator=collate_fn,\n",
" tokenizer=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "152ec628",
"metadata": {},
"outputs": [],
"source": [
"def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):\n",
" size = 1000\n",
" for i in range(int(math.ceil(len(test_set_to_use) / size))):\n",
" _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n",
" compute_metrics.time_data = _dataset['time']\n",
" compute_metrics.type_data = _dataset['token_type']\n",
" if len(_dataset):\n",
" p = trainer.predict(_dataset)\n",
" metrics_data = compute_metrics(p, metrics_data)['metrics_data']\n",
" m_file.write(\"{}, {}, {}, {}\\n\".format(f_name, metrics_data['precision']['all'], \n",
" metrics_data['precision']['new'], \n",
" metrics_data['precision']['old'],\n",
" metrics_data['recall']['all'],\n",
" metrics_data['recall']['new'],\n",
" metrics_data['recall']['old']))\n",
" print(f_name,\n",
" metrics_data['precision']['all'], \n",
" metrics_data['precision']['new'], \n",
" metrics_data['precision']['old'],\n",
" metrics_data['recall']['all'],\n",
" metrics_data['recall']['new'],\n",
" metrics_data['recall']['old']) \n",
" pickle.dump(metrics_data, f_name)\n",
"\n",
" return metrics_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4aee33b",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"m_file = open(\"./metrics/summary.txt\", 'w', buffering=1)\n",
"m_file.write(\"file_name, precision all, precision new, precision old\\n\")\n",
"\n",
"for types in [all_types, {'T-11'}, {'T-55'}, {'T-18'}, {'T-39'}]:\n",
" _types = list(types)[0] if len(types) == 1 else 'all_types'\n",
" for timerange in [30, 365, 1000000]:\n",
" compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
" prediction_scope='time_range', \n",
" topk=1, # 1, 5, 10\n",
" start=0, # 0, 10, 20, 50, 100\n",
" return_all_metrics=True, \n",
" batch_size=1000, \n",
" select_token_types=types,\n",
" type_data=test_set_to_use['token_type'],\n",
" token_type2tokens=tokenizer.token_type2tokens,\n",
" time_data=test_set_to_use['time'], \n",
" time_range=timerange*24*60*60, #30, 365, 1000000\n",
" ignore_label_status=False,\n",
" min_time_left=24*60*60)\n",
" f_name = f\"./metrics/start-0_topk-1_time_range-{timerange}_types-{_types}.pickle\"\n",
" get_metrics(None, test_set_to_use, trainer, m_file, f_name)\n",
"\n",
" for topk in [5, 10]:\n",
" compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
" prediction_scope='time_range', \n",
" topk=topk, # 1, 5, 10\n",
" start=0, # 0, 10, 20, 50, 100\n",
" return_all_metrics=True, \n",
" batch_size=1000, \n",
" select_token_types=types,\n",
" type_data=test_set_to_use['token_type'],\n",
" token_type2tokens=tokenizer.token_type2tokens,\n",
" time_data=test_set_to_use['time'], \n",
" time_range=30*24*60*60, #30, 365, 1000000\n",
" ignore_label_status=False,\n",
" min_time_left=24*60*60)\n",
" f_name = f\"./metrics/start-0_topk-{topk}_time_range-30_types-{_types}.pickle\"\n",
" get_metrics(None, test_set_to_use, trainer, m_file, f_name)\n",
"m_file.close()"
]
},
{
"cell_type": "markdown",
"id": "aa75db95",
"metadata": {},
"source": [
"# Test Death"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "daad35f1",
"metadata": {},
"outputs": [],
"source": [
"all_types = set(['death'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb93c9e7",
"metadata": {},
"outputs": [],
"source": [
"test_set_to_use"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d20a61c",
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model, # the instantiated 🤗 Transformers model to be trained\n",
" args=training_args, # training arguments, defined above\n",
" train_dataset=None, # training dataset\n",
" eval_dataset=None, # evaluation dataset\n",
" compute_metrics=None,\n",
" data_collator=collate_fn,\n",
" tokenizer=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebabb840",
"metadata": {},
"outputs": [],
"source": [
"def get_metrics(metrics_data=None, test_set_to_use=None, trainer=None, m_file=None, f_name=None):\n",
" size = 1000\n",
" for i in range(int(math.ceil(len(test_set_to_use) / size))):\n",
" _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n",
" compute_metrics.time_data = _dataset['time']\n",
" compute_metrics.type_data = _dataset['token_type']\n",
" if len(_dataset):\n",
" p = trainer.predict(_dataset)\n",
" metrics_data = compute_metrics(p, metrics_data)['metrics_data']\n",
" m_file.write(\"{}, {}, {}, {}\\n\".format(f_name, metrics_data['precision']['all'], \n",
" metrics_data['precision']['new'], \n",
" metrics_data['precision']['old']))\n",
" print(f_name,\n",
" metrics_data['precision']['all'], \n",
" metrics_data['precision']['new'], \n",
" metrics_data['precision']['old'])\n",
" pickle.dump(metrics_data, f_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "202e1dc1",
"metadata": {},
"outputs": [],
"source": [
"compute_metrics = ComputePrecisionHF(tokenizer.id2tkn, \n",
" topk=1, # 1, 5, 10\n",
" start=0, # 0, 10, 20, 50, 100\n",
" return_all_metrics=True, \n",
" batch_size=1000, \n",
" type_data=test_set_to_use['token_type'],\n",
" token_type2tokens=tokenizer.token_type2tokens,\n",
" time_data=test_set_to_use['time'], \n",
" time_range=24*60*60, #30, 365, 1000000\n",
" ignore_label_status=False,\n",
" min_time_left=0,\n",
" concept_id=270)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71282b28",
"metadata": {},
"outputs": [],
"source": [
"metrics_data = None\n",
"_dataset = Dataset.from_dict(test_set_to_use[0:1000])\n",
"compute_metrics.time_data = _dataset['time']\n",
"compute_metrics.type_data = _dataset['token_type']\n",
"if len(_dataset):\n",
" p = trainer.predict(_dataset)\n",
" metrics_data = compute_metrics(p, metrics_data)['metrics_data']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7b85a03",
"metadata": {},
"outputs": [],
"source": [
"metrics_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51c716a0",
"metadata": {},
"outputs": [],
"source": [
"tokenizer.tkn2id['The patient has died']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69a9a505",
"metadata": {},
"outputs": [],
"source": [
"for i in range(len(_dataset['input_ids'])):\n",
" if 270 in _dataset['input_ids'][i]:\n",
" print(i)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24bb8e62",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"metrics_data = None\n",
"size = 1000\n",
"for i in range(int(math.ceil(len(test_set_to_use) / size))):\n",
" _dataset = Dataset.from_dict(test_set_to_use[i*size:(i+1)*size])\n",
" compute_metrics.time_data = _dataset['time']\n",
" compute_metrics.type_data = _dataset['token_type']\n",
" if len(_dataset):\n",
" p = trainer.predict(_dataset)\n",
" metrics_data = compute_metrics(p, metrics_data)['metrics_data']"
]
},
{
"cell_type": "markdown",
"id": "055f6a43",
"metadata": {},
"source": [
"# Hyperparameter search"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe6a39e1",
"metadata": {},
"outputs": [],
"source": [
"from ray.tune.schedulers import PopulationBasedTraining\n",
"from ray import tune\n",
"from ray.tune import CLIReporter \n",
"import ray"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5105ae66",
"metadata": {},
"outputs": [],
"source": [
"compute_metrics = ComputePrecisionHF(id2tkn, id2type, prediction_scope='age', topk=1, start=0, batch_size=2000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea5f2956",
"metadata": {},
"outputs": [],
"source": [
"NUM_TRIALS = 20\n",
"N_GPU_PER_TRIAL = 1\n",
"METRIC_TO_OPTIMIZE = 'eval_precision'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27783bf0",
"metadata": {},
"outputs": [],
"source": [
"def get_model(params):\n",
" torch.cuda.empty_cache()\n",
" if params is None:\n",
" params = {}\n",
" \n",
" config = GPT2Config(\n",
" vocab_size=len(embeddings),\n",
" n_positions=MAX_SEQ_LEN+1,\n",
" n_ctx=MAX_SEQ_LEN+1,\n",
" n_embd=params.get('n_embd', 300),\n",
" n_layer=params.get('n_layer', 1),\n",
" n_head=params.get('n_head', 1),\n",
" bos_token_id=tkn2id['<PAD>'],\n",
" eos_token_id=tkn2id['<PAD>']\n",
" )\n",
" model = GPT2LMHeadModel(config)\n",
" \n",
" if params.get('load_weights', 0):\n",
" model.transformer.wte.load_state_dict({'weight': torch.tensor(embeddings, dtype=torch.float32)})\n",
" model.transformer.wte.weight.requires_grad = True\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16e4a39d",
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\n",
" output_dir='./results', # output directory\n",
" num_train_epochs=5, # total number of training epochs\n",
" per_device_train_batch_size=16, # batch size per device during training\n",
" per_device_eval_batch_size=128, # batch size for evaluation\n",
" weight_decay=0.01, # strength of weight decay\n",
" logging_dir='./logs', # directory for storing logs\n",
" logging_steps=200,\n",
" eval_steps=200,\n",
" learning_rate= 5e-5,\n",
" eval_accumulation_steps=1,\n",
" do_eval=True,\n",
" evaluation_strategy='steps',\n",
" skip_memory_metrics=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0edc2eb1",
"metadata": {},
"outputs": [],
"source": [
"training_args.n_head = 1\n",
"training_args.n_layer = 1\n",
"training_args.n_embd = 300\n",
"training_args.load_weights = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e46e8cb6",
"metadata": {},
"outputs": [],
"source": [
"tune_dataset = encoded_dataset['train'].train_test_split(test_size=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f111f0a",
"metadata": {},
"outputs": [],
"source": [
"tune_train_dataset = tune_dataset['train']\n",
"tune_test_dataset = tune_dataset['test']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f085ab71",
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
"# model=model, # the instantiated 🤗 Transformers model to be trained\n",
" args=training_args, # training arguments, defined above\n",
" train_dataset=tune_train_dataset, # training dataset\n",
" eval_dataset=tune_test_dataset, # evaluation dataset\n",
" compute_metrics=compute_metrics,\n",
" data_collator=collate_fn,\n",
" tokenizer=None,\n",
" model_init=get_model,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0356cdc",
"metadata": {},
"outputs": [],
"source": [
"tune_config = {\n",
" \"num_train_epochs\": tune.choice([5]),\n",
" \"n_head\": tune.choice([2, 4, 6]),\n",
"}\n",
"scheduler = PopulationBasedTraining(\n",
" time_attr=\"training_iteration\",\n",
" metric=METRIC_TO_OPTIMIZE,\n",
" mode=\"max\",\n",
" perturbation_interval=1,\n",
" hyperparam_mutations={\n",
" \"weight_decay\": tune.uniform(0.0, 0.3),\n",
" \"learning_rate\": tune.uniform(1e-5, 5e-5),\n",
" \"per_device_train_batch_size\": [16, 32, 64, 128],\n",
" \"n_layer\": tune.choice([2, 4, 6, 8]),\n",
"# \"n_embd\": tune.choice([256, 512]),\n",
" \"load_weights\": tune.choice([0, 1]),\n",
" \"warmup_steps\": tune.choice([20, 40, 60, 100]),\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63f3fd7d",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"def compute_objective(metrics):\n",
" metrics = copy.deepcopy(metrics)\n",
" eval_precision = metrics.pop('eval_precision')\n",
" \n",
" return eval_precision"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23e5bff1",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"best_model = trainer.hyperparameter_search(\n",
" hp_space=lambda _: tune_config,\n",
" backend=\"ray\",\n",
" n_trials=NUM_TRIALS,\n",
" direction='maximize',\n",
" compute_objective=compute_objective,\n",
" resources_per_trial={\n",
" \"cpu\": 1,\n",
" \"gpu\": N_GPU_PER_TRIAL\n",
" },\n",
" scheduler=scheduler,\n",
" keep_checkpoints_num=1,\n",
" checkpoint_score_attr=METRIC_TO_OPTIMIZE,\n",
" stop=None,\n",
" local_dir=RESULTS_HYPERPARAM,\n",
" name=\"21_May_2021\",\n",
" log_to_file=False,\n",
" loggers=None,# (WandbLogger, ),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "edf5a0ab",
"metadata": {},
"outputs": [],
"source": [
"best_model"
]
},
{
"cell_type": "markdown",
"id": "9db3fae5",
"metadata": {},
"source": [
"# Saliency "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "364f1896",
"metadata": {},
"outputs": [],
"source": [
"import ecco"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b237de8",
"metadata": {},
"outputs": [],
"source": [
"lm = ecco.LM(trainer.model, tokenizer, model_name='gpt2')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92ad3750",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"ind = 49\n",
"print(\"~~\".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids']]))\n",
"text = \"~~\".join([tokenizer.id2tkn[id] for id in encoded_dataset['test'][ind]['input_ids'][1:-1]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28108d90",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"output = lm.generate(text, generate=10, do_sample=True, temperature=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b1b3898",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"output.saliency(style=\"detailed\")"
]
},
{
"cell_type": "markdown",
"id": "3dcb6cf5",
"metadata": {},
"source": [
"# Probability prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09f0fbb7",
"metadata": {},
"outputs": [],
"source": [
"from foresight.sight import Sight"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4b0662d",
"metadata": {},
"outputs": [],
"source": [
"_ = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6d741d6",
"metadata": {},
"outputs": [],
"source": [
"sight = Sight(tokenizer=tokenizer, device='cuda', model=model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1b9bfec",
"metadata": {},
"outputs": [],
"source": [
"cdb.name2cuis['muscle~pain']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f92d2766",
"metadata": {},
"outputs": [],
"source": [
"cdb.get_name('pain')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b545e844",
"metadata": {},
"outputs": [],
"source": [
"text = '<ETHNICITY>~~White~~<SEX>~~Male~~<AGE>~~23~~49727002~~386661006'.split(\"~~\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a7fc62e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Small with WD\n",
"r = sight.next_concepts(text, type_ids=['T-11'], n=40, p_new=True, create_position_ids=False)\n",
"print([cdb.get_name(x) for x in text])\n",
"for x in r:\n",
" print(x[0], x[1], cdb.get_name(x[0]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}