[87e8bf]: / docs / source / tutorials / 7_Fatigue_Modeling.ipynb

Download this file

763 lines (762 with data), 32.0 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %env MUJOCO_GL=egl\n",
    "import myosuite\n",
    "from myosuite.utils import gym\n",
    "import skvideo.io\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "import mujoco"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    " \n",
    "def show_video(video_path, video_width=400):\n",
    "   \n",
    "  video_file = open(video_path, \"r+b\").read()\n",
    " \n",
    "  video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n",
    "  return HTML(f\"\"\"<video autoplay width={video_width} controls><source src=\"{video_url}\"></video>\"\"\")\n",
    "\n",
    "\n",
    "import PIL.Image, PIL.ImageDraw, PIL.ImageFont\n",
    "\n",
    "def add_text_to_frame(frame, text, pos=(20, 20), color=(255, 0, 0), fontsize=12):\n",
    "    if isinstance(frame, np.ndarray):\n",
    "        frame = PIL.Image.fromarray(frame)\n",
    "    \n",
    "    draw = PIL.ImageDraw.Draw(frame)\n",
    "    draw.text(pos, text, fill=color)\n",
    "    return frame"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fatigue Modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MyoSuite includes a fatigue model based on the (modified) [\"Three Compartment Controller (3CC-r)\" model](https://doi.org/10.1016/j.jbiomech.2018.06.005). \\\n",
    "The implementation is based on the *CumulativeFatigue* model included in the [\"User-in-the-Box\" framework](https://github.com/aikkala/user-in-the-box/blob/main/uitb/bm_models/effort_models.py).\n",
    "For details on the dynamics of the 3CC-r model, we refer the interested readers to the papers from [Looft et al.](https://doi.org/10.1016/j.jbiomech.2018.06.005) and [Cheema et al.](https://doi.org/10.1145/3313831.3376701).\n",
    "\n",
    "Crucially, the 3CC-r model is implemented on a muscle level, i.e., fatigue is computed for each muscle individually rather than for a single (shoulder) joint. \\\n",
    "While originally built and tested for models of the arm, elbow and hand, it can also be used with models of the lower extremity, e.g., the MyoLeg model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use the Fatigue Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To model fatigue, load the desired MyoSuite environment in its \"Fati\" variant:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "envFatigue = gym.make('myoFatiElbowPose1D6MRandom-v0', normalize_act=False)\n",
    "\n",
    "## for comparison\n",
    "env = gym.make('myoElbowPose1D6MRandom-v0', normalize_act=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This adds the \"muscle_fatigue\" attribute, which entails the current fatigue state of each muscle:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "envFatigue.unwrapped.muscle_fatigue.MF   #percentage of fatigued motor units for each muscle\n",
    "envFatigue.unwrapped.muscle_fatigue.MR   #percentage of resting motor units for each muscle\n",
    "envFatigue.unwrapped.muscle_fatigue.MA   #percentage of active motor units for each muscle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The fatigue/recovery constants F and R as well as the recovery multiplier r, which determines how much faster motor units recover during rest periods (i.e., when less motor units are required than are currently active), can be set using the following methods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "envFatigue.unwrapped.muscle_fatigue.set_RecoveryMultiplier(10)\n",
    "envFatigue.unwrapped.muscle_fatigue.set_RecoveryCoefficient(0.0022)\n",
    "envFatigue.unwrapped.muscle_fatigue.set_FatigueCoefficient(0.0146)\n",
    "\n",
    "envFatigue.unwrapped.muscle_fatigue.r, envFatigue.unwrapped.muscle_fatigue.R, envFatigue.unwrapped.muscle_fatigue.F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** The parameters F and R (and in particular their ratio, which defines the percentage of active motor units in a totally fatigued state) generally depend on the joints and muscles, and thus need to be manually chosen for each model!\n",
    "\n",
    "The muscle force development/relaxation factors LD and LR are automatically computed for each muscle based on its time activation and deactivation constants."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Initial Fatigue States"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the simulation starts in the default \"non-fatigued\" muscle state with 100% resting motor units."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "envFatigue.reset()\n",
    "envFatigue.unwrapped.muscle_fatigue.MF, envFatigue.unwrapped.muscle_fatigue.MR, envFatigue.unwrapped.muscle_fatigue.MA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To reset to a randomly chosen distribution of fatigued, resting and active motor units per muscle, the following command can be used:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "envFatigue.unwrapped.set_fatigue_reset_random(True)\n",
    "\n",
    "envFatigue.reset()\n",
    "envFatigue.unwrapped.muscle_fatigue.MF, envFatigue.unwrapped.muscle_fatigue.MR, envFatigue.unwrapped.muscle_fatigue.MA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the fatigue variant, the simulation applies the \"currently available\" muscle controls as defined by the fatigue state and dynamics, rather than the intended muscle control signals a:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "envFatigue.unwrapped.set_fatigue_reset_random(False)\n",
    "a = np.zeros(envFatigue.unwrapped.sim.model.nu,)\n",
    "a[0] = 1\n",
    "\n",
    "envFatigue.reset()\n",
    "for i in range(10):\n",
    "    next_o, r, done, *_, ifo = envFatigue.step(a) # take an action\n",
    "\n",
    "# Comparison: without fatigue\n",
    "env.reset()\n",
    "for i in range(10):\n",
    "    next_o, r, done, *_, ifo = env.step(a) # take an action\n",
    "\n",
    "env.unwrapped.last_ctrl, envFatigue.unwrapped.last_ctrl, envFatigue.unwrapped.muscle_fatigue.MF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This allows to predict how fatigue evolves over time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.reset()\n",
    "envFatigue.reset()\n",
    "data_store = []\n",
    "data_store_f = []\n",
    "for i in range(7*3): # 7 batches of 3 episodes, with 2 episodes of maximum muscle controls for some muscles followed by a resting episode (i.e., zero muscle controls) in each batch\n",
    "    a = np.zeros(env.unwrapped.sim.model.nu,)\n",
    "    if i%3!=2:\n",
    "        a[3:]=1\n",
    "    else:\n",
    "        a[:]=0\n",
    "    \n",
    "    for _ in range(500): # 500 samples (=10s) for each episode\n",
    "        next_o, r, done, *_, ifo = env.step(a) # take an action\n",
    "        next_f_o, r_f, done_F, *_, ifo_f = envFatigue.step(a) # take an action\n",
    "                    \n",
    "        data_store.append({\"action\":a.copy(), \n",
    "                            \"jpos\":env.unwrapped.sim.data.qpos.copy(), \n",
    "                            \"mlen\":env.unwrapped.sim.data.actuator_length.copy(), \n",
    "                            \"act\":env.unwrapped.sim.data.act.copy()})\n",
    "        data_store_f.append({\"action\":a.copy(), \n",
    "                            \"jpos\":envFatigue.unwrapped.sim.data.qpos.copy(), \n",
    "                            \"mlen\":envFatigue.unwrapped.sim.data.actuator_length.copy(),\n",
    "                            \"MF\":envFatigue.unwrapped.muscle_fatigue.MF.copy(),\n",
    "                            \"MR\":envFatigue.unwrapped.muscle_fatigue.MR.copy(),\n",
    "                            \"MA\":envFatigue.unwrapped.muscle_fatigue.MA.copy(), \n",
    "                            \"act\":envFatigue.unwrapped.sim.data.act.copy()})\n",
    "\n",
    "env.close()\n",
    "envFatigue.close()\n",
    "\n",
    "muscle_names = [env.unwrapped.sim.model.id2name(i, \"actuator\") for i in range(env.unwrapped.sim.model.nu) if env.unwrapped.sim.model.actuator_dyntype[i] == mujoco.mjtDyn.mjDYN_MUSCLE]\n",
    "muscle_id = -1\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.subplot(221)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['act'][muscle_id] for d in data_store]), label=\"Normal model/Desired activations\")\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['act'][muscle_id] for d in data_store_f]), label='Fatigued model')\n",
    "plt.legend()\n",
    "plt.title(f'Muscle activations over time ({muscle_names[muscle_id]})')\n",
    "plt.xlabel('time (s)'),plt.ylabel('act')\n",
    "\n",
    "plt.subplot(222)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['jpos'] for d in data_store]), label=\"Normal model\")\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['jpos'] for d in data_store_f]), label=\"Fatigued model\")\n",
    "plt.legend()\n",
    "plt.title('Joint angle over time')\n",
    "plt.xlabel('time (s)'),plt.ylabel('angle')\n",
    "\n",
    "plt.subplot(223)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['mlen'][muscle_id] for d in data_store]), label=\"Normal model\")\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['mlen'][muscle_id] for d in data_store_f]), label=\"Fatigued model\")\n",
    "plt.legend()\n",
    "plt.title(f'Muscle lengths over time ({muscle_names[muscle_id]})')\n",
    "plt.xlabel('time (s)'),plt.ylabel('muscle length')\n",
    "\n",
    "plt.subplot(224)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['MF'][muscle_id] for d in data_store_f]), color=\"tab:orange\")\n",
    "plt.title(f'Fatigued motor units over time ({muscle_names[muscle_id]})')\n",
    "plt.xlabel('time (s)'),plt.ylabel('%MVC')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.reset()\n",
    "envFatigue.reset()\n",
    "data_store = []\n",
    "data_store_f = []\n",
    "for i in range(2*3): # 2 batches of 3 episodes, with 0.5*MVC in first and 1*MVC in second episode, followed by a resting episode (i.e., zero muscle controls) in each batch\n",
    "    a = np.zeros(env.unwrapped.sim.model.nu,)\n",
    "    if i%3==0:\n",
    "        a[3:]=0.5\n",
    "    elif i%3==1:\n",
    "        a[3:]=1\n",
    "    else:\n",
    "        a[:]=0\n",
    "    \n",
    "    for _ in range(9000): # 9000 samples (=3 minutes) for each episode\n",
    "        next_o, r, done, *_, ifo = env.step(a) # take an action\n",
    "        next_f_o, r_f, done_F, *_, ifo_f = envFatigue.step(a) # take an action\n",
    "                    \n",
    "        data_store.append({\"action\":a.copy(), \n",
    "                            \"jpos\":env.unwrapped.sim.data.qpos.copy(), \n",
    "                            \"mlen\":env.unwrapped.sim.data.actuator_length.copy(), \n",
    "                            \"act\":env.unwrapped.sim.data.act.copy()})\n",
    "        data_store_f.append({\"action\":a.copy(), \n",
    "                            \"jpos\":envFatigue.unwrapped.sim.data.qpos.copy(), \n",
    "                            \"mlen\":envFatigue.unwrapped.sim.data.actuator_length.copy(),\n",
    "                            \"MF\":envFatigue.unwrapped.muscle_fatigue.MF.copy(),\n",
    "                            \"MR\":envFatigue.unwrapped.muscle_fatigue.MR.copy(),\n",
    "                            \"MA\":envFatigue.unwrapped.muscle_fatigue.MA.copy(),\n",
    "                            \"act\":envFatigue.unwrapped.sim.data.act.copy()})\n",
    "\n",
    "env.close()\n",
    "envFatigue.close()\n",
    "\n",
    "muscle_names = [env.unwrapped.sim.model.id2name(i, \"actuator\") for i in range(env.unwrapped.sim.model.nu) if env.unwrapped.sim.model.actuator_dyntype[i] == mujoco.mjtDyn.mjDYN_MUSCLE]\n",
    "muscle_id = -1\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.subplot(221)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['act'][muscle_id] for d in data_store]), label=\"Normal model/Desired activations\")\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['act'][muscle_id] for d in data_store_f]), label='Fatigued model')\n",
    "plt.legend()\n",
    "plt.title(f'Muscle activations over time ({muscle_names[muscle_id]})')\n",
    "plt.xlabel('time (s)'),plt.ylabel('act')\n",
    "\n",
    "plt.subplot(222)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['jpos'] for d in data_store]), label=\"Normal model\")\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['jpos'] for d in data_store_f]), label=\"Fatigued model\")\n",
    "plt.legend()\n",
    "plt.title('Joint angle over time')\n",
    "plt.xlabel('time (s)'),plt.ylabel('angle')\n",
    "\n",
    "plt.subplot(223)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['mlen'][muscle_id] for d in data_store]), label=\"Normal model\")\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['mlen'][muscle_id] for d in data_store_f]), label=\"Fatigued model\")\n",
    "plt.legend()\n",
    "plt.title(f'Muscle lengths over time ({muscle_names[muscle_id]})')\n",
    "plt.xlabel('time (s)'),plt.ylabel('muscle length')\n",
    "\n",
    "plt.subplot(224)\n",
    "plt.plot(env.unwrapped.dt*np.arange(len(data_store)), np.array([d['MF'][muscle_id] for d in data_store_f]), color=\"tab:orange\")\n",
    "plt.title(f'Fatigued motor units over time ({muscle_names[muscle_id]})')\n",
    "plt.xlabel('time (s)'),plt.ylabel('%MVC')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Agents with Fatigue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.callbacks import CheckpointCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
    "\n",
    "env = gym.make(env_name)\n",
    "env.unwrapped.set_fatigue_reset_random(True)\n",
    "env.reset()\n",
    "\n",
    "# Save a checkpoint every 1000 steps\n",
    "checkpoint_callback = CheckpointCallback(\n",
    "  save_freq=50000,\n",
    "  save_path=f\"./{env_name}/iterations/\",\n",
    "  name_prefix=\"rl_model\",\n",
    "  save_replay_buffer=True,\n",
    "  save_vecnormalize=True,\n",
    ")\n",
    "\n",
    "model = PPO(\"MlpPolicy\", env, verbose=0)\n",
    "model.learn(total_timesteps=200, callback=checkpoint_callback)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**NOTE:** By default, random fatigue states are sampled at the beginning of each training episode. \\\n",
    "To start with a specific, fixed fatigue state, set `fatigue_reset_random=False` and define `fatigue_reset_vec` as the vector MF of fatigued motor units per muscle.\n",
    "\n",
    "Best practice is to create a new  of the desired environment, i.e., calling `register_env_variant()` with\n",
    "`variants={'muscle_condition': 'fatigue',\n",
    "            'fatigue_reset_vec': np.array([0., 0., 0.]),\n",
    "            'fatigue_reset_random': False}`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulate and Evaluate Trained Agents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the following, we evaluate the latest policy trained for a given fatigue environment (and for comparison also the policy trained in the respective non-fatigue environment).\n",
    "\n",
    "To this end, we simulate several episodes per policy, with fatigue accumulating across episodes, starting in the default zero fatigue state. \\\n",
    "Videos of the first and the last few episodes are generated, and simulation data is logged (and later visualised) for all episodes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 1: Fatigue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
    "\n",
    "GENERATE_VIDEO = True\n",
    "GENERATE_VIDEO_EPS = 4  #number of episodes that are rendered BOTH at the beginning (i.e., without fatigue) and at the end (i.e., with fatigue)\n",
    "\n",
    "STORE_DATA = True  #store collected data from evaluation run in .npy file\n",
    "n_eps = 250\n",
    "\n",
    "###################################\n",
    "\n",
    "env = gym.make(env_name)\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "model = PPO.load(f\"{env_name}/iterations/rl_model_200000_steps\")\n",
    "\n",
    "env.unwrapped.set_fatigue_reset_random(False)\n",
    "env.reset(fatigue_reset=True)  #ensure that fatigue is reset before the simulation starts\n",
    "\n",
    "env.unwrapped.sim.model.cam_poscom0[0]= np.array([-1.3955, -0.3287,  0.6579])\n",
    "\n",
    "data_store = []\n",
    "if GENERATE_VIDEO:\n",
    "    frames = []\n",
    "\n",
    "env.unwrapped.target_jnt_value = env.unwrapped.target_jnt_range[:, 1]\n",
    "env.unwrapped.target_type = 'fixed'\n",
    "env.unwrapped.update_target(restore_sim=True)\n",
    "\n",
    "start_time = time.time()\n",
    "for ep in range(n_eps):\n",
    "    print(\"Ep {} of {}\".format(ep, n_eps))\n",
    "    \n",
    "    for _cstep in range(env.spec.max_episode_steps):\n",
    "        if GENERATE_VIDEO and (ep in range(GENERATE_VIDEO_EPS) or ep in range(n_eps-GENERATE_VIDEO_EPS, n_eps)):\n",
    "            frame = env.unwrapped.sim.renderer.render_offscreen(width=400, height=400, camera_id=0)\n",
    "            \n",
    "            # Add text overlay\n",
    "            _current_time = (ep*env.spec.max_episode_steps + _cstep)*env.unwrapped.dt\n",
    "            frame = np.array(add_text_to_frame(frame,\n",
    "                    f\"t={str(int(_current_time//60)).zfill(2)}:{str(int(_current_time%60)).zfill(2)}min\",\n",
    "                    pos=(285, 3), color=(0, 0, 0), fontsize=18))\n",
    "            \n",
    "            frames.append(frame)\n",
    "        o = env.unwrapped.get_obs()\n",
    "        a = model.predict(o)[0]\n",
    "        next_o, r, done, _, ifo = env.step(a) # take an action based on the current observation\n",
    "\n",
    "        data_store.append({\"action\":a.copy(), \n",
    "                            \"jpos\":env.unwrapped.sim.data.qpos.copy(), \n",
    "                            \"mlen\":env.unwrapped.sim.data.actuator_length.copy(), \n",
    "                            \"act\":env.unwrapped.sim.data.act.copy(),\n",
    "                            \"reward\":r,\n",
    "                            \"solved\":env.unwrapped.rwd_dict['solved'].item(),\n",
    "                            \"pose_err\":env.unwrapped.get_obs_dict(env.unwrapped.sim)[\"pose_err\"],\n",
    "                            \"MA\":env.unwrapped.muscle_fatigue.MA.copy(),\n",
    "                            \"MR\":env.unwrapped.muscle_fatigue.MR.copy(),\n",
    "                            \"MF\":env.unwrapped.muscle_fatigue.MF.copy(),\n",
    "                            \"ctrl\":env.unwrapped.last_ctrl.copy()})\n",
    "env.close()\n",
    "\n",
    "## OPTIONALLY: Stored simulated data\n",
    "if STORE_DATA:\n",
    "    os.makedirs(f\"{env_name}/logs\", exist_ok=True)\n",
    "    np.save(f\"{env_name}/logs/fatitest.npy\", data_store)\n",
    "\n",
    "## OPTIONALLY: Render video\n",
    "if GENERATE_VIDEO:\n",
    "    os.makedirs(f'{env_name}/videos', exist_ok=True)\n",
    "    # make a local copy\n",
    "    skvideo.io.vwrite(f'{env_name}/videos/fatitest.mp4', np.asarray(frames),inputdict={'-r': str(int(1/env.unwrapped.dt))},outputdict={\"-pix_fmt\": \"yuv420p\"})\n",
    "\n",
    "end_time = time.time()\n",
    "print(f\"DURATION: {end_time - start_time:.2f}s\")\n",
    "\n",
    "if GENERATE_VIDEO:\n",
    "    display(show_video(f'{env_name}/videos/fatitest.mp4'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
    "\n",
    "####################\n",
    "\n",
    "env_test = gym.make(env_name, normalize_act=False)\n",
    "muscle_names = [env_test.unwrapped.sim.model.id2name(i, \"actuator\") for i in range(env_test.unwrapped.sim.model.nu) if env_test.unwrapped.sim.model.actuator_dyntype[i] == mujoco.mjtDyn.mjDYN_MUSCLE]\n",
    "_env_dt = env_test.unwrapped.dt  #0.02\n",
    "\n",
    "data_store = np.load(f\"{env_name}/logs/fatitest.npy\", allow_pickle=True)\n",
    "\n",
    "plt.figure()\n",
    "for _muscleid in range(len(data_store[0]['MF'])):\n",
    "    plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MF'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
    "plt.legend()\n",
    "plt.title('Fatigued Motor Units')\n",
    "\n",
    "plt.figure()\n",
    "for _muscleid in range(len(data_store[0]['MR'])):\n",
    "    plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MR'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
    "plt.legend()\n",
    "plt.title('Resting Motor Units')\n",
    "\n",
    "plt.figure()\n",
    "for _muscleid in range(len(data_store[0]['MA'])):\n",
    "    plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MA'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
    "plt.legend()\n",
    "plt.title('Active Motor Units')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(_env_dt*np.arange(len(data_store)), np.array([np.linalg.norm(d['pose_err']) for d in data_store])), plt.title('Pose Error')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['reward'] for d in data_store])), plt.title(f\"Reward (Total: {np.array([d['reward'] for d in data_store]).sum():.2f})\")\n",
    "\n",
    "if \"solved\" in data_store[0]:\n",
    "    plt.figure()\n",
    "    plt.scatter(_env_dt*np.arange(len(data_store))[np.array([d['solved'] for d in data_store])], np.array([d['solved'] for d in data_store])[np.array([d['solved'] for d in data_store])]), plt.title(f\"Success\")\n",
    "\n",
    "print(f\"Muscle Fatigue Equilibrium: {data_store[-1]['MF']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Comparison: Policy trained without fatigue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 2: Fatigue + Resting Period"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this variant, the same target needs to be reached for 5 minutes, followed by a resting period of 2:30 minutes, and another 2:30 minutes of the same task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
    "\n",
    "GENERATE_VIDEO = True\n",
    "GENERATE_VIDEO_EPS = 300  #number of episodes that are rendered BOTH at the beginning (i.e., without fatigue) and at the end (i.e., with fatigue)\n",
    "\n",
    "STORE_DATA = True  #store collected data from evaluation run in .npy file\n",
    "n_eps = 300\n",
    "\n",
    "###################################\n",
    "\n",
    "env = gym.make(env_name)\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "model = PPO.load(f\"{env_name}/iterations/rl_model_200000_steps\")\n",
    "\n",
    "env.unwrapped.set_fatigue_reset_random(False)\n",
    "env.reset(fatigue_reset=True)  #ensure that fatigue is reset before the simulation starts\n",
    "\n",
    "env.unwrapped.sim.model.cam_poscom0[0]= np.array([-1.3955, -0.3287,  0.6579])\n",
    "\n",
    "data_store = []\n",
    "if GENERATE_VIDEO:\n",
    "    frames = []\n",
    "\n",
    "env.unwrapped.target_jnt_value = env.unwrapped.target_jnt_range[:, 1]\n",
    "env.unwrapped.target_type = 'fixed'\n",
    "env.unwrapped.update_target(restore_sim=True)\n",
    "\n",
    "start_time = time.time()\n",
    "for ep in range(n_eps):\n",
    "    print(\"Ep {} of {}\".format(ep, n_eps))\n",
    "    \n",
    "    for _cstep in range(env.spec.max_episode_steps):\n",
    "        if GENERATE_VIDEO and (ep in range(GENERATE_VIDEO_EPS) or ep in range(n_eps-GENERATE_VIDEO_EPS, n_eps)):\n",
    "            frame = env.unwrapped.sim.renderer.render_offscreen(width=480, height=480, camera_id=0)\n",
    "            \n",
    "            # Add text overlay\n",
    "            _current_time = (ep*env.spec.max_episode_steps + _cstep)*env.unwrapped.dt\n",
    "            frame = np.array(add_text_to_frame(frame,\n",
    "                    f\"t={str(int(_current_time//60)).zfill(2)}:{str(int(_current_time%60)).zfill(2)}min\",\n",
    "                    pos=(365, 3), color=(0, 0, 0), fontsize=18))\n",
    "            \n",
    "            if ep >= n_eps*0.5 and ep < n_eps*0.75:\n",
    "                frame = np.array(add_text_to_frame(frame,\n",
    "                    f\"Resting Phase\",\n",
    "                    pos=(320, 450), color=(84, 184, 81), fontsize=18))\n",
    "\n",
    "            frames.append(frame)\n",
    "        o = env.unwrapped.get_obs()\n",
    "        a = model.predict(o)[0]\n",
    "\n",
    "        if ep >= n_eps*0.5 and ep < n_eps*0.75:\n",
    "            a[:] = -100000  #resting period (corresponds to zero muscle activations)\n",
    "            env.unwrapped.sim.model.site_rgba[env.unwrapped.target_sids[0]][-1] = 0  #hide target during resting period\n",
    "            env.unwrapped.sim.model.tendon_rgba[-1][-1] = 0  #hide error line during resting period\n",
    "        else:\n",
    "            env.unwrapped.sim.model.site_rgba[env.unwrapped.target_sids[0]][-1] = 0.2  #visualise target during task\n",
    "            env.unwrapped.sim.model.tendon_rgba[-1][-1] = 0.2  #visualise error line during task\n",
    "\n",
    "        next_o, r, done, _, ifo = env.step(a) # take an action based on the current observation\n",
    "\n",
    "        data_store.append({\"action\":a.copy(), \n",
    "                            \"jpos\":env.unwrapped.sim.data.qpos.copy(), \n",
    "                            \"mlen\":env.unwrapped.sim.data.actuator_length.copy(), \n",
    "                            \"act\":env.unwrapped.sim.data.act.copy(),\n",
    "                            \"reward\":r,\n",
    "                            \"solved\":env.unwrapped.rwd_dict['solved'].item(),\n",
    "                            \"pose_err\":env.unwrapped.get_obs_dict(env.unwrapped.sim)[\"pose_err\"],\n",
    "                            \"MA\":env.unwrapped.muscle_fatigue.MA.copy(),\n",
    "                            \"MR\":env.unwrapped.muscle_fatigue.MR.copy(),\n",
    "                            \"MF\":env.unwrapped.muscle_fatigue.MF.copy(),\n",
    "                            \"ctrl\":env.unwrapped.last_ctrl.copy()})\n",
    "env.close()\n",
    "\n",
    "## OPTIONALLY: Stored simulated data\n",
    "if STORE_DATA:\n",
    "    os.makedirs(f\"{env_name}/logs\", exist_ok=True)\n",
    "    np.save(f\"{env_name}/logs/fatitest_recovery.npy\", data_store)\n",
    "\n",
    "## OPTIONALLY: Render video\n",
    "if GENERATE_VIDEO:\n",
    "    os.makedirs(f'{env_name}/videos', exist_ok=True)\n",
    "    # make a local copy\n",
    "    skvideo.io.vwrite(f'{env_name}/videos/fatitest_recovery.mp4', np.asarray(frames),inputdict={'-r': str(int(1/env.unwrapped.dt))},outputdict={\"-pix_fmt\": \"yuv420p\"})\n",
    "\n",
    "end_time = time.time()\n",
    "print(f\"DURATION: {end_time - start_time:.2f}s\")\n",
    "\n",
    "if GENERATE_VIDEO:\n",
    "    display(show_video(f'{env_name}/videos/fatitest_recovery.mp4'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env_name = \"myoFatiElbowPose1D6MRandom-v0\"\n",
    "\n",
    "####################\n",
    "\n",
    "env_test = gym.make(env_name, normalize_act=False)\n",
    "muscle_names = [env_test.unwrapped.sim.model.id2name(i, \"actuator\") for i in range(env_test.unwrapped.sim.model.nu) if env_test.unwrapped.sim.model.actuator_dyntype[i] == mujoco.mjtDyn.mjDYN_MUSCLE]\n",
    "_env_dt = env_test.unwrapped.dt  #0.02\n",
    "\n",
    "data_store = np.load(f\"{env_name}/logs/fatitest_recovery.npy\", allow_pickle=True)\n",
    "\n",
    "plt.figure()\n",
    "for _muscleid in range(len(data_store[0]['MF'])):\n",
    "    plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MF'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
    "plt.legend()\n",
    "plt.title('Fatigued Motor Units')\n",
    "\n",
    "plt.figure()\n",
    "for _muscleid in range(len(data_store[0]['MR'])):\n",
    "    plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MR'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
    "plt.legend()\n",
    "plt.title('Resting Motor Units')\n",
    "\n",
    "plt.figure()\n",
    "for _muscleid in range(len(data_store[0]['MA'])):\n",
    "    plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['MA'][_muscleid] for d in data_store]), label=muscle_names[_muscleid])\n",
    "plt.legend()\n",
    "plt.title('Active Motor Units')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(_env_dt*np.arange(len(data_store)), np.array([np.linalg.norm(d['pose_err']) for d in data_store])), plt.title('Pose Error')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(_env_dt*np.arange(len(data_store)), np.array([d['reward'] for d in data_store])), plt.title(f\"Reward (Total: {np.array([d['reward'] for d in data_store]).sum():.2f})\")\n",
    "\n",
    "if \"solved\" in data_store[0]:\n",
    "    plt.figure()\n",
    "    plt.scatter(_env_dt*np.arange(len(data_store))[np.array([d['solved'] for d in data_store])], np.array([d['solved'] for d in data_store])[np.array([d['solved'] for d in data_store])]), plt.title(f\"Success\")\n",
    "\n",
    "print(f\"Muscle Fatigue Equilibrium: {data_store[-1]['MF']}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}