606 lines (605 with data), 25.9 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/varadi_kristof/llms-for-trials/src/hint\n"
]
}
],
"source": [
"%cd ../"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7febc6a6dbf0>"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import warnings\n",
"import numpy\n",
"import random\n",
"\n",
"seed = 42\n",
"warnings.filterwarnings(\"ignore\")\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"device=torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def seed_worker(worker_id):\n",
" worker_seed = torch.initial_seed() % 2**32\n",
" numpy.random.seed(worker_seed)\n",
" random.seed(worker_seed)\n",
"\n",
"g = torch.Generator()\n",
"g.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"import ast\n",
"import pandas as pd\n",
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"from toxicity.model import MultitaskToxicityModel, load_ckp\n",
"from trial.model import TrialModel, Trainer as TrialTrainer\n",
"from trial.protocol import ProtocolEmbedding\n",
"from trial.disease_encoder import GRAM, build_icdcode2ancestor_dict"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"icdcode2ancestor_dict = build_icdcode2ancestor_dict()\n",
"gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device=device).to(device)\n",
"protocol_model = ProtocolEmbedding(hf_model=\"emilyalsentzer/Bio_ClinicalBERT\", device=device).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"data_dir = \"./data\"\n",
"trial_data_dir = f\"{data_dir}/trial\"\n",
"model_dir = \"./checkpoints/toxicity\"\n",
"protocol_embedding_file = f\"{data_dir}/protocol_embeddings.pth\"\n",
"smiles_embedding_dir = \"toxicity/smiles_embedding\"\n",
"\n",
"smiles_embed_train = torch.load(f\"{data_dir}/{smiles_embedding_dir}/smiles_embed_train.pt\")\n",
"smiles_embed_valid = torch.load(f\"{data_dir}/{smiles_embedding_dir}/smiles_embed_valid.pt\")\n",
"smiles_embed_test = torch.load(f\"{data_dir}/{smiles_embedding_dir}/smiles_embed_test.pt\")\n",
"\n",
"smiles_embeddings = {**smiles_embed_train, **smiles_embed_valid, **smiles_embed_test}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"clintox_task = ['CT_TOX']\n",
"tox21_tasks = ['NR-AR', 'NR-Aromatase', 'NR-PPAR-gamma', 'SR-HSE', \n",
" 'NR-AR-LBD', 'NR-ER', 'SR-ARE', 'SR-MMP',\n",
" 'NR-AhR', 'NR-ER-LBD', 'SR-ATAD5', 'SR-p53']\n",
"\n",
"all_tasks = tox21_tasks + clintox_task\n",
"\n",
"first_smiles = next(iter(smiles_embed_train.values()))\n",
"input_shape = first_smiles.shape[0]\n",
"\n",
"model = MultitaskToxicityModel(input_shape, all_tasks).to(device)\n",
"toxicity_model, _, _, _ = load_ckp(f\"{model_dir}/best_model_by_valid.pt\", model, None)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_name = 'mtdnn_multiphase_small'\n",
"\n",
"train_file = os.path.join(trial_data_dir,'train.csv')\n",
"valid_file = os.path.join(trial_data_dir, 'valid.csv')\n",
"test_file = os.path.join(trial_data_dir, 'test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"def explode_list(row):\n",
" smiles_list = ast.literal_eval(row)\n",
" return smiles_list\n",
"\n",
"def extract_smiles_embed(smiles_row: str):\n",
" smiles_list = explode_list(smiles_row)\n",
" embeddings = [smiles_embeddings.get(smiles, torch.zeros(input_shape)) for smiles in smiles_list]\n",
"\n",
" if embeddings:\n",
" embs = embeddings[0]\n",
" else:\n",
" embs = torch.zeros(input_shape)\n",
" return embs"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def extract_icd(text):\n",
" text = text[2:-2]\n",
" lst_lst = []\n",
" for i in text.split('\", \"'):\n",
" i = i[1:-1]\n",
" lst_lst.append([j.strip()[1:-1] for j in i.split(',')])\n",
" return lst_lst "
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"def prepare_trial_df(df):\n",
" df[[\"criteria\"]] = df[[\"criteria\"]].fillna(value=\"\")\n",
" df[[\"smiless\"]] = df[[\"smiless\"]].fillna(value=\"[]\")\n",
" df[[\"icdcodes\"]] = df[[\"icdcodes\"]].fillna(value=\"[]\")\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"train_df = prepare_trial_df(pd.read_csv(train_file))\n",
"valid_df = prepare_trial_df(pd.read_csv(valid_file))\n",
"test_df = prepare_trial_df(pd.read_csv(test_file))\n",
"\n",
"multiphase_df = pd.concat([train_df, valid_df, test_df])\n",
"all_phase_categories = pd.concat([train_df['phase'], valid_df['phase'], test_df['phase']]).unique()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"if not os.path.exists(protocol_embedding_file):\n",
" \n",
" def embedding_collate_fn(batch):\n",
" batch_inputs = {key: [d[key] for d in batch] for key in batch[0]}\n",
"\n",
" batch_inputs[\"smiless\"] = torch.stack([extract_smiles_embed(smiles) for smiles in batch_inputs[\"smiless\"]])\n",
" batch_inputs[\"icdcodes\"] = [extract_icd(icd) for icd in batch_inputs[\"icdcodes\"]]\n",
" batch_inputs[\"criteria\"] = [protocol_model.tokenizer(criteria, padding=True) for criteria in batch_inputs[\"criteria\"]]\n",
" batch_inputs[\"nctids\"] = batch_inputs[\"nctids\"]\n",
" batch_inputs[\"labels\"] = torch.tensor(batch_inputs[\"labels\"])\n",
" batch_inputs[\"phase\"] = torch.stack(batch_inputs[\"phase\"]).float()\n",
"\n",
" return (batch_inputs[\"nctids\"], batch_inputs[\"labels\"], batch_inputs[\"smiless\"], batch_inputs[\"icdcodes\"], batch_inputs[\"criteria\"], batch_inputs[\"phase\"])\n",
" \n",
" protocol_embeddings = {}\n",
" multiphase_dataset = TrialDataset(multiphase_df, all_phase_categories)\n",
" multiphase_dataloader = DataLoader(multiphase_dataset, batch_size=64, shuffle=False, collate_fn=embedding_collate_fn)\n",
" \n",
" for nctids, labels, smiles, icdcodes, criteria, phase in tqdm(multiphase_dataloader):\n",
" criteria_embs = protocol_model(criteria).mean(dim=1).cpu()\n",
" for nctid, emb in zip(nctids, criteria_embs):\n",
" protocol_embeddings[nctid] = emb\n",
" \n",
" torch.save(protocol_embeddings, protocol_embedding_file)\n",
" \n",
"else:\n",
" protocol_embeddings = torch.load(protocol_embedding_file)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"class TrialDataset(Dataset):\n",
" def __init__(self, dataframe, phase_categories):\n",
" self.dataframe = dataframe\n",
" self.phase_categories = phase_categories\n",
" phase_dummies = pd.get_dummies(self.dataframe['phase']).reindex(columns=phase_categories, fill_value=0)\n",
" self.dataframe = pd.concat([self.dataframe, phase_dummies], axis=1)\n",
" self.phase_columns = phase_categories\n",
" \n",
" def __len__(self):\n",
" return len(self.dataframe)\n",
" \n",
" def __getitem__(self, idx):\n",
" data = self.dataframe.iloc[idx]\n",
" phase_data = torch.tensor(data[self.phase_columns].values.astype(float))\n",
" return {\n",
" \"nctids\": data['nctid'],\n",
" \"labels\": data['label'],\n",
" \"smiless\": data['smiless'],\n",
" \"criteria\": data['criteria'],\n",
" \"icdcodes\": data['icdcodes'],\n",
" \"phase\": phase_data\n",
" }\n",
"\n",
"\n",
"def trial_collate_fn(batch):\n",
" batch_inputs = {key: [d[key] for d in batch] for key in batch[0]}\n",
" \n",
" batch_inputs[\"smiless\"] = torch.stack([extract_smiles_embed(smiles) for smiles in batch_inputs[\"smiless\"]])\n",
" batch_inputs[\"icdcodes\"] = [extract_icd(icd) for icd in batch_inputs[\"icdcodes\"]]\n",
" batch_inputs[\"criteria\"] = torch.stack([protocol_embeddings.get(nctid) for nctid in batch_inputs[\"nctids\"]]).float()\n",
" batch_inputs[\"nctids\"] = batch_inputs[\"nctids\"]\n",
" batch_inputs[\"labels\"] = torch.tensor(batch_inputs[\"labels\"])\n",
" batch_inputs[\"phase\"] = torch.stack(batch_inputs[\"phase\"]).float()\n",
" \n",
" return (batch_inputs[\"nctids\"], batch_inputs[\"labels\"], batch_inputs[\"smiless\"], batch_inputs[\"icdcodes\"], batch_inputs[\"criteria\"], batch_inputs[\"phase\"])"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"train_val_df, test_df = train_test_split(multiphase_df, test_size=0.2, random_state=42)\n",
"train_df, valid_df = train_test_split(train_val_df, test_size=0.25, random_state=42)\n",
"\n",
"train_dataset = TrialDataset(train_df, all_phase_categories)\n",
"valid_dataset = TrialDataset(valid_df, all_phase_categories)\n",
"test_dataset = TrialDataset(test_df, all_phase_categories)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=trial_collate_fn, worker_init_fn=seed_worker, generator=g)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn, worker_init_fn=seed_worker, generator=g)\n",
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=trial_collate_fn, worker_init_fn=seed_worker, generator=g)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"phase_dim = train_dataset[0][\"phase\"].shape[-1]"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"hint_model_path = f\"./checkpoints/{model_name}.ckpt\"\n",
"\n",
"model = TrialModel(\n",
" toxicity_encoder = toxicity_model, \n",
" disease_encoder = gram_model, \n",
" protocol_embedding_size = protocol_model.embedding_size,\n",
" embedding_size = 100, \n",
" num_ffn_layers=2,\n",
" num_pred_layers=3,\n",
" dropout = 0.0,\n",
" phase_dim=phase_dim,\n",
" name=model_name,\n",
" device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Finishing last run (ID:6ql5atha) before initializing another..."
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style>\n",
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
" </style>\n",
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train_loss</td><td>█▃▂▂▁</td></tr><tr><td>valid_loss</td><td>▁█▂▄▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train_loss</td><td>0.54669</td></tr><tr><td>valid_loss</td><td>0.62324</td></tr></table><br/></div></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run <strong style=\"color:#cdcd00\">vocal-universe-32</strong> at: <a href='https://wandb.ai/betonitcso/trial_outcome_prediction/runs/6ql5atha' target=\"_blank\">https://wandb.ai/betonitcso/trial_outcome_prediction/runs/6ql5atha</a><br/> View project at: <a href='https://wandb.ai/betonitcso/trial_outcome_prediction' target=\"_blank\">https://wandb.ai/betonitcso/trial_outcome_prediction</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: <code>./wandb/run-20240520_145828-6ql5atha/logs</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Successfully finished last run (ID:6ql5atha). Initializing new run:<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a1b12cbdfa44967aab3b34cc46b45b6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011113153977526559, max=1.0…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.17.0"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/home/varadi_kristof/llms-for-trials/src/hint/wandb/run-20240520_150346-s3dwczig</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/betonitcso/trial_outcome_prediction/runs/s3dwczig' target=\"_blank\">resilient-feather-33</a></strong> to <a href='https://wandb.ai/betonitcso/trial_outcome_prediction' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/betonitcso/trial_outcome_prediction' target=\"_blank\">https://wandb.ai/betonitcso/trial_outcome_prediction</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/betonitcso/trial_outcome_prediction/runs/s3dwczig' target=\"_blank\">https://wandb.ai/betonitcso/trial_outcome_prediction/runs/s3dwczig</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:28<00:00, 8.15it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:25<00:00, 9.04it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:21<00:00, 11.05it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:19<00:00, 11.76it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:21<00:00, 10.73it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:18<00:00, 12.59it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:22<00:00, 10.23it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:21<00:00, 10.90it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:19<00:00, 11.79it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:20<00:00, 11.34it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:19<00:00, 12.27it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:23<00:00, 9.99it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:18<00:00, 12.94it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:20<00:00, 11.58it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:23<00:00, 9.84it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:24<00:00, 9.46it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:20<00:00, 11.30it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:18<00:00, 12.72it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:19<00:00, 11.87it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████| 234/234 [00:18<00:00, 12.53it/s]\n"
]
}
],
"source": [
"trainer = TrialTrainer(model, lr=1e-3, weight_decay=0, device=device)\n",
"num_epochs = 20\n",
"metrics = trainer.train(num_epochs, train_loader, valid_loader, test_loader)\n",
"torch.save(model, hint_model_path)\n",
"\n",
"print(\"Test results\\n\\n\")\n",
"test_results = trainer.test(test_loader, all_phase_categories)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(\"Bootstrap test results\\n\")\n",
"bootstrap_results = trainer.bootstrap_test(test_loader, all_phase_categories)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"multiphase_df.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 4
}