855 lines (854 with data), 23.0 kB
{
"cells": [
{
"cell_type": "markdown",
"id": "c4025824",
"metadata": {},
"source": [
"# Prototype for Training NN to Invert ODE"
]
},
{
"cell_type": "markdown",
"id": "623533c9",
"metadata": {},
"source": [
"## Imports / Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9b3b7ca",
"metadata": {},
"outputs": [],
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import os\n",
"import scipy\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sys\n",
"import scanpy as sc\n",
"import scvelo as scv\n",
"import matplotlib.pyplot as plt\n",
"import math\n",
"import random\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d505d5d",
"metadata": {},
"outputs": [],
"source": [
"# Use this for progress bar help: https://stackoverflow.com/questions/3160699/python-progress-bar\n",
"from time import sleep\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4b54a5d",
"metadata": {},
"outputs": [],
"source": [
"#import pytorch\n",
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d03f795-e504-4e59-9655-44e8cc62dc1d",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(42)\n",
"np.random.seed(42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a8688eb",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"# import function that makes our data\n",
"from training_tools import *"
]
},
{
"cell_type": "markdown",
"id": "87d99b50",
"metadata": {},
"source": [
"## Construct the Matrix Used for Our Training Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28f2d4cf-0ff2-434a-bf7c-37197d45305d",
"metadata": {},
"outputs": [],
"source": [
"adj_factor = 21"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71675430-7d69-4f0d-9312-56e44e832053",
"metadata": {},
"outputs": [],
"source": [
"# direction of the data we're fitting\n",
"dire = 2\n",
"\n",
"# model 1 repression genes are notoriously hard to train,\n",
"# raise this flag if training them and some extra help\n",
"# will be added\n",
"dir_2_model_1 = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "190f05ca-cd68-4823-941f-4215c7f3b4f9",
"metadata": {},
"outputs": [],
"source": [
"if dir_2_model_1 and not dire == 2:\n",
" raise Exception(\"Don't set dir_2_model_1 to True if you aren't training repression genes!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a87b3044-ac00-49c4-bdf1-69fe49152137",
"metadata": {},
"outputs": [],
"source": [
"if dire == 0:\n",
" suffix_name = \"dir0\"\n",
"elif dire == 1:\n",
" suffix_name = \"dir1\"\n",
"elif dire == 2:\n",
" if dir_2_model_1:\n",
" suffix_name = \"dir2_m1\"\n",
" else:\n",
" suffix_name = \"dir2_m2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a484821",
"metadata": {},
"outputs": [],
"source": [
"read_folder = \"./data/simulated_data/\" + suffix_name"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93430a28",
"metadata": {},
"outputs": [],
"source": [
"X, t = X_from_file(read_folder, dire)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74782e63-6bcf-46d2-b1cb-ab72bfaab456",
"metadata": {},
"outputs": [],
"source": [
"t = (t + 1) / adj_factor\n",
"# t = (t / adj_factor)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a42597db",
"metadata": {},
"outputs": [],
"source": [
"print(X.shape)\n",
"print(t.shape)"
]
},
{
"cell_type": "markdown",
"id": "01f28593",
"metadata": {},
"source": [
"## Prepare Batches"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09701043",
"metadata": {},
"outputs": [],
"source": [
"batches = 800"
]
},
{
"cell_type": "markdown",
"id": "9bec7b96",
"metadata": {},
"source": [
"## Generate Validation Set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08e4ed27",
"metadata": {},
"outputs": [],
"source": [
"# you will need to supply your own validation data for this to work, either that or\n",
"# remove the validation code entirely\n",
"read_folder = \n",
"\n",
"val_X, val_t = X_from_file(read_folder, dire)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd350bfb-25c5-437a-ab36-4bcd5f8a614d",
"metadata": {},
"outputs": [],
"source": [
"print(np.any(np.isnan(X)))\n",
"print(np.any(np.isnan(t)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d825d30c-4b77-47bf-b227-4e14ffd9d608",
"metadata": {},
"outputs": [],
"source": [
"val_t = (val_t + 1) / adj_factor"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71397cb6",
"metadata": {},
"outputs": [],
"source": [
"val_X_ten = torch.tensor(val_X, dtype=torch.float, requires_grad=True).reshape(-1, val_X.shape[1])\n",
"val_t_ten = torch.tensor(val_t, dtype=torch.float, requires_grad=True).reshape(-1, 1)"
]
},
{
"cell_type": "markdown",
"id": "8d1fc534",
"metadata": {},
"source": [
"## Define Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a78ccefb",
"metadata": {},
"outputs": [],
"source": [
"if dire == 0:\n",
" # DIR 0:\n",
" base_n = 75\n",
" \n",
" ode_model = nn.Sequential(\n",
" nn.Linear(21, int(2*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(2*base_n), int(1.5*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.5*base_n), int(1*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.0*base_n), 1),\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
"elif dire == 1:\n",
" # DIR 1:\n",
" base_n = 32\n",
" \n",
" ode_model = nn.Sequential(\n",
" nn.Linear(16, int(2*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(2*base_n), int(1.5*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.5*base_n), int(1*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.0*base_n), 1),\n",
" nn.Sigmoid()\n",
" )\n",
"\n",
"elif dire == 2:\n",
" if dir_2_model_1:\n",
" # DIR 2 M1:\n",
" base_n = 110\n",
" \n",
" ode_model = nn.Sequential(\n",
" nn.Linear(18, int(2*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(2*base_n), int(1.5*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.5*base_n), int(1*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.0*base_n), 1),\n",
" nn.Sigmoid()\n",
" )\n",
" \n",
" else:\n",
" # DIR 2 M2:\n",
" base_n = 75\n",
" \n",
" ode_model = nn.Sequential(\n",
" nn.Linear(16, int(2*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(2*base_n), int(1.5*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.5*base_n), int(1*base_n)),\n",
" nn.ReLU(),\n",
" nn.Linear(int(1.0*base_n), 1),\n",
" nn.Sigmoid()\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ef0a45d",
"metadata": {},
"outputs": [],
"source": [
"# Used this: https://stackoverflow.com/questions/49433936/how-do-i-initialize-weights-in-pytorch\n",
"\n",
"def init_weights(m):\n",
" if dire == 0:\n",
" if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)\n",
" else:\n",
" if isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9215cb57",
"metadata": {},
"outputs": [],
"source": [
"ode_model.apply(init_weights)"
]
},
{
"cell_type": "markdown",
"id": "0212f3bb",
"metadata": {},
"source": [
"## Train"
]
},
{
"cell_type": "markdown",
"id": "5e18d3b7",
"metadata": {},
"source": [
"I closely followed this tutorial: https://pytorch.org/tutorials/beginner/examples_nn/polynomial_nn.html?highlight=mse"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88847166",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Number of max epochs for each neural network.\n",
"# Each N value was the epoch at which the network ceased\n",
"# to improve performance on the developer's validation data\n",
"# set for five epochs.\n",
"if dire == 0:\n",
" N = 51\n",
"elif dire == 1:\n",
" N = 67\n",
"elif dire == 2:\n",
" if dir_2_model_1:\n",
" N = 43\n",
" else:\n",
" N = 51\n",
"\n",
"# Use MSE as our loss criterion\n",
"criterion = torch.nn.MSELoss(reduction=\"mean\")\n",
"val_criterion = torch.nn.MSELoss(reduction=\"mean\")\n",
"\n",
"# learning rate for adam\n",
"if dir_2_model_1:\n",
"# if dir_2_model_1:\n",
" adam_lr = 1e-3\n",
"elif dire == 0:\n",
" adam_lr = 3e-4\n",
"else:\n",
" adam_lr = 1e-4\n",
"\n",
"# keep track of the minimum validation loss\n",
"min_loss = float('inf')\n",
"\n",
"epoch_loss = []\n",
"val_loss_list = []\n",
"val_epoch = []\n",
"\n",
"# progress bar keeping track of \n",
"# how many training epochs have\n",
"# transpired\n",
"epoch_bar = tqdm(total=N)\n",
"\n",
"# progress bar keeping track of\n",
"# how many batches within the\n",
"# epoch have transpired\n",
"batch_bar = tqdm(total=batches)\n",
"\n",
"# mean training loss over each batch\n",
"# of the epoch\n",
"mean_loss = 0\n",
"\n",
"# set Adam as the training optimizer\n",
"optimizer = torch.optim.Adam(ode_model.parameters(), lr=adam_lr)\n",
"\n",
"# print out some hyperparameters\n",
"print(\"Running with lr =\", adam_lr, \"and base_n =\", base_n)\n",
"\n",
"ode_model.train()\n",
"\n",
"# loop thru each epoch\n",
"for i in range(N):\n",
"\n",
" # reset the epoch mean loss\n",
" mean_loss = 0\n",
" mean_context_loss = 0\n",
" \n",
" batch_bar.reset()\n",
" \n",
" sleep(0.001)\n",
"\n",
" # randomly select data points and assign them to batches\n",
" x_tens, t_tens = make_batches_random(batches, X, t)\n",
"\n",
" # loop thru each batch\n",
" for j in range(batches):\n",
" \n",
" ode_model.float()\n",
" \n",
" torch.autograd.set_detect_anomaly(True)\n",
" \n",
" # do a forward pass\n",
" t_pred = ode_model(x_tens[j])\n",
"\n",
" # for loss values, set t_pred and t to \n",
" # the original time scale\n",
" scaled_t_pred = torch.mul(t_pred, adj_factor)\n",
" scaled_t = torch.mul(t_tens[j], adj_factor)\n",
" \n",
" # check loss\n",
" loss = criterion(scaled_t, scaled_t_pred)\n",
" \n",
" # reset gradients\n",
" optimizer.zero_grad()\n",
" \n",
" # do a backward pass\n",
" loss.backward()\n",
"\n",
" # step throught the optimizer\n",
" optimizer.step()\n",
" \n",
" batch_bar.update()\n",
"\n",
" # add the loss of this batch\n",
" # to the running mean batch loss\n",
" mean_loss += loss.item()\n",
"\n",
" # calculate the mean batch loss\n",
" mean_loss /= batches\n",
"\n",
" # add loss values to lists for graphing later\n",
" epoch_loss.append(mean_loss)\n",
"\n",
" # periodically check our progress\n",
" # (can adjust value after modulo in case\n",
" # you want to check progress less regularly)\n",
" if i % 1 == 0:\n",
"\n",
" ode_model.eval()\n",
" \n",
" # do a forward pass on validation data\n",
" val_t_pred = ode_model(val_X_ten)\n",
" \n",
" # check loss on validation data\n",
" val_loss = val_criterion(torch.mul(val_t_pred, adj_factor), torch.mul(val_t_ten, adj_factor))\n",
" val_loss_list.append(val_loss.item())\n",
" val_epoch.append(i)\n",
" \n",
" print(\"Epoch\", i, \"loss:\", mean_loss)\n",
" print(\"Validation loss:\", val_loss.item())\n",
" \n",
" # find the relative increase between the two previous\n",
" # loss values for our validation set\n",
" rel_diff = 0\n",
" \n",
" new_val_N = len(val_loss_list)\n",
"\n",
" # as long as we have at least two items to compare,\n",
" # compare the relative difference in successive\n",
" # validation loss values\n",
" if new_val_N >= 2:\n",
" rel_diff = (val_loss_list[-1] - min_loss) / min_loss\n",
" print(str(rel_diff*100) + \"%\")\n",
"\n",
" # # keep track of the lowest validation loss \n",
" if val_loss_list[-1] < min_loss:\n",
" min_loss = val_loss_list[-1]\n",
" print(\"New min!\")\n",
"\n",
" print()\n",
" \n",
" ode_model.train()\n",
" \n",
" epoch_bar.update()\n",
" \n",
"ode_model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2305187",
"metadata": {},
"outputs": [],
"source": [
"torch.save(ode_model.state_dict(), \"../src/multivelo/neural_nets/\" + suffix_name + \".pt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70c54804",
"metadata": {},
"outputs": [],
"source": [
"sub = 10\n",
"\n",
"# make a graph of the loss per epoch\n",
"plt.scatter(range(len(epoch_loss))[sub:], epoch_loss[sub:], s=2, label=\"Training\", color=\"blue\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"id": "9781af6f-73b8-48a9-afea-6a167652d092",
"metadata": {},
"source": [
"## Visualize Neural Network"
]
},
{
"cell_type": "markdown",
"id": "7f7652bd-5119-4db9-932e-80c2871ca13a",
"metadata": {},
"source": [
"For this section we visualize how well the neural network performs by passing in noiseless c/u/s data and graphing the results."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c455f6b5-0793-4ed0-adcb-c69b90eb4641",
"metadata": {},
"outputs": [],
"source": [
"read_folder = \"./data/simulated_data/\" + suffix_name + \"_noiseless\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26aa153c-0fae-42af-804d-d60d31cb9a95",
"metadata": {},
"outputs": [],
"source": [
"genes_to_graph = [1,12]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e359ea7-2375-4415-9e57-df8e21deab53",
"metadata": {},
"outputs": [],
"source": [
"graph_X, graph_t, subset_gene = X_from_file(read_folder, dire, subset_gene=genes_to_graph)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89b007c5-ec35-4064-aa91-495e7a3436fe",
"metadata": {},
"outputs": [],
"source": [
"f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 16))\n",
"\n",
"perf_line = np.arange(0, 20)\n",
"ax1.plot(perf_line, perf_line, label=\"perfect fit\", linewidth=1)\n",
"ax1.set_xlabel(\"Real\")\n",
"ax1.set_ylabel(\"Pred\")\n",
"\n",
"ax2.set_ylabel(\"Time\")\n",
"ax2.set_xlabel(\"C\")\n",
"\n",
"ax3.set_ylabel(\"Time\")\n",
"ax3.set_xlabel(\"U\")\n",
"\n",
"ax4.set_ylabel(\"Time\")\n",
"ax4.set_xlabel(\"S\")\n",
"\n",
"epsilon = 1e-5\n",
"\n",
"for i in range(len(genes_to_graph)):\n",
"\n",
" subset_X = graph_X[subset_gene[i]:subset_gene[i+1]]\n",
" subset_t = graph_t[subset_gene[i]:subset_gene[i+1]]\n",
"\n",
" alpha_c = str(np.round(np.exp(subset_X[0,3]) - epsilon, 4))\n",
" alpha = str(np.round(np.exp(subset_X[0,4]) - epsilon, 4))\n",
" beta = str(np.round(np.exp(subset_X[0,5]) - epsilon, 4))\n",
" gamma = str(np.round(np.exp(subset_X[0,6]) - epsilon, 4))\n",
"\n",
" c = subset_X[:,0]\n",
" u = np.exp(subset_X[:,1] - epsilon)\n",
" s = np.exp(subset_X[:,2] - epsilon)\n",
"\n",
" t_pred = ode_model(torch.tensor(subset_X).reshape(-1, subset_X.shape[1]))\n",
" t_pred = (t_pred.detach().numpy().reshape(-1) * adj_factor) - 1\n",
"\n",
" ax1.plot(subset_t, t_pred, linewidth=2, label=\"alpha_c: \" + alpha_c)\n",
"\n",
" clabel = \" t - alpha_c: \" + alpha_c\n",
" ax2.plot(c, subset_t, label=\"real\" + clabel, linewidth=1)\n",
" ax2.plot(c, t_pred, label=\"pred\" + clabel, linewidth=1)\n",
"\n",
" ulabel = \" t - alpha: \" + alpha + \" and beta: \" + beta\n",
" ax3.plot(u, subset_t, label=\"real\" + ulabel, linewidth=1)\n",
" ax3.plot(u, t_pred, label=\"pred\" + ulabel, linewidth=1)\n",
"\n",
" slabel = \" t - beta: \" + beta + \" and gamma: \" + gamma\n",
" ax4.plot(s, subset_t, label=\"real\" + slabel, linewidth=1)\n",
" ax4.plot(s, t_pred, label=\"pred\" + slabel, linewidth=1)\n",
" \n",
"\n",
"ax1.legend()\n",
"ax2.legend()\n",
"ax3.legend()\n",
"ax4.legend()"
]
},
{
"cell_type": "markdown",
"id": "cc11bf73-7775-4168-ae1b-d432825326d0",
"metadata": {},
"source": [
"## Visualize Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7379eb3-fe9e-4840-9226-2ad69ec881c9",
"metadata": {},
"outputs": [],
"source": [
"raise Exception (\"On large datasets, the next steps can crash the kernel! Only proceed if you have enough RAM\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adeb4ab5",
"metadata": {},
"outputs": [],
"source": [
"# pass all the data through the final model\n",
"final_pred_t_alldata = ode_model(torch.tensor(X, dtype=torch.float, requires_grad=True))\n",
"\n",
"# calculate final loss and print it\n",
"final_train_loss = val_criterion(final_pred_t_alldata, \\\n",
" torch.tensor(t, dtype=torch.float, requires_grad=True).reshape(-1, 1))\n",
"\n",
"print(final_train_loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aeeb8c07-475f-4da2-824a-73fa028b5cc9",
"metadata": {},
"outputs": [],
"source": [
"# calculate final loss with the original time scale\n",
"final_train_context_loss = val_criterion(final_pred_t_alldata*adj_factor, \\\n",
" torch.tensor(t*adj_factor, dtype=torch.float, requires_grad=True).reshape(-1, 1))\n",
"\n",
"print(final_train_context_loss)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17fb73a0",
"metadata": {},
"outputs": [],
"source": [
"# convert predicted time to numpy\n",
"t_pred_for_graph = final_pred_t_alldata.detach().numpy().reshape(-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30e90e0a",
"metadata": {},
"outputs": [],
"source": [
"# graph the final results of training the original data\n",
"graph_results(t, t_pred_for_graph, X, \"test\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d25291c0",
"metadata": {},
"outputs": [],
"source": [
"# a heatmap showing data points of predicted time vs true time\n",
"# (a well-trained model will show the most points along the x=y line)\n",
"fig, ax = plt.subplots()\n",
"h = ax.hist2d(t, t_pred_for_graph, bins=50, cmap=\"PuOr\")\n",
"fig.colorbar(h[3], ax=ax)\n",
"ax.set_ylabel(\"Predicted\")\n",
"ax.set_xlabel(\"True\")"
]
},
{
"cell_type": "markdown",
"id": "21ce473e",
"metadata": {},
"source": [
"## More Validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c18f07d",
"metadata": {},
"outputs": [],
"source": [
"# pass the full HSPC validation set through the model to get predicted time\n",
"pred_val_t = ode_model(torch.tensor(val_X, dtype=torch.float, requires_grad=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c346252a",
"metadata": {},
"outputs": [],
"source": [
"# calculate the loss\n",
"loss = criterion(pred_val_t, torch.tensor(val_t, dtype=torch.float, requires_grad=True).reshape(-1, 1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f320b50",
"metadata": {},
"outputs": [],
"source": [
"# print the loss\n",
"print(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01b87c70",
"metadata": {},
"outputs": [],
"source": [
"# graph the validation results\n",
"graph_results(val_t, pred_val_t.detach().numpy(), val_X, \"test\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6ecf431-2716-43bc-be7f-bc0ed2b88145",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}