[b5ec00]: / MLP_Model / MLP_models.ipynb

Download this file

1 lines (1 with data), 158.3 kB

{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"MLP_models.ipynb","provenance":[],"collapsed_sections":["pMOATtfckQrD","-5k5NxKfkW7p","zyvdoJ0Vmuog","WA8cHUO8pMlb","qWjp7FlTrTWg"]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"id":"84nJFHFrK4Tp","executionInfo":{"status":"ok","timestamp":1651886496422,"user_tz":240,"elapsed":11193,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"outputs":[],"source":["import tensorflow as tf\n","import tensorflow.keras.layers as layers\n","import torch\n","import numpy as np\n","import pandas as pd\n","import os\n","import pickle\n","import matplotlib.pyplot as plt\n","from sklearn import metrics"]},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount(\"/content/gdrive\")"],"metadata":{"id":"NwTz__WVK_sU","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1651886514076,"user_tz":240,"elapsed":17660,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"07a021e3-bd4a-4664-f783-7281dc89b820"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n"]}]},{"cell_type":"markdown","source":["Loading the datasets"],"metadata":{"id":"kMU4jE5qK6-l"}},{"cell_type":"code","source":["# project folder\n","deep_learning_dir = \"/content/gdrive/My Drive/BMI 707 Project\""],"metadata":{"id":"lifWknGULF_d","executionInfo":{"status":"ok","timestamp":1651886514078,"user_tz":240,"elapsed":15,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["# get nctid's and labels\n","train = pd.read_csv(deep_learning_dir + \"/data_formatting/training_data.tsv\", sep=\"\\t\")\n","val = pd.read_csv(deep_learning_dir + \"/data_formatting/validation_data.tsv\", sep=\"\\t\")\n","test = pd.read_csv(deep_learning_dir + \"/data_formatting/testing_data.tsv\", sep=\"\\t\")\n","\n","train_nctids = train[\"nctid\"].to_numpy()\n","train_labels = train[\"label\"].to_numpy()\n","\n","val_nctids = val[\"nctid\"].to_numpy()\n","val_labels = val[\"label\"].to_numpy()\n","\n","test_nctids = test[\"nctid\"].to_numpy()\n","test_labels = test[\"label\"].to_numpy()"],"metadata":{"id":"7-XMLJW8LH6x","executionInfo":{"status":"ok","timestamp":1651886516727,"user_tz":240,"elapsed":2660,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":4,"outputs":[]},{"cell_type":"markdown","source":["Loading the embeddings"],"metadata":{"id":"LHlFydAxK8SM"}},{"cell_type":"code","source":["## diseases\n","with open(deep_learning_dir + \"/embeddings/nctid2diseases.pkl\", \"rb\") as handle:\n","  nctid2diseases = pickle.load(handle)\n","\n","\n","## eligibility criteria\n","with open(deep_learning_dir + \"/embeddings/nctid2incl_criteria.pkl\", \"rb\") as handle:\n","  nctid2incl_criteria = pickle.load(handle)\n","\n","with open(deep_learning_dir + \"/embeddings/nctid2excl_criteria.pkl\", \"rb\") as handle:\n","  nctid2excl_criteria = pickle.load(handle)\n","\n","\n","## drugs\n","with open(deep_learning_dir + \"/embeddings/nctid2drugs.pkl\", \"rb\") as handle:\n","  nctid2drugs = pickle.load(handle)\n","\n","\n","## targets\n","with open(deep_learning_dir + \"/embeddings/nctid2drug_targets.pkl\", \"rb\") as handle:\n","  nctid2target = pickle.load(handle)"],"metadata":{"id":"KrPTAj29LABG","executionInfo":{"status":"ok","timestamp":1651886541109,"user_tz":240,"elapsed":24390,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["# convert to np array\n","for nctid in nctid2diseases:\n","  nctid2incl_criteria[nctid] = nctid2incl_criteria[nctid].numpy()\n","  nctid2excl_criteria[nctid] = nctid2excl_criteria[nctid].numpy()\n","  \n","  if nctid in nctid2target:\n","    nctid2target[nctid] = nctid2target[nctid].numpy()\n","  else:\n","    nctid2target[nctid] = np.zeros(1024, dtype=np.float32)"],"metadata":{"id":"RqpIFddbZ4MI","executionInfo":{"status":"ok","timestamp":1651886541113,"user_tz":240,"elapsed":19,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":6,"outputs":[]},{"cell_type":"code","source":["## meta data\n","with open(deep_learning_dir + \"/embeddings/nctid2n_diseases.pkl\", \"rb\") as handle:\n","  nctid2n_diseases = pickle.load(handle)\n","\n","\n","with open(deep_learning_dir + \"/embeddings/length_nctid2incl_criteria.pkl\", \"rb\") as handle:\n","  nctid2n_incl = pickle.load(handle)\n","\n","with open(deep_learning_dir + \"/embeddings/length_nctid2excl_criteria.pkl\", \"rb\") as handle:\n","  nctid2n_excl = pickle.load(handle)\n","\n","\n","with open(deep_learning_dir + \"/embeddings/nctid2npart_success.pkl\", \"rb\") as handle:\n","  nctid2npart_success = pickle.load(handle)"],"metadata":{"id":"oOPY5HPIj4ot","executionInfo":{"status":"ok","timestamp":1651886542289,"user_tz":240,"elapsed":1188,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["# concatenate meta info\n","nctid2meta = {}\n","\n","for nctid in nctid2diseases:\n","  nctid2meta[nctid] = np.array([nctid2n_diseases[nctid], nctid2n_incl[nctid], nctid2n_excl[nctid], \n","                                nctid2npart_success[nctid][0], nctid2npart_success[nctid][1]])"],"metadata":{"id":"NHQ96_4dxEDj","executionInfo":{"status":"ok","timestamp":1651886542291,"user_tz":240,"elapsed":14,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":8,"outputs":[]},{"cell_type":"markdown","source":["Build input data for the neural network."],"metadata":{"id":"LCV1mhpans5A"}},{"cell_type":"code","source":["def build_data(nctids, nctid_dict):\n","  \"\"\"\n","  Build NN input matrix for given list of nctid's and a dictionary,\n","  which maps an nctid to an embedding.\n","  \"\"\"\n","  n_rows = len(nctids)\n","  n_cols = len(list(nctid_dict.values())[0])\n","\n","  data = np.zeros((n_rows, n_cols), dtype=float)\n","\n","  for i,nctid in enumerate(nctids):\n","    \n","    data[i,:] = nctid_dict[nctid]\n","\n","  return data"],"metadata":{"id":"wvWOoXLdnSd7","executionInfo":{"status":"ok","timestamp":1651886542292,"user_tz":240,"elapsed":12,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["# train inputs\n","train_data_diseases = build_data(train_nctids, nctid2diseases)\n","train_data_incl = build_data(train_nctids, nctid2incl_criteria)\n","train_data_excl = build_data(train_nctids, nctid2excl_criteria)\n","train_data_drug = build_data(train_nctids, nctid2drugs)\n","train_data_target = build_data(train_nctids, nctid2target)\n","train_data_meta = build_data(train_nctids, nctid2meta)\n","\n","# validation inputs\n","val_data_diseases = build_data(val_nctids, nctid2diseases)\n","val_data_incl = build_data(val_nctids, nctid2incl_criteria)\n","val_data_excl = build_data(val_nctids, nctid2excl_criteria)\n","val_data_drug = build_data(val_nctids, nctid2drugs)\n","val_data_target = build_data(val_nctids, nctid2target)\n","val_data_meta = build_data(val_nctids, nctid2meta)\n","\n","# test inputs\n","test_data_diseases = build_data(test_nctids, nctid2diseases)\n","test_data_incl = build_data(test_nctids, nctid2incl_criteria)\n","test_data_excl = build_data(test_nctids, nctid2excl_criteria)\n","test_data_drug = build_data(test_nctids, nctid2drugs)\n","test_data_target = build_data(test_nctids, nctid2target)\n","test_data_meta = build_data(test_nctids, nctid2meta)"],"metadata":{"id":"kaPmiRbXqilL","executionInfo":{"status":"ok","timestamp":1651886544063,"user_tz":240,"elapsed":1780,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["# example\n","train_data_target.shape"],"metadata":{"id":"ay8w9686rFwa","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1651886555520,"user_tz":240,"elapsed":179,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"2efe023d-4536-463d-d84f-00dcfd57f479"},"execution_count":13,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(3094, 1024)"]},"metadata":{},"execution_count":13}]},{"cell_type":"code","source":["# input dims of the model, i.e. the embedding dimensions\n","diseases_dim = train_data_diseases.shape[1]\n","\n","incl_criteria_dim = train_data_incl.shape[1]\n","excl_criteria_dim = train_data_excl.shape[1]\n","\n","drug_dim = train_data_drug.shape[1]\n","\n","targets_dim = train_data_target.shape[1]\n","\n","meta_dim = train_data_meta.shape[1]"],"metadata":{"id":"-l_kLfPsbe0A","executionInfo":{"status":"ok","timestamp":1651886544067,"user_tz":240,"elapsed":10,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}}},"execution_count":12,"outputs":[]},{"cell_type":"markdown","source":["# Final Model"],"metadata":{"id":"s9LU1NeKE1pV"}},{"cell_type":"markdown","source":["# Automatic model evaluation\n","\n","Includes: ROC-plot, Precision-Recall curve, accuracy, ROC-AUC, PR-AUC, F1-score"],"metadata":{"id":"_HUYggIvVSeS"}},{"cell_type":"code","source":["def mean_and_sd(l):\n","  return np.mean(l), np.sqrt(np.var(l))"],"metadata":{"id":"LEmf4c6ZlR2d"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def plot_learning_curve(history, save_path=None):\n","  \"\"\"\n","  Plot training and validation loss per epoch.\n","  \"\"\"\n","  plt.figure(figsize=(8,5))\n","  plt.title(\"Model accuracy\")\n","\n","  plt.plot(history.history[\"loss\"], label=\"Train\", color=\"black\")\n","  plt.plot(history.history[\"val_loss\"], label=\"Validation\", color=\"#990000\")\n","  \n","  plt.xlabel(\"epoch\")\n","  plt.ylabel(\"loss\")\n","  \n","  plt.legend(loc=\"best\")\n","\n","  if save_path: plt.savefig(save_path)\n","  \n","  plt.show()\n","\n","  \n","\n","\n","\n","def model_performance(model, input, y_true, save_path=None):\n","  \"\"\"\n","  Compute performance of 'model' applied to 'input' data with true labels 'y_true'\n","  Saves the ROC plot and precision-recall curve to 'save_path'\n","\n","  Return: dict\n","    accuracy, ROC-AUC, PR-AUC, F1-score\n","  \"\"\"\n","\n","  pred_prob = model.predict(input).squeeze()\n","  y_pred = (pred_prob > 0.5).astype(\"int32\")\n","\n","  #fig, axes = plt.subplots(1, 2, figsize=(10, 5))\n","\n","  fpr, tpr, _ = metrics.roc_curve(y_true, pred_prob)\n","  #metrics.RocCurveDisplay(fpr=fpr, tpr=tpr).plot(ax=axes[0], color=\"black\")\n","\n","  prec, recall, _ = metrics.precision_recall_curve(y_true, pred_prob)\n","  #metrics.PrecisionRecallDisplay(precision=prec, recall=recall).plot(ax=axes[1], color=\"black\")\n","\n","  #axes[0].set(xlabel=\"1 - specificity\", ylabel=\"sensitivity\", title=\"ROC plot\")\n","  #axes[1].set(xlabel=\"recall\", ylabel=\"precision\", title=\"Precision-Recall Curve\")\n","\n","  if save_path: plt.savefig(save_path)\n","  \n","  roc_auc = metrics.auc(fpr, tpr)\n","  pr_auc = metrics.auc(recall, prec)\n","\n","  ac = metrics.accuracy_score(y_true, y_pred)\n","  f1 = metrics.precision_recall_fscore_support(y_true, y_pred, average=\"binary\")[2]\n","\n","  return {\"accuracy\": np.round(ac, 3), \"roc_auc\": np.round(roc_auc, 3), \n","          \"pr_auc\": np.round(pr_auc, 3), \"F1\": np.round(f1, 3)}"],"metadata":{"id":"4ZuUb5tX1FlO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# number of successful test trials\n","sum(test_labels) / len(test_labels)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"9zyIEbrLGv2T","executionInfo":{"status":"ok","timestamp":1651874955515,"user_tz":240,"elapsed":248,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"38d2f868-09cb-49d0-ad41-77b21e6fab15"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["0.7495636998254799"]},"metadata":{},"execution_count":15}]},{"cell_type":"code","source":["min_val_loss = np.inf\n","best_dropout = 1.\n","best_lr = np.inf"],"metadata":{"id":"rtnEdJoGmjXP"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dropouts = np.linspace(0.2, 0.4, 6)\n","lrs = [0.01, 0.005, 0.001, 0.0005, 0.0001]\n","n_runs = 5"],"metadata":{"id":"Os3SK3xpm3Ct"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for do in dropouts:\n","\n","  for lr in lrs:\n","    \n","    print(f\"Current dropout: {do}, current learning rate:{lr}\")\n","    val_losses = np.zeros(n_runs, dtype=float)\n","\n","    for i in range(n_runs):\n","\n","      if i%5 == 0: print(f\"Run {i} for dropout {do} and learning rate {lr}\")\n","\n","      # inclusion and exclusion criteria\n","      inclusion_input = layers.Input(shape=(incl_criteria_dim,), name=\"inclusion\")\n","      inclusion_emb = layers.Dense(256, activation=\"relu\")(inclusion_input)\n","\n","      exclusion_input = layers.Input(shape=(excl_criteria_dim,), name=\"exclusion\")\n","      exclusion_emb = layers.Dense(256, activation=\"relu\")(exclusion_input)\n","\n","      inclusion_exclusion_raw = layers.Concatenate()([inclusion_emb, exclusion_emb])\n","      ie_dropout = layers.Dropout(rate=do)(inclusion_exclusion_raw)\n","      inclusion_exclusion_emb = layers.Dense(256, activation=\"relu\")(ie_dropout)\n","\n","      # diseases\n","      diseases_input = layers.Input(shape=(diseases_dim,), name=\"diseases\")\n","      diseases_emb = layers.Dense(128, activation=\"relu\")(diseases_input)\n","\n","\n","      # drug\n","      drug_input = layers.Input(shape=(drug_dim,), name=\"drug\")\n","      drug_emb = layers.Dense(64, activation=\"relu\")(drug_input)\n","\n","      # targets\n","      targets_input = layers.Input(shape=(targets_dim,), name=\"targets\")\n","      targets_emb = layers.Dense(64, activation=\"relu\")(targets_input)\n","\n","      # drug-target interaction\n","      drug_target_raw = layers.Concatenate()([drug_emb, targets_emb])\n","      dt = layers.Dropout(rate=do)(drug_target_raw)\n","      drug_target_emb = layers.Dense(64, activation=\"relu\")(dt)\n","\n","\n","      all_emb = layers.Concatenate()([inclusion_exclusion_emb, diseases_emb, drug_target_emb])\n","      ae = layers.Dropout(rate=do)(all_emb)\n","      trial_embedding1 = layers.Dense(128, activation=\"relu\")(ae)\n","\n","      meta_input = layers.Input(shape=(meta_dim,), name=\"meta\")\n","      emb_and_meta = layers.Concatenate()([trial_embedding1, meta_input])\n","\n","      trial_embedding2 = layers.Dense(64, activation=\"relu\")(emb_and_meta)\n","      trial_embedding3 = layers.Dense(32, activation=\"relu\")(trial_embedding2)\n","\n","      o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","      model = tf.keras.Model(inputs=[inclusion_input, exclusion_input, diseases_input, drug_input, targets_input, meta_input], outputs=[o])\n","      # model.summary()\n","\n","\n","      # compile and train\n","\n","      model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","      callback = tf.keras.callbacks.EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n","\n","      history = model.fit(\n","        x={\"inclusion\": train_data_incl, \"exclusion\": train_data_excl, \n","          \"diseases\": train_data_diseases, \"drug\": train_data_drug,\n","          \"targets\": train_data_target, \"meta\": train_data_meta},\n","        y={\"trial_success\": train_labels},\n","        validation_data=(\n","            {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","            \"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","            \"targets\": val_data_target, \"meta\": val_data_meta},\n","            {\"trial_success\": val_labels}\n","        ), \n","        epochs=10,\n","        batch_size=128,\n","        callbacks=[callback]\n","      )\n","\n","      val_losses[i] = min(history.history[\"val_loss\"])\n","    \n","    if np.mean(val_losses) < min_val_loss:\n","        \n","        best_dropout = do\n","        best_lr = lr\n","        min_val_loss = np.mean(val_losses)\n","        print(f\"Current best dropout: {best_dropout}, current best learning rate:{best_lr}, current best average loss: {np.mean(val_losses)}\")"],"metadata":{"id":"1XUeqe8AE3aL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(best_dropout, best_lr)\n","\n","# 0.32, 0.01"],"metadata":{"id":"Ls0cC6WzzJgO"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# RUN BEST MODEL AFTER GRID SEARCH 30 TIMES\n","n_runs = 30"],"metadata":{"id":"5-aBXz_fiL-3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1 = np.zeros(n_runs, dtype=float)\n","val_acc = np.zeros(n_runs, dtype=float)\n","val_pr_auc = np.zeros(n_runs, dtype=float)\n","val_roc_auc = np.zeros(n_runs, dtype=float)\n","\n","test_F1 = np.zeros(n_runs, dtype=float)\n","test_acc = np.zeros(n_runs, dtype=float)\n","test_pr_auc = np.zeros(n_runs, dtype=float)\n","test_roc_auc = np.zeros(n_runs, dtype=float)"],"metadata":{"id":"RVuo6IPlh4HP"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(n_runs):\n","\n","  # inclusion and exclusion criteria\n","  inclusion_input = layers.Input(shape=(incl_criteria_dim,), name=\"inclusion\")\n","  inclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_inclusion\")(inclusion_input)\n","\n","  exclusion_input = layers.Input(shape=(excl_criteria_dim,), name=\"exclusion\")\n","  exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_exclusion\")(exclusion_input)\n","\n","  inclusion_exclusion_raw = layers.Concatenate(name=\"criteria\")([inclusion_emb, exclusion_emb])\n","  ie_dropout = layers.Dropout(rate=0.32)(inclusion_exclusion_raw)\n","  inclusion_exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_criteria\")(ie_dropout)\n","\n","  # diseases\n","  diseases_input = layers.Input(shape=(diseases_dim,), name=\"diseases\")\n","  diseases_emb = layers.Dense(128, activation=\"relu\", name=\"lower_dim_diseases\")(diseases_input)\n","\n","\n","  # drug\n","  drug_input = layers.Input(shape=(drug_dim,), name=\"drug\")\n","  drug_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug\")(drug_input)\n","\n","  # targets\n","  targets_input = layers.Input(shape=(targets_dim,), name=\"targets\")\n","  targets_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_targets\")(targets_input)\n","\n","  # drug-target interaction\n","  drug_target_raw = layers.Concatenate(name=\"drug-targets\")([drug_emb, targets_emb])\n","  dt = layers.Dropout(rate=0.32)(drug_target_raw)\n","  drug_target_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug-targets\")(dt)\n","\n","\n","  all_emb = layers.Concatenate(name=\"all_embeddings\")([inclusion_exclusion_emb, diseases_emb, drug_target_emb])\n","  ae = layers.Dropout(rate=0.32)(all_emb)\n","  trial_embedding1 = layers.Dense(128, activation=\"relu\", name=\"trial_embedding_1\")(ae)\n","\n","  meta_input = layers.Input(shape=(meta_dim,), name=\"meta\")\n","  emb_and_meta = layers.Concatenate()([trial_embedding1, meta_input])\n","\n","  trial_embedding2 = layers.Dense(64, activation=\"relu\", name=\"trial_embedding_2\")(emb_and_meta)\n","  trial_embedding3 = layers.Dense(32, activation=\"relu\", name=\"trial_embedding_3\")(trial_embedding2)\n","\n","  o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","  model = tf.keras.Model(inputs=[inclusion_input, exclusion_input, diseases_input, drug_input, targets_input, meta_input], outputs=[o])\n","  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","  callback = tf.keras.callbacks.EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n","\n","  history = model.fit(\n","    x={\"inclusion\": train_data_incl, \"exclusion\": train_data_excl, \n","      \"diseases\": train_data_diseases, \"drug\": train_data_drug,\n","      \"targets\": train_data_target, \"meta\": train_data_meta},\n","    y={\"trial_success\": train_labels},\n","    validation_data=(\n","        {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","        \"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","        \"targets\": val_data_target, \"meta\": val_data_meta},\n","        {\"trial_success\": val_labels}\n","    ), \n","    epochs=10,\n","    batch_size=128,\n","    callbacks=[callback]\n","  )\n","\n","\n","  val_metrics = model_performance(\n","    model,\n","    {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl, \n","    \"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","    \"targets\": val_data_target, \"meta\": val_data_meta},\n","    val_labels,\n","    save_path=None\n","  )\n","\n","  test_metrics = model_performance(\n","    model,\n","    {\"inclusion\": test_data_incl, \"exclusion\": test_data_excl, \n","    \"diseases\": test_data_diseases, \"drug\": test_data_drug,\n","    \"targets\": test_data_target, \"meta\": test_data_meta},\n","    test_labels,\n","    save_path=None\n","  )\n","\n","  val_F1[i] = val_metrics[\"F1\"]\n","  val_acc[i] = val_metrics[\"accuracy\"]\n","  val_pr_auc[i] = val_metrics[\"pr_auc\"]\n","  val_roc_auc[i] = val_metrics[\"roc_auc\"]\n","\n","  test_F1[i] = test_metrics[\"F1\"]\n","  test_acc[i] = test_metrics[\"accuracy\"]\n","  test_pr_auc[i] = test_metrics[\"pr_auc\"]\n","  test_roc_auc[i] = test_metrics[\"roc_auc\"]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"Abp5DL43f7ic","executionInfo":{"status":"ok","timestamp":1651518004414,"user_tz":240,"elapsed":88484,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"856e52cb-9425-40e7-d0a0-20a17b7ac8ed"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/10\n","25/25 [==============================] - 11s 327ms/step - loss: 5.1901 - accuracy: 0.5837 - auc: 0.4969 - val_loss: 0.6553 - val_accuracy: 0.6831 - val_auc: 0.5050\n","Epoch 2/10\n","25/25 [==============================] - 7s 287ms/step - loss: 0.6480 - accuracy: 0.6658 - auc: 0.5857 - val_loss: 0.6123 - val_accuracy: 0.6802 - val_auc: 0.6481\n","Epoch 3/10\n","25/25 [==============================] - 7s 284ms/step - loss: 0.6103 - accuracy: 0.6813 - auc: 0.6581 - val_loss: 0.6064 - val_accuracy: 0.6831 - val_auc: 0.6698\n","Epoch 4/10\n","25/25 [==============================] - 5s 193ms/step - loss: 0.5893 - accuracy: 0.6894 - auc: 0.6982 - val_loss: 0.5899 - val_accuracy: 0.6744 - val_auc: 0.6827\n","Epoch 5/10\n","25/25 [==============================] - 5s 199ms/step - loss: 0.5504 - accuracy: 0.7143 - auc: 0.7569 - val_loss: 0.5811 - val_accuracy: 0.7122 - val_auc: 0.7064\n","Epoch 6/10\n","25/25 [==============================] - 5s 194ms/step - loss: 0.5093 - accuracy: 0.7524 - auc: 0.8060 - val_loss: 0.6642 - val_accuracy: 0.6831 - val_auc: 0.6762\n","Epoch 7/10\n","25/25 [==============================] - 5s 198ms/step - loss: 0.4902 - accuracy: 0.7592 - auc: 0.8183 - val_loss: 0.6573 - val_accuracy: 0.6948 - val_auc: 0.6892\n","Epoch 8/10\n","25/25 [==============================] - 5s 197ms/step - loss: 0.4225 - accuracy: 0.8009 - auc: 0.8712 - val_loss: 0.6614 - val_accuracy: 0.6221 - val_auc: 0.6569\n","Epoch 9/10\n","25/25 [==============================] - 5s 191ms/step - loss: 0.3865 - accuracy: 0.8332 - auc: 0.8969 - val_loss: 0.7221 - val_accuracy: 0.7006 - val_auc: 0.6865\n","Epoch 10/10\n","25/25 [==============================] - 5s 193ms/step - loss: 0.3507 - accuracy: 0.8516 - auc: 0.9149 - val_loss: 0.7224 - val_accuracy: 0.6802 - val_auc: 0.7006\n"]},{"output_type":"display_data","data":{"text/plain":["<Figure size 720x360 with 2 Axes>"],"image/png":"\n"},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"text/plain":["<Figure size 720x360 with 2 Axes>"],"image/png":"\n"},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","source":["val_F1_mean, val_F1_sd = mean_and_sd(val_F1)\n","val_acc_mean, val_acc_sd = mean_and_sd(val_acc)\n","val_pr_auc_mean, val_pr_auc_sd = mean_and_sd(val_pr_auc)\n","val_roc_auc_mean, val_roc_auc_sd = mean_and_sd(val_roc_auc)\n","\n","test_F1_mean, test_F1_sd = mean_and_sd(test_F1)\n","test_acc_mean, test_acc_sd = mean_and_sd(test_acc)\n","test_pr_auc_mean, test_pr_auc_sd = mean_and_sd(test_pr_auc)\n","test_roc_auc_mean, test_roc_auc_sd = mean_and_sd(test_roc_auc)"],"metadata":{"id":"OcYnEyGKnn5n"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f\"validation F1: {round(val_F1_mean, 3)} +/- {round(val_F1_sd, 3)}\")\n","print(f\"test F1: {round(test_F1_mean, 3)} +/- {round(test_F1_sd, 3)}\")\n","\n","print(f\"validation accuracy: {round(val_acc_mean, 3)} +/- {round(val_acc_sd, 3)}\")\n","print(f\"test accuracy: {round(test_acc_mean, 3)} +/- {round(test_acc_sd, 3)}\")\n","\n","print(f\"validation PR-AUC: {round(val_pr_auc_mean, 3)} +/- {round(val_pr_auc_sd, 3)}\")\n","print(f\"test PR-AUC: {round(test_pr_auc_mean, 3)} +/- {round(test_pr_auc_sd, 3)}\")\n","\n","print(f\"validation ROC-AUC: {round(val_roc_auc_mean, 3)} +/- {round(val_roc_auc_sd, 3)}\")\n","print(f\"test ROC-AUC: {round(test_roc_auc_mean, 3)} +/- {round(test_roc_auc_sd, 3)}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"-liHd50Ypx4b","executionInfo":{"status":"ok","timestamp":1651514835741,"user_tz":240,"elapsed":25,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"b4c679af-24eb-49c1-8846-2f83d9dad6e9"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["validation F1: 0.787 +/- 0.016\n","test F1: 0.831 +/- 0.027\n","validation accuracy: 0.679 +/- 0.012\n","test accuracy: 0.728 +/- 0.028\n","validation PR-AUC: 0.799 +/- 0.013\n","test PR-AUC: 0.841 +/- 0.009\n","validation ROC-AUC: 0.665 +/- 0.026\n","test ROC-AUC: 0.65 +/- 0.032\n"]}]},{"cell_type":"code","source":["model.save(deep_learning_dir + \"/MLP_models/model_weights_plot\")\n","\n","# model = tf.keras.models.load_model(deep_learning_dir + \"clinical_trial_model_weights\")"],"metadata":{"id":"t5boyGKM0sg7","executionInfo":{"status":"ok","timestamp":1651518148927,"user_tz":240,"elapsed":6838,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"c7ee6b89-47d8-4e0a-a6c2-af162ae10323"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:Assets written to: /content/gdrive/My Drive/BMI 707 Project/MLP_models/model_weights_plot/assets\n"]}]},{"cell_type":"markdown","source":["Model structure"],"metadata":{"id":"Wg5dCkNvK-Bu"}},{"cell_type":"code","source":["# print model structure\n","tf.keras.utils.plot_model(model, \"clinical_trial_model.pdf\", show_shapes=True)"],"metadata":{"id":"Ei7D7SQmlq3E"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!mv clinical_trial_model.pdf \"/content/gdrive/My Drive/BMI 707 Project/MLP_models/\""],"metadata":{"id":"dh87OEEm-QZI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["plot_learning_curve(history, save_path=\"/content/gdrive/My Drive/BMI 707 Project/MLP_models/loss_history.pdf\")"],"metadata":{"id":"eAUGpVh2KLc5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model_metrics = model_performance(\n","  model,\n","  {\"inclusion\": test_data_incl, \"exclusion\": test_data_excl, \n","   \"diseases\": test_data_diseases, \"drug\": test_data_drug,\n","   \"targets\": test_data_target, \"meta\": test_data_meta},\n","  test_labels,\n","  save_path=deep_learning_dir+\"/MLP_models/model_performance.pdf\"\n",")"],"metadata":{"id":"_w6DnF131LWv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model_metrics"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"X1fZY5urQunF","executionInfo":{"status":"ok","timestamp":1651518476978,"user_tz":240,"elapsed":143,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"c24e7bf0-3599-4532-f895-b8403f57b2dd"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'F1': 0.83, 'accuracy': 0.727, 'pr_auc': 0.844, 'roc_auc': 0.67}"]},"metadata":{},"execution_count":67}]},{"cell_type":"markdown","source":["# Leave one out models"],"metadata":{"id":"GCVkz9nCiezw"}},{"cell_type":"markdown","source":["### Without eligibility criteria"],"metadata":{"id":"pMOATtfckQrD"}},{"cell_type":"code","source":["n_runs = 30"],"metadata":{"id":"41-pz1IXigzs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1 = np.zeros(n_runs, dtype=float)\n","val_acc = np.zeros(n_runs, dtype=float)\n","val_pr_auc = np.zeros(n_runs, dtype=float)\n","val_roc_auc = np.zeros(n_runs, dtype=float)\n","\n","test_F1 = np.zeros(n_runs, dtype=float)\n","test_acc = np.zeros(n_runs, dtype=float)\n","test_pr_auc = np.zeros(n_runs, dtype=float)\n","test_roc_auc = np.zeros(n_runs, dtype=float)"],"metadata":{"id":"BPpr81y9ijnC"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(n_runs):\n","\n","  # diseases\n","  diseases_input = layers.Input(shape=(diseases_dim,), name=\"diseases\")\n","  diseases_emb = layers.Dense(128, activation=\"relu\", name=\"lower_dim_diseases\")(diseases_input)\n","\n","\n","  # drug\n","  drug_input = layers.Input(shape=(drug_dim,), name=\"drug\")\n","  drug_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug\")(drug_input)\n","\n","  # targets\n","  targets_input = layers.Input(shape=(targets_dim,), name=\"targets\")\n","  targets_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_targets\")(targets_input)\n","\n","  # drug-target interaction\n","  drug_target_raw = layers.Concatenate(name=\"drug-targets\")([drug_emb, targets_emb])\n","  dt = layers.Dropout(rate=0.32)(drug_target_raw)\n","  drug_target_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug-targets\")(dt)\n","\n","\n","  all_emb = layers.Concatenate(name=\"all_embeddings\")([diseases_emb, drug_target_emb])\n","  ae = layers.Dropout(rate=0.32)(all_emb)\n","  trial_embedding1 = layers.Dense(128, activation=\"relu\", name=\"trial_embedding_1\")(ae)\n","\n","  meta_input = layers.Input(shape=(meta_dim,), name=\"meta\")\n","  emb_and_meta = layers.Concatenate()([trial_embedding1, meta_input])\n","\n","  trial_embedding2 = layers.Dense(64, activation=\"relu\", name=\"trial_embedding_2\")(emb_and_meta)\n","  trial_embedding3 = layers.Dense(32, activation=\"relu\", name=\"trial_embedding_3\")(trial_embedding2)\n","\n","  o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","  model = tf.keras.Model(inputs=[diseases_input, drug_input, targets_input, meta_input], outputs=[o])\n","  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","\n","  history = model.fit(\n","    x={\"diseases\": train_data_diseases, \"drug\": train_data_drug,\n","      \"targets\": train_data_target, \"meta\": train_data_meta},\n","    y={\"trial_success\": train_labels},\n","    validation_data=(\n","        {\"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","        \"targets\": val_data_target, \"meta\": val_data_meta},\n","        {\"trial_success\": val_labels}\n","    ), \n","    epochs=5,\n","    batch_size=128\n","  )\n","\n","\n","  val_metrics = model_performance(\n","    model,\n","    {\"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","    \"targets\": val_data_target, \"meta\": val_data_meta},\n","    val_labels,\n","    save_path=None\n","  )\n","\n","  test_metrics = model_performance(\n","    model,\n","    {\"diseases\": test_data_diseases, \"drug\": test_data_drug,\n","    \"targets\": test_data_target, \"meta\": test_data_meta},\n","    test_labels,\n","    save_path=None\n","  )\n","\n","  val_F1[i] = val_metrics[\"F1\"]\n","  val_acc[i] = val_metrics[\"accuracy\"]\n","  val_pr_auc[i] = val_metrics[\"pr_auc\"]\n","  val_roc_auc[i] = val_metrics[\"roc_auc\"]\n","\n","  test_F1[i] = test_metrics[\"F1\"]\n","  test_acc[i] = test_metrics[\"accuracy\"]\n","  test_pr_auc[i] = test_metrics[\"pr_auc\"]\n","  test_roc_auc[i] = test_metrics[\"roc_auc\"]"],"metadata":{"id":"ND3BoknyisOR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1_mean, val_F1_sd = mean_and_sd(val_F1)\n","val_acc_mean, val_acc_sd = mean_and_sd(val_acc)\n","val_pr_auc_mean, val_pr_auc_sd = mean_and_sd(val_pr_auc)\n","val_roc_auc_mean, val_roc_auc_sd = mean_and_sd(val_roc_auc)\n","\n","test_F1_mean, test_F1_sd = mean_and_sd(test_F1)\n","test_acc_mean, test_acc_sd = mean_and_sd(test_acc)\n","test_pr_auc_mean, test_pr_auc_sd = mean_and_sd(test_pr_auc)\n","test_roc_auc_mean, test_roc_auc_sd = mean_and_sd(test_roc_auc)"],"metadata":{"id":"05s6938Oi7LY"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f\"validation F1: {round(val_F1_mean, 3)} +/- {round(val_F1_sd, 3)}\")\n","print(f\"test F1: {round(test_F1_mean, 3)} +/- {round(test_F1_sd, 3)}\")\n","\n","print(f\"validation accuracy: {round(val_acc_mean, 3)} +/- {round(val_acc_sd, 3)}\")\n","print(f\"test accuracy: {round(test_acc_mean, 3)} +/- {round(test_acc_sd, 3)}\")\n","\n","print(f\"validation PR-AUC: {round(val_pr_auc_mean, 3)} +/- {round(val_pr_auc_sd, 3)}\")\n","print(f\"test PR-AUC: {round(test_pr_auc_mean, 3)} +/- {round(test_pr_auc_sd, 3)}\")\n","\n","print(f\"validation ROC-AUC: {round(val_roc_auc_mean, 3)} +/- {round(val_roc_auc_sd, 3)}\")\n","print(f\"test ROC-AUC: {round(test_roc_auc_mean, 3)} +/- {round(test_roc_auc_sd, 3)}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nsTA6lp8i8R5","executionInfo":{"status":"ok","timestamp":1651525179337,"user_tz":240,"elapsed":12,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"aca0e6dc-3f33-480f-8828-2f518587b3ef"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["validation F1: 0.781 +/- 0.012\n","test F1: 0.817 +/- 0.02\n","validation accuracy: 0.667 +/- 0.012\n","test accuracy: 0.71 +/- 0.021\n","validation PR-AUC: 0.795 +/- 0.013\n","test PR-AUC: 0.83 +/- 0.011\n","validation ROC-AUC: 0.659 +/- 0.02\n","test ROC-AUC: 0.631 +/- 0.022\n"]}]},{"cell_type":"markdown","source":["### Without target data"],"metadata":{"id":"-5k5NxKfkW7p"}},{"cell_type":"code","source":["n_runs = 30"],"metadata":{"id":"TxIzT1AUkM0m"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1 = np.zeros(n_runs, dtype=float)\n","val_acc = np.zeros(n_runs, dtype=float)\n","val_pr_auc = np.zeros(n_runs, dtype=float)\n","val_roc_auc = np.zeros(n_runs, dtype=float)\n","\n","test_F1 = np.zeros(n_runs, dtype=float)\n","test_acc = np.zeros(n_runs, dtype=float)\n","test_pr_auc = np.zeros(n_runs, dtype=float)\n","test_roc_auc = np.zeros(n_runs, dtype=float)"],"metadata":{"id":"j2HytVzakfMk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(n_runs):\n","\n","  # inclusion and exclusion criteria\n","  inclusion_input = layers.Input(shape=(incl_criteria_dim,), name=\"inclusion\")\n","  inclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_inclusion\")(inclusion_input)\n","\n","  exclusion_input = layers.Input(shape=(excl_criteria_dim,), name=\"exclusion\")\n","  exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_exclusion\")(exclusion_input)\n","\n","  inclusion_exclusion_raw = layers.Concatenate(name=\"criteria\")([inclusion_emb, exclusion_emb])\n","  ie_dropout = layers.Dropout(rate=0.32)(inclusion_exclusion_raw)\n","  inclusion_exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_criteria\")(ie_dropout)\n","\n","  # diseases\n","  diseases_input = layers.Input(shape=(diseases_dim,), name=\"diseases\")\n","  diseases_emb = layers.Dense(128, activation=\"relu\", name=\"lower_dim_diseases\")(diseases_input)\n","\n","\n","  # drug\n","  drug_input = layers.Input(shape=(drug_dim,), name=\"drug\")\n","  drug_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug\")(drug_input)\n","\n","\n","  all_emb = layers.Concatenate(name=\"all_embeddings\")([inclusion_exclusion_emb, diseases_emb, drug_emb])\n","  ae = layers.Dropout(rate=0.32)(all_emb)\n","  trial_embedding1 = layers.Dense(128, activation=\"relu\", name=\"trial_embedding_1\")(ae)\n","\n","  meta_input = layers.Input(shape=(meta_dim,), name=\"meta\")\n","  emb_and_meta = layers.Concatenate()([trial_embedding1, meta_input])\n","\n","  trial_embedding2 = layers.Dense(64, activation=\"relu\", name=\"trial_embedding_2\")(emb_and_meta)\n","  trial_embedding3 = layers.Dense(32, activation=\"relu\", name=\"trial_embedding_3\")(trial_embedding2)\n","\n","  o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","  model = tf.keras.Model(inputs=[inclusion_input, exclusion_input, diseases_input, drug_input, meta_input], outputs=[o])\n","  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","\n","  history = model.fit(\n","    x={\"inclusion\": train_data_incl, \"exclusion\": train_data_excl, \n","      \"diseases\": train_data_diseases, \"drug\": train_data_drug,\n","      \"meta\": train_data_meta},\n","    y={\"trial_success\": train_labels},\n","    validation_data=(\n","        {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","        \"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","        \"meta\": val_data_meta},\n","        {\"trial_success\": val_labels}\n","    ), \n","    epochs=5,\n","    batch_size=128\n","  )\n","\n","\n","  val_metrics = model_performance(\n","    model,\n","    {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl, \n","    \"diseases\": val_data_diseases, \"drug\": val_data_drug,\n","    \"meta\": val_data_meta},\n","    val_labels,\n","    save_path=None\n","  )\n","\n","  test_metrics = model_performance(\n","    model,\n","    {\"inclusion\": test_data_incl, \"exclusion\": test_data_excl, \n","    \"diseases\": test_data_diseases, \"drug\": test_data_drug, \n","    \"meta\": test_data_meta},\n","    test_labels,\n","    save_path=None\n","  )\n","\n","  val_F1[i] = val_metrics[\"F1\"]\n","  val_acc[i] = val_metrics[\"accuracy\"]\n","  val_pr_auc[i] = val_metrics[\"pr_auc\"]\n","  val_roc_auc[i] = val_metrics[\"roc_auc\"]\n","\n","  test_F1[i] = test_metrics[\"F1\"]\n","  test_acc[i] = test_metrics[\"accuracy\"]\n","  test_pr_auc[i] = test_metrics[\"pr_auc\"]\n","  test_roc_auc[i] = test_metrics[\"roc_auc\"]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"QPWTu9XYkkjD","executionInfo":{"status":"ok","timestamp":1651526772411,"user_tz":240,"elapsed":1286826,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"f1f9e993-c82c-4d6c-e230-bea851a7f242"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/5\n","25/25 [==============================] - 7s 210ms/step - loss: 3.0516 - accuracy: 0.5915 - auc: 0.5278 - val_loss: 0.7818 - val_accuracy: 0.6628 - val_auc: 0.5612\n","Epoch 2/5\n","25/25 [==============================] - 5s 195ms/step - loss: 0.7825 - accuracy: 0.6574 - auc: 0.6027 - val_loss: 0.6396 - val_accuracy: 0.6744 - val_auc: 0.6260\n","Epoch 3/5\n","25/25 [==============================] - 5s 192ms/step - loss: 0.7362 - accuracy: 0.6642 - auc: 0.6320 - val_loss: 0.6168 - val_accuracy: 0.6773 - val_auc: 0.6550\n","Epoch 4/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6388 - accuracy: 0.6800 - auc: 0.6745 - val_loss: 0.6311 - val_accuracy: 0.6773 - val_auc: 0.6396\n","Epoch 5/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6032 - accuracy: 0.6949 - auc: 0.7047 - val_loss: 0.6036 - val_accuracy: 0.6890 - val_auc: 0.6658\n","Epoch 1/5\n","25/25 [==============================] - 8s 223ms/step - loss: 3.2503 - accuracy: 0.5695 - auc: 0.5043 - val_loss: 0.6779 - val_accuracy: 0.6541 - val_auc: 0.6130\n","Epoch 2/5\n","25/25 [==============================] - 5s 193ms/step - loss: 0.6538 - accuracy: 0.6542 - auc: 0.5508 - val_loss: 0.6417 - val_accuracy: 0.6657 - val_auc: 0.5747\n","Epoch 3/5\n","25/25 [==============================] - 5s 193ms/step - loss: 0.6393 - accuracy: 0.6532 - auc: 0.5856 - val_loss: 0.6247 - val_accuracy: 0.6657 - val_auc: 0.6029\n","Epoch 4/5\n","25/25 [==============================] - 5s 207ms/step - loss: 0.6362 - accuracy: 0.6532 - auc: 0.6055 - val_loss: 0.6184 - val_accuracy: 0.6657 - val_auc: 0.6306\n","Epoch 5/5\n","25/25 [==============================] - 5s 190ms/step - loss: 0.6073 - accuracy: 0.6532 - auc: 0.6814 - val_loss: 0.6163 - val_accuracy: 0.6657 - val_auc: 0.6367\n","Epoch 1/5\n","25/25 [==============================] - 7s 225ms/step - loss: 2.7811 - accuracy: 0.5944 - auc: 0.5325 - val_loss: 0.6888 - val_accuracy: 0.6715 - val_auc: 0.5158\n","Epoch 2/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6513 - accuracy: 0.6668 - auc: 0.5851 - val_loss: 0.6220 - val_accuracy: 0.6831 - val_auc: 0.6167\n","Epoch 3/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6226 - accuracy: 0.6797 - auc: 0.6495 - val_loss: 0.6090 - val_accuracy: 0.6773 - val_auc: 0.6392\n","Epoch 4/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5957 - accuracy: 0.6894 - auc: 0.6904 - val_loss: 0.6031 - val_accuracy: 0.6657 - val_auc: 0.6670\n","Epoch 5/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.5666 - accuracy: 0.6981 - auc: 0.7306 - val_loss: 0.6022 - val_accuracy: 0.6860 - val_auc: 0.6755\n","Epoch 1/5\n","25/25 [==============================] - 7s 216ms/step - loss: 2.7659 - accuracy: 0.5895 - auc: 0.5276 - val_loss: 0.6746 - val_accuracy: 0.6599 - val_auc: 0.5285\n","Epoch 2/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6566 - accuracy: 0.6642 - auc: 0.6104 - val_loss: 0.6700 - val_accuracy: 0.6744 - val_auc: 0.6584\n","Epoch 3/5\n","25/25 [==============================] - 5s 208ms/step - loss: 0.6450 - accuracy: 0.6726 - auc: 0.6325 - val_loss: 0.6229 - val_accuracy: 0.6715 - val_auc: 0.6678\n","Epoch 4/5\n","25/25 [==============================] - 5s 220ms/step - loss: 0.6456 - accuracy: 0.6800 - auc: 0.6793 - val_loss: 0.6611 - val_accuracy: 0.6744 - val_auc: 0.6819\n","Epoch 5/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.7195 - accuracy: 0.6923 - auc: 0.7027 - val_loss: 0.6543 - val_accuracy: 0.6337 - val_auc: 0.6687\n","Epoch 1/5\n","25/25 [==============================] - 8s 271ms/step - loss: 2.0996 - accuracy: 0.5750 - auc: 0.5257 - val_loss: 1.0610 - val_accuracy: 0.6453 - val_auc: 0.5251\n","Epoch 2/5\n","25/25 [==============================] - 8s 322ms/step - loss: 0.7630 - accuracy: 0.6542 - auc: 0.5899 - val_loss: 0.6516 - val_accuracy: 0.6715 - val_auc: 0.6333\n","Epoch 3/5\n","25/25 [==============================] - 7s 279ms/step - loss: 0.6190 - accuracy: 0.6761 - auc: 0.6478 - val_loss: 0.6308 - val_accuracy: 0.6773 - val_auc: 0.6626\n","Epoch 4/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.7346 - accuracy: 0.6694 - auc: 0.6555 - val_loss: 0.6943 - val_accuracy: 0.6017 - val_auc: 0.6423\n","Epoch 5/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6278 - accuracy: 0.6858 - auc: 0.7035 - val_loss: 0.6227 - val_accuracy: 0.6860 - val_auc: 0.6876\n","Epoch 1/5\n","25/25 [==============================] - 9s 265ms/step - loss: 3.1197 - accuracy: 0.5782 - auc: 0.4986 - val_loss: 0.7682 - val_accuracy: 0.6512 - val_auc: 0.6309\n","Epoch 2/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.7477 - accuracy: 0.6571 - auc: 0.5851 - val_loss: 0.7635 - val_accuracy: 0.6715 - val_auc: 0.5649\n","Epoch 3/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6401 - accuracy: 0.6791 - auc: 0.6595 - val_loss: 0.6444 - val_accuracy: 0.6802 - val_auc: 0.6620\n","Epoch 4/5\n","25/25 [==============================] - 10s 398ms/step - loss: 0.5945 - accuracy: 0.6942 - auc: 0.6967 - val_loss: 0.5965 - val_accuracy: 0.6831 - val_auc: 0.6819\n","Epoch 5/5\n","25/25 [==============================] - 8s 316ms/step - loss: 0.5739 - accuracy: 0.7065 - auc: 0.7269 - val_loss: 0.5841 - val_accuracy: 0.6890 - val_auc: 0.6943\n","Epoch 1/5\n","25/25 [==============================] - 10s 317ms/step - loss: 4.2090 - accuracy: 0.5792 - auc: 0.5125 - val_loss: 0.7939 - val_accuracy: 0.5116 - val_auc: 0.6493\n","Epoch 2/5\n","25/25 [==============================] - 6s 238ms/step - loss: 0.7171 - accuracy: 0.6354 - auc: 0.5632 - val_loss: 0.7715 - val_accuracy: 0.6831 - val_auc: 0.6399\n","Epoch 3/5\n","25/25 [==============================] - 7s 283ms/step - loss: 0.6826 - accuracy: 0.6829 - auc: 0.6556 - val_loss: 1.5883 - val_accuracy: 0.6831 - val_auc: 0.6741\n","Epoch 4/5\n","25/25 [==============================] - 5s 219ms/step - loss: 0.6835 - accuracy: 0.6761 - auc: 0.6738 - val_loss: 0.8813 - val_accuracy: 0.6948 - val_auc: 0.6968\n","Epoch 5/5\n","25/25 [==============================] - 10s 385ms/step - loss: 0.6380 - accuracy: 0.6955 - auc: 0.7093 - val_loss: 0.7119 - val_accuracy: 0.6628 - val_auc: 0.6652\n","Epoch 1/5\n","25/25 [==============================] - 7s 212ms/step - loss: 2.6217 - accuracy: 0.5798 - auc: 0.5231 - val_loss: 0.6823 - val_accuracy: 0.6628 - val_auc: 0.5858\n","Epoch 2/5\n","25/25 [==============================] - 6s 244ms/step - loss: 0.7247 - accuracy: 0.6500 - auc: 0.6172 - val_loss: 0.6562 - val_accuracy: 0.6628 - val_auc: 0.6441\n","Epoch 3/5\n","25/25 [==============================] - 7s 292ms/step - loss: 0.7213 - accuracy: 0.6629 - auc: 0.6381 - val_loss: 0.7133 - val_accuracy: 0.6192 - val_auc: 0.6104\n","Epoch 4/5\n","25/25 [==============================] - 5s 211ms/step - loss: 0.5931 - accuracy: 0.6959 - auc: 0.7046 - val_loss: 0.6356 - val_accuracy: 0.6686 - val_auc: 0.6598\n","Epoch 5/5\n","25/25 [==============================] - 6s 253ms/step - loss: 0.5955 - accuracy: 0.6952 - auc: 0.7158 - val_loss: 0.5910 - val_accuracy: 0.6860 - val_auc: 0.6851\n","Epoch 1/5\n","25/25 [==============================] - 9s 216ms/step - loss: 2.7464 - accuracy: 0.5999 - auc: 0.4886 - val_loss: 0.6688 - val_accuracy: 0.6686 - val_auc: 0.5023\n","Epoch 2/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6545 - accuracy: 0.6587 - auc: 0.5608 - val_loss: 0.6301 - val_accuracy: 0.6657 - val_auc: 0.5938\n","Epoch 3/5\n","25/25 [==============================] - 5s 194ms/step - loss: 0.6286 - accuracy: 0.6532 - auc: 0.6185 - val_loss: 0.6206 - val_accuracy: 0.6657 - val_auc: 0.6217\n","Epoch 4/5\n","25/25 [==============================] - 7s 277ms/step - loss: 0.6137 - accuracy: 0.6532 - auc: 0.6570 - val_loss: 0.6051 - val_accuracy: 0.6657 - val_auc: 0.6605\n","Epoch 5/5\n","25/25 [==============================] - 9s 364ms/step - loss: 0.5848 - accuracy: 0.6532 - auc: 0.7185 - val_loss: 0.6058 - val_accuracy: 0.6657 - val_auc: 0.6611\n","Epoch 1/5\n","25/25 [==============================] - 10s 328ms/step - loss: 2.4270 - accuracy: 0.5856 - auc: 0.5219 - val_loss: 0.7504 - val_accuracy: 0.6657 - val_auc: 0.4887\n","Epoch 2/5\n","25/25 [==============================] - 7s 287ms/step - loss: 0.6812 - accuracy: 0.6422 - auc: 0.5762 - val_loss: 0.6269 - val_accuracy: 0.6802 - val_auc: 0.6338\n","Epoch 3/5\n","25/25 [==============================] - 8s 321ms/step - loss: 0.6170 - accuracy: 0.6781 - auc: 0.6529 - val_loss: 0.6180 - val_accuracy: 0.6773 - val_auc: 0.6544\n","Epoch 4/5\n","25/25 [==============================] - 7s 289ms/step - loss: 0.5969 - accuracy: 0.6813 - auc: 0.6915 - val_loss: 0.6128 - val_accuracy: 0.6744 - val_auc: 0.6599\n","Epoch 5/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.5770 - accuracy: 0.6849 - auc: 0.7179 - val_loss: 0.5948 - val_accuracy: 0.6860 - val_auc: 0.6854\n","Epoch 1/5\n","25/25 [==============================] - 7s 219ms/step - loss: 2.5684 - accuracy: 0.6021 - auc: 0.5325 - val_loss: 0.6450 - val_accuracy: 0.6744 - val_auc: 0.5836\n","Epoch 2/5\n","25/25 [==============================] - 7s 296ms/step - loss: 0.6439 - accuracy: 0.6580 - auc: 0.5895 - val_loss: 0.6288 - val_accuracy: 0.6831 - val_auc: 0.6062\n","Epoch 3/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6153 - accuracy: 0.6846 - auc: 0.6428 - val_loss: 0.6061 - val_accuracy: 0.6802 - val_auc: 0.6580\n","Epoch 4/5\n","25/25 [==============================] - 8s 326ms/step - loss: 0.5905 - accuracy: 0.6949 - auc: 0.7025 - val_loss: 0.6118 - val_accuracy: 0.6860 - val_auc: 0.6483\n","Epoch 5/5\n","25/25 [==============================] - 7s 298ms/step - loss: 0.5737 - accuracy: 0.7078 - auc: 0.7322 - val_loss: 0.5962 - val_accuracy: 0.6802 - val_auc: 0.6767\n","Epoch 1/5\n","25/25 [==============================] - 10s 322ms/step - loss: 2.5810 - accuracy: 0.5659 - auc: 0.4943 - val_loss: 0.8021 - val_accuracy: 0.6802 - val_auc: 0.5091\n","Epoch 2/5\n","25/25 [==============================] - 6s 247ms/step - loss: 0.6780 - accuracy: 0.6438 - auc: 0.5843 - val_loss: 0.6288 - val_accuracy: 0.6715 - val_auc: 0.6130\n","Epoch 3/5\n","25/25 [==============================] - 7s 280ms/step - loss: 0.6062 - accuracy: 0.6878 - auc: 0.6680 - val_loss: 0.6049 - val_accuracy: 0.6919 - val_auc: 0.6681\n","Epoch 4/5\n","25/25 [==============================] - 7s 279ms/step - loss: 0.5883 - accuracy: 0.6810 - auc: 0.6909 - val_loss: 0.6058 - val_accuracy: 0.6802 - val_auc: 0.6630\n","Epoch 5/5\n","25/25 [==============================] - 7s 279ms/step - loss: 0.5733 - accuracy: 0.6984 - auc: 0.7280 - val_loss: 0.6048 - val_accuracy: 0.6831 - val_auc: 0.6794\n","Epoch 1/5\n","25/25 [==============================] - 10s 257ms/step - loss: 3.4692 - accuracy: 0.6138 - auc: 0.5288 - val_loss: 0.6400 - val_accuracy: 0.6802 - val_auc: 0.5729\n","Epoch 2/5\n","25/25 [==============================] - 6s 255ms/step - loss: 0.6350 - accuracy: 0.6723 - auc: 0.5813 - val_loss: 0.6633 - val_accuracy: 0.6657 - val_auc: 0.6267\n","Epoch 3/5\n","25/25 [==============================] - 7s 288ms/step - loss: 0.6235 - accuracy: 0.6736 - auc: 0.6505 - val_loss: 0.6079 - val_accuracy: 0.6831 - val_auc: 0.6738\n","Epoch 4/5\n","25/25 [==============================] - 7s 276ms/step - loss: 0.5702 - accuracy: 0.6981 - auc: 0.7286 - val_loss: 0.6071 - val_accuracy: 0.6948 - val_auc: 0.6758\n","Epoch 5/5\n","25/25 [==============================] - 9s 351ms/step - loss: 0.5300 - accuracy: 0.7356 - auc: 0.7828 - val_loss: 0.6174 - val_accuracy: 0.6744 - val_auc: 0.6771\n","Epoch 1/5\n","25/25 [==============================] - 11s 353ms/step - loss: 3.8252 - accuracy: 0.5633 - auc: 0.4954 - val_loss: 0.7803 - val_accuracy: 0.5610 - val_auc: 0.5570\n","Epoch 2/5\n","25/25 [==============================] - 10s 376ms/step - loss: 0.8035 - accuracy: 0.6367 - auc: 0.5596 - val_loss: 1.2881 - val_accuracy: 0.6017 - val_auc: 0.6355\n","Epoch 3/5\n","25/25 [==============================] - 8s 303ms/step - loss: 0.7916 - accuracy: 0.6645 - auc: 0.6412 - val_loss: 0.6234 - val_accuracy: 0.6657 - val_auc: 0.6305\n","Epoch 4/5\n","25/25 [==============================] - 6s 222ms/step - loss: 0.6503 - accuracy: 0.6745 - auc: 0.6759 - val_loss: 0.6906 - val_accuracy: 0.6657 - val_auc: 0.6468\n","Epoch 5/5\n","25/25 [==============================] - 5s 194ms/step - loss: 0.6085 - accuracy: 0.6836 - auc: 0.7120 - val_loss: 0.5927 - val_accuracy: 0.6599 - val_auc: 0.6736\n","Epoch 1/5\n","25/25 [==============================] - 7s 214ms/step - loss: 2.4624 - accuracy: 0.5944 - auc: 0.5385 - val_loss: 0.6212 - val_accuracy: 0.6890 - val_auc: 0.5988\n","Epoch 2/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6638 - accuracy: 0.6677 - auc: 0.6151 - val_loss: 0.6547 - val_accuracy: 0.6802 - val_auc: 0.6844\n","Epoch 3/5\n","25/25 [==============================] - 5s 195ms/step - loss: 0.6306 - accuracy: 0.6858 - auc: 0.6718 - val_loss: 0.6037 - val_accuracy: 0.6860 - val_auc: 0.6720\n","Epoch 4/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.5911 - accuracy: 0.6975 - auc: 0.7089 - val_loss: 0.6075 - val_accuracy: 0.6541 - val_auc: 0.6816\n","Epoch 5/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.6108 - accuracy: 0.6936 - auc: 0.7174 - val_loss: 0.5936 - val_accuracy: 0.6860 - val_auc: 0.6746\n","Epoch 1/5\n","25/25 [==============================] - 7s 224ms/step - loss: 3.3112 - accuracy: 0.6047 - auc: 0.5420 - val_loss: 0.6967 - val_accuracy: 0.6686 - val_auc: 0.6080\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6366 - accuracy: 0.6826 - auc: 0.6482 - val_loss: 0.6372 - val_accuracy: 0.6424 - val_auc: 0.6573\n","Epoch 3/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.5902 - accuracy: 0.6813 - auc: 0.6974 - val_loss: 0.5977 - val_accuracy: 0.6860 - val_auc: 0.6912\n","Epoch 4/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.5932 - accuracy: 0.6962 - auc: 0.7236 - val_loss: 0.5954 - val_accuracy: 0.6860 - val_auc: 0.6813\n","Epoch 5/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.5564 - accuracy: 0.7185 - auc: 0.7500 - val_loss: 0.5941 - val_accuracy: 0.6831 - val_auc: 0.6968\n","Epoch 1/5\n","25/25 [==============================] - 7s 220ms/step - loss: 3.0030 - accuracy: 0.5963 - auc: 0.5402 - val_loss: 0.7164 - val_accuracy: 0.6628 - val_auc: 0.5879\n","Epoch 2/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.7052 - accuracy: 0.6516 - auc: 0.5897 - val_loss: 0.6358 - val_accuracy: 0.6860 - val_auc: 0.6255\n","Epoch 3/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6307 - accuracy: 0.6778 - auc: 0.6575 - val_loss: 0.8902 - val_accuracy: 0.6773 - val_auc: 0.6436\n","Epoch 4/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.6705 - accuracy: 0.6710 - auc: 0.6700 - val_loss: 0.5994 - val_accuracy: 0.6802 - val_auc: 0.6770\n","Epoch 5/5\n","25/25 [==============================] - 5s 217ms/step - loss: 0.5874 - accuracy: 0.6942 - auc: 0.7184 - val_loss: 0.6224 - val_accuracy: 0.6802 - val_auc: 0.6846\n","Epoch 1/5\n","25/25 [==============================] - 7s 217ms/step - loss: 2.0739 - accuracy: 0.5989 - auc: 0.4932 - val_loss: 0.6955 - val_accuracy: 0.5233 - val_auc: 0.5777\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6504 - accuracy: 0.6626 - auc: 0.5616 - val_loss: 0.6218 - val_accuracy: 0.6802 - val_auc: 0.5804\n","Epoch 3/5\n","25/25 [==============================] - 5s 220ms/step - loss: 0.6237 - accuracy: 0.6781 - auc: 0.6261 - val_loss: 0.6372 - val_accuracy: 0.6657 - val_auc: 0.6127\n","Epoch 4/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6153 - accuracy: 0.6791 - auc: 0.6415 - val_loss: 0.6023 - val_accuracy: 0.6802 - val_auc: 0.6617\n","Epoch 5/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.5883 - accuracy: 0.6955 - auc: 0.6956 - val_loss: 0.5989 - val_accuracy: 0.6860 - val_auc: 0.6759\n","Epoch 1/5\n","25/25 [==============================] - 7s 213ms/step - loss: 5.2044 - accuracy: 0.5721 - auc: 0.5154 - val_loss: 0.7226 - val_accuracy: 0.6134 - val_auc: 0.5367\n","Epoch 2/5\n","25/25 [==============================] - 6s 244ms/step - loss: 0.6862 - accuracy: 0.6545 - auc: 0.5759 - val_loss: 0.6411 - val_accuracy: 0.6890 - val_auc: 0.6032\n","Epoch 3/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6571 - accuracy: 0.6690 - auc: 0.6047 - val_loss: 0.6166 - val_accuracy: 0.6773 - val_auc: 0.6466\n","Epoch 4/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.5997 - accuracy: 0.6816 - auc: 0.6851 - val_loss: 0.6233 - val_accuracy: 0.6831 - val_auc: 0.6883\n","Epoch 5/5\n","25/25 [==============================] - 5s 195ms/step - loss: 0.5794 - accuracy: 0.6933 - auc: 0.7242 - val_loss: 0.6075 - val_accuracy: 0.6686 - val_auc: 0.6663\n","Epoch 1/5\n","25/25 [==============================] - 7s 223ms/step - loss: 3.5444 - accuracy: 0.5866 - auc: 0.5215 - val_loss: 0.8044 - val_accuracy: 0.6366 - val_auc: 0.5591\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6924 - accuracy: 0.6658 - auc: 0.6150 - val_loss: 0.9770 - val_accuracy: 0.6715 - val_auc: 0.5864\n","Epoch 3/5\n","25/25 [==============================] - 5s 192ms/step - loss: 0.9201 - accuracy: 0.6564 - auc: 0.6389 - val_loss: 0.8999 - val_accuracy: 0.6628 - val_auc: 0.6413\n","Epoch 4/5\n","25/25 [==============================] - 6s 248ms/step - loss: 0.8226 - accuracy: 0.6677 - auc: 0.6724 - val_loss: 0.8026 - val_accuracy: 0.6599 - val_auc: 0.6561\n","Epoch 5/5\n","25/25 [==============================] - 11s 430ms/step - loss: 0.7156 - accuracy: 0.6900 - auc: 0.7079 - val_loss: 0.6913 - val_accuracy: 0.6453 - val_auc: 0.6535\n","Epoch 1/5\n","25/25 [==============================] - 9s 310ms/step - loss: 4.0803 - accuracy: 0.6021 - auc: 0.5279 - val_loss: 1.1291 - val_accuracy: 0.6570 - val_auc: 0.6175\n","Epoch 2/5\n","25/25 [==============================] - 11s 432ms/step - loss: 0.7530 - accuracy: 0.6632 - auc: 0.5903 - val_loss: 0.6945 - val_accuracy: 0.6831 - val_auc: 0.6541\n","Epoch 3/5\n","25/25 [==============================] - 8s 298ms/step - loss: 0.6382 - accuracy: 0.6745 - auc: 0.6458 - val_loss: 0.6029 - val_accuracy: 0.6744 - val_auc: 0.6458\n","Epoch 4/5\n","25/25 [==============================] - 7s 278ms/step - loss: 0.6262 - accuracy: 0.6813 - auc: 0.6768 - val_loss: 0.7143 - val_accuracy: 0.6657 - val_auc: 0.6664\n","Epoch 5/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.5877 - accuracy: 0.6991 - auc: 0.7234 - val_loss: 0.6996 - val_accuracy: 0.6773 - val_auc: 0.6711\n","Epoch 1/5\n","25/25 [==============================] - 7s 215ms/step - loss: 3.6606 - accuracy: 0.5769 - auc: 0.5110 - val_loss: 0.7526 - val_accuracy: 0.6424 - val_auc: 0.6245\n","Epoch 2/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.8067 - accuracy: 0.6341 - auc: 0.5831 - val_loss: 0.6776 - val_accuracy: 0.6744 - val_auc: 0.6348\n","Epoch 3/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6468 - accuracy: 0.6816 - auc: 0.6608 - val_loss: 0.6210 - val_accuracy: 0.6744 - val_auc: 0.6522\n","Epoch 4/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6583 - accuracy: 0.6752 - auc: 0.6697 - val_loss: 0.8086 - val_accuracy: 0.6308 - val_auc: 0.6648\n","Epoch 5/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6531 - accuracy: 0.6849 - auc: 0.6914 - val_loss: 0.6765 - val_accuracy: 0.6657 - val_auc: 0.6181\n","Epoch 1/5\n","25/25 [==============================] - 7s 217ms/step - loss: 2.1563 - accuracy: 0.6044 - auc: 0.5169 - val_loss: 0.6735 - val_accuracy: 0.6802 - val_auc: 0.5441\n","Epoch 2/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6321 - accuracy: 0.6774 - auc: 0.6040 - val_loss: 0.6285 - val_accuracy: 0.6657 - val_auc: 0.6545\n","Epoch 3/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6422 - accuracy: 0.6668 - auc: 0.6571 - val_loss: 0.6944 - val_accuracy: 0.6860 - val_auc: 0.6636\n","Epoch 4/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6196 - accuracy: 0.6878 - auc: 0.6764 - val_loss: 0.6032 - val_accuracy: 0.7035 - val_auc: 0.6629\n","Epoch 5/5\n","25/25 [==============================] - 6s 239ms/step - loss: 0.5737 - accuracy: 0.7033 - auc: 0.7331 - val_loss: 0.6076 - val_accuracy: 0.6831 - val_auc: 0.6676\n","Epoch 1/5\n","25/25 [==============================] - 7s 221ms/step - loss: 3.1150 - accuracy: 0.5750 - auc: 0.5051 - val_loss: 0.7842 - val_accuracy: 0.5698 - val_auc: 0.6041\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.7039 - accuracy: 0.6493 - auc: 0.5539 - val_loss: 0.6307 - val_accuracy: 0.6860 - val_auc: 0.5702\n","Epoch 3/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6311 - accuracy: 0.6765 - auc: 0.6229 - val_loss: 0.5930 - val_accuracy: 0.6860 - val_auc: 0.6865\n","Epoch 4/5\n","25/25 [==============================] - 5s 194ms/step - loss: 0.6057 - accuracy: 0.6765 - auc: 0.6671 - val_loss: 0.6420 - val_accuracy: 0.6831 - val_auc: 0.6627\n","Epoch 5/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5961 - accuracy: 0.6871 - auc: 0.7037 - val_loss: 0.5829 - val_accuracy: 0.6977 - val_auc: 0.6990\n","Epoch 1/5\n","25/25 [==============================] - 8s 278ms/step - loss: 3.5574 - accuracy: 0.5792 - auc: 0.5238 - val_loss: 0.6546 - val_accuracy: 0.6512 - val_auc: 0.5775\n","Epoch 2/5\n","25/25 [==============================] - 7s 296ms/step - loss: 0.6565 - accuracy: 0.6736 - auc: 0.6220 - val_loss: 1.0334 - val_accuracy: 0.6657 - val_auc: 0.5666\n","Epoch 3/5\n","25/25 [==============================] - 7s 284ms/step - loss: 0.7975 - accuracy: 0.6312 - auc: 0.6064 - val_loss: 0.6256 - val_accuracy: 0.6715 - val_auc: 0.6544\n","Epoch 4/5\n","25/25 [==============================] - 9s 351ms/step - loss: 0.6417 - accuracy: 0.6729 - auc: 0.6732 - val_loss: 0.6585 - val_accuracy: 0.6802 - val_auc: 0.6541\n","Epoch 5/5\n","25/25 [==============================] - 9s 380ms/step - loss: 0.5990 - accuracy: 0.6826 - auc: 0.7089 - val_loss: 0.6161 - val_accuracy: 0.6831 - val_auc: 0.6796\n","Epoch 1/5\n","25/25 [==============================] - 12s 369ms/step - loss: 2.9459 - accuracy: 0.5837 - auc: 0.5284 - val_loss: 0.7329 - val_accuracy: 0.6483 - val_auc: 0.5562\n","Epoch 2/5\n","25/25 [==============================] - 8s 310ms/step - loss: 0.8338 - accuracy: 0.6490 - auc: 0.5790 - val_loss: 0.7063 - val_accuracy: 0.6715 - val_auc: 0.6165\n","Epoch 3/5\n","25/25 [==============================] - 5s 211ms/step - loss: 0.6300 - accuracy: 0.6668 - auc: 0.6546 - val_loss: 0.6363 - val_accuracy: 0.6715 - val_auc: 0.6586\n","Epoch 4/5\n","25/25 [==============================] - 5s 206ms/step - loss: 0.7162 - accuracy: 0.6622 - auc: 0.6708 - val_loss: 0.7727 - val_accuracy: 0.6977 - val_auc: 0.6665\n","Epoch 5/5\n","25/25 [==============================] - 8s 333ms/step - loss: 0.7299 - accuracy: 0.6907 - auc: 0.7066 - val_loss: 0.6007 - val_accuracy: 0.6860 - val_auc: 0.6948\n","Epoch 1/5\n","25/25 [==============================] - 10s 287ms/step - loss: 2.9779 - accuracy: 0.5708 - auc: 0.5143 - val_loss: 0.6720 - val_accuracy: 0.5785 - val_auc: 0.5434\n","Epoch 2/5\n","25/25 [==============================] - 7s 288ms/step - loss: 0.6610 - accuracy: 0.6506 - auc: 0.5826 - val_loss: 0.6947 - val_accuracy: 0.6657 - val_auc: 0.5686\n","Epoch 3/5\n","25/25 [==============================] - 7s 291ms/step - loss: 0.6333 - accuracy: 0.6784 - auc: 0.6283 - val_loss: 0.5968 - val_accuracy: 0.6831 - val_auc: 0.6740\n","Epoch 4/5\n","25/25 [==============================] - 8s 309ms/step - loss: 0.6014 - accuracy: 0.6807 - auc: 0.6991 - val_loss: 0.6325 - val_accuracy: 0.6744 - val_auc: 0.6826\n","Epoch 5/5\n","25/25 [==============================] - 7s 292ms/step - loss: 0.5949 - accuracy: 0.6842 - auc: 0.7130 - val_loss: 0.5897 - val_accuracy: 0.6890 - val_auc: 0.6807\n","Epoch 1/5\n","25/25 [==============================] - 11s 361ms/step - loss: 2.9335 - accuracy: 0.5850 - auc: 0.5299 - val_loss: 0.9148 - val_accuracy: 0.6657 - val_auc: 0.4665\n","Epoch 2/5\n","25/25 [==============================] - 8s 301ms/step - loss: 0.7654 - accuracy: 0.6461 - auc: 0.5910 - val_loss: 0.6279 - val_accuracy: 0.6919 - val_auc: 0.6306\n","Epoch 3/5\n","25/25 [==============================] - 7s 284ms/step - loss: 0.6131 - accuracy: 0.6816 - auc: 0.6600 - val_loss: 0.6027 - val_accuracy: 0.6919 - val_auc: 0.6693\n","Epoch 4/5\n","25/25 [==============================] - 7s 289ms/step - loss: 0.6090 - accuracy: 0.6778 - auc: 0.6803 - val_loss: 0.6420 - val_accuracy: 0.6715 - val_auc: 0.6623\n","Epoch 5/5\n","25/25 [==============================] - 7s 286ms/step - loss: 0.6012 - accuracy: 0.6849 - auc: 0.7031 - val_loss: 0.6013 - val_accuracy: 0.6628 - val_auc: 0.6686\n","Epoch 1/5\n","25/25 [==============================] - 11s 334ms/step - loss: 3.6025 - accuracy: 0.5837 - auc: 0.5193 - val_loss: 0.6376 - val_accuracy: 0.6628 - val_auc: 0.5889\n","Epoch 2/5\n","25/25 [==============================] - 7s 289ms/step - loss: 0.6586 - accuracy: 0.6538 - auc: 0.5961 - val_loss: 0.6665 - val_accuracy: 0.6570 - val_auc: 0.6319\n","Epoch 3/5\n","25/25 [==============================] - 7s 288ms/step - loss: 0.6030 - accuracy: 0.6723 - auc: 0.6746 - val_loss: 0.6088 - val_accuracy: 0.6512 - val_auc: 0.6527\n","Epoch 4/5\n","25/25 [==============================] - 8s 331ms/step - loss: 0.5913 - accuracy: 0.6823 - auc: 0.7041 - val_loss: 0.6493 - val_accuracy: 0.6744 - val_auc: 0.6662\n","Epoch 5/5\n","25/25 [==============================] - 7s 293ms/step - loss: 0.5881 - accuracy: 0.6913 - auc: 0.7131 - val_loss: 0.6047 - val_accuracy: 0.6919 - val_auc: 0.6995\n","Epoch 1/5\n","25/25 [==============================] - 10s 303ms/step - loss: 4.2498 - accuracy: 0.5928 - auc: 0.5234 - val_loss: 0.6393 - val_accuracy: 0.6802 - val_auc: 0.5735\n","Epoch 2/5\n","25/25 [==============================] - 7s 275ms/step - loss: 0.6363 - accuracy: 0.6784 - auc: 0.6186 - val_loss: 0.6330 - val_accuracy: 0.6715 - val_auc: 0.6473\n","Epoch 3/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6446 - accuracy: 0.6774 - auc: 0.6666 - val_loss: 0.7163 - val_accuracy: 0.6017 - val_auc: 0.6487\n","Epoch 4/5\n","25/25 [==============================] - 5s 195ms/step - loss: 0.6313 - accuracy: 0.6600 - auc: 0.6718 - val_loss: 0.6187 - val_accuracy: 0.6773 - val_auc: 0.6567\n","Epoch 5/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.5806 - accuracy: 0.6962 - auc: 0.7185 - val_loss: 0.5942 - val_accuracy: 0.6890 - val_auc: 0.6696\n"]}]},{"cell_type":"code","source":["val_F1_mean, val_F1_sd = mean_and_sd(val_F1)\n","val_acc_mean, val_acc_sd = mean_and_sd(val_acc)\n","val_pr_auc_mean, val_pr_auc_sd = mean_and_sd(val_pr_auc)\n","val_roc_auc_mean, val_roc_auc_sd = mean_and_sd(val_roc_auc)\n","\n","test_F1_mean, test_F1_sd = mean_and_sd(test_F1)\n","test_acc_mean, test_acc_sd = mean_and_sd(test_acc)\n","test_pr_auc_mean, test_pr_auc_sd = mean_and_sd(test_pr_auc)\n","test_roc_auc_mean, test_roc_auc_sd = mean_and_sd(test_roc_auc)"],"metadata":{"id":"9WTovGNylc-Q"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f\"validation F1: {round(val_F1_mean, 3)} +/- {round(val_F1_sd, 3)}\")\n","print(f\"test F1: {round(test_F1_mean, 3)} +/- {round(test_F1_sd, 3)}\")\n","\n","print(f\"validation accuracy: {round(val_acc_mean, 3)} +/- {round(val_acc_sd, 3)}\")\n","print(f\"test accuracy: {round(test_acc_mean, 3)} +/- {round(test_acc_sd, 3)}\")\n","\n","print(f\"validation PR-AUC: {round(val_pr_auc_mean, 3)} +/- {round(val_pr_auc_sd, 3)}\")\n","print(f\"test PR-AUC: {round(test_pr_auc_mean, 3)} +/- {round(test_pr_auc_sd, 3)}\")\n","\n","print(f\"validation ROC-AUC: {round(val_roc_auc_mean, 3)} +/- {round(val_roc_auc_sd, 3)}\")\n","print(f\"test ROC-AUC: {round(test_roc_auc_mean, 3)} +/- {round(test_roc_auc_sd, 3)}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"E7zHHIs6ls4f","executionInfo":{"status":"ok","timestamp":1651526772416,"user_tz":240,"elapsed":25,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"a89c59de-e86e-42ae-93dd-bdd4df2d2a86"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["validation F1: 0.784 +/- 0.02\n","test F1: 0.822 +/- 0.042\n","validation accuracy: 0.677 +/- 0.014\n","test accuracy: 0.719 +/- 0.04\n","validation PR-AUC: 0.798 +/- 0.018\n","test PR-AUC: 0.839 +/- 0.008\n","validation ROC-AUC: 0.674 +/- 0.017\n","test ROC-AUC: 0.649 +/- 0.015\n"]}]},{"cell_type":"markdown","source":["### Without disease data"],"metadata":{"id":"zyvdoJ0Vmuog"}},{"cell_type":"code","source":["n_runs = 30"],"metadata":{"id":"-m1MCcOcm5vi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1 = np.zeros(n_runs, dtype=float)\n","val_acc = np.zeros(n_runs, dtype=float)\n","val_pr_auc = np.zeros(n_runs, dtype=float)\n","val_roc_auc = np.zeros(n_runs, dtype=float)\n","\n","test_F1 = np.zeros(n_runs, dtype=float)\n","test_acc = np.zeros(n_runs, dtype=float)\n","test_pr_auc = np.zeros(n_runs, dtype=float)\n","test_roc_auc = np.zeros(n_runs, dtype=float)"],"metadata":{"id":"tS9OXzy2m8Nz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(n_runs):\n","\n","  # inclusion and exclusion criteria\n","  inclusion_input = layers.Input(shape=(incl_criteria_dim,), name=\"inclusion\")\n","  inclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_inclusion\")(inclusion_input)\n","\n","  exclusion_input = layers.Input(shape=(excl_criteria_dim,), name=\"exclusion\")\n","  exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_exclusion\")(exclusion_input)\n","\n","  inclusion_exclusion_raw = layers.Concatenate(name=\"criteria\")([inclusion_emb, exclusion_emb])\n","  ie_dropout = layers.Dropout(rate=0.32)(inclusion_exclusion_raw)\n","  inclusion_exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_criteria\")(ie_dropout)\n","\n","\n","  # drug\n","  drug_input = layers.Input(shape=(drug_dim,), name=\"drug\")\n","  drug_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug\")(drug_input)\n","\n","  # targets\n","  targets_input = layers.Input(shape=(targets_dim,), name=\"targets\")\n","  targets_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_targets\")(targets_input)\n","\n","  # drug-target interaction\n","  drug_target_raw = layers.Concatenate(name=\"drug-targets\")([drug_emb, targets_emb])\n","  dt = layers.Dropout(rate=0.32)(drug_target_raw)\n","  drug_target_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug-targets\")(dt)\n","\n","\n","  all_emb = layers.Concatenate(name=\"all_embeddings\")([inclusion_exclusion_emb, drug_target_emb])\n","  ae = layers.Dropout(rate=0.32)(all_emb)\n","  trial_embedding1 = layers.Dense(128, activation=\"relu\", name=\"trial_embedding_1\")(ae)\n","\n","  meta_input = layers.Input(shape=(meta_dim,), name=\"meta\")\n","  emb_and_meta = layers.Concatenate()([trial_embedding1, meta_input])\n","\n","  trial_embedding2 = layers.Dense(64, activation=\"relu\", name=\"trial_embedding_2\")(emb_and_meta)\n","  trial_embedding3 = layers.Dense(32, activation=\"relu\", name=\"trial_embedding_3\")(trial_embedding2)\n","\n","  o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","  model = tf.keras.Model(inputs=[inclusion_input, exclusion_input, drug_input, targets_input, meta_input], outputs=[o])\n","  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","\n","  history = model.fit(\n","    x={\"inclusion\": train_data_incl, \"exclusion\": train_data_excl, \n","      \"drug\": train_data_drug, \"targets\": train_data_target,\n","      \"meta\": train_data_meta},\n","    y={\"trial_success\": train_labels},\n","    validation_data=(\n","        {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","        \"drug\": val_data_drug, \"targets\": val_data_target,\n","        \"meta\": val_data_meta},\n","        {\"trial_success\": val_labels}\n","    ), \n","    epochs=5,\n","    batch_size=128\n","  )\n","\n","\n","  val_metrics = model_performance(\n","    model,\n","    {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl, \n","    \"drug\": val_data_drug, \"targets\": val_data_target,\n","    \"meta\": val_data_meta},\n","    val_labels,\n","    save_path=None\n","  )\n","\n","  test_metrics = model_performance(\n","    model,\n","    {\"inclusion\": test_data_incl, \"exclusion\": test_data_excl, \n","    \"drug\": test_data_drug, \"targets\": test_data_target,\n","    \"meta\": test_data_meta},\n","    test_labels,\n","    save_path=None\n","  )\n","\n","  val_F1[i] = val_metrics[\"F1\"]\n","  val_acc[i] = val_metrics[\"accuracy\"]\n","  val_pr_auc[i] = val_metrics[\"pr_auc\"]\n","  val_roc_auc[i] = val_metrics[\"roc_auc\"]\n","\n","  test_F1[i] = test_metrics[\"F1\"]\n","  test_acc[i] = test_metrics[\"accuracy\"]\n","  test_pr_auc[i] = test_metrics[\"pr_auc\"]\n","  test_roc_auc[i] = test_metrics[\"roc_auc\"]"],"metadata":{"id":"KIfFSkeenCmF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1_mean, val_F1_sd = mean_and_sd(val_F1)\n","val_acc_mean, val_acc_sd = mean_and_sd(val_acc)\n","val_pr_auc_mean, val_pr_auc_sd = mean_and_sd(val_pr_auc)\n","val_roc_auc_mean, val_roc_auc_sd = mean_and_sd(val_roc_auc)\n","\n","test_F1_mean, test_F1_sd = mean_and_sd(test_F1)\n","test_acc_mean, test_acc_sd = mean_and_sd(test_acc)\n","test_pr_auc_mean, test_pr_auc_sd = mean_and_sd(test_pr_auc)\n","test_roc_auc_mean, test_roc_auc_sd = mean_and_sd(test_roc_auc)"],"metadata":{"id":"RggtOU2noIDa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f\"validation F1: {round(val_F1_mean, 3)} +/- {round(val_F1_sd, 3)}\")\n","print(f\"test F1: {round(test_F1_mean, 3)} +/- {round(test_F1_sd, 3)}\")\n","\n","print(f\"validation accuracy: {round(val_acc_mean, 3)} +/- {round(val_acc_sd, 3)}\")\n","print(f\"test accuracy: {round(test_acc_mean, 3)} +/- {round(test_acc_sd, 3)}\")\n","\n","print(f\"validation PR-AUC: {round(val_pr_auc_mean, 3)} +/- {round(val_pr_auc_sd, 3)}\")\n","print(f\"test PR-AUC: {round(test_pr_auc_mean, 3)} +/- {round(test_pr_auc_sd, 3)}\")\n","\n","print(f\"validation ROC-AUC: {round(val_roc_auc_mean, 3)} +/- {round(val_roc_auc_sd, 3)}\")\n","print(f\"test ROC-AUC: {round(test_roc_auc_mean, 3)} +/- {round(test_roc_auc_sd, 3)}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kHliRYWLoKND","executionInfo":{"status":"ok","timestamp":1651527968428,"user_tz":240,"elapsed":14,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"063c2103-e15e-422a-c19c-b6199d2dcd20"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["validation F1: 0.798 +/- 0.007\n","test F1: 0.838 +/- 0.019\n","validation accuracy: 0.685 +/- 0.011\n","test accuracy: 0.729 +/- 0.023\n","validation PR-AUC: 0.754 +/- 0.012\n","test PR-AUC: 0.814 +/- 0.012\n","validation ROC-AUC: 0.639 +/- 0.013\n","test ROC-AUC: 0.604 +/- 0.024\n"]}]},{"cell_type":"markdown","source":["### Without drug data"],"metadata":{"id":"WA8cHUO8pMlb"}},{"cell_type":"code","source":["n_runs = 30"],"metadata":{"id":"qRZwNXrQpMle"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1 = np.zeros(n_runs, dtype=float)\n","val_acc = np.zeros(n_runs, dtype=float)\n","val_pr_auc = np.zeros(n_runs, dtype=float)\n","val_roc_auc = np.zeros(n_runs, dtype=float)\n","\n","test_F1 = np.zeros(n_runs, dtype=float)\n","test_acc = np.zeros(n_runs, dtype=float)\n","test_pr_auc = np.zeros(n_runs, dtype=float)\n","test_roc_auc = np.zeros(n_runs, dtype=float)"],"metadata":{"id":"jL8YVYJ3pMlg"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(n_runs):\n","\n","  # inclusion and exclusion criteria\n","  inclusion_input = layers.Input(shape=(incl_criteria_dim,), name=\"inclusion\")\n","  inclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_inclusion\")(inclusion_input)\n","\n","  exclusion_input = layers.Input(shape=(excl_criteria_dim,), name=\"exclusion\")\n","  exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_exclusion\")(exclusion_input)\n","\n","  inclusion_exclusion_raw = layers.Concatenate(name=\"criteria\")([inclusion_emb, exclusion_emb])\n","  ie_dropout = layers.Dropout(rate=0.32)(inclusion_exclusion_raw)\n","  inclusion_exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_criteria\")(ie_dropout)\n","\n","  # diseases\n","  diseases_input = layers.Input(shape=(diseases_dim,), name=\"diseases\")\n","  diseases_emb = layers.Dense(128, activation=\"relu\", name=\"lower_dim_diseases\")(diseases_input)\n","\n","  # targets\n","  targets_input = layers.Input(shape=(targets_dim,), name=\"targets\")\n","  targets_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_targets\")(targets_input)\n","\n","\n","  all_emb = layers.Concatenate(name=\"all_embeddings\")([inclusion_exclusion_emb, diseases_emb, targets_emb])\n","  ae = layers.Dropout(rate=0.32)(all_emb)\n","  trial_embedding1 = layers.Dense(128, activation=\"relu\", name=\"trial_embedding_1\")(ae)\n","\n","  meta_input = layers.Input(shape=(meta_dim,), name=\"meta\")\n","  emb_and_meta = layers.Concatenate()([trial_embedding1, meta_input])\n","\n","  trial_embedding2 = layers.Dense(64, activation=\"relu\", name=\"trial_embedding_2\")(emb_and_meta)\n","  trial_embedding3 = layers.Dense(32, activation=\"relu\", name=\"trial_embedding_3\")(trial_embedding2)\n","\n","  o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","  model = tf.keras.Model(inputs=[inclusion_input, exclusion_input, diseases_input, targets_input, meta_input], outputs=[o])\n","  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","\n","  history = model.fit(\n","    x={\"inclusion\": train_data_incl, \"exclusion\": train_data_excl,\n","       \"diseases\": train_data_diseases, \"targets\": train_data_target,\n","      \"meta\": train_data_meta},\n","    y={\"trial_success\": train_labels},\n","    validation_data=(\n","        {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","        \"diseases\": val_data_diseases, \"targets\": val_data_target,\n","        \"meta\": val_data_meta},\n","        {\"trial_success\": val_labels}\n","    ), \n","    epochs=5,\n","    batch_size=128\n","  )\n","\n","\n","  val_metrics = model_performance(\n","    model,\n","    {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl, \n","    \"diseases\": val_data_diseases, \"targets\": val_data_target,\n","    \"meta\": val_data_meta},\n","    val_labels,\n","    save_path=None\n","  )\n","\n","  test_metrics = model_performance(\n","    model,\n","    {\"inclusion\": test_data_incl, \"exclusion\": test_data_excl, \n","    \"diseases\": test_data_diseases, \"targets\": test_data_target,\n","    \"meta\": test_data_meta},\n","    test_labels,\n","    save_path=None\n","  )\n","\n","  val_F1[i] = val_metrics[\"F1\"]\n","  val_acc[i] = val_metrics[\"accuracy\"]\n","  val_pr_auc[i] = val_metrics[\"pr_auc\"]\n","  val_roc_auc[i] = val_metrics[\"roc_auc\"]\n","\n","  test_F1[i] = test_metrics[\"F1\"]\n","  test_acc[i] = test_metrics[\"accuracy\"]\n","  test_pr_auc[i] = test_metrics[\"pr_auc\"]\n","  test_roc_auc[i] = test_metrics[\"roc_auc\"]"],"metadata":{"id":"jzs9CyMZpMlh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1_mean, val_F1_sd = mean_and_sd(val_F1)\n","val_acc_mean, val_acc_sd = mean_and_sd(val_acc)\n","val_pr_auc_mean, val_pr_auc_sd = mean_and_sd(val_pr_auc)\n","val_roc_auc_mean, val_roc_auc_sd = mean_and_sd(val_roc_auc)\n","\n","test_F1_mean, test_F1_sd = mean_and_sd(test_F1)\n","test_acc_mean, test_acc_sd = mean_and_sd(test_acc)\n","test_pr_auc_mean, test_pr_auc_sd = mean_and_sd(test_pr_auc)\n","test_roc_auc_mean, test_roc_auc_sd = mean_and_sd(test_roc_auc)"],"metadata":{"id":"CNHR5bcJpMll"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f\"validation F1: {round(val_F1_mean, 3)} +/- {round(val_F1_sd, 3)}\")\n","print(f\"test F1: {round(test_F1_mean, 3)} +/- {round(test_F1_sd, 3)}\")\n","\n","print(f\"validation accuracy: {round(val_acc_mean, 3)} +/- {round(val_acc_sd, 3)}\")\n","print(f\"test accuracy: {round(test_acc_mean, 3)} +/- {round(test_acc_sd, 3)}\")\n","\n","print(f\"validation PR-AUC: {round(val_pr_auc_mean, 3)} +/- {round(val_pr_auc_sd, 3)}\")\n","print(f\"test PR-AUC: {round(test_pr_auc_mean, 3)} +/- {round(test_pr_auc_sd, 3)}\")\n","\n","print(f\"validation ROC-AUC: {round(val_roc_auc_mean, 3)} +/- {round(val_roc_auc_sd, 3)}\")\n","print(f\"test ROC-AUC: {round(test_roc_auc_mean, 3)} +/- {round(test_roc_auc_sd, 3)}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"oJx1NwE3pMlo","executionInfo":{"status":"ok","timestamp":1651529173009,"user_tz":240,"elapsed":17,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"3dcdd25f-42a1-4e50-957c-d7897af4ca4e"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["validation F1: 0.787 +/- 0.014\n","test F1: 0.83 +/- 0.027\n","validation accuracy: 0.683 +/- 0.01\n","test accuracy: 0.729 +/- 0.028\n","validation PR-AUC: 0.802 +/- 0.012\n","test PR-AUC: 0.842 +/- 0.007\n","validation ROC-AUC: 0.679 +/- 0.011\n","test ROC-AUC: 0.667 +/- 0.015\n"]}]},{"cell_type":"markdown","source":["### Without meta data"],"metadata":{"id":"qWjp7FlTrTWg"}},{"cell_type":"code","source":["n_runs = 30"],"metadata":{"id":"7V4NvDTUrXvz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["val_F1 = np.zeros(n_runs, dtype=float)\n","val_acc = np.zeros(n_runs, dtype=float)\n","val_pr_auc = np.zeros(n_runs, dtype=float)\n","val_roc_auc = np.zeros(n_runs, dtype=float)\n","\n","test_F1 = np.zeros(n_runs, dtype=float)\n","test_acc = np.zeros(n_runs, dtype=float)\n","test_pr_auc = np.zeros(n_runs, dtype=float)\n","test_roc_auc = np.zeros(n_runs, dtype=float)"],"metadata":{"id":"Uk51fO2graji"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["for i in range(n_runs):\n","\n","  # inclusion and exclusion criteria\n","  inclusion_input = layers.Input(shape=(incl_criteria_dim,), name=\"inclusion\")\n","  inclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_inclusion\")(inclusion_input)\n","\n","  exclusion_input = layers.Input(shape=(excl_criteria_dim,), name=\"exclusion\")\n","  exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_exclusion\")(exclusion_input)\n","\n","  inclusion_exclusion_raw = layers.Concatenate(name=\"criteria\")([inclusion_emb, exclusion_emb])\n","  ie_dropout = layers.Dropout(rate=0.32)(inclusion_exclusion_raw)\n","  inclusion_exclusion_emb = layers.Dense(256, activation=\"relu\", name=\"lower_dim_criteria\")(ie_dropout)\n","\n","\n","  # diseases\n","  diseases_input = layers.Input(shape=(diseases_dim,), name=\"diseases\")\n","  diseases_emb = layers.Dense(128, activation=\"relu\", name=\"lower_dim_diseases\")(diseases_input)\n","\n","\n","  # drug\n","  drug_input = layers.Input(shape=(drug_dim,), name=\"drug\")\n","  drug_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug\")(drug_input)\n","\n","  # targets\n","  targets_input = layers.Input(shape=(targets_dim,), name=\"targets\")\n","  targets_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_targets\")(targets_input)\n","\n","  # drug-target interaction\n","  drug_target_raw = layers.Concatenate(name=\"drug-targets\")([drug_emb, targets_emb])\n","  dt = layers.Dropout(rate=0.32)(drug_target_raw)\n","  drug_target_emb = layers.Dense(64, activation=\"relu\", name=\"lower_dim_drug-targets\")(dt)\n","\n","\n","  all_emb = layers.Concatenate(name=\"all_embeddings\")([inclusion_exclusion_emb, diseases_emb, drug_target_emb])\n","  ae = layers.Dropout(rate=0.32)(all_emb)\n","  trial_embedding1 = layers.Dense(128, activation=\"relu\", name=\"trial_embedding_1\")(ae)\n","  trial_embedding2 = layers.Dense(64, activation=\"relu\", name=\"trial_embedding_2\")(trial_embedding1)\n","  trial_embedding3 = layers.Dense(32, activation=\"relu\", name=\"trial_embedding_3\")(trial_embedding2)\n","\n","  o = layers.Dense(1, activation=\"sigmoid\", name=\"trial_success\")(trial_embedding3)\n","\n","  model = tf.keras.Model(inputs=[inclusion_input, exclusion_input, diseases_input, drug_input, targets_input], outputs=[o])\n","  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss=\"binary_crossentropy\", metrics=[\"accuracy\", \"AUC\"])\n","\n","  history = model.fit(\n","    x={\"inclusion\": train_data_incl, \"exclusion\": train_data_excl,\n","       \"drug\": train_data_drug, \"diseases\": train_data_diseases, \n","       \"targets\": train_data_target},\n","    y={\"trial_success\": train_labels},\n","    validation_data=(\n","        {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","         \"drug\": val_data_drug, \"diseases\": val_data_diseases, \n","         \"targets\": val_data_target\n","        },\n","        {\"trial_success\": val_labels}\n","    ), \n","    epochs=5,\n","    batch_size=128\n","  )\n","\n","\n","  val_metrics = model_performance(\n","    model,\n","    {\"inclusion\": val_data_incl, \"exclusion\": val_data_excl,\n","     \"drug\": val_data_drug, \"diseases\": val_data_diseases, \n","     \"targets\": val_data_target},\n","    val_labels,\n","    save_path=None\n","  )\n","\n","  test_metrics = model_performance(\n","    model,\n","    {\"inclusion\": test_data_incl, \"exclusion\": test_data_excl,\n","     \"drug\": test_data_drug, \"diseases\": test_data_diseases, \n","     \"targets\": test_data_target},\n","    test_labels,\n","    save_path=None\n","  )\n","\n","  val_F1[i] = val_metrics[\"F1\"]\n","  val_acc[i] = val_metrics[\"accuracy\"]\n","  val_pr_auc[i] = val_metrics[\"pr_auc\"]\n","  val_roc_auc[i] = val_metrics[\"roc_auc\"]\n","\n","  test_F1[i] = test_metrics[\"F1\"]\n","  test_acc[i] = test_metrics[\"accuracy\"]\n","  test_pr_auc[i] = test_metrics[\"pr_auc\"]\n","  test_roc_auc[i] = test_metrics[\"roc_auc\"]"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"eTOicaYtrfD5","executionInfo":{"status":"ok","timestamp":1651530561600,"user_tz":240,"elapsed":1191457,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"52aa55d1-02e5-46cb-853a-e3f9dbe38ad3"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/5\n","25/25 [==============================] - 7s 221ms/step - loss: 2.8720 - accuracy: 0.5957 - auc: 0.5178 - val_loss: 0.6364 - val_accuracy: 0.6802 - val_auc: 0.5521\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6398 - accuracy: 0.6739 - auc: 0.5995 - val_loss: 0.6445 - val_accuracy: 0.6831 - val_auc: 0.6133\n","Epoch 3/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6252 - accuracy: 0.6723 - auc: 0.6217 - val_loss: 0.6202 - val_accuracy: 0.6628 - val_auc: 0.6541\n","Epoch 4/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6048 - accuracy: 0.6732 - auc: 0.6838 - val_loss: 0.5979 - val_accuracy: 0.6773 - val_auc: 0.6830\n","Epoch 5/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.5752 - accuracy: 0.6871 - auc: 0.7254 - val_loss: 0.6089 - val_accuracy: 0.6802 - val_auc: 0.6766\n","Epoch 1/5\n","25/25 [==============================] - 7s 228ms/step - loss: 2.2530 - accuracy: 0.5798 - auc: 0.5297 - val_loss: 0.6348 - val_accuracy: 0.6570 - val_auc: 0.5560\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6528 - accuracy: 0.6690 - auc: 0.6302 - val_loss: 0.6147 - val_accuracy: 0.6802 - val_auc: 0.6401\n","Epoch 3/5\n","25/25 [==============================] - 6s 223ms/step - loss: 0.5987 - accuracy: 0.6787 - auc: 0.6730 - val_loss: 0.6055 - val_accuracy: 0.6686 - val_auc: 0.6477\n","Epoch 4/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.5866 - accuracy: 0.6862 - auc: 0.6983 - val_loss: 0.6094 - val_accuracy: 0.6744 - val_auc: 0.6683\n","Epoch 5/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.5718 - accuracy: 0.7004 - auc: 0.7258 - val_loss: 0.6061 - val_accuracy: 0.6890 - val_auc: 0.6649\n","Epoch 1/5\n","25/25 [==============================] - 7s 213ms/step - loss: 4.1644 - accuracy: 0.5960 - auc: 0.5108 - val_loss: 0.6416 - val_accuracy: 0.6686 - val_auc: 0.5058\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6529 - accuracy: 0.6622 - auc: 0.5624 - val_loss: 0.6288 - val_accuracy: 0.6686 - val_auc: 0.6495\n","Epoch 3/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6168 - accuracy: 0.6758 - auc: 0.6300 - val_loss: 0.5982 - val_accuracy: 0.6831 - val_auc: 0.6803\n","Epoch 4/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.5979 - accuracy: 0.6810 - auc: 0.6803 - val_loss: 0.5886 - val_accuracy: 0.6773 - val_auc: 0.6971\n","Epoch 5/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.5944 - accuracy: 0.6823 - auc: 0.6990 - val_loss: 0.6495 - val_accuracy: 0.6628 - val_auc: 0.6731\n","Epoch 1/5\n","25/25 [==============================] - 7s 222ms/step - loss: 3.5438 - accuracy: 0.5727 - auc: 0.5115 - val_loss: 0.6754 - val_accuracy: 0.6570 - val_auc: 0.5285\n","Epoch 2/5\n","25/25 [==============================] - 5s 209ms/step - loss: 0.6461 - accuracy: 0.6652 - auc: 0.6196 - val_loss: 0.5988 - val_accuracy: 0.6773 - val_auc: 0.6633\n","Epoch 3/5\n","25/25 [==============================] - 5s 207ms/step - loss: 0.6084 - accuracy: 0.6794 - auc: 0.6943 - val_loss: 0.5917 - val_accuracy: 0.6744 - val_auc: 0.6721\n","Epoch 4/5\n","25/25 [==============================] - 5s 208ms/step - loss: 0.5733 - accuracy: 0.6997 - auc: 0.7280 - val_loss: 0.5934 - val_accuracy: 0.6831 - val_auc: 0.6758\n","Epoch 5/5\n","25/25 [==============================] - 5s 212ms/step - loss: 0.5467 - accuracy: 0.7127 - auc: 0.7653 - val_loss: 0.5903 - val_accuracy: 0.7093 - val_auc: 0.6985\n","Epoch 1/5\n","25/25 [==============================] - 7s 223ms/step - loss: 2.9898 - accuracy: 0.5886 - auc: 0.5237 - val_loss: 0.6816 - val_accuracy: 0.6628 - val_auc: 0.5093\n","Epoch 2/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6733 - accuracy: 0.6522 - auc: 0.5819 - val_loss: 0.6484 - val_accuracy: 0.6657 - val_auc: 0.5529\n","Epoch 3/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6358 - accuracy: 0.6677 - auc: 0.6220 - val_loss: 0.6174 - val_accuracy: 0.6715 - val_auc: 0.6534\n","Epoch 4/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6154 - accuracy: 0.6687 - auc: 0.6667 - val_loss: 0.5978 - val_accuracy: 0.6773 - val_auc: 0.6751\n","Epoch 5/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6034 - accuracy: 0.6745 - auc: 0.6911 - val_loss: 0.6256 - val_accuracy: 0.6715 - val_auc: 0.7009\n","Epoch 1/5\n","25/25 [==============================] - 33s 221ms/step - loss: 2.8836 - accuracy: 0.5866 - auc: 0.5262 - val_loss: 0.6612 - val_accuracy: 0.6424 - val_auc: 0.6132\n","Epoch 2/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6504 - accuracy: 0.6412 - auc: 0.5913 - val_loss: 0.6649 - val_accuracy: 0.6657 - val_auc: 0.5088\n","Epoch 3/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6178 - accuracy: 0.6723 - auc: 0.6478 - val_loss: 0.5954 - val_accuracy: 0.6831 - val_auc: 0.6776\n","Epoch 4/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5925 - accuracy: 0.6823 - auc: 0.6983 - val_loss: 0.6142 - val_accuracy: 0.6831 - val_auc: 0.6552\n","Epoch 5/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.5782 - accuracy: 0.6871 - auc: 0.7194 - val_loss: 0.5957 - val_accuracy: 0.6773 - val_auc: 0.6703\n","Epoch 1/5\n","25/25 [==============================] - 7s 224ms/step - loss: 2.4289 - accuracy: 0.5902 - auc: 0.5076 - val_loss: 0.7202 - val_accuracy: 0.5552 - val_auc: 0.5643\n","Epoch 2/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6485 - accuracy: 0.6487 - auc: 0.6008 - val_loss: 0.6181 - val_accuracy: 0.6657 - val_auc: 0.6296\n","Epoch 3/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6157 - accuracy: 0.6755 - auc: 0.6414 - val_loss: 0.6223 - val_accuracy: 0.6773 - val_auc: 0.6396\n","Epoch 4/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.5965 - accuracy: 0.6836 - auc: 0.6816 - val_loss: 0.5999 - val_accuracy: 0.6802 - val_auc: 0.6619\n","Epoch 5/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5775 - accuracy: 0.6878 - auc: 0.7153 - val_loss: 0.5888 - val_accuracy: 0.6831 - val_auc: 0.6833\n","Epoch 1/5\n","25/25 [==============================] - 7s 212ms/step - loss: 4.5040 - accuracy: 0.5960 - auc: 0.5215 - val_loss: 0.6671 - val_accuracy: 0.6657 - val_auc: 0.4863\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6476 - accuracy: 0.6661 - auc: 0.5737 - val_loss: 0.6202 - val_accuracy: 0.6831 - val_auc: 0.6112\n","Epoch 3/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6220 - accuracy: 0.6813 - auc: 0.6296 - val_loss: 0.6093 - val_accuracy: 0.6831 - val_auc: 0.6521\n","Epoch 4/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.5944 - accuracy: 0.6826 - auc: 0.6877 - val_loss: 0.6211 - val_accuracy: 0.6831 - val_auc: 0.6859\n","Epoch 5/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.5684 - accuracy: 0.7023 - auc: 0.7407 - val_loss: 0.7594 - val_accuracy: 0.6831 - val_auc: 0.6877\n","Epoch 1/5\n","25/25 [==============================] - 7s 231ms/step - loss: 4.1338 - accuracy: 0.5876 - auc: 0.5455 - val_loss: 0.7275 - val_accuracy: 0.6802 - val_auc: 0.5735\n","Epoch 2/5\n","25/25 [==============================] - 5s 221ms/step - loss: 0.6829 - accuracy: 0.6610 - auc: 0.5789 - val_loss: 0.6488 - val_accuracy: 0.6773 - val_auc: 0.6175\n","Epoch 3/5\n","25/25 [==============================] - 9s 352ms/step - loss: 0.6134 - accuracy: 0.6742 - auc: 0.6668 - val_loss: 0.5989 - val_accuracy: 0.6802 - val_auc: 0.6738\n","Epoch 4/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5915 - accuracy: 0.6862 - auc: 0.7027 - val_loss: 0.5909 - val_accuracy: 0.6773 - val_auc: 0.6913\n","Epoch 5/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.5916 - accuracy: 0.6965 - auc: 0.7269 - val_loss: 0.6117 - val_accuracy: 0.6890 - val_auc: 0.6905\n","Epoch 1/5\n","25/25 [==============================] - 7s 233ms/step - loss: 2.5749 - accuracy: 0.5805 - auc: 0.5319 - val_loss: 0.6398 - val_accuracy: 0.6831 - val_auc: 0.6407\n","Epoch 2/5\n","25/25 [==============================] - 6s 237ms/step - loss: 0.6370 - accuracy: 0.6707 - auc: 0.6048 - val_loss: 0.6161 - val_accuracy: 0.6890 - val_auc: 0.6377\n","Epoch 3/5\n","25/25 [==============================] - 5s 215ms/step - loss: 0.6032 - accuracy: 0.6829 - auc: 0.6664 - val_loss: 0.5961 - val_accuracy: 0.6860 - val_auc: 0.6715\n","Epoch 4/5\n","25/25 [==============================] - 5s 212ms/step - loss: 0.5931 - accuracy: 0.6749 - auc: 0.6835 - val_loss: 0.6027 - val_accuracy: 0.6773 - val_auc: 0.6706\n","Epoch 5/5\n","25/25 [==============================] - 6s 226ms/step - loss: 0.5835 - accuracy: 0.6765 - auc: 0.7150 - val_loss: 0.5963 - val_accuracy: 0.6831 - val_auc: 0.6802\n","Epoch 1/5\n","25/25 [==============================] - 8s 219ms/step - loss: 3.4536 - accuracy: 0.5911 - auc: 0.5146 - val_loss: 0.6735 - val_accuracy: 0.6599 - val_auc: 0.5618\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6388 - accuracy: 0.6606 - auc: 0.5981 - val_loss: 0.6432 - val_accuracy: 0.6657 - val_auc: 0.5528\n","Epoch 3/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6382 - accuracy: 0.6616 - auc: 0.6246 - val_loss: 0.6080 - val_accuracy: 0.6802 - val_auc: 0.6459\n","Epoch 4/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.5926 - accuracy: 0.6842 - auc: 0.6899 - val_loss: 0.5954 - val_accuracy: 0.6890 - val_auc: 0.6902\n","Epoch 5/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.5673 - accuracy: 0.6917 - auc: 0.7375 - val_loss: 0.5926 - val_accuracy: 0.6919 - val_auc: 0.6896\n","Epoch 1/5\n","25/25 [==============================] - 7s 214ms/step - loss: 2.6731 - accuracy: 0.5795 - auc: 0.5031 - val_loss: 0.6879 - val_accuracy: 0.6715 - val_auc: 0.5343\n","Epoch 2/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.6630 - accuracy: 0.6571 - auc: 0.5982 - val_loss: 0.6247 - val_accuracy: 0.6831 - val_auc: 0.5846\n","Epoch 3/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6162 - accuracy: 0.6726 - auc: 0.6483 - val_loss: 0.6007 - val_accuracy: 0.6890 - val_auc: 0.6618\n","Epoch 4/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5995 - accuracy: 0.6833 - auc: 0.6810 - val_loss: 0.6026 - val_accuracy: 0.6831 - val_auc: 0.6724\n","Epoch 5/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.5658 - accuracy: 0.6984 - auc: 0.7394 - val_loss: 0.5948 - val_accuracy: 0.6890 - val_auc: 0.6843\n","Epoch 1/5\n","25/25 [==============================] - 7s 221ms/step - loss: 2.6600 - accuracy: 0.5417 - auc: 0.5000 - val_loss: 0.6164 - val_accuracy: 0.6715 - val_auc: 0.6074\n","Epoch 2/5\n","25/25 [==============================] - 5s 208ms/step - loss: 0.6761 - accuracy: 0.6484 - auc: 0.5843 - val_loss: 0.6438 - val_accuracy: 0.6628 - val_auc: 0.6369\n","Epoch 3/5\n","25/25 [==============================] - 5s 195ms/step - loss: 0.6233 - accuracy: 0.6639 - auc: 0.6448 - val_loss: 0.6349 - val_accuracy: 0.6657 - val_auc: 0.6229\n","Epoch 4/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.5971 - accuracy: 0.6800 - auc: 0.6992 - val_loss: 0.6054 - val_accuracy: 0.6773 - val_auc: 0.6791\n","Epoch 5/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5664 - accuracy: 0.6994 - auc: 0.7372 - val_loss: 0.6021 - val_accuracy: 0.6948 - val_auc: 0.6888\n","Epoch 1/5\n","25/25 [==============================] - 7s 219ms/step - loss: 2.5732 - accuracy: 0.5779 - auc: 0.5090 - val_loss: 0.6879 - val_accuracy: 0.6686 - val_auc: 0.5807\n","Epoch 2/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6411 - accuracy: 0.6645 - auc: 0.5978 - val_loss: 0.6221 - val_accuracy: 0.6773 - val_auc: 0.6302\n","Epoch 3/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6049 - accuracy: 0.6755 - auc: 0.6682 - val_loss: 0.6131 - val_accuracy: 0.6802 - val_auc: 0.6627\n","Epoch 4/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6021 - accuracy: 0.6816 - auc: 0.6863 - val_loss: 0.5960 - val_accuracy: 0.6715 - val_auc: 0.6727\n","Epoch 5/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.5743 - accuracy: 0.6910 - auc: 0.7273 - val_loss: 0.6088 - val_accuracy: 0.6657 - val_auc: 0.6730\n","Epoch 1/5\n","25/25 [==============================] - 7s 219ms/step - loss: 3.8925 - accuracy: 0.5604 - auc: 0.4935 - val_loss: 0.6604 - val_accuracy: 0.6744 - val_auc: 0.5165\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6801 - accuracy: 0.6526 - auc: 0.5695 - val_loss: 0.6119 - val_accuracy: 0.6802 - val_auc: 0.6244\n","Epoch 3/5\n","25/25 [==============================] - 5s 208ms/step - loss: 0.6338 - accuracy: 0.6684 - auc: 0.6405 - val_loss: 0.6125 - val_accuracy: 0.6802 - val_auc: 0.6526\n","Epoch 4/5\n","25/25 [==============================] - 5s 207ms/step - loss: 0.6058 - accuracy: 0.6716 - auc: 0.6673 - val_loss: 0.6034 - val_accuracy: 0.6773 - val_auc: 0.6758\n","Epoch 5/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.5824 - accuracy: 0.6875 - auc: 0.7059 - val_loss: 0.6517 - val_accuracy: 0.6860 - val_auc: 0.6799\n","Epoch 1/5\n","25/25 [==============================] - 7s 226ms/step - loss: 4.9008 - accuracy: 0.5663 - auc: 0.5274 - val_loss: 0.7584 - val_accuracy: 0.5698 - val_auc: 0.5564\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.7499 - accuracy: 0.6454 - auc: 0.5571 - val_loss: 0.6380 - val_accuracy: 0.6657 - val_auc: 0.5872\n","Epoch 3/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.7764 - accuracy: 0.6496 - auc: 0.5699 - val_loss: 0.7585 - val_accuracy: 0.6308 - val_auc: 0.6328\n","Epoch 4/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6173 - accuracy: 0.6855 - auc: 0.6681 - val_loss: 0.6167 - val_accuracy: 0.6890 - val_auc: 0.6697\n","Epoch 5/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6650 - accuracy: 0.6752 - auc: 0.6717 - val_loss: 0.7828 - val_accuracy: 0.6512 - val_auc: 0.6817\n","Epoch 1/5\n","25/25 [==============================] - 7s 214ms/step - loss: 1.7056 - accuracy: 0.5915 - auc: 0.5269 - val_loss: 0.7462 - val_accuracy: 0.5610 - val_auc: 0.6078\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6439 - accuracy: 0.6707 - auc: 0.5802 - val_loss: 0.6226 - val_accuracy: 0.6831 - val_auc: 0.6194\n","Epoch 3/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.6154 - accuracy: 0.6784 - auc: 0.6438 - val_loss: 0.6133 - val_accuracy: 0.6773 - val_auc: 0.6632\n","Epoch 4/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.5876 - accuracy: 0.6836 - auc: 0.6993 - val_loss: 0.6013 - val_accuracy: 0.6802 - val_auc: 0.6765\n","Epoch 5/5\n","25/25 [==============================] - 5s 206ms/step - loss: 0.5599 - accuracy: 0.7046 - auc: 0.7466 - val_loss: 0.6355 - val_accuracy: 0.6831 - val_auc: 0.6735\n","Epoch 1/5\n","25/25 [==============================] - 7s 217ms/step - loss: 3.0620 - accuracy: 0.5672 - auc: 0.5159 - val_loss: 0.6734 - val_accuracy: 0.6744 - val_auc: 0.5102\n","Epoch 2/5\n","25/25 [==============================] - 5s 207ms/step - loss: 0.6898 - accuracy: 0.6606 - auc: 0.5807 - val_loss: 0.7096 - val_accuracy: 0.6715 - val_auc: 0.5750\n","Epoch 3/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.6862 - accuracy: 0.6600 - auc: 0.5845 - val_loss: 0.6150 - val_accuracy: 0.6773 - val_auc: 0.6545\n","Epoch 4/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6260 - accuracy: 0.6710 - auc: 0.6603 - val_loss: 0.6002 - val_accuracy: 0.6831 - val_auc: 0.6587\n","Epoch 5/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.5950 - accuracy: 0.6807 - auc: 0.6976 - val_loss: 0.5901 - val_accuracy: 0.6802 - val_auc: 0.6872\n","Epoch 1/5\n","25/25 [==============================] - 7s 216ms/step - loss: 3.9171 - accuracy: 0.5856 - auc: 0.5272 - val_loss: 0.7199 - val_accuracy: 0.6657 - val_auc: 0.4941\n","Epoch 2/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.7404 - accuracy: 0.6632 - auc: 0.5929 - val_loss: 0.6110 - val_accuracy: 0.6686 - val_auc: 0.6590\n","Epoch 3/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6375 - accuracy: 0.6703 - auc: 0.6658 - val_loss: 0.6786 - val_accuracy: 0.6599 - val_auc: 0.6700\n","Epoch 4/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.6918 - accuracy: 0.6749 - auc: 0.6590 - val_loss: 0.6109 - val_accuracy: 0.6831 - val_auc: 0.6824\n","Epoch 5/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6019 - accuracy: 0.6994 - auc: 0.7352 - val_loss: 0.6349 - val_accuracy: 0.6657 - val_auc: 0.6710\n","Epoch 1/5\n","25/25 [==============================] - 7s 217ms/step - loss: 3.0466 - accuracy: 0.5769 - auc: 0.5065 - val_loss: 0.8265 - val_accuracy: 0.6657 - val_auc: 0.4470\n","Epoch 2/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.6797 - accuracy: 0.6671 - auc: 0.5468 - val_loss: 0.6872 - val_accuracy: 0.6628 - val_auc: 0.4842\n","Epoch 3/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6483 - accuracy: 0.6739 - auc: 0.5667 - val_loss: 0.6231 - val_accuracy: 0.6686 - val_auc: 0.6074\n","Epoch 4/5\n","25/25 [==============================] - 5s 194ms/step - loss: 0.6126 - accuracy: 0.6749 - auc: 0.6345 - val_loss: 0.6163 - val_accuracy: 0.6802 - val_auc: 0.6312\n","Epoch 5/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.5961 - accuracy: 0.6813 - auc: 0.6799 - val_loss: 0.5989 - val_accuracy: 0.6831 - val_auc: 0.6853\n","Epoch 1/5\n","25/25 [==============================] - 7s 219ms/step - loss: 4.5273 - accuracy: 0.5921 - auc: 0.5254 - val_loss: 0.7853 - val_accuracy: 0.6657 - val_auc: 0.4446\n","Epoch 2/5\n","25/25 [==============================] - 5s 209ms/step - loss: 0.7613 - accuracy: 0.6587 - auc: 0.5443 - val_loss: 0.6757 - val_accuracy: 0.6773 - val_auc: 0.6110\n","Epoch 3/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6511 - accuracy: 0.6719 - auc: 0.6004 - val_loss: 0.6433 - val_accuracy: 0.6831 - val_auc: 0.6267\n","Epoch 4/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6152 - accuracy: 0.6739 - auc: 0.6647 - val_loss: 0.6606 - val_accuracy: 0.6802 - val_auc: 0.6514\n","Epoch 5/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.6104 - accuracy: 0.6913 - auc: 0.7186 - val_loss: 0.6421 - val_accuracy: 0.6686 - val_auc: 0.6477\n","Epoch 1/5\n","25/25 [==============================] - 8s 219ms/step - loss: 3.7306 - accuracy: 0.5844 - auc: 0.5119 - val_loss: 0.6445 - val_accuracy: 0.6744 - val_auc: 0.5951\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6539 - accuracy: 0.6642 - auc: 0.5985 - val_loss: 0.6208 - val_accuracy: 0.6773 - val_auc: 0.6324\n","Epoch 3/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6287 - accuracy: 0.6752 - auc: 0.6417 - val_loss: 0.6222 - val_accuracy: 0.6773 - val_auc: 0.6464\n","Epoch 4/5\n","25/25 [==============================] - 5s 199ms/step - loss: 0.6140 - accuracy: 0.6803 - auc: 0.6796 - val_loss: 0.6114 - val_accuracy: 0.6831 - val_auc: 0.6637\n","Epoch 5/5\n","25/25 [==============================] - 5s 216ms/step - loss: 0.5772 - accuracy: 0.6952 - auc: 0.7249 - val_loss: 0.5922 - val_accuracy: 0.6860 - val_auc: 0.6866\n","Epoch 1/5\n","25/25 [==============================] - 7s 223ms/step - loss: 4.5425 - accuracy: 0.5847 - auc: 0.5176 - val_loss: 0.6939 - val_accuracy: 0.6831 - val_auc: 0.5232\n","Epoch 2/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.6923 - accuracy: 0.6519 - auc: 0.5525 - val_loss: 0.6322 - val_accuracy: 0.6831 - val_auc: 0.5536\n","Epoch 3/5\n","25/25 [==============================] - 5s 201ms/step - loss: 0.6402 - accuracy: 0.6765 - auc: 0.5879 - val_loss: 0.6132 - val_accuracy: 0.6831 - val_auc: 0.6510\n","Epoch 4/5\n","25/25 [==============================] - 5s 196ms/step - loss: 0.6011 - accuracy: 0.6736 - auc: 0.6726 - val_loss: 0.5867 - val_accuracy: 0.6773 - val_auc: 0.6796\n","Epoch 5/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.5907 - accuracy: 0.6875 - auc: 0.6961 - val_loss: 0.5903 - val_accuracy: 0.6744 - val_auc: 0.6831\n","Epoch 1/5\n","25/25 [==============================] - 8s 260ms/step - loss: 3.9043 - accuracy: 0.5656 - auc: 0.5284 - val_loss: 0.7677 - val_accuracy: 0.6657 - val_auc: 0.5425\n","Epoch 2/5\n","25/25 [==============================] - 6s 248ms/step - loss: 0.6546 - accuracy: 0.6574 - auc: 0.6255 - val_loss: 0.6155 - val_accuracy: 0.6802 - val_auc: 0.6242\n","Epoch 3/5\n","25/25 [==============================] - 6s 235ms/step - loss: 0.6418 - accuracy: 0.6781 - auc: 0.6610 - val_loss: 0.7782 - val_accuracy: 0.6657 - val_auc: 0.5988\n","Epoch 4/5\n","25/25 [==============================] - 6s 235ms/step - loss: 0.6388 - accuracy: 0.6884 - auc: 0.6731 - val_loss: 0.6123 - val_accuracy: 0.6686 - val_auc: 0.6650\n","Epoch 5/5\n","25/25 [==============================] - 6s 238ms/step - loss: 0.5754 - accuracy: 0.7101 - auc: 0.7312 - val_loss: 0.6156 - val_accuracy: 0.6802 - val_auc: 0.6607\n","Epoch 1/5\n","25/25 [==============================] - 7s 220ms/step - loss: 3.8313 - accuracy: 0.5983 - auc: 0.5299 - val_loss: 0.6652 - val_accuracy: 0.6366 - val_auc: 0.6259\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6979 - accuracy: 0.6616 - auc: 0.5947 - val_loss: 0.6335 - val_accuracy: 0.6599 - val_auc: 0.6156\n","Epoch 3/5\n","25/25 [==============================] - 5s 210ms/step - loss: 0.6567 - accuracy: 0.6629 - auc: 0.6154 - val_loss: 0.6761 - val_accuracy: 0.6628 - val_auc: 0.6290\n","Epoch 4/5\n","25/25 [==============================] - 5s 207ms/step - loss: 0.6218 - accuracy: 0.6823 - auc: 0.6780 - val_loss: 0.6184 - val_accuracy: 0.6831 - val_auc: 0.6506\n","Epoch 5/5\n","25/25 [==============================] - 5s 213ms/step - loss: 0.5812 - accuracy: 0.6939 - auc: 0.7180 - val_loss: 0.5961 - val_accuracy: 0.6802 - val_auc: 0.6729\n","Epoch 1/5\n","25/25 [==============================] - 8s 241ms/step - loss: 2.5413 - accuracy: 0.6228 - auc: 0.5427 - val_loss: 0.6404 - val_accuracy: 0.6831 - val_auc: 0.5618\n","Epoch 2/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.6500 - accuracy: 0.6568 - auc: 0.5697 - val_loss: 0.6176 - val_accuracy: 0.6831 - val_auc: 0.6197\n","Epoch 3/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.6111 - accuracy: 0.6833 - auc: 0.6404 - val_loss: 0.6153 - val_accuracy: 0.6831 - val_auc: 0.6523\n","Epoch 4/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.5867 - accuracy: 0.6829 - auc: 0.6977 - val_loss: 0.6094 - val_accuracy: 0.6948 - val_auc: 0.6777\n","Epoch 5/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.5532 - accuracy: 0.7133 - auc: 0.7542 - val_loss: 0.6215 - val_accuracy: 0.6686 - val_auc: 0.6757\n","Epoch 1/5\n","25/25 [==============================] - 8s 254ms/step - loss: 3.5067 - accuracy: 0.5569 - auc: 0.5000 - val_loss: 0.7088 - val_accuracy: 0.6744 - val_auc: 0.5994\n","Epoch 2/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.6726 - accuracy: 0.6687 - auc: 0.6145 - val_loss: 0.6164 - val_accuracy: 0.6831 - val_auc: 0.6198\n","Epoch 3/5\n","25/25 [==============================] - 5s 198ms/step - loss: 0.6123 - accuracy: 0.6784 - auc: 0.6532 - val_loss: 0.5993 - val_accuracy: 0.6919 - val_auc: 0.6564\n","Epoch 4/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6325 - accuracy: 0.6729 - auc: 0.6424 - val_loss: 0.6196 - val_accuracy: 0.6657 - val_auc: 0.6642\n","Epoch 5/5\n","25/25 [==============================] - 5s 205ms/step - loss: 0.5922 - accuracy: 0.6823 - auc: 0.6986 - val_loss: 0.6073 - val_accuracy: 0.6802 - val_auc: 0.6899\n","Epoch 1/5\n","25/25 [==============================] - 7s 220ms/step - loss: 2.6486 - accuracy: 0.5989 - auc: 0.5402 - val_loss: 0.6307 - val_accuracy: 0.6773 - val_auc: 0.6406\n","Epoch 2/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6253 - accuracy: 0.6781 - auc: 0.6328 - val_loss: 0.6429 - val_accuracy: 0.6860 - val_auc: 0.6678\n","Epoch 3/5\n","25/25 [==============================] - 5s 197ms/step - loss: 0.6027 - accuracy: 0.6871 - auc: 0.6918 - val_loss: 0.5823 - val_accuracy: 0.6919 - val_auc: 0.6977\n","Epoch 4/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.5689 - accuracy: 0.7020 - auc: 0.7297 - val_loss: 0.5980 - val_accuracy: 0.6802 - val_auc: 0.6912\n","Epoch 5/5\n","25/25 [==============================] - 5s 200ms/step - loss: 0.5483 - accuracy: 0.7227 - auc: 0.7707 - val_loss: 0.6037 - val_accuracy: 0.6773 - val_auc: 0.6863\n","Epoch 1/5\n","25/25 [==============================] - 7s 218ms/step - loss: 3.1665 - accuracy: 0.5882 - auc: 0.5263 - val_loss: 0.7622 - val_accuracy: 0.6715 - val_auc: 0.5948\n","Epoch 2/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.7056 - accuracy: 0.6574 - auc: 0.5758 - val_loss: 0.6127 - val_accuracy: 0.6890 - val_auc: 0.6552\n","Epoch 3/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6594 - accuracy: 0.6687 - auc: 0.6273 - val_loss: 0.6376 - val_accuracy: 0.6831 - val_auc: 0.6591\n","Epoch 4/5\n","25/25 [==============================] - 5s 203ms/step - loss: 0.6155 - accuracy: 0.6836 - auc: 0.6801 - val_loss: 0.6262 - val_accuracy: 0.6773 - val_auc: 0.6800\n","Epoch 5/5\n","25/25 [==============================] - 5s 208ms/step - loss: 0.5923 - accuracy: 0.6800 - auc: 0.7087 - val_loss: 0.5978 - val_accuracy: 0.6686 - val_auc: 0.6814\n","Epoch 1/5\n","25/25 [==============================] - 7s 223ms/step - loss: 3.4884 - accuracy: 0.5850 - auc: 0.5119 - val_loss: 0.6640 - val_accuracy: 0.6570 - val_auc: 0.5234\n","Epoch 2/5\n","25/25 [==============================] - 5s 204ms/step - loss: 0.7236 - accuracy: 0.6464 - auc: 0.5523 - val_loss: 0.6302 - val_accuracy: 0.6715 - val_auc: 0.6432\n","Epoch 3/5\n","25/25 [==============================] - 5s 202ms/step - loss: 0.6131 - accuracy: 0.6703 - auc: 0.6536 - val_loss: 0.6320 - val_accuracy: 0.6686 - val_auc: 0.6223\n","Epoch 4/5\n","25/25 [==============================] - 5s 209ms/step - loss: 0.6318 - accuracy: 0.6703 - auc: 0.6489 - val_loss: 0.6045 - val_accuracy: 0.6831 - val_auc: 0.6806\n","Epoch 5/5\n","25/25 [==============================] - 5s 211ms/step - loss: 0.5921 - accuracy: 0.6816 - auc: 0.6978 - val_loss: 0.5951 - val_accuracy: 0.6860 - val_auc: 0.6858\n"]}]},{"cell_type":"code","source":["val_F1_mean, val_F1_sd = mean_and_sd(val_F1)\n","val_acc_mean, val_acc_sd = mean_and_sd(val_acc)\n","val_pr_auc_mean, val_pr_auc_sd = mean_and_sd(val_pr_auc)\n","val_roc_auc_mean, val_roc_auc_sd = mean_and_sd(val_roc_auc)\n","\n","test_F1_mean, test_F1_sd = mean_and_sd(test_F1)\n","test_acc_mean, test_acc_sd = mean_and_sd(test_acc)\n","test_pr_auc_mean, test_pr_auc_sd = mean_and_sd(test_pr_auc)\n","test_roc_auc_mean, test_roc_auc_sd = mean_and_sd(test_roc_auc)"],"metadata":{"id":"z-ueEJk4rjoj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f\"validation F1: {round(val_F1_mean, 3)} +/- {round(val_F1_sd, 3)}\")\n","print(f\"test F1: {round(test_F1_mean, 3)} +/- {round(test_F1_sd, 3)}\")\n","\n","print(f\"validation accuracy: {round(val_acc_mean, 3)} +/- {round(val_acc_sd, 3)}\")\n","print(f\"test accuracy: {round(test_acc_mean, 3)} +/- {round(test_acc_sd, 3)}\")\n","\n","print(f\"validation PR-AUC: {round(val_pr_auc_mean, 3)} +/- {round(val_pr_auc_sd, 3)}\")\n","print(f\"test PR-AUC: {round(test_pr_auc_mean, 3)} +/- {round(test_pr_auc_sd, 3)}\")\n","\n","print(f\"validation ROC-AUC: {round(val_roc_auc_mean, 3)} +/- {round(val_roc_auc_sd, 3)}\")\n","print(f\"test ROC-AUC: {round(test_roc_auc_mean, 3)} +/- {round(test_roc_auc_sd, 3)}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rlBVo3xdrno9","executionInfo":{"status":"ok","timestamp":1651530561603,"user_tz":240,"elapsed":13,"user":{"displayName":"Benedikt Geiger","userId":"17925887631246406508"}},"outputId":"2572a1c9-cb77-4436-accb-0e79c11e724f"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["validation F1: 0.787 +/- 0.015\n","test F1: 0.824 +/- 0.035\n","validation accuracy: 0.68 +/- 0.011\n","test accuracy: 0.721 +/- 0.035\n","validation PR-AUC: 0.802 +/- 0.008\n","test PR-AUC: 0.84 +/- 0.008\n","validation ROC-AUC: 0.681 +/- 0.011\n","test ROC-AUC: 0.654 +/- 0.019\n"]}]}]}