[befbfc]: / synthetics / 02_create_synthetic_mouse_phenomes.ipynb

Download this file

394 lines (393 with data), 13.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create synthetic mouse phenome data\n",
    "\n",
    "Create a synthetic version of the mouse phenomes from the original experiment, which are available after running `01_create_phenome_training_data.ipynb`. To run this notebook, you will need an API key from the Gretel console,  at https://console.gretel.cloud."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get setup with Gretel Synthetics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!python3 -m pip install -U gretel-client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify your Gretel API key\n",
    "\n",
    "from getpass import getpass\n",
    "import pandas as pd\n",
    "from gretel_client import configure_session, ClientConfig\n",
    "\n",
    "pd.set_option('max_colwidth', None)\n",
    "\n",
    "configure_session(ClientConfig(api_key=getpass(prompt=\"Enter Gretel API key\"), \n",
    "                               endpoint=\"https://api.gretel.cloud\"))\n",
    "\n",
    "                            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure model hyper parameters\n",
    "Load the default configuration template. This template will work well for most datasets. View other templates at https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from smart_open import open\n",
    "import yaml\n",
    "\n",
    "with open(\"https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/config_templates/gretel/synthetics/default.yml\", 'r') as stream:\n",
    "    config = yaml.safe_load(stream)\n",
    "\n",
    "# Optimize parameters for complex dataset\n",
    "config['models'][0]['synthetics']['params']['epochs'] = 200\n",
    "config['models'][0]['synthetics']['params']['vocab_size'] = 0\n",
    "config['models'][0]['synthetics']['params']['rnn_units'] = 640\n",
    "config['models'][0]['synthetics']['params']['reset_states'] = False\n",
    "config['models'][0]['synthetics']['params']['learning_rate'] = 0.001\n",
    "config['models'][0]['synthetics']['params']['dropout_rate'] = 0.4312\n",
    "config['models'][0]['synthetics']['params']['gen_temp'] = 1.003\n",
    "config['models'][0]['synthetics']['privacy_filters']['similarity'] = None\n",
    "\n",
    "print(json.dumps(config, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the synthetic model\n",
    "In this step, we will task the worker running in the Gretel cloud, or locally, to train a synthetic model on the source dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the location of the phenome training data\n",
    "\n",
    "import os\n",
    "import pathlib\n",
    "\n",
    "base_path = pathlib.Path(os.getcwd().replace(\"/synthetics\", \"\"))\n",
    "data_path = base_path / 'mice_data_set' / 'data'\n",
    "data_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a function to submit a new model for a specific phenome batch dataset\n",
    "\n",
    "def create_model(batch_num):\n",
    "    seconds = int(time.time())\n",
    "    project_name = \"Training phenomes\" + str(seconds)\n",
    "    project = create_project(display_name=project_name)\n",
    "    batchfile = \"pheno_batch\" + str(batch_num) + \".csv\"\n",
    "    trainpath = str(data_path / batchfile)\n",
    "    model = project.create_model_obj(model_config=config)\n",
    "    model.data_source = trainpath\n",
    "    model.submit(upload_data_source=True)  \n",
    "    return(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Submit all the phenome batches to train in parallel; poll for completion\n",
    "\n",
    "from gretel_client.helpers import poll\n",
    "from gretel_client import create_project\n",
    "import time\n",
    "\n",
    "# Create a model for each batch\n",
    "models = []\n",
    "for i in range(7):\n",
    "    model = create_model(i)\n",
    "    models.append(model)\n",
    "\n",
    "# Poll for completion. Resubmit errors.\n",
    "training = True\n",
    "while training:\n",
    "    time.sleep(60)\n",
    "    training = False\n",
    "    print()\n",
    "    for i in range(7):\n",
    "        model = models[i]\n",
    "        model._poll_job_endpoint()\n",
    "        status = model.__dict__['_data']['model']['status']\n",
    "        print(\"Batch \" + str(i) + \" has status: \" + status)\n",
    "        if ((status == \"active\") or (status == \"pending\")):\n",
    "            training = True\n",
    "        if status == \"error\":\n",
    "            model = create_model(i)\n",
    "            models[i] = model\n",
    "            training = True           \n",
    "\n",
    "# Now that models are complete, get each batches Synthetic Quality Score (SQS)            \n",
    "batch = 0\n",
    "print()\n",
    "for model in models:\n",
    "    model._poll_job_endpoint()\n",
    "    status = model.__dict__['_data']['model']['status']\n",
    "    if status == \"error\":\n",
    "        print(\"Batch \" + str(batch) + \" ended with error\")\n",
    "    else:\n",
    "        report = model.peek_report()\n",
    "        sqs = report['synthetic_data_quality_score']['score']\n",
    "        label = \"Moderate\"\n",
    "        if sqs >= 80:\n",
    "            label = \"Excellent\"\n",
    "        elif sqs >= 60:\n",
    "            label = \"Good\"\n",
    "        print(\"Batch \" + str(batch) + \" completes with SQS: \" + label + \" (\" + str(sqs) + \")\")\n",
    "    batch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the original phenome set\n",
    "\n",
    "filename = \"phenome_alldata.csv\"\n",
    "filepath = data_path / filename\n",
    "phenome_orig = pd.read_csv(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phenome_orig.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge the synthetic batches into one dataframe\n",
    "# First gather the synthetic data for each batch\n",
    "\n",
    "synth_batches = []\n",
    "for i in range(7):\n",
    "    model = models[i]\n",
    "    synth = pd.read_csv(model.get_artifact_link(\"data_preview\"), compression='gzip')\n",
    "    synth_batches.append(synth)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge batch 0 and 1 on common field sacweight\n",
    "synth_batches[0]['g'] = synth_batches[0].groupby('sacweight').cumcount()\n",
    "synth_batches[1]['g'] = synth_batches[1].groupby('sacweight').cumcount()\n",
    "synth_allbatches = pd.merge(synth_batches[0],synth_batches[1],on=[\"sacweight\", 'g'],how='left').drop('g', axis=1)\n",
    "\n",
    "# Now merge in batch 2 on common fields SW16, SW20, SW17\n",
    "synth_allbatches['g'] = synth_allbatches.groupby(['SW16','SW20', 'SW17']).cumcount()\n",
    "synth_batches[2]['g'] = synth_batches[2].groupby(['SW16', 'SW20', 'SW17']).cumcount()\n",
    "synth_allbatches = pd.merge(synth_allbatches,synth_batches[2],on=['SW16', 'SW20', 'SW17', 'g'],how='left').drop('g', axis=1)\n",
    "\n",
    "# Now merge in batches 3 \n",
    "synth_allbatches = pd.concat([synth_allbatches, synth_batches[3]], axis=1)\n",
    "\n",
    "# Now merge in batch 4 using common fields 'methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10'\n",
    "synth_allbatches['g'] = synth_allbatches.groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
    "synth_batches[4]['g'] = synth_batches[4].groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
    "synth_allbatches = pd.merge(synth_allbatches,synth_batches[4],on=['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10', 'g'],how='left').drop('g', axis=1)\n",
    "\n",
    "# Now merge in batch 5 using common fields 'methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10'\n",
    "synth_allbatches['g'] = synth_allbatches.groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
    "synth_batches[5]['g'] = synth_batches[5].groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
    "synth_allbatches = pd.merge(synth_allbatches,synth_batches[5],on=['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10', 'g'],how='left').drop('g', axis=1)\n",
    "\n",
    "# Now merge in batches 6\n",
    "synth_allbatches = pd.concat([synth_allbatches, synth_batches[6]], axis=1)\n",
    "synth_allbatches\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add back in the \"id\" and \"discard\" fields, and save off complete synthetic data\n",
    "\n",
    "id_col = []\n",
    "discard_col = []\n",
    "for i in range(len(synth_allbatches.index)):\n",
    "    id_col.append(i)\n",
    "    discard_col.append(\"no\")\n",
    "    \n",
    "synth_allbatches[\"id\"] = id_col\n",
    "synth_allbatches[\"discard\"] = discard_col\n",
    "filepath = data_path / 'phenome_alldata_synth.csv'\n",
    "synth_allbatches.to_csv(filepath, index=False, header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional cell if you have already created the synthetic phenomes and just need a new seed file\n",
    "# to analyze a new pheno. Be sure to set the data path at the top of the notebook first\n",
    "\n",
    "import pandas as pd\n",
    "filepath = data_path / 'phenome_alldata_synth.csv'\n",
    "synth_allbatches = pd.read_csv(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save off the phenotypes values you plan to analyze so we can later condition the genotype synthesis\n",
    "# with these values\n",
    "\n",
    "filename = data_path / 'pheno_and_covariates.csv'\n",
    "pheno_analysis_df = pd.read_csv(filename)\n",
    "pheno_seeds = list(pheno_analysis_df[\"pheno_and_cov\"])\n",
    "\n",
    "print(pheno_seeds)\n",
    "\n",
    "seeds_df = synth_allbatches.filter(pheno_seeds)\n",
    "# The seeding won't work if there are any NaN's in the seedfile\n",
    "seeds_df = seeds_df.fillna(0)\n",
    "seeds_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(seeds_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# When you create the seeds df, you must make sure that any rounding or casting to int\n",
    "# is replicated when creating genome training files.\n",
    "\n",
    "seedfile = data_path / 'phenome_seeds.csv'\n",
    "seeds_df.to_csv(seedfile, index=False, header=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Generate report that shows the statistical performance between the training and synthetic data\n",
    "# Use the synthetic batch that includes abBMD\n",
    "\n",
    "from smart_open import open\n",
    "from IPython.core.display import display, HTML\n",
    "\n",
    "\n",
    "# Change batch_num to any value between 0 and 6 to view performance report for other batches\n",
    "batch_num = 0\n",
    "display(HTML(data=open(models[0].get_artifact_link(\"report\")).read(), metadata=dict(isolated=True)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}