Skip to content

Commit

Permalink
Cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
yardenas committed Oct 8, 2024
1 parent 9cd3d95 commit dcbe264
Show file tree
Hide file tree
Showing 21 changed files with 55 additions and 46 deletions.
File renamed without changes.
File renamed without changes.
34 changes: 21 additions & 13 deletions actsafe/la_mbda/la_mbda.py → actsafe/actsafe/actsafe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from omegaconf import DictConfig

from actsafe.common.learner import Learner
from actsafe.la_mbda import rssm
from actsafe.la_mbda.exploration import make_exploration
from actsafe.la_mbda.make_actor_critic import make_actor_critic
from actsafe.la_mbda.multi_reward import MultiRewardBridge
from actsafe.la_mbda.replay_buffer import ReplayBuffer
from actsafe.la_mbda.sentiment import make_sentiment
from actsafe.la_mbda.world_model import WorldModel, evaluate_model, variational_step
from actsafe.actsafe import rssm
from actsafe.actsafe.exploration import UniformExploration, make_exploration
from actsafe.actsafe.make_actor_critic import make_actor_critic
from actsafe.actsafe.multi_reward import MultiRewardBridge
from actsafe.actsafe.replay_buffer import ReplayBuffer
from actsafe.actsafe.sentiment import make_sentiment
from actsafe.actsafe.world_model import WorldModel, evaluate_model, variational_step
from actsafe.rl.epoch_summary import EpochSummary
from actsafe.rl.metrics import MetricsMonitor
from actsafe.rl.trajectory import TrajectoryData, Transition
Expand Down Expand Up @@ -53,7 +53,7 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState"
return self


class LaMBDA:
class ActSafe:
def __init__(
self,
observation_space: Box,
Expand Down Expand Up @@ -99,6 +99,7 @@ def __init__(
action_dim,
next(self.prng),
)
self.offline = UniformExploration(action_dim)
self.state = AgentState.init(
config.training.parallel_envs, self.model.cell, action_dim
)
Expand All @@ -112,6 +113,9 @@ def __init__(
self.should_explore = Until(
config.agent.exploration_steps, environment_steps_per_agent_step
)
self.should_collect_offline = Until(
config.agent.offline_steps, environment_steps_per_agent_step
)
learn_model_steps = (
config.agent.learn_model_steps
if config.agent.learn_model_steps is not None
Expand All @@ -128,12 +132,16 @@ def __call__(
) -> FloatArray:
if train and self.should_train() and not self.replay_buffer.empty:
self.update()
policy_fn = (
self.exploration.get_policy()
if self.should_explore()
else self.actor_critic.actor.act
)
if self.should_collect_offline():
policy_fn = self.offline.get_policy()
else:
policy_fn = (
self.exploration.get_policy()
if self.should_explore()
else self.actor_critic.actor.act
)
self.should_explore.tick()
self.should_collect_offline.tick()
self.learn_model.tick()
actions, self.state = policy(
policy_fn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import jax.numpy as jnp
from jaxtyping import PyTree

from actsafe.la_mbda.actor_critic import ContinuousActor
from actsafe.la_mbda.safe_actor_critic import ActorEvaluation
from actsafe.actsafe.actor_critic import ContinuousActor
from actsafe.actsafe.safe_actor_critic import ActorEvaluation


class AugmentedLagrangianUpdate(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import jax
from jaxtyping import PyTree

from actsafe.la_mbda.actor_critic import ContinuousActor
from actsafe.la_mbda.safe_actor_critic import ActorEvaluation
from actsafe.actsafe.actor_critic import ContinuousActor
from actsafe.actsafe.safe_actor_critic import ActorEvaluation


class DummyPenalizer:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import jax
from omegaconf import DictConfig

from actsafe.la_mbda.opax_bridge import OpaxBridge
from actsafe.la_mbda.make_actor_critic import make_actor_critic
from actsafe.la_mbda.sentiment import identity, make_sentiment
from actsafe.actsafe.opax_bridge import OpaxBridge
from actsafe.actsafe.make_actor_critic import make_actor_critic
from actsafe.actsafe.sentiment import identity, make_sentiment
from actsafe.rl.types import Model, Policy


Expand Down
4 changes: 2 additions & 2 deletions actsafe/la_mbda/lbsgd.py → actsafe/actsafe/lbsgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from actsafe.common.mixed_precision import apply_dtype
from actsafe.common.pytree_utils import pytrees_unstack
from actsafe.la_mbda.actor_critic import ContinuousActor
from actsafe.la_mbda.safe_actor_critic import ActorEvaluation
from actsafe.actsafe.actor_critic import ContinuousActor
from actsafe.actsafe.safe_actor_critic import ActorEvaluation

_EPS = 1e-8

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import numpy as np
from actsafe.la_mbda.augmented_lagrangian import AugmentedLagrangianPenalizer
from actsafe.la_mbda.dummy_penalizer import DummyPenalizer
from actsafe.la_mbda.lbsgd import LBSGDPenalizer
from actsafe.la_mbda.safe_actor_critic import SafeModelBasedActorCritic
from actsafe.la_mbda.sentiment import bayes
from actsafe.actsafe.augmented_lagrangian import AugmentedLagrangianPenalizer
from actsafe.actsafe.dummy_penalizer import DummyPenalizer
from actsafe.actsafe.lbsgd import LBSGDPenalizer
from actsafe.actsafe.safe_actor_critic import SafeModelBasedActorCritic
from actsafe.actsafe.sentiment import bayes


_LOG = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import equinox as eqx
from actsafe.la_mbda.rssm import ShiftScale, State
from actsafe.la_mbda.world_model import WorldModel
from actsafe.actsafe.rssm import ShiftScale, State
from actsafe.actsafe.world_model import WorldModel
from actsafe.rl.types import Policy, Prediction


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import jax
import equinox as eqx
from actsafe import opax
from actsafe.la_mbda.rssm import ShiftScale, State
from actsafe.la_mbda.world_model import WorldModel
from actsafe.actsafe.rssm import ShiftScale, State
from actsafe.actsafe.world_model import WorldModel
from actsafe.rl.types import Policy, Prediction


Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from actsafe.common.learner import Learner
from actsafe.common.mixed_precision import apply_mixed_precision
from actsafe.la_mbda.rssm import ShiftScale
from actsafe.la_mbda.sentiment import Sentiment
from actsafe.la_mbda.actor_critic import ContinuousActor, Critic, actor_entropy
from actsafe.actsafe.rssm import ShiftScale
from actsafe.actsafe.sentiment import Sentiment
from actsafe.actsafe.actor_critic import ContinuousActor, Critic, actor_entropy
from actsafe.opax import normalized_epistemic_uncertainty
from actsafe.rl.types import Model, RolloutFn
from actsafe.rl.utils import nest_vmap
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Protocol
import jax

from actsafe.la_mbda.rssm import ShiftScale
from actsafe.actsafe.rssm import ShiftScale
from actsafe.opax import normalized_epistemic_uncertainty


Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from actsafe.common.learner import Learner
from actsafe.common.mixed_precision import apply_mixed_precision
from actsafe.la_mbda.rssm import RSSM, Features, ShiftScale, State
from actsafe.actsafe.rssm import RSSM, Features, ShiftScale, State
from actsafe.rl.types import Prediction
from actsafe.la_mbda.utils import marginalize_prediction
from actsafe.actsafe.utils import marginalize_prediction
from actsafe.rl.types import Policy
from actsafe.rl.utils import nest_vmap

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- penalizer: lbsgd

name: lambda
name: actsafe
replay_buffer:
batch_size: 16
sequence_length: 50
Expand Down Expand Up @@ -52,6 +52,7 @@ safety_slack: 0.
evaluate_model: false
exploration_strategy: uniform
exploration_steps: 5000
offline_steps: 200000
learn_model_steps: null
exploration_reward_scale: 10.0
exploration_epistemic_scale: 1.
Expand Down
2 changes: 1 addition & 1 deletion actsafe/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- agent: la_mbda
- agent: actsafe
- environment: safe_adaptation_gym

hydra:
Expand Down
2 changes: 1 addition & 1 deletion actsafe/opax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import jax.numpy as jnp
from actsafe.la_mbda.rssm import ShiftScale
from actsafe.actsafe.rssm import ShiftScale
from actsafe.rl.types import Prediction

_EPS = 1e-5
Expand Down
12 changes: 6 additions & 6 deletions actsafe/rl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from actsafe import benchmark_suites
from actsafe.la_mbda.la_mbda import LaMBDA
from actsafe.actsafe.actsafe import ActSafe
from actsafe.rl import acting, episodic_async_env
from actsafe.rl.epoch_summary import EpochSummary
from actsafe.rl.logging import StateWriter, TrainingLogger
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
self,
config: DictConfig,
make_env: EnvironmentFactory,
agent: LaMBDA | None = None,
agent: ActSafe | None = None,
start_epoch: int = 0,
step: int = 0,
seeds: PRNGSequence | None = None,
Expand Down Expand Up @@ -88,10 +88,10 @@ def __enter__(self):
self.agent = self.make_agent()
return self

def make_agent(self) -> LaMBDA:
def make_agent(self) -> ActSafe:
assert self.env is not None
if self.config.agent.name == "lambda":
agent = LaMBDA(
if self.config.agent.name == "actsafe":
agent = ActSafe(
self.env.observation_space,
self.env.action_space,
self.config,
Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(
self,
config: DictConfig,
make_env: EnvironmentFactory,
agent: LaMBDA | None = None,
agent: ActSafe | None = None,
start_epoch: int = 0,
step: int = 0,
seeds: PRNGSequence | None = None,
Expand Down

0 comments on commit dcbe264

Please sign in to comment.