Skip to content

Commit

Permalink
Reset Bug Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nitish Bhupathi Raju authored and Nitish Bhupathi Raju committed Aug 6, 2024
1 parent 4cf54dc commit 8492cb8
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 302 deletions.
58 changes: 36 additions & 22 deletions verse/analysis/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,10 @@ def apply_reset_stars(
ego_type = find(agent.decision_logic.args, lambda a: a.name == EGO).typ


new_state = agent_state[1].starcopy() #copy.deepcopy([agent_state[0][1:], agent_state[1][1:]])
#Modified
old_state = agent_state[1].starcopy() #copy.deepcopy([agent_state[0][1:], agent_state[1][1:]])
reset_vars = {}
expr_list = {}

# The reset_list here are all the resets for a single transition. Need to evaluate each of them
# and then combine them together
Expand Down Expand Up @@ -1388,6 +1391,7 @@ def apply_reset_stars(
# Assume linear function for continuous variables
else:
#agent_state.continuous_reset(reset_variable, expr, agent, ego_type,cont_var_dict, rect)
#breakpoint()
lhs = reset_variable
rhs = expr
found = False
Expand All @@ -1396,6 +1400,8 @@ def apply_reset_stars(
):
if cts_variable == lhs:
found = True
expr_list[lhs_idx] = rhs
reset_vars[lhs_idx] = lhs
break
if not found:
raise ValueError(f"Reset continuous variable {cts_variable} not found")
Expand All @@ -1405,39 +1411,47 @@ def apply_reset_stars(
#TODO: check that this only gets run on ego?
if 'ego' in var:
statevec.append(var)




#print(statevec)

#concern: how to handle the case where you need other agents state. for now: assume you do not
def reset_func(state): #[ego.x, ego.y, ...]
#breakpoint()
output = np.copy(state)
val_dict = {}
tmp_exp = copy.deepcopy(expr)
for i in range(0, len(state)):
if statevec[i] in tmp_exp:
tmp_exp = tmp_exp.replace(statevec[i], str(state[i]))
#print(tmp_exp)
result = eval(tmp_exp, {}, val_dict)
for i in range(0, len(state)):
if lhs in statevec[i]:
output[i] = result
return output
#print("TODO: find where/when this gets set elsewhere")
#breakpoint()
#print("foo")
#breakpoint()
new_state = new_state.apply_reset(reset_func)
#print("bar")
#breakpoint()
def reset_func(state, expr_list, reset_vars): #[ego.x, ego.y, ...]
#breakpoint()
output = np.copy(state)
idxs = list(reset_vars.keys())
for idx in idxs:
val_dict = {}
tmp_exp = copy.deepcopy(expr_list[idx])
for i in range(0, len(state)):
if statevec[i] in tmp_exp:
tmp_exp = tmp_exp.replace(statevec[i], str(state[i]))
#print(tmp_exp)
result = eval(tmp_exp, {}, val_dict)
for i in range(0, len(state)):
if reset_vars[idx] == statevec[i].split('.',1)[1]:
output[i] = result
return output
#print("TODO: find where/when this gets set elsewhere")
#breakpoint()
#print("foo")
#breakpoint()
new_state = old_state.apply_reset(reset_func, expr_list, reset_vars)
#print("bar")
#breakpoint()


all_dest = itertools.product(*possible_dest)
dest = []
for tmp in all_dest:
dest.append(tmp)
#breakpoint()

# print("apply_reset")
# print(dest)
# print(new_state.overapprox_rectangle())
return dest, new_state


Expand Down
69 changes: 51 additions & 18 deletions verse/stars/reach_at_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,61 @@
import cvxpy as cp
from z3 import *
from verse.analysis import AnalysisTreeNode, AnalysisTree, AnalysisTreeNodeType
from verse.utils.star_manhattan import *

def time_step_diameter(trace, time_horizon, time_step):
time_steps = np.linspace(0, time_horizon, int(time_horizon/time_step) + 1, endpoint= True)
time_steps = np.append(time_steps, [time_steps[-1] + time_step])
diameters = []
#breakpoint()
for i in range(len(time_steps) - 1):
curr_diam = []
reach_tubes = reach_star(trace, time_steps[i] - time_step*.05, time_steps[i+1] - time_step*.05)
agents = list(reach_tubes.keys())
#breakpoint()
for agent in agents:
if reach_tubes[agent] is not None:
if len(reach_tubes[agent]) > 0:
for j in range(0, len(reach_tubes[agent])):
star = reach_tubes[agent][j]
curr_diam.append(star_manhattan_distance(star.center, star.basis, star.C, star.g))
# breakpoint()
# print(time_steps[i])
# print(time_steps[i+1])
# print(reach_tubes)
# print(star.overapprox_rectangle())
if len(curr_diam) > 0:
#print(curr_diam)
diameters.append(max(curr_diam))
#breakpoint()

#print(diameters)
return diameters

### assuming mode is not a parameter
def reach_star(trace: AnalysisTree, t_l: float = 0, t_u: float = None) -> Dict[str, List[StarSet]]:
def reach_star(traces, t_l: float = 0, t_u: float = None) -> Dict[str, List[StarSet]]:
reach = {}
nodes: List[AnalysisTreeNode] = trace.nodes
agents = nodes[0].agent.keys() # list of agents

if t_u is None:

nodes = traces.nodes
agents = list(nodes[0].trace.keys())

# if t_u is None:
# for agent in trace:
# t_u = trace[agent][-1][0] # T
# break

for node in nodes:
trace = node.trace
for agent in agents:
t_u = trace[agent][-1][0] # T
break

for agent in trace:
for i in range(len(trace[agent])):
cur = trace[agent][i]
if cur[0]<t_l:
continue
if cur[0]>t_u:
break
if agent not in reach:
reach[agent] = []
reach[agent].append(cur[1]) # just store the star set
for i in range(len(trace[agent])):
cur = trace[agent][i]
if cur[0]<t_l:
continue
if cur[0]>t_u:
break
if agent not in reach:
reach[agent] = []
reach[agent].append(cur[1]) # just store the star set

return reach

Expand Down
Loading

0 comments on commit 8492cb8

Please sign in to comment.