[befbfc]: / synthetics / 04_create_synthetic_mouse_genomes.ipynb

Download this file

1046 lines (1045 with data), 39.7 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UTRxpSlaczHY"
   },
   "source": [
    "# Create synthetic mouse genome data\n",
    "\n",
    "* Create a synthetic version of the mouse genomes from the original experiment. \n",
    "* To run this notebook, you will need an API key from the Gretel console,  at https://console.gretel.cloud. \n",
    "* This notebook will create synthetic data for all the batch training sets chosen in the previous notebook, 03_build_genome_training_data. \n",
    "* The most important variable to customize in this notebook is MAX_MODELS_RUNNING which will enable no more than that number of models to train or generate in parallel.\n",
    "* If you'd like to hold onto your models so you can later run the synthetic performance report, please see the\n",
    "notes in the function \"chk_project_completion\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VEM6kjRsczHd"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install -U gretel-client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZQ-TmAdwczHd"
   },
   "outputs": [],
   "source": [
    "# Specify your Gretel API key\n",
    "\n",
    "from getpass import getpass\n",
    "import pandas as pd\n",
    "from gretel_client import configure_session, ClientConfig\n",
    "\n",
    "pd.set_option('max_colwidth', None)\n",
    "\n",
    "configure_session(ClientConfig(api_key=getpass(prompt=\"Enter Gretel API key\"), \n",
    "                               endpoint=\"https://api.gretel.cloud\"))\n",
    "\n",
    "                            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import pandas as pd\n",
    "\n",
    "base_path = pathlib.Path(os.getcwd().replace(\"/synthetics\", \"\"))\n",
    "data_path = base_path / 'mice_data_set' / 'data'\n",
    "data_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read in the list of training batches\n",
    "* This file was created in the previous notebook, 03_build_genome_training_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = data_path / \"batch_training_list.csv\"\n",
    "training_list_df = pd.read_csv(filename)\n",
    "batches = list(training_list_df[\"batch\"])\n",
    "batch_cnt = len(batches)\n",
    "batch_cnt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure model hyper parameters\n",
    "* Load the default configuration template, then update the parameters to those chosen by an Optuna hyperparameter optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from smart_open import open\n",
    "import yaml\n",
    "\n",
    "# Read in the phenome seed df\n",
    "\n",
    "seedfile = str(data_path / 'phenome_seeds.csv')\n",
    "seed_df = pd.read_csv(seedfile)\n",
    "\n",
    "# Load the default synthetic config options\n",
    "with open(\"https://raw.githubusercontent.com/gretelai/gretel-blueprints/main/config_templates/gretel/synthetics/default.yml\", 'r') as stream:\n",
    "    config = yaml.safe_load(stream)\n",
    "\n",
    "# Add a seed task which will enable us to tie together the genome and phenome data\n",
    "\n",
    "filename = data_path / 'pheno_analysis.csv'\n",
    "pheno_analysis_df = pd.read_csv(filename)\n",
    "pheno_analysis = list(pheno_analysis_df[\"pheno\"])\n",
    "\n",
    "filename = data_path / 'pheno_and_covariates.csv'\n",
    "pheno_and_cov_df = pd.read_csv(filename)\n",
    "seed_fields = list(pheno_and_cov_df[\"pheno_and_cov\"])\n",
    "\n",
    "task = {\n",
    "    'type': 'seed',\n",
    "    'attrs': {\n",
    "        'fields': seed_fields\n",
    "    }\n",
    "}\n",
    "\n",
    "\n",
    "# Optimize parameters for complex dataset\n",
    "\n",
    "config['models'][0]['synthetics']['task'] = task\n",
    "config['models'][0]['synthetics']['params']['epochs'] = 150\n",
    "config['models'][0]['synthetics']['params']['vocab_size'] = 0\n",
    "config['models'][0]['synthetics']['params']['rnn_units'] = 256\n",
    "config['models'][0]['synthetics']['params']['reset_states'] = False\n",
    "config['models'][0]['synthetics']['params']['learning_rate'] = 0.001\n",
    "config['models'][0]['synthetics']['privacy_filters']['similarity'] = None\n",
    "config['models'][0]['synthetics']['params']['dropout_rate'] = 0.5\n",
    "config['models'][0]['synthetics']['params']['gen_temp'] = 1.0\n",
    "config['models'][0]['synthetics']['generate']['num_records'] = len(seed_df)\n",
    "config['models'][0]['synthetics']['generate']['max_invalid'] = len(seed_df) * 10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(seed_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get set up for training genome batches in parallel\n",
    "* Be sure to modify the MAX_MODELS_RUNNING to an appropriate number"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gretel_client import create_project\n",
    "\n",
    "# Estabilish some variables that will enable the running of models in parallel\n",
    "\n",
    "MAX_MODELS_RUNNING = 58\n",
    "MAX_MODEL_CNT_IN_PROJECT = 8\n",
    "total_model_cnt = batch_cnt\n",
    "models_training = []\n",
    "models_generating = []\n",
    "models_complete = []\n",
    "models_training_cnt = 0\n",
    "models_generating_cnt = 0\n",
    "models_complete_cnt = 0\n",
    "models_error_cnt = 0\n",
    "model_cnt_in_project = 0\n",
    "project_to_model_mapping = {}\n",
    "project_num = 0\n",
    "batch_num = 0\n",
    "base_project_name = \"Illumina Genome Batch \"\n",
    "moreToDo = True\n",
    "model_info = {}\n",
    "gwas_errors = 0\n",
    "gwas_error_batches = []\n",
    "\n",
    "# Initialize first project\n",
    "\n",
    "project_name = base_project_name + str(project_num)\n",
    "current_project = create_project(display_name=project_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optionally, start where you left off\n",
    "* Run this cell if the notebook stopped before all jobs were completed and you want to\n",
    "pick up where you left off"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell if the notebook stopped before all jobs were completed and you want to\n",
    "# pick up where you left off\n",
    "\n",
    "filename_for_stats =  data_path / \"Completed_genome_batch_stats.csv\"\n",
    "results_df = pd.read_csv(filename_for_stats)\n",
    "for i in range(len(results_df)):\n",
    "    model_name = results_df.loc[i][\"model_name\"]\n",
    "    filename = results_df.loc[i][\"filename\"]\n",
    "    sqs = results_df.loc[i][\"sqs\"]\n",
    "    f1 = results_df.loc[i][\"F1\"]\n",
    "    model_info[model_name] = {}\n",
    "    model_info[model_name][\"sqs\"] = sqs\n",
    "    model_info[model_name][\"F1\"] = f1\n",
    "    model_info[model_name][\"filename\"] = filename\n",
    "    models_complete.append(model_name)\n",
    "    models_complete_cnt += 1\n",
    "    model_num = int(model_name[5:])\n",
    "    batches.remove(model_num)\n",
    "    if \"gwas_error\" in results_df.loc[i]:\n",
    "        gwas_err = results_df.loc[i][\"gwas_error\"]\n",
    "        if gwas_err:\n",
    "            gwas_errors += 1\n",
    "            gwas_error_batches.append(model_name)\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions for training and generating multiple models in parallel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def chk_model_completion():\n",
    "    \n",
    "    global models_training\n",
    "    global models_generating\n",
    "    global model_info\n",
    "    global models_complete\n",
    "    global models_training_cnt\n",
    "    global models_generating_cnt\n",
    "    global models_complete_cnt\n",
    "    global models_error_cnt\n",
    "    global model_cnt_in_project\n",
    "    global project_num\n",
    "    global current_project\n",
    "    global project_to_model_mapping\n",
    "    global project_name\n",
    "    \n",
    "    for model_name in models_training:\n",
    "        \n",
    "        model = model_info[model_name][\"model\"]\n",
    "        try:\n",
    "            model._poll_job_endpoint()\n",
    "        except:\n",
    "            model._poll_job_endpoint()\n",
    "        status = model.__dict__['_data']['model']['status']\n",
    "        print(\"Model \" + model_name + \" has training status: \" + status)\n",
    "        \n",
    "        # If model ended in error or was lost, restart it\n",
    "        if ((status == \"error\") or (status == \"lost\")):\n",
    "            \n",
    "            # Check if we need a new project\n",
    "            if model_cnt_in_project >= MAX_MODEL_CNT_IN_PROJECT:\n",
    "                project_num += 1\n",
    "                project_name = base_project_name + str(project_num)\n",
    "                try:\n",
    "                    current_project = create_project(display_name=project_name)\n",
    "                except:\n",
    "                    current_project = create_project(display_name=project_name)\n",
    "                model_cnt_in_project = 0\n",
    "            \n",
    "            # Start a new model\n",
    "            \n",
    "            filename = model_info[model_name][\"filename\"]\n",
    "            filepath = data_path / \"genome_training_data\" / filename\n",
    "            try:\n",
    "                artifact_id = current_project.upload_artifact(filepath)\n",
    "            except:\n",
    "                artifact_id = current_project.upload_artifact(filepath)\n",
    "            \n",
    "            try:\n",
    "                model = current_project.create_model_obj(model_config=config, data_source=artifact_id)\n",
    "            except:\n",
    "                model = current_project.create_model_obj(model_config=config, data_source=artifact_id)\n",
    "            \n",
    "            model_info[model_name][\"model\"] = model\n",
    "            \n",
    "            try:\n",
    "                model.submit()\n",
    "            except:\n",
    "                model.submit()\n",
    "                \n",
    "            model_cnt_in_project += 1           \n",
    "            print(\"Model restarted training due to error: \" + str(model_name))\n",
    "            models_error_cnt += 1\n",
    "            if project_name in project_to_model_mapping:\n",
    "                project_to_model_mapping[project_name][\"models\"].append(model_name)\n",
    "            else:\n",
    "                project_to_model_mapping[project_name] = {}\n",
    "                project_to_model_mapping[project_name][\"models\"] = [model_name]\n",
    "                project_to_model_mapping[project_name][\"project\"] = current_project\n",
    "                project_to_model_mapping[project_name][\"status\"] = \"active\"\n",
    "            \n",
    "        # If completed, get SQS and start generating records from seeds\n",
    "        if (status == 'completed'): \n",
    "            models_training.remove(model_name)\n",
    "            models_training_cnt -= 1\n",
    "            report = model.peek_report()\n",
    "            if report:\n",
    "                sqs = report['synthetic_data_quality_score']['score']\n",
    "                model_info[model_name][\"sqs\"] = sqs  \n",
    "                print(\"Model \" + str(model_name) + \" has SQS: \" + str(sqs))\n",
    "            \n",
    "            # Generate more records using seeds\n",
    "            \n",
    "            print(\"Model started generating: \" + model_name)\n",
    "            try:\n",
    "                rh = model.create_record_handler_obj(data_source=seedfile, params={\"num_records\": len(seed_df)})\n",
    "                rh.submit_cloud()\n",
    "            except:\n",
    "                rh = model.create_record_handler_obj(data_source=seedfile, params={\"num_records\": len(seed_df)})\n",
    "                rh.submit_cloud()\n",
    "            model_info[model_name][\"rh\"] = rh\n",
    "            models_generating.append(model_name)\n",
    "            models_generating_cnt += 1\n",
    "            \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chk_model_generation_completion():\n",
    "\n",
    "    import numpy as np\n",
    "    global models_generating\n",
    "    global models_training\n",
    "    global model_info\n",
    "    global models_complete\n",
    "    global models_generating_cnt\n",
    "    global models_training_cnt\n",
    "    global models_complete_cnt\n",
    "    global models_error_cnt\n",
    "    global model_cnt_in_project\n",
    "    global project_num\n",
    "    global current_project\n",
    "    global project_to_model_mapping\n",
    "    global project_name\n",
    "    global gwas_errors\n",
    "    global gwas_error_batches\n",
    "    \n",
    "    for model_name in models_generating:\n",
    "        rh = model_info[model_name][\"rh\"]\n",
    "        try:\n",
    "            rh._poll_job_endpoint()\n",
    "        except:\n",
    "            rh._poll_job_endpoint()\n",
    "        status = rh.__dict__['_data']['handler']['status']\n",
    "        print(\"Model \" + model_name + \" has generating status: \" + status)\n",
    "\n",
    "        # If generation ends in error, restart by training a fresh model\n",
    "        \n",
    "        if ((status == \"error\") or (status == \"lost\")):\n",
    "            \n",
    "            # Check if we need a new project\n",
    "            if model_cnt_in_project >= MAX_MODEL_CNT_IN_PROJECT:\n",
    "                project_num += 1\n",
    "                project_name = base_project_name + str(project_num)\n",
    "                try:\n",
    "                    current_project = create_project(display_name=project_name)\n",
    "                except:\n",
    "                    current_project = create_project(display_name=project_name)\n",
    "                model_cnt_in_project = 0\n",
    "            \n",
    "            # Start a new model\n",
    "            filename = model_info[model_name][\"filename\"]\n",
    "            filepath = data_path / \"genome_training_data\" / filename\n",
    "            try:\n",
    "                artifact_id = current_project.upload_artifact(filepath)\n",
    "            except:\n",
    "                artifact_id = current_project.upload_artifact(filepath)\n",
    "            \n",
    "            \n",
    "            try:\n",
    "                model = current_project.create_model_obj(model_config=config, data_source=artifact_id)\n",
    "            except:\n",
    "                model = current_project.create_model_obj(model_config=config, data_source=artifact_id)\n",
    "\n",
    "            model_info[model_name][\"model\"] = model\n",
    "            try:\n",
    "                model.submit()\n",
    "            except:\n",
    "                model.submit()\n",
    "            model_cnt_in_project += 1           \n",
    "            print(\"Model restarted training due to error in generation: \" + str(model_name))\n",
    "            models_error_cnt += 1\n",
    "                  \n",
    "            models_generating.remove(model_name)\n",
    "            models_training.append(model_name)\n",
    "            models_generating_cnt -= 1\n",
    "            models_training_cnt += 1\n",
    "            \n",
    "            if project_name in project_to_model_mapping:\n",
    "                project_to_model_mapping[project_name][\"models\"].append(model_name)\n",
    "            else:\n",
    "                project_to_model_mapping[project_name] = {}\n",
    "                project_to_model_mapping[project_name][\"models\"] = [model_name]\n",
    "                project_to_model_mapping[project_name][\"project\"] = current_project\n",
    "                project_to_model_mapping[project_name][\"status\"] = \"active\"\n",
    "            \n",
    "        if status == \"completed\":\n",
    "            \n",
    "            models_generating.remove(model_name)\n",
    "            models_generating_cnt -= 1\n",
    "            models_complete.append(model_name)\n",
    "            models_complete_cnt += 1\n",
    "         \n",
    "            synthetic_genomes = pd.read_csv(rh.get_artifact_link(\"data\"), compression='gzip') \n",
    "            \n",
    "            # Drop the phenome information from the genome synth data and add back in the fields \"id\" and \"discard\"\n",
    "\n",
    "            id_col = []\n",
    "            discard_col = []\n",
    "            for i in range(len(synthetic_genomes.index)):\n",
    "                id_col.append(i)\n",
    "                discard_col.append(\"no\")\n",
    "\n",
    "            synthetic_genomes = synthetic_genomes.drop(seed_fields, axis=1)\n",
    "            \n",
    "            columns = ['id', 'discard']\n",
    "            columns = columns + list(synthetic_genomes.columns)   \n",
    "            synthetic_genomes[\"id\"] = id_col\n",
    "            synthetic_genomes[\"discard\"] = discard_col\n",
    "            synthetic_genomes = synthetic_genomes.filter(columns)\n",
    "    \n",
    "            # Save the synthetic data\n",
    "        \n",
    "            filename = \"synthetic_genomes_\" + model_name + \".txt\"  \n",
    "            filepath = data_path / \"synthetic_genome_data\" / filename\n",
    "            synthetic_genomes.to_csv(filepath, index=False, sep=' ')  \n",
    "            print(\"Synthetic data for \" + model_name + \" saved to: \" + filename)\n",
    "            \n",
    "            # Compute an initial GWAS F1\n",
    "            \n",
    "            try:\n",
    "                F1s = computeF1(model_name)\n",
    "                print(\"GWAS F1 score fors \" + model_name + \" are: \")\n",
    "                for i, next_pheno in enumerate(pheno_analysis):\n",
    "                    print(\"\\t\" + next_pheno + \": \" + str(F1s[i]))\n",
    "                model_info[model_name][\"F1\"] = np.mean(F1s)\n",
    "                    \n",
    "            except:\n",
    "                F1s = []\n",
    "                print(\"GWAS error for model \" + model_name)\n",
    "                gwas_errors += 1\n",
    "                gwas_error_batches.append(model_name)\n",
    "                model_info[model_name][\"F1\"] = 0\n",
    "                \n",
    "            \n",
    "            \n",
    "                                                   \n",
    "          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def chk_project_completion():\n",
    "    \n",
    "    global models_generating\n",
    "    global models_training\n",
    "    global project_to_model_mapping\n",
    "\n",
    "    for next in project_to_model_mapping:\n",
    "        if project_to_model_mapping[next][\"status\"] == \"active\":\n",
    "            project_active = False\n",
    "            for next_model in project_to_model_mapping[next][\"models\"]:\n",
    "                if ((next_model in models_generating) or (next_model in models_training)):\n",
    "                    project_active = True\n",
    "            if (project_active == False):\n",
    "                this_project = project_to_model_mapping[next][\"project\"]\n",
    "                \n",
    "                # Note the below line is what you'd comment out if you'd like to hold onto your\n",
    "                # synthetic models and later run the synthetic performance report. This means you'll\n",
    "                # have to manually go into the Gretel Console and delete unneeded projects, which can\n",
    "                # grow quickly if you're processing all genomic batches\n",
    "                \n",
    "                this_project.delete()\n",
    "                project_to_model_mapping[next][\"status\"] = \"completed\"\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def start_more_models():\n",
    "    \n",
    "    import time\n",
    "    \n",
    "    global models_training\n",
    "    global models_generating\n",
    "    global model_info\n",
    "    global models_training_cnt\n",
    "    global models_generating_cnt\n",
    "    global models_complete_cnt\n",
    "    global model_cnt_in_project\n",
    "    global current_project\n",
    "    global filelist\n",
    "    global batch_num\n",
    "    global project_num\n",
    "    global project_to_model_mapping\n",
    "    global project_name\n",
    "    \n",
    "    while (((models_training_cnt + models_generating_cnt) < MAX_MODELS_RUNNING) and \n",
    "          ((models_training_cnt + models_generating_cnt + models_complete_cnt) < total_model_cnt)):\n",
    "        \n",
    "        # Check if we need a new project\n",
    "        if model_cnt_in_project >= MAX_MODEL_CNT_IN_PROJECT:\n",
    "            project_num += 1\n",
    "            project_name = base_project_name + str(project_num)\n",
    "            try:\n",
    "                current_project = create_project(display_name=project_name)\n",
    "            except:\n",
    "                current_project = create_project(display_name=project_name)\n",
    "            model_cnt_in_project = 0\n",
    "            \n",
    "        # Start a new model\n",
    "\n",
    "        batch = batches[batch_num]\n",
    "        batch_num += 1\n",
    "        filename = \"geno_batch\" + str(batch) + \"_train.csv\"\n",
    "        filepath = data_path / \"genome_training_data\" / filename\n",
    "        df = pd.read_csv(filepath)\n",
    "        cluster_size = len(df.columns)\n",
    "        config['models'][0]['synthetics']['params']['field_cluster_size'] = cluster_size\n",
    "        \n",
    "        try:\n",
    "            artifact_id = current_project.upload_artifact(filepath)\n",
    "        except:\n",
    "            artifact_id = current_project.upload_artifact(filepath)\n",
    "          \n",
    "        try:\n",
    "            model = current_project.create_model_obj(model_config=config, data_source=artifact_id)\n",
    "        except:\n",
    "            model = current_project.create_model_obj(model_config=config, data_source=artifact_id)\n",
    "        model_name = \"batch\" + str(batch)\n",
    "        \n",
    "        models_training.append(model_name)\n",
    "        models_training_cnt += 1\n",
    "        model_info[model_name] = {}\n",
    "        model_info[model_name][\"model\"] = model\n",
    "        model_info[model_name][\"filename\"] = filename\n",
    "        try:\n",
    "            model.submit()\n",
    "        except:\n",
    "            model.submit()\n",
    "        model_cnt_in_project += 1\n",
    "        print(\"Model started training: \" + str(model_name))\n",
    "        \n",
    "        if project_name in project_to_model_mapping:\n",
    "            project_to_model_mapping[project_name][\"models\"].append(model_name)\n",
    "        else:\n",
    "            project_to_model_mapping[project_name] = {}\n",
    "            project_to_model_mapping[project_name][\"models\"] = [model_name]\n",
    "            project_to_model_mapping[project_name][\"project\"] = current_project\n",
    "            project_to_model_mapping[project_name][\"status\"] = \"active\"\n",
    "            \n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def computeF1(model_name):\n",
    "    \n",
    "    from sklearn.metrics import f1_score\n",
    "\n",
    "    base_path = pathlib.Path(os.getcwd().replace(\"/synthetics\", \"\"))\n",
    "    data_path = base_path / 'mice_data_set' / 'data' \n",
    "    real_gwas_path = base_path / 'mice_data_set' / 'out' \n",
    "    synthetic_gwas_path = base_path / 'mice_data_set' / 'out_synth_working'\n",
    "    \n",
    "    # Copy the relevent synth and map files to the files gwas uses\n",
    "    filename = \"synthetic_genomes_\" + model_name + \".txt\"  \n",
    "    filepath = data_path / \"synthetic_genome_data\" / filename\n",
    "    synth_df = pd.read_csv(filepath, sep=' ')\n",
    "    newfile = data_path / \"synthetic_genomes.txt\"\n",
    "    synth_df.to_csv(newfile, index=False, sep=' ')  \n",
    "    filename = \"map_\" + model_name + \".txt\"  \n",
    "    filepath = data_path / \"genome_map_data\" / filename\n",
    "    map_df = pd.read_csv(filepath, sep=' ')\n",
    "    newfile = data_path / \"map_batch.txt\"\n",
    "    map_df.to_csv(newfile, index=False, sep=' ')\n",
    "   \n",
    "    # Run GWAS\n",
    "    !rm ../mice_data_set/out_synth/*.csv\n",
    "    !R --vanilla < ../research_paper_code/src/map_gwas_batch.R &> /tmp/R.log  \n",
    "    \n",
    "    filename = data_path / 'pheno_analysis.csv'\n",
    "    pheno_analysis_df = pd.read_csv(filename)\n",
    "    pheno_analysis = list(pheno_analysis_df[\"pheno\"])\n",
    "    \n",
    "    all_f1s = []\n",
    "\n",
    "    for phenotype in pheno_analysis:\n",
    "        \n",
    "        # Read in the new results\n",
    "\n",
    "        try:\n",
    "            synthetic_snps = pd.read_csv(synthetic_gwas_path / f'lm_{phenotype}.csv')  \n",
    "        except:\n",
    "            !R --vanilla < ../research_paper_code/src/run_map_batch1.R &> /tmp/R.log\n",
    "            synthetic_snps = pd.read_csv(synthetic_gwas_path / f'lm_{phenotype}.csv')\n",
    "        \n",
    "        synthetic_snps = synthetic_snps.rename(columns={synthetic_snps.columns[0]: 'index'})\n",
    "        synthetic_snps = synthetic_snps[['index', 'snp', 'p']]\n",
    "        synthetic_snps['interest'] = synthetic_snps['p'].apply(lambda x: True if x <= 5e-8 else False)\n",
    "         \n",
    "        # Read in the original results\n",
    "\n",
    "        real_snps = pd.read_csv(real_gwas_path / f'lm_{phenotype}_1_79646.csv') #, usecols=['snp', 'p']) # , usecols=['snp', 'p']\n",
    "        real_snps = real_snps.rename(columns={real_snps.columns[0]: 'index'})\n",
    "        real_snps = real_snps[['index', 'snp', 'p']]\n",
    "        real_snps['interest'] = real_snps['p'].apply(lambda x: True if x <= 5e-8 else False)\n",
    "    \n",
    "\n",
    "        combined = pd.merge(synthetic_snps, \n",
    "             real_snps, \n",
    "             how='inner', \n",
    "             on=['snp'],\n",
    "             suffixes=['_synthetic', '_real'])\n",
    "    \n",
    "        f1 = round(f1_score(combined['interest_real'], combined['interest_synthetic'], average='weighted'), 4)\n",
    "\n",
    "        all_f1s.append(f1)\n",
    "    \n",
    "    return all_f1s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_model_stats(filename):\n",
    " \n",
    "    global models_completed\n",
    "    global model_info\n",
    "    global gwas_error_batches\n",
    "    \n",
    "    model_names = []\n",
    "    filenames = []\n",
    "    sqss = []\n",
    "    F1s = []\n",
    "    gwas_errors = []\n",
    "    \n",
    "    for model_name in models_complete:\n",
    "        model_names.append(model_name)\n",
    "        filenames.append(model_info[model_name][\"filename\"])\n",
    "        if \"sqs\" in model_info[model_name]:\n",
    "            sqss.append(model_info[model_name][\"sqs\"])\n",
    "        else:\n",
    "            sqss.append(None)\n",
    "        if \"F1\" in model_info[model_name]:\n",
    "            F1s.append(model_info[model_name][\"F1\"])\n",
    "        else:\n",
    "            F1s.append(None)\n",
    "        if model_name in gwas_error_batches:\n",
    "            gwas_errors.append(True)\n",
    "        else:\n",
    "            gwas_errors.append(False)\n",
    "            \n",
    "    results_df = pd.DataFrame({\"model_name\": model_names, \"filename\": filenames,\n",
    "                              \"sqs\": sqss, \"F1\": F1s, \"gwas_error\": gwas_errors})\n",
    "    \n",
    "    results_df.to_csv(filename, index=False, header=True)\n",
    "    print(\"Completed model stats saved to: \" + str(filename))\n",
    "    \n",
    "    return results_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start training synthetic models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "global pheno_analysis\n",
    "pass_num = 0\n",
    "\n",
    "starttime = time.time()\n",
    "\n",
    "while moreToDo:\n",
    "    \n",
    "    pass_num += 1\n",
    "    print()\n",
    "    print_pass = \"************************************** PASS \" + str(pass_num) + \" **************************************\"\n",
    "    print(print_pass)\n",
    "    print(\"Models training: \" + str(models_training_cnt))\n",
    "    print(\"Models generating: \" + str(models_generating_cnt))\n",
    "    print(\"Models complete: \" + str(len(models_complete)))\n",
    "    still_to_start = total_model_cnt - models_training_cnt - models_generating_cnt - len(models_complete)\n",
    "    print(\"Models still to start: \" + str(still_to_start))\n",
    "    print(\"Training errors encountered: \" + str(models_error_cnt))\n",
    "    print(\"GWAS errors encountered: \" + str(gwas_errors))\n",
    "    print()\n",
    "    \n",
    "    # Check for model completion\n",
    "    chk_model_completion()\n",
    "    \n",
    "    # Check for generation completion\n",
    "    chk_model_generation_completion()\n",
    "    \n",
    "    # Check for project completion\n",
    "    chk_project_completion()\n",
    "    \n",
    "    # Start more models if room\n",
    "    start_more_models()\n",
    "    \n",
    "    # Gather complete model stats and save to file\n",
    "    if len(models_complete) > 0:\n",
    "        filename_for_stats =  data_path / \"Completed_genome_batch_stats.csv\"\n",
    "        results_df = save_model_stats(filename_for_stats)\n",
    "    \n",
    "    # Check if we're all done\n",
    "    if models_complete_cnt == total_model_cnt:\n",
    "        moreToDo = False\n",
    "    \n",
    "    # Sleep for 1  to 5 minutes, adjust to value desired\n",
    "    time.sleep(60)\n",
    "    \n",
    "endtime = time.time() \n",
    "exectime = endtime - starttime\n",
    "exectime_min = round((exectime / 60), 2)\n",
    "print()\n",
    "print(\"************************************** MODELS ALL COMPLETE **************************************\")\n",
    "print(str(total_model_cnt) + \" models completed in \" + str(exectime_min) + \" minutes\")\n",
    "print(\"A maximum of \" + str(MAX_MODELS_RUNNING) + \" were allowed to run in parallel\")\n",
    "print(str(models_error_cnt) + \" errors occurred which resulted in model retraining\")\n",
    "avg_sec_per_model = exectime / total_model_cnt\n",
    "avg_min_per_model = round((avg_sec_per_model / 60), 2)\n",
    "print(\"Each model took an average of \" + str(avg_min_per_model) + \" minutes\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# View all model results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The first models trained all have interesting assocations in them, if you chose batches that way\n",
    "results_df.head(40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The last models trained all have do not have interesting assocations in them\n",
    "results_df.tail(40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What is the average overall F1?\n",
    "import numpy as np\n",
    "np.mean(results_df[results_df[\"gwas_error\"] == False][\"F1\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What is the average F1 for just batches with interesting associations\n",
    "# This cell is only usable is your model batch list contains the interesting associations at the top\n",
    "\n",
    "import numpy as np\n",
    "pos_batch_cnt = 40\n",
    "np.mean(results_df[\"F1\"].head(pos_batch_cnt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "fig = px.histogram(results_df, x=\"F1\", title=\"Synthetic Genome Batch GWAS F1 Scores\",)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "69XYfU9k7fq4"
   },
   "source": [
    "# View the synthetic data quality reports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zX8qsizqczHg",
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Generate report that shows the statistical performance between the training and synthetic data\n",
    "# Note, for this to work you would have had to comment out the deletion of projects in function \n",
    "# chk_project_completion\n",
    "\n",
    "from smart_open import open\n",
    "from IPython.core.display import display, HTML\n",
    "\n",
    "# Change model_name to one from the list chosen above:\n",
    "model_name = \"batch1920\"\n",
    "display(HTML(data=open(model_info[model_name][\"model\"].get_artifact_link(\"report\")).read(), metadata=dict(isolated=True)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optionally save off results to another safe location"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename_for_stats =  data_path / \"synthetic_genome_allbatches_results.csv\"\n",
    "results_df.to_csv(filename_for_stats, index=False, header=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combine the synthetic genome batches \n",
    "* Cell to combine the individual synthetic genome batch files into one file where the ID in the synthetic genome data corresponds to the ID in the synthetic phenome data\n",
    "* Be sure to set the final filename appropriately, as you'll need to reuse it in the 05 notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First combine all the synthetic batch results into one dataframe.  This cell can take a while\n",
    "\n",
    "model_name = models_complete[0]\n",
    "filename = \"synthetic_genomes_\" + model_name + \".txt\"  \n",
    "filepath = data_path / \"synthetic_genome_data\" / filename \n",
    "synthetic_genomes = pd.read_csv(filepath, sep=' ')\n",
    "min_file_length = len(synthetic_genomes.index)\n",
    "\n",
    "batch_num = 1\n",
    "while batch_num < len(models_complete):\n",
    "    model_name = models_complete[batch_num]\n",
    "    print(batch_num)\n",
    "    batch_num += 1 \n",
    "    filename = \"synthetic_genomes_\" + model_name + \".txt\"  \n",
    "    filepath = data_path / \"synthetic_genome_data\" / filename \n",
    "    synthetic_genomes_batch = pd.read_csv(filepath, sep=' ')\n",
    "    if (len(synthetic_genomes_batch.index) < min_file_length):\n",
    "        min_file_length = len(synthetic_genomes_batch.index)\n",
    "    synthetic_genomes_batch = synthetic_genomes_batch.drop(['id', 'discard'], axis=1)\n",
    "    synthetic_genomes = pd.concat([synthetic_genomes, synthetic_genomes_batch], axis=1)\n",
    "            \n",
    "synthetic_genomes = synthetic_genomes.dropna().reindex\n",
    "synthetic_genomes= synthetic_genomes.astype({'id': 'int32'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now create of version of map.txt with just the SNPs in this final synthetic dataset\n",
    "# This will be used in the map code to run GWAS one more time on the overall synthetic dataset\n",
    "# NOTE: IT'S IMPORTANT TO REMEMBER THE NAME YOU CHOOSE FOR THE MAP FILE.  YOU WILL USE IT IN FURTHER NOTEBOOKS\n",
    "\n",
    "mapfile = data_path / \"map.txt\"\n",
    "mapdata = pd.read_csv(mapfile, sep=' ')\n",
    "\n",
    "snps = list(synthetic_genomes.columns)\n",
    "snps.remove(\"id\")\n",
    "snps.remove(\"discard\")\n",
    "\n",
    "mapdata_use = mapdata[mapdata[\"id\"].isin(snps)]\n",
    "filename = \"map_allbatches_abBMD.txt\"\n",
    "mapfile_new = data_path / \"genome_map_data\" / filename\n",
    "mapdata_use.to_csv(mapfile_new, sep=' ', header=True, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Put the synthetic genome file in chromosome order\n",
    "map_ids = list(mapdata_use[\"id\"])\n",
    "col_use = [\"id\", \"discard\"]\n",
    "col_use = col_use + map_ids\n",
    "synthetic_genomes = synthetic_genomes.filter(col_use)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Modify the filename to be what you want\n",
    "# NOTE: IT'S IMPORTANT TO REMEMBER THE NAME YOU CHOOSE FOR THE GENOME FILE.  YOU WILL USE IT IN FURTHER NOTEBOOKS\n",
    "\n",
    "filename = \"synthetic_genomes_allbatches_abBMD.txt\"  \n",
    "filepath = data_path / \"synthetic_genome_data\" / filename \n",
    "synthetic_genomes = synthetic_genomes.fillna(0)\n",
    "synthetic_genomes.to_csv(filepath, index=False, header=True, sep=' ')"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Create synthetic data from a DataFrame or CSV",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}