Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 034a112

Browse files
author
Jules Pondard
committed
Add vanilla MCTS/UCT algorithm for options
Generate options using UCT algorithm. Useful for benchmarking.
1 parent 842ac12 commit 034a112

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed

Diff for: python/experimental/options_search/mcts.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import tensor_comprehensions as tc
2+
import torch
3+
import utils
4+
import numpy as np
5+
#from tqdm import tqdm
6+
from visdom import Visdom
7+
8+
viz = Visdom()
9+
10+
class Node:
11+
def __init__(self, father=None, new_act=0):
12+
self.value = 0
13+
self.values = []
14+
self.nbVisits=0
15+
self.nbChildrenSeen = 0
16+
self.pos=0
17+
#self.hasSeen = {} #todo
18+
self.children=[]
19+
self.parent = father
20+
self.stateVector = [0] * utils.NB_HYPERPARAMS
21+
if(father != None):
22+
self.pos = father.pos+1
23+
#self.hasSeen = {} #todo
24+
self.stateVector = father.stateVector[:]
25+
self.stateVector[self.pos-1] = new_act
26+
27+
def getRoot(self):
28+
return self
29+
30+
def getParent(self):
31+
return self.parent
32+
33+
def notRoot(self):
34+
return (self.parent != None)
35+
36+
class MCTS:
37+
def __init__(self):
38+
self.C = 1 #to tune
39+
40+
(tc_code, tc_name, inp, _) = utils.get_convolution_example(size_type="input", inp_sz_list=[8,2,28,28,8,1,1])
41+
42+
self.nbActions = utils.cat_sz
43+
self.tree = Node()
44+
45+
self.best_rewards = []
46+
self.rws = []
47+
48+
self.curIter=0
49+
self.curr_best=0
50+
self.running_reward=0
51+
self.win0 = viz.line(X=np.arange(5), Y=np.random.rand(5))
52+
53+
def main_search(self, starting_pos): #, init_inp):
54+
node = starting_pos
55+
#node.nbVisits+=1
56+
ttNbIters = 10 #2*self.nbActions[node.pos]
57+
for _ in range(max(ttNbIters, self.nbActions[node.pos])):
58+
leaf = self.getLeaf(node)
59+
val = self.evaluate(leaf)
60+
self.backup(leaf, val)
61+
#print(node.value / node.nbVisits)
62+
_, action = self.getBestChild2(node)
63+
return action
64+
65+
def take_action(self, node, act):
66+
if(node.nbChildrenSeen > act):
67+
return node.children[act]
68+
new_child = Node(father=node, new_act=act)
69+
node.children.append(new_child)
70+
#node.hasSeen[act]=1
71+
node.nbChildrenSeen += 1
72+
return node.children[-1]
73+
74+
def getLeaf(self, node):
75+
first=True
76+
while(node.pos < utils.NB_HYPERPARAMS and (first or node.nbVisits != 0)):
77+
first=False
78+
pos = node.pos
79+
if(node.nbChildrenSeen == self.nbActions[pos]):
80+
node, _ = self.getBestChild(node)
81+
else:
82+
act=node.nbChildrenSeen
83+
self.take_action(node, act)
84+
return node.children[-1]
85+
return node
86+
87+
def getBestChild2(self, node):
88+
bestIndic = 0.
89+
bestAction = 0
90+
first=True
91+
pos = node.pos
92+
for act in range(self.nbActions[pos]):
93+
child = node.children[act]
94+
#indic = np.percentile(child.values, 20)
95+
indic = child.value / child.nbVisits
96+
if(first or indic > bestIndic):
97+
bestIndic = indic
98+
bestAction = act
99+
first=False
100+
return node.children[bestAction], bestAction
101+
102+
def getBestChild(self, node):
103+
bestIndic = 0.
104+
bestAction = 0
105+
first=True
106+
pos = node.pos
107+
for act in range(self.nbActions[pos]):
108+
child = node.children[act]
109+
#indic = np.percentile(child.values, 20) + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits)
110+
indic = child.value / child.nbVisits + self.C * np.sqrt(2*np.log(node.nbVisits) / child.nbVisits)
111+
if(first or indic > bestIndic):
112+
bestIndic = indic
113+
bestAction = act
114+
first=False
115+
return node.children[bestAction], bestAction
116+
117+
def saveReward(self, reward, opts):
118+
INTER_DISP = 20
119+
#print(-reward)
120+
if(self.curIter == 0):
121+
self.running_reward = reward
122+
self.curr_best = reward
123+
if(self.curIter == 0 or reward > self.curr_best):
124+
print(-reward)
125+
print(opts)
126+
self.curIter += 1
127+
self.running_reward = self.running_reward * 0.99 + reward * 0.01
128+
self.curr_best = max(self.curr_best, reward)
129+
#self.rewards.append(-reward)
130+
self.best_rewards.append(-self.curr_best)
131+
self.rws.append(-self.running_reward)
132+
if self.curIter % INTER_DISP == 0:
133+
viz.line(X=np.column_stack((np.arange(self.curIter), np.arange(self.curIter))), \
134+
Y=np.column_stack((np.array(self.rws), np.array(self.best_rewards))), \
135+
win=self.win0, opts=dict(legend=["Geometric run", "Best time"]))
136+
137+
def randomSampleScoreFrom(self, node):
138+
pos = node.pos
139+
optsVector = node.stateVector
140+
for i in range(utils.NB_HYPERPARAMS - (pos)):
141+
a = np.random.randint(self.nbActions[i+pos])
142+
optsVector[i+(pos)] = a
143+
#print(optsVector)
144+
reward = -np.log(utils.evalTime(optsVector))
145+
self.saveReward(reward, optsVector)
146+
return reward
147+
148+
def evaluate(self, leaf):
149+
score = 0
150+
nb_iters=5
151+
for _ in range(nb_iters):
152+
score += self.randomSampleScoreFrom(leaf)
153+
return score / nb_iters
154+
155+
def backup(self, leaf, val):
156+
#if(val > 10.): #infty
157+
# return
158+
node = leaf
159+
while(node.notRoot()):
160+
node.nbVisits += 1
161+
#node.values.append(val)
162+
node.value += val
163+
node = node.getParent()
164+
node.nbVisits += 1
165+
node.value += val
166+
node.values.append(val)
167+
168+
mcts = MCTS()
169+
170+
opts = []
171+
curr_node = mcts.tree
172+
for i in range(utils.NB_HYPERPARAMS):
173+
opts.append(mcts.main_search(curr_node))
174+
curr_node = mcts.take_action(curr_node, opts[-1])
175+
print(opts)
176+
opts = np.array(opts).astype(int)
177+
print(utils.evalTime(opts.tolist()))
178+
utils.print_opt(opts)

0 commit comments

Comments
 (0)