Switch to unified view

a b/examples/legacy/train.arm.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "metadata": {},
6
   "source": [
7
    "# Learning how to move a human arm\n",
8
    "\n",
9
    "In this tutorial we will show how to train a basic biomechanical model using `keras-rl`.\n",
10
    "\n",
11
    "## Installation\n",
12
    "\n",
13
    "To make it work, follow the instructions in\n",
14
    "https://github.com/stanfordnmbl/osim-rl#getting-started\n",
15
    "i.e. run\n",
16
    "\n",
17
    "    conda create -n opensim-rl -c kidzik opensim python=3.6.1\n",
18
    "    activate opensim-rl\n",
19
    "    pip install git+https://github.com/stanfordnmbl/osim-rl.git\n",
20
    "\n",
21
    "Then run\n",
22
    "\n",
23
    "    pip install keras tensorflow keras-rl jupyter\n",
24
    "    git clone https://github.com/stanfordnmbl/osim-rl.git\n",
25
    "    cd osim-rl\n",
26
    "    \n",
27
    "follow the instructions and once jupyter is installed and type\n",
28
    "\n",
29
    "    jupyter notebook\n",
30
    "\n",
31
    "This should open the browser with jupyter. Navigate to this notebook, i.e. to the file `examples/train.arm.ipynb`.\n",
32
    "\n",
33
    "## Preparing the environment\n",
34
    "\n",
35
    "The following two blocks load necessary libraries and create a simulator environment."
36
   ]
37
  },
38
  {
39
   "cell_type": "code",
40
   "execution_count": null,
41
   "metadata": {},
42
   "outputs": [],
43
   "source": [
44
    "import osim\n",
45
    "import numpy as np\n",
46
    "import sys\n",
47
    "\n",
48
    "# Keras libraries \n",
49
    "from keras.optimizers import Adam\n",
50
    "\n",
51
    "import numpy as np\n",
52
    "from helpers import *\n",
53
    "\n",
54
    "from rl.agents import DDPGAgent\n",
55
    "from rl.memory import SequentialMemory\n",
56
    "from rl.random import OrnsteinUhlenbeckProcess\n",
57
    "\n",
58
    "from keras.optimizers import RMSprop\n",
59
    "\n",
60
    "import argparse\n",
61
    "import math"
62
   ]
63
  },
64
  {
65
   "cell_type": "code",
66
   "execution_count": null,
67
   "metadata": {},
68
   "outputs": [],
69
   "source": [
70
    "# Load arm environment\n",
71
    "from osim.env import Arm2DEnv\n",
72
    "env = Arm2DEnv(True)"
73
   ]
74
  },
75
  {
76
   "cell_type": "markdown",
77
   "metadata": {},
78
   "source": [
79
    "## Creating the actor and the critic\n",
80
    "\n",
81
    "The actor serves as a brain for controlling muscles. The critic is our approximation of how good is the brain performing for achieving the goal"
82
   ]
83
  },
84
  {
85
   "cell_type": "code",
86
   "execution_count": null,
87
   "metadata": {},
88
   "outputs": [],
89
   "source": [
90
    "# Create networks for DDPG\n",
91
    "# Next, we build a very simple model.\n",
92
    "actor = policy_nn(env.observation_space.shape[0], env.action_space.shape[0], hidden_layers = 3, hidden_size = 32)\n",
93
    "print(actor.summary())"
94
   ]
95
  },
96
  {
97
   "cell_type": "code",
98
   "execution_count": null,
99
   "metadata": {},
100
   "outputs": [],
101
   "source": [
102
    "qfunc = q_nn(env.observation_space.shape[0], env.action_space.shape[0], hidden_layers = 3, hidden_size = 64)\n",
103
    "print(qfunc[0].summary())"
104
   ]
105
  },
106
  {
107
   "cell_type": "markdown",
108
   "metadata": {},
109
   "source": [
110
    "## Train the actor and the critic\n",
111
    "\n",
112
    "We will now run `keras-rl` implementation of the DDPG algorithm which trains both networks."
113
   ]
114
  },
115
  {
116
   "cell_type": "code",
117
   "execution_count": null,
118
   "metadata": {},
119
   "outputs": [],
120
   "source": [
121
    "# Set up the agent for training\n",
122
    "memory = SequentialMemory(limit=100000, window_length=1)\n",
123
    "random_process = OrnsteinUhlenbeckProcess(theta=.15, mu=0., sigma=.2, size=env.action_space.shape)\n",
124
    "agent = DDPGAgent(nb_actions=env.action_space.shape[0], actor=actor, critic=qfunc[0], critic_action_input=qfunc[1],\n",
125
    "                  memory=memory, nb_steps_warmup_critic=100, nb_steps_warmup_actor=100,\n",
126
    "                  random_process=random_process, gamma=.99, target_model_update=1e-3,\n",
127
    "                  delta_clip=1.)\n",
128
    "agent.compile(Adam(lr=.001, clipnorm=1.), metrics=['mae'])"
129
   ]
130
  },
131
  {
132
   "cell_type": "code",
133
   "execution_count": null,
134
   "metadata": {
135
    "scrolled": true
136
   },
137
   "outputs": [],
138
   "source": [
139
    "# Okay, now it's time to learn something! We visualize the training here for show, but this\n",
140
    "# slows down training quite a lot. You can always safely abort the training prematurely by\n",
141
    "# stopping the notebook\n",
142
    "agent.fit(env, nb_steps=2000, visualize=False, verbose=0, nb_max_episode_steps=200, log_interval=10000)\n",
143
    "# After training is done, we save the final weights.\n",
144
    "# agent.save_weights(args.model, overwrite=True)"
145
   ]
146
  },
147
  {
148
   "cell_type": "markdown",
149
   "metadata": {},
150
   "source": [
151
    "## Evaluate the results\n",
152
    "Check how our trained 'brain' performs. Below we will also load a pretrained model (on the larger number of episodes), which should perform better. It was trained exactly the same way, just with a larger number of steps (parameter `nb_steps` in `agent.fit`."
153
   ]
154
  },
155
  {
156
   "cell_type": "code",
157
   "execution_count": null,
158
   "metadata": {},
159
   "outputs": [],
160
   "source": [
161
    "# agent.load_weights(args.model)\n",
162
    "# Finally, evaluate our algorithm for 2 episodes.\n",
163
    "agent.test(env, nb_episodes=2, visualize=False, nb_max_episode_steps=1000)"
164
   ]
165
  }
166
 ],
167
 "metadata": {
168
  "kernelspec": {
169
   "display_name": "Python 3",
170
   "language": "python",
171
   "name": "python3"
172
  },
173
  "language_info": {
174
   "codemirror_mode": {
175
    "name": "ipython",
176
    "version": 3
177
   },
178
   "file_extension": ".py",
179
   "mimetype": "text/x-python",
180
   "name": "python",
181
   "nbconvert_exporter": "python",
182
   "pygments_lexer": "ipython3",
183
   "version": "3.6.1"
184
  }
185
 },
186
 "nbformat": 4,
187
 "nbformat_minor": 2
188
}