Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Jan 17, 2025
1 parent 2e35eff commit e566d4d
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 47 deletions.
18 changes: 18 additions & 0 deletions kos_sim/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Client loop for KOS."""

import asyncio

import pykos

from kos_sim import logger


async def main() -> None:
kos = pykos.KOS(ip="localhost", port=50051)
sim = kos.simulation

sim.set_parameters(time_scale=1.0)


if __name__ == "__main__":
asyncio.run(main())
20 changes: 10 additions & 10 deletions kos_sim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def from_file(cls, config_path: str) -> "SimulatorConfig":
def default(cls) -> "SimulatorConfig":
"""Create default config with standard joint mappings."""
joint_name_to_id = {
"L_hip_y": 1,
"L_hip_x": 2,
"L_hip_z": 3,
"L_knee": 4,
"L_ankle_y": 5,
"R_hip_y": 6,
"R_hip_x": 7,
"R_hip_z": 8,
"R_knee": 9,
"R_ankle_y": 10,
"L_hip_y_04": 1,
"L_hip_x_03": 2,
"L_hip_z_03": 3,
"L_knee_04": 4,
"L_ankle_02": 5,
"R_hip_y_04": 6,
"R_hip_x_03": 7,
"R_hip_z_03": 8,
"R_knee_04": 9,
"R_ankle_02": 10,
}
return cls(joint_id_to_name={v: k for k, v in joint_name_to_id.items()}, joint_name_to_id=joint_name_to_id)
86 changes: 52 additions & 34 deletions kos_sim/server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Server and simulation loop for KOS."""

import argparse
import threading
import asyncio
import time
from concurrent import futures

import colorlogging
import grpc
from kos_protos import actuator_pb2_grpc, imu_pb2_grpc, sim_pb2_grpc
from kscale import K

from kos_sim import logger
from kos_sim.config import SimulatorConfig
Expand All @@ -32,79 +33,96 @@ def __init__(
self.simulator = MujocoSimulator(model_path, config=config, render=True)
self.step_controller = StepController(self.simulator, mode=step_mode)
self.port = port
self._stop_event = threading.Event()
self._grpc_thread: threading.Thread | None = None
self._stop_event = asyncio.Event()
self._server = None

def _grpc_server_loop(self) -> None:
"""Run the gRPC server in a separate thread."""
# Create gRPC server
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
async def _grpc_server_loop(self) -> None:
"""Run the async gRPC server."""
# Create async gRPC server
self._server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10))

assert self._server is not None

# Add our services
# Add our services (these need to be modified to be async as well)
actuator_service = ActuatorService(self.simulator)
imu_service = IMUService(self.simulator)
sim_service = SimService(self.simulator)
sim_service = SimService(self.simulator, self.step_controller)

actuator_pb2_grpc.add_ActuatorServiceServicer_to_server(actuator_service, self._server)
imu_pb2_grpc.add_IMUServiceServicer_to_server(imu_service, self._server)
sim_pb2_grpc.add_SimulationServiceServicer_to_server(sim_service, self._server)

# Start the server
self._server.add_insecure_port(f"[::]:{self.port}")
self._server.start()
await self._server.start()
logger.info("Server started on port %d", self.port)
await self._server.wait_for_termination()

# Wait for termination
self._server.wait_for_termination()

def start(self) -> None:
"""Start the gRPC server and run simulation in main thread."""
self._grpc_thread = threading.Thread(target=self._grpc_server_loop)
self._grpc_thread.start()
async def simulation_loop(self) -> None:
"""Run the simulation loop asynchronously."""
last_update = time.time()

try:
while not self._stop_event.is_set():
process_start = time.time()
current_time = time.time()
sim_time = current_time - last_update
last_update = current_time

if self.step_controller.should_step():
self.simulator.step()
while sim_time > 0:
self.simulator.step()
sim_time -= self.simulator.timestep

self.simulator.render()
# Add a small sleep to prevent the loop from consuming too much CPU
await asyncio.sleep(0.001)

except Exception as e:
logger.error("Simulation loop failed: %s", e)

sleep_time = max(0, self.simulator._config.dt - (time.time() - process_start))
time.sleep(sleep_time)
finally:
await self.stop()

except KeyboardInterrupt:
self.stop()
async def start(self) -> None:
"""Start both the gRPC server and simulation loop asynchronously."""
grpc_task = asyncio.create_task(self._grpc_server_loop())
sim_task = asyncio.create_task(self.simulation_loop())

def stop(self) -> None:
"""Stop the simulation and cleanup resources."""
try:
await asyncio.gather(grpc_task, sim_task)
except asyncio.CancelledError:
await self.stop()

async def stop(self) -> None:
"""Stop the simulation and cleanup resources asynchronously."""
logger.info("Shutting down simulation...")
self._stop_event.set()
if self._server is not None:
self._server.stop(0)
if self._grpc_thread is not None:
self._grpc_thread.join()
await self._server.stop(0)
self.simulator.close()


def serve(model_path: str, config_path: str | None = None, port: int = 50051) -> None:
async def serve(model_name: str, config_path: str | None = None, port: int = 50051) -> None:
api = K()
model_dir = await api.download_and_extract_urdf(model_name)
model_path = next(model_dir.glob("*.mjcf"))

server = SimulationServer(model_path, config_path=config_path, port=port)
server.start()
await server.start()


def run_server() -> None:
async def run_server() -> None:
parser = argparse.ArgumentParser(description="Start the simulation gRPC server.")
parser.add_argument("--model-path", type=str, required=True, help="Path to MuJoCo XML model file")
parser.add_argument("model_name", type=str, help="Name of the model to simulate")
parser.add_argument("--port", type=int, default=50051, help="Port to listen on")
parser.add_argument("--config-path", type=str, default=None, help="Path to config file")

colorlogging.configure()

args = parser.parse_args()
serve(args.model_path, args.config_path, args.port)
await serve(args.model_name, args.config_path, args.port)


if __name__ == "__main__":
run_server()
# python -m kos_sim.server
asyncio.run(run_server())
6 changes: 3 additions & 3 deletions kos_sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ async def test_simulation_adhoc(
model_name: str, duration: float = 5.0, speed: float = 1.0, render: bool = True
) -> None:
api = K()
bot_dir = await api.download_and_extract_urdf(model_name)
bot_mjcf = next(bot_dir.glob("*.mjcf"))
model_dir = await api.download_and_extract_urdf(model_name)
model_path = next(model_dir.glob("*.mjcf"))

simulator = MujocoSimulator(bot_mjcf, render=render)
simulator = MujocoSimulator(model_path, render=render)

timestep = simulator.timestep
initial_update = last_update = asyncio.get_event_loop().time()
Expand Down
1 change: 1 addition & 0 deletions kos_sim/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,5 @@ def main() -> None:


if __name__ == "__main__":
# python -m kos_sim.test_client
main()

0 comments on commit e566d4d

Please sign in to comment.