Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/mdptoolbox/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,81 @@ def small():
P = _np.array([[[0.5, 0.5], [0.8, 0.2]], [[0, 1], [0.1, 0.9]]])
R = _np.array([[5, 10], [-1, 2]])
return(P, R)


def gridworld():
"""4x4 gridworld example.

Example 4.1 of `Reinforcement Learning: An Introduction
<http://webdocs.cs.ualberta.ca/~sutton/book/the-book.html>`_,
by Richard S. Sutton and Andrew G. Barto.

Returns
-------
out : tuple
``out[0]`` contains the transition probability matrix P,
and ``out[1]`` contains the reward matrix R. The non-terminal
states correspond to the indices 0-13 in both matrices,
and the terminal state to the index 14.

Examples
--------
>>> import mdptoolbox.example
>>> P, R = mdptoolbox.example.gridworld()
>>> P.shape
(4, 15, 15)
>>> R.shape
(4, 15, 15)
"""
# States: labelled 1, 2, ..., 14 in the figure, plus the terminal
# state associated with the two terminal positions.
S = 15 # number of states
terminal_state = 14 # terminal state index

# Actions: up, down, right, left.
A = 4 # number of actions
up, down, right, left = range(A) # indices of the actions

# Transitions.
P = _np.zeros((A, S, S))

# Grid transitions.
grid_transitions = {
# from_state: ((action, to_state), ...)
1: ((down, 5), (right, 2), (left, 15)),
2: ((down, 6), (right, 3), (left, 1)),
3: ((down, 7), (left, 2)),
4: ((up, 15), (down, 8), (right, 5)),
5: ((up, 1), (down, 9), (right, 6), (left, 4)),
6: ((up, 2), (down, 10), (right, 7), (left, 5)),
7: ((up, 3), (down, 11), (left, 6)),
8: ((up, 4), (down, 12), (right, 9)),
9: ((up, 5), (down, 13), (right, 10), (left, 8)),
10: ((up, 6), (down, 14), (right, 11), (left, 9)),
11: ((up, 7), (down, 15), (left, 10)),
12: ((up, 8), (right, 13)),
13: ((up, 9), (right, 14), (left, 12)),
14: ((up, 10), (right, 15), (left, 13))
}
for i, moves in grid_transitions.items():
for a, j in moves:
P[a, i - 1, j - 1] = 1.0

# Border transitions.
for i in (1, 2, 3):
P[up, i - 1, i - 1] = 1.0
for i in (12, 13, 14):
P[down, i - 1, i - 1] = 1.0
for i in (3, 7, 11):
P[right, i - 1, i - 1] = 1.0
for i in (4, 8, 12):
P[left, i - 1, i - 1] = 1.0

# The terminal state should be an absorbing state.
P[:, terminal_state, terminal_state] = 1.0

# Rewards.
R = -1 * _np.ones((A, S, S))
R[:, terminal_state, :] = 0

return P, R
6 changes: 6 additions & 0 deletions src/tests/test_PolicyIteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .utils import SMALLNUM, P_forest, R_forest, P_small, R_small, P_sparse, \
P_forest_sparse, R_forest_sparse, \
P_gridworld, R_gridworld, policy_gridworld, \
assert_sequence_almost_equal

def test_PolicyIteration_init_policy0():
Expand Down Expand Up @@ -123,6 +124,11 @@ def test_PolicyIterative_forest_sparse():
assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all()
assert sdp.iter == itr

def test_PolicyIteration_gridworld():
pi = mdptoolbox.mdp.PolicyIteration(P_gridworld, R_gridworld, 1.0)
pi.run()
assert pi.policy == policy_gridworld

def test_goggle_code_issue_5():
P = [sp.csr_matrix([[0.5, 0.5], [0.8, 0.2]]),
sp.csr_matrix([[0.0, 1.0], [0.1, 0.9]])]
Expand Down
8 changes: 7 additions & 1 deletion src/tests/test_PolicyIterationModified.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from mdptoolbox import mdp

from .utils import BaseTestMDP, assert_sequence_almost_equal
from .utils import BaseTestMDP, P_gridworld, R_gridworld, policy_gridworld, \
assert_sequence_almost_equal

class TestPolicyIterationModified(BaseTestMDP):
def test_small(self):
Expand All @@ -25,3 +26,8 @@ def test_small_undiscounted(self):
pim = mdp.PolicyIterationModified(self.small_P, self.small_R, 1)
pim.run()
assert_equal(pim.policy, (1, 0))

def test_gridworld(self):
pim = mdp.PolicyIterationModified(P_gridworld, R_gridworld, 1.0)
pim.run()
assert pim.policy == policy_gridworld
8 changes: 7 additions & 1 deletion src/tests/test_QLearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import mdptoolbox

from .utils import SMALLNUM, P_forest, R_forest, P_forest_sparse
from .utils import R_forest_sparse, P_small, R_small, P_sparse
from .utils import R_forest_sparse, P_small, R_small, P_sparse, \
P_gridworld, R_gridworld, policy_gridworld

def test_QLearning_small():
np.random.seed(0)
Expand Down Expand Up @@ -55,3 +56,8 @@ def test_QLearning_forest_sparse():
sdp.run()
p = (0, 1, 1, 1, 1, 1, 0, 0, 0, 0)
assert sdp.policy == p

def test_QLearning_gridworld():
qlearning = mdptoolbox.mdp.QLearning(P_gridworld, R_gridworld, 1.0)
qlearning.run()
assert qlearning.policy == policy_gridworld
6 changes: 6 additions & 0 deletions src/tests/test_ValueIteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .utils import SMALLNUM, P_forest, R_forest, P_forest_sparse
from .utils import R_forest_sparse, P_rand, R_rand, P_rand_sparse, R_rand_sparse
from .utils import P_small, R_small, P_sparse
from .utils import P_gridworld, R_gridworld, policy_gridworld

def test_ValueIteration_small():
sdp = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9)
Expand Down Expand Up @@ -58,3 +59,8 @@ def test_ValueIteration_rand_sparse():
sdp = mdptoolbox.mdp.ValueIteration(P_rand_sparse, R_rand_sparse, 0.9)
sdp.run()
assert sdp.policy

def test_ValueIteration_gridworld():
vi = mdptoolbox.mdp.ValueIteration(P_gridworld, R_gridworld, 1.0)
vi.run()
assert vi.policy == policy_gridworld
6 changes: 6 additions & 0 deletions src/tests/test_ValueIterationGS.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .utils import SMALLNUM, P_forest, R_forest, P_small, R_small, P_sparse
from .utils import P_forest_sparse, R_forest_sparse
from .utils import P_gridworld, R_gridworld, policy_gridworld

def test_ValueIterationGS_small():
sdp = mdptoolbox.mdp.ValueIterationGS(P_small, R_small, 0.9)
Expand Down Expand Up @@ -51,3 +52,8 @@ def test_ValueIterationGS_forest_sparse():
itr = 16 # from Octave MDPtoolbox
assert sdp.policy == p
assert sdp.iter == itr

def test_ValueIterationGS_gridworld():
vigs = mdptoolbox.mdp.ValueIterationGS(P_gridworld, R_gridworld, 1.0)
vigs.run()
assert vigs.policy == policy_gridworld
10 changes: 10 additions & 0 deletions src/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,13 @@ def assert_sequence_almost_equal(a, b, spacing=10e-12):
np.random.seed(0)
P_rand_sparse, R_rand_sparse = mdptoolbox.example.rand(STATES, ACTIONS,
is_sparse=True)


P_gridworld, R_gridworld = mdptoolbox.example.gridworld()
up, down, right, left = range(4)
policy_gridworld = (left, left, down,
up, up, up, down,
up, up, down, down,
up, right, right,
up) # for the terminal state
del up, down, right, left