Skip to content

Commit a3aa10a

Browse files
authored
Add RL example: cart pole (#106)
* Add example: Cart pole reinforcement learning
1 parent cc4a746 commit a3aa10a

10 files changed

+5630
-0
lines changed

cart-pole/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

cart-pole/README.md

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# TensorFlow.js Example: Reinforcement Learning with Cart-Pole Simulation
2+
3+
## Overview
4+
5+
This example illustrates how to use TensorFlow.js to perform simple
6+
reinforcement learning (RL). Specifically, it showcases an implementation
7+
of the policy-gradient method in TensorFlow.js with a combination of the Layers
8+
and gradients API. This implementation is used to solve the classic cart-pole
9+
control problem, which was originally proposed in:
10+
11+
- Barto, Sutton, and Anderson, "Neuronlike Adaptive Elements That Can Solve
12+
Difficult Learning Control Problems," IEEE Trans. Syst., Man, Cybern.,
13+
Vol. SMC-13, pp. 834--846, Sept.--Oct. 1983
14+
- Sutton, "Temporal Aspects of Credit Assignment in Reinforcement Learning",
15+
Ph.D. Dissertation, Department of Computer and Information Science,
16+
University of Massachusetts, Amherst, 1984.
17+
18+
It later became one of OpenAI's gym environmnets:
19+
https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
20+
21+
The gist of the RL algorithm in this example (see [index.js](../index.js)) is:
22+
23+
1. Define a policy network to make decisions on leftward vs. rightward force
24+
given the observed state of the system. The decision is not completely
25+
deterministic. Instead, it is a probability that is converted to the actual
26+
action by drawing random samples from binomial probability distribution.
27+
2. For each "game", calculate reward values in such a way that longer-lasting
28+
games are assigned positive reward values, while shorter-lasting ones
29+
are assigned negative reward values.
30+
3. Calculate the gradients of the policy network's weights with respect to the
31+
actual actions and scale the gradients with the reward values from step 2.
32+
The scale gradients are added to the policy network's weights, the effect of
33+
which is to make the policy network more likely to select actions that lead
34+
to the longer-lasting games given the same system states.
35+
36+
For a more detailed overview of policy gradient methods, see:
37+
http://www.scholarpedia.org/article/Policy_gradient_methods
38+
39+
For a more graphical illustration of the cart-pole problem, see:
40+
http://gym.openai.com/envs/CartPole-v1/
41+
42+
### Features:
43+
44+
- Allows user to specify the architecture of the policy network, in particular,
45+
the number of the neural networks's layers and their sizes (# of units).
46+
- Allows training of the policy network in the browser, optionally with
47+
simultaneous visualization of the cart-pole system.
48+
- Allows testing in the browser, with visualization.
49+
- Allows saving the policy network to the browser's IndexedDB. The saved policy
50+
network can later be loaded back for testing and/or further training.
51+
52+
## Usage
53+
54+
```sh
55+
yarn && yarn watch
56+
```

cart-pole/cart_pole.js

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Implementation based on: http://incompleteideas.net/book/code/pole.c
20+
*/
21+
22+
import * as tf from '@tensorflow/tfjs';
23+
24+
/**
25+
* Cart-pole system simulator.
26+
*
27+
* In the control-theory sense, there are four state variables in this system:
28+
*
29+
* - x: The 1D location of the cart.
30+
* - xDot: The velocity of the cart.
31+
* - theta: The angle of the pole (in radians). A value of 0 corresponds to
32+
* a vertical position.
33+
* - thetaDot: The angular velocity of the pole.
34+
*
35+
* The system is controlled through a single action:
36+
*
37+
* - leftward or rightward force.
38+
*/
39+
export class CartPole {
40+
/**
41+
* Constructor of CartPole.
42+
*/
43+
constructor() {
44+
// Constants that characterize the system.
45+
this.gravity = 9.8;
46+
this.massCart = 1.0;
47+
this.massPole = 0.1;
48+
this.totalMass = this.massCart + this.massPole;
49+
this.cartWidth = 0.2;
50+
this.cartHeight = 0.1;
51+
this.length = 0.5;
52+
this.poleMoment = this.massPole * this.length;
53+
this.forceMag = 10.0;
54+
this.tau = 0.02; // Seconds between state updates.
55+
56+
// Threshold values, beyond which a simulation will be marked as failed.
57+
this.xThreshold = 2.4;
58+
this.thetaTheshold = 12 / 360 * 2 * Math.PI;
59+
60+
this.setRandomState();
61+
}
62+
63+
/**
64+
* Set the state of the cart-pole system randomly.
65+
*/
66+
setRandomState() {
67+
// The control-theory state variables of the cart-pole system.
68+
// Cart position, meters.
69+
this.x = Math.random() - 0.5;
70+
// Cart velocity.
71+
this.xDot = (Math.random() - 0.5) * 1;
72+
// Pole angle, radians.
73+
this.theta = (Math.random() - 0.5) * 2 * (6 / 360 * 2 * Math.PI);
74+
// Pole angle velocity.
75+
this.thetaDot = (Math.random() - 0.5) * 0.5;
76+
}
77+
78+
/**
79+
* Get current state as a tf.Tensor of shape [1, 4].
80+
*/
81+
getStateTensor() {
82+
return tf.tensor2d([[this.x, this.xDot, this.theta, this.thetaDot]]);
83+
}
84+
85+
/**
86+
* Update the cart-pole system using an action.
87+
* @param {number} action Only the sign of `action` matters.
88+
* A value > 0 leads to a rightward force of a fixed magnitude.
89+
* A value <= 0 leads to a leftward force of the same fixed magnitude.
90+
*/
91+
update(action) {
92+
const force = action > 0 ? this.forceMag : -this.forceMag;
93+
94+
const cosTheta = Math.cos(this.theta);
95+
const sinTheta = Math.sin(this.theta);
96+
97+
const temp =
98+
(force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) /
99+
this.totalMass;
100+
const thetaAcc = (this.gravity * sinTheta - cosTheta * temp) /
101+
(this.length *
102+
(4 / 3 - this.massPole * cosTheta * cosTheta / this.totalMass));
103+
const xAcc = temp - this.poleMoment * thetaAcc * cosTheta / this.totalMass;
104+
105+
// Update the four state variables, using Euler's metohd.
106+
this.x += this.tau * this.xDot;
107+
this.xDot += this.tau * xAcc;
108+
this.theta += this.tau * this.thetaDot;
109+
this.thetaDot += this.tau * thetaAcc;
110+
111+
return this.isDone();
112+
}
113+
114+
/**
115+
* Determine whether this simulation is done.
116+
*
117+
* A simulation is done when `x` (position of the cart) goes out of bound
118+
* or when `theta` (angle of the pole) goes out of bound.
119+
*
120+
* @returns {bool} Whether the simulation is done.
121+
*/
122+
isDone() {
123+
return this.x < -this.xThreshold || this.x > this.xThreshold ||
124+
this.theta < -this.thetaTheshold || this.theta > this.thetaTheshold;
125+
}
126+
}

cart-pole/index.html

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
<!--
2+
Copyright 2018 Google LLC. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
-->
17+
18+
<html>
19+
<head>
20+
<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/material.cyan-teal.min.css" />
21+
</head>
22+
23+
<body>
24+
25+
<style>
26+
#app-status {
27+
color: blue;
28+
font-size: 150%;
29+
padding-bottom: 1em;
30+
}
31+
button {
32+
font-size: 105%;
33+
min-width: 120px;
34+
}
35+
input {
36+
font-family: monospace;
37+
width: 200px;
38+
}
39+
.input-div {
40+
padding: 5px;
41+
font-family: monospace;
42+
}
43+
.input-label {
44+
display: inline-block;
45+
width: 15em;
46+
}
47+
.canvases {
48+
display: inline-block;
49+
}
50+
.horizontal-sections {
51+
display: inline-block;
52+
padding-left: 0px;
53+
padding-right: 10px;
54+
vertical-align: top;
55+
border: 1px #AAA solid;
56+
}
57+
.status-span {
58+
display: inline-block;
59+
width: 150px;
60+
}
61+
.buttons-section {
62+
float: right;
63+
}
64+
input :disabled {
65+
background-color: #AAA,
66+
}
67+
</style>
68+
69+
<body>
70+
<h1>TensorFlow.js Example:<br/>Reinforcement Learning: Cart Pole </h1>
71+
72+
<div>
73+
<div>
74+
<span id="app-status">Standing by.</span>
75+
</div>
76+
77+
<div>
78+
<div class="horizontal-sections">
79+
<div class="input-div">
80+
<span class="input-label">Locally-stored network</span>
81+
<input id="stored-model-status" value="N/A" disabled="true" readonly="true"></input>
82+
<button id="delete-stored-model" disabled="true">Delete</button>
83+
</div>
84+
85+
<div class="horizontal-sections">
86+
<div class="input-div">
87+
<span class="input-label">Hidden layer size(s) (e.g.: "5", "8,6"):</span>
88+
<input id="hidden-layer-sizes" value="4"></input>
89+
<button id="create-model" disabled="true">Create model</button>
90+
</div>
91+
</div>
92+
93+
<div class="input-div">
94+
<span class="input-label">Number of iterations:</span>
95+
<input id="num-iterations" value="20"></input>
96+
</div>
97+
<div class="input-div">
98+
<span class="input-label">Games per iteration:</span>
99+
<input id="games-per-iteration" value="20"></input>
100+
</div>
101+
<div class="input-div">
102+
<span class="input-label">Max. steps per game:</span>
103+
<input id="max-steps-per-game" value="500"></input>
104+
</div>
105+
<div class="input-div">
106+
<span class="input-label">Reward discount rate:</span>
107+
<input id="discount-rate" value="0.95"></input>
108+
</div>
109+
<div class="input-div">
110+
<span class="input-label">Learning rate:</span>
111+
<input id="learning-rate" value="0.05"></input>
112+
</div>
113+
<div class="input-div">
114+
<span class="input-label">Render during training:</span>
115+
<input type="checkbox" id="render-during-training" />
116+
</div>
117+
<div class="buttons-section">
118+
<button id="train" disabled="true">Train</button>
119+
<button id="test" disabled="true">Test</button>
120+
</div>
121+
</div>
122+
123+
<div class="horizontal-sections">
124+
<div>
125+
<span id="iteration-status" class="status-span"></span>
126+
<progress value="0" max="100" id="iteration-progress"></progress>
127+
</div>
128+
<div>
129+
<span id="train-status" class="status-span"></span>
130+
<progress value="0" max="100" id="train-progress"></progress>
131+
</div>
132+
<div>
133+
<span class="status-span">Training speed:</span>
134+
<span id="train-speed" class="status-span"></span>
135+
</div>
136+
<div class="canvases" id="steps-canvas"></div>
137+
</div>
138+
</div>
139+
140+
<div>
141+
<canvas id="cart-pole-canvas" height="150px" width="500px"></canvas>
142+
</div>
143+
</div>
144+
145+
<script src="index.js"></script>
146+
</body>
147+
</html>

0 commit comments

Comments
 (0)