Skip to content

Commit

Permalink
Enable Quantity typing with syntax of Quantity[unit] (#95)
Browse files Browse the repository at this point in the history
* enable Quantity typing with syntax of `Quantity[unit]`

* fix tests
  • Loading branch information
chaoming0625 authored Jan 15, 2025
1 parent 4423578 commit 862210b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
6 changes: 4 additions & 2 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast, TypeVar, Generic

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -73,6 +73,8 @@
PyTree = Any
_all_slice = slice(None, None, None)
compat_with_equinox = False
A = TypeVar('A')



def compatible_with_equinox(mode: bool = True):
Expand Down Expand Up @@ -2135,7 +2137,7 @@ def _element_not_quantity(x):


@register_pytree_node_class
class Quantity:
class Quantity(Generic[A]):
"""
The `Quantity` class represents a physical quantity with a mantissa and a unit.
It is used to represent all physical quantities in ``BrainUnit``.
Expand Down
28 changes: 18 additions & 10 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import itertools
import os
import pickle
import sys
import tempfile

os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
import itertools
import unittest
import warnings
from copy import deepcopy
from typing import Union

import brainstate as bst
import jax
Expand All @@ -48,7 +50,6 @@
)
from brainunit._unit_common import *
from brainunit._unit_shortcuts import kHz, ms, mV, nS
import pickle


class TestDimension(unittest.TestCase):
Expand Down Expand Up @@ -900,6 +901,19 @@ def test_to(self):
print(x.to(u.volt))
print(x.to(u.uvolt))

def test_quantity_type(self):

# if sys.version_info >= (3, 11):

def f1(a: u.Quantity[u.ms]) -> u.Quantity[u.mV]:
return a

def f2(a: u.Quantity[Union[u.ms, u.mA]]) -> u.Quantity[u.mV]:
return a

def f3(a: u.Quantity[Union[u.ms, u.mA]]) -> u.Quantity[Union[u.mV, u.ms]]:
return a


class TestNumPyFunctions(unittest.TestCase):
def test_special_case_numpy_functions(self):
Expand Down Expand Up @@ -1468,12 +1482,6 @@ def test_pickle():
print(b)








def test_str_repr():
"""
Test that str representations do not raise any errors and that repr
Expand Down

0 comments on commit 862210b

Please sign in to comment.