[87e8bf]: / docs / source / tutorials / 4_Train_policy.ipynb

Download this file

156 lines (155 with data), 4.7 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from myosuite.utils import gym\n",
    "import skvideo.io\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    "\n",
    "def show_video(video_path, video_width = 400):\n",
    "\n",
    "  video_file = open(video_path, \"r+b\").read()\n",
    "\n",
    "  video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n",
    "  return HTML(f\"\"\"<video autoplay width={video_width} controls><source src=\"{video_url}\"></video>\"\"\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "env = gym.make('myoElbowPose1D6MRandom-v0')\n",
    "\n",
    "env.reset();\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from mjrl.utils.gym_env import GymEnv\n",
    "from mjrl.policies.gaussian_mlp import MLP\n",
    "from mjrl.baselines.mlp_baseline import MLPBaseline\n",
    "from mjrl.algos.npg_cg import NPG\n",
    "from mjrl.utils.train_agent import train_agent\n",
    "import myosuite\n",
    "\n",
    "policy_size = (32, 32)\n",
    "vf_hidden_size = (128, 128)\n",
    "seed = 123\n",
    "rl_step_size = 0.1\n",
    "e = GymEnv(env)\n",
    "\n",
    "policy = MLP(e.spec, hidden_sizes=policy_size, seed=seed, init_log_std=-0.25, min_log_std=-1.0)\n",
    "\n",
    "baseline = MLPBaseline(e.spec, reg_coef=1e-3, batch_size=64, hidden_sizes=vf_hidden_size, \\\n",
    "                    epochs=2, learn_rate=1e-3)\n",
    "\n",
    "agent = NPG(e, policy, baseline, normalized_step_size=rl_step_size, \\\n",
    "            seed=seed, save_logs=True)\n",
    "\n",
    "print(\"========================================\")\n",
    "print(\"Starting policy learning\")\n",
    "print(\"========================================\")\n",
    "\n",
    "train_agent(job_name='.',\n",
    "            agent=agent,\n",
    "            seed=seed,\n",
    "            niter=200,\n",
    "            gamma=0.995,\n",
    "            gae_lambda=0.97,\n",
    "            num_cpu=8,\n",
    "            sample_mode=\"trajectories\",\n",
    "            num_traj=96,\n",
    "            num_samples=0,\n",
    "            save_freq=100,\n",
    "            evaluation_rollouts=10)\n",
    "\n",
    "print(\"========================================\")\n",
    "print(\"Job Finished.\")\n",
    "print(\"========================================\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policy = \"iterations/best_policy.pickle\"\n",
    "\n",
    "import pickle\n",
    "pi = pickle.load(open(policy, 'rb'))\n",
    "\n",
    "AngleSequence = [60, 30, 30, 60, 80, 80, 60, 30, 80, 30, 80, 60]\n",
    "env.reset()\n",
    "frames = []\n",
    "for ep in range(len(AngleSequence)):\n",
    "    print(\"Ep {} of {} testing angle {}\".format(ep, len(AngleSequence), AngleSequence[ep]))\n",
    "    env.unwrapped.target_jnt_value = [np.deg2rad(AngleSequence[int(ep)])]\n",
    "    env.unwrapped.target_type = 'fixed'\n",
    "    env.unwrapped.weight_range=(0,0)\n",
    "    env.unwrapped.update_target()\n",
    "    for _ in range(40):\n",
    "        frame = env.sim.render(width=400, height=400,mode='offscreen', camera_name=None)\n",
    "        frames.append(frame[::-1,:,:])\n",
    "        o = env.get_obs()\n",
    "        a = pi.get_action(o)[0]\n",
    "        next_o, r, done, *_, ifo = env.step(a) # take an action based on the current observation\n",
    "env.close()\n",
    "\n",
    "os.makedirs('videos', exist_ok=True)\n",
    "# make a local copy\n",
    "skvideo.io.vwrite('videos/arm.mp4', np.asarray(frames),outputdict={\"-pix_fmt\": \"yuv420p\"})\n",
    "show_video('videos/arm.mp4')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}