8
8
from omegaconf import DictConfig
9
9
10
10
from actsafe .common .learner import Learner
11
- from actsafe .la_mbda import rssm
12
- from actsafe .la_mbda .exploration import make_exploration
13
- from actsafe .la_mbda .make_actor_critic import make_actor_critic
14
- from actsafe .la_mbda .multi_reward import MultiRewardBridge
15
- from actsafe .la_mbda .replay_buffer import ReplayBuffer
16
- from actsafe .la_mbda .sentiment import make_sentiment
17
- from actsafe .la_mbda .world_model import WorldModel , evaluate_model , variational_step
11
+ from actsafe .actsafe import rssm
12
+ from actsafe .actsafe .exploration import UniformExploration , make_exploration
13
+ from actsafe .actsafe .make_actor_critic import make_actor_critic
14
+ from actsafe .actsafe .multi_reward import MultiRewardBridge
15
+ from actsafe .actsafe .replay_buffer import ReplayBuffer
16
+ from actsafe .actsafe .sentiment import make_sentiment
17
+ from actsafe .actsafe .world_model import WorldModel , evaluate_model , variational_step
18
18
from actsafe .rl .epoch_summary import EpochSummary
19
19
from actsafe .rl .metrics import MetricsMonitor
20
20
from actsafe .rl .trajectory import TrajectoryData , Transition
@@ -53,7 +53,7 @@ def init(cls, batch_size: int, cell: rssm.RSSM, action_dim: int) -> "AgentState"
53
53
return self
54
54
55
55
56
- class LaMBDA :
56
+ class ActSafe :
57
57
def __init__ (
58
58
self ,
59
59
observation_space : Box ,
@@ -99,6 +99,7 @@ def __init__(
99
99
action_dim ,
100
100
next (self .prng ),
101
101
)
102
+ self .offline = UniformExploration (action_dim )
102
103
self .state = AgentState .init (
103
104
config .training .parallel_envs , self .model .cell , action_dim
104
105
)
@@ -112,6 +113,9 @@ def __init__(
112
113
self .should_explore = Until (
113
114
config .agent .exploration_steps , environment_steps_per_agent_step
114
115
)
116
+ self .should_collect_offline = Until (
117
+ config .agent .offline_steps , environment_steps_per_agent_step
118
+ )
115
119
learn_model_steps = (
116
120
config .agent .learn_model_steps
117
121
if config .agent .learn_model_steps is not None
@@ -128,12 +132,16 @@ def __call__(
128
132
) -> FloatArray :
129
133
if train and self .should_train () and not self .replay_buffer .empty :
130
134
self .update ()
131
- policy_fn = (
132
- self .exploration .get_policy ()
133
- if self .should_explore ()
134
- else self .actor_critic .actor .act
135
- )
135
+ if self .should_collect_offline ():
136
+ policy_fn = self .offline .get_policy ()
137
+ else :
138
+ policy_fn = (
139
+ self .exploration .get_policy ()
140
+ if self .should_explore ()
141
+ else self .actor_critic .actor .act
142
+ )
136
143
self .should_explore .tick ()
144
+ self .should_collect_offline .tick ()
137
145
self .learn_model .tick ()
138
146
actions , self .state = policy (
139
147
policy_fn ,
0 commit comments