Skip to content

Commit

Permalink
freq
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Jan 14, 2025
1 parent 0518211 commit b8a820b
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 28 deletions.
24 changes: 17 additions & 7 deletions kos_sim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,34 @@
class SimulatorConfig:
joint_id_to_name: dict[int, str]
joint_name_to_id: dict[str, int]
kp: float = 32.0
kp: float = 80.0
kd: float = 10.0
physics_freq: float = 1000.0 # Hz
command_freq: float = 50.0 # Hz

@property
def sim_decimation(self) -> int:
"""Calculate decimation factor to achieve desired command frequency."""
return max(1, int(self.physics_freq / self.command_freq))

@classmethod
def from_file(cls, config_path: str) -> "SimulatorConfig":
"""Load config from YAML file."""
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)

# Load joint mappings
joint_name_to_id = config_data.get("joint_mappings", {})
joint_id_to_name = {v: k for k, v in joint_name_to_id.items()}

# Load control parameters
kp = config_data.get("control", {}).get("kp", 1.0)
kd = config_data.get("control", {}).get("kd", 0.1)

return cls(joint_id_to_name=joint_id_to_name, joint_name_to_id=joint_name_to_id, kp=kp, kd=kd)
control_config = config_data.get("control", {})
return cls(
joint_id_to_name=joint_id_to_name,
joint_name_to_id=joint_name_to_id,
kp=control_config.get("kp", 80.0),
kd=control_config.get("kd", 10.0),
physics_freq=control_config.get("physics_freq", 1000.0),
command_freq=control_config.get("command_freq", 50.0),
)

@classmethod
def default(cls) -> "SimulatorConfig":
Expand Down
28 changes: 20 additions & 8 deletions kos_sim/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@

from kos_sim.services import ActuatorService, IMUService
from kos_sim.simulator import MujocoSimulator
from kos_sim.config import SimulatorConfig

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SimulationServer:
def __init__(self, model_path: str, port: int = 50051, dt: float = 0.002) -> None:
self.simulator = MujocoSimulator(model_path, dt=dt, render=True, suspended=False)
def __init__(
self,
model_path: str,
config_path: str | None = None,
port: int = 50051
) -> None:
if config_path:
config = SimulatorConfig.from_file(config_path)
else:
config = SimulatorConfig.default()
self.simulator = MujocoSimulator(model_path, config=config, render=True)
self.port = port
self._stop_event = threading.Event()
self._grpc_thread: threading.Thread | None = None
Expand Down Expand Up @@ -48,15 +58,16 @@ def _grpc_server_loop(self) -> None:

def start(self) -> None:
"""Start the gRPC server and run simulation in main thread."""
# Start gRPC server in separate thread
self._grpc_thread = threading.Thread(target=self._grpc_server_loop)
self._grpc_thread.start()

# Run simulation in main thread
physics_dt = 1.0 / self.simulator._config.physics_freq
try:
while not self._stop_event.is_set():
process_start = time.time()
self.simulator.step()
time.sleep(self.simulator._model.opt.timestep)
# Sleep to maintain physics frequency
time.sleep(max(0, physics_dt - (time.time() - process_start)))
except KeyboardInterrupt:
self.stop()

Expand All @@ -71,15 +82,16 @@ def stop(self) -> None:
self.simulator.close()


def serve(model_path: str, port: int = 50051) -> None:
server = SimulationServer(model_path, port)
def serve(model_path: str, config_path: str | None = None, port: int = 50051) -> None:
server = SimulationServer(model_path, config_path=config_path, port=port)
server.start()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start the simulation gRPC server.")
parser.add_argument("--model-path", type=str, required=True, help="Path to MuJoCo XML model file")
parser.add_argument("--port", type=int, default=50051, help="Port to listen on")
parser.add_argument("--config-path", type=str, default=None, help="Path to config file")

args = parser.parse_args()
serve(args.model_path, args.port)
serve(args.model_path, args.config_path, args.port)
10 changes: 8 additions & 2 deletions kos_sim/services.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Service implementations for MuJoCo simulation."""

import math

import grpc
from google.protobuf import empty_pb2
from kos_protos import actuator_pb2, actuator_pb2_grpc, imu_pb2, imu_pb2_grpc
Expand All @@ -18,7 +20,7 @@ def CommandActuators( # noqa: N802
) -> actuator_pb2.CommandActuatorsResponse:
"""Implements CommandActuators by forwarding to simulator."""
try:
commands = {cmd.actuator_id: cmd.position for cmd in request.commands}
commands = {cmd.actuator_id: math.radians(cmd.position) for cmd in request.commands}
self.simulator.command_actuators(commands)
return actuator_pb2.CommandActuatorsResponse()
except Exception as e:
Expand All @@ -37,7 +39,11 @@ def GetActuatorsState( # noqa: N802
}
return actuator_pb2.GetActuatorsStateResponse(
states=[
actuator_pb2.ActuatorStateResponse(actuator_id=joint_id, position=float(state), online=True)
actuator_pb2.ActuatorStateResponse(
actuator_id=joint_id,
position=math.degrees(float(state)),
online=True
)
for joint_id, state in states.items()
]
)
Expand Down
25 changes: 16 additions & 9 deletions kos_sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,18 @@ def __init__(
self._kp = np.array([self._config.kp] * self._model.nu)
self._kd = np.array([self._config.kd] * self._model.nu)

self._count_lowlevel = 0
self._target_positions: dict[str, float] = {} # Store target positions between updates

def step(self) -> None:
"""Execute one step of the simulation."""
with self._lock:
# Only update commands every sim_decimation steps
if self._count_lowlevel % self._config.sim_decimation == 0:
self._target_positions = self._current_commands.copy()

# Apply actuator commands using PD control
for name, target_pos in self._current_commands.items():
for name, target_pos in self._target_positions.items():
actuator_id = self._actuator_ids[name]
current_pos = self._data.qpos[actuator_id]
current_vel = self._data.qvel[actuator_id]
Expand All @@ -99,16 +106,16 @@ def step(self) -> None:
# Step physics
mujoco.mj_step(self._model, self._data)

# If suspended, maintain position and orientation
if self._suspended and self._initial_pos is not None:
# Find the free joint
# Increment counter
self._count_lowlevel += 1

if self._suspended:
# Find the root joint (floating_base)
for i in range(self._model.njnt):
if self._model.jnt_type[i] == mujoco.mjtJoint.mjJNT_FREE:
# Reset position and orientation
self._data.qpos[i : i + 3] = self._initial_pos
self._data.qpos[i + 3 : i + 7] = self._initial_quat
# Zero out velocities
self._data.qvel[i : i + 6] = 0
print(f"Joint name: {self._model.joint(i).name}")
self._data.qpos[i:i + 7] = self._model.keyframe("default").qpos[i:i + 7]
self._data.qvel[i:i + 6] = 0
break

if self._render_enabled:
Expand Down
2 changes: 1 addition & 1 deletion kos_sim/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_actuator_commands(host: str = "localhost", port: int = 50051) -> None:
frequency = 1.0 # Hz
amplitude = 45.0 # degrees
duration = 5.0 # seconds
actuator_id = 1
actuator_id = 2

logger.info("Starting actuator command test...")

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ignore = [
"D101", "D102", "D103", "D104", "D105", "D106", "D107",
"N812", "N817",
"PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004",
"PLW0603", "PLW2901", "ANN101", "ANN102",
"PLW0603", "PLW2901",
]

dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
Expand Down

0 comments on commit b8a820b

Please sign in to comment.