diff --git a/demo/dryvr_demo/rendezvous_controller.py b/demo/dryvr_demo/rendezvous_controller.py index acc6fdc7..d1b9e0f0 100644 --- a/demo/dryvr_demo/rendezvous_controller.py +++ b/demo/dryvr_demo/rendezvous_controller.py @@ -46,7 +46,7 @@ def decisionLogic(ego: State): # assert (ego.craft_mode!=CraftMode.ProxB or\ # (ego.xp>=-105 and ego.yp>=0.57735*ego.xp and -ego.yp>=0.57735*ego.xp)), "Line-of-sight" - assert ego.craft_mode != CraftMode.Passive or ( - ego.xp <= -0.2 or ego.xp >= 0.2 or ego.yp <= -0.2 or ego.yp >= 0.2 - ), "Collision avoidance" + # assert ego.craft_mode != CraftMode.Passive or ( + # ego.xp <= -0.2 or ego.xp >= 0.2 or ego.yp <= -0.2 or ego.yp >= 0.2 + # ), "Collision avoidance" return output diff --git a/demo/dryvr_demo/rendezvous_demo.py b/demo/dryvr_demo/rendezvous_demo.py index c5ac1621..70474c04 100644 --- a/demo/dryvr_demo/rendezvous_demo.py +++ b/demo/dryvr_demo/rendezvous_demo.py @@ -1,5 +1,5 @@ from origin_agent import craft_agent -from verse import Scenario +from verse import Scenario, ScenarioConfig from verse.plotter.plotter2D import * from verse.sensor.example_sensor.craft_sensor import CraftSensor @@ -15,7 +15,7 @@ class CraftMode(Enum): if __name__ == "__main__": input_code_name = "./demo/dryvr_demo/rendezvous_controller.py" - scenario = Scenario() + scenario = Scenario(ScenarioConfig(parallel=False)) car = craft_agent("test", file_name=input_code_name) scenario.add_agent(car) diff --git a/demo/traffic_signal/mp0.py b/demo/traffic_signal/mp0.py new file mode 100644 index 00000000..4ea691a7 --- /dev/null +++ b/demo/traffic_signal/mp0.py @@ -0,0 +1,384 @@ +from typing import Tuple, List + +import numpy as np +from scipy.integrate import ode + +from verse import BaseAgent, Scenario +from verse.analysis.utils import wrap_to_pi +from verse.analysis.analysis_tree import TraceType, AnalysisTree +from verse.parser import ControllerIR +from vehicle_controller import VehicleMode +from verse.analysis import AnalysisTreeNode, AnalysisTree, AnalysisTreeNodeType + +import copy + +refine_profile = { + 'R1': [0], + 'R2': [3,3,3,0], + 'R3': [3,3,3,0] +} + +def tree_safe(tree: AnalysisTree): + for node in tree.nodes: + if node.assert_hits is not None: + return False + return True + +def verify_refine(scenario: Scenario, time_horizon, time_step): + refine_depth = 10 + init_car = scenario.init_dict['car'] + init_ped = scenario.init_dict['pedestrian'] + partition_depth = 0 + if init_ped[1][0] - init_ped[0][0]>0.1: + exp = 'R3' + elif init_car[1][3] - init_car[0][3] > 0.1: + exp = 'R2' + else: + exp = 'R1' + res_list = [] + init_queue = [] + if init_car[1][3]-init_car[0][3] > 0.05: + car_v_init_range = np.linspace(init_car[0][3], init_car[1][3], 33) + else: + car_v_init_range = [init_car[0][3], init_car[1][3]] + if init_car[1][0]-init_car[0][0] > 0.1: + car_x_init_range = np.linspace(init_car[0][0], init_car[1][0], 5) + else: + car_x_init_range = [init_car[0][0], init_car[1][0]] + for i in range(len(car_x_init_range)-1): + for j in range(len(car_v_init_range)-1): + tmp = copy.deepcopy(init_car) + tmp[0][0] = car_x_init_range[i] + tmp[1][0] = car_x_init_range[i+1] + tmp[0][3] = car_v_init_range[j] + tmp[1][3] = car_v_init_range[j+1] + init_queue.append((tmp, init_ped, partition_depth)) + # init_queue = [(init_car, init_ped, partition_depth)] + while init_queue!=[] and partition_depth < refine_depth: + car_init, ped_init, partition_depth = init_queue.pop(0) + print(f"######## {partition_depth}, car x, {car_init[0][0]}, {car_init[1][0]}, car v, {car_init[0][3]}, {car_init[1][3]}, ped x, {ped_init[0][0]}, {ped_init[1][0]}, ped y, {ped_init[0][1]}, {ped_init[1][1]}") + scenario.set_init_single('car', car_init, (VehicleMode.Normal,)) + scenario.set_init_single('pedestrian', ped_init, (PedestrianMode.Normal,)) + traces = scenario.verify(time_horizon, time_step) + if not tree_safe(traces): + # Partition car and pedestrian initial state + idx = refine_profile[exp][partition_depth%len(refine_profile[exp])] + if car_init[1][idx] - car_init[0][idx] < 0.01: + print(f"Stop refine car state {idx}") + init_queue.append((car_init, ped_init, partition_depth+1)) + elif partition_depth >= refine_depth: + print('Threshold Reached. Scenario is UNSAFE.') + res_list.append(traces) + break + car_v_init = (car_init[0][idx] + car_init[1][idx])/2 + car_init1 = copy.deepcopy(car_init) + car_init1[1][idx] = car_v_init + init_queue.append((car_init1, ped_init, partition_depth+1)) + car_init2 = copy.deepcopy(car_init) + car_init2[0][idx] = car_v_init + init_queue.append((car_init2, ped_init, partition_depth+1)) + else: + res_list.append(traces) + com_traces = combine_tree(res_list) + + return com_traces + +class TrafficSignalAgent(BaseAgent): + def __init__( + self, + id, + file_name + ): + super().__init__(id, code = None, file_name = file_name) + + def TC_simulate( + self, mode: List[str], init, time_bound, time_step, lane_map = None + ) -> TraceType: + time_bound = float(time_bound) + num_points = int(np.ceil(time_bound / time_step)) + trace = np.zeros((num_points + 1, 1 + len(init))) + trace[1:, 0] = [round((i+1) * time_step, 10) for i in range(num_points)] + trace[:,-1] = trace[:,0] + trace[:, 1:-1] = init[:-1] + return trace + +class VehicleAgent(BaseAgent): + def __init__( + self, + id, + code = None, + file_name = None, + accel_brake = 5, + accel_notbrake = 5, + accel_hardbrake = 20, + speed = 10 + ): + super().__init__( + id, code, file_name + ) + self.accel_brake = accel_brake + self.accel_notbrake = accel_notbrake + self.accel_hardbrake = accel_hardbrake + self.speed = speed + self.vmax = 20 + + @staticmethod + def dynamic(t, state, u): + x, y, theta, v = state + delta, a = u + x_dot = v * np.cos(theta + delta) + y_dot = v * np.sin(theta + delta) + theta_dot = v / 1.75 * np.sin(delta) + v_dot = a + return [x_dot, y_dot, theta_dot, v_dot] + + def action_handler(self, mode: List[str], state) -> Tuple[float, float]: + x, y, theta, v = state + vehicle_mode, = mode + vehicle_pos = np.array([x, y]) + a = 0 + lane_width = 3 + d = -y + if vehicle_mode == "Normal" or vehicle_mode == "Stop": + pass + elif vehicle_mode == "SwitchLeft": + d += lane_width + elif vehicle_mode == "SwitchRight": + d -= lane_width + elif vehicle_mode == "Brake": + a = max(-self.accel_brake, -v) + # a = -50 + elif vehicle_mode == "HardBrake": + a = max(-self.accel_hardbrake, -v) + # a = -50 + elif vehicle_mode == "Accel": + a = min(self.accel_notbrake, self.speed-v) + else: + raise ValueError(f"Invalid mode: {vehicle_mode}") + + heading = 0 + psi = wrap_to_pi(heading - theta) + steering = psi + np.arctan2(0.45 * d, v) + steering = np.clip(steering, -0.61, 0.61) + return steering, a + + def TC_simulate( + self, mode: List[str], init, time_bound, time_step, lane_map = None + ) -> TraceType: + time_bound = float(time_bound) + num_points = int(np.ceil(time_bound / time_step)) + trace = np.zeros((num_points + 1, 1 + len(init))) + trace[1:, 0] = [round(i * time_step, 10) for i in range(num_points)] + trace[0, 1:] = init + for i in range(num_points): + steering, a = self.action_handler(mode, init) + r = ode(self.dynamic) + r.set_initial_value(init).set_f_params([steering, a]) + res: np.ndarray = r.integrate(r.t + time_step) + init = res.flatten() + if init[3] < 0: + init[3] = 0 + trace[i + 1, 0] = time_step * (i + 1) + trace[i + 1, 1:] = init + return trace + +def dist(pnt1, pnt2): + return np.linalg.norm( + np.array(pnt1) - np.array(pnt2) + ) + +def get_extreme(rect1, rect2): + lb11 = rect1[0] + lb12 = rect1[1] + ub11 = rect1[2] + ub12 = rect1[3] + + lb21 = rect2[0] + lb22 = rect2[1] + ub21 = rect2[2] + ub22 = rect2[3] + + # Using rect 2 as reference + left = lb21 > ub11 + right = ub21 < lb11 + bottom = lb22 > ub12 + top = ub22 < lb12 + + if top and left: + dist_min = dist((ub11, lb12),(lb21, ub22)) + dist_max = dist((lb11, ub12),(ub21, lb22)) + elif bottom and left: + dist_min = dist((ub11, ub12),(lb21, lb22)) + dist_max = dist((lb11, lb12),(ub21, ub22)) + elif top and right: + dist_min = dist((lb11, lb12), (ub21, ub22)) + dist_max = dist((ub11, ub12), (lb21, lb22)) + elif bottom and right: + dist_min = dist((lb11, ub12),(ub21, lb22)) + dist_max = dist((ub11, lb12),(lb21, ub22)) + elif left: + dist_min = lb21 - ub11 + dist_max = np.sqrt((lb21 - ub11)**2 + max((ub22-lb12)**2, (ub12-lb22)**2)) + elif right: + dist_min = lb11 - ub21 + dist_max = np.sqrt((lb21 - ub11)**2 + max((ub22-lb12)**2, (ub12-lb22)**2)) + elif top: + dist_min = lb12 - ub22 + dist_max = np.sqrt((ub12 - lb22)**2 + max((ub21-lb11)**2, (ub11-lb21)**2)) + elif bottom: + dist_min = lb22 - ub12 + dist_max = np.sqrt((ub22 - lb12)**2 + max((ub21-lb11)**2, (ub11-lb21)**2)) + else: + dist_min = 0 + dist_max = max( + dist((lb11, lb12), (ub21, ub22)), + dist((lb11, ub12), (ub21, lb22)), + dist((ub11, lb12), (lb21, ub12)), + dist((ub11, ub12), (lb21, lb22)) + ) + return dist_min, dist_max + +class TrafficSensor: + def __init__(self): + self.sensor_distance = 60 + + # The baseline sensor is omniscient. Each agent can get the state of all other agents + def sense(self, agent: BaseAgent, state_dict, lane_map): + len_dict = {} + cont = {} + disc = {} + len_dict = {"others": len(state_dict) - 1} + tmp = np.array(list(state_dict.values())[0][0]) + if tmp.ndim < 2: + if agent.id == 'car': + len_dict['others'] = 1 + cont['ego.x'] = state_dict['car'][0][1] + cont['ego.y'] = state_dict['car'][0][2] + cont['ego.theta'] = state_dict['car'][0][3] + cont['ego.v'] = state_dict['car'][0][4] + disc['ego.agent_mode'] = state_dict['car'][1][0] + disc['other.signal_mode'] = state_dict['tl'][1][0] + dist = np.sqrt( + (state_dict['car'][0][1]-state_dict['tl'][0][1])**2+\ + (state_dict['car'][0][2]-state_dict['tl'][0][2])**2 + ) + if dist < self.sensor_distance: + cont['other.dist'] = dist + else: + cont['other.dist'] = 1000 + elif agent.id == 'tl': + cont['ego.timer'] = state_dict['tl'][0][5] + disc['ego.signal_mode'] = state_dict['tl'][1][0] + else: + if agent.id == 'car': + len_dict['others'] = 1 + dist_min, dist_max = get_extreme( + (state_dict['car'][0][0][1],state_dict['car'][0][0][2],state_dict['car'][0][1][1],state_dict['car'][0][1][2]), + (state_dict['tl'][0][0][1],state_dict['tl'][0][0][2],state_dict['tl'][0][1][1],state_dict['tl'][0][1][2]), + ) + cont['ego.x'] = [ + state_dict['car'][0][0][1], state_dict['car'][0][1][1] + ] + cont['ego.y'] = [ + state_dict['car'][0][0][2], state_dict['car'][0][1][2] + ] + cont['ego.theta'] = [ + state_dict['car'][0][0][3], state_dict['car'][0][1][3] + ] + cont['ego.v'] = [ + state_dict['car'][0][0][4], state_dict['car'][0][1][4] + ] + cont['other.dist'] = [ + dist_min, dist_max + ] + disc['ego.agent_mode'] = state_dict['car'][1][0] + disc['other.signal_mode'] = state_dict['tl'][1][0] + if dist_min 20: + output.signal_mode = TLMode.YELLOW + output.timer = 0 + if ego.signal_mode == TLMode.YELLOW and ego.timer > 5: + output.signal_mode = TLMode.RED + output.timer = 0 + if ego.signal_mode == TLMode.RED and ego.timer > 10: + output.signal_mode = TLMode.GREEN + output.timer = 0 + + # assert True + return output \ No newline at end of file diff --git a/demo/traffic_signal/traffic_signal_scenario.py b/demo/traffic_signal/traffic_signal_scenario.py new file mode 100644 index 00000000..c66549b6 --- /dev/null +++ b/demo/traffic_signal/traffic_signal_scenario.py @@ -0,0 +1,86 @@ +from mp0 import VehicleAgent, TrafficSignalAgent, TrafficSensor, verify_refine, eval_velocity, sample_init +from verse import Scenario, ScenarioConfig +from vehicle_controller import VehicleMode, TLMode + +from verse.plotter.plotter2D import * +from verse.plotter.plotter3D_new import * +import plotly.graph_objects as go +import copy + +if __name__ == "__main__": + import os + script_dir = os.path.realpath(os.path.dirname(__file__)) + input_code_name = os.path.join(script_dir, "vehicle_controller.py") + vehicle = VehicleAgent('car', file_name=input_code_name) + input_code_name = os.path.join(script_dir, "traffic_controller.py") + tl = TrafficSignalAgent('tl', file_name=input_code_name) + + scenario = Scenario(ScenarioConfig(init_seg_length=1, parallel=False)) + + scenario.add_agent(vehicle) + scenario.add_agent(tl) + scenario.set_sensor(TrafficSensor()) + + # # R1 + init_car = [[-5,-5,0,8],[5,5,0,8]] + init_pedestrian = [[140,0,0,0,0],[140,0,0,0,0]] + + # R2 + # init_car = [[-5,-5,0,5],[5,5,0,10]] + # init_pedestrian = [[140,0,0,0,0],[140,0,0,0,0]] + + # # R3 + # init_car = [[-5,-5,0,5],[5,5,0,10]] + # init_pedestrian = [[140,-55,0,3],[150,-50,0,3]] + + scenario.set_init_single( + 'car', init_car,(VehicleMode.Normal,) + ) + scenario.set_init_single( + 'tl', init_pedestrian, (TLMode.GREEN,) + ) + + # # ----------- Simulate single ------------- + trace = scenario.simulate(50, 0.1) + fig = go.Figure() + # fig = simulation_tree_3d(trace, fig,\ + # 0,'time', 1,'x',2,'y') + # fig.show() + fig = simulation_tree(trace, None, fig, 0, 1) + fig.show() + + # # ----------- Simulate multi ------------- + # init_dict_list= sample_init(scenario, num_sample=50) + # trace_list = scenario.simulate_multi(50, 0.1,\ + # init_dict_list=init_dict_list) + # fig = go.Figure() + # for trace in trace_list: + # fig = simulation_tree_3d(trace, fig,\ + # 0,'time', 1,'x',2,'y') + # fig.show() + # avg_vel, unsafe_frac, unsafe_init = eval_velocity(trace_list) + # print(f"Average velocity {avg_vel}, Unsafe fraction {unsafe_frac}, Unsafe init {unsafe_init}") + # # ----------------------------------------- + + # ----------- verify old version ---------- + # traces = scenario.verify(30, 1) + # # fig = go.Figure() + # # fig = reachtube_tree(traces, fig, 0,1,[0,1],'lines', 'trace') + # # fig.show() + # # fig = go.Figure() + # # fig = reachtube_tree(traces, fig, 0,2,[0,1],'lines', 'trace') + # # fig.show() + + # fig = go.Figure() + # fig = reachtube_tree_3d(traces, fig, 0,'time', 1,'x',2,'y') + # fig.show() + + # ----------------------------------------- + + # # ------------- Verify refine ------------- + # com_traces = verify_refine(scenario, 50, 0.1) + # fig = go.Figure() + # fig = reachtube_tree_3d(com_traces, fig,\ + # 0,'time', 1,'x',2,'y') + # fig.show() + # # ----------------------------------------- diff --git a/demo/traffic_signal/vehicle_controller.py b/demo/traffic_signal/vehicle_controller.py new file mode 100644 index 00000000..60a63f79 --- /dev/null +++ b/demo/traffic_signal/vehicle_controller.py @@ -0,0 +1,37 @@ +from enum import Enum, auto +import copy +from typing import List + +class TLMode(Enum): + GREEN=auto() + YELLOW=auto() + RED=auto() + +class VehicleMode(Enum): + Normal = auto() + Brake = auto() + Accel = auto() + HardBrake = auto() + +class State: + x: float + y: float + theta: float + v: float + agent_mode: VehicleMode + + def __init__(self, x, y, theta, v, agent_mode: VehicleMode): + pass + +def decisionLogic(ego: State, other: State): + output = copy.deepcopy(ego) + if ego.agent_mode == VehicleMode.Normal and other.signal_mode == TLMode.RED: + output.agent_mode = VehicleMode.Brake + if ego.agent_mode == VehicleMode.Brake and other.signal_mode != TLMode.RED: + output.agent_mode = VehicleMode.Accel + # if (ego.agent_mode == VehicleMode.Brake or ego.agent_mode == VehicleMode.HardBrake) and other.y>5: + # output.agent_mode = VehicleMode.Accel + + # assert other.dist > 2.0 + + return output \ No newline at end of file diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py index 6a97e4ea..5e33c653 100644 --- a/verse/analysis/verifier.py +++ b/verse/analysis/verifier.py @@ -340,7 +340,7 @@ def compute_full_reachtube_step( # pp(("to sim", new_cache.keys(), len(paths_to_sim))) # Get all possible transitions to next mode - asserts, all_possible_transitions = Verifier.get_transition_verify_opt( + asserts, all_possible_transitions = Verifier.get_transition_verify( config, new_cache, paths_to_sim, node, consts.lane_map, consts.sensor ) node.assert_hits = asserts @@ -675,6 +675,215 @@ def compute_full_reachtube( return self.reachtube_tree + @staticmethod + def get_transition_verify( + config: "ScenarioConfig", cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode, track_map, sensor + ) -> Tuple[ + Optional[Dict[str, List[str]]], + Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]] + ]: + # For each agent + agent_guard_dict = defaultdict(list) + cached_guards = defaultdict(list) + min_trans_ind = None + cached_trans = defaultdict(list) + agent_dict = node.agent + + if not cache: + paths = [(agent, p) for agent in node.agent.values() for p in agent.decision_logic.paths] + else: + + # _transitions = [trans.transition for seg in cache.values() for trans in seg.transitions] + _transitions = [(aid, trans) for aid, seg in cache.items() for trans in seg.transitions if reach_trans_suit(trans.inits, node.init)] + # pp(("cached trans", len(_transitions))) + if len(_transitions) > 0: + min_trans_ind = min([t.transition for _, t in _transitions]) + # TODO: check for asserts + cached_trans = [(aid, tran.mode, tran.dest, tran.reset, tran.reset_idx, tran.paths) for aid, tran in dedup(_transitions, lambda p: (p[0], p[1].mode, p[1].dest)) if tran.transition == min_trans_ind] + if len(paths) == 0: + # print(red("full cache")) + return None, cached_trans + + path_transitions = defaultdict(int) + for seg in cache.values(): + for tran in seg.transitions: + for p in tran.paths: + path_transitions[p.cond] = max(path_transitions[p.cond], tran.transition) + for agent_id, segment in cache.items(): + agent = node.agent[agent_id] + if len(agent.decision_logic.args) == 0: + continue + state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} + + agent_paths = dedup([p for tran in segment.transitions for p in tran.paths], lambda i: (i.var, i.cond, i.val)) + for path in agent_paths: + cont_var_dict_template, discrete_variable_dict, length_dict = sensor.sense( + agent, state_dict, track_map) + reset = (path.var, path.val_veri) + guard_expression = GuardExpressionAst([path.cond_veri]) + + cont_var_updater = guard_expression.parse_any_all_new( + cont_var_dict_template, discrete_variable_dict, length_dict) + Verifier.apply_cont_var_updater( + cont_var_dict_template, cont_var_updater) + guard_can_satisfied = guard_expression.evaluate_guard_disc( + agent, discrete_variable_dict, cont_var_dict_template, track_map) + if not guard_can_satisfied: + continue + cached_guards[agent_id].append((path, guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset, path_transitions[path.cond])) + + # for aid, trace in node.trace.items(): + # if len(trace) < 2: + # pp(("weird state", aid, trace)) + for agent, path in paths: + if len(agent.decision_logic.args) == 0: + continue + agent_id = agent.id + state_dict = {aid: (node.trace[aid][0:2], node.mode[aid], node.static[aid]) for aid in node.agent} + cont_var_dict_template, discrete_variable_dict, length_dict = sensor.sense( + agent, state_dict, track_map) + # TODO-PARSER: Get equivalent for this function + # Construct the guard expression + guard_expression = GuardExpressionAst([path.cond_veri]) + + cont_var_updater = guard_expression.parse_any_all_new( + cont_var_dict_template, discrete_variable_dict, length_dict) + Verifier.apply_cont_var_updater( + cont_var_dict_template, cont_var_updater) + guard_can_satisfied = guard_expression.evaluate_guard_disc( + agent, discrete_variable_dict, cont_var_dict_template, track_map) + if not guard_can_satisfied: + continue + agent_guard_dict[agent_id].append( + (guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), path)) + + trace_length = int(min(len(v) for v in node.trace.values()) // 2) + # pp(("trace len", trace_length, {a: len(t) for a, t in node.trace.items()})) + guard_hits = [] + guard_hit = False + for idx in range(trace_length): + if min_trans_ind != None and idx >= min_trans_ind: + return None, cached_trans + any_contained = False + hits = [] + state_dict = {aid: (node.trace[aid][idx*2:idx*2+2], node.mode[aid], node.static[aid]) for aid in node.agent} + + asserts = defaultdict(list) + for agent_id in agent_dict.keys(): + agent: BaseAgent = agent_dict[agent_id] + if len(agent.decision_logic.args) == 0: + continue + agent_state, agent_mode, agent_static = state_dict[agent_id] + # if np.array(agent_state).ndim != 2: + # pp(("weird state", agent_id, agent_state)) + agent_state = agent_state[1:] + cont_vars, disc_vars, len_dict = sensor.sense(agent, state_dict, track_map) + resets = defaultdict(list) + # Check safety conditions + for i, a in enumerate(agent.decision_logic.asserts_veri): + pre_expr = a.pre + + def eval_expr(expr): + ge = GuardExpressionAst([copy.deepcopy(expr)]) + cont_var_updater = ge.parse_any_all_new(cont_vars, disc_vars, len_dict) + Verifier.apply_cont_var_updater(cont_vars, cont_var_updater) + sat = ge.evaluate_guard_disc(agent, disc_vars, cont_vars, track_map) + if sat: + sat = ge.evaluate_guard_hybrid(agent, disc_vars, cont_vars, track_map) + if sat: + sat, contained = ge.evaluate_guard_cont(agent, cont_vars, track_map) + sat = sat and contained + return sat + if eval_expr(pre_expr): + if not eval_expr(a.cond): + label = a.label if a.label != None else f"" + print(f"assert hit for {agent_id}: \"{label}\"") + print(idx) + asserts[agent_id].append(label) + if agent_id in asserts: + continue + if agent_id not in agent_guard_dict: + continue + + unchecked_cache_guards = [g[:-1] for g in cached_guards[agent_id] if g[-1] < idx] # FIXME: off by 1? + for guard_expression, continuous_variable_updater, discrete_variable_dict, path in agent_guard_dict[agent_id] + unchecked_cache_guards: + assert isinstance(path, ModePath) + new_cont_var_dict = copy.deepcopy(cont_vars) + one_step_guard: GuardExpressionAst = copy.deepcopy(guard_expression) + + Verifier.apply_cont_var_updater(new_cont_var_dict, continuous_variable_updater) + guard_can_satisfied = one_step_guard.evaluate_guard_hybrid( + agent, discrete_variable_dict, new_cont_var_dict, track_map) + if not guard_can_satisfied: + continue + guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont( + agent, new_cont_var_dict, track_map) + any_contained = any_contained or is_contained + # TODO: Can we also store the cont and disc var dict so we don't have to call sensor again? + if guard_satisfied: + reset_expr = ResetExpression((path.var, path.val_veri)) + resets[reset_expr.var].append( + (reset_expr, discrete_variable_dict, + new_cont_var_dict, guard_expression.guard_idx, path) + ) + # Perform combination over all possible resets to generate all possible real resets + combined_reset_list = list(itertools.product(*resets.values())) + if len(combined_reset_list) == 1 and combined_reset_list[0] == (): + continue + for i in range(len(combined_reset_list)): + # Compute reset_idx + reset_idx = [] + for reset_info in combined_reset_list[i]: + reset_idx.append(reset_info[3]) + # a list of reset expression + hits.append((agent_id, tuple(reset_idx), combined_reset_list[i])) + if len(asserts) > 0: + return (asserts, idx), None + if hits != []: + guard_hits.append((hits, state_dict, idx)) + guard_hit = True + elif guard_hit: + break + if any_contained: + break + + reset_dict = {} # defaultdict(lambda: defaultdict(list)) + for hits, all_agent_state, hit_idx in guard_hits: + for agent_id, reset_idx, reset_list in hits: + # TODO: Need to change this function to handle the new reset expression and then I am done + dest_list, reset_rect = Verifier.apply_reset(node.agent[agent_id], reset_list, all_agent_state, track_map) + # pp(("dests", dest_list, *[astunparser.unparse(reset[-1].val_veri) for reset in reset_list])) + if agent_id not in reset_dict: + reset_dict[agent_id] = {} + if not dest_list: + warnings.warn( + f"Guard hit for mode {node.mode[agent_id]} for agent {agent_id} without available next mode") + dest_list.append(None) + if reset_idx not in reset_dict[agent_id]: + reset_dict[agent_id][reset_idx] = {} + for dest in dest_list: + if dest not in reset_dict[agent_id][reset_idx]: + reset_dict[agent_id][reset_idx][dest] = [] + reset_dict[agent_id][reset_idx][dest].append((reset_rect, hit_idx, reset_list[-1])) + + possible_transitions = [] + # Combine reset rects and construct transitions + for agent in reset_dict: + for reset_idx in reset_dict[agent]: + for dest in reset_dict[agent][reset_idx]: + reset_data = tuple(map(list, zip(*reset_dict[agent][reset_idx][dest]))) + paths = [r[-1] for r in reset_data[-1]] + transition = (agent, node.mode[agent],dest, *reset_data[:-1], paths) + src_mode = node.get_mode(agent, node.mode[agent]) + src_track = node.get_track(agent, node.mode[agent]) + dest_mode = node.get_mode(agent, dest) + dest_track = node.get_track(agent, dest) + if dest_track == track_map.h(src_track, src_mode, dest_mode): + possible_transitions.append(transition) + print(transition[4]) + # Return result + return None, possible_transitions + @staticmethod def get_transition_verify_opt( config: "ScenarioConfig", cache: Dict[str, CachedRTTrans], paths: PathDiffs, node: AnalysisTreeNode, track_map, sensor