Skip to content

Commit 39662c9

Browse files
committed
Linting fixes for entire library.
1 parent 38f177c commit 39662c9

24 files changed

+695
-315
lines changed

pixi.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ sparse = "*"
2424
numba = ">=0.60"
2525
scipy = "*"
2626
numpy = "==2.*"
27+
mypy = ">=1.15.0,<2"
2728

2829
[feature.test.tasks]
2930
test = { cmd = "pytest", depends-on = ["compile"] }

pyproject.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,23 @@ numba = "^0.61.0"
2121
[build-system]
2222
requires = ["poetry-core>=1.0.8"]
2323
build-backend = "poetry.core.masonry.api"
24+
25+
[tool.ruff.lint]
26+
select = ["F", "E", "W", "I", "B", "UP", "YTT", "BLE", "C4", "T10", "ISC", "ICN", "PIE", "PYI", "RSE", "RET", "SIM", "PGH", "FLY", "NPY", "PERF"]
27+
28+
[tool.ruff.lint.isort.sections]
29+
numpy = ["numpy", "numpy.*", "scipy", "scipy.*"]
30+
31+
[tool.ruff.format]
32+
quote-style = "double"
33+
docstring-code-format = true
34+
35+
[tool.ruff.lint.isort]
36+
section-order = [
37+
"future",
38+
"standard-library",
39+
"first-party",
40+
"numpy",
41+
"third-party",
42+
"local-folder",
43+
]

src/finch/__init__.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,31 @@
11
from . import finch_logic
2-
from .interface import *
2+
from .interface import (
3+
compute,
4+
elementwise,
5+
expand_dims,
6+
fuse,
7+
fused,
8+
identify,
9+
lazy,
10+
multiply,
11+
permute_dims,
12+
prod,
13+
reduce,
14+
squeeze,
15+
)
16+
17+
__all__ = [
18+
"lazy",
19+
"compute",
20+
"finch_logic",
21+
"fuse",
22+
"fused",
23+
"permute_dims",
24+
"expand_dims",
25+
"squeeze",
26+
"identify",
27+
"reduce",
28+
"elementwise",
29+
"prod",
30+
"multiply",
31+
]

src/finch/algebra/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
from .algebra import *
1+
from .algebra import (
2+
element_type,
3+
fill_value,
4+
fixpoint_type,
5+
init_value,
6+
is_associative,
7+
query_property,
8+
register_property,
9+
return_type,
10+
)
211

312
__all__ = [
413
"fill_value",
@@ -9,4 +18,4 @@
918
"is_associative",
1019
"query_property",
1120
"register_property",
12-
]
21+
]

src/finch/algebra/algebra.py

Lines changed: 101 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from typing import Any, Type
2-
from collections.abc import Hashable
3-
41
"""
52
Finch performs extensive rewriting and defining of functions. The Finch
63
compiler is designed to inspect objects and functions defined by other
@@ -22,33 +19,40 @@
2219
2320
```python
2421
from finch import register_property
25-
register_property(complex, '__add__', 'is_associative', lambda obj: True)
22+
23+
register_property(complex, "__add__", "is_associative", lambda obj: True)
2624
```
2725
2826
Finch includes a convenience functions to query each property as well,
2927
for example:
3028
```python
3129
from finch import query_property
3230
from operator import add
33-
query_property(complex, '__add__', 'is_associative')
31+
32+
query_property(complex, "__add__", "is_associative")
3433
# True
3534
is_associative(add, complex, complex)
3635
# True
3736
```
3837
39-
Properties can be inherited in the same way as methods. First we check whether properties have been defined for the object itself (in the case of functions), then we check For example, if you
40-
register a property for a class, all subclasses of that class will inherit
41-
that property. This allows you to define properties for a class and have
42-
them automatically apply to all subclasses, without having to register the
43-
property for each subclass individually.
38+
Properties can be inherited in the same way as methods. First we check whether
39+
properties have been defined for the object itself (in the case of functions), then we
40+
check For example, if you register a property for a class, all subclasses of that class
41+
will inherit that property. This allows you to define properties for a class and have
42+
them automatically apply to all subclasses, without having to register the property for
43+
each subclass individually.
4444
"""
45+
4546
import operator
46-
from typing import Union
47+
from collections.abc import Hashable
48+
from typing import Any, Callable
49+
4750
import numpy as np
4851

49-
_properties = {}
52+
_properties: dict[tuple[Hashable, str, str], Any] = {}
5053

51-
def query_property(obj, attr, prop, *args):
54+
55+
def query_property(obj: Hashable, attr: str, prop: Hashable, *args: Any) -> Any:
5256
"""Queries a property of an attribute of an object or class. Properties can
5357
be overridden by calling register_property on the object or it's class.
5458
@@ -64,22 +68,28 @@ def query_property(obj, attr, prop, *args):
6468
Raises:
6569
NotImplementedError: If the property is not implemented for the given type.
6670
"""
67-
if isinstance(obj, type):
68-
T = obj
69-
else:
70-
if isinstance(obj, Hashable):
71-
if (obj, attr, prop) in _properties:
72-
return _properties[(obj, attr, prop)](obj, *args)
71+
T = obj
72+
if not isinstance(obj, Hashable):
7373
T = type(obj)
74-
while True:
75-
if (T, attr, prop) in _properties:
76-
return _properties[(T, attr, prop)](obj, *args)
77-
if T is object:
78-
break
79-
T = T.__base__
74+
to_query = {T}
75+
queried: set[type] = set()
76+
while len(to_query) != len(queried):
77+
to_query_new = to_query.copy()
78+
for o in to_query:
79+
method = _properties.get((o, attr, prop), None)
80+
if method is not None:
81+
return method(obj, *args)
82+
queried.add(o)
83+
if not isinstance(o, type):
84+
to_query_new.update(type(o))
85+
continue
86+
to_query_new.update(o.__mro__)
87+
to_query = to_query_new
88+
8089
raise NotImplementedError(f"Property {prop} not implemented for {type(obj)}")
8190

82-
def register_property(cls, attr, prop, f):
91+
92+
def register_property(cls: type, attr: str, prop: str, f: Callable) -> None:
8393
"""Registers a property for a class or object.
8494
8595
Args:
@@ -90,6 +100,7 @@ def register_property(cls, attr, prop, f):
90100
"""
91101
_properties[(cls, attr, prop)] = f
92102

103+
93104
def fill_value(arg: Any) -> Any:
94105
"""The fill value for the given argument. The fill value is the
95106
default value for a tensor when it is created with a given shape and dtype,
@@ -104,11 +115,15 @@ def fill_value(arg: Any) -> Any:
104115
Raises:
105116
NotImplementedError: If the fill value is not implemented for the given type.
106117
"""
107-
return query_property(arg, '__self__', 'fill_value')
118+
return query_property(arg, "__self__", "fill_value")
119+
120+
121+
register_property(
122+
np.ndarray, "__self__", "fill_value", lambda x: np.zeros((), dtype=x.dtype)[()]
123+
)
108124

109-
register_property(np.ndarray, '__self__', 'fill_value', lambda x: np.zeros((), dtype=x.dtype)[()])
110125

111-
def element_type(arg: Any) -> Type:
126+
def element_type(arg: Any) -> type:
112127
"""The element type of the given argument. The element type is the scalar type of
113128
the elements in a tensor, which may be different from the data type of the
114129
tensor.
@@ -122,9 +137,16 @@ def element_type(arg: Any) -> Type:
122137
Raises:
123138
NotImplementedError: If the element type is not implemented for the given type.
124139
"""
125-
return query_property(arg, '__self__', 'element_type')
140+
return query_property(arg, "__self__", "element_type")
141+
142+
143+
register_property(
144+
np.ndarray,
145+
"__self__",
146+
"element_type",
147+
lambda x: type(np.zeros((), dtype=x.dtype)[()]),
148+
)
126149

127-
register_property(np.ndarray, '__self__', 'element_type', lambda x: type(np.zeros((), dtype=x.dtype)[()]))
128150

129151
def return_type(op: Any, *args: Any) -> Any:
130152
"""The return type of the given function on the given argument types.
@@ -136,7 +158,8 @@ def return_type(op: Any, *args: Any) -> Any:
136158
Returns:
137159
The return type of op(*args: arg_types)
138160
"""
139-
return query_property(op, '__call__', 'return_type', *args)
161+
return query_property(op, "__call__", "return_type", *args)
162+
140163

141164
StableNumber = (np.number, bool, int, float, complex)
142165

@@ -158,34 +181,52 @@ def return_type(op: Any, *args: Any) -> Any:
158181
}
159182

160183
for op, (meth, rmeth) in _reflexive_operators.items():
161-
register_property(op, '__call__', 'return_type', lambda op, a, b: query_property(a, meth, 'return_type', b) if hasattr(a, meth) else query_property(b, rmeth, 'return_type', a)),
184+
(
185+
register_property(
186+
op,
187+
"__call__",
188+
"return_type",
189+
lambda op, a, b, meth=meth, rmeth=rmeth: query_property(
190+
a, meth, "return_type", b
191+
)
192+
if hasattr(a, meth)
193+
else query_property(b, rmeth, "return_type", a),
194+
),
195+
)
196+
162197
def _return_type(meth):
163198
def _return_type_closure(a, b):
164199
if issubclass(b, StableNumber):
165200
return type(getattr(a(True), meth)(b(True)))
166-
else:
167-
raise TypeError(f"Unsupported operand type for {type(a)}.{meth}: {type(b)}")
201+
raise TypeError(
202+
f"Unsupported operand type for {type(a)}.{meth}: {type(b)}"
203+
)
204+
168205
return _return_type_closure
206+
169207
for T in StableNumber:
170-
register_property(T, meth, 'return_type', _return_type(meth))
171-
register_property(T, rmeth, 'return_type', _return_type(rmeth))
208+
register_property(T, meth, "return_type", _return_type(meth))
209+
register_property(T, rmeth, "return_type", _return_type(rmeth))
210+
172211

173212
def is_associative(op: Any) -> bool:
174213
"""Returns whether the given function is associative, that is, whether the
175214
op(op(a, b), c) == op(a, op(b, c)) for all a, b, c.
176215
177216
Args:
178217
op: The function to check.
179-
218+
180219
Returns:
181220
True if the function can be proven to be associative, False otherwise.
182221
"""
183-
return query_property(op, '__call__', 'is_associative')
222+
return query_property(op, "__call__", "is_associative")
223+
184224

185225
for op in [operator.add, operator.mul, operator.and_, operator.xor, operator.or_]:
186-
register_property(op, '__call__', 'is_associative', lambda op: True)
226+
register_property(op, "__call__", "is_associative", lambda op: True)
187227

188-
def fixpoint_type(op: Any, z: Any, T: Type) -> Type:
228+
229+
def fixpoint_type(op: Any, z: Any, T: type) -> type:
189230
"""Determines the fixpoint type after repeated calling the given operation.
190231
191232
Args:
@@ -200,9 +241,12 @@ def fixpoint_type(op: Any, z: Any, T: Type) -> Type:
200241
R = type(z)
201242
while R not in S:
202243
S.add(R)
203-
R = return_type(op, type(z), T) # Assuming `op` is a callable that takes `z` and `T` as arguments
244+
R = return_type(
245+
op, type(z), T
246+
) # Assuming `op` is a callable that takes `z` and `T` as arguments
204247
return R
205248

249+
206250
def init_value(op, arg) -> Any:
207251
"""Returns the initial value for a reduction operation on the given type.
208252
@@ -214,17 +258,24 @@ def init_value(op, arg) -> Any:
214258
The initial value for the given operation and type.
215259
216260
Raises:
217-
NotImplementedError: If the initial value is not implemented for the given type and operation.
261+
NotImplementedError: If the initial value is not implemented for the given type
262+
and operation.
218263
"""
219-
return query_property(op, '__call__', 'init_value', arg)
264+
return query_property(op, "__call__", "init_value", arg)
265+
220266

221267
for op in [operator.add, operator.mul, operator.and_, operator.xor, operator.or_]:
222268
(meth, rmeth) = _reflexive_operators[op]
223-
register_property(op, '__call__', 'init_value', lambda op, arg: query_property(arg, meth, 'init_value', arg))
269+
register_property(
270+
op,
271+
"__call__",
272+
"init_value",
273+
lambda op, arg, meth=meth: query_property(arg, meth, "init_value", arg),
274+
)
224275

225276
for T in StableNumber:
226-
register_property(T, '__add__', 'init_value', lambda a, b: a(False))
227-
register_property(T, '__mul__', 'init_value', lambda a, b: a(True))
228-
register_property(T, '__and__', 'init_value', lambda a, b: a(True))
229-
register_property(T, '__xor__', 'init_value', lambda a, b: a(False))
230-
register_property(T, '__or__', 'init_value', lambda a, b: a(False))
277+
register_property(T, "__add__", "init_value", lambda a, b: a(False))
278+
register_property(T, "__mul__", "init_value", lambda a, b: a(True))
279+
register_property(T, "__and__", "init_value", lambda a, b: a(True))
280+
register_property(T, "__xor__", "init_value", lambda a, b: a(False))
281+
register_property(T, "__or__", "init_value", lambda a, b: a(False))

src/finch/autoschedule/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
Subquery,
1515
Table,
1616
)
17+
from ..symbolic import PostOrderDFS, PostWalk, PreWalk
1718
from .optimize import (
19+
lift_subqueries,
1820
optimize,
1921
propagate_map_queries,
20-
lift_subqueries,
2122
)
22-
from ..symbolic import PostOrderDFS, PostWalk, PreWalk
2323

2424
__all__ = [
2525
"Aggregate",

0 commit comments

Comments
 (0)