From b8a820be1a4e2f83973a2f9d3a847be70be9f712 Mon Sep 17 00:00:00 2001 From: Wesley Maa Date: Tue, 14 Jan 2025 10:26:56 -0800 Subject: [PATCH] freq --- kos_sim/config.py | 24 +++++++++++++++++------- kos_sim/server.py | 28 ++++++++++++++++++++-------- kos_sim/services.py | 10 ++++++++-- kos_sim/simulator.py | 25 ++++++++++++++++--------- kos_sim/test_client.py | 2 +- pyproject.toml | 2 +- 6 files changed, 63 insertions(+), 28 deletions(-) diff --git a/kos_sim/config.py b/kos_sim/config.py index 87d9768..fb79d1b 100644 --- a/kos_sim/config.py +++ b/kos_sim/config.py @@ -12,8 +12,15 @@ class SimulatorConfig: joint_id_to_name: dict[int, str] joint_name_to_id: dict[str, int] - kp: float = 32.0 + kp: float = 80.0 kd: float = 10.0 + physics_freq: float = 1000.0 # Hz + command_freq: float = 50.0 # Hz + + @property + def sim_decimation(self) -> int: + """Calculate decimation factor to achieve desired command frequency.""" + return max(1, int(self.physics_freq / self.command_freq)) @classmethod def from_file(cls, config_path: str) -> "SimulatorConfig": @@ -21,15 +28,18 @@ def from_file(cls, config_path: str) -> "SimulatorConfig": with open(config_path, "r") as f: config_data = yaml.safe_load(f) - # Load joint mappings joint_name_to_id = config_data.get("joint_mappings", {}) joint_id_to_name = {v: k for k, v in joint_name_to_id.items()} - # Load control parameters - kp = config_data.get("control", {}).get("kp", 1.0) - kd = config_data.get("control", {}).get("kd", 0.1) - - return cls(joint_id_to_name=joint_id_to_name, joint_name_to_id=joint_name_to_id, kp=kp, kd=kd) + control_config = config_data.get("control", {}) + return cls( + joint_id_to_name=joint_id_to_name, + joint_name_to_id=joint_name_to_id, + kp=control_config.get("kp", 80.0), + kd=control_config.get("kd", 10.0), + physics_freq=control_config.get("physics_freq", 1000.0), + command_freq=control_config.get("command_freq", 50.0), + ) @classmethod def default(cls) -> "SimulatorConfig": diff --git a/kos_sim/server.py b/kos_sim/server.py index 0f10935..fb62006 100644 --- a/kos_sim/server.py +++ b/kos_sim/server.py @@ -11,14 +11,24 @@ from kos_sim.services import ActuatorService, IMUService from kos_sim.simulator import MujocoSimulator +from kos_sim.config import SimulatorConfig logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SimulationServer: - def __init__(self, model_path: str, port: int = 50051, dt: float = 0.002) -> None: - self.simulator = MujocoSimulator(model_path, dt=dt, render=True, suspended=False) + def __init__( + self, + model_path: str, + config_path: str | None = None, + port: int = 50051 + ) -> None: + if config_path: + config = SimulatorConfig.from_file(config_path) + else: + config = SimulatorConfig.default() + self.simulator = MujocoSimulator(model_path, config=config, render=True) self.port = port self._stop_event = threading.Event() self._grpc_thread: threading.Thread | None = None @@ -48,15 +58,16 @@ def _grpc_server_loop(self) -> None: def start(self) -> None: """Start the gRPC server and run simulation in main thread.""" - # Start gRPC server in separate thread self._grpc_thread = threading.Thread(target=self._grpc_server_loop) self._grpc_thread.start() - # Run simulation in main thread + physics_dt = 1.0 / self.simulator._config.physics_freq try: while not self._stop_event.is_set(): + process_start = time.time() self.simulator.step() - time.sleep(self.simulator._model.opt.timestep) + # Sleep to maintain physics frequency + time.sleep(max(0, physics_dt - (time.time() - process_start))) except KeyboardInterrupt: self.stop() @@ -71,8 +82,8 @@ def stop(self) -> None: self.simulator.close() -def serve(model_path: str, port: int = 50051) -> None: - server = SimulationServer(model_path, port) +def serve(model_path: str, config_path: str | None = None, port: int = 50051) -> None: + server = SimulationServer(model_path, config_path=config_path, port=port) server.start() @@ -80,6 +91,7 @@ def serve(model_path: str, port: int = 50051) -> None: parser = argparse.ArgumentParser(description="Start the simulation gRPC server.") parser.add_argument("--model-path", type=str, required=True, help="Path to MuJoCo XML model file") parser.add_argument("--port", type=int, default=50051, help="Port to listen on") + parser.add_argument("--config-path", type=str, default=None, help="Path to config file") args = parser.parse_args() - serve(args.model_path, args.port) + serve(args.model_path, args.config_path, args.port) diff --git a/kos_sim/services.py b/kos_sim/services.py index 13e5327..5bb6afc 100644 --- a/kos_sim/services.py +++ b/kos_sim/services.py @@ -1,5 +1,7 @@ """Service implementations for MuJoCo simulation.""" +import math + import grpc from google.protobuf import empty_pb2 from kos_protos import actuator_pb2, actuator_pb2_grpc, imu_pb2, imu_pb2_grpc @@ -18,7 +20,7 @@ def CommandActuators( # noqa: N802 ) -> actuator_pb2.CommandActuatorsResponse: """Implements CommandActuators by forwarding to simulator.""" try: - commands = {cmd.actuator_id: cmd.position for cmd in request.commands} + commands = {cmd.actuator_id: math.radians(cmd.position) for cmd in request.commands} self.simulator.command_actuators(commands) return actuator_pb2.CommandActuatorsResponse() except Exception as e: @@ -37,7 +39,11 @@ def GetActuatorsState( # noqa: N802 } return actuator_pb2.GetActuatorsStateResponse( states=[ - actuator_pb2.ActuatorStateResponse(actuator_id=joint_id, position=float(state), online=True) + actuator_pb2.ActuatorStateResponse( + actuator_id=joint_id, + position=math.degrees(float(state)), + online=True + ) for joint_id, state in states.items() ] ) diff --git a/kos_sim/simulator.py b/kos_sim/simulator.py index 513609b..f4c6985 100644 --- a/kos_sim/simulator.py +++ b/kos_sim/simulator.py @@ -83,11 +83,18 @@ def __init__( self._kp = np.array([self._config.kp] * self._model.nu) self._kd = np.array([self._config.kd] * self._model.nu) + self._count_lowlevel = 0 + self._target_positions: dict[str, float] = {} # Store target positions between updates + def step(self) -> None: """Execute one step of the simulation.""" with self._lock: + # Only update commands every sim_decimation steps + if self._count_lowlevel % self._config.sim_decimation == 0: + self._target_positions = self._current_commands.copy() + # Apply actuator commands using PD control - for name, target_pos in self._current_commands.items(): + for name, target_pos in self._target_positions.items(): actuator_id = self._actuator_ids[name] current_pos = self._data.qpos[actuator_id] current_vel = self._data.qvel[actuator_id] @@ -99,16 +106,16 @@ def step(self) -> None: # Step physics mujoco.mj_step(self._model, self._data) - # If suspended, maintain position and orientation - if self._suspended and self._initial_pos is not None: - # Find the free joint + # Increment counter + self._count_lowlevel += 1 + + if self._suspended: + # Find the root joint (floating_base) for i in range(self._model.njnt): if self._model.jnt_type[i] == mujoco.mjtJoint.mjJNT_FREE: - # Reset position and orientation - self._data.qpos[i : i + 3] = self._initial_pos - self._data.qpos[i + 3 : i + 7] = self._initial_quat - # Zero out velocities - self._data.qvel[i : i + 6] = 0 + print(f"Joint name: {self._model.joint(i).name}") + self._data.qpos[i:i + 7] = self._model.keyframe("default").qpos[i:i + 7] + self._data.qvel[i:i + 6] = 0 break if self._render_enabled: diff --git a/kos_sim/test_client.py b/kos_sim/test_client.py index 795449c..167d6cd 100644 --- a/kos_sim/test_client.py +++ b/kos_sim/test_client.py @@ -21,7 +21,7 @@ def test_actuator_commands(host: str = "localhost", port: int = 50051) -> None: frequency = 1.0 # Hz amplitude = 45.0 # degrees duration = 5.0 # seconds - actuator_id = 1 + actuator_id = 2 logger.info("Starting actuator command test...") diff --git a/pyproject.toml b/pyproject.toml index 1d9c39e..24c26f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ ignore = [ "D101", "D102", "D103", "D104", "D105", "D106", "D107", "N812", "N817", "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004", - "PLW0603", "PLW2901", "ANN101", "ANN102", + "PLW0603", "PLW2901", ] dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"