From 3bf7660c7f498b0aa2632d9fd679551b2cc8551f Mon Sep 17 00:00:00 2001 From: Jiayi Zhou <108712610+Gaiejj@users.noreply.github.com> Date: Mon, 27 Mar 2023 01:31:16 +0800 Subject: [PATCH] refactor(wrapper): refactor the cuda setting (#176) * refactor(wrapper): refactor the cuda setting * chore: revert train_policy.py * chore: set device in safety_gymnasium_env.py * fix: [pre-commit.ci] auto fixes [...] * fix(safety_gymnasium_env.py): fix device interface --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- omnisafe/adapter/offpolicy_adapter.py | 16 ++--- omnisafe/adapter/online_adapter.py | 33 ++++------ omnisafe/adapter/saute_adapter.py | 10 ++-- .../algorithms/on_policy/second_order/cpo.py | 41 +++++++------ .../algorithms/on_policy/second_order/pcpo.py | 9 ++- omnisafe/envs/core.py | 3 +- omnisafe/envs/safety_gymnasium_env.py | 19 ++++-- omnisafe/envs/wrapper.py | 60 +++++++++++-------- omnisafe/evaluator.py | 6 +- 9 files changed, 102 insertions(+), 95 deletions(-) diff --git a/omnisafe/adapter/offpolicy_adapter.py b/omnisafe/adapter/offpolicy_adapter.py index 0b9e3b03e..8bde287e8 100644 --- a/omnisafe/adapter/offpolicy_adapter.py +++ b/omnisafe/adapter/offpolicy_adapter.py @@ -14,11 +14,9 @@ # ============================================================================== """OffPolicy Adapter for OmniSafe.""" -from functools import partial from typing import Dict, Optional import torch -from gymnasium import spaces from omnisafe.adapter.online_adapter import OnlineAdapter from omnisafe.common.buffer import VectorOffPolicyBuffer @@ -60,16 +58,14 @@ def roll_out( # pylint: disable=too-many-locals logger (Logger): Logger. use_rand_action (bool): Whether to use random action. """ - if use_rand_action: - if isinstance(self._env.action_space, spaces.Box): - act_fn = partial( - torch.rand, size=(self._env.num_envs, *self._env.action_space.shape) + for _ in range(roll_out_step): + if use_rand_action: + act = torch.as_tensor(self._env.sample_action(), dtype=torch.float32).to( + self._device ) - else: - act_fn = partial(agent.step, self._current_obs, deterministic=False) + else: + act = agent.step(self._current_obs, deterministic=False) - for _ in range(roll_out_step): - act = act_fn() next_obs, reward, cost, terminated, truncated, info = self.step(act) self._log_value(reward=reward, cost=cost, info=info) diff --git a/omnisafe/adapter/online_adapter.py b/omnisafe/adapter/online_adapter.py index 40be14825..fc6b1e828 100644 --- a/omnisafe/adapter/online_adapter.py +++ b/omnisafe/adapter/online_adapter.py @@ -45,7 +45,9 @@ def __init__( # pylint: disable=too-many-arguments assert env_id in support_envs(), f'Env {env_id} is not supported.' self._env_id = env_id - self._env = make(env_id, num_envs=num_envs) + self._env = make(env_id, num_envs=num_envs, device=cfgs.train_cfgs.device) + self._cfgs = cfgs + self._device = cfgs.train_cfgs.device self._wrapper( obs_normalize=cfgs.algo_cfgs.obs_normalize, reward_normalize=cfgs.algo_cfgs.reward_normalize, @@ -53,9 +55,6 @@ def __init__( # pylint: disable=too-many-arguments ) self._env.set_seed(seed) - self._cfgs = cfgs - self._device = cfgs.train_cfgs.device - def _wrapper( self, obs_normalize: bool = True, @@ -63,18 +62,18 @@ def _wrapper( cost_normalize: bool = True, ): if self._env.need_time_limit_wrapper: - self._env = TimeLimit(self._env, time_limit=1000) + self._env = TimeLimit(self._env, device=self._device, time_limit=1000) if self._env.need_auto_reset_wrapper: - self._env = AutoReset(self._env) + self._env = AutoReset(self._env, device=self._device) if obs_normalize: - self._env = ObsNormalize(self._env) + self._env = ObsNormalize(self._env, device=self._device) if reward_normalize: - self._env = RewardNormalize(self._env) + self._env = RewardNormalize(self._env, device=self._device) if cost_normalize: - self._env = CostNormalize(self._env) - self._env = ActionScale(self._env, low=-1.0, high=1.0) + self._env = CostNormalize(self._env, device=self._device) + self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0) if self._env.num_envs == 1: - self._env = Unsqueeze(self._env) + self._env = Unsqueeze(self._env, device=self._device) @property def action_space(self) -> OmnisafeSpace: @@ -111,14 +110,7 @@ def step( truncated (torch.Tensor): whether the episode has been truncated due to a time limit. info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning). """ - obs, reward, cost, terminated, truncated, info = self._env.step(action) - obs, reward, cost, terminated, truncated = map( - lambda x: x.to(self._device), - (obs, reward, cost, terminated, truncated), - ) - if info.get('final_observation') is not None: - info['final_observation'] = info['final_observation'].to(self._device) - return obs, reward, cost, terminated, truncated, info + return self._env.step(action) def reset(self) -> Tuple[torch.Tensor, Dict]: """Resets the environment and returns an initial observation. @@ -130,8 +122,7 @@ def reset(self) -> Tuple[torch.Tensor, Dict]: observation (torch.Tensor): the initial observation of the space. info (Dict): contains auxiliary diagnostic information (helpful for debugging, and sometimes learning). """ - obs, info = self._env.reset() - return obs.to(self._device), info + return self._env.reset() def save(self) -> Dict[str, torch.nn.Module]: """Save the environment. diff --git a/omnisafe/adapter/saute_adapter.py b/omnisafe/adapter/saute_adapter.py index 1b65d60a8..5f48b57b0 100644 --- a/omnisafe/adapter/saute_adapter.py +++ b/omnisafe/adapter/saute_adapter.py @@ -66,16 +66,16 @@ def _wrapper( cost_normalize: bool = False, ): if self._env.need_time_limit_wrapper: - self._env = TimeLimit(self._env, time_limit=1000) + self._env = TimeLimit(self._env, device=self._device, time_limit=1000) if self._env.need_auto_reset_wrapper: - self._env = AutoReset(self._env) + self._env = AutoReset(self._env, device=self._device) if obs_normalize: - self._env = ObsNormalize(self._env) + self._env = ObsNormalize(self._env, device=self._device) assert reward_normalize is False, 'Reward normalization is not supported' assert cost_normalize is False, 'Cost normalization is not supported' - self._env = ActionScale(self._env, low=-1.0, high=1.0) + self._env = ActionScale(self._env, device=self._device, low=-1.0, high=1.0) if self._env.num_envs == 1: - self._env = Unsqueeze(self._env) + self._env = Unsqueeze(self._env, device=self._device) def reset(self) -> Tuple[torch.Tensor, Dict]: obs, info = self._env.reset() diff --git a/omnisafe/algorithms/on_policy/second_order/cpo.py b/omnisafe/algorithms/on_policy/second_order/cpo.py index 3969305e6..8e96d30ad 100644 --- a/omnisafe/algorithms/on_policy/second_order/cpo.py +++ b/omnisafe/algorithms/on_policy/second_order/cpo.py @@ -16,7 +16,6 @@ from typing import Tuple -import numpy as np import torch from omnisafe.algorithms import registry @@ -116,8 +115,12 @@ def _cpo_search_step( acceptance_step = step + 1 with torch.no_grad(): - # loss of policy reward from target/expected reward - loss_reward, _ = self._loss_pi(obs=obs, act=act, logp=logp, adv=adv_r) + try: + # loss of policy reward from target/expected reward + loss_reward, _ = self._loss_pi(obs=obs, act=act, logp=logp, adv=adv_r) + except ValueError: + step_frac *= decay + continue # loss of cost of policy cost from real/expected reward loss_cost = self._loss_pi_cost(obs=obs, act=act, logp=logp, adv_c=adv_c) # compute KL distance between new and old policy @@ -139,7 +142,10 @@ def _cpo_search_step( # check whether there are nan. if not torch.isfinite(loss_reward) and not torch.isfinite(loss_cost): self._logger.log('WARNING: loss_pi not finite') - elif loss_reward_improve < 0 if optim_case > 1 else False: + if not torch.isfinite(kl): + self._logger.log('WARNING: KL not finite') + continue + if loss_reward_improve < 0 if optim_case > 1 else False: self._logger.log('INFO: did not improve improve <0') # change of cost's range elif loss_cost_diff > max(-violation_c, 0): @@ -236,14 +242,13 @@ def _update_actor( b_grad = get_flat_gradients_from(self._actor_critic.actor) ep_costs = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit - cost = ep_costs / (self._logger.get_stats('Metrics/EpLen')[0] + 1e-8) p = conjugate_gradients(self._fvp, b_grad, self._cfgs.algo_cfgs.cg_iters) q = xHx r = grad.dot(p) s = b_grad.dot(p) - if b_grad.dot(b_grad) <= 1e-6 and cost < 0: + if b_grad.dot(b_grad) <= 1e-6 and ep_costs < 0: # feasible step and cost grad is zero: use plain TRPO update... A = torch.zeros(1) B = torch.zeros(1) @@ -253,17 +258,17 @@ def _update_actor( assert torch.isfinite(s).all(), 's is not finite' A = q - r**2 / (s + 1e-8) - B = 2 * self._cfgs.algo_cfgs.target_kl - cost**2 / (s + 1e-8) + B = 2 * self._cfgs.algo_cfgs.target_kl - ep_costs**2 / (s + 1e-8) - if cost < 0 and B < 0: + if ep_costs < 0 and B < 0: # point in trust region is feasible and safety boundary doesn't intersect # ==> entire trust region is feasible optim_case = 3 - elif cost < 0 <= B: + elif ep_costs < 0 <= B: # point in trust region is feasible but safety boundary intersects # ==> only part of trust region is feasible optim_case = 2 - elif cost >= 0 and B >= 0: + elif ep_costs >= 0 and B >= 0: # point in trust region is infeasible and cost boundary doesn't intersect # ==> entire trust region is infeasible optim_case = 1 @@ -296,16 +301,16 @@ def project(data: torch.Tensor, low: float, high: float) -> torch.Tensor: # where projection(str,b,c)=max(b,min(str,c)) # may be regarded as a projection from effective region towards safety region r_num = r.item() - eps_cost = cost + 1e-8 - if cost < 0: + eps_cost = ep_costs + 1e-8 + if ep_costs < 0: lambda_a_star = project(lambda_a, 0.0, r_num / eps_cost) - lambda_b_star = project(lambda_b, r_num / eps_cost, np.inf) + lambda_b_star = project(lambda_b, r_num / eps_cost, torch.inf) else: - lambda_a_star = project(lambda_a, r_num / eps_cost, np.inf) + lambda_a_star = project(lambda_a, r_num / eps_cost, torch.inf) lambda_b_star = project(lambda_b, 0.0, r_num / eps_cost) def f_a(lam): - return -0.5 * (A / (lam + 1e-8) + B * lam) - r * cost / (s + 1e-8) + return -0.5 * (A / (lam + 1e-8) + B * lam) - r * ep_costs / (s + 1e-8) def f_b(lam): return -0.5 * (q / (lam + 1e-8) + 2 * self._cfgs.algo_cfgs.target_kl * lam) @@ -316,7 +321,7 @@ def f_b(lam): # discard all negative values with torch.clamp(x, min=0) # Nu_star = (lambda_star * - r)/s - nu_star = torch.clamp(lambda_star * cost - r, min=0) / (s + 1e-8) + nu_star = torch.clamp(lambda_star * ep_costs - r, min=0) / (s + 1e-8) # final x_star as final direction played as policy's loss to backward and update step_direction = 1.0 / (lambda_star + 1e-8) * (x - nu_star * p) @@ -324,7 +329,7 @@ def f_b(lam): # purely decrease costs # without further check lambda_star = torch.zeros(1) - nu_star = np.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (s + 1e-8)) + nu_star = torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (s + 1e-8)) step_direction = -nu_star * p step_direction, accept_step = self._cpo_search_step( @@ -339,7 +344,7 @@ def f_b(lam): loss_reward_before=loss_reward_before, loss_cost_before=loss_cost_before, total_steps=20, - violation_c=cost, + violation_c=ep_costs, optim_case=optim_case, ) diff --git a/omnisafe/algorithms/on_policy/second_order/pcpo.py b/omnisafe/algorithms/on_policy/second_order/pcpo.py index 2ec1bc1bb..8b3a49753 100644 --- a/omnisafe/algorithms/on_policy/second_order/pcpo.py +++ b/omnisafe/algorithms/on_policy/second_order/pcpo.py @@ -91,9 +91,8 @@ def _update_actor( b_grad = get_flat_gradients_from(self._actor_critic.actor) ep_costs = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit - cost = ep_costs / (self._logger.get_stats('Metrics/EpLen')[0] + 1e-8) - self._logger.log(f'c = {cost}') + self._logger.log(f'c = {ep_costs}') self._logger.log(f'b^T b = {b_grad.dot(b_grad).item()}') p = conjugate_gradients(self._fvp, b_grad, self._cfgs.algo_cfgs.cg_iters) @@ -104,7 +103,7 @@ def _update_actor( step_direction = ( torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / (q + 1e-8)) * H_inv_g - torch.clamp_min( - (torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / q) * r + cost) / s, + (torch.sqrt(2 * self._cfgs.algo_cfgs.target_kl / q) * r + ep_costs) / s, torch.tensor(0.0, device=self._device), ) * p @@ -121,8 +120,8 @@ def _update_actor( adv_c=adv_c, loss_reward_before=loss_reward_before, loss_cost_before=loss_cost_before, - total_steps=20, - violation_c=cost, + total_steps=200, + violation_c=ep_costs, ) theta_new = theta_old + step_direction set_param_values_to_model(self._actor_critic.actor, theta_new) diff --git a/omnisafe/envs/core.py b/omnisafe/envs/core.py index 4f082c218..56c87065b 100644 --- a/omnisafe/envs/core.py +++ b/omnisafe/envs/core.py @@ -193,13 +193,14 @@ class Wrapper(CMDP): """ - def __init__(self, env: CMDP) -> None: + def __init__(self, env: CMDP, device: torch.device) -> None: """Initialize the wrapper. Args: env (CMDP): the environment. """ self._env = env + self._device = device def __getattr__(self, name: str) -> Any: """Get the attribute of the environment. diff --git a/omnisafe/envs/safety_gymnasium_env.py b/omnisafe/envs/safety_gymnasium_env.py index 7a0860031..521fbccd6 100644 --- a/omnisafe/envs/safety_gymnasium_env.py +++ b/omnisafe/envs/safety_gymnasium_env.py @@ -75,7 +75,9 @@ class SafetyGymnasiumEnv(CMDP): need_auto_reset_wrapper = False need_time_limit_wrapper = False - def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None: + def __init__( + self, env_id: str, num_envs: int = 1, device: torch.device = torch.device('cpu'), **kwargs + ) -> None: super().__init__(env_id) if num_envs > 1: self._env = safety_gymnasium.vector.make(env_id=env_id, num_envs=num_envs, **kwargs) @@ -88,13 +90,16 @@ def __init__(self, env_id: str, num_envs: int = 1, **kwargs) -> None: self._num_envs = num_envs self._metadata = self._env.metadata + self._device = device def step( self, action: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]: - obs, reward, cost, terminated, truncated, info = self._env.step(action) + obs, reward, cost, terminated, truncated, info = self._env.step( + action.detach().cpu().numpy() + ) obs, reward, cost, terminated, truncated = map( - lambda x: torch.as_tensor(x, dtype=torch.float32), + lambda x: torch.as_tensor(x, dtype=torch.float32, device=self._device), (obs, reward, cost, terminated, truncated), ) if 'final_observation' in info: @@ -105,20 +110,22 @@ def step( ] ) info['final_observation'] = torch.as_tensor( - info['final_observation'], dtype=torch.float32 + info['final_observation'], dtype=torch.float32, device=self._device ) return obs, reward, cost, terminated, truncated, info def reset(self, seed: Optional[int] = None) -> Tuple[torch.Tensor, Dict]: obs, info = self._env.reset(seed=seed) - return torch.as_tensor(obs, dtype=torch.float32), info + return torch.as_tensor(obs, dtype=torch.float32, device=self._device), info def set_seed(self, seed: int) -> None: self.reset(seed=seed) def sample_action(self) -> torch.Tensor: - return torch.as_tensor(self._env.action_space.sample(), dtype=torch.float32) + return torch.as_tensor( + self._env.action_space.sample(), dtype=torch.float32, device=self._device + ) def render(self) -> Any: return self._env.render() diff --git a/omnisafe/envs/wrapper.py b/omnisafe/envs/wrapper.py index c4413a746..a27a43c52 100644 --- a/omnisafe/envs/wrapper.py +++ b/omnisafe/envs/wrapper.py @@ -32,14 +32,14 @@ class TimeLimit(Wrapper): >>> env = TimeLimit(env, time_limit=100) """ - def __init__(self, env: CMDP, time_limit: int) -> None: + def __init__(self, env: CMDP, time_limit: int, device: torch.device) -> None: """Initialize the time limit wrapper. Args: env (CMDP): The environment to wrap. time_limit (int): The time limit for each episode. """ - super().__init__(env) + super().__init__(env=env, device=device) assert self.num_envs == 1, 'TimeLimit only supports single environment' @@ -56,7 +56,9 @@ def step( obs, reward, cost, terminated, truncated, info = super().step(action) self._time += 1 - truncated = torch.tensor(self._time >= self._time_limit, dtype=torch.bool) + truncated = torch.tensor( + self._time >= self._time_limit, dtype=torch.bool, device=self._device + ) return obs, reward, cost, terminated, truncated, info @@ -69,8 +71,8 @@ class AutoReset(Wrapper): """ - def __init__(self, env: CMDP) -> None: - super().__init__(env) + def __init__(self, env: CMDP, device: torch.device) -> None: + super().__init__(env=env, device=device) assert self.num_envs == 1, 'AutoReset only supports single environment' @@ -106,14 +108,14 @@ class ObsNormalize(Wrapper): """ - def __init__(self, env: CMDP, norm: Optional[Normalizer] = None) -> None: - super().__init__(env) + def __init__(self, env: CMDP, device: torch.device, norm: Optional[Normalizer] = None) -> None: + super().__init__(env=env, device=device) assert isinstance(self.observation_space, spaces.Box), 'Observation space must be Box' if norm is not None: - self._obs_normalizer = norm + self._obs_normalizer = norm.to(self._device) else: - self._obs_normalizer = Normalizer(self.observation_space.shape, clip=5) + self._obs_normalizer = Normalizer(self.observation_space.shape, clip=5).to(self._device) def step( self, action: torch.Tensor @@ -124,6 +126,7 @@ def step( final_obs_slice = info['_final_observation'] else: final_obs_slice = slice(None) + info['final_observation'] = info['final_observation'].to(self._device) info['original_final_observation'] = info['final_observation'] info['final_observation'][final_obs_slice] = self._obs_normalizer.normalize( info['final_observation'][final_obs_slice] @@ -155,7 +158,7 @@ class RewardNormalize(Wrapper): """ - def __init__(self, env: CMDP, norm: Optional[Normalizer] = None) -> None: + def __init__(self, env: CMDP, device: torch.device, norm: Optional[Normalizer] = None) -> None: """Initialize the reward normalizer. Args: @@ -163,11 +166,11 @@ def __init__(self, env: CMDP, norm: Optional[Normalizer] = None) -> None: norm (Optional[Normalizer], optional): The normalizer to use. Defaults to None. """ - super().__init__(env) + super().__init__(env=env, device=device) if norm is not None: - self._reward_normalizer = norm + self._reward_normalizer = norm.to(self._device) else: - self._reward_normalizer = Normalizer((), clip=5) + self._reward_normalizer = Normalizer((), clip=5).to(self._device) def step( self, action: torch.Tensor @@ -193,18 +196,18 @@ class CostNormalize(Wrapper): >>> env = CostNormalize(env, norm) """ - def __init__(self, env: CMDP, norm: Optional[Normalizer] = None) -> None: + def __init__(self, env: CMDP, device: torch.device, norm: Optional[Normalizer] = None) -> None: """Initialize the cost normalizer. Args: env (CMDP): The environment to wrap. norm (Normalizer, optional): The normalizer to use. Defaults to None. """ - super().__init__(env) + super().__init__(env=env, device=device) if norm is not None: - self._obs_normalizer = norm + self._obs_normalizer = norm.to(self._device) else: - self._cost_normalizer = Normalizer((), clip=5) + self._cost_normalizer = Normalizer((), clip=5).to(self._device) def step( self, action: torch.Tensor @@ -232,6 +235,7 @@ class ActionScale(Wrapper): def __init__( self, env: CMDP, + device: torch.device, low: Union[int, float], high: Union[int, float], ) -> None: @@ -242,11 +246,15 @@ def __init__( low: The lower bound of the action space. high: The upper bound of the action space. """ - super().__init__(env) + super().__init__(env=env, device=device) assert isinstance(self.action_space, spaces.Box), 'Action space must be Box' - self._old_min_action = torch.tensor(self.action_space.low, dtype=torch.float32) - self._old_max_action = torch.tensor(self.action_space.high, dtype=torch.float32) + self._old_min_action = torch.tensor( + self.action_space.low, dtype=torch.float32, device=self._device + ) + self._old_max_action = torch.tensor( + self.action_space.high, dtype=torch.float32, device=self._device + ) min_action = np.zeros(self.action_space.shape, dtype=self.action_space.dtype) + low max_action = np.zeros(self.action_space.shape, dtype=self.action_space.dtype) + high @@ -257,16 +265,16 @@ def __init__( dtype=self.action_space.dtype, # type: ignore ) - self._min_action = torch.tensor(min_action, dtype=torch.float32) - self._max_action = torch.tensor(max_action, dtype=torch.float32) + self._min_action = torch.tensor(min_action, dtype=torch.float32, device=self._device) + self._max_action = torch.tensor(max_action, dtype=torch.float32, device=self._device) def step( self, action: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict]: action = self._old_min_action + (self._old_max_action - self._old_min_action) * ( - action.cpu() - self._min_action + action - self._min_action ) / (self._max_action - self._min_action) - return super().step(action.numpy()) + return super().step(action) class Unsqueeze(Wrapper): @@ -276,13 +284,13 @@ class Unsqueeze(Wrapper): >>> env = Unsqueeze(env) """ - def __init__(self, env: CMDP) -> None: + def __init__(self, env: CMDP, device: torch.device) -> None: """Initialize the wrapper. Args: env: The environment to wrap. """ - super().__init__(env) + super().__init__(env=env, device=device) assert self.num_envs == 1, 'Unsqueeze only works with single environment' assert isinstance(self.observation_space, spaces.Box), 'Observation space must be Box' diff --git a/omnisafe/evaluator.py b/omnisafe/evaluator.py index 19139ba4d..d5895d33c 100644 --- a/omnisafe/evaluator.py +++ b/omnisafe/evaluator.py @@ -116,10 +116,10 @@ def __load_model_and_env(self, save_dir: str, model_name: str, env_kwargs: Dict[ if self._cfgs['algo_cfgs']['obs_normalize']: obs_normalizer = Normalizer(shape=observation_space.shape, clip=5) obs_normalizer.load_state_dict(model_params['obs_normalizer']) - self._env = ObsNormalize(self._env, obs_normalizer) + self._env = ObsNormalize(self._env, device='cpu', norm=obs_normalizer) if self._env.need_time_limit_wrapper: - self._env = TimeLimit(self._env, time_limit=1000) - self._env = ActionScale(self._env, low=-1.0, high=1.0) + self._env = TimeLimit(self._env, device='cpu', time_limit=1000) + self._env = ActionScale(self._env, device='cpu', low=-1.0, high=1.0) actor_type = self._cfgs['model_cfgs']['actor_type'] pi_cfg = self._cfgs['model_cfgs']['actor']