Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ with one *slight* but **important** difference:
- Fix anchor link to JAX's documentation (by <gh-user:jeertmans>, in <gh-pr:346>).
- Simplified {func}`deepmimo.export<differt.plugins.deepmimo.export>` to reduce redundant code (by <gh-user:jeertmans>, in <gh-pr:356>).
- Changed type checker from `pyright` to `ty` (by <gh-user:jeertmans>, in <gh-pr:292>).
- Slightly improved code coverage (by <gh-user:jeertmans>, in <gh-pr:362>).

### Fixed

- Restricted `ipykernel` version to `<7` to avoid compatibility issues with `jupyter_rfb`, see <ext-gh-issue:vispy/jupyter_rfb#121> (by <gh-user:jeertmans>, in <gh-pr:347>).
- Pinned `sphinx` to `<9` to avoid breakage with `sphinx-autodoc-typehints` and the Sphinx v9 release (by <gh-user:jeertmans>, in <gh-pr:352>).
- Fixed `get` method when indexing mesh with {meth}`TriangleMesh.at<differt.geometry.TriangleMesh.at>` to **not** drop duplicate indices (by <gh-user:jeertmans>, in <gh-pr:362>).

<!-- start changelog -->

Expand Down
19 changes: 7 additions & 12 deletions differt/src/differt/geometry/_triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,21 @@
if TYPE_CHECKING or hasattr(typing, "GENERATING_DOCS"):
from typing import Self
else:
Self = Any # Because runtime type checking from 'beartype' will fail when combined with 'jaxtyping
Self = Any # Because runtime type checking from 'beartype' will fail when combined with 'jaxtyping'


class _AtIndexingKwargs(TypedDict):
indices_are_sorted: bool
unique_indices: bool
wrap_negative_indices: NotRequired[bool]
wrap_negative_indices: bool


_AT_INDEXING_KWARGS: _AtIndexingKwargs = {
"indices_are_sorted": True,
"unique_indices": True,
"wrap_negative_indices": False,
}

if jax.__version_info__ >= (0, 7, 0):
_AT_INDEXING_KWARGS["wrap_negative_indices"] = False


@jax.jit
def triangles_contain_vertices_assuming_inside_same_plane(
Expand Down Expand Up @@ -138,10 +136,6 @@ def __repr__(self) -> str:
return f"_TriangleMeshVerticesUpdateRef({self.mesh!r}, {self.index!r})"

def _triangles_index(self, **kwargs: Any) -> _Index:
if self.index == slice(None):
# TODO: check if we can use fast path but avoid updating vertices
# that are not referenced by any triangle
return self.index # Fast path
index = self.mesh.triangles.at[self.index, :].get(**kwargs).reshape(-1)
return jnp.unique(
index, size=len(index), fill_value=self.mesh.vertices.shape[0]
Expand All @@ -156,8 +150,9 @@ def set(self, values: Any, **kwargs: Any) -> _T:
)

def get(self, **kwargs: Any) -> Float[ArrayLike, "num_indexed_triangles 3"]:
index = self._triangles_index(**kwargs)
return self.mesh.vertices.at[index, :].get(**_AT_INDEXING_KWARGS)
# get() is allowed to return duplicates, so we do not use _triangles_index()
index = self.mesh.triangles.at[self.index, :].get(**kwargs).reshape(-1)
return self.mesh.vertices.at[index, :].get(wrap_negative_indices=False)

def apply(
self,
Expand Down Expand Up @@ -503,7 +498,7 @@ def bounding_box(self) -> Float[Array, "2 3"]:
@property
def at(self): # noqa: ANN202
"""
Helper property for updating a subset of triangle vertices.
Helper property for updating or indexing a subset of triangle vertices.

This ``at`` property is used to update vertices of a triangle mesh,
based on triangles indices,
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/rt/_image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def image_method(
mirror_vertices = jnp.asarray(mirror_vertices)
mirror_normals = jnp.asarray(mirror_normals)

if mirror_vertices.shape[0] == 0:
if mirror_vertices.shape[-2] == 0:
# If there are no mirrors, return empty array.
batch = jnp.broadcast_shapes(
from_vertices.shape[:-1],
Expand Down
43 changes: 34 additions & 9 deletions differt/tests/geometry/test_triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def test_box(
"index",
[
slice(None),
slice(0, None, 2),
jnp.arange(24),
jnp.array([0, 1, 2]),
jnp.ones(24, dtype=bool),
Expand All @@ -312,25 +313,33 @@ def test_box(
)
@pytest.mark.parametrize(
("method", "func_or_values"),
[("apply", lambda x: 1 / x), ("add", [1.0, 3.0, 6.0]), ("mul", 2.0)],
[
("set", (0,)),
("get", ()),
("apply", (lambda x: 1 / x,)),
("add", ([1.0, 3.0, 6.0],)),
("mul", (2.0,)),
],
)
def test_at_update(
self,
index: slice | Array,
method: Literal["set", "apply", "add", "mul", "get"],
func_or_values: Any,
func_or_values: tuple[Any, ...],
two_buildings_mesh: TriangleMesh,
) -> None:
got = getattr(two_buildings_mesh.at[index], method)(func_or_values)
got = getattr(two_buildings_mesh.at[index], method)(*func_or_values)

if index != slice(None):
if isinstance(index, Array) and index.dtype != jnp.bool:
index = jnp.unique(index)
index = two_buildings_mesh.triangles[index, :].reshape(-1)
if method != "get" and isinstance(index, Array) and index.dtype != jnp.bool:
# This should be a no-op, because duplicate indices are dropped before updating
index = jnp.unique(index)
index = two_buildings_mesh.triangles[index, :].reshape(-1)
if method != "get":
# Duplicate indices are dropped before updating
index = jnp.unique(index)

vertices = getattr(two_buildings_mesh.vertices.at[index, :], method)(
func_or_values
*func_or_values
)
if method == "get":
expected = vertices
Expand Down Expand Up @@ -408,6 +417,14 @@ def test_not_empty(self, two_buildings_mesh: TriangleMesh) -> None:
"other_colors",
[False, True],
)
@pytest.mark.parametrize(
"self_face_materials",
[False, True],
)
@pytest.mark.parametrize(
"other_face_materials",
[False, True],
)
@pytest.mark.parametrize(
"self_mask",
[False, True],
Expand All @@ -424,12 +441,13 @@ def test_append(
other_assume_quads: bool,
self_colors: bool,
other_colors: bool,
self_face_materials: bool,
other_face_materials: bool,
self_mask: bool,
other_mask: bool,
two_buildings_mesh: TriangleMesh,
key: PRNGKeyArray,
) -> None:
# TODO: Test merging material names.
s = (
TriangleMesh.empty() if self_empty else two_buildings_mesh
).set_assume_quads(self_assume_quads)
Expand All @@ -444,6 +462,11 @@ def test_append(
if other_colors and not other_empty:
o = o.set_face_colors(key=key_o) # type: ignore[reportCallIssue]

if self_face_materials and not self_empty:
s = s.set_materials("material_a")
if other_face_materials and not other_empty:
o = o.set_materials("material_b")

if self_mask and not self_empty:
s = eqx.tree_at(
lambda m: m.mask,
Expand Down Expand Up @@ -480,6 +503,8 @@ def test_append(
else:
assert mesh.face_colors is None

# TODO: Test merging material names and indices.

if (self_mask and not self_empty) or (other_mask and not other_empty):
assert mesh.mask is not None
else:
Expand Down
1 change: 1 addition & 0 deletions differt/tests/rt/test_image_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_image_of_vertices_with_respect_to_mirrors() -> None:
((10, 3), (1, 3), (1, 3), does_not_raise()),
((10, 3), (10, 1, 3), (10, 1, 3), does_not_raise()),
((10, 3), (10, 1, 3), (1, 1, 3), does_not_raise()),
((0, 3), (10, 0, 3), (1, 0, 3), does_not_raise()),
((1, 3), (10, 1, 3), (1, 1, 3), does_not_raise()),
((3,), (1, 3), (1, 3), does_not_raise()),
pytest.param(
Expand Down
38 changes: 27 additions & 11 deletions differt/tests/scene/test_triangle_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,21 @@ def test_compute_paths_on_simple_street_canyon(

@pytest.mark.xfail(reason="Not yet (correctly) implemented.")
@pytest.mark.parametrize("order", [0, 1, 2, 3])
@pytest.mark.parametrize(
"method",
[
"exhaustive",
"sbr",
"hybrid",
],
)
@pytest.mark.parametrize("chunk_size", [None, 1000])
@pytest.mark.parametrize("assume_quads", [False, True])
@pytest.mark.parametrize("mesh_mask", [False, True])
def test_compute_paths_with_smoothing(
self,
order: int | None,
method: Literal["exhaustive", "sbr", "hybrid"],
chunk_size: int | None,
assume_quads: bool,
mesh_mask: bool,
Expand All @@ -317,18 +326,27 @@ def test_compute_paths_with_smoothing(
is_leaf=lambda x: x is None,
)

expected = scene.compute_paths(
expected = scene.compute_paths( # ty: ignore[no-matching-overload]
order=order,
chunk_size=chunk_size,
method="exhaustive",
method=method,
)

got = scene.compute_paths(
order=order,
chunk_size=chunk_size,
method="exhaustive",
smoothing_factor=1000.0,
)
if method != "exhaustive":
expectation = pytest.warns(
UserWarning,
match="Argument 'smoothing' is currently ignored when 'method' is not set to 'exhaustive'",
)
else:
expectation = does_not_raise()

with expectation:
got = scene.compute_paths( # ty: ignore[no-matching-overload]
order=order,
method=method,
chunk_size=chunk_size,
smoothing_factor=1000.0,
)

assert type(got) is type(expected)

Expand Down Expand Up @@ -473,9 +491,7 @@ def test_plot(
[
"exhaustive",
"sbr",
pytest.param(
"hybrid", marks=pytest.mark.xfail(reason="Not yet implemented.")
),
"hybrid",
],
)
def test_compute_paths_with_mesh_mask_matches_sub_mesh_without_mask(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ skip = "docs/source/conf.py,docs/source/references.bib,pyproject.toml,uv.lock"
exclude_lines = [
'pragma: no cover',
'raise NotImplementedError',
'if TYPE_CHECKING:',
'if typing.TYPE_CHECKING:',
'if TYPE_CHECKING',
'if typing.TYPE_CHECKING',
]
precision = 2

Expand Down
Loading