diff --git a/CHANGELOG.md b/CHANGELOG.md index 6defbb42e..010a5ef2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ Arcade [PyPi Release History](https://pypi.org/project/arcade/#history) page. - Fix `UIScrollArea.add` always returning None - Support `layer` in `UIView.add_widget()` - Text objects are now lazy and can be created before the window +- Introduce `arcade.SpriteSequence[T]` as a covariant supertype of `arcade.SpriteList[T]` + (this is similar to Python's `Sequence[T]`, which is a supertype of `list[T]`) + and various improvements to the typing of the API that leverage it ## Version 3.1.0 diff --git a/arcade/__init__.py b/arcade/__init__.py index cef66f34b..97b18275f 100644 --- a/arcade/__init__.py +++ b/arcade/__init__.py @@ -168,6 +168,7 @@ def configure_logging(level: int | None = None): from .sprite import PyMunk from .sprite import PymunkMixin from .sprite import SpriteType +from .sprite import SpriteType_co from .sprite import Sprite from .sprite import BasicSprite @@ -176,6 +177,7 @@ def configure_logging(level: int | None = None): from .sprite import SpriteSolidColor from .sprite_list import SpriteList +from .sprite_list import SpriteSequence from .sprite_list import check_for_collision from .sprite_list import check_for_collision_with_list from .sprite_list import check_for_collision_with_lists @@ -283,9 +285,11 @@ def configure_logging(level: int | None = None): "BasicSprite", "Sprite", "SpriteType", + "SpriteType_co", "PymunkMixin", "SpriteCircle", "SpriteList", + "SpriteSequence", "SpriteSolidColor", "Text", "Texture", diff --git a/arcade/future/input/input_manager_example.py b/arcade/future/input/input_manager_example.py index 0fce91380..284c31c87 100644 --- a/arcade/future/input/input_manager_example.py +++ b/arcade/future/input/input_manager_example.py @@ -26,7 +26,7 @@ class Player(arcade.Sprite): def __init__( self, texture, - walls: arcade.SpriteList, + walls: arcade.SpriteSequence[arcade.BasicSprite], input_manager_template: InputManager, controller: pyglet.input.Controller | None = None, center_x: float = 0.0, @@ -76,11 +76,11 @@ def __init__( } self.players: list[Player | None] = [] - self.player_list = arcade.SpriteList() + self.player_list: arcade.SpriteList[Player] = arcade.SpriteList() self.device_labels_batch = pyglet.graphics.Batch() self.player_device_labels: list[arcade.Text | None] = [] - self.wall_list = arcade.SpriteList(use_spatial_hash=True) + self.wall_list: arcade.SpriteList[arcade.Sprite] = arcade.SpriteList(use_spatial_hash=True) for x in range(0, self.width + 64, 64): wall = arcade.Sprite(":resources:images/tiles/grassMid.png", scale=0.5) diff --git a/arcade/future/light/light_demo.py b/arcade/future/light/light_demo.py index cfdedd3dc..dee3839ce 100644 --- a/arcade/future/light/light_demo.py +++ b/arcade/future/light/light_demo.py @@ -18,7 +18,7 @@ def __init__(self, width, height, title): super().__init__(width, height, title) self.background = arcade.load_texture(":resources:images/backgrounds/abstract_1.jpg") - self.torch_list = arcade.SpriteList() + self.torch_list: arcade.SpriteList[arcade.Sprite] = arcade.SpriteList() self.torch_list.extend( [ arcade.Sprite( diff --git a/arcade/particles/emitter.py b/arcade/particles/emitter.py index ec6b24830..6067472b1 100644 --- a/arcade/particles/emitter.py +++ b/arcade/particles/emitter.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Callable, cast +from typing import Callable import arcade from arcade import Vec2 @@ -151,7 +151,7 @@ def __init__( self.particle_factory = particle_factory self._emit_done_cb = emit_done_cb self._reap_cb = reap_cb - self._particles: arcade.SpriteList = arcade.SpriteList(use_spatial_hash=False) + self._particles: arcade.SpriteList[Particle] = arcade.SpriteList(use_spatial_hash=False) def _emit(self): """ @@ -189,7 +189,7 @@ def update(self, delta_time: float = 1 / 60): for _ in range(emit_count): self._emit() self._particles.update(delta_time) - particles_to_reap = [p for p in self._particles if cast(Particle, p).can_reap()] + particles_to_reap = [p for p in self._particles if p.can_reap()] for dead_particle in particles_to_reap: dead_particle.kill() diff --git a/arcade/paths.py b/arcade/paths.py index 8c9e65c16..2990a801e 100644 --- a/arcade/paths.py +++ b/arcade/paths.py @@ -4,14 +4,22 @@ import math -from arcade import Sprite, SpriteList, check_for_collision_with_list, get_sprites_at_point +from arcade import ( + BasicSprite, + Sprite, + SpriteSequence, + check_for_collision_with_list, + get_sprites_at_point, +) from arcade.math import get_distance, lerp_2d from arcade.types import Point2 __all__ = ["AStarBarrierList", "astar_calculate_path", "has_line_of_sight"] -def _spot_is_blocked(position: Point2, moving_sprite: Sprite, blocking_sprites: SpriteList) -> bool: +def _spot_is_blocked( + position: Point2, moving_sprite: Sprite, blocking_sprites: SpriteSequence[BasicSprite] +) -> bool: """ Return if position is blocked @@ -275,7 +283,7 @@ class AStarBarrierList: def __init__( self, moving_sprite: Sprite, - blocking_sprites: SpriteList, + blocking_sprites: SpriteSequence[BasicSprite], grid_size: int, left: int, right: int, @@ -372,7 +380,7 @@ def astar_calculate_path( def has_line_of_sight( observer: Point2, target: Point2, - walls: SpriteList, + walls: SpriteSequence[BasicSprite], max_distance: float = float("inf"), check_resolution: int = 2, ) -> bool: diff --git a/arcade/physics_engines.py b/arcade/physics_engines.py index ba2f69bd4..6f2dce0b8 100644 --- a/arcade/physics_engines.py +++ b/arcade/physics_engines.py @@ -8,7 +8,7 @@ from arcade import ( BasicSprite, Sprite, - SpriteList, + SpriteSequence, SpriteType, check_for_collision, check_for_collision_with_lists, @@ -20,7 +20,7 @@ from arcade.utils import Chain, copy_dunders_unimplemented -def _wiggle_until_free(colliding: Sprite, walls: Iterable[SpriteList]) -> None: +def _wiggle_until_free(colliding: Sprite, walls: Iterable[SpriteSequence[BasicSprite]]) -> None: """Kludge to 'guess' a colliding sprite out of a collision. It works by iterating over increasing wiggle sizes of 8 points @@ -80,7 +80,7 @@ def _wiggle_until_free(colliding: Sprite, walls: Iterable[SpriteList]) -> None: def _move_sprite( - moving_sprite: Sprite, can_collide: Iterable[SpriteList[SpriteType]], ramp_up: bool + moving_sprite: Sprite, can_collide: Iterable[SpriteSequence[SpriteType]], ramp_up: bool ) -> list[SpriteType]: """Update a sprite's angle and position, returning a list of collisions. @@ -273,11 +273,14 @@ def _move_sprite( return complete_hit_list -def _add_to_list(dest: list[SpriteList], source: SpriteList | Iterable[SpriteList] | None) -> None: - """Helper function to add a SpriteList or list of SpriteLists to a list.""" +def _add_to_list( + dest: list[SpriteSequence[SpriteType]], + source: SpriteSequence[SpriteType] | Iterable[SpriteSequence[SpriteType]] | None, +) -> None: + """Helper function to add a SpriteSequence or list of SpriteSequences to a list.""" if not source: return - elif isinstance(source, SpriteList): + elif isinstance(source, SpriteSequence): dest.append(source) else: dest.extend(source) @@ -310,17 +313,17 @@ class PhysicsEngineSimple: def __init__( self, player_sprite: Sprite, - walls: SpriteList | Iterable[SpriteList] | None = None, + walls: SpriteSequence[BasicSprite] | Iterable[SpriteSequence[BasicSprite]] | None = None, ) -> None: self.player_sprite: Sprite = player_sprite """The player-controlled :py:class:`.Sprite`.""" - self._walls: list[SpriteList] = [] + self._walls: list[SpriteSequence[BasicSprite]] = [] if walls: _add_to_list(self._walls, walls) @property - def walls(self) -> list[SpriteList]: + def walls(self) -> list[SpriteSequence[BasicSprite]]: """Which :py:class:`.SpriteList` instances block player movement. .. important:: Avoid moving sprites in these lists! @@ -334,7 +337,10 @@ def walls(self) -> list[SpriteList]: return self._walls @walls.setter - def walls(self, walls: SpriteList | Iterable[SpriteList] | None = None) -> None: + def walls( + self, + walls: SpriteSequence[BasicSprite] | Iterable[SpriteSequence[BasicSprite]] | None = None, + ) -> None: if walls: _add_to_list(self._walls, walls) else: @@ -429,17 +435,17 @@ class PhysicsEnginePlatformer: def __init__( self, player_sprite: Sprite, - platforms: SpriteList | Iterable[SpriteList] | None = None, + platforms: SpriteSequence[Sprite] | Iterable[SpriteSequence[Sprite]] | None = None, gravity_constant: float = 0.5, - ladders: SpriteList | Iterable[SpriteList] | None = None, - walls: SpriteList | Iterable[SpriteList] | None = None, + ladders: SpriteSequence[BasicSprite] | Iterable[SpriteSequence[BasicSprite]] | None = None, + walls: SpriteSequence[BasicSprite] | Iterable[SpriteSequence[BasicSprite]] | None = None, ) -> None: if not isinstance(player_sprite, Sprite): raise TypeError("player_sprite must be a Sprite, not a basic_sprite!") - self._ladders: list[SpriteList] = [] - self._platforms: list[SpriteList] = [] - self._walls: list[SpriteList] = [] + self._ladders: list[SpriteSequence[BasicSprite]] = [] + self._platforms: list[SpriteSequence[Sprite]] = [] + self._walls: list[SpriteSequence[BasicSprite]] = [] self._all_obstacles = Chain(self._walls, self._platforms) _add_to_list(self._ladders, ladders) @@ -517,7 +523,7 @@ def __init__( # TODO: figure out what do do with 15_ladders_moving_platforms.py # It's no longer used by any example or tutorial file @property - def ladders(self) -> list[SpriteList]: + def ladders(self) -> list[SpriteSequence[BasicSprite]]: """Ladders turn off gravity while touched by the player. This means that whenever the :py:attr:`player_sprite` collides @@ -533,7 +539,10 @@ def ladders(self) -> list[SpriteList]: return self._ladders @ladders.setter - def ladders(self, ladders: SpriteList | Iterable[SpriteList] | None = None) -> None: + def ladders( + self, + ladders: SpriteSequence[BasicSprite] | Iterable[SpriteSequence[BasicSprite]] | None = None, + ) -> None: if ladders: _add_to_list(self._ladders, ladders) else: @@ -544,7 +553,7 @@ def ladders(self) -> None: self._ladders.clear() @property - def platforms(self) -> list[SpriteList]: + def platforms(self) -> list[SpriteSequence[Sprite]]: """:py:class:`~arcade.sprite_list.sprite_list.SpriteList` instances containing platforms. .. important:: For best performance, put non-moving terrain in @@ -575,7 +584,9 @@ def platforms(self) -> list[SpriteList]: return self._platforms @platforms.setter - def platforms(self, platforms: SpriteList | Iterable[SpriteList] | None = None) -> None: + def platforms( + self, platforms: SpriteSequence[Sprite] | Iterable[SpriteSequence[Sprite]] | None = None + ) -> None: if platforms: _add_to_list(self._platforms, platforms) else: @@ -586,7 +597,7 @@ def platforms(self) -> None: self._platforms.clear() @property - def walls(self) -> list[SpriteList]: + def walls(self) -> list[SpriteSequence[BasicSprite]]: """Exposes the :py:class:`SpriteList` instances use as terrain. .. important:: For best performance, only add non-moving sprites! @@ -611,7 +622,10 @@ def walls(self) -> list[SpriteList]: return self._walls @walls.setter - def walls(self, walls: SpriteList | Iterable[SpriteList] | None = None) -> None: + def walls( + self, + walls: SpriteSequence[BasicSprite] | Iterable[SpriteSequence[BasicSprite]] | None = None, + ) -> None: if walls: _add_to_list(self._walls, walls) else: diff --git a/arcade/sprite/__init__.py b/arcade/sprite/__init__.py index 724680a91..611ae41c0 100644 --- a/arcade/sprite/__init__.py +++ b/arcade/sprite/__init__.py @@ -4,7 +4,7 @@ from arcade.texture import Texture from arcade.resources import resolve -from .base import BasicSprite, SpriteType +from .base import BasicSprite, SpriteType, SpriteType_co from .sprite import Sprite from .mixins import PymunkMixin, PyMunk from .animated import ( @@ -69,6 +69,7 @@ def load_animated_gif(resource_name: str | Path) -> TextureAnimationSprite: __all__ = [ "SpriteType", + "SpriteType_co", "BasicSprite", "Sprite", "PyMunk", diff --git a/arcade/sprite/base.py b/arcade/sprite/base.py index d07dc7c9c..2a785e036 100644 --- a/arcade/sprite/base.py +++ b/arcade/sprite/base.py @@ -16,6 +16,9 @@ # Type from sprite that can be any BasicSprite or any subclass of BasicSprite SpriteType = TypeVar("SpriteType", bound="BasicSprite") +# Same as SpriteType, for covariant type parameters +SpriteType_co = TypeVar("SpriteType_co", bound="BasicSprite", covariant=True) + @copy_dunders_unimplemented # See https://github.com/pythonarcade/arcade/issues/2074 class BasicSprite: @@ -70,7 +73,15 @@ def __init__( self._height = height * self._scale[1] self._visible = bool(visible) self._color: Color = WHITE - self.sprite_lists: list["SpriteList"] = [] + + # In a more powerful type system, this would be typed as + # list[SpriteList[? super Self]] + # i.e., a list of SpriteList's with varying type arguments, but where + # each of those type arguments is known to be a supertype of Self. + # All changes to this list should go through the pair of methods + # register_sprite_list, _unregister_sprite_list. + # They ensure that the above typing invariant is preserved. + self.sprite_lists: list["SpriteList[Any]"] = [] """The sprite lists this sprite is a member of""" # Core properties we don't use, but spritelist expects it @@ -747,7 +758,7 @@ def update_spatial_hash(self) -> None: if sprite_list.spatial_hash is not None: sprite_list.spatial_hash.move(self) - def register_sprite_list(self, new_list: SpriteList) -> None: + def register_sprite_list(self: SpriteType, new_list: SpriteList[SpriteType]) -> None: """ Register this sprite as belonging to a list. @@ -755,13 +766,15 @@ def register_sprite_list(self, new_list: SpriteList) -> None: """ self.sprite_lists.append(new_list) + def _unregister_sprite_list(self: SpriteType, new_list: SpriteList[SpriteType]) -> None: + """Unregister this sprite as belonging to a list.""" + self.sprite_lists.remove(new_list) + def remove_from_sprite_lists(self) -> None: """Remove the sprite from all sprite lists.""" while len(self.sprite_lists) > 0: self.sprite_lists[0].remove(self) - self.sprite_lists.clear() - # ----- Drawing Methods ----- def draw_hit_box(self, color: RGBOrA255 = BLACK, line_thickness: float = 2.0) -> None: diff --git a/arcade/sprite/sprite.py b/arcade/sprite/sprite.py index 761cf8667..e67cf5795 100644 --- a/arcade/sprite/sprite.py +++ b/arcade/sprite/sprite.py @@ -1,6 +1,6 @@ import math from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any import arcade from arcade import Texture @@ -11,10 +11,6 @@ from .base import BasicSprite from .mixins import PymunkMixin -if TYPE_CHECKING: # handle import cycle caused by type hinting - from arcade.sprite_list import SpriteList - - __all__ = ["Sprite"] @@ -141,7 +137,6 @@ def __init__( self.physics_engines: list[Any] = [] """List of physics engines that have registered this sprite.""" - self._sprite_list: SpriteList | None = None # Debug properties self.guid: str | None = None """A unique id for debugging purposes.""" diff --git a/arcade/sprite_list/__init__.py b/arcade/sprite_list/__init__.py index 465fe90dd..f8b93309b 100644 --- a/arcade/sprite_list/__init__.py +++ b/arcade/sprite_list/__init__.py @@ -1,4 +1,4 @@ -from .sprite_list import SpriteList +from .sprite_list import SpriteList, SpriteSequence from .spatial_hash import SpatialHash from .collision import ( get_distance_between_sprites, @@ -14,6 +14,7 @@ __all__ = [ "SpriteList", + "SpriteSequence", "SpatialHash", "get_distance_between_sprites", "get_closest_sprite", diff --git a/arcade/sprite_list/collision.py b/arcade/sprite_list/collision.py index f7955f0d3..c8526d252 100644 --- a/arcade/sprite_list/collision.py +++ b/arcade/sprite_list/collision.py @@ -17,7 +17,7 @@ from arcade.types import Point from arcade.types.rect import Rect -from .sprite_list import SpriteList +from .sprite_list import SpriteSequence def get_distance_between_sprites(sprite1: SpriteType, sprite2: SpriteType) -> float: @@ -32,7 +32,7 @@ def get_distance_between_sprites(sprite1: SpriteType, sprite2: SpriteType) -> fl def get_closest_sprite( - sprite: SpriteType, sprite_list: SpriteList + sprite: BasicSprite, sprite_list: SpriteSequence[SpriteType] ) -> Tuple[SpriteType, float] | None: """ Given a Sprite and SpriteList, returns the closest sprite, and its distance. @@ -74,7 +74,7 @@ def check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool: if __debug__: if not isinstance(sprite1, BasicSprite): raise TypeError("Parameter 1 is not an instance of a Sprite class.") - if isinstance(sprite2, SpriteList): + if isinstance(sprite2, SpriteSequence): raise TypeError( "Parameter 2 is a instance of the SpriteList instead of a required " "Sprite. See if you meant to call check_for_collision_with_list instead " @@ -133,7 +133,7 @@ def _check_for_collision(sprite1: BasicSprite, sprite2: BasicSprite) -> bool: def _get_nearby_sprites( - sprite: BasicSprite, sprite_list: SpriteList[SpriteType] + sprite: BasicSprite, sprite_list: SpriteSequence[SpriteType] ) -> List[SpriteType]: sprite_count = len(sprite_list) if sprite_count == 0: @@ -186,7 +186,7 @@ def _get_nearby_sprites( def check_for_collision_with_list( sprite: BasicSprite, - sprite_list: SpriteList[SpriteType], + sprite_list: SpriteSequence[SpriteType], method: int = 0, ) -> List[SpriteType]: """ @@ -218,7 +218,7 @@ def check_for_collision_with_list( f"Parameter 1 is not an instance of the Sprite class, " f"it is an instance of {type(sprite)}." ) - if not isinstance(sprite_list, SpriteList): + if not isinstance(sprite_list, SpriteSequence): raise TypeError(f"Parameter 2 is a {type(sprite_list)} instead of expected SpriteList.") sprites_to_check: Iterable[SpriteType] @@ -246,7 +246,7 @@ def check_for_collision_with_list( def check_for_collision_with_lists( sprite: BasicSprite, - sprite_lists: Iterable[SpriteList[SpriteType]], + sprite_lists: Iterable[SpriteSequence[SpriteType]], method=1, ) -> List[SpriteType]: """ @@ -290,7 +290,7 @@ def check_for_collision_with_lists( return sprites -def get_sprites_at_point(point: Point, sprite_list: SpriteList[SpriteType]) -> List[SpriteType]: +def get_sprites_at_point(point: Point, sprite_list: SpriteSequence[SpriteType]) -> List[SpriteType]: """ Get a list of sprites at a particular point. This function sees if any sprite overlaps the specified point. If a sprite has a different center_x/center_y but touches the point, @@ -303,7 +303,7 @@ def get_sprites_at_point(point: Point, sprite_list: SpriteList[SpriteType]) -> L :returns: List of sprites colliding, or an empty list. """ if __debug__: - if not isinstance(sprite_list, SpriteList): + if not isinstance(sprite_list, SpriteSequence): raise TypeError(f"Parameter 2 is a {type(sprite_list)} instead of expected SpriteList.") sprites_to_check: Iterable[SpriteType] @@ -321,7 +321,7 @@ def get_sprites_at_point(point: Point, sprite_list: SpriteList[SpriteType]) -> L def get_sprites_at_exact_point( - point: Point, sprite_list: SpriteList[SpriteType] + point: Point, sprite_list: SpriteSequence[SpriteType] ) -> List[SpriteType]: """ Get a list of sprites whose center_x, center_y match the given point. @@ -334,7 +334,7 @@ def get_sprites_at_exact_point( List of sprites colliding, or an empty list. """ if __debug__: - if not isinstance(sprite_list, SpriteList): + if not isinstance(sprite_list, SpriteSequence): raise TypeError(f"Parameter 2 is a {type(sprite_list)} instead of expected SpriteList.") sprites_to_check: Iterable[SpriteType] @@ -349,7 +349,7 @@ def get_sprites_at_exact_point( return [s for s in sprites_to_check if s.position == point] -def get_sprites_in_rect(rect: Rect, sprite_list: SpriteList[SpriteType]) -> List[SpriteType]: +def get_sprites_in_rect(rect: Rect, sprite_list: SpriteSequence[SpriteType]) -> List[SpriteType]: """ Get a list of sprites in a particular rectangle. This function sees if any sprite overlaps the specified rectangle. If a sprite has a different @@ -365,7 +365,7 @@ def get_sprites_in_rect(rect: Rect, sprite_list: SpriteList[SpriteType]) -> List List of sprites colliding, or an empty list. """ if __debug__: - if not isinstance(sprite_list, SpriteList): + if not isinstance(sprite_list, SpriteSequence): raise TypeError(f"Parameter 2 is a {type(sprite_list)} instead of expected SpriteList.") rect_points = rect.to_points() diff --git a/arcade/sprite_list/spatial_hash.py b/arcade/sprite_list/spatial_hash.py index 53ddb3ffd..f16778d03 100644 --- a/arcade/sprite_list/spatial_hash.py +++ b/arcade/sprite_list/spatial_hash.py @@ -1,13 +1,68 @@ +from abc import abstractmethod +from collections.abc import Set from math import trunc -from typing import Generic +from typing import Protocol -from arcade.sprite import SpriteType +from arcade.sprite import SpriteType, SpriteType_co from arcade.sprite.base import BasicSprite from arcade.types import IPoint, Point from arcade.types.rect import Rect -class SpatialHash(Generic[SpriteType]): +class ReadOnlySpatialHash(Protocol[SpriteType_co]): + """A read-only view of a :py:class:`.SpatialHash` which helps preserve safety. + + This works like the read-only views of Python's built-in :py:class:`dict` + and other types. As an every-day user, it means that the underlying + `SpatialHash` may contain subclasses of the annotated type, but not + superclasses. + + This ensures predicable behavior via type safety in cases where: + + #. A spatial hash is annotated with a specific type + #. It is then manipulated outside the original context with a broader type + + Advanced users who want more information on the specifics should see the + comments of :py:class:`~arcade.sprite_list.SpriteList`. + """ + + @abstractmethod + def get_sprites_near_sprite(self, sprite: BasicSprite) -> Set[SpriteType_co]: + """ + Get all the sprites that are in the same buckets as the given sprite. + + Args: + sprite: The sprite to check + """ + ... + + @abstractmethod + def get_sprites_near_point(self, point: Point) -> Set[SpriteType_co]: + """ + Return sprites in the same bucket as the given point. + + Args: + point: The point to check + """ + ... + + @abstractmethod + def get_sprites_near_rect(self, rect: Rect) -> Set[SpriteType_co]: + """ + Return sprites in the same buckets as the given rectangle. + + .. tip:: Use :py:mod:`arcade.types.rect`'s helper functions to create + rectangle objects! + + Args: + rect: + The rectangle to check as a :py:class:`~arcade.types.rect.Rect` + object. + """ + ... + + +class SpatialHash(ReadOnlySpatialHash[SpriteType]): """A data structure best for collision checks with non-moving sprites. It subdivides space into a grid of squares, each with sides of length @@ -104,12 +159,6 @@ def remove(self, sprite: SpriteType) -> None: del self.buckets_for_sprite[sprite] def get_sprites_near_sprite(self, sprite: BasicSprite) -> set[SpriteType]: - """ - Get all the sprites that are in the same buckets as the given sprite. - - Args: - sprite: The sprite to check - """ min_point = trunc(sprite.left), trunc(sprite.bottom) max_point = trunc(sprite.right), trunc(sprite.top) @@ -126,23 +175,11 @@ def get_sprites_near_sprite(self, sprite: BasicSprite) -> set[SpriteType]: return close_by_sprites def get_sprites_near_point(self, point: Point) -> set[SpriteType]: - """ - Return sprites in the same bucket as the given point. - - Args: - point: The point to check - """ hash_point = self.hash((trunc(point[0]), trunc(point[1]))) # Return a copy of the set. return set(self.contents.setdefault(hash_point, set())) def get_sprites_near_rect(self, rect: Rect) -> set[SpriteType]: - """ - Return sprites in the same buckets as the given rectangle. - - Args: - rect: The rectangle to check (left, right, bottom, top) - """ left, right, bottom, top = rect.lrbt min_point = trunc(left), trunc(bottom) max_point = trunc(right), trunc(top) diff --git a/arcade/sprite_list/sprite_list.py b/arcade/sprite_list/sprite_list.py index 35f642a8c..8bf44a91c 100644 --- a/arcade/sprite_list/sprite_list.py +++ b/arcade/sprite_list/sprite_list.py @@ -8,6 +8,7 @@ from __future__ import annotations import random +from abc import abstractmethod from array import array from collections import deque from typing import ( @@ -15,15 +16,15 @@ Any, Callable, ClassVar, + Collection, Deque, - Generic, Iterable, Iterator, Sized, cast, ) -from arcade import Sprite, SpriteType, get_window, gl +from arcade import Sprite, SpriteType, SpriteType_co, get_window, gl from arcade.gl import Program, Texture2D from arcade.gl.buffer import Buffer from arcade.gl.types import BlendFunction, OpenGlFilter, PyGLenum @@ -39,8 +40,108 @@ _DEFAULT_CAPACITY = 100 +class SpriteSequence(Collection[SpriteType_co]): + """A read-only view of a :py:class:`.SpriteList`. + + Like other read-only generics such as :py:class:`collections.abc.Sequence`, + a `SpriteSequence` requires sprites be of a covariant type relative to their + annotated type. + + See :py:class:`.SpriteList` for more details. + """ + + from ..sprite_list import spatial_hash as sh + + @property + @abstractmethod + def spatial_hash(self) -> sh.ReadOnlySpatialHash[SpriteType_co] | None: ... + + @abstractmethod + def __getitem__(self, index: int) -> SpriteType_co: + """Return the sprite at the given index.""" + ... + + @abstractmethod + def update(self, delta_time: float = 1 / 60, *args, **kwargs) -> None: + """ + Call the update() method on each sprite in the list. + + Args: + delta_time: Time since last update in seconds + *args: Additional positional arguments + **kwargs: Additional keyword arguments + """ + ... + + @abstractmethod + def update_animation(self, delta_time: float = 1 / 60, *args, **kwargs) -> None: + """ + Call the update_animation in every sprite in the sprite list. + + Args: + delta_time: Time since last update in seconds + *args: Additional positional arguments + **kwargs: Additional keyword arguments + """ + ... + + @abstractmethod + def draw( + self, + *, + filter: PyGLenum | OpenGlFilter | None = None, + pixelated: bool | None = None, + blend_function: BlendFunction | None = None, + ) -> None: + """ + Draw this list of sprites. + + Uninitialized sprite lists will first create OpenGL resources + before drawing. This may cause a performance stutter when the + following are true: + + 1. You created the sprite list with ``lazy=True`` + 2. You did not call :py:meth:`~SpriteList.initialize` before drawing + 3. You are initializing many sprites and/or lists at once + + See :ref:`pg_spritelist_advanced_lazy_spritelists` to learn more. + + Args: + filter: + Optional parameter to set OpenGL filter, such as + `gl.GL_NEAREST` to avoid smoothing. + pixelated: + ``True`` for pixelated and ``False`` for smooth interpolation. + Shortcut for setting filter to GL_NEAREST for a pixelated look. + The filter parameter have precedence over this. + blend_function: + Optional parameter to set the OpenGL blend function used for drawing + the sprite list, such as 'arcade.Window.ctx.BLEND_ADDITIVE' or + 'arcade.Window.ctx.BLEND_DEFAULT' + """ + ... + + @abstractmethod + def draw_hit_boxes( + self, color: RGBOrA255 = (0, 0, 0, 255), line_thickness: float = 1.0 + ) -> None: + """ + Draw all the hit boxes in this list. + + .. warning:: This method is slow and should only be used for debugging. + + Args: + color: The color of the hit boxes + line_thickness: The thickness of the lines + """ + ... + + @abstractmethod + def _write_sprite_buffers_to_gpu(self) -> None: ... + + @copy_dunders_unimplemented # Temp fixes https://github.com/pythonarcade/arcade/issues/2074 -class SpriteList(Generic[SpriteType]): +class SpriteList(SpriteSequence[SpriteType]): """ The purpose of the spriteList is to batch draw a list of sprites. Drawing single sprites will not get you anywhere performance wise @@ -100,6 +201,20 @@ class SpriteList(Generic[SpriteType]): #: arcade.SpriteList.DEFAULT_TEXTURE_FILTER = gl.NEAREST, gl.NEAREST DEFAULT_TEXTURE_FILTER: ClassVar[tuple[int, int]] = gl.LINEAR, gl.LINEAR + # Declare `special_hash` as an attribute that implements the abstract + # property from `SpriteSequence`. It needs an explicit type here because + # it is better than the inherited type. + # More subtle: it requires to be initialized as a *class* attribute with + # `= None` to "delete" the abstract property definition from the class. + # Without that trick, attempt to instantiate a SpriteList results in a + # TypeError: Can't instantiate abstract class SpriteList + # without an implementation for abstract method 'spatial_hash' + # The abstract property is actually implemented as an attribute (for + # efficiency), so it is OK to silence the issue like that. + from ..sprite_list import spatial_hash as sh + + spatial_hash: sh.SpatialHash[SpriteType] | None = None + def __init__( self, use_spatial_hash: bool = False, @@ -167,7 +282,7 @@ def __init__( from .spatial_hash import SpatialHash self._spatial_hash_cell_size = spatial_hash_cell_size - self.spatial_hash: SpatialHash[SpriteType] | None = None + self.spatial_hash = None if use_spatial_hash: self.spatial_hash = SpatialHash(cell_size=self._spatial_hash_cell_size) @@ -247,7 +362,7 @@ def __len__(self) -> int: """Return the length of the sprite list.""" return len(self.sprite_list) - def __contains__(self, sprite: Sprite) -> bool: + def __contains__(self, sprite: object) -> bool: """Return if the sprite list contains the given sprite""" return sprite in self.sprite_slot @@ -269,7 +384,7 @@ def __setitem__(self, index: int, sprite: SpriteType) -> None: pass sprite_to_be_removed = self.sprite_list[index] - sprite_to_be_removed.sprite_lists.remove(self) + sprite_to_be_removed._unregister_sprite_list(self) self.sprite_list[index] = sprite # Replace sprite sprite.register_sprite_list(self) @@ -567,7 +682,7 @@ def clear(self, *, capacity: int | None = None, deep: bool = True) -> None: # Manually remove the spritelist from all sprites if deep: for sprite in self.sprite_list: - sprite.sprite_lists.remove(self) + sprite._unregister_sprite_list(self) self.sprite_list = [] self.sprite_slot = dict() @@ -626,7 +741,7 @@ def pop(self, index: int = -1) -> SpriteType: except KeyError: raise ValueError("Sprite is not in the SpriteList") - sprite.sprite_lists.remove(self) + sprite._unregister_sprite_list(self) del self.sprite_slot[sprite] self._sprite_buffer_free_slots.append(slot) @@ -715,7 +830,7 @@ def remove(self, sprite: SpriteType) -> None: index = self.sprite_list.index(sprite) self.sprite_list.pop(index) - sprite.sprite_lists.remove(self) + sprite._unregister_sprite_list(self) del self.sprite_slot[sprite] self._sprite_buffer_free_slots.append(slot) @@ -728,7 +843,7 @@ def remove(self, sprite: SpriteType) -> None: if self.spatial_hash is not None: self.spatial_hash.remove(sprite) - def extend(self, sprites: Iterable[SpriteType] | SpriteList[SpriteType]) -> None: + def extend(self, sprites: Iterable[SpriteType]) -> None: """ Extends the current list with the given iterable @@ -874,26 +989,10 @@ def _recalculate_spatial_hashes(self) -> None: self.spatial_hash.add(sprite) def update(self, delta_time: float = 1 / 60, *args, **kwargs) -> None: - """ - Call the update() method on each sprite in the list. - - Args: - delta_time: Time since last update in seconds - *args: Additional positional arguments - **kwargs: Additional keyword arguments - """ for sprite in self.sprite_list: sprite.update(delta_time, *args, **kwargs) def update_animation(self, delta_time: float = 1 / 60, *args, **kwargs) -> None: - """ - Call the update_animation in every sprite in the sprite list. - - Args: - delta_time: Time since last update in seconds - *args: Additional positional arguments - **kwargs: Additional keyword arguments - """ for sprite in self.sprite_list: sprite.update_animation(delta_time, *args, **kwargs) @@ -1009,32 +1108,6 @@ def draw( pixelated: bool | None = None, blend_function: BlendFunction | None = None, ) -> None: - """ - Draw this list of sprites. - - Uninitialized sprite lists will first create OpenGL resources - before drawing. This may cause a performance stutter when the - following are true: - - 1. You created the sprite list with ``lazy=True`` - 2. You did not call :py:meth:`~SpriteList.initialize` before drawing - 3. You are initializing many sprites and/or lists at once - - See :ref:`pg_spritelist_advanced_lazy_spritelists` to learn more. - - Args: - filter: - Optional parameter to set OpenGL filter, such as - `gl.GL_NEAREST` to avoid smoothing. - pixelated: - ``True`` for pixelated and ``False`` for smooth interpolation. - Shortcut for setting filter to GL_NEAREST for a pixelated look. - The filter parameter have precedence over this. - blend_function: - Optional parameter to set the OpenGL blend function used for drawing - the sprite list, such as 'arcade.Window.ctx.BLEND_ADDITIVE' or - 'arcade.Window.ctx.BLEND_DEFAULT' - """ if len(self.sprite_list) == 0 or not self._visible or self.alpha_normalized == 0.0: return @@ -1105,15 +1178,6 @@ def draw( def draw_hit_boxes( self, color: RGBOrA255 = (0, 0, 0, 255), line_thickness: float = 1.0 ) -> None: - """ - Draw all the hit boxes in this list. - - .. warning:: This method is slow and should only be used for debugging. - - Args: - color: The color of the hit boxes - line_thickness: The thickness of the lines - """ import arcade converted_color = Color.from_iterable(color) diff --git a/tests/unit/spritelist/test_spritesequence.py b/tests/unit/spritelist/test_spritesequence.py new file mode 100644 index 000000000..454ba7a7c --- /dev/null +++ b/tests/unit/spritelist/test_spritesequence.py @@ -0,0 +1,26 @@ +import arcade + +class _CustomSpriteSolidColor(arcade.SpriteSolidColor): + pass + +def test_collective_draw(window: arcade.Window) -> None: + sprite_list1: arcade.SpriteList[arcade.Sprite] = arcade.SpriteList() + sprite_list1.append(arcade.SpriteSolidColor(16, 16, color=(255, 0, 0, 1))) + + sprite_list2: arcade.SpriteList[_CustomSpriteSolidColor] = arcade.SpriteList() + sprite_list2.append(_CustomSpriteSolidColor(16, 16, color=(255, 0, 0, 1))) + + # It really is a SpriteList with a good type; this would not typecheck otherwise + custom_sprite: _CustomSpriteSolidColor = sprite_list2[0] # assert_type + + # Assert that SpriteSequence is truly covariant: + # It can be used as a common type for different types of SpriteLists. + scene: list[arcade.SpriteSequence[arcade.Sprite]] = [ + sprite_list1, + sprite_list2, + ] + sprite: arcade.Sprite = scene[0][0] # assert_type + + # We can collectively draw all the SpriteSequences. + for sprite_list in scene: + sprite_list.draw()