394 lines (393 with data), 13.2 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create synthetic mouse phenome data\n",
"\n",
"Create a synthetic version of the mouse phenomes from the original experiment, which are available after running `01_create_phenome_training_data.ipynb`. To run this notebook, you will need an API key from the Gretel console, at https://console.gretel.cloud."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get setup with Gretel Synthetics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!python3 -m pip install -U gretel-client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Specify your Gretel API key\n",
"\n",
"from getpass import getpass\n",
"import pandas as pd\n",
"from gretel_client import configure_session, ClientConfig\n",
"\n",
"pd.set_option('max_colwidth', None)\n",
"\n",
"configure_session(ClientConfig(api_key=getpass(prompt=\"Enter Gretel API key\"), \n",
" endpoint=\"https://api.gretel.cloud\"))\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure model hyper parameters\n",
"Load the default configuration template. This template will work well for most datasets. View other templates at https://github.com/gretelai/gretel-blueprints/tree/main/config_templates/gretel/synthetics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from smart_open import open\n",
"import yaml\n",
"\n",
"with open(\"https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/config_templates/gretel/synthetics/default.yml\", 'r') as stream:\n",
" config = yaml.safe_load(stream)\n",
"\n",
"# Optimize parameters for complex dataset\n",
"config['models'][0]['synthetics']['params']['epochs'] = 200\n",
"config['models'][0]['synthetics']['params']['vocab_size'] = 0\n",
"config['models'][0]['synthetics']['params']['rnn_units'] = 640\n",
"config['models'][0]['synthetics']['params']['reset_states'] = False\n",
"config['models'][0]['synthetics']['params']['learning_rate'] = 0.001\n",
"config['models'][0]['synthetics']['params']['dropout_rate'] = 0.4312\n",
"config['models'][0]['synthetics']['params']['gen_temp'] = 1.003\n",
"config['models'][0]['synthetics']['privacy_filters']['similarity'] = None\n",
"\n",
"print(json.dumps(config, indent=2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the synthetic model\n",
"In this step, we will task the worker running in the Gretel cloud, or locally, to train a synthetic model on the source dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get the location of the phenome training data\n",
"\n",
"import os\n",
"import pathlib\n",
"\n",
"base_path = pathlib.Path(os.getcwd().replace(\"/synthetics\", \"\"))\n",
"data_path = base_path / 'mice_data_set' / 'data'\n",
"data_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define a function to submit a new model for a specific phenome batch dataset\n",
"\n",
"def create_model(batch_num):\n",
" seconds = int(time.time())\n",
" project_name = \"Training phenomes\" + str(seconds)\n",
" project = create_project(display_name=project_name)\n",
" batchfile = \"pheno_batch\" + str(batch_num) + \".csv\"\n",
" trainpath = str(data_path / batchfile)\n",
" model = project.create_model_obj(model_config=config)\n",
" model.data_source = trainpath\n",
" model.submit(upload_data_source=True) \n",
" return(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Submit all the phenome batches to train in parallel; poll for completion\n",
"\n",
"from gretel_client.helpers import poll\n",
"from gretel_client import create_project\n",
"import time\n",
"\n",
"# Create a model for each batch\n",
"models = []\n",
"for i in range(7):\n",
" model = create_model(i)\n",
" models.append(model)\n",
"\n",
"# Poll for completion. Resubmit errors.\n",
"training = True\n",
"while training:\n",
" time.sleep(60)\n",
" training = False\n",
" print()\n",
" for i in range(7):\n",
" model = models[i]\n",
" model._poll_job_endpoint()\n",
" status = model.__dict__['_data']['model']['status']\n",
" print(\"Batch \" + str(i) + \" has status: \" + status)\n",
" if ((status == \"active\") or (status == \"pending\")):\n",
" training = True\n",
" if status == \"error\":\n",
" model = create_model(i)\n",
" models[i] = model\n",
" training = True \n",
"\n",
"# Now that models are complete, get each batches Synthetic Quality Score (SQS) \n",
"batch = 0\n",
"print()\n",
"for model in models:\n",
" model._poll_job_endpoint()\n",
" status = model.__dict__['_data']['model']['status']\n",
" if status == \"error\":\n",
" print(\"Batch \" + str(batch) + \" ended with error\")\n",
" else:\n",
" report = model.peek_report()\n",
" sqs = report['synthetic_data_quality_score']['score']\n",
" label = \"Moderate\"\n",
" if sqs >= 80:\n",
" label = \"Excellent\"\n",
" elif sqs >= 60:\n",
" label = \"Good\"\n",
" print(\"Batch \" + str(batch) + \" completes with SQS: \" + label + \" (\" + str(sqs) + \")\")\n",
" batch += 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Read the original phenome set\n",
"\n",
"filename = \"phenome_alldata.csv\"\n",
"filepath = data_path / filename\n",
"phenome_orig = pd.read_csv(filepath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"phenome_orig.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Merge the synthetic batches into one dataframe\n",
"# First gather the synthetic data for each batch\n",
"\n",
"synth_batches = []\n",
"for i in range(7):\n",
" model = models[i]\n",
" synth = pd.read_csv(model.get_artifact_link(\"data_preview\"), compression='gzip')\n",
" synth_batches.append(synth)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Merge batch 0 and 1 on common field sacweight\n",
"synth_batches[0]['g'] = synth_batches[0].groupby('sacweight').cumcount()\n",
"synth_batches[1]['g'] = synth_batches[1].groupby('sacweight').cumcount()\n",
"synth_allbatches = pd.merge(synth_batches[0],synth_batches[1],on=[\"sacweight\", 'g'],how='left').drop('g', axis=1)\n",
"\n",
"# Now merge in batch 2 on common fields SW16, SW20, SW17\n",
"synth_allbatches['g'] = synth_allbatches.groupby(['SW16','SW20', 'SW17']).cumcount()\n",
"synth_batches[2]['g'] = synth_batches[2].groupby(['SW16', 'SW20', 'SW17']).cumcount()\n",
"synth_allbatches = pd.merge(synth_allbatches,synth_batches[2],on=['SW16', 'SW20', 'SW17', 'g'],how='left').drop('g', axis=1)\n",
"\n",
"# Now merge in batches 3 \n",
"synth_allbatches = pd.concat([synth_allbatches, synth_batches[3]], axis=1)\n",
"\n",
"# Now merge in batch 4 using common fields 'methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10'\n",
"synth_allbatches['g'] = synth_allbatches.groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
"synth_batches[4]['g'] = synth_batches[4].groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
"synth_allbatches = pd.merge(synth_allbatches,synth_batches[4],on=['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10', 'g'],how='left').drop('g', axis=1)\n",
"\n",
"# Now merge in batch 5 using common fields 'methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10'\n",
"synth_allbatches['g'] = synth_allbatches.groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
"synth_batches[5]['g'] = synth_batches[5].groupby(['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10']).cumcount()\n",
"synth_allbatches = pd.merge(synth_allbatches,synth_batches[5],on=['methcage12', 'methcage9', 'methcage11', 'methcage8', 'methcage7', 'methcage10', 'g'],how='left').drop('g', axis=1)\n",
"\n",
"# Now merge in batches 6\n",
"synth_allbatches = pd.concat([synth_allbatches, synth_batches[6]], axis=1)\n",
"synth_allbatches\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add back in the \"id\" and \"discard\" fields, and save off complete synthetic data\n",
"\n",
"id_col = []\n",
"discard_col = []\n",
"for i in range(len(synth_allbatches.index)):\n",
" id_col.append(i)\n",
" discard_col.append(\"no\")\n",
" \n",
"synth_allbatches[\"id\"] = id_col\n",
"synth_allbatches[\"discard\"] = discard_col\n",
"filepath = data_path / 'phenome_alldata_synth.csv'\n",
"synth_allbatches.to_csv(filepath, index=False, header=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Optional cell if you have already created the synthetic phenomes and just need a new seed file\n",
"# to analyze a new pheno. Be sure to set the data path at the top of the notebook first\n",
"\n",
"import pandas as pd\n",
"filepath = data_path / 'phenome_alldata_synth.csv'\n",
"synth_allbatches = pd.read_csv(filepath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save off the phenotypes values you plan to analyze so we can later condition the genotype synthesis\n",
"# with these values\n",
"\n",
"filename = data_path / 'pheno_and_covariates.csv'\n",
"pheno_analysis_df = pd.read_csv(filename)\n",
"pheno_seeds = list(pheno_analysis_df[\"pheno_and_cov\"])\n",
"\n",
"print(pheno_seeds)\n",
"\n",
"seeds_df = synth_allbatches.filter(pheno_seeds)\n",
"# The seeding won't work if there are any NaN's in the seedfile\n",
"seeds_df = seeds_df.fillna(0)\n",
"seeds_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(seeds_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# When you create the seeds df, you must make sure that any rounding or casting to int\n",
"# is replicated when creating genome training files.\n",
"\n",
"seedfile = data_path / 'phenome_seeds.csv'\n",
"seeds_df.to_csv(seedfile, index=False, header=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"seeds_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# Generate report that shows the statistical performance between the training and synthetic data\n",
"# Use the synthetic batch that includes abBMD\n",
"\n",
"from smart_open import open\n",
"from IPython.core.display import display, HTML\n",
"\n",
"\n",
"# Change batch_num to any value between 0 and 6 to view performance report for other batches\n",
"batch_num = 0\n",
"display(HTML(data=open(models[0].get_artifact_link(\"report\")).read(), metadata=dict(isolated=True)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}