diff --git a/kos_sim/client.py b/kos_sim/client.py new file mode 100644 index 0000000..04e8de2 --- /dev/null +++ b/kos_sim/client.py @@ -0,0 +1,18 @@ +"""Client loop for KOS.""" + +import asyncio + +import pykos + +from kos_sim import logger + + +async def main() -> None: + kos = pykos.KOS(ip="localhost", port=50051) + sim = kos.simulation + + sim.set_parameters(time_scale=1.0) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/kos_sim/config.py b/kos_sim/config.py index b1fecb9..551cb31 100644 --- a/kos_sim/config.py +++ b/kos_sim/config.py @@ -50,15 +50,15 @@ def from_file(cls, config_path: str) -> "SimulatorConfig": def default(cls) -> "SimulatorConfig": """Create default config with standard joint mappings.""" joint_name_to_id = { - "L_hip_y": 1, - "L_hip_x": 2, - "L_hip_z": 3, - "L_knee": 4, - "L_ankle_y": 5, - "R_hip_y": 6, - "R_hip_x": 7, - "R_hip_z": 8, - "R_knee": 9, - "R_ankle_y": 10, + "L_hip_y_04": 1, + "L_hip_x_03": 2, + "L_hip_z_03": 3, + "L_knee_04": 4, + "L_ankle_02": 5, + "R_hip_y_04": 6, + "R_hip_x_03": 7, + "R_hip_z_03": 8, + "R_knee_04": 9, + "R_ankle_02": 10, } return cls(joint_id_to_name={v: k for k, v in joint_name_to_id.items()}, joint_name_to_id=joint_name_to_id) diff --git a/kos_sim/server.py b/kos_sim/server.py index a4fa7f6..de95091 100644 --- a/kos_sim/server.py +++ b/kos_sim/server.py @@ -1,13 +1,14 @@ """Server and simulation loop for KOS.""" import argparse -import threading +import asyncio import time from concurrent import futures import colorlogging import grpc from kos_protos import actuator_pb2_grpc, imu_pb2_grpc, sim_pb2_grpc +from kscale import K from kos_sim import logger from kos_sim.config import SimulatorConfig @@ -32,21 +33,20 @@ def __init__( self.simulator = MujocoSimulator(model_path, config=config, render=True) self.step_controller = StepController(self.simulator, mode=step_mode) self.port = port - self._stop_event = threading.Event() - self._grpc_thread: threading.Thread | None = None + self._stop_event = asyncio.Event() self._server = None - def _grpc_server_loop(self) -> None: - """Run the gRPC server in a separate thread.""" - # Create gRPC server - self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + async def _grpc_server_loop(self) -> None: + """Run the async gRPC server.""" + # Create async gRPC server + self._server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10)) assert self._server is not None - # Add our services + # Add our services (these need to be modified to be async as well) actuator_service = ActuatorService(self.simulator) imu_service = IMUService(self.simulator) - sim_service = SimService(self.simulator) + sim_service = SimService(self.simulator, self.step_controller) actuator_pb2_grpc.add_ActuatorServiceServicer_to_server(actuator_service, self._server) imu_pb2_grpc.add_IMUServiceServicer_to_server(imu_service, self._server) @@ -54,57 +54,75 @@ def _grpc_server_loop(self) -> None: # Start the server self._server.add_insecure_port(f"[::]:{self.port}") - self._server.start() + await self._server.start() logger.info("Server started on port %d", self.port) + await self._server.wait_for_termination() - # Wait for termination - self._server.wait_for_termination() - - def start(self) -> None: - """Start the gRPC server and run simulation in main thread.""" - self._grpc_thread = threading.Thread(target=self._grpc_server_loop) - self._grpc_thread.start() + async def simulation_loop(self) -> None: + """Run the simulation loop asynchronously.""" + last_update = time.time() try: while not self._stop_event.is_set(): - process_start = time.time() + current_time = time.time() + sim_time = current_time - last_update + last_update = current_time if self.step_controller.should_step(): - self.simulator.step() + while sim_time > 0: + self.simulator.step() + sim_time -= self.simulator.timestep + + self.simulator.render() + # Add a small sleep to prevent the loop from consuming too much CPU + await asyncio.sleep(0.001) + + except Exception as e: + logger.error("Simulation loop failed: %s", e) - sleep_time = max(0, self.simulator._config.dt - (time.time() - process_start)) - time.sleep(sleep_time) + finally: + await self.stop() - except KeyboardInterrupt: - self.stop() + async def start(self) -> None: + """Start both the gRPC server and simulation loop asynchronously.""" + grpc_task = asyncio.create_task(self._grpc_server_loop()) + sim_task = asyncio.create_task(self.simulation_loop()) - def stop(self) -> None: - """Stop the simulation and cleanup resources.""" + try: + await asyncio.gather(grpc_task, sim_task) + except asyncio.CancelledError: + await self.stop() + + async def stop(self) -> None: + """Stop the simulation and cleanup resources asynchronously.""" logger.info("Shutting down simulation...") self._stop_event.set() if self._server is not None: - self._server.stop(0) - if self._grpc_thread is not None: - self._grpc_thread.join() + await self._server.stop(0) self.simulator.close() -def serve(model_path: str, config_path: str | None = None, port: int = 50051) -> None: +async def serve(model_name: str, config_path: str | None = None, port: int = 50051) -> None: + api = K() + model_dir = await api.download_and_extract_urdf(model_name) + model_path = next(model_dir.glob("*.mjcf")) + server = SimulationServer(model_path, config_path=config_path, port=port) - server.start() + await server.start() -def run_server() -> None: +async def run_server() -> None: parser = argparse.ArgumentParser(description="Start the simulation gRPC server.") - parser.add_argument("--model-path", type=str, required=True, help="Path to MuJoCo XML model file") + parser.add_argument("model_name", type=str, help="Name of the model to simulate") parser.add_argument("--port", type=int, default=50051, help="Port to listen on") parser.add_argument("--config-path", type=str, default=None, help="Path to config file") colorlogging.configure() args = parser.parse_args() - serve(args.model_path, args.config_path, args.port) + await serve(args.model_name, args.config_path, args.port) if __name__ == "__main__": - run_server() + # python -m kos_sim.server + asyncio.run(run_server()) diff --git a/kos_sim/simulator.py b/kos_sim/simulator.py index 92768b5..db96c65 100644 --- a/kos_sim/simulator.py +++ b/kos_sim/simulator.py @@ -207,10 +207,10 @@ async def test_simulation_adhoc( model_name: str, duration: float = 5.0, speed: float = 1.0, render: bool = True ) -> None: api = K() - bot_dir = await api.download_and_extract_urdf(model_name) - bot_mjcf = next(bot_dir.glob("*.mjcf")) + model_dir = await api.download_and_extract_urdf(model_name) + model_path = next(model_dir.glob("*.mjcf")) - simulator = MujocoSimulator(bot_mjcf, render=render) + simulator = MujocoSimulator(model_path, render=render) timestep = simulator.timestep initial_update = last_update = asyncio.get_event_loop().time() diff --git a/kos_sim/test_client.py b/kos_sim/test_client.py index f6e7c72..b3aa24f 100644 --- a/kos_sim/test_client.py +++ b/kos_sim/test_client.py @@ -108,4 +108,5 @@ def main() -> None: if __name__ == "__main__": + # python -m kos_sim.test_client main()