[0ad989]: / train_neural_network / agg_fig_data.ipynb

Download this file

373 lines (372 with data), 23.4 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "78ad58c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import anndata as ad\n",
    "import os\n",
    "import scipy\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sys\n",
    "import multivelo as mv\n",
    "import scanpy as sc\n",
    "import scvelo as scv\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sys.path.append(\"/..\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "923af518-b096-4e63-94e5-06f08c944fee",
   "metadata": {},
   "source": [
    "## Read in the Appropriate Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51361505-5c31-4f11-9212-9a5cfab0f6d1",
   "metadata": {},
   "source": [
    "Uncomment the top cell to aggragate data for training. Uncomment the bottom cell to aggregate data for validation. (You will need to supply your own AnnData object of validation data.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ebbb4e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "# training data\n",
    "\n",
    "# data used to generate figures in paper\n",
    "fig3 = sc.read_h5ad(\"../Examples/multivelo_result_fig3.h5ad\")\n",
    "fig4 = sc.read_h5ad(\"../Examples/multivelo_result_fig4.h5ad\")\n",
    "fig5 = sc.read_h5ad(\"../Examples/multivelo_result_fig5.h5ad\")\n",
    "fig6 = sc.read_h5ad(\"../Examples/multivelo_result_fig6.h5ad\")\n",
    "\n",
    "figs = [fig3, fig4, fig5, fig6]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13dca1bb-4ab9-4168-a696-fa5e3bb305d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# vali data\n",
    "# You will need to supply your own validation data in order to run this code!\n",
    "# val_data = sc.read_h5ad()\n",
    "# figs = [val_data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "46b3767a-63fa-43c7-ad58-4af3b67357fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = len(figs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34e2fb82-18e3-4fd0-922a-a0187ae68b2b",
   "metadata": {},
   "source": [
    "## Define Appropriate Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dc5347e3-21bf-4d87-94fa-ae96f264ace5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph a set of rate parameters\n",
    "def graph_params(alpha_c, alpha, beta, gamma):\n",
    "\n",
    "    # the names of the parameters for axis labels\n",
    "    names = [\"alpha_c\", \"alpha\", \"beta\", \"gamma\"]\n",
    "\n",
    "    # create a subplot\n",
    "    fig, axs = plt.subplots(4, 3, figsize=(12, 12))\n",
    "\n",
    "    # axis boundaries for each rate parameter\n",
    "    lims = [0.2, 2, 1.25, 2]\n",
    "\n",
    "    # assemble the rate parameters into a matrix\n",
    "    mvln = np.array([alpha_c, alpha, beta, gamma])\n",
    "\n",
    "    # for each combination of rate parameters...\n",
    "    for i in range(4):\n",
    "        col = 0\n",
    "        for j in range(4):\n",
    "\n",
    "            # if we're about to plot a rate parameter against\n",
    "            # itself, then skip this iteration\n",
    "            if i == j:\n",
    "                continue\n",
    "\n",
    "            # graph the rate parameters\n",
    "            h = axs[i][col].hist2d(np.ravel(mvln[i,:]), np.ravel(mvln[j,:]),\n",
    "                                   # range=[[0, lims[i]],\n",
    "                                   #        [0, lims[j]]],\n",
    "                                   bins=10,\n",
    "                                   cmap=\"Greens\")\n",
    "            axs[i][col].set_xlabel(names[i])\n",
    "            axs[i][col].set_ylabel(names[j])\n",
    "            fig.colorbar(h[3], ax=axs[i][col])\n",
    "            col += 1\n",
    "    \n",
    "    # # fig.colorbar(axs)\n",
    "    fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "936b562d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# a function for filtering out genes that fall below a certain fit_likelihood threshold\n",
    "def filter_likelihood(adata, thresh):\n",
    "\n",
    "    # print out the shape of the AnnData object before filtering\n",
    "    print(\"Before shape:\", adata.shape)\n",
    "\n",
    "    # get the fit likelihood of each gene\n",
    "    fig_likelihood = adata.var['fit_likelihood']\n",
    "\n",
    "    # draw a histogram of the fit likelihood\n",
    "    plt.hist(fig_likelihood, bins=40, range=(0, 0.2))\n",
    "\n",
    "    # do the filtering\n",
    "    filtered = fig_likelihood > thresh\n",
    "    return_val = adata[:, filtered]\n",
    "\n",
    "    # print out the shape of the AnnData object after filtering\n",
    "    print(\"After shape:\", return_val.shape)\n",
    "    print()\n",
    "    \n",
    "    return return_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a591b8fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# a function for choosing a subset of the data to use\n",
    "# like if we wanted to only save enough for a validation set, for example\n",
    "def subset_data(adata, n=None):\n",
    "\n",
    "    # if a subset size isn't specified just\n",
    "    # return the original dataset\n",
    "    if n is None:\n",
    "        return adata\n",
    "\n",
    "    total_genes = adata.shape[1]\n",
    "    \n",
    "    full_data = range(total_genes)\n",
    "\n",
    "    # make a random choice of indices\n",
    "    choice = np.random.choice(total_genes, size=n, replace=False)\n",
    "\n",
    "    # subset the AnnData object with our random set\n",
    "    subset_adata = adata[:,choice]\n",
    "    \n",
    "    return subset_adata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "eef193fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# function responsible for saving the relevant data\n",
    "def write_files(outfile, adatas, graph=False):\n",
    "\n",
    "        # print the name of the file we're saving to\n",
    "        print(outfile)\n",
    "\n",
    "        alpha_c = np.array([])\n",
    "        alpha = np.array([])\n",
    "        beta = np.array([])\n",
    "        gamma = np.array([])\n",
    "\n",
    "        # assemble the rate parameters of each AnnData object\n",
    "        for adata in adatas:\n",
    "        \n",
    "            alpha_c = np.concatenate((alpha_c, np.array(adata.var['fit_alpha_c'])))\n",
    "            alpha = np.concatenate((alpha, np.array(adata.var['fit_alpha'])))\n",
    "            beta = np.concatenate((beta, np.array(adata.var['fit_beta'])))\n",
    "            gamma = np.concatenate((gamma, np.array(adata.var['fit_gamma'])))\n",
    "\n",
    "        # graph results if the user specifies it\n",
    "        if graph:\n",
    "            graph_results(alpha_c, alpha, beta, gamma)\n",
    "\n",
    "        # save all of the data\n",
    "        np.savetxt(outfile + \"/alpha_c.txt\", alpha_c)\n",
    "        np.savetxt(outfile + \"/alpha.txt\", alpha)\n",
    "        np.savetxt(outfile + \"/beta.txt\", beta)\n",
    "        np.savetxt(outfile + \"/gamma.txt\", gamma)\n",
    "\n",
    "        # print the total number of genes saved\n",
    "        print(\"Number of genes:\", alpha_c.shape[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "780c9b89-5cf8-406d-adde-227cbd2af61b",
   "metadata": {},
   "source": [
    "## Process and Save Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e966528d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before shape: (3365, 865)\n",
      "After shape: (3365, 665)\n",
      "\n",
      "Before shape: (6436, 960)\n",
      "After shape: (6436, 771)\n",
      "\n",
      "Before shape: (11605, 936)\n",
      "After shape: (11605, 655)\n",
      "\n",
      "Before shape: (4693, 747)\n",
      "After shape: (4693, 507)\n",
      "\n",
      "The number of remaining genes is: 26099\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# filter out genes that have a fit likelihood lower\n",
    "# than this specified value\n",
    "\n",
    "# used on fig data:\n",
    "# likelihood_thresh = 0.0\n",
    "\n",
    "# used on hspc data:\n",
    "likelihood_thresh = 0.025\n",
    "\n",
    "# keep a running total of the final number of genes\n",
    "new_gene_num = 0\n",
    "\n",
    "# For each AnnData object we're pulling data from,\n",
    "# filter out the bad data\n",
    "for i in range(N):\n",
    "    \n",
    "    figs[i] = filter_likelihood(figs[i], likelihood_thresh)\n",
    "    new_gene_num += figs[i].shape[0]\n",
    "\n",
    "print(\"The number of remaining genes is:\", new_gene_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "13590e02",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_figs = []\n",
    "\n",
    "# Set subset_size to number of genes per\n",
    "# AnnData object you want if you're generating\n",
    "# validation data\n",
    "# subset_size = 30\n",
    "\n",
    "subset_size = None\n",
    "\n",
    "# For each AnnData object we're pulling data from,\n",
    "# pull out a subset:\n",
    "for i in range(N):\n",
    "    \n",
    "    test_fig = subset_data(figs[i],  subset_size)\n",
    "    test_figs.append(test_fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b86e54b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./data/rates/fig_rates_train\n",
      "Number of genes: 2598\n"
     ]
    }
   ],
   "source": [
    "# write the data to the specified file:\n",
    "# write_files(\"./data/rates/val_rates\", test_figs)\n",
    "write_files(\"./data/rates/fig_rates_train\", test_figs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}