Skip to content

Commit

Permalink
update codes
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 8, 2024
1 parent d85ce40 commit 64283c0
Show file tree
Hide file tree
Showing 5 changed files with 839 additions and 780 deletions.
21 changes: 13 additions & 8 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'Quantity',
'Unit',
'UnitRegistry',
'DIMENSIONLESS',
'DimensionMismatchError',
'get_or_create_dimension',
'get_unit',
Expand All @@ -49,7 +50,6 @@
]

_all_slice = slice(None, None, None)

random = None
_unit_checking = True
_automatically_register_units = True
Expand Down Expand Up @@ -242,6 +242,7 @@ class Dimension:
indices, allowing for a very fast dimensionality check with ``is``.
"""

__module__ = "brainunit"
__slots__ = ["_dims"]
__array_priority__ = 1000

Expand Down Expand Up @@ -748,9 +749,9 @@ def in_best_unit(x, precision=None):
return x.repr_in_unit(u, precision=precision)


def array_with_units(
def array_with_unit(
floatval,
units: Dimension,
unit: Dimension,
dtype: bst.typing.DTypeLike = None
) -> 'Quantity':
"""
Expand All @@ -764,7 +765,7 @@ def array_with_units(
----------
floatval : `float`
The floating point value of the array.
units: Dimension
unit: Dimension
The unit dimensions of the array.
dtype: `dtype`, optional
The data type of the array.
Expand All @@ -777,10 +778,10 @@ def array_with_units(
Examples
--------
>>> from brainunit import *
>>> array_with_units(0.001, volt.unit)
>>> array_with_unit(0.001, volt.unit)
1. * mvolt
"""
return Quantity(floatval, unit=get_or_create_dimension(units._dims), dtype=dtype)
return Quantity(floatval, unit=get_or_create_dimension(unit._dims), dtype=dtype)


def is_unitless(obj) -> bool:
Expand Down Expand Up @@ -936,6 +937,7 @@ class Quantity(object):
unit. It is used to represent all physical quantities in ``BrainCore``.
"""
__module__ = "brainunit"
__slots__ = ('_value', '_unit')
_value: Union[jax.Array, numbers.Number]
_unit: Dimension
Expand Down Expand Up @@ -1701,7 +1703,7 @@ def __round__(self, ndigits: int = None) -> 'Quantity':
return Quantity(self.value.__round__(ndigits), unit=self.unit)

def __reduce__(self):
return array_with_units, (self.value, self.unit, self.value.dtype)
return array_with_unit, (self.value, self.unit, self.value.dtype)

# ----------------------- #
# NumPy methods #
Expand Down Expand Up @@ -2438,8 +2440,9 @@ class Unit(Quantity):
3. * joule
"""
__slots__ = ["_value", "_unit", "scale", "_dispname", "_name", "iscompound"]

__module__ = "brainunit"
__slots__ = ["_value", "_unit", "scale", "_dispname", "_name", "iscompound"]
__array_priority__ = 1000

def __init__(
Expand Down Expand Up @@ -2706,6 +2709,8 @@ class UnitRegistry:
__getitem__
"""

__module__ = "brainunit"

def __init__(self):
self.units = collections.OrderedDict()
self.units_for_dimensions = collections.defaultdict(dict)
Expand Down
59 changes: 42 additions & 17 deletions brainunit/math/_compat_numpy.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from collections.abc import Sequence
from functools import wraps
from typing import (Callable, Union)
from typing import (Callable, Union, Optional)

import brainstate as bst
import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
from braincore._common import set_module_as
from jax import lax
from brainstate._utils import set_module_as
from jax._src.numpy.lax_numpy import _einsum

from brainunit._base import (
DIMENSIONLESS,
Quantity,
fail_for_dimension_mismatch,
is_unitless,
_return_check_unitless,
get_unit,
)
from brainunit.math._utils import _compatible_with_quantity
from ._utils import _compatible_with_quantity
from .._base import (DIMENSIONLESS,
Quantity,
Unit,
fail_for_dimension_mismatch,
is_unitless,
get_unit, )
from .._base import _return_check_unitless

__all__ = [
# array creation
Expand Down Expand Up @@ -126,8 +140,12 @@
# --------------

def wrap_array_creation_function(func):
def f(*args, **kwargs):
return Quantity(func(*args, **kwargs))
def f(*args, unit: Unit = None, **kwargs):
if unit is not None:
assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}'
return func(*args, **kwargs) * unit
else:
return func(*args, **kwargs)

f.__module__ = 'brainunit.math'
return f
Expand All @@ -143,7 +161,6 @@ def f(*args, **kwargs):
empty = wrap_array_creation_function(jnp.empty)
ones = wrap_array_creation_function(jnp.ones)
zeros = wrap_array_creation_function(jnp.zeros)
array = wrap_array_creation_function(jnp.array)


@set_module_as('brainunit.math')
Expand Down Expand Up @@ -218,7 +235,12 @@ def zeros_like(a, dtype=None, shape=None):


@set_module_as('brainunit.math')
def asarray(a, dtype=None, order=None):
def asarray(
a,
dtype: Optional[bst.typing.DTypeLike] = None,
order: Optional[str] = None,
unit: Optional[Unit] = None,
):
from builtins import all as origin_all
from builtins import any as origin_any
if isinstance(a, Quantity):
Expand All @@ -238,6 +260,9 @@ def asarray(a, dtype=None, order=None):
return jnp.asarray(a, dtype=dtype, order=order)


array = asarray


@set_module_as('brainunit.math')
def arange(*args, **kwargs):
# arange has a bit of a complicated argument structure unfortunately
Expand Down Expand Up @@ -1309,7 +1334,7 @@ def einsum(
optimize: Union[str, bool] = "optimal",
precision: jax.lax.PrecisionLike = None,
preferred_element_type: Union[jax.typing.DTypeLike, None] = None,
_dot_general: Callable[..., jax.Array] = lax.dot_general,
_dot_general: Callable[..., jax.Array] = jax.lax.dot_general,
) -> Union[jax.Array, Quantity]:
operands = (subscripts, *operands)
if out is not None:
Expand Down
Loading

0 comments on commit 64283c0

Please sign in to comment.