Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge fork to support other lib #146

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guide/user/3-distributed.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ workflow = workflows.StdWorkflow(cso, ackley)

key = jax.random.PRNGKey(42)
state = workflow.init(key)
state = workflow.enable_multi_devices(state, devices)
state = workflow.enable_multi_devices(state)

for i in range(10):
train_info, state = workflow.step(state)
Expand Down
2 changes: 1 addition & 1 deletion src/evox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .core.workflow import Workflow
from .core.algorithm import Algorithm, has_init_ask, has_init_tell
from .core.module import use_state, jit_class, jit_method, Stateful
from .core.module import Stateful, StatefulWrapper, use_state, jit_class, jit_cls_method
from .core.problem import Problem
from .core.state import State, get_state_sharding
from .core.monitor import Monitor
Expand Down
22 changes: 11 additions & 11 deletions src/evox/algorithms/so/es_variants/cma_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def setup(self, key):
B=B,
D=D,
C=C,
count_eigen=0,
count_iter=0,
count_eigen=jnp.zeros(()),
count_iter=jnp.zeros((), dtype=jnp.int32),
invsqrtC=C,
mean=self.center_init,
sigma=self.init_stdev,
mean=jnp.asarray(self.center_init, dtype=jnp.float32),
sigma=jnp.asarray(self.init_stdev, dtype=jnp.float32),
key=key,
population=jnp.empty((self.pop_size, self.dim)),
)
Expand Down Expand Up @@ -285,17 +285,17 @@ def setup(self, key):
B=B,
D=D,
C=C,
count_eigen=0,
count_iter=0,
count_eigen=jnp.zeros(()),
count_iter=jnp.zeros((), dtype=jnp.int32),
invsqrtC=C,
mean=self.center_init,
sigma=self.init_stdev,
mean=jnp.asarray(self.center_init, dtype=jnp.float32),
sigma=jnp.asarray(self.init_stdev, dtype=jnp.float32),
key=key,
population=jnp.empty((self.pop_size, self.dim)),
best_fitness=float("inf"),
restarts=0,
stagnation_count=0,
pop_size=self.original_pop_size,
restarts=jnp.zeros((), dtype=jnp.int32),
stagnation_count=jnp.zeros((), dtype=jnp.int32),
pop_size=jnp.int32(self.original_pop_size),
)

def _update_best_fitness(self, current_best_fitness, best_fitness):
Expand Down
39 changes: 0 additions & 39 deletions src/evox/core/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,47 +50,8 @@ def tell(self, state: State, fitness: jax.Array) -> State:


def has_init_ask(algorithm):
# def init_ask(self, state: State) -> Tuple[jax.Array, State]:
# """Ask the algorithm for the initial population

# Override this method if you need to initialize the population in a special way.
# For example, Genetic Algorithm needs to evaluate the fitness of the initial population of size N,
# but after that, it only need to evaluate the fitness of the offspring of size M, and N != M.
# Since JAX requires the function return to have static shape, we need to have two different functions,
# one is the normal `ask` and another is `init_ask`.

# Parameters
# ----------
# state
# The state of this algorithm.

# Returns
# -------
# population
# The candidate solution.
# state
# The new state of the algorithm.
# """
# return None, State()
return hasattr(algorithm, "init_ask") and callable(algorithm.init_ask)


def has_init_tell(algorithm):
# def init_tell(self, state: State, fitness: jax.Array) -> State:
# """Tell the algorithm the fitness of the initial population
# Use in pair with `init_ask`.

# Parameters
# ----------
# state
# The state of this algorithm
# fitness
# The fitness

# Returns
# -------
# state
# The new state of the algorithm
# """
# return State()
return hasattr(algorithm, "init_tell") and callable(algorithm.init_tell)
192 changes: 84 additions & 108 deletions src/evox/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .state import State


def use_state(func: Callable, index: int = None):
def use_state(func: Callable):
"""Decorator for easy state management.

This decorator will try to extract the sub-state belong to the module from current state
Expand All @@ -28,7 +28,7 @@ def use_state(func: Callable, index: int = None):
Typically used to handle batched states created from `State.batch`.
"""

err_msg = "Expect last return value must be State, got {}"
err_msg = "Expect last return value must be State, but get {}"

def wrapper(self, state: State, *args, **kwargs):
assert isinstance(
Expand All @@ -40,46 +40,30 @@ def wrapper(self, state: State, *args, **kwargs):
)

# find the state that match the current module
path, matched_state = state.find_path_to(self._node_id, self._module_name)

if index is not None:
extracted_state = tree_map(lambda x: x[index], matched_state)
this_module = tree_map(lambda x: x[index], self)
else:
extracted_state = matched_state
this_module = self
path, extracted_state = state._query_state_by_id(
self._node_id, self._module_name
)

if hasattr(func, "__self__"):
# bounded method, don't pass self
return_value = func(extracted_state, *args, **kwargs)
else:
# unbounded method (class method), pass self
return_value = func(this_module, extracted_state, *args, **kwargs)
return_value = func(self, extracted_state, *args, **kwargs)

# single return value, the value must be a State
if not isinstance(return_value, tuple):
assert isinstance(return_value, State), err_msg.format(type(return_value))
aux, new_state = None, return_value
aux, new_extracted_state = None, return_value
state = state.replace_state(path, new_extracted_state)
return state
else:
# unpack the return value first
assert isinstance(return_value[-1], State), err_msg.format(
type(return_value[-1])
)
aux, new_state = return_value[:-1], return_value[-1]

# if index is specified, apply the index to the state
if index is not None:
new_state = tree_map(
lambda batch_arr, new_arr: batch_arr.at[index].set(new_arr),
matched_state,
new_state,
)

state = state.replace_by_path(path, new_state)

if aux is None:
return state
else:
aux, new_extracted_state = return_value[:-1], return_value[-1]
state = state.replace_state(path, new_extracted_state)
return (*aux, state)

if hasattr(func, "__self__"):
Expand All @@ -88,7 +72,7 @@ def wrapper(self, state: State, *args, **kwargs):
return wraps(func)(wrapper)


def jit_method(method: Callable):
def jit_cls_method(method: Callable):
"""Decorator for methods, wrapper the method with jax.jit, and set self as static argument.

Parameters
Expand All @@ -101,12 +85,7 @@ def jit_method(method: Callable):
function
A jit wrapped version of this method
"""
return jax.jit(
method,
static_argnums=[
0,
],
)
return jax.jit(method, static_argnums=(0,))


def default_jit_func(name: str):
Expand All @@ -133,11 +112,14 @@ def jit_class(cls):
if dataclasses.is_dataclass(cls):
wrapped = jax.jit(func)
else:
wrapped = jit_method(func)
wrapped = jit_cls_method(func)
setattr(cls, attr_name, wrapped)
return cls


SubmoduleInfo = namedtuple("SubmoduleInfo", ["name", "module", "metadata"])


class Stateful:
"""Base class for all evox modules.

Expand Down Expand Up @@ -174,79 +156,49 @@ def setup(self, key: jax.Array) -> State:
return State()

def _recursive_init(
self, key: jax.Array, node_id: int, module_name: str, no_state: bool
) -> Tuple[State, int]:
self, key: jax.Array, node_id: int, module_name: str
) -> tuple[State, int]:
# the unique id of this module, matching its state._state_id
object.__setattr__(self, "_node_id", node_id)
object.__setattr__(self, "_module_name", module_name)

if not no_state:
child_states = {}

# Find all submodules and sort them according to their name.
# Sorting is important because it makes sure that the node_id
# is deterministic across different runs.
SubmoduleInfo = namedtuple("Submodule", ["name", "module", "metadata"])

submodules = []
# preprocess and sort to make sure the order is deterministic
# otherwise the node_id will be different across different runs
# making save/load impossible
if dataclasses.is_dataclass(self):
submodule_infos = []
if dataclasses.is_dataclass(self): # TODO: use robust check
for field in dataclasses.fields(self):
attr = getattr(self, field.name)

if isinstance(attr, Stateful):
submodules.append(SubmoduleInfo(field.name, attr, field.metadata))
submodule_infos.append(
SubmoduleInfo(field.name, attr, field.metadata)
)
else:
for attr_name in vars(self):
attr = getattr(self, attr_name)
if not attr_name.startswith("_") and isinstance(attr, Stateful):
submodules.append(SubmoduleInfo(attr_name, attr, {}))

submodules.sort()
if isinstance(attr, Stateful):
submodule_infos.append(SubmoduleInfo(attr_name, attr, {}))

for attr_name, attr, metadata in submodules:
if key is None:
subkey = None
else:
key, subkey = jax.random.split(key)

# handle "StackAnnotation"
# attr should be a list, or tuple of modules
if metadata.get("stack", False):
num_copies = len(attr)
subkeys = jax.random.split(subkey, num_copies)
current_node_id = node_id
_, node_id = attr._recursive_init(None, node_id + 1, attr_name, True)
submodule_state, _node_id = jax.vmap(
partial(
Stateful._recursive_init,
node_id=current_node_id + 1,
module_name=attr_name,
no_state=no_state,
)
)(attr, subkeys)
else:
submodule_state, node_id = attr._recursive_init(
subkey, node_id + 1, attr_name, no_state
)

if not no_state:
assert isinstance(
submodule_state, State
), "setup method must return a State"
child_states[attr_name] = submodule_state
if no_state:
return None, node_id
else:
return (
self.setup(key)
._set_state_id_mut(self._node_id)
._set_child_states_mut(child_states),
node_id,
# Find all submodules and sort them according to their name.
# Sorting is important because it makes sure that the node_id
# is deterministic across different runs.
submodule_infos.sort()
child_states = {}
for attr_name, attr, metadata in submodule_infos:
key, subkey = jax.random.split(key)
submodule_state, node_id = attr._recursive_init(
subkey, node_id + 1, attr_name
)
child_states[attr_name] = submodule_state

return (
self.setup(key)
._set_state_id_mut(self._node_id)
._set_child_states_mut(child_states),
node_id,
)

def init(self, key: jax.Array = None, no_state: bool = False) -> State:
def init(self, key: jax.Array) -> State:
"""Initialize this module and all submodules

This method should not be overwritten.
Expand All @@ -261,24 +213,48 @@ def init(self, key: jax.Array = None, no_state: bool = False) -> State:
State
The state of this module and all submodules combined.
"""
state, _node_id = self._recursive_init(key, 0, None, no_state)
state, _ = self._recursive_init(key, 0, self.__class__.__name__)
return state

@classmethod
def stack(cls, stateful_objs, axis=0):
for obj in stateful_objs:
assert dataclasses.is_dataclass(obj), "All objects must be dataclasses"
# @classmethod
# def stack(cls, stateful_objs, axis=0):
# for obj in stateful_objs:
# assert dataclasses.is_dataclass(obj), "All objects must be dataclasses"

def stack_arrays(array, *arrays):
return jnp.stack((array, *arrays), axis=axis)
# def stack_arrays(array, *arrays):
# return jnp.stack((array, *arrays), axis=axis)

return tree_map(stack_arrays, stateful_objs[0], *stateful_objs[1:])
# return tree_map(stack_arrays, stateful_objs[0], *stateful_objs[1:])

# def __len__(self) -> int:
# """
# Inspect the length of the first element in the state,
# usually paired with `Stateful.stack` to read the batch size
# """
# assert dataclasses.is_dataclass(self), "Length is only supported for dataclass"

# return len(tree_leaves(self)[0])

def __len__(self) -> int:
"""
Inspect the length of the first element in the state,
usually paired with `Stateful.stack` to read the batch size
"""
assert dataclasses.is_dataclass(self), "Length is only supported for dataclass"

return len(tree_leaves(self)[0])
class StatefulWrapper(Stateful):
"""
A wrapper class for Stateful modules.
"""

def __init__(self, module: Stateful):
super().__init__()
self._module = module

def _recursive_init(
self, key: jax.Array, node_id: int, module_name: str
) -> tuple[State, int]:
"""Skip the wrapper during init"""

# the unique id of this module, matching its state._state_id
object.__setattr__(self, "_node_id", node_id)
object.__setattr__(self, "_module_name", module_name)

return self._module._recursive_init(key, node_id, module_name)

def setup(self, key: jax.Array) -> State:
raise NotImplementedError("This method should not be called")
1 change: 0 additions & 1 deletion src/evox/core/pytree_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from jax.tree_util import register_pytree_node
import dataclasses
from typing import Annotated, Any, Callable, Optional, Tuple, TypeVar, get_type_hints

from typing_extensions import (
dataclass_transform, # pytype: disable=not-supported-yet
Expand Down
Loading