Skip to content

Commit dcbe264

Browse files
committed
Cleanups
1 parent 9cd3d95 commit dcbe264

21 files changed

+55
-46
lines changed
File renamed without changes.
File renamed without changes.

actsafe/la_mbda/la_mbda.py actsafe/actsafe/actsafe.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from omegaconf import DictConfig
99

1010
from actsafe.common.learner import Learner
11-
from actsafe.la_mbda import rssm
12-
from actsafe.la_mbda.exploration import make_exploration
13-
from actsafe.la_mbda.make_actor_critic import make_actor_critic
14-
from actsafe.la_mbda.multi_reward import MultiRewardBridge
15-
from actsafe.la_mbda.replay_buffer import ReplayBuffer
16-
from actsafe.la_mbda.sentiment import make_sentiment
17-
from actsafe.la_mbda.world_model import WorldModel, evaluate_model, variational_step
11+
from actsafe.actsafe import rssm
12+
from actsafe.actsafe.exploration import UniformExploration, make_exploration
13+
from actsafe.actsafe.make_actor_critic import make_actor_critic
14+
from actsafe.actsafe.multi_reward import MultiRewardBridge
15+
from actsafe.actsafe.replay_buffer import ReplayBuffer
16+
from actsafe.actsafe.sentiment import make_sentiment
17+
from actsafe.actsafe.world_model import WorldModel, evaluate_model, variational_step
1818
from actsafe.rl.epoch_summary import EpochSummary
1919
from actsafe.rl.metrics import MetricsMonitor
2020
from actsafe.rl.trajectory import TrajectoryData, Transition
@@ -53,7 +53,7 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState"
5353
return self
5454

5555

56-
class LaMBDA:
56+
class ActSafe:
5757
def __init__(
5858
self,
5959
observation_space: Box,
@@ -99,6 +99,7 @@ def __init__(
9999
action_dim,
100100
next(self.prng),
101101
)
102+
self.offline = UniformExploration(action_dim)
102103
self.state = AgentState.init(
103104
config.training.parallel_envs, self.model.cell, action_dim
104105
)
@@ -112,6 +113,9 @@ def __init__(
112113
self.should_explore = Until(
113114
config.agent.exploration_steps, environment_steps_per_agent_step
114115
)
116+
self.should_collect_offline = Until(
117+
config.agent.offline_steps, environment_steps_per_agent_step
118+
)
115119
learn_model_steps = (
116120
config.agent.learn_model_steps
117121
if config.agent.learn_model_steps is not None
@@ -128,12 +132,16 @@ def __call__(
128132
) -> FloatArray:
129133
if train and self.should_train() and not self.replay_buffer.empty:
130134
self.update()
131-
policy_fn = (
132-
self.exploration.get_policy()
133-
if self.should_explore()
134-
else self.actor_critic.actor.act
135-
)
135+
if self.should_collect_offline():
136+
policy_fn = self.offline.get_policy()
137+
else:
138+
policy_fn = (
139+
self.exploration.get_policy()
140+
if self.should_explore()
141+
else self.actor_critic.actor.act
142+
)
136143
self.should_explore.tick()
144+
self.should_collect_offline.tick()
137145
self.learn_model.tick()
138146
actions, self.state = policy(
139147
policy_fn,

actsafe/la_mbda/augmented_lagrangian.py actsafe/actsafe/augmented_lagrangian.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import jax.numpy as jnp
77
from jaxtyping import PyTree
88

9-
from actsafe.la_mbda.actor_critic import ContinuousActor
10-
from actsafe.la_mbda.safe_actor_critic import ActorEvaluation
9+
from actsafe.actsafe.actor_critic import ContinuousActor
10+
from actsafe.actsafe.safe_actor_critic import ActorEvaluation
1111

1212

1313
class AugmentedLagrangianUpdate(NamedTuple):

actsafe/la_mbda/dummy_penalizer.py actsafe/actsafe/dummy_penalizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import jax
55
from jaxtyping import PyTree
66

7-
from actsafe.la_mbda.actor_critic import ContinuousActor
8-
from actsafe.la_mbda.safe_actor_critic import ActorEvaluation
7+
from actsafe.actsafe.actor_critic import ContinuousActor
8+
from actsafe.actsafe.safe_actor_critic import ActorEvaluation
99

1010

1111
class DummyPenalizer:

actsafe/la_mbda/exploration.py actsafe/actsafe/exploration.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import jax
22
from omegaconf import DictConfig
33

4-
from actsafe.la_mbda.opax_bridge import OpaxBridge
5-
from actsafe.la_mbda.make_actor_critic import make_actor_critic
6-
from actsafe.la_mbda.sentiment import identity, make_sentiment
4+
from actsafe.actsafe.opax_bridge import OpaxBridge
5+
from actsafe.actsafe.make_actor_critic import make_actor_critic
6+
from actsafe.actsafe.sentiment import identity, make_sentiment
77
from actsafe.rl.types import Model, Policy
88

99

actsafe/la_mbda/lbsgd.py actsafe/actsafe/lbsgd.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from actsafe.common.mixed_precision import apply_dtype
1111
from actsafe.common.pytree_utils import pytrees_unstack
12-
from actsafe.la_mbda.actor_critic import ContinuousActor
13-
from actsafe.la_mbda.safe_actor_critic import ActorEvaluation
12+
from actsafe.actsafe.actor_critic import ContinuousActor
13+
from actsafe.actsafe.safe_actor_critic import ActorEvaluation
1414

1515
_EPS = 1e-8
1616

actsafe/la_mbda/make_actor_critic.py actsafe/actsafe/make_actor_critic.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
import numpy as np
3-
from actsafe.la_mbda.augmented_lagrangian import AugmentedLagrangianPenalizer
4-
from actsafe.la_mbda.dummy_penalizer import DummyPenalizer
5-
from actsafe.la_mbda.lbsgd import LBSGDPenalizer
6-
from actsafe.la_mbda.safe_actor_critic import SafeModelBasedActorCritic
7-
from actsafe.la_mbda.sentiment import bayes
3+
from actsafe.actsafe.augmented_lagrangian import AugmentedLagrangianPenalizer
4+
from actsafe.actsafe.dummy_penalizer import DummyPenalizer
5+
from actsafe.actsafe.lbsgd import LBSGDPenalizer
6+
from actsafe.actsafe.safe_actor_critic import SafeModelBasedActorCritic
7+
from actsafe.actsafe.sentiment import bayes
88

99

1010
_LOG = logging.getLogger(__name__)

actsafe/la_mbda/multi_reward.py actsafe/actsafe/multi_reward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22
import equinox as eqx
3-
from actsafe.la_mbda.rssm import ShiftScale, State
4-
from actsafe.la_mbda.world_model import WorldModel
3+
from actsafe.actsafe.rssm import ShiftScale, State
4+
from actsafe.actsafe.world_model import WorldModel
55
from actsafe.rl.types import Policy, Prediction
66

77

actsafe/la_mbda/opax_bridge.py actsafe/actsafe/opax_bridge.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import jax
22
import equinox as eqx
33
from actsafe import opax
4-
from actsafe.la_mbda.rssm import ShiftScale, State
5-
from actsafe.la_mbda.world_model import WorldModel
4+
from actsafe.actsafe.rssm import ShiftScale, State
5+
from actsafe.actsafe.world_model import WorldModel
66
from actsafe.rl.types import Policy, Prediction
77

88

File renamed without changes.
File renamed without changes.

actsafe/la_mbda/safe_actor_critic.py actsafe/actsafe/safe_actor_critic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
from actsafe.common.learner import Learner
1111
from actsafe.common.mixed_precision import apply_mixed_precision
12-
from actsafe.la_mbda.rssm import ShiftScale
13-
from actsafe.la_mbda.sentiment import Sentiment
14-
from actsafe.la_mbda.actor_critic import ContinuousActor, Critic, actor_entropy
12+
from actsafe.actsafe.rssm import ShiftScale
13+
from actsafe.actsafe.sentiment import Sentiment
14+
from actsafe.actsafe.actor_critic import ContinuousActor, Critic, actor_entropy
1515
from actsafe.opax import normalized_epistemic_uncertainty
1616
from actsafe.rl.types import Model, RolloutFn
1717
from actsafe.rl.utils import nest_vmap

actsafe/la_mbda/sentiment.py actsafe/actsafe/sentiment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Protocol
22
import jax
33

4-
from actsafe.la_mbda.rssm import ShiftScale
4+
from actsafe.actsafe.rssm import ShiftScale
55
from actsafe.opax import normalized_epistemic_uncertainty
66

77

File renamed without changes.
File renamed without changes.

actsafe/la_mbda/world_model.py actsafe/actsafe/world_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from actsafe.common.learner import Learner
1010
from actsafe.common.mixed_precision import apply_mixed_precision
11-
from actsafe.la_mbda.rssm import RSSM, Features, ShiftScale, State
11+
from actsafe.actsafe.rssm import RSSM, Features, ShiftScale, State
1212
from actsafe.rl.types import Prediction
13-
from actsafe.la_mbda.utils import marginalize_prediction
13+
from actsafe.actsafe.utils import marginalize_prediction
1414
from actsafe.rl.types import Policy
1515
from actsafe.rl.utils import nest_vmap
1616

actsafe/configs/agent/la_mbda.yaml actsafe/configs/agent/actsafe.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
defaults:
22
- penalizer: lbsgd
33

4-
name: lambda
4+
name: actsafe
55
replay_buffer:
66
batch_size: 16
77
sequence_length: 50
@@ -52,6 +52,7 @@ safety_slack: 0.
5252
evaluate_model: false
5353
exploration_strategy: uniform
5454
exploration_steps: 5000
55+
offline_steps: 200000
5556
learn_model_steps: null
5657
exploration_reward_scale: 10.0
5758
exploration_epistemic_scale: 1.

actsafe/configs/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
defaults:
22
- _self_
3-
- agent: la_mbda
3+
- agent: actsafe
44
- environment: safe_adaptation_gym
55

66
hydra:

actsafe/opax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import jax
22
import jax.numpy as jnp
3-
from actsafe.la_mbda.rssm import ShiftScale
3+
from actsafe.actsafe.rssm import ShiftScale
44
from actsafe.rl.types import Prediction
55

66
_EPS = 1e-5

actsafe/rl/trainer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from actsafe import benchmark_suites
11-
from actsafe.la_mbda.la_mbda import LaMBDA
11+
from actsafe.actsafe.actsafe import ActSafe
1212
from actsafe.rl import acting, episodic_async_env
1313
from actsafe.rl.epoch_summary import EpochSummary
1414
from actsafe.rl.logging import StateWriter, TrainingLogger
@@ -57,7 +57,7 @@ def __init__(
5757
self,
5858
config: DictConfig,
5959
make_env: EnvironmentFactory,
60-
agent: LaMBDA | None = None,
60+
agent: ActSafe | None = None,
6161
start_epoch: int = 0,
6262
step: int = 0,
6363
seeds: PRNGSequence | None = None,
@@ -88,10 +88,10 @@ def __enter__(self):
8888
self.agent = self.make_agent()
8989
return self
9090

91-
def make_agent(self) -> LaMBDA:
91+
def make_agent(self) -> ActSafe:
9292
assert self.env is not None
93-
if self.config.agent.name == "lambda":
94-
agent = LaMBDA(
93+
if self.config.agent.name == "actsafe":
94+
agent = ActSafe(
9595
self.env.observation_space,
9696
self.env.action_space,
9797
self.config,
@@ -198,7 +198,7 @@ def __init__(
198198
self,
199199
config: DictConfig,
200200
make_env: EnvironmentFactory,
201-
agent: LaMBDA | None = None,
201+
agent: ActSafe | None = None,
202202
start_epoch: int = 0,
203203
step: int = 0,
204204
seeds: PRNGSequence | None = None,

0 commit comments

Comments
 (0)