Skip to content

Commit

Permalink
enable Quantity typing with syntax of Quantity[unit]
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 15, 2025
1 parent 4423578 commit e9fb3e3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
6 changes: 4 additions & 2 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast, TypeVar, Generic

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -73,6 +73,8 @@
PyTree = Any
_all_slice = slice(None, None, None)
compat_with_equinox = False
A = TypeVar('A')



def compatible_with_equinox(mode: bool = True):
Expand Down Expand Up @@ -2135,7 +2137,7 @@ def _element_not_quantity(x):


@register_pytree_node_class
class Quantity:
class Quantity(Generic[A]):
"""
The `Quantity` class represents a physical quantity with a mantissa and a unit.
It is used to represent all physical quantities in ``BrainUnit``.
Expand Down
17 changes: 11 additions & 6 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import numpy as np
import pytest
from numpy.testing import assert_equal
from typing import Union

import brainunit as u
from brainunit._base import (
Expand Down Expand Up @@ -900,6 +901,16 @@ def test_to(self):
print(x.to(u.volt))
print(x.to(u.uvolt))

def test_quantity_type(self):
def f1(a: u.Quantity[u.ms]) -> u.Quantity[u.mV]:
return a

def f2(a: u.Quantity[Union[u.ms, u.mA]]) -> u.Quantity[u.mV]:
return a

def f3(a: u.Quantity[Union[u.ms, u.mA]]) -> u.Quantity[Union[u.mV, u.ms]]:
return a


class TestNumPyFunctions(unittest.TestCase):
def test_special_case_numpy_functions(self):
Expand Down Expand Up @@ -1468,12 +1479,6 @@ def test_pickle():
print(b)








def test_str_repr():
"""
Test that str representations do not raise any errors and that repr
Expand Down

0 comments on commit e9fb3e3

Please sign in to comment.