diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..58b29af --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,32 @@ +name: pyright + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + pyright: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install uv + uv pip install --system -e . + uv pip install --system jax + uv pip install --system git+https://github.com/brentyi/jaxls.git + uv pip install --system git+https://github.com/brentyi/hamer_helper.git + uv pip install --system pyright + - name: Run pyright + run: | + pyright . diff --git a/3_aria_inference.py b/3_aria_inference.py index ed9c0d2..c489011 100644 --- a/3_aria_inference.py +++ b/3_aria_inference.py @@ -188,7 +188,7 @@ def main(args: Args) -> None: hamer_detections, aria_detections, points_data=points_data, - splat_path=splat_path, + splat_path=traj_paths.splat_path, floor_z=floor_z, ) while True: diff --git a/pyproject.toml b/pyproject.toml index eee4434..e94089e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "torch>2.2", "viser>=0.2.11", "typeguard", - "jaxtyping", + "jaxtyping>=0.2.29", "einops", "rotary-embedding-torch", "h5py", @@ -28,7 +28,8 @@ dependencies = [ "tensorboardX", "loguru", "projectaria-tools[all]", - "opencv-python" + "opencv-python", + "gdown", ] [tool.setuptools.package-data] diff --git a/src/egoallo/data/amass.py b/src/egoallo/data/amass.py index a3adfb3..8f8bda4 100644 --- a/src/egoallo/data/amass.py +++ b/src/egoallo/data/amass.py @@ -1,12 +1,11 @@ from pathlib import Path -from typing import Any, Literal, cast +from typing import Any, Literal, assert_never, cast import h5py import numpy as np import torch import torch.utils import torch.utils.data -from typing_extensions import assert_never from .dataclass import EgoTrainingData diff --git a/src/egoallo/data/dataclass.py b/src/egoallo/data/dataclass.py index bc79523..f552285 100644 --- a/src/egoallo/data/dataclass.py +++ b/src/egoallo/data/dataclass.py @@ -58,7 +58,7 @@ def joints_wrt_world(self) -> Tensor: @staticmethod def load_from_npz( - body_model: fncsmpl.SmplModel, + body_model: fncsmpl.SmplhModel, path: Path, include_hands: bool, ) -> EgoTrainingData: diff --git a/src/egoallo/network.py b/src/egoallo/network.py index 7574c3d..64c3150 100644 --- a/src/egoallo/network.py +++ b/src/egoallo/network.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import cache, cached_property -from typing import Literal +from typing import Literal, assert_never import numpy as np import torch @@ -11,7 +11,6 @@ from loguru import logger from rotary_embedding_torch import RotaryEmbedding from torch import Tensor, nn -from typing_extensions import assert_never from .fncsmpl import SmplhModel, SmplhShapedAndPosed from .tensor_dataclass import TensorDataclass diff --git a/src/egoallo/tensor_dataclass.py b/src/egoallo/tensor_dataclass.py index 92925a7..edc438b 100644 --- a/src/egoallo/tensor_dataclass.py +++ b/src/egoallo/tensor_dataclass.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Any, Callable +from typing import Any, Callable, Self, dataclass_transform import torch -from typing_extensions import Self, dataclass_transform @dataclass_transform() diff --git a/src/egoallo/training_utils.py b/src/egoallo/training_utils.py index e17c378..d0a128d 100644 --- a/src/egoallo/training_utils.py +++ b/src/egoallo/training_utils.py @@ -8,10 +8,18 @@ import time import traceback as tb from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Protocol, Sized, overload +from typing import ( + Any, + Dict, + Generator, + Iterable, + Protocol, + Sized, + get_type_hints, + overload, +) import torch -from typing_extensions import get_type_hints def flattened_hparam_dict_from_dataclass( diff --git a/src/egoallo/transforms/_base.py b/src/egoallo/transforms/_base.py index 7c321ed..8847397 100644 --- a/src/egoallo/transforms/_base.py +++ b/src/egoallo/transforms/_base.py @@ -1,10 +1,20 @@ import abc -from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload +from typing import ( + ClassVar, + Generic, + Self, + Tuple, + Type, + TypeVar, + Union, + final, + overload, + override, +) import numpy as onp import torch from torch import Tensor -from typing_extensions import Self, final, override GroupType = TypeVar("GroupType", bound="MatrixLieGroup") SEGroupType = TypeVar("SEGroupType", bound="SEBase") diff --git a/src/egoallo/transforms/_se3.py b/src/egoallo/transforms/_se3.py index 660871b..104fbbf 100644 --- a/src/egoallo/transforms/_se3.py +++ b/src/egoallo/transforms/_se3.py @@ -1,12 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import cast +from typing import Union, cast, override import numpy as np import torch from torch import Tensor -from typing_extensions import Union, override from . import _base from ._so3 import SO3 diff --git a/src/egoallo/transforms/_so3.py b/src/egoallo/transforms/_so3.py index cf37ff8..39e1dfb 100644 --- a/src/egoallo/transforms/_so3.py +++ b/src/egoallo/transforms/_so3.py @@ -1,12 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Union +from typing import Union, override import numpy as np import torch from torch import Tensor -from typing_extensions import override from . import _base from .utils import get_epsilon, register_lie_group