Skip to content

Commit b1a99b1

Browse files
authored
Add files via upload
0 parents  commit b1a99b1

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed

Value Iteration.ipynb

+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 21,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np\n",
10+
"import pprint\n",
11+
"import sys\n",
12+
"if \"../\" not in sys.path:\n",
13+
" sys.path.append(\"../\") \n",
14+
"from lib.envs.gridworld import GridworldEnv"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 22,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"pp = pprint.PrettyPrinter(indent=2)\n",
24+
"env = GridworldEnv()"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 23,
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"def value_iteration(env, theta=0.0001, discount_factor=1.0):\n",
34+
" \"\"\"\n",
35+
" Value Iteration Algorithm.\n",
36+
" \n",
37+
" Args:\n",
38+
" env: OpenAI env. env.P represents the transition probabilities of the environment.\n",
39+
" env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).\n",
40+
" env.nS is a number of states in the environment. \n",
41+
" env.nA is a number of actions in the environment.\n",
42+
" theta: We stop evaluation once our value function change is less than theta for all states.\n",
43+
" discount_factor: Gamma discount factor.\n",
44+
" \n",
45+
" Returns:\n",
46+
" A tuple (policy, V) of the optimal policy and the optimal value function. \n",
47+
" \"\"\"\n",
48+
" \n",
49+
" #Initializing Values and policy.\n",
50+
" V = np.zeros(env.nS)\n",
51+
" policy = np.zeros([env.nS, env.nA])\n",
52+
" \n",
53+
" # Implement!\n",
54+
" while True:\n",
55+
" delta = 0\n",
56+
" \n",
57+
" #Going through all of the states, one by one.\n",
58+
" for s in range(env.nS):\n",
59+
" \n",
60+
"\n",
61+
" v = 0\n",
62+
" expected_values = np.zeros(env.nA)\n",
63+
" #For updating the value, we do one step-look-ahead \n",
64+
" for a in range(env.nA):\n",
65+
" #We keep an array of expected returns from all of the actions possible \n",
66+
" for prob, next_state, reward, done in env.P[s][a]:\n",
67+
" expected_values[a] += prob*(reward+discount_factor*V[next_state])\n",
68+
" \n",
69+
" #Choosing value as the max of all the possible returns we can get from the actions possible \n",
70+
" v = np.max(expected_values) \n",
71+
" delta = max(delta, np.abs(v - V[s]))\n",
72+
" V[s] = v\n",
73+
" if delta < theta:\n",
74+
" break\n",
75+
" #for policy, just act greedily w.r.t. this value function.\n",
76+
" for s in range(env.nS):\n",
77+
" #To act greeddily, we do one step-look-ahead, again.\n",
78+
" expected_values = np.zeros(env.nA)\n",
79+
" for a in range(env.nA):\n",
80+
" #We keep an array of expected returns from all of the actions possible \n",
81+
" for prob, next_state, reward, done in env.P[s][a]:\n",
82+
" expected_values[a] += prob*(reward+discount_factor*V[next_state])\n",
83+
" #Creating new policy as the action for each state that maximizes the expected return \n",
84+
" new_action = np.argmax(expected_values)\n",
85+
" #Updating the old policy\n",
86+
" policy[s] = np.eye(env.nA)[new_action]\n",
87+
" return policy, V"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 24,
93+
"metadata": {},
94+
"outputs": [
95+
{
96+
"name": "stdout",
97+
"output_type": "stream",
98+
"text": [
99+
"Policy Probability Distribution:\n",
100+
"[[1. 0. 0. 0.]\n",
101+
" [0. 0. 0. 1.]\n",
102+
" [0. 0. 0. 1.]\n",
103+
" [0. 0. 1. 0.]\n",
104+
" [1. 0. 0. 0.]\n",
105+
" [1. 0. 0. 0.]\n",
106+
" [1. 0. 0. 0.]\n",
107+
" [0. 0. 1. 0.]\n",
108+
" [1. 0. 0. 0.]\n",
109+
" [1. 0. 0. 0.]\n",
110+
" [0. 1. 0. 0.]\n",
111+
" [0. 0. 1. 0.]\n",
112+
" [1. 0. 0. 0.]\n",
113+
" [0. 1. 0. 0.]\n",
114+
" [0. 1. 0. 0.]\n",
115+
" [1. 0. 0. 0.]]\n",
116+
"\n",
117+
"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
118+
"[[0 3 3 2]\n",
119+
" [0 0 0 2]\n",
120+
" [0 0 1 2]\n",
121+
" [0 1 1 0]]\n",
122+
"\n",
123+
"Value Function:\n",
124+
"[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1. 0.]\n",
125+
"\n",
126+
"Reshaped Grid Value Function:\n",
127+
"[[ 0. -1. -2. -3.]\n",
128+
" [-1. -2. -3. -2.]\n",
129+
" [-2. -3. -2. -1.]\n",
130+
" [-3. -2. -1. 0.]]\n",
131+
"\n"
132+
]
133+
}
134+
],
135+
"source": [
136+
"policy, v = value_iteration(env)\n",
137+
"\n",
138+
"print(\"Policy Probability Distribution:\")\n",
139+
"print(policy)\n",
140+
"print(\"\")\n",
141+
"\n",
142+
"print(\"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\")\n",
143+
"print(np.reshape(np.argmax(policy, axis=1), env.shape))\n",
144+
"print(\"\")\n",
145+
"\n",
146+
"print(\"Value Function:\")\n",
147+
"print(v)\n",
148+
"print(\"\")\n",
149+
"\n",
150+
"print(\"Reshaped Grid Value Function:\")\n",
151+
"print(v.reshape(env.shape))\n",
152+
"print(\"\")"
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": 26,
158+
"metadata": {},
159+
"outputs": [
160+
{
161+
"name": "stdout",
162+
"output_type": "stream",
163+
"text": [
164+
"816 µs ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
165+
]
166+
}
167+
],
168+
"source": [
169+
"%timeit value_iteration(env)"
170+
]
171+
},
172+
{
173+
"cell_type": "markdown",
174+
"metadata": {},
175+
"source": [
176+
"As you can see, the convergence of the Value Iteration Algorithm is order times better than Policy Iteration."
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": 25,
182+
"metadata": {},
183+
"outputs": [],
184+
"source": [
185+
"# Test the value function\n",
186+
"expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])\n",
187+
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"metadata": {},
194+
"outputs": [],
195+
"source": []
196+
}
197+
],
198+
"metadata": {
199+
"kernelspec": {
200+
"display_name": "Python 3",
201+
"language": "python",
202+
"name": "python3"
203+
},
204+
"language_info": {
205+
"codemirror_mode": {
206+
"name": "ipython",
207+
"version": 3
208+
},
209+
"file_extension": ".py",
210+
"mimetype": "text/x-python",
211+
"name": "python",
212+
"nbconvert_exporter": "python",
213+
"pygments_lexer": "ipython3",
214+
"version": "3.6.5"
215+
}
216+
},
217+
"nbformat": 4,
218+
"nbformat_minor": 4
219+
}

0 commit comments

Comments
 (0)