[87e8bf]: / docs / source / tutorials / 2_Load_policy.ipynb

Download this file

121 lines (120 with data), 3.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from myosuite.utils import gym\n",
    "import skvideo.io\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    "\n",
    "def show_video(video_path, video_width = 400):\n",
    "\n",
    "  video_file = open(video_path, \"r+b\").read()\n",
    "\n",
    "  video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n",
    "  return HTML(f\"\"\"<video autoplay width={video_width} controls><source src=\"{video_url}\"></video>\"\"\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pth = '../../../myosuite/agents/baslines_NPG/'\n",
    "\n",
    "policy = pth+\"myoElbowPose1D6MExoRandom-v0/2022-02-26_21-16-27/36_env=myoElbowPose1D6MExoRandom-v0,seed=1/iterations/best_policy.pickle\"\n",
    "\n",
    "import pickle\n",
    "pi = pickle.load(open(policy, 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('myoElbowPose1D6MExoRandom-v0')\n",
    "\n",
    "env.reset()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a discrete sequence of positions to test\n",
    "AngleSequence = [60, 30, 30, 60, 80, 80, 60, 30, 80, 30, 80, 60]\n",
    "env.reset()\n",
    "frames = []\n",
    "for ep in range(len(AngleSequence)):\n",
    "    print(\"Ep {} of {} testing angle {}\".format(ep, len(AngleSequence), AngleSequence[ep]))\n",
    "    env.unwrapped.target_jnt_value = [np.deg2rad(AngleSequence[int(ep)])]\n",
    "    env.unwrapped.target_type = 'fixed'\n",
    "    env.unwrapped.weight_range=(0,0)\n",
    "    env.unwrapped.update_target()\n",
    "    for _ in range(40):\n",
    "        frame = env.sim.renderer.render_offscreen(\n",
    "                        width=400,\n",
    "                        height=400,\n",
    "                        camera_id=0)\n",
    "        frames.append(frame)\n",
    "        o = env.get_obs()\n",
    "        a = pi.get_action(o)[0]\n",
    "        next_o, r, done, *_, ifo = env.step(a) # take an action based on the current observation\n",
    "env.close()\n",
    "\n",
    "os.makedirs('videos', exist_ok=True)\n",
    "# make a local copy\n",
    "skvideo.io.vwrite('videos/exo_arm.mp4', np.asarray(frames),outputdict={\"-pix_fmt\": \"yuv420p\"})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video('videos/exo_arm.mp4')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}