Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exercise notebooks with no outputs. #207

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
181 changes: 151 additions & 30 deletions DP/Gamblers Problem.ipynb

Large diffs are not rendered by default.

83 changes: 47 additions & 36 deletions DP/Policy Evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -17,21 +15,17 @@
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"env = GridworldEnv()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": true
},
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
Expand All @@ -53,17 +47,27 @@
" # Start with a random (all 0) value function\n",
" V = np.zeros(env.nS)\n",
" while True:\n",
" # TODO: Implement!\n",
" break\n",
" delta = .0\n",
" for s in range(env.nS):\n",
" v = V[s]\n",
" v_new = 0\n",
" \n",
" for a, p_as in enumerate(policy[s]):\n",
" for (p_srsa, next_state, reward, done) in env.P[s][a]:\n",
" v_new += p_as * p_srsa * (reward + discount_factor * V[next_state])\n",
" V[s] = v_new\n",
" \n",
" delta = max(delta, np.abs(v - V[s]))\n",
"\n",
" if delta < theta:\n",
" break\n",
" return np.array(V)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"random_policy = np.ones([env.nS, env.nA]) / env.nA\n",
Expand All @@ -72,35 +76,42 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Test: Make sure the evaluated policy is what we expected\n",
"expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])\n",
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "\nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,\n -20, -14, 0])",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-235f39fb115c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Test: Make sure the evaluated policy is what we expected\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mexpected_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m22\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m22\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m14\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtesting\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_array_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpected_v\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[0;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 914\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m 915\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Arrays are not almost equal to %d decimals'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 916\u001b[0;31m precision=decimal)\n\u001b[0m\u001b[1;32m 917\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 918\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision)\u001b[0m\n\u001b[1;32m 735\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[1;32m 736\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: \nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,\n -20, -14, 0])"
]
"data": {
"text/plain": [
"array([ 0. , -13.99993529, -19.99990698, -21.99989761,\n",
" -13.99993529, -17.9999206 , -19.99991379, -19.99991477,\n",
" -19.99990698, -19.99991379, -17.99992725, -13.99994569,\n",
" -21.99989761, -19.99991477, -13.99994569, 0. ])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Test: Make sure the evaluated policy is what we expected\n",
"expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])\n",
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
"v"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
135 changes: 71 additions & 64 deletions DP/Policy Iteration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -18,10 +16,8 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"pp = pprint.PrettyPrinter(indent=2)\n",
Expand All @@ -30,10 +26,8 @@
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Taken from Policy Evaluation Exercise!\n",
Expand Down Expand Up @@ -78,10 +72,8 @@
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": true
},
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):\n",
Expand All @@ -102,57 +94,81 @@
" V is the value function for the optimal policy.\n",
" \n",
" \"\"\"\n",
" \n",
" def greedy_action(s, V, env):\n",
" A = np.zeros(env.nA)\n",
" for a in range(env.nA):\n",
" for prob, next_state, reward, done in env.P[s][a]:\n",
" A[a] += prob * (reward + discount_factor * V[next_state])\n",
" \n",
" return A.argmax()\n",
" \n",
" # Start with a random policy\n",
" policy = np.ones([env.nS, env.nA]) / env.nA\n",
" Vstar = None\n",
" \n",
" while True:\n",
" # Implement this!\n",
" break\n",
" V = policy_eval_fn(policy, env, discount_factor)\n",
" \n",
" policy_stable = True\n",
" for s in range(env.nS):\n",
" # given the initial policy is a random policy we should instead\n",
" # sample from the corresponding probability distribution\n",
" action = policy[s].argmax()\n",
" best_action = greedy_action(s, V, env)\n",
" \n",
" policy[s] = np.eye(env.nA)[best_action]\n",
" if action != best_action:\n",
" policy_stable = False\n",
" \n",
" if policy_stable:\n",
" Vstar = V\n",
" break\n",
" \n",
" return policy, np.zeros(env.nS)"
" return policy, Vstar"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Policy Probability Distribution:\n",
"[[ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]\n",
" [ 0.25 0.25 0.25 0.25]]\n",
"[[1. 0. 0. 0.]\n",
" [0. 0. 0. 1.]\n",
" [0. 0. 0. 1.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [1. 0. 0. 0.]]\n",
"\n",
"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
"[[0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]\n",
" [0 0 0 0]]\n",
"[[0 3 3 2]\n",
" [0 0 0 2]\n",
" [0 0 1 2]\n",
" [0 1 1 0]]\n",
"\n",
"Value Function:\n",
"[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
"[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1. 0.]\n",
"\n",
"Reshaped Grid Value Function:\n",
"[[ 0. 0. 0. 0.]\n",
" [ 0. 0. 0. 0.]\n",
" [ 0. 0. 0. 0.]\n",
" [ 0. 0. 0. 0.]]\n",
"[[ 0. -1. -2. -3.]\n",
" [-1. -2. -3. -2.]\n",
" [-2. -3. -2. -1.]\n",
" [-3. -2. -1. 0.]]\n",
"\n"
]
}
Expand All @@ -179,23 +195,9 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "AssertionError",
"evalue": "\nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-15-55581f8eb5c9>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Test the value function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mexpected_v\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtesting\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_array_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexpected_v\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[0;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 914\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m 915\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Arrays are not almost equal to %d decimals'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 916\u001b[0;31m precision=decimal)\n\u001b[0m\u001b[1;32m 917\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 918\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/dennybritz/venvs/tf/lib/python3.5/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision)\u001b[0m\n\u001b[1;32m 735\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[1;32m 736\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcond\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: \nArrays are not almost equal to 2 decimals\n\n(mismatch 87.5%)\n x: array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n 0., 0., 0.])\n y: array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])"
]
}
],
"outputs": [],
"source": [
"# Test the value function\n",
"expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])\n",
Expand All @@ -205,9 +207,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
Expand Down
Loading