Skip to content

Commit

Permalink
add more interesting behavior to the traffic signal scenario
Browse files Browse the repository at this point in the history
  • Loading branch information
lyg1597 committed Apr 1, 2024
1 parent 5148be3 commit 313b72d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 28 deletions.
14 changes: 10 additions & 4 deletions demo/traffic_signal/mp0.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,19 @@ def sense(self, agent: BaseAgent, state_dict, lane_map):
cont['other.dist'] = dist
else:
cont['other.dist'] = 1000
cont['other.x'] = state_dict['tl'][0][1]
elif agent.id == 'tl':
cont['ego.timer'] = state_dict['tl'][0][5]
disc['ego.signal_mode'] = state_dict['tl'][1][0]
else:
if agent.id == 'car':
len_dict['others'] = 1
dist_min, dist_max = get_extreme(
(state_dict['car'][0][0][1],state_dict['car'][0][0][2],state_dict['car'][0][1][1],state_dict['car'][0][1][2]),
(state_dict['tl'][0][0][1],state_dict['tl'][0][0][2],state_dict['tl'][0][1][1],state_dict['tl'][0][1][2]),
)
# dist_min, dist_max = get_extreme(
# (state_dict['car'][0][0][1],state_dict['car'][0][0][2],state_dict['car'][0][1][1],state_dict['car'][0][1][2]),
# (state_dict['tl'][0][0][1],state_dict['tl'][0][0][2],state_dict['tl'][0][1][1],state_dict['tl'][0][1][2]),
# )
dist_min = min(abs(state_dict['car'][0][0][1]-state_dict['tl'][0][0][1]), abs(state_dict['car'][0][1][1]-state_dict['tl'][0][0][1]))
dist_max = max(abs(state_dict['car'][0][0][1]-state_dict['tl'][0][0][1]), abs(state_dict['car'][0][1][1]-state_dict['tl'][0][0][1]))
cont['ego.x'] = [
state_dict['car'][0][0][1], state_dict['car'][0][1][1]
]
Expand All @@ -288,6 +291,9 @@ def sense(self, agent: BaseAgent, state_dict, lane_map):
cont['ego.v'] = [
state_dict['car'][0][0][4], state_dict['car'][0][1][4]
]
cont['other.x'] = [
state_dict['tl'][0][0][1], state_dict['tl'][0][1][1]
]
cont['other.dist'] = [
dist_min, dist_max
]
Expand Down
41 changes: 21 additions & 20 deletions demo/traffic_signal/traffic_signal_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
scenario.set_sensor(TrafficSensor())

# # R1
init_car = [[-5,-5,0,5],[5,5,0,5]]
init_pedestrian = [[200,0,0,0,0],[200,0,0,0,0]]
init_car = [[0,-5,0,5],[50,5,0,5]]
init_pedestrian = [[300,0,0,0,0],[300,0,0,0,0]]

# R2
# init_car = [[-5,-5,0,5],[5,5,0,10]]
Expand All @@ -41,35 +41,36 @@
)

# # ----------- Simulate single -------------
trace = scenario.simulate(80, 0.1)
fig = go.Figure()
# fig = simulation_tree_3d(trace, fig,\
# 0,'time', 1,'x',2,'y')
# trace = scenario.simulate(80, 0.1)
# fig = go.Figure()
# # fig = simulation_tree_3d(trace, fig,\
# # 0,'time', 1,'x',2,'y')
# # fig.show()
# fig = simulation_tree(trace, None, fig, 0, 1)
# fig.show()
fig = simulation_tree(trace, None, fig, 0, 1)
fig.show()

# # ----------- Simulate multi -------------
# init_dict_list= sample_init(scenario, num_sample=50)
# trace_list = scenario.simulate_multi(50, 0.1,\
# init_dict_list=init_dict_list)
# fig = go.Figure()
# for trace in trace_list:
# fig = simulation_tree_3d(trace, fig,\
# 0,'time', 1,'x',2,'y')
# fig.show()
init_dict_list= sample_init(scenario, num_sample=50)
trace_list = scenario.simulate_multi(100, 0.1,\
init_dict_list=init_dict_list)
fig = go.Figure()
for trace in trace_list:
# fig = simulation_tree_3d(trace, fig,\
# 0,'time', 1,'x',2,'y')
fig = simulation_tree(trace, None, fig, 0, 1)
fig.show()
# avg_vel, unsafe_frac, unsafe_init = eval_velocity(trace_list)
# print(f"Average velocity {avg_vel}, Unsafe fraction {unsafe_frac}, Unsafe init {unsafe_init}")
# # -----------------------------------------

# ----------- verify old version ----------
traces = scenario.verify(80, 0.1)
traces = scenario.verify(100, 0.1)
fig = go.Figure()
fig = reachtube_tree(traces, None, fig, 0,1,[0,1],'lines', 'trace')
fig.show()
fig = go.Figure()
fig = reachtube_tree(traces, None, fig, 0,2,[0,1],'lines', 'trace')
fig.show()
# fig = go.Figure()
# fig = reachtube_tree(traces, None, fig, 0,2,[0,1],'lines', 'trace')
# fig.show()

# fig = go.Figure()
# fig = reachtube_tree_3d(traces, fig, 0,'time', 1,'x',2,'y')
Expand Down
13 changes: 9 additions & 4 deletions demo/traffic_signal/vehicle_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ def __init__(self, x, y, theta, v, agent_mode: VehicleMode):

def decisionLogic(ego: State, other: State):
output = copy.deepcopy(ego)
if ego.agent_mode == VehicleMode.Normal and other.signal_mode == TLMode.RED and other.dist<40:
# if ego.x > 100000:
# output.agent_mode = VehicleMode.Brake
# if ego.agent_mode == VehicleMode.Normal:
# output.agent_mode = VehicleMode.Accel
if ego.agent_mode == VehicleMode.Normal and other.signal_mode == TLMode.RED and other.dist<60:
output.agent_mode = VehicleMode.Brake
elif ego.agent_mode == VehicleMode.Normal and other.signal_mode == TLMode.GREEN and other.dist < 20:
elif ego.agent_mode == VehicleMode.Normal and other.signal_mode == TLMode.YELLOW and other.dist < 60:
output.agent_mode = VehicleMode.Brake
if ego.agent_mode == VehicleMode.Brake and other.signal_mode != TLMode.RED:
if ego.agent_mode == VehicleMode.Brake and other.signal_mode == TLMode.GREEN:
output.agent_mode = VehicleMode.Accel
# if (ego.agent_mode == VehicleMode.Brake or ego.agent_mode == VehicleMode.HardBrake) and other.y>5:
# output.agent_mode = VehicleMode.Accel

assert not (other.signal_mode == TLMode.RED and (ego.x>190 and ego.x<210))
assert not (other.signal_mode == TLMode.RED and (ego.x>other.x-20 and ego.x<other.x-15)), "run red light"
assert not (other.signal_mode == TLMode.RED and (ego.x>other.x-15 and ego.x<other.x) and ego.v<1), "stop at intersection"

return output

0 comments on commit 313b72d

Please sign in to comment.