-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpacmanMdp.py
219 lines (171 loc) · 6.94 KB
/
pacmanMdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# pacmanMdp.py
# IA UC3M 2016
# -----------------------
##
from game import GameStateData
from game import Game
from game import Actions
from util import nearestPoint
import util, layout
import sys, types, time, random, os
import mdp
from featureExtractors import *
class PacmanMdp(mdp.MarkovDecisionProcess):
"""
pacman MDP
"""
def __init__(self, extractor='StateExtractor'):
# Feature extractor
self.featExtractor = util.lookup(extractor, globals())()
# Transition function (data structure required for the transition function)
#*** YOUR CODE STARTS HERE ***"
# Code to remove ---------- from here
# self.frequencies = util.Counter()
self.frequencies = dict()
# Code to remove ---------- to here
#"*** YOUR CODE FINISHES HERE ***"
# Dictionary with examples of a Low state for each High state: it serves to get possible actions
# and to check terminal states (though it is not required if the high level representation
# capture them)
self.states = util.Counter()
# Reward for each state at the high level representation
self.reward = util.Counter()
def stateToHigh(self, stateL):
"""
Returns the high level representation of an state
"""
return tuple(self.featExtractor.getFeatures(stateL).values())
def addStateLow(self, stateH, stateL):
"""
Adds a new pair stateH stateL to the dictionary of states
"""
# print "Added", stateH
if not stateH in self.states.keys():
self.states[stateH] = stateL
self.reward[stateH] = [1, [stateL.getScore()]]
else:
self.reward[stateH][0] += 1
self.reward[stateH][1].append(stateL.getScore())
def updateTransitionFunction(self, stateL, action, nextStateL):
"""
Updates the transition function with a new case stateL, action, nextStateL
The states received as parameters have a low level representation. The transition function
should be stored over the high level (simplified) representation
"""
# Change the representation to the simplified one
state = self.stateToHigh(stateL)
nextState= self.stateToHigh(nextStateL)
# Set the start state in the first call
if len(self.states.keys())== 0:
self.setStartState(state)
# Add the received states to self.states
self.addStateLow(state, stateL)
self.addStateLow(nextState, nextStateL)
##util.raiseNotDefined()
#"*** YOUR CODE STARTS HERE ***"
# Code to remove ---------- from here
stateAction = (state, action)
if self.frequencies.has_key(stateAction):
self.frequencies[stateAction][nextState] += 1
else:
self.frequencies[stateAction]= util.Counter()
self.frequencies[stateAction][nextState] = 1
# if self.frequencies.has_key(stateAction):
# if self.frequencies[stateAction].has_key(nextState):
# self.frequencies[stateAction][nextState] += 1
# else:
# if len(self.frequencies[stateAction].keys()) > 0:
# #print "SEVERAL"
# self.frequencies[stateAction][nextState] = 1
# else:
# self.frequencies[stateAction]= util.Counter()
# self.frequencies[stateAction][nextState] = 1
# else:
# self.frequencies[stateAction]= util.Counter()
# self.frequencies[stateAction][nextState] = 1
#print "MDP transition function UPDATE", stateAction, nextState,
self.frequencies[stateAction][nextState]
# Code to remove ---------- to here
#"*** YOUR CODE FINISHES HERE ***"
def getPossibleActions(self, state):
"""
Returns list of valid actions for 'state'.
Note that you can request moves into walls and
that "exit" states transition to the terminal
state under the special action "done".
"""
if not state in self.states.keys():
return []
return (self.states[state]).getLegalActions(0)
def getStates(self):
"""
Return list of all states.
"""
return self.states.keys()
def isKnownState(self, state):
"""
True if the state is in the dict of states.
"""
return state in self.states.keys()
def getAverageReward(self, state):
"""
Return average rewards of the known low level states represented by a high level state
"""
return sum(i for i in self.reward[state][1])/self.reward[state][0]
def getReward(self, state, action, nextState):
"""
Get reward for state, action, nextState transition.
"""
return self.getAverageReward(nextState) - self.getAverageReward(state)
def setStartState(self, state):
"""
set for start state
"""
self.startState = state
def getStartState(self):
"""
get for start state
"""
return startState
def isTerminal(self, state):
"""
Pacman terminal states
"""
if not state in self.states.keys():
return self.featExtractor.isTerminalFeatures(state)
else:
return self.states[state].isLose() or self.states[state].isWin()
def printMdp( self ):
"""
Shows the transition function of the MDP
"""
for state in self.states.keys():
for action in self.getPossibleActions(state):
print state, action, self.getTransitionStatesAndProbabilities(state, action)
def getTransitionStatesAndProbabilities(self, state, action):
"""
Returns list of (nextState, prob) pairs
representing the states reachable
from 'state' by taking 'action' along
with their transition probabilities.
"""
if action not in self.getPossibleActions(state):
raise "Illegal action!"
if self.isTerminal(state):
return []
successors = []
##util.raiseNotDefined()
#"*** YOUR CODE STARTS HERE ***"
# Code to remove --- from here
stateActionKey = (state, action)
if self.frequencies.has_key(stateActionKey):
total = 0.0
#print "MDP Fun", stateActionKey, len(self.frequencies[stateActionKey].keys())
for key in self.frequencies[stateActionKey].keys():
total += self.frequencies[stateActionKey][key]
for key in self.frequencies[stateActionKey].keys():
successors.append((key, self.frequencies[stateActionKey][key]/total))
#print " SUCC", key, self.frequencies[stateActionKey][key]/total
# Code to remove --- to here
#"*** YOUR CODE FINISHES HERE ***"
return successors