Skip to content

Commit

Permalink
Add pyright + fix caught typos
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 19, 2024
1 parent 0c1718c commit c33ce8e
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 18 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: pyright

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
pyright:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
uv pip install --system -e .
uv pip install --system jax
uv pip install --system git+https://github.com/brentyi/jaxls.git
uv pip install --system git+https://github.com/brentyi/hamer_helper.git
uv pip install --system pyright
- name: Run pyright
run: |
pyright .
2 changes: 1 addition & 1 deletion 3_aria_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def main(args: Args) -> None:
hamer_detections,
aria_detections,
points_data=points_data,
splat_path=splat_path,
splat_path=traj_paths.splat_path,
floor_z=floor_z,
)
while True:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"torch>2.2",
"viser>=0.2.11",
"typeguard",
"jaxtyping",
"jaxtyping>=0.2.29",
"einops",
"rotary-embedding-torch",
"h5py",
Expand All @@ -28,7 +28,8 @@ dependencies = [
"tensorboardX",
"loguru",
"projectaria-tools[all]",
"opencv-python"
"opencv-python",
"gdown",
]

[tool.setuptools.package-data]
Expand Down
3 changes: 1 addition & 2 deletions src/egoallo/data/amass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from pathlib import Path
from typing import Any, Literal, cast
from typing import Any, Literal, assert_never, cast

import h5py
import numpy as np
import torch
import torch.utils
import torch.utils.data
from typing_extensions import assert_never

from .dataclass import EgoTrainingData

Expand Down
2 changes: 1 addition & 1 deletion src/egoallo/data/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def joints_wrt_world(self) -> Tensor:

@staticmethod
def load_from_npz(
body_model: fncsmpl.SmplModel,
body_model: fncsmpl.SmplhModel,
path: Path,
include_hands: bool,
) -> EgoTrainingData:
Expand Down
3 changes: 1 addition & 2 deletions src/egoallo/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from functools import cache, cached_property
from typing import Literal
from typing import Literal, assert_never

import numpy as np
import torch
Expand All @@ -11,7 +11,6 @@
from loguru import logger
from rotary_embedding_torch import RotaryEmbedding
from torch import Tensor, nn
from typing_extensions import assert_never

from .fncsmpl import SmplhModel, SmplhShapedAndPosed
from .tensor_dataclass import TensorDataclass
Expand Down
3 changes: 1 addition & 2 deletions src/egoallo/tensor_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import dataclasses
from typing import Any, Callable
from typing import Any, Callable, Self, dataclass_transform

import torch
from typing_extensions import Self, dataclass_transform


@dataclass_transform()
Expand Down
12 changes: 10 additions & 2 deletions src/egoallo/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@
import time
import traceback as tb
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, Protocol, Sized, overload
from typing import (
Any,
Dict,
Generator,
Iterable,
Protocol,
Sized,
get_type_hints,
overload,
)

import torch
from typing_extensions import get_type_hints


def flattened_hparam_dict_from_dataclass(
Expand Down
14 changes: 12 additions & 2 deletions src/egoallo/transforms/_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import abc
from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload
from typing import (
ClassVar,
Generic,
Self,
Tuple,
Type,
TypeVar,
Union,
final,
overload,
override,
)

import numpy as onp
import torch
from torch import Tensor
from typing_extensions import Self, final, override

GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
SEGroupType = TypeVar("SEGroupType", bound="SEBase")
Expand Down
3 changes: 1 addition & 2 deletions src/egoallo/transforms/_se3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import cast
from typing import Union, cast, override

import numpy as np
import torch
from torch import Tensor
from typing_extensions import Union, override

from . import _base
from ._so3 import SO3
Expand Down
3 changes: 1 addition & 2 deletions src/egoallo/transforms/_so3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Union
from typing import Union, override

import numpy as np
import torch
from torch import Tensor
from typing_extensions import override

from . import _base
from .utils import get_epsilon, register_lie_group
Expand Down

0 comments on commit c33ce8e

Please sign in to comment.