diff --git a/pyproject.toml b/pyproject.toml index a780edd..5adfce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ readme = "README.md" packages = [{include = "finch", from = "src"}] [tool.poetry.dependencies] -python = "^3.10" +python = "^3.11" numpy = ">=1.19" [tool.poetry.group.test.dependencies] diff --git a/src/finch/autoschedule/__init__.py b/src/finch/autoschedule/__init__.py index 1836890..5e18633 100644 --- a/src/finch/autoschedule/__init__.py +++ b/src/finch/autoschedule/__init__.py @@ -1,4 +1,4 @@ -from .finch_logic import ( +from ..finch_logic import ( Aggregate, Alias, Deferred, @@ -15,7 +15,7 @@ Table, ) from .optimize import optimize, propagate_map_queries -from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk +from ..symbolic import PostOrderDFS, PostWalk, PreWalk __all__ = [ "Aggregate", diff --git a/src/finch/autoschedule/compiler.py b/src/finch/autoschedule/compiler.py index 2005183..a97f9c4 100644 --- a/src/finch/autoschedule/compiler.py +++ b/src/finch/autoschedule/compiler.py @@ -2,7 +2,7 @@ from textwrap import dedent from typing import Any -from .finch_logic import ( +from ..finch_logic import ( Alias, Deferred, Field, diff --git a/src/finch/autoschedule/executor.py b/src/finch/autoschedule/executor.py index 356765f..148aed2 100644 --- a/src/finch/autoschedule/executor.py +++ b/src/finch/autoschedule/executor.py @@ -1,5 +1,5 @@ from .compiler import LogicCompiler -from .rewrite_tools import gensym +from ..symbolic import gensym class LogicExecutor: diff --git a/src/finch/autoschedule/optimize.py b/src/finch/autoschedule/optimize.py index 982bf53..0b1e91f 100644 --- a/src/finch/autoschedule/optimize.py +++ b/src/finch/autoschedule/optimize.py @@ -1,6 +1,6 @@ from .compiler import LogicCompiler -from .finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query -from .rewrite_tools import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite +from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query +from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite def optimize(prgm: LogicNode) -> LogicNode: @@ -52,4 +52,4 @@ def __init__(self, ctx: LogicCompiler): def __call__(self, prgm: LogicNode): prgm = optimize(prgm) - return self.ctx(prgm) \ No newline at end of file + return self.ctx(prgm) diff --git a/src/finch/finch_logic/nodes.py b/src/finch/finch_logic/nodes.py index eeaca95..2977679 100644 --- a/src/finch/finch_logic/nodes.py +++ b/src/finch/finch_logic/nodes.py @@ -191,6 +191,10 @@ def children(self): """Returns the children of the node.""" return [self.op, *self.args] + @classmethod + def make_term(cls, head, op, *args): + return head(op, args) + @dataclass(eq=True, frozen=True) class Aggregate(LogicNode): @@ -412,3 +416,7 @@ def is_stateful(): def children(self): """Returns the children of the node.""" return [*self.bodies] + + @classmethod + def make_term(cls, head, *val): + return head(val) diff --git a/src/finch/symbolic/rewriters.py b/src/finch/symbolic/rewriters.py index 7f8057f..8086f64 100644 --- a/src/finch/symbolic/rewriters.py +++ b/src/finch/symbolic/rewriters.py @@ -87,7 +87,7 @@ def __call__(self, x: Term) -> Term | None: new_args = list(map(self, args)) if all(arg is None for arg in new_args): return self.rw(x) - y = x.make_term(*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args)) + y = x.make_term(x.head(), *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args)) return default_rewrite(self.rw(y), y) return self.rw(x) diff --git a/src/finch/symbolic/term.py b/src/finch/symbolic/term.py index a908c1f..06cab10 100644 --- a/src/finch/symbolic/term.py +++ b/src/finch/symbolic/term.py @@ -30,7 +30,10 @@ def is_expr(self) -> bool: @abstractmethod def make_term(self, head: Any, children: List[Term]) -> Term: - """Construct a new term in the same family of terms with the given head type and children.""" + """ + Construct a new term in the same family of terms with the given head type and children. + This function should satisfy `x == x.make_term(x.head(), *x.children())` + """ pass def __hash__(self) -> int: @@ -52,4 +55,4 @@ def PreOrderDFS(node: Term) -> Iterator[Term]: yield node if node.is_expr(): for arg in node.children(): - yield from PreOrderDFS(arg) \ No newline at end of file + yield from PreOrderDFS(arg) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..b5bc285 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,21 @@ +from finch.autoschedule import propagate_map_queries +from finch.finch_logic import * + + +def test_propagate_map_queries_simple(): + plan = Plan( + ( + Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())), + Query(Alias("A11"), Alias("A10")), + Produces((Alias("11"),)), + ) + ) + expected = Plan( + ( + Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))), + Produces((Alias("11"),)), + ) + ) + + result = propagate_map_queries(plan) + assert result == expected