diff --git a/.gitignore b/.gitignore index e5073fe..b3dbaa3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ out*/ # dev ref/ +.kos-sim/ diff --git a/kos_sim/client.py b/kos_sim/client.py deleted file mode 100644 index 04e8de2..0000000 --- a/kos_sim/client.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Client loop for KOS.""" - -import asyncio - -import pykos - -from kos_sim import logger - - -async def main() -> None: - kos = pykos.KOS(ip="localhost", port=50051) - sim = kos.simulation - - sim.set_parameters(time_scale=1.0) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/kos_sim/config.py b/kos_sim/config.py deleted file mode 100644 index 551cb31..0000000 --- a/kos_sim/config.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Configuration for the simulator.""" - -import logging -from dataclasses import dataclass - -import yaml - -logger = logging.getLogger(__name__) - - -@dataclass -class SimulatorConfig: - joint_id_to_name: dict[int, str] - joint_name_to_id: dict[str, int] - kp: float = 80.0 - kd: float = 10.0 - dt: float = 0.01 - command_freq: float = 50.0 # Hz - - @property - def physics_freq(self) -> float: - """Calculate physics frequency from timestep.""" - return 1.0 / self.dt - - @property - def sim_decimation(self) -> int: - """Calculate decimation factor to achieve desired command frequency.""" - return max(1, int(self.physics_freq / self.command_freq)) - - @classmethod - def from_file(cls, config_path: str) -> "SimulatorConfig": - """Load config from YAML file.""" - with open(config_path, "r") as f: - config_data = yaml.safe_load(f) - - joint_name_to_id = config_data.get("joint_mappings", {}) - joint_id_to_name = {v: k for k, v in joint_name_to_id.items()} - - control_config = config_data.get("control", {}) - return cls( - joint_id_to_name=joint_id_to_name, - joint_name_to_id=joint_name_to_id, - kp=control_config.get("kp", 80.0), - kd=control_config.get("kd", 10.0), - dt=control_config.get("dt", 0.001), - command_freq=control_config.get("command_freq", 50.0), - ) - - @classmethod - def default(cls) -> "SimulatorConfig": - """Create default config with standard joint mappings.""" - joint_name_to_id = { - "L_hip_y_04": 1, - "L_hip_x_03": 2, - "L_hip_z_03": 3, - "L_knee_04": 4, - "L_ankle_02": 5, - "R_hip_y_04": 6, - "R_hip_x_03": 7, - "R_hip_z_03": 8, - "R_knee_04": 9, - "R_ankle_02": 10, - } - return cls(joint_id_to_name={v: k for k, v in joint_name_to_id.items()}, joint_name_to_id=joint_name_to_id) diff --git a/kos_sim/server.py b/kos_sim/server.py index de95091..0069402 100644 --- a/kos_sim/server.py +++ b/kos_sim/server.py @@ -4,33 +4,30 @@ import asyncio import time from concurrent import futures +from pathlib import Path import colorlogging import grpc from kos_protos import actuator_pb2_grpc, imu_pb2_grpc, sim_pb2_grpc from kscale import K +from kscale.web.gen.api import RobotURDFMetadataOutput from kos_sim import logger -from kos_sim.config import SimulatorConfig from kos_sim.services import ActuatorService, IMUService, SimService from kos_sim.simulator import MujocoSimulator from kos_sim.stepping import StepController, StepMode +from kos_sim.utils import get_sim_artifacts_path class SimulationServer: def __init__( self, - model_path: str, - config_path: str | None = None, + model_path: str | Path, + model_metadata: RobotURDFMetadataOutput, port: int = 50051, step_mode: StepMode = StepMode.CONTINUOUS, ) -> None: - if config_path: - config = SimulatorConfig.from_file(config_path) - else: - config = SimulatorConfig.default() - - self.simulator = MujocoSimulator(model_path, config=config, render=True) + self.simulator = MujocoSimulator(model_path, model_metadata, render=True) self.step_controller = StepController(self.simulator, mode=step_mode) self.port = port self._stop_event = asyncio.Event() @@ -102,9 +99,25 @@ async def stop(self) -> None: self.simulator.close() +async def get_model_metadata(api: K, model_name: str) -> RobotURDFMetadataOutput: + model_path = get_sim_artifacts_path() / model_name / "metadata.json" + if model_path.exists(): + return RobotURDFMetadataOutput.model_validate_json(model_path.read_text()) + model_path.parent.mkdir(parents=True, exist_ok=True) + model_metadata = await api.get_robot_class(model_name) + model_path.write_text(model_metadata.model_dump_json()) + return model_metadata + + async def serve(model_name: str, config_path: str | None = None, port: int = 50051) -> None: api = K() - model_dir = await api.download_and_extract_urdf(model_name) + model_dir, model_metadata = await asyncio.gather( + api.download_and_extract_urdf(model_name), + get_model_metadata(api, model_name), + ) + + breakpoint() + model_path = next(model_dir.glob("*.mjcf")) server = SimulationServer(model_path, config_path=config_path, port=port) @@ -123,6 +136,10 @@ async def run_server() -> None: await serve(args.model_name, args.config_path, args.port) +def main() -> None: + asyncio.run(run_server()) + + if __name__ == "__main__": # python -m kos_sim.server - asyncio.run(run_server()) + main() diff --git a/kos_sim/simulator.py b/kos_sim/simulator.py index db96c65..b157e8d 100644 --- a/kos_sim/simulator.py +++ b/kos_sim/simulator.py @@ -10,22 +10,22 @@ import mujoco_viewer import numpy as np from kscale import K +from kscale.web.gen.api import RobotURDFMetadataOutput from kos_sim import logger -from kos_sim.config import SimulatorConfig class MujocoSimulator: def __init__( self, model_path: str | Path, - config: SimulatorConfig | None = None, + model_metadata: RobotURDFMetadataOutput, render: bool = False, gravity: bool = True, suspended: bool = False, ) -> None: # Load config or use default - self._config = config or SimulatorConfig.default() + self._metadata = model_metadata # Load MuJoCo model and initialize data self._model = mujoco.MjModel.from_xml_path(str(model_path)) diff --git a/kos_sim/utils.py b/kos_sim/utils.py new file mode 100644 index 0000000..e4a2e90 --- /dev/null +++ b/kos_sim/utils.py @@ -0,0 +1,9 @@ +"""Defines some utility functions.""" + +import os +from pathlib import Path + + +def get_sim_artifacts_path() -> Path: + base_path = os.getenv("KOS_SIM_CACHE_PATH", ".kos-sim") + return Path(base_path).expanduser().resolve() diff --git a/setup.py b/setup.py index f59bac6..cfe83ca 100644 --- a/setup.py +++ b/setup.py @@ -37,9 +37,9 @@ tests_require=requirements_dev, extras_require={"dev": requirements_dev}, packages=["kos_sim"], - # entry_points={ - # "console_scripts": [ - # "kos-sim.cli:main", - # ], - # }, + entry_points={ + "console_scripts": [ + "kos-sim=kos_sim.server:main", + ], + }, )