Skip to content

Commit

Permalink
decimation
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Jan 14, 2025
1 parent b8a820b commit 7cf316a
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 35 deletions.
12 changes: 12 additions & 0 deletions kos_sim/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
import logging

# Configure root logger
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)

# Get logger for this package
logger = logging.getLogger("kos_sim")

__version__ = "0.0.1"
11 changes: 8 additions & 3 deletions kos_sim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ class SimulatorConfig:
joint_name_to_id: dict[str, int]
kp: float = 80.0
kd: float = 10.0
physics_freq: float = 1000.0 # Hz
command_freq: float = 50.0 # Hz
dt: float = 0.02
command_freq: float = 50.0 # Hz

@property
def physics_freq(self) -> float:
"""Calculate physics frequency from timestep."""
return 1.0 / self.dt

@property
def sim_decimation(self) -> int:
Expand All @@ -37,7 +42,7 @@ def from_file(cls, config_path: str) -> "SimulatorConfig":
joint_name_to_id=joint_name_to_id,
kp=control_config.get("kp", 80.0),
kd=control_config.get("kd", 10.0),
physics_freq=control_config.get("physics_freq", 1000.0),
dt=control_config.get("dt", 0.001),
command_freq=control_config.get("command_freq", 50.0),
)

Expand Down
20 changes: 6 additions & 14 deletions kos_sim/server.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,21 @@
"""Server and simulation loop for KOS."""

import argparse
import logging
import threading
import time
from concurrent import futures

import grpc
from kos_protos import actuator_pb2_grpc, imu_pb2_grpc

from kos_sim import logger
from kos_sim.config import SimulatorConfig
from kos_sim.services import ActuatorService, IMUService
from kos_sim.simulator import MujocoSimulator
from kos_sim.config import SimulatorConfig

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SimulationServer:
def __init__(
self,
model_path: str,
config_path: str | None = None,
port: int = 50051
) -> None:
def __init__(self, model_path: str, config_path: str | None = None, port: int = 50051) -> None:
if config_path:
config = SimulatorConfig.from_file(config_path)
else:
Expand Down Expand Up @@ -61,13 +53,13 @@ def start(self) -> None:
self._grpc_thread = threading.Thread(target=self._grpc_server_loop)
self._grpc_thread.start()

physics_dt = 1.0 / self.simulator._config.physics_freq
try:
while not self._stop_event.is_set():
process_start = time.time()
self.simulator.step()
# Sleep to maintain physics frequency
time.sleep(max(0, physics_dt - (time.time() - process_start)))
sleep_time = max(0, self.simulator._config.dt - (time.time() - process_start))
logger.debug("Sleeping for %f seconds", sleep_time)
time.sleep(sleep_time)
except KeyboardInterrupt:
self.stop()

Expand Down
4 changes: 1 addition & 3 deletions kos_sim/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def GetActuatorsState( # noqa: N802
return actuator_pb2.GetActuatorsStateResponse(
states=[
actuator_pb2.ActuatorStateResponse(
actuator_id=joint_id,
position=math.degrees(float(state)),
online=True
actuator_id=joint_id, position=math.degrees(float(state)), online=True
)
for joint_id, state in states.items()
]
Expand Down
12 changes: 4 additions & 8 deletions kos_sim/simulator.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
"""Wrapper around MuJoCo simulation."""

import logging
import threading

import mujoco
import mujoco_viewer
import numpy as np

from kos_sim import logger
from kos_sim.config import SimulatorConfig

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class MujocoSimulator:
def __init__(
self,
model_path: str,
config: SimulatorConfig | None = None,
render: bool = False,
dt: float = 0.002,
gravity: bool = True,
suspended: bool = False,
) -> None:
Expand All @@ -28,7 +24,7 @@ def __init__(

# Load MuJoCo model and initialize data
self._model = mujoco.MjModel.from_xml_path(model_path)
self._model.opt.timestep = dt
self._model.opt.timestep = self._config.dt # Use dt from config
self._data = mujoco.MjData(self._model)

self._gravity = gravity
Expand Down Expand Up @@ -114,8 +110,8 @@ def step(self) -> None:
for i in range(self._model.njnt):
if self._model.jnt_type[i] == mujoco.mjtJoint.mjJNT_FREE:
print(f"Joint name: {self._model.joint(i).name}")
self._data.qpos[i:i + 7] = self._model.keyframe("default").qpos[i:i + 7]
self._data.qvel[i:i + 6] = 0
self._data.qpos[i : i + 7] = self._model.keyframe("default").qpos[i : i + 7]
self._data.qvel[i : i + 6] = 0
break

if self._render_enabled:
Expand Down
4 changes: 1 addition & 3 deletions kos_sim/test_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Test script for the simulation server."""

import argparse
import logging
import math
import time

import pykos

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from kos_sim import logger


def test_actuator_commands(host: str = "localhost", port: int = 50051) -> None:
Expand Down
4 changes: 0 additions & 4 deletions kos_sim/test_simulator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
"""Test script for the simulator."""

import argparse
import logging
import time

from kos_sim.simulator import MujocoSimulator

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def test_simulation(model_path: str, duration: float = 5.0, speed: float = 1.0, render: bool = True) -> None:
simulator = MujocoSimulator(model_path, render=render)
Expand Down

0 comments on commit 7cf316a

Please sign in to comment.