[f54d94]: / tutorials / 2_Labeling.ipynb

Download this file

206 lines (205 with data), 7.5 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "43f4d50c-4e7b-4652-9701-be9366ff70c4",
   "metadata": {},
   "source": [
    "# Labeling\n",
    "\n",
    "A core component of FEMR is labeling patients.\n",
    "\n",
    "Labels within FEMR follow the [label schema within MEDS](https://github.com/Medical-Event-Data-Standard/meds/blob/e93f63a2f9642123c49a31ecffcdb84d877dc54a/src/meds/__init__.py#L70).\n",
    "\n",
    "Per MEDS, each label consists of three attributes:\n",
    "\n",
    "* `patient_id` (int64): The identifier for the patient to predict on\n",
    "* `prediction_time` (datetime.datetime): The timestamp for when the prediction should be made. This indicates what features are allowed to be used for prediction.\n",
    "* `boolean_value` (bool): The target to predict\n",
    "\n",
    "Additional types of labels will be added to MEDS over time, and then supported here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c6ac5c41-bc99-4731-ad82-7152274c67e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "import os\n",
    "\n",
    "TARGET_DIR = 'trash/tutorial_2'\n",
    "\n",
    "if os.path.exists(TARGET_DIR):\n",
    "    shutil.rmtree(TARGET_DIR)\n",
    "\n",
    "os.mkdir(TARGET_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e98dd85",
   "metadata": {},
   "source": [
    "# Demonstration of some example labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8d9e2ccd-71c2-4ae0-897b-7ec022f9fdf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# We can construct these labels manually\n",
    "\n",
    "import femr.labelers\n",
    "import datetime\n",
    "import meds\n",
    "\n",
    "# Predict False on March 2nd, 1994\n",
    "example_label = {'patient_id': 100, 'prediction_time': datetime.datetime(1994, 3, 2), 'boolean_value': False}\n",
    "\n",
    "# Predict True on March 2nd, 2009\n",
    "example_label2 = {'patient_id': 100, 'prediction_time': datetime.datetime(2009, 3, 2), 'boolean_value': True}\n",
    "\n",
    "\n",
    "# Multiple labels are stored using a list\n",
    "labels = [example_label, example_label2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e77b1bfc-8d2d-4f79-b855-f90b3a73736e",
   "metadata": {},
   "source": [
    "# Generating labels programatically within FEMR\n",
    "\n",
    "One core feature of FEMR is the ability to algorithmically generate labels through the use of a labeling function class.\n",
    "\n",
    "The core for FEMR's labeling code is the abstract base class [Labeler](https://github.com/som-shahlab/femr/blob/main/src/femr/labelers/core.py#L40).\n",
    "\n",
    "Labeler has one abstract methods:\n",
    "\n",
    "```python\n",
    "def label(self, patient: meds.Patient) -> List[meds.Label]:\n",
    "    Generate a list of labels for a patient\n",
    "```\n",
    "\n",
    "Note that the patient is assumed to be the [MEDS Patient schema](https://github.com/Medical-Event-Data-Standard/meds/blob/e93f63a2f9642123c49a31ecffcdb84d877dc54a/src/meds/__init__.py#L18).\n",
    "\n",
    "Once this method is implemented, the apply function becomes available for generating labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9ac22dbe-ef34-468a-8ab3-673e58e5a920",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 3040.98 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'patient_id': 100, 'prediction_time': datetime.datetime(1992, 7, 15, 0, 0), 'boolean_value': False}\n",
      "{'patient_id': 101, 'prediction_time': datetime.datetime(1992, 8, 20, 0, 0), 'boolean_value': False}\n",
      "{'patient_id': 102, 'prediction_time': datetime.datetime(1991, 4, 13, 0, 0), 'boolean_value': True}\n",
      "{'patient_id': 103, 'prediction_time': datetime.datetime(1990, 10, 19, 0, 0), 'boolean_value': False}\n",
      "{'patient_id': 104, 'prediction_time': datetime.datetime(1990, 6, 15, 0, 0), 'boolean_value': True}\n",
      "{'patient_id': 105, 'prediction_time': datetime.datetime(1990, 6, 29, 0, 0), 'boolean_value': True}\n",
      "{'patient_id': 106, 'prediction_time': datetime.datetime(1992, 5, 25, 0, 0), 'boolean_value': True}\n",
      "{'patient_id': 107, 'prediction_time': datetime.datetime(1992, 5, 29, 0, 0), 'boolean_value': False}\n",
      "{'patient_id': 108, 'prediction_time': datetime.datetime(1991, 10, 20, 0, 0), 'boolean_value': True}\n",
      "{'patient_id': 109, 'prediction_time': datetime.datetime(1991, 6, 25, 0, 0), 'boolean_value': True}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from typing import List\n",
    "import femr.pat_utils\n",
    "import datasets\n",
    "\n",
    "class IsMaleLabeler(femr.labelers.Labeler):\n",
    "    # Dummy labeler to predict gender at birth\n",
    "    \n",
    "    def label(self, patient: meds.Patient) -> List[meds.Label]:\n",
    "        is_male = any('Gender/M' == measurement['code'] for event in patient['events'] for measurement in event['measurements'])\n",
    "        return [{\n",
    "            'patient_id': patient['patient_id'], \n",
    "            'prediction_time': femr.pat_utils.get_patient_birthdate(patient),\n",
    "            'boolean_value': is_male,\n",
    "        }]\n",
    "    \n",
    "dataset = datasets.Dataset.from_parquet(\"input/meds/data/*\")\n",
    "\n",
    "labeler = IsMaleLabeler()\n",
    "labeled_patients = labeler.apply(dataset)\n",
    "\n",
    "for i in range(10):\n",
    "    print(labeled_patients[100 + i])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "20bd7859",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can use pyarrow to save these labels to a csv\n",
    "import pyarrow\n",
    "import pyarrow.csv\n",
    "\n",
    "table = pyarrow.Table.from_pylist(labeled_patients, schema=meds.label)\n",
    "pyarrow.csv.write_csv(table, \"trash/tutorial_2/labels.csv\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}