Skip to content

Commit 47af393

Browse files
authored
Add files via upload
1 parent b1a99b1 commit 47af393

File tree

1 file changed

+312
-0
lines changed

1 file changed

+312
-0
lines changed

Policy Iteration.ipynb

+312
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 81,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np\n",
10+
"import pprint\n",
11+
"import sys\n",
12+
"if \"../\" not in sys.path:\n",
13+
" sys.path.append(\"../\") \n",
14+
"from lib.envs.gridworld import GridworldEnv"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 82,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"pp = pprint.PrettyPrinter(indent=2)\n",
24+
"env = GridworldEnv()"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 83,
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"# Taken from Policy Evaluation Exercise!\n",
34+
"\n",
35+
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
36+
" \"\"\"\n",
37+
" Evaluate a policy given an environment and a full description of the environment's dynamics.\n",
38+
" \n",
39+
" Args:\n",
40+
" policy: [S, A] shaped matrix representing the policy.\n",
41+
" env: OpenAI env. env.P represents the transition probabilities of the environment.\n",
42+
" env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).\n",
43+
" env.nS is a number of states in the environment. \n",
44+
" env.nA is a number of actions in the environment.\n",
45+
" theta: We stop evaluation once our value function change is less than theta for all states.\n",
46+
" discount_factor: Gamma discount factor.\n",
47+
" \n",
48+
" Returns:\n",
49+
" Vector of length env.nS representing the value function.\n",
50+
" \"\"\"\n",
51+
" # Start with a random (all 0) value function\n",
52+
" V = np.zeros(env.nS)\n",
53+
" while True:\n",
54+
" delta = 0\n",
55+
" # For each state, perform a \"full backup\"\n",
56+
" for s in range(env.nS):\n",
57+
" v = 0\n",
58+
" # Look at the possible next actions\n",
59+
" for a, action_prob in enumerate(policy[s]):\n",
60+
" # For each action, look at the possible next states...\n",
61+
" for prob, next_state, reward, done in env.P[s][a]:\n",
62+
" # Calculate the expected value\n",
63+
" v += action_prob * prob * (reward + discount_factor * V[next_state])\n",
64+
" # How much our value function changed (across any states)\n",
65+
" delta = max(delta, np.abs(v - V[s]))\n",
66+
" V[s] = v\n",
67+
" # Stop evaluating once our value function change is below a threshold\n",
68+
" if delta < theta:\n",
69+
" break\n",
70+
" return np.array(V)"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 84,
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):\n",
80+
" \"\"\"\n",
81+
" Policy Improvement Algorithm. Iteratively evaluates and improves a policy\n",
82+
" until an optimal policy is found.\n",
83+
" \n",
84+
" Args:\n",
85+
" env: The OpenAI envrionment.\n",
86+
" policy_eval_fn: Policy Evaluation function that takes 3 arguments:\n",
87+
" policy, env, discount_factor.\n",
88+
" discount_factor: gamma discount factor.\n",
89+
" \n",
90+
" Returns:\n",
91+
" A tuple (policy, V). \n",
92+
" policy is the optimal policy, a matrix of shape [S, A] where each state s\n",
93+
" contains a valid probability distribution over actions.\n",
94+
" V is the value function for the optimal policy.\n",
95+
" \n",
96+
" \"\"\"\n",
97+
" # Start with a random policy\n",
98+
" policy = np.ones([env.nS, env.nA]) / env.nA\n",
99+
" \n",
100+
" while True:\n",
101+
" # Implement this!\n",
102+
" #We first evaluate the policy.\n",
103+
" v_pi = policy_eval_fn(policy, env, discount_factor)\n",
104+
" policy_stable = True\n",
105+
" #Going through all of the states, one by one\n",
106+
" for s in range(env.nS):\n",
107+
" \n",
108+
" old_action = np.argmax(policy[s])\n",
109+
" \n",
110+
" expected_values = np.zeros(env.nA)\n",
111+
" #Doing one-step-lookahead from the current state\n",
112+
" for a in range(env.nA):\n",
113+
" #For each action, we keep a record of expected return\n",
114+
" for prob, next_state, reward, done in env.P[s][a]:\n",
115+
" expected_values[a] += prob*(reward+discount_factor*v_pi[next_state])\n",
116+
" #Declaring new policy (Hence, new action), by acting greedy with respect to the current value function\n",
117+
" new_action = np.argmax(expected_values)\n",
118+
" \n",
119+
" if old_action!=new_action:\n",
120+
" policy_stable = False\n",
121+
" \n",
122+
" #Replacing new policy by new one. \n",
123+
" policy[s] = np.eye(env.nA)[new_action]\n",
124+
" \n",
125+
" #Checking if there's any change in past and new action, if no, then our job's done. \n",
126+
" if policy_stable:\n",
127+
" return policy, v_pi\n",
128+
" \n",
129+
" \n",
130+
" \n",
131+
" \n",
132+
" \n",
133+
" \n",
134+
" \n",
135+
" \n",
136+
" "
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": 85,
142+
"metadata": {},
143+
"outputs": [
144+
{
145+
"name": "stdout",
146+
"output_type": "stream",
147+
"text": [
148+
"1\n",
149+
"2\n",
150+
"5\n",
151+
"0\n"
152+
]
153+
}
154+
],
155+
"source": [
156+
"for key,val in env.P[1].items():\n",
157+
" print(val[0][1])"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": 86,
163+
"metadata": {},
164+
"outputs": [
165+
{
166+
"data": {
167+
"text/plain": [
168+
"{0: [(1.0, 1, -1.0, False)],\n",
169+
" 1: [(1.0, 2, -1.0, False)],\n",
170+
" 2: [(1.0, 5, -1.0, False)],\n",
171+
" 3: [(1.0, 0, -1.0, True)]}"
172+
]
173+
},
174+
"execution_count": 86,
175+
"metadata": {},
176+
"output_type": "execute_result"
177+
}
178+
],
179+
"source": [
180+
"env.P[1]"
181+
]
182+
},
183+
{
184+
"cell_type": "code",
185+
"execution_count": 87,
186+
"metadata": {},
187+
"outputs": [
188+
{
189+
"name": "stdout",
190+
"output_type": "stream",
191+
"text": [
192+
"Policy Probability Distribution:\n",
193+
"[[1. 0. 0. 0.]\n",
194+
" [0. 0. 0. 1.]\n",
195+
" [0. 0. 0. 1.]\n",
196+
" [0. 0. 1. 0.]\n",
197+
" [1. 0. 0. 0.]\n",
198+
" [1. 0. 0. 0.]\n",
199+
" [1. 0. 0. 0.]\n",
200+
" [0. 0. 1. 0.]\n",
201+
" [1. 0. 0. 0.]\n",
202+
" [1. 0. 0. 0.]\n",
203+
" [0. 1. 0. 0.]\n",
204+
" [0. 0. 1. 0.]\n",
205+
" [1. 0. 0. 0.]\n",
206+
" [0. 1. 0. 0.]\n",
207+
" [0. 1. 0. 0.]\n",
208+
" [1. 0. 0. 0.]]\n",
209+
"\n",
210+
"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
211+
"[[0 3 3 2]\n",
212+
" [0 0 0 2]\n",
213+
" [0 0 1 2]\n",
214+
" [0 1 1 0]]\n",
215+
"\n",
216+
"Value Function:\n",
217+
"[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1. 0.]\n",
218+
"\n",
219+
"Reshaped Grid Value Function:\n",
220+
"[[ 0. -1. -2. -3.]\n",
221+
" [-1. -2. -3. -2.]\n",
222+
" [-2. -3. -2. -1.]\n",
223+
" [-3. -2. -1. 0.]]\n",
224+
"\n"
225+
]
226+
}
227+
],
228+
"source": [
229+
"policy, v = policy_improvement(env)\n",
230+
"print(\"Policy Probability Distribution:\")\n",
231+
"print(policy)\n",
232+
"print(\"\")\n",
233+
"\n",
234+
"print(\"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\")\n",
235+
"print(np.reshape(np.argmax(policy, axis=1), env.shape))\n",
236+
"print(\"\")\n",
237+
"\n",
238+
"print(\"Value Function:\")\n",
239+
"print(v)\n",
240+
"print(\"\")\n",
241+
"\n",
242+
"print(\"Reshaped Grid Value Function:\")\n",
243+
"print(v.reshape(env.shape))\n",
244+
"print(\"\")\n",
245+
"\n"
246+
]
247+
},
248+
{
249+
"cell_type": "code",
250+
"execution_count": 89,
251+
"metadata": {},
252+
"outputs": [
253+
{
254+
"name": "stdout",
255+
"output_type": "stream",
256+
"text": [
257+
"14.3 ms ± 270 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
258+
]
259+
}
260+
],
261+
"source": [
262+
"%timeit policy_improvement(env)"
263+
]
264+
},
265+
{
266+
"cell_type": "markdown",
267+
"metadata": {},
268+
"source": [
269+
"Clearly, this process is quite slow, for even a small problem like the given gridworld."
270+
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": 88,
275+
"metadata": {},
276+
"outputs": [],
277+
"source": [
278+
"# Test the value function\n",
279+
"expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1, 0])\n",
280+
"np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
281+
]
282+
},
283+
{
284+
"cell_type": "code",
285+
"execution_count": null,
286+
"metadata": {},
287+
"outputs": [],
288+
"source": []
289+
}
290+
],
291+
"metadata": {
292+
"kernelspec": {
293+
"display_name": "Python 3",
294+
"language": "python",
295+
"name": "python3"
296+
},
297+
"language_info": {
298+
"codemirror_mode": {
299+
"name": "ipython",
300+
"version": 3
301+
},
302+
"file_extension": ".py",
303+
"mimetype": "text/x-python",
304+
"name": "python",
305+
"nbconvert_exporter": "python",
306+
"pygments_lexer": "ipython3",
307+
"version": "3.6.5"
308+
}
309+
},
310+
"nbformat": 4,
311+
"nbformat_minor": 4
312+
}

0 commit comments

Comments
 (0)