279 lines (278 with data), 10.3 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "cf343c1c-ad8d-4fdb-a142-c501e579e288",
"metadata": {},
"source": [
"# Count Featurization And Models\n",
"\n",
"FEMR contains several utilities to implement common tabular featurization strategies.\n",
"\n",
"[CountFeaturizer](https://github.com/som-shahlab/femr/blob/main/src/femr/featurizers/featurizers.py#L180) is the main class and it documents the various supported options.\n",
"\n",
"In order to use the featurizer, you must construct a featurizer list, prepare the featurizers, and then featurize."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "892ab2d5-0c5a-43c9-a210-9201f775e4fb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 36758.28 examples/s]\n",
"Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3295.78 examples/s]\n",
"Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2998.20 examples/s]\n"
]
}
],
"source": [
"import pickle\n",
"import femr.featurizers\n",
"import femr.labelers\n",
"import meds\n",
"import pyarrow.csv\n",
"import datasets\n",
"import femr.index\n",
"\n",
"# Load some labels\n",
"labels = pyarrow.csv.read_csv('input/labels.csv').to_pylist()\n",
"\n",
"# Load our data\n",
"dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n",
"\n",
"# We need to create an index to allow us to find patients quickly\n",
"index = femr.index.PatientIndex(dataset)\n",
" \n",
"# Define our featurizer\n",
"\n",
"# Note that we are using both ages and counts here\n",
"age = femr.featurizers.AgeFeaturizer(is_normalize=False)\n",
"count = femr.featurizers.CountFeaturizer(string_value_combination=True)\n",
"featurizer_age_count = femr.featurizers.FeaturizerList([age, count])\n",
"\n",
"# Preprocessing the featurizers, which includes processes such as normalizing age.\n",
"featurizer_age_count.preprocess_featurizers(dataset, index, labels)\n",
"\n",
"# Actually do the featurization\n",
"features = featurizer_age_count.featurize(dataset, index, labels)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "112fe99d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"patient_ids (200,)\n",
"feature_times (200,)\n",
"features (200, 1884)\n"
]
}
],
"source": [
"# Results consist of three components, the patient ids, feature times, and the features themselves\n",
"\n",
"for k, v in features.items():\n",
" print(k, v.shape)"
]
},
{
"cell_type": "markdown",
"id": "fafa8ea8",
"metadata": {},
"source": [
"# Joining features and labels\n",
"\n",
"Given a feature set, it's important to be able to join a set of labels to those features.\n",
"\n",
"This can be done with femr.featurizers.join_labels"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cd0f43fd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"boolean_values (200,)\n",
"patient_ids (200,)\n",
"times (200,)\n",
"features (200, 1884)\n"
]
}
],
"source": [
"features_and_labels = femr.featurizers.join_labels(features, labels)\n",
"\n",
"for k, v in features_and_labels.items():\n",
" print(k, v.shape)"
]
},
{
"cell_type": "markdown",
"id": "66934476-c40a-467c-8702-b0d7021d92bf",
"metadata": {},
"source": [
"# Data Splitting\n",
"\n",
"FEMR contains utilities for doing hash based patient splitting, where splits are determined based on a hash value of the patient id.\n",
"\n",
"This is a deterministic approximate approach for splitting that is both stable and scalable."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "01acd922-668b-481b-8dbb-54ab6ae433af",
"metadata": {},
"outputs": [],
"source": [
"import femr.splits\n",
"import numpy as np\n",
"\n",
"# We split into a global training and test set\n",
"split = femr.splits.generate_hash_split(set(features_and_labels['patient_ids']), seed=87, frac_test=0.3)\n",
"\n",
"train_mask = np.isin(features_and_labels['patient_ids'], split.train_patient_ids)\n",
"test_mask = np.isin(features_and_labels['patient_ids'], split.test_patient_ids)\n",
"\n",
"percent_train = .70\n",
"X_train, y_train = (\n",
" features_and_labels['features'][train_mask],\n",
" features_and_labels['boolean_values'][train_mask],\n",
")\n",
"X_test, y_test = (\n",
" features_and_labels['features'][test_mask],\n",
" features_and_labels['boolean_values'][test_mask],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9aaeb7e5-eb48-46f5-ae59-9abfbc0dcef5",
"metadata": {},
"source": [
"# Building Models\n",
"\n",
"The generated features can then be used to build your standard models. In this case we construct both logistic regression and XGBoost models and evaluate them.\n",
"\n",
"Performance is perfect since our task (predicting gender) is 100% determined by the features"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "caae3126-1437-408e-b25f-04568e15c96a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---- Logistic Regression ----\n",
"Train:\n",
"\tAUROC: 1.0\n",
"\tAPS: 1.0\n",
"\tAccuracy: 1.0\n",
"\tF1 Score: 1.0\n",
"Test:\n",
"\tAUROC: 1.0\n",
"\tAPS: 1.0\n",
"\tAccuracy: 1.0\n",
"\tF1 Score: 1.0\n",
"---- XGBoost ----\n",
"Train:\n",
"\tAUROC: 1.0\n",
"\tAPS: 1.0\n",
"\tAccuracy: 1.0\n",
"\tF1 Score: 1.0\n",
"Test:\n",
"\tAUROC: 1.0\n",
"\tAPS: 1.0\n",
"\tAccuracy: 1.0\n",
"\tF1 Score: 1.0\n"
]
}
],
"source": [
"import xgboost as xgb\n",
"import sklearn.linear_model\n",
"import sklearn.metrics\n",
"import sklearn.preprocessing\n",
"\n",
"def run_analysis(title: str, y_train, y_train_proba, y_test, y_test_proba):\n",
" print(f\"---- {title} ----\")\n",
" print(\"Train:\")\n",
" print_metrics(y_train, y_train_proba)\n",
" print(\"Test:\")\n",
" print_metrics(y_test, y_test_proba)\n",
"\n",
"def print_metrics(y_true, y_proba):\n",
" y_pred = y_proba > 0.5\n",
" auroc = sklearn.metrics.roc_auc_score(y_true, y_proba)\n",
" aps = sklearn.metrics.average_precision_score(y_true, y_proba)\n",
" accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)\n",
" f1 = sklearn.metrics.f1_score(y_true, y_pred)\n",
" print(\"\\tAUROC:\", auroc)\n",
" print(\"\\tAPS:\", aps)\n",
" print(\"\\tAccuracy:\", accuracy)\n",
" print(\"\\tF1 Score:\", f1)\n",
"\n",
"\n",
"scaler = sklearn.preprocessing.MaxAbsScaler().fit(\n",
" X_train\n",
") # best for sparse data: see https://scikit-learn.org/stable/modules/preprocessing.html#scaling-sparse-data\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"X_test_scaled = scaler.transform(X_test)\n",
"model = sklearn.linear_model.LogisticRegressionCV(penalty=\"l2\", solver=\"liblinear\").fit(X_train_scaled, y_train)\n",
"y_train_proba = model.predict_proba(X_train_scaled)[::, 1]\n",
"y_test_proba = model.predict_proba(X_test_scaled)[::, 1]\n",
"run_analysis(\"Logistic Regression\", y_train, y_train_proba, y_test, y_test_proba)\n",
"\n",
"\n",
"# XGBoost\n",
"model = xgb.XGBClassifier()\n",
"model.fit(X_train, y_train)\n",
"y_train_proba = model.predict_proba(X_train)[::, 1]\n",
"y_test_proba = model.predict_proba(X_test)[::, 1]\n",
"run_analysis(\"XGBoost\", y_train, y_train_proba, y_test, y_test_proba)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}