Skip to content

Commit

Permalink
feat: ruff formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Jun 14, 2024
1 parent 9156666 commit 4971ce0
Show file tree
Hide file tree
Showing 28 changed files with 1,671 additions and 1,437 deletions.
63 changes: 37 additions & 26 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@ default_stages: [ "commit", "commit-msg", "push" ]
default_language_version:
python: python3


repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.13.2
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.8
hooks:
- id: isort
# Run the linter.
- id: ruff
types_or: [ python ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]

- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
name: "Code formatter"
# - repo: https://github.com/timothycrosley/isort
# rev: 5.13.2
# hooks:
# - id: isort

# - repo: https://github.com/psf/black
# rev: 24.2.0
# hooks:
# - id: black
# name: "Code formatter"

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: end-of-file-fixer
name: "End of file fixer"
Expand All @@ -32,20 +43,20 @@ repos:
- id: trailing-whitespace
name: "Trailing whitespace fixer"

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
name: "Linter"
additional_dependencies:
- pep8-naming
- flake8-builtins
- flake8-comprehensions
- flake8-bugbear
- flake8-pytest-style
- flake8-cognitive-complexity
- flake8-pyproject
- importlib-metadata<5.0
# - repo: https://github.com/PyCQA/flake8
# rev: 7.0.0
# hooks:
# - id: flake8
# name: "Linter"
# additional_dependencies:
# - pep8-naming
# - flake8-builtins
# - flake8-comprehensions
# - flake8-bugbear
# - flake8-pytest-style
# - flake8-cognitive-complexity
# - flake8-pyproject
# - importlib-metadata<5.0

- repo: local
hooks:
Expand All @@ -57,15 +68,15 @@ repos:
pass_filenames: false

- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v9.11.0
rev: v9.16.0
hooks:
- id: commitlint
name: "Commit linter"
stages: [ commit-msg ]
additional_dependencies: [ '@commitlint/config-conventional' ]

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.0
rev: v1.5.5
hooks:
- id: insert-license
name: "License inserter"
Expand Down
2,392 changes: 1,205 additions & 1,187 deletions examples/Quickstart.ipynb

Large diffs are not rendered by default.

37 changes: 26 additions & 11 deletions mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def get_learner_fn(
config: DictConfig,
) -> StoreExpLearnerFn[LearnerState]:
"""Get the learner function."""

# Get apply and update functions for actor and critic networks.
actor_apply_fn, critic_apply_fn = apply_fns
actor_update_fn, critic_update_fn = update_fns
Expand All @@ -75,13 +74,15 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
losses.
Args:
----
learner_state (NamedTuple):
- params (Params): The current model parameters.
- opt_states (OptStates): The current optimizer states.
- key (PRNGKey): The random number generator state.
- env_state (State): The environment state.
- last_timestep (TimeStep): The last timestep in the current trajectory.
_ (Any): The current metrics info.
"""

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
Expand All @@ -106,7 +107,13 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
info = timestep.extras["episode_metrics"]

transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
done,
action,
value,
timestep.reward,
log_prob,
last_timestep.observation,
info,
)

learner_state = LearnerState(params, opt_states, key, env_state, timestep)
Expand Down Expand Up @@ -155,7 +162,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:

def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""

# UNPACK TRAIN STATE AND BATCH INFO
params, opt_states = train_state
traj_batch, advantages, targets = batch_info
Expand Down Expand Up @@ -214,13 +220,19 @@ def _critic_loss_fn(
# CALCULATE ACTOR LOSS
actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
params.actor_params, opt_states.actor_opt_state, traj_batch, advantages
params.actor_params,
opt_states.actor_opt_state,
traj_batch,
advantages,
)

# CALCULATE CRITIC LOSS
critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True)
critic_loss_info, critic_grads = critic_grad_fn(
params.critic_params, opt_states.critic_opt_state, traj_batch, targets
params.critic_params,
opt_states.critic_opt_state,
traj_batch,
targets,
)

# Compute the parallel mean (pmean) over the batch.
Expand Down Expand Up @@ -285,7 +297,7 @@ def _critic_loss_fn(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])),
lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])),
shuffled_batch,
)

Expand Down Expand Up @@ -319,14 +331,15 @@ def learner_fn(
updates. The `_update_step` function is vectorized over a batch of inputs.
Args:
----
learner_state (NamedTuple):
- params (Params): The initial model parameters.
- opt_states (OptStates): The initial optimizer state.
- key (chex.PRNGKey): The random number generator state.
- env_state (LogEnvState): The environment state.
- timesteps (TimeStep): The initial timestep in the initial trajectory.
"""
"""
batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch")

learner_state, (episode_info, loss_info, traj_batch) = jax.lax.scan(
Expand Down Expand Up @@ -400,7 +413,10 @@ def learner_setup(
input_params=Params(actor_params, critic_params)
)
# Update the params
actor_params, critic_params = restored_params.actor_params, restored_params.critic_params
actor_params, critic_params = (
restored_params.actor_params,
restored_params.critic_params,
)

# Pack apply and update functions.
apply_fns = (actor_network.apply, critic_network.apply)
Expand All @@ -412,7 +428,7 @@ def learner_setup(

# Broadcast params and optimiser state to cores and batch.
broadcast = lambda x: jnp.broadcast_to(
x, (n_devices, config.system.update_batch_size) + x.shape
x, (n_devices, config.system.update_batch_size, *x.shape)
)

actor_params = jax.tree_map(broadcast, actor_params)
Expand Down Expand Up @@ -450,7 +466,7 @@ def learner_setup(


# TODO: fix cognitive complexity
def run_experiment(_config: DictConfig) -> None: # noqa: CCR001
def run_experiment(_config: DictConfig) -> None:
"""Runs experiment."""
# Logger setup
config = copy.deepcopy(_config)
Expand Down Expand Up @@ -547,7 +563,6 @@ def run_experiment(_config: DictConfig) -> None: # noqa: CCR001
@jax.jit
def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Array]:
"""Reshape experience to match buffer."""

# Swap the T and NE axes (D, NU, UB, T, NE, ...) -> (D, NU, UB, NE, T, ...)
experience: Dict[str, chex.Array] = jax.tree_map(lambda x: x.swapaxes(3, 4), experience)
# Merge 4 leading dimensions into 1. (D, NU, UB, NE, T ...) -> (D * NU * UB * NE, T, ...)
Expand Down
15 changes: 6 additions & 9 deletions mava/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@


class TanhTransformedDistribution(tfd.TransformedDistribution):
"""
A distribution transformed using the `tanh` function.
"""A distribution transformed using the `tanh` function.
This transformation was adapted to acme's implementation.
For details, please see: http://tinyurl.com/2x5xea57
Expand All @@ -35,14 +34,15 @@ def __init__(
threshold: float = 0.999,
validate_args: bool = False,
) -> None:
"""
Initialises the TanhTransformedDistribution.
"""Initialises the TanhTransformedDistribution.
Args:
----
distribution: The base distribution to be transformed.
bijector: The bijective transformation applied to the distribution.
threshold: Clipping value for the action when computing the log_prob.
validate_args: Whether to validate input with respect to distribution parameters.
"""
super().__init__(
distribution=distribution, bijector=tfb.Tanh(), validate_args=validate_args
Expand All @@ -64,7 +64,6 @@ def __init__(

def log_prob(self, event: chex.Array) -> chex.Array:
"""Computes the log probability of the event under the transformed distribution."""

# Without this clip, there would be NaNs in the internal tf.where.
event = jnp.clip(event, -self._threshold, self._threshold)
# The inverse image of {threshold} is the interval [atanh(threshold), inf]
Expand Down Expand Up @@ -93,8 +92,7 @@ def _parameter_properties(cls, dtype: Optional[Any], num_classes: Any = None) ->


class MaskedEpsGreedyDistribution(tfd.Categorical):
"""
Computes an epsilon-greedy distribution for each action choice. There are two
"""Computes an epsilon-greedy distribution for each action choice. There are two
components in the distribution:
1. A uniform component, where every action that is NOT masked out gets an even weighting.
Expand Down Expand Up @@ -146,8 +144,7 @@ def _parameter_properties(cls, dtype: Optional[Any], num_classes: Any = None) ->


class IdentityTransformation(tfd.TransformedDistribution):
"""
A distribution transformed using the `Identity()` bijector.
"""A distribution transformed using the `Identity()` bijector.
We transform this distribution with the `Identity()` bijector to enable us to call
`pi.entropy(seed)` and keep the API identical to the TanhTransformedDistribution.
Expand Down
22 changes: 17 additions & 5 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_ff_evaluator_fn(
"""Get the evaluator function for feedforward networks.
Args:
----
env (Environment): An evironment instance for evaluation.
apply_fn (callable): Network forward pass method.
config (dict): Experiment configuration.
Expand All @@ -51,6 +52,7 @@ def get_ff_evaluator_fn(
of training by rolling out the policy which obtained the greatest evaluation
performance during training for 10 times more episodes than were used at a
single evaluation step.
"""

def eval_one_episode(params: FrozenDict, init_eval_state: EvalState) -> Dict:
Expand All @@ -65,7 +67,8 @@ def _env_step(eval_state: EvalState) -> EvalState:
key, policy_key = jax.random.split(key)
# Add a batch dimension to the observation.
pi = apply_fn(
params, jax.tree_map(lambda x: x[jnp.newaxis, ...], last_timestep.observation)
params,
jax.tree_map(lambda x: x[jnp.newaxis, ...], last_timestep.observation),
)

if config.arch.evaluation_greedy:
Expand Down Expand Up @@ -106,7 +109,6 @@ def not_done(carry: Tuple) -> bool:

def evaluator_fn(trained_params: FrozenDict, key: chex.PRNGKey) -> ExperimentOutput[EvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
n_devices = len(jax.devices())

Expand Down Expand Up @@ -226,7 +228,6 @@ def evaluator_fn(
trained_params: FrozenDict, key: chex.PRNGKey
) -> ExperimentOutput[RNNEvalState]:
"""Evaluator function."""

# Initialise environment states and timesteps.
n_devices = len(jax.devices())

Expand Down Expand Up @@ -292,6 +293,7 @@ def make_eval_fns(
"""Initialize evaluator functions for reinforcement learning.
Args:
----
eval_env (Environment): The environment used for evaluation.
network_apply_fn (Union[ActorApply,RecActorApply]): Creates a policy to sample.
config (DictConfig): The configuration settings for the evaluation.
Expand All @@ -300,11 +302,14 @@ def make_eval_fns(
Required if `use_recurrent_net` is True. Defaults to None.
Returns:
-------
Tuple[EvalFn, EvalFn]: A tuple of two evaluation functions:
one for use during training and one for absolute metrics.
Raises:
------
AssertionError: If `use_recurrent_net` is True but `scanned_rnn` is not provided.
"""
# Check if win rate is required for evaluation.
log_win_rate = config.env.log_win_rate
Expand All @@ -328,10 +333,17 @@ def make_eval_fns(
)
else:
evaluator = get_ff_evaluator_fn(
eval_env, network_apply_fn, config, log_win_rate # type: ignore
eval_env,
network_apply_fn, # type: ignore
config,
log_win_rate, # type: ignore
)
absolute_metric_evaluator = get_ff_evaluator_fn(
eval_env, network_apply_fn, config, log_win_rate, 10 # type: ignore
eval_env,
network_apply_fn, # type: ignore
config,
log_win_rate,
10, # type: ignore
)

evaluator = jax.pmap(evaluator, axis_name="device")
Expand Down
Loading

0 comments on commit 4971ce0

Please sign in to comment.