[87e8bf]: / docs / source / tutorials / 3_Analyse_movements.ipynb

Download this file

156 lines (155 with data), 4.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from myosuite.utils import gym\n",
    "import skvideo.io\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from stable_baselines3 import PPO\n",
    "# policy = \"ElbowPose_policy.zip\"\n",
    "\n",
    "# pi = PPO.load(policy)\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "env = gym.make('myoElbowPose1D6MRandom-v0')\n",
    "\n",
    "env.reset()\n",
    "pi = PPO(\"MlpPolicy\", env, verbose=0)\n",
    "\n",
    "\n",
    "pi.learn(total_timesteps=1000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_store = []\n",
    "for _ in range(10): # 10 episodes\n",
    "    for _ in range(100): # 100 samples for each episode\n",
    "        o = env.get_obs()\n",
    "        a = pi.predict(o)[0]\n",
    "        next_o, r, done, *_, ifo = env.step(a) # take a random action\n",
    "                    \n",
    "        data_store.append({\"action\":a.copy(), \n",
    "                            \"jpos\":env.sim.data.qpos.copy(), \n",
    "                            \"mlen\":env.sim.data.actuator_length.copy(), \n",
    "                            \"act\":env.sim.data.act.copy()})\n",
    "\n",
    "env.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def VAF(W, H, A):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        W: ndarray, m x rank matrix, m-muscles x activation coefficients obtained from (# rank) nmf\n",
    "        H: ndarray, rank x L matrix, basis vectors obtained from nmf where L is the length of the signal\n",
    "        A: ndarray, m x L matrix, original time-invariant sEMG signal\n",
    "    Returns:\n",
    "        global_VAF: float, VAF calculated for the entire A based on the W&H\n",
    "        local_VAF: 1D array, VAF calculated for each muscle (column) in A based on W&H\n",
    "    \"\"\"\n",
    "    SSE_matrix = (A - np.dot(W, H))**2\n",
    "    SST_matrix = (A)**2\n",
    "\n",
    "    global_SSE = np.sum(SSE_matrix)\n",
    "    global_SST = np.sum(SST_matrix)\n",
    "    global_VAF = 100 * (1 - global_SSE / global_SST)\n",
    "\n",
    "    local_SSE = np.sum(SSE_matrix, axis = 0)\n",
    "    local_SST = np.sum(SST_matrix, axis = 0)\n",
    "    local_VAF = 100 * (1 - np.divide(local_SSE, local_SST))\n",
    "\n",
    "    return global_VAF, local_VAF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install scikit-learn\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import NMF\n",
    "\n",
    "act = np.array([dd['act'] for dd in data_store])\n",
    "\n",
    "VAFstore=[]\n",
    "SSE, SST = [], []\n",
    "\n",
    "sample_points = [1,2,3,4,5,10,20,30]\n",
    "for isyn in sample_points:\n",
    "    nmf_model = NMF(n_components=isyn, init='random', random_state=0);\n",
    "    W = nmf_model.fit_transform(act)\n",
    "    H = nmf_model.components_\n",
    "\n",
    "    global_VAF, local_VAF = VAF(W, H, act)\n",
    "\n",
    "    VAFstore.append(global_VAF)\n",
    "\n",
    "plt.plot(sample_points,VAFstore,'-o')\n",
    "plt.xlabel('Number of Muscle Synergies')\n",
    "plt.ylabel('Explained Variance R^2')\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "plt.gca().spines['right'].set_visible(False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}