--- a
+++ b/process_mimic.ipynb
@@ -0,0 +1,298 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "Copy of process_mimic.ipynb",
+      "version": "0.3.2",
+      "provenance": [],
+      "collapsed_sections": [],
+      "include_colab_link": true
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "view-in-github",
+        "colab_type": "text"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/BenM1215/medgan/blob/master/process_mimic.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "metadata": {
+        "id": "dkIm2oJnyjVK",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "import sys\n",
+        "import _pickle as pickle\n",
+        "import numpy as np\n",
+        "from datetime import datetime"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "F3LE_9PVyBfI",
+        "colab_type": "code",
+        "outputId": "d6ad558d-3bd1-4da1-be4f-248213b11409",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 34
+        }
+      },
+      "cell_type": "code",
+      "source": [
+        "from google.colab import drive\n",
+        "drive.mount('/content/gdrive', force_remount=True)"
+      ],
+      "execution_count": 0,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Mounted at /content/gdrive\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "metadata": {
+        "id": "QbwB-8Fdylim",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "def convert_to_icd9(dxStr):\n",
+        "    if dxStr.startswith('E'):\n",
+        "        if len(dxStr) > 4: return dxStr[:4] + '.' + dxStr[4:]\n",
+        "        else: return dxStr\n",
+        "    else:\n",
+        "        if len(dxStr) > 3: return dxStr[:3] + '.' + dxStr[3:]\n",
+        "        else: return dxStr"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "DEHXM6G2yqLR",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "def convert_to_3digit_icd9(dxStr):\n",
+        "    if dxStr.startswith('E'):\n",
+        "        if len(dxStr) > 4: return dxStr[:4]\n",
+        "        else: return dxStr\n",
+        "    else:\n",
+        "        if len(dxStr) > 3: return dxStr[:3]\n",
+        "        else: return dxStr"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "WTP8qLpe4GsQ",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "# input arguments\n",
+        "binary_count = 'binary'"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "m40Qy9Ok4KB7",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "root_dir = \"/content/gdrive/My Drive/\"\n",
+        "\n",
+        "if binary_count == 'count':\n",
+        "  base_dir = root_dir + 'GOSH/Synthetic Data/medgan/count/'\n",
+        "else:\n",
+        "  base_dir = root_dir + 'GOSH/Synthetic Data/medgan/binary/'\n",
+        "\n",
+        "raw_data_dir = root_dir + 'GOSH/Synthetic Data/medgan/mimic/'\n",
+        "processed_data_dir = base_dir + 'processed_mimic/'\n",
+        "model_dir = base_dir + 'models/'\n",
+        "gen_data_dir = base_dir + 'generated_data/'"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "84a4GULe5NmG",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "admissionFile = raw_data_dir + 'ADMISSIONS.csv'\n",
+        "diagnosisFile = raw_data_dir + 'DIAGNOSES_ICD.csv'\n",
+        "outFile = processed_data_dir + 'processed_mimic'"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "0o1e-8_RyttW",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        "if binary_count != 'binary' and binary_count != 'count':\n",
+        "    print('You must choose either binary or count.')"
+      ],
+      "execution_count": 0,
+      "outputs": []
+    },
+    {
+      "metadata": {
+        "id": "RYeXZcXe4-Zj",
+        "colab_type": "code",
+        "outputId": "0c259433-92d5-497b-da62-b9357d23d938",
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 119
+        }
+      },
+      "cell_type": "code",
+      "source": [
+        "print('Building pid-admission mapping, admission-date mapping')\n",
+        "pidAdmMap = {}\n",
+        "admDateMap = {}\n",
+        "infd = open(admissionFile, 'r')\n",
+        "infd.readline()\n",
+        "for line in infd:\n",
+        "    tokens = line.strip().split(',')\n",
+        "    pid = int(tokens[1])\n",
+        "    admId = int(tokens[2])\n",
+        "    admTime = datetime.strptime(tokens[3], '%Y-%m-%d %H:%M:%S')\n",
+        "    admDateMap[admId] = admTime\n",
+        "    if pid in pidAdmMap: pidAdmMap[pid].append(admId)\n",
+        "    else: pidAdmMap[pid] = [admId]\n",
+        "infd.close()\n",
+        "\n",
+        "print('Building admission-dxList mapping')\n",
+        "admDxMap = {}\n",
+        "infd = open(diagnosisFile, 'r')\n",
+        "infd.readline()\n",
+        "for line in infd:\n",
+        "    tokens = line.strip().split(',')\n",
+        "    admId = int(tokens[2])\n",
+        "    # Uncomment this line and comment the line below, if you want to use the entire ICD9 digits.\n",
+        "    dxStr = 'D_' + convert_to_icd9(tokens[4][1:-1])\n",
+        "    #dxStr = 'D_' + convert_to_3digit_icd9(tokens[4][1:-1])\n",
+        "    if admId in admDxMap: admDxMap[admId].append(dxStr)\n",
+        "    else: admDxMap[admId] = [dxStr]\n",
+        "infd.close()\n",
+        "\n",
+        "print('Building pid-sortedVisits mapping')\n",
+        "pidSeqMap = {}\n",
+        "for pid, admIdList in pidAdmMap.items():\n",
+        "    #if len(admIdList) < 2: continue\n",
+        "    sortedList = sorted([(admDateMap[admId], admDxMap[admId]) for admId in admIdList])\n",
+        "    pidSeqMap[pid] = sortedList\n",
+        "\n",
+        "print('Building pids, dates, strSeqs')\n",
+        "pids = []\n",
+        "dates = []\n",
+        "seqs = []\n",
+        "for pid, visits in pidSeqMap.items():\n",
+        "    pids.append(pid)\n",
+        "    seq = []\n",
+        "    date = []\n",
+        "    for visit in visits:\n",
+        "        date.append(visit[0])\n",
+        "        seq.append(visit[1])\n",
+        "    dates.append(date)\n",
+        "    seqs.append(seq)\n",
+        "\n",
+        "print('Converting strSeqs to intSeqs, and making types')\n",
+        "types = {}\n",
+        "newSeqs = []\n",
+        "for patient in seqs:\n",
+        "    newPatient = []\n",
+        "    for visit in patient:\n",
+        "        newVisit = []\n",
+        "        for code in visit:\n",
+        "            if code in types:\n",
+        "                newVisit.append(types[code])\n",
+        "            else:\n",
+        "                types[code] = len(types)\n",
+        "                newVisit.append(types[code])\n",
+        "        newPatient.append(newVisit)\n",
+        "    newSeqs.append(newPatient)\n",
+        "\n",
+        "print('Constructing the matrix')\n",
+        "numPatients = len(newSeqs)\n",
+        "numCodes = len(types)\n",
+        "matrix = np.zeros((numPatients, numCodes)).astype('float32')\n",
+        "for i, patient in enumerate(newSeqs):\n",
+        "    for visit in patient:\n",
+        "        for code in visit:\n",
+        "            if binary_count == 'binary':\n",
+        "                matrix[i][code] = 1.\n",
+        "            else:\n",
+        "                matrix[i][code] += 1.\n",
+        "\n",
+        "pickle.dump(pids, open(outFile+'.pids', 'wb'), -1)\n",
+        "pickle.dump(matrix, open(outFile+'.matrix', 'wb'), -1)\n",
+        "pickle.dump(types, open(outFile+'.types', 'wb'), -1)"
+      ],
+      "execution_count": 0,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "text": [
+            "Building pid-admission mapping, admission-date mapping\n",
+            "Building admission-dxList mapping\n",
+            "Building pid-sortedVisits mapping\n",
+            "Building pids, dates, strSeqs\n",
+            "Converting strSeqs to intSeqs, and making types\n",
+            "Constructing the matrix\n"
+          ],
+          "name": "stdout"
+        }
+      ]
+    },
+    {
+      "metadata": {
+        "id": "AaF-cC2t0pT6",
+        "colab_type": "code",
+        "colab": {}
+      },
+      "cell_type": "code",
+      "source": [
+        ""
+      ],
+      "execution_count": 0,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file