Skip to content

Commit

Permalink
Merge branch 'main' into autograd-docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 6, 2025
2 parents 88cf060 + 48b6ddb commit adc0eab
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,4 @@ cython_debug/
/docs/apis/changelog.md
/dist-hist/
/dist-hist/
/examples/
11 changes: 8 additions & 3 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,7 +2187,8 @@ def __init__(
# skip 'asarray' if dtype is not provided

elif isinstance(mantissa, (jnp.number, numbers.Number)):
mantissa = jnp.array(mantissa, dtype=dtype)
# mantissa = jnp.array(mantissa, dtype=dtype)
mantissa = mantissa

else:
mantissa = mantissa
Expand Down Expand Up @@ -3631,7 +3632,11 @@ def tolist(self):
If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will
not be a list at all, but a simple Python scalar.
"""
return _replace_with_array(self.mantissa.tolist(), self.unit)
if isinstance(self.mantissa, numbers.Number):
list_mantissa = self.mantissa
else:
list_mantissa = self.mantissa.tolist()
return _replace_with_array(list_mantissa, self.unit)

def transpose(self, *axes) -> 'Quantity':
"""Returns a view of the array with axes transposed.
Expand Down Expand Up @@ -4054,7 +4059,7 @@ def get(
arguments ``indices_are_sorted`` and ``unique_indices`` to be passed.
"""
if fill_value is not None:
fill_value = Quantity(fill_value).in_unit(self.unit).mantissa.item()
fill_value = Quantity(fill_value).in_unit(self.unit).mantissa
return Quantity(
self.mantissa_at[self.index].get(
indices_are_sorted=indices_are_sorted,
Expand Down
12 changes: 6 additions & 6 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_display(self):
assert_equal(str(u.kmeter / u.meter), 'Unit(10.0^3)')

def test_unit_with_factor(self):
self.assertTrue(1. * u.eV / u.joule == 1.6021766e-19)
self.assertTrue(1. * u.joule / u.eV == 6.241509074460762e18)
self.assertTrue(u.math.isclose(1. * u.eV / u.joule, 1.6021766e-19))
self.assertTrue(u.math.isclose(1. * u.joule / u.eV, 6.241509074460762e18))


class TestQuantity(unittest.TestCase):
Expand Down Expand Up @@ -261,10 +261,10 @@ def test_display(self):
assert_equal(display_in_unit(10. * u.mV), '10. * mvolt')
assert_equal(display_in_unit(10. * u.ohm * u.amp), '10. * volt')
assert_equal(display_in_unit(120. * (u.mS / u.cm ** 2)), '120. * msiemens / cmeter2')
assert_equal(display_in_unit(3.0 * u.kmeter / 130.51 * u.meter), '0.02298675 * 10.0^3 * meter2')
assert_equal(display_in_unit(3.0 * u.kmeter / (130.51 * u.meter)), 'Quantity(22.986746)')
assert_equal(display_in_unit(3.0 * u.kmeter / 130.51 * u.meter * u.cm ** -2), 'Quantity(229867.45)')
assert_equal(display_in_unit(3.0 * u.kmeter / 130.51 * u.meter * u.cm ** -1), '0.02298675 * 10.0^5 * meter')
assert_equal(display_in_unit(3.0 * u.kmeter / 130.51 * u.meter), '0.02298674 * 10.0^3 * meter2')
assert_equal(display_in_unit(3.0 * u.kmeter / (130.51 * u.meter)), 'Quantity(22.986744)')
assert_equal(display_in_unit(3.0 * u.kmeter / 130.51 * u.meter * u.cm ** -2), 'Quantity(229867.44)')
assert_equal(display_in_unit(3.0 * u.kmeter / 130.51 * u.meter * u.cm ** -1), '0.02298674 * 10.0^5 * meter')
assert_equal(display_in_unit(1. * u.joule / u.kelvin), '1. * joule / kelvin')

assert_equal(str(1. * u.metre / ((3.0 * u.ms) / (1. * u.second))), '333.33334 * meter')
Expand Down
6 changes: 3 additions & 3 deletions brainunit/_celsius_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import numpy as np

import brainunit as u

Expand All @@ -22,12 +22,12 @@ def test1():
assert a == 273.15 * u.kelvin

b = u.celsius2kelvin(-100)
assert b == 173.15 * u.kelvin
assert u.math.allclose(b, 173.15 * u.kelvin)


def test2():
a = u.kelvin2celsius(273.15 * u.kelvin)
assert a == 0

b = u.kelvin2celsius(173.15 * u.kelvin)
assert b == -100
assert np.isclose(b, -100)
14 changes: 14 additions & 0 deletions brainunit/autograd/_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ def hessian(
Physical unit-aware version of `jax.hessian <https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html>`_,
computing Hessian of ``fun`` as a dense array.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def scalar_function1(x):
... return x ** 2 + 3 * x * u.ms + 2 * u.msecond2
>>> hess_fn = u.autograd.hessian(scalar_function1)
>>> hess_fn(jnp.array(1.0) * u.ms)
[2]
>>> def scalar_function2(x):
... return x ** 3 + 3 * x * u.msecond2 + 2 * u.msecond3
>>> hess_fn = u.autograd.hessian(scalar_function2)
>>> hess_fn(jnp.array(1.0) * u.ms)
[6] * ms
Args:
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
Expand Down
38 changes: 38 additions & 0 deletions brainunit/autograd/_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ def jacrev(
"""
Physical unit-aware version of `jax.jacrev <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacrev.html>`_.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function1(x):
... return x ** 2
>>> jac_fn = u.autograd.jacrev(simple_function)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms
>>> def simple_function2(x, y):
... return x * y
>>> jac_fn = u.autograd.jacrev(simple_function2, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_fn(x, y)
([[5., 0.],
[0., 6.]] * mA,
[[3., 0.],
[0., 4.]] * ohm)
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
Expand Down Expand Up @@ -240,6 +259,25 @@ def jacfwd(
"""
Physical unit-aware version of `jax.jacfwd <https://jax.readthedocs.io/en/latest/_autosummary/jax.jacfwd.html>`_.
Example::
>>> import jax.numpy as jnp
>>> import brainunit as u
>>> def simple_function(x):
... return x ** 2
>>> jac_fn = u.autograd.jacfwd(simple_function)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms
>>> def simple_function(x, y):
... return x * y
>>> jac_fn = u.autograd.jacfwd(simple_function, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_fn(x, y)
([[5., 0.],
[0., 6.]] * mA,
[[3., 0.],
[0., 4.]] * ohm)
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
Expand Down
6 changes: 4 additions & 2 deletions brainunit/constants_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,7 @@ def test_quantity_constants_and_unit_constants(self):
for c in constants_list:
q_c = getattr(quantity_constants, c)
u_c = getattr(unit_constants, c)
assert q_c.to_decimal(q_c.unit) == (1. * u_c).to_decimal(
q_c.unit), f"Mismatch between {c} in quantity_constants and unit_constants"
assert u.math.isclose(
q_c.to_decimal(q_c.unit),
(1. * u_c).to_decimal(q_c.unit)
), f"Mismatch between {c} in quantity_constants and unit_constants"

0 comments on commit adc0eab

Please sign in to comment.