From a03c4e0057290bd6796a549623461b932fc0a979 Mon Sep 17 00:00:00 2001 From: patorrad Date: Tue, 8 Nov 2022 14:58:25 -0800 Subject: [PATCH] Added collision sphere visualization --- examples/franka_reacher.py | 19 +++++++++++++++++++ storm_kit/gym/core.py | 20 ++++++++++++++++++++ storm_kit/mpc/rollout/arm_base.py | 5 ++++- 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/examples/franka_reacher.py b/examples/franka_reacher.py index aeaabef3..cfbc016b 100644 --- a/examples/franka_reacher.py +++ b/examples/franka_reacher.py @@ -301,6 +301,25 @@ def mpc_robot_interactive(args, gym_instance): color[1] = 1.0 - float(k) / float(top_trajs.shape[0]) gym_instance.draw_lines(pts, color=color) + link_pos_seq = copy.deepcopy(mpc_control.controller.rollout_fn.link_pos_seq) + link_rot_seq = copy.deepcopy(mpc_control.controller.rollout_fn.link_rot_seq) + batch_size = link_pos_seq.shape[0] + horizon = link_pos_seq.shape[1] + n_links = link_pos_seq.shape[2] + link_pos = link_pos_seq.view(batch_size * horizon, n_links, 3) + link_rot = link_rot_seq.view(batch_size * horizon, n_links, 3, 3) + mpc_control.controller.rollout_fn.robot_self_collision_cost.coll.update_batch_robot_collision_objs(link_pos, link_rot) + + spheres = mpc_control.controller.rollout_fn.get_spheres() + arr = None + for sphere in spheres: + if arr is None: + arr = np.array(sphere[1:,:,:4].cpu().numpy().squeeze()) + else: + arr = np.vstack((arr,sphere[1:,:,:4].cpu().numpy().squeeze())) + + [gym_instance.draw_collision_spheres(sphere,w_T_r) for sphere in arr] + robot_sim.command_robot_position(q_des, env_ptr, robot_ptr) #robot_sim.set_robot_state(q_des, qd_des, env_ptr, robot_ptr) current_state = command diff --git a/storm_kit/gym/core.py b/storm_kit/gym/core.py index fc2b3092..54819967 100644 --- a/storm_kit/gym/core.py +++ b/storm_kit/gym/core.py @@ -114,6 +114,26 @@ def draw_lines(self, pts, color=[0.5,0.0,0.0], env_idx=0, w_T_l=None): self.gym.add_lines(self.viewer,self.env_list[env_idx],pts.shape[0] - 1,verts, colors) #self.gym.add_lines(self.viewer,self.env_list[env_idx],pts.shape[0] - 1,verts, colors) + def draw_sphere(self, pose, diameter, edges): + sphere_rot = gymapi.Quat.from_euler_zyx(0.5 * np.pi, 0, 0) + sphere_pose = gymapi.Transform(r=sphere_rot) + sphere_geom = gymutil.WireframeSphereGeometry(diameter, edges, edges, sphere_pose, color=(1, 0, 0)) # Should be radius? + + verts = sphere_geom.instance_verts(pose) + colors = np.empty(1, dtype=gymapi.Vec3.dtype) + # colors = (1, 1, 0) + self.gym.add_lines(self.viewer,self.env_list[0], sphere_geom.num_lines(), verts, sphere_geom.colors()) + + def draw_collision_spheres(self, sub_sphere, w_T_r): + print(sub_sphere) + goal_pose = gymapi.Transform() + goal_pose.p = gymapi.Vec3(sub_sphere[0], sub_sphere[1], sub_sphere[2]) + goal_pose.r = gymapi.Quat(0, 0.707, 0, 0.707) + w_T_r_pose = w_T_r * goal_pose + radius = sub_sphere[3] + edges = 6 + self.draw_sphere(w_T_r_pose, radius, edges) + class World(object): def __init__(self, gym_instance, sim_instance, env_ptr, world_params=None, w_T_r=None): self.gym = gym_instance diff --git a/storm_kit/mpc/rollout/arm_base.py b/storm_kit/mpc/rollout/arm_base.py index 71f13311..54e6d917 100644 --- a/storm_kit/mpc/rollout/arm_base.py +++ b/storm_kit/mpc/rollout/arm_base.py @@ -209,7 +209,10 @@ def cost_fn(self, state_dict, action_batch, no_coll=False, horizon_cost=True): return cost - + + def get_spheres(self): + return self.robot_self_collision_cost.coll.w_batch_link_spheres + def rollout_fn(self, start_state, act_seq): """ Return sequence of costs and states encountered