Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22,244 changes: 6,606 additions & 15,638 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"obj_array.ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D",
# sympy
"sp.Matrix": "class:sympy.matrices.dense.DenseMatrix",
"sym.Expr": "class:sympy.core.expr.Expr",
"sym.Symbol": "class:sympy.core.symbol.Symbol",
"sym.Matrix": "class:sympy.matrices.dense.DenseMatrix",
# pytools
"ObjectArray1D": "obj:pytools.obj_array.ObjectArray1D",
# pymbolic
Expand Down
37 changes: 25 additions & 12 deletions sumpy/assignment_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
THE SOFTWARE.
"""


import logging
from typing import TYPE_CHECKING

from typing_extensions import override

import sumpy.symbolic as sym


if TYPE_CHECKING:
from collections.abc import Sequence


logger = logging.getLogger(__name__)

__doc__ = """
Expand Down Expand Up @@ -100,7 +106,12 @@ class SymbolicAssignmentCollection:
by this class, but not expressions using names defined in this collection.
"""

def __init__(self, assignments=None):
assignments: dict[str, sym.Expr]
reversed_assignments: dict[sym.Expr, str]
symbol_generator: _SymbolGenerator
all_dependencies_cache: dict[str, set[sym.Symbol]]

def __init__(self, assignments: dict[str, sym.Expr] | None = None):
"""
:arg assignments: mapping from *var_name* to expression
"""
Expand All @@ -114,12 +125,13 @@ def __init__(self, assignments=None):
self.symbol_generator = _SymbolGenerator(self.assignments)
self.all_dependencies_cache = {}

@override
def __str__(self):
return "\n".join(
f"{name} <- {expr}"
for name, expr in self.assignments.items())

def get_all_dependencies(self, var_name):
def get_all_dependencies(self, var_name: str):
"""Including recursive dependencies."""
try:
return self.all_dependencies_cache[var_name]
Expand All @@ -129,7 +141,7 @@ def get_all_dependencies(self, var_name):
if var_name not in self.assignments:
return set()

result = set()
result: set[sym.Symbol] = set()
for dep in self.assignments[var_name].atoms():
if not isinstance(dep, sym.Symbol):
continue
Expand All @@ -143,13 +155,14 @@ def get_all_dependencies(self, var_name):
self.all_dependencies_cache[var_name] = result
return result

def add_assignment(self, name, expr, root_name=None, wrt_set=None,
retain_name=True):
def add_assignment(self,
name: str,
expr: sym.Expr,
root_name: str | None = None,
retain_name: bool = True):
assert isinstance(name, str)
assert name not in self.assignments

if wrt_set is None:
wrt_set = frozenset()
if root_name is None:
root_name = name

Expand All @@ -163,23 +176,23 @@ def add_assignment(self, name, expr, root_name=None, wrt_set=None,

return name

def assign_unique(self, name_base, expr):
def assign_unique(self, name_base: str, expr: sym.Expr):
"""Assign *expr* to a new variable whose name is based on *name_base*.
Return the new variable name.
"""
new_name = self.symbol_generator(name_base).name

return self.add_assignment(new_name, expr)

def assign_temp(self, name_base, expr):
def assign_temp(self, name_base: str, expr: sym.Expr):
"""If *expr* is mapped to a existing variable, then return the existing
variable or assign *expr* to a new variable whose name is based on
*name_base*. Return the variable name *expr* is mapped to in either case.
"""
new_name = self.symbol_generator(name_base).name
return self.add_assignment(new_name, expr, retain_name=False)

def run_global_cse(self, extra_exprs=None):
def run_global_cse(self, extra_exprs: Sequence[sym.Expr] | None = None):
if extra_exprs is None:
extra_exprs = []

Expand All @@ -199,7 +212,7 @@ def run_global_cse(self, extra_exprs=None):
# from sumpy.symbolic import checked_cse

from sumpy.cse import cse
new_assignments, new_exprs = cse(assign_exprs + extra_exprs,
new_assignments, new_exprs = cse([*assign_exprs, *extra_exprs],
symbols=self.symbol_generator)

new_assign_exprs = new_exprs[:len(assign_exprs)]
Expand Down
17 changes: 9 additions & 8 deletions sumpy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,16 +510,17 @@ def map_constant(self, expr, *args):

# {{{ convert complex to np.complex

class ComplexRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
class ComplexRewriter(CSECachingIdentityMapper[[]], CallExternalRecMapper):
complex_dtype: np.dtype[np.complexfloating] | None

def __init__(self, complex_dtype=None):
def __init__(self, complex_dtype: np.dtype[np.complexfloating] | None = None):
super().__init__()
self.complex_dtype = complex_dtype

def map_constant(self, expr, rec_self=None):
def map_constant(self, expr: object, rec_self=None):
"""Convert complex values to numpy types
"""
if not isinstance(expr, complex | np.complex64 | np.complex128):
if not isinstance(expr, (complex, np.complex64, np.complex128)):
return IdentityMapper.map_constant(rec_self or self, expr,
rec_self=rec_self)

Expand All @@ -544,7 +545,7 @@ def map_constant(self, expr, rec_self=None):
INDEXED_VAR_RE = re.compile(r"^([a-zA-Z_]+)([0-9]+)$")


class VectorComponentRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
class VectorComponentRewriter(CSECachingIdentityMapper[[]], CallExternalRecMapper):
"""For names in name_whitelist, turn ``a3`` into ``a[3]``."""

name_whitelist: frozenset[str]
Expand Down Expand Up @@ -574,7 +575,7 @@ def map_variable(self, expr, *args):

# {{{ sum sign grouper

class SumSignGrouper(CSECachingIdentityMapper, CallExternalRecMapper):
class SumSignGrouper(CSECachingIdentityMapper[[]], CallExternalRecMapper):
"""Anti-cancellation cargo-cultism."""

def map_sum(self, expr, *args):
Expand Down Expand Up @@ -613,7 +614,7 @@ def map_sum(self, expr, *args):
# }}}


class MathConstantRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
class MathConstantRewriter(CSECachingIdentityMapper[[]], CallExternalRecMapper):
def map_variable(self, expr, *args):
if expr.name == "pi":
return prim.Variable("M_PI")
Expand All @@ -625,7 +626,7 @@ def map_variable(self, expr, *args):

# {{{ combine mappers

def combine_mappers(*mappers):
def combine_mappers(*mappers: CallExternalRecMapper):
"""Returns a mapper that combines the work of several other mappers. For
this to work, the mappers need to be instances of
:class:`sumpy.codegen.CallExternalRecMapper`. When calling parent class
Expand Down
Loading
Loading