Skip to content

TEST: Test propagate_map_queries pass #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ readme = "README.md"
packages = [{include = "finch", from = "src"}]

[tool.poetry.dependencies]
python = "^3.10"
python = "^3.11"
numpy = ">=1.19"

[tool.poetry.group.test.dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/finch/autoschedule/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .finch_logic import (
from ..finch_logic import (
Aggregate,
Alias,
Deferred,
Expand All @@ -15,7 +15,7 @@
Table,
)
from .optimize import optimize, propagate_map_queries
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
from ..symbolic import PostOrderDFS, PostWalk, PreWalk

__all__ = [
"Aggregate",
Expand Down
2 changes: 1 addition & 1 deletion src/finch/autoschedule/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from textwrap import dedent
from typing import Any

from .finch_logic import (
from ..finch_logic import (
Alias,
Deferred,
Field,
Expand Down
2 changes: 1 addition & 1 deletion src/finch/autoschedule/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .compiler import LogicCompiler
from .rewrite_tools import gensym
from ..symbolic import gensym


class LogicExecutor:
Expand Down
6 changes: 3 additions & 3 deletions src/finch/autoschedule/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .compiler import LogicCompiler
from .finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
from .rewrite_tools import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite


def optimize(prgm: LogicNode) -> LogicNode:
Expand Down Expand Up @@ -52,4 +52,4 @@ def __init__(self, ctx: LogicCompiler):

def __call__(self, prgm: LogicNode):
prgm = optimize(prgm)
return self.ctx(prgm)
return self.ctx(prgm)
8 changes: 8 additions & 0 deletions src/finch/finch_logic/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def children(self):
"""Returns the children of the node."""
return [self.op, *self.args]

@classmethod
def make_term(cls, head, op, *args):
return head(op, args)


@dataclass(eq=True, frozen=True)
class Aggregate(LogicNode):
Expand Down Expand Up @@ -412,3 +416,7 @@ def is_stateful():
def children(self):
"""Returns the children of the node."""
return [*self.bodies]

@classmethod
def make_term(cls, head, *val):
return head(val)
2 changes: 1 addition & 1 deletion src/finch/symbolic/rewriters.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(self, x: Term) -> Term | None:
new_args = list(map(self, args))
if all(arg is None for arg in new_args):
return self.rw(x)
y = x.make_term(*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
y = x.make_term(x.head(), *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
return default_rewrite(self.rw(y), y)
return self.rw(x)

Expand Down
7 changes: 5 additions & 2 deletions src/finch/symbolic/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ def is_expr(self) -> bool:

@abstractmethod
def make_term(self, head: Any, children: List[Term]) -> Term:
"""Construct a new term in the same family of terms with the given head type and children."""
"""
Construct a new term in the same family of terms with the given head type and children.
This function should satisfy `x == x.make_term(x.head(), *x.children())`
"""
pass

def __hash__(self) -> int:
Expand All @@ -52,4 +55,4 @@ def PreOrderDFS(node: Term) -> Iterator[Term]:
yield node
if node.is_expr():
for arg in node.children():
yield from PreOrderDFS(arg)
yield from PreOrderDFS(arg)
21 changes: 21 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from finch.autoschedule import propagate_map_queries
from finch.finch_logic import *


def test_propagate_map_queries_simple():
plan = Plan(
(
Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())),
Query(Alias("A11"), Alias("A10")),
Produces((Alias("11"),)),
)
)
expected = Plan(
(
Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))),
Produces((Alias("11"),)),
)
)

result = propagate_map_queries(plan)
assert result == expected