diff --git a/src/mdptoolbox/example.py b/src/mdptoolbox/example.py index bccf795..e469517 100644 --- a/src/mdptoolbox/example.py +++ b/src/mdptoolbox/example.py @@ -402,3 +402,81 @@ def small(): P = _np.array([[[0.5, 0.5], [0.8, 0.2]], [[0, 1], [0.1, 0.9]]]) R = _np.array([[5, 10], [-1, 2]]) return(P, R) + + +def gridworld(): + """4x4 gridworld example. + + Example 4.1 of `Reinforcement Learning: An Introduction + `_, + by Richard S. Sutton and Andrew G. Barto. + + Returns + ------- + out : tuple + ``out[0]`` contains the transition probability matrix P, + and ``out[1]`` contains the reward matrix R. The non-terminal + states correspond to the indices 0-13 in both matrices, + and the terminal state to the index 14. + + Examples + -------- + >>> import mdptoolbox.example + >>> P, R = mdptoolbox.example.gridworld() + >>> P.shape + (4, 15, 15) + >>> R.shape + (4, 15, 15) + """ + # States: labelled 1, 2, ..., 14 in the figure, plus the terminal + # state associated with the two terminal positions. + S = 15 # number of states + terminal_state = 14 # terminal state index + + # Actions: up, down, right, left. + A = 4 # number of actions + up, down, right, left = range(A) # indices of the actions + + # Transitions. + P = _np.zeros((A, S, S)) + + # Grid transitions. + grid_transitions = { + # from_state: ((action, to_state), ...) + 1: ((down, 5), (right, 2), (left, 15)), + 2: ((down, 6), (right, 3), (left, 1)), + 3: ((down, 7), (left, 2)), + 4: ((up, 15), (down, 8), (right, 5)), + 5: ((up, 1), (down, 9), (right, 6), (left, 4)), + 6: ((up, 2), (down, 10), (right, 7), (left, 5)), + 7: ((up, 3), (down, 11), (left, 6)), + 8: ((up, 4), (down, 12), (right, 9)), + 9: ((up, 5), (down, 13), (right, 10), (left, 8)), + 10: ((up, 6), (down, 14), (right, 11), (left, 9)), + 11: ((up, 7), (down, 15), (left, 10)), + 12: ((up, 8), (right, 13)), + 13: ((up, 9), (right, 14), (left, 12)), + 14: ((up, 10), (right, 15), (left, 13)) + } + for i, moves in grid_transitions.items(): + for a, j in moves: + P[a, i - 1, j - 1] = 1.0 + + # Border transitions. + for i in (1, 2, 3): + P[up, i - 1, i - 1] = 1.0 + for i in (12, 13, 14): + P[down, i - 1, i - 1] = 1.0 + for i in (3, 7, 11): + P[right, i - 1, i - 1] = 1.0 + for i in (4, 8, 12): + P[left, i - 1, i - 1] = 1.0 + + # The terminal state should be an absorbing state. + P[:, terminal_state, terminal_state] = 1.0 + + # Rewards. + R = -1 * _np.ones((A, S, S)) + R[:, terminal_state, :] = 0 + + return P, R diff --git a/src/tests/test_PolicyIteration.py b/src/tests/test_PolicyIteration.py index 43a2e72..3473770 100644 --- a/src/tests/test_PolicyIteration.py +++ b/src/tests/test_PolicyIteration.py @@ -12,6 +12,7 @@ from .utils import SMALLNUM, P_forest, R_forest, P_small, R_small, P_sparse, \ P_forest_sparse, R_forest_sparse, \ + P_gridworld, R_gridworld, policy_gridworld, \ assert_sequence_almost_equal def test_PolicyIteration_init_policy0(): @@ -123,6 +124,11 @@ def test_PolicyIterative_forest_sparse(): assert (np.absolute(np.array(sdp.V) - v) < SMALLNUM).all() assert sdp.iter == itr +def test_PolicyIteration_gridworld(): + pi = mdptoolbox.mdp.PolicyIteration(P_gridworld, R_gridworld, 1.0) + pi.run() + assert pi.policy == policy_gridworld + def test_goggle_code_issue_5(): P = [sp.csr_matrix([[0.5, 0.5], [0.8, 0.2]]), sp.csr_matrix([[0.0, 1.0], [0.1, 0.9]])] diff --git a/src/tests/test_PolicyIterationModified.py b/src/tests/test_PolicyIterationModified.py index 3b22b03..fdcd582 100644 --- a/src/tests/test_PolicyIterationModified.py +++ b/src/tests/test_PolicyIterationModified.py @@ -11,7 +11,8 @@ from mdptoolbox import mdp -from .utils import BaseTestMDP, assert_sequence_almost_equal +from .utils import BaseTestMDP, P_gridworld, R_gridworld, policy_gridworld, \ + assert_sequence_almost_equal class TestPolicyIterationModified(BaseTestMDP): def test_small(self): @@ -25,3 +26,8 @@ def test_small_undiscounted(self): pim = mdp.PolicyIterationModified(self.small_P, self.small_R, 1) pim.run() assert_equal(pim.policy, (1, 0)) + + def test_gridworld(self): + pim = mdp.PolicyIterationModified(P_gridworld, R_gridworld, 1.0) + pim.run() + assert pim.policy == policy_gridworld diff --git a/src/tests/test_QLearning.py b/src/tests/test_QLearning.py index cb42033..a07dace 100644 --- a/src/tests/test_QLearning.py +++ b/src/tests/test_QLearning.py @@ -10,7 +10,8 @@ import mdptoolbox from .utils import SMALLNUM, P_forest, R_forest, P_forest_sparse -from .utils import R_forest_sparse, P_small, R_small, P_sparse +from .utils import R_forest_sparse, P_small, R_small, P_sparse, \ + P_gridworld, R_gridworld, policy_gridworld def test_QLearning_small(): np.random.seed(0) @@ -55,3 +56,8 @@ def test_QLearning_forest_sparse(): sdp.run() p = (0, 1, 1, 1, 1, 1, 0, 0, 0, 0) assert sdp.policy == p + +def test_QLearning_gridworld(): + qlearning = mdptoolbox.mdp.QLearning(P_gridworld, R_gridworld, 1.0) + qlearning.run() + assert qlearning.policy == policy_gridworld diff --git a/src/tests/test_ValueIteration.py b/src/tests/test_ValueIteration.py index 5e634d8..ae605bb 100644 --- a/src/tests/test_ValueIteration.py +++ b/src/tests/test_ValueIteration.py @@ -12,6 +12,7 @@ from .utils import SMALLNUM, P_forest, R_forest, P_forest_sparse from .utils import R_forest_sparse, P_rand, R_rand, P_rand_sparse, R_rand_sparse from .utils import P_small, R_small, P_sparse +from .utils import P_gridworld, R_gridworld, policy_gridworld def test_ValueIteration_small(): sdp = mdptoolbox.mdp.ValueIteration(P_small, R_small, 0.9) @@ -58,3 +59,8 @@ def test_ValueIteration_rand_sparse(): sdp = mdptoolbox.mdp.ValueIteration(P_rand_sparse, R_rand_sparse, 0.9) sdp.run() assert sdp.policy + +def test_ValueIteration_gridworld(): + vi = mdptoolbox.mdp.ValueIteration(P_gridworld, R_gridworld, 1.0) + vi.run() + assert vi.policy == policy_gridworld diff --git a/src/tests/test_ValueIterationGS.py b/src/tests/test_ValueIterationGS.py index 6f37d24..2fc1bfd 100644 --- a/src/tests/test_ValueIterationGS.py +++ b/src/tests/test_ValueIterationGS.py @@ -11,6 +11,7 @@ from .utils import SMALLNUM, P_forest, R_forest, P_small, R_small, P_sparse from .utils import P_forest_sparse, R_forest_sparse +from .utils import P_gridworld, R_gridworld, policy_gridworld def test_ValueIterationGS_small(): sdp = mdptoolbox.mdp.ValueIterationGS(P_small, R_small, 0.9) @@ -51,3 +52,8 @@ def test_ValueIterationGS_forest_sparse(): itr = 16 # from Octave MDPtoolbox assert sdp.policy == p assert sdp.iter == itr + +def test_ValueIterationGS_gridworld(): + vigs = mdptoolbox.mdp.ValueIterationGS(P_gridworld, R_gridworld, 1.0) + vigs.run() + assert vigs.policy == policy_gridworld diff --git a/src/tests/utils.py b/src/tests/utils.py index 298163e..629f6ed 100644 --- a/src/tests/utils.py +++ b/src/tests/utils.py @@ -39,3 +39,13 @@ def assert_sequence_almost_equal(a, b, spacing=10e-12): np.random.seed(0) P_rand_sparse, R_rand_sparse = mdptoolbox.example.rand(STATES, ACTIONS, is_sparse=True) + + +P_gridworld, R_gridworld = mdptoolbox.example.gridworld() +up, down, right, left = range(4) +policy_gridworld = (left, left, down, + up, up, up, down, + up, up, down, down, + up, right, right, + up) # for the terminal state +del up, down, right, left