Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a07bdbc
first version, really caothic, and doesn't work with defslot powers
Bumblebee00 Jun 14, 2025
12843da
second version, really caothic, but works with defslotpowers
Bumblebee00 Jun 14, 2025
d81145e
fix typo
Bumblebee00 Jun 18, 2025
79118fc
operation + and * are always commutative now
Bumblebee00 Jun 18, 2025
cd0cc33
added some tests of commutative operations
Bumblebee00 Jun 18, 2025
bd06d79
fixed bug on defslot functionality
Bumblebee00 Jun 19, 2025
a1da82d
added defslot on operations with multiple arguments
Bumblebee00 Jun 19, 2025
7849e7a
moved the commutativiry checks to only acrule macro
Bumblebee00 Jun 19, 2025
a7d57e9
negative exponent feature is done in a different way, more clean
Bumblebee00 Jun 20, 2025
50b5e50
fixed failing ci tests
Bumblebee00 Jun 20, 2025
3bd1282
added tests with deflost in operation call with more than two arguments
Bumblebee00 Jun 20, 2025
6825df3
now rationals can be used in rules
Bumblebee00 Jun 21, 2025
e6bce15
created smrule (sum multiplication rule) macro
Bumblebee00 Jun 22, 2025
f8c8841
enhance commutative term matcher to validate operation type
Bumblebee00 Jun 22, 2025
e742a84
fixed bug in defslot code and improved performance
Bumblebee00 Jun 22, 2025
9e4596d
improved negative exponent pattern matching. now it matches also for…
Bumblebee00 Jun 22, 2025
bdce8c4
changed order of checks in pow term matcher
Bumblebee00 Jun 24, 2025
8c8a207
added match for exp and sqrt calls
Bumblebee00 Jun 27, 2025
08e9993
removed smrule macro and added commutativity checks to the rule macro
Bumblebee00 Jun 30, 2025
80cabb1
added commutativity checks also for segment matcher
Bumblebee00 Jul 7, 2025
2dbff77
fixed predicates with defslots
Bumblebee00 Jul 7, 2025
734d1b9
now the pattern ~x^~m matches 1/x with m=-1
Bumblebee00 Aug 3, 2025
4a49b19
added tests for power match with sqrt and exp functions
Bumblebee00 Aug 6, 2025
05a5af2
refactor
Bumblebee00 Aug 6, 2025
36034e0
now ...^(1//2) matches in the rule with sqrt, and ℯ^... matches in th…
Bumblebee00 Aug 8, 2025
9142ba0
first prototype
Bumblebee00 Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 171 additions & 79 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
# 3. Callback: takes arguments Dictionary × Number of elements matched
#

function matcher(val::Any)
function matcher(val::Any, acSets, condition)
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
if iscall(val)
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val)
# else return a normal term matcher
else
return term_matcher_constructor(val)
# just two arguments bc defslot is only supported with operations with two args: *, ^, +
if any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val, acSets, condition)
end
# else return a normal term matcher
return term_matcher_constructor(val, acSets, condition)
end

function literal_matcher(next, data, bindings)
Expand All @@ -24,7 +24,9 @@ function matcher(val::Any)
end
end

function matcher(slot::Slot)
# acSets and condition are not used but needs to be there in case
# matcher(::Slot) is directly called from the macro
function matcher(slot::Slot, acSets, condition)
function slot_matcher(next, data, bindings)
!islist(data) && return nothing
val = get(bindings, slot.name, nothing)
Expand All @@ -35,6 +37,7 @@ function matcher(slot::Slot)
end
# elseif the first element of data matches the slot predicate, add it to bindings and call next
elseif slot.predicate(car(data))
# println("slot of $slot matched")
next(assoc(bindings, slot.name, car(data)), 1)
end
end
Expand All @@ -43,8 +46,8 @@ end
# this is called only when defslot_term_matcher finds the operation and tries
# to match it, so no default value used. So the same function as slot_matcher
# can be used
function matcher(defslot::DefSlot)
matcher(Slot(defslot.name, defslot.predicate))
function matcher(defslot::DefSlot, acSets, condition)
matcher(Slot(defslot.name, defslot.predicate), nothing, nothing)
end

# returns n == offset, 0 if failed
Expand Down Expand Up @@ -75,7 +78,7 @@ function trymatchexpr(data, value, n)
end
end

function matcher(segment::Segment)
function matcher(segment::Segment, acSets)
function segment_matcher(success, data, bindings)
val = get(bindings, segment.name, nothing)

Expand All @@ -90,98 +93,187 @@ function matcher(segment::Segment)
for i=length(data):-1:0
subexpr = take_n(data, i)

if segment.predicate(subexpr)
res = success(assoc(bindings, segment.name, subexpr), i)
if res !== nothing
break
end
end
!segment.predicate(subexpr) && continue
res = success(assoc(bindings, segment.name, subexpr), i)
res !== nothing && break
end

return res
end
end
end

function term_matcher_constructor(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
function term_matcher_constructor(term, acSets, condition)
matchers = (
matcher(operation(term), acSets, condition),
map(x->matcher(x,acSets, condition), arguments(term))...,
)

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return bindings′
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
# explanation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end

function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
# if condition errors, this means not all the bindings
# are associated, so we are not at the end of the match. So
# we continue to the next matchers
function check_conditions(result)
result === nothing && return false
try
tmp = condition(result)
# tmp==nothing means no conditions are present
tmp===nothing && return true
return tmp
catch e
# println("condition failed, continuing")
return true
end
end

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
# if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent
if operation(term) === ^
function pow_term_matcher(success, data, bindings)
# println("in ^ matcher of $term with data $data")
!islist(data) && return nothing # if data is not a list, return nothing
data = car(data) # from (..., ) to ...
!iscall(data) && return nothing # if first element is not a call, return nothing

result = loop(data, bindings, matchers)
check_conditions(result) && return success(result, 1)

frankestein = nothing
if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1)
# if data is of the alternative form (1/...)^(...)
one_over_smth = arguments(data)[1]
T = symtype(one_over_smth)
frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]])
elseif (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^)
# if data is of the alternative form 1/(...)^(...)
denominator = arguments(data)[2]
T = symtype(denominator)
frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]])
elseif (operation(data) === /) && isequal(arguments(data)[1], 1)
# if data is of the alternative form 1/(...), it might match with exponent = -1
denominator = arguments(data)[2]
T = symtype(denominator)
frankestein = Term{T}(^, [denominator, -1])
elseif operation(data)===exp
# if data is a exp call, it might match with base e
T = symtype(arguments(data)[1])
frankestein = Term{T}(^,[ℯ, arguments(data)[1]])
elseif operation(data)===sqrt
# if data is a sqrt call, it might match with exponent 1//2
T = symtype(arguments(data)[1])
frankestein = Term{T}(^,[arguments(data)[1], 1//2])
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))

if frankestein !==nothing
result = loop(frankestein, bindings, matchers)
check_conditions(result) && return success(result, 1)
end
# explanation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty

return nothing
end
return pow_term_matcher
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
elseif acSets!==nothing && operation(term) in [+, *]
function commutative_term_matcher(success, data, bindings)
# println("in +* matcher of $term with data $data")
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try

T = symtype(car(data))
if T <: Number
f = operation(car(data))
data_args = arguments(car(data))

for inds in acSets(eachindex(data_args), length(data_args))
candidate = Term{T}(f, @views data_args[inds])

loop(car(data), bindings, matchers) # Try to eat exactly one term
result = loop(candidate, bindings, matchers)
check_conditions(result) && return success(result, 1)
end
# if car(data) does not subtype to number, it might not be commutative
else
# call the normal matcher
result = loop(car(data), bindings, matchers)
check_conditions(result) && return success(result, 1)
end
return nothing
end
return commutative_term_matcher
else
function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing

result = loop(car(data), bindings, matchers)
check_conditions(result) && return success(result, 1)
return nothing
end
return term_matcher
end
end

# creates a matcher for a term containing a defslot, such as:
# (~x + ...complicated pattern...) * ~!y
# normal part (can bee a tree) operation defslot part

# defslot_term_matcher works like this:
# checks whether data starts with the default operation.
# if yes (1): continues like term_matcher
# if no checks whether data matches the normal part
# if no returns nothing, rule is not applied
# if yes (2): adds the pair (default value name, default value) to the found bindings and
# calls the success function like term_matcher would do

function defslot_term_matcher_constructor(term)
a = arguments(term) # length two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term

function defslot_term_matcher_constructor(term, acSets, condition)
a = arguments(term)
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
defslot = a[defslot_index]
if length(a) == 2
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets, condition)
else
others = [a[i] for i in eachindex(a) if i != defslot_index]
T = symtype(term)
f = operation(term)
other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets, condition)
end

function defslot_term_matcher(success, data, bindings)
# if data is not a list, return nothing
!islist(data) && return nothing
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation)
other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part

# checks whether it matches the normal part
# <-----------------(2)------------------------------->
bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)

if bindings === nothing
return nothing
end
return success(bindings, 1)
end
normal_matcher = term_matcher_constructor(term, acSets, condition)

# (1)
function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))


function defslot_term_matcher(success, data, bindings)
# println("in defslotmatcher of $term with data $data")
!islist(data) && return nothing # if data is not a list, return nothing
# call the normal matcher, with success function foo1 that simply returns the bindings
# <--foo1-->
result = normal_matcher((b,n) -> b, data, bindings)
result !== nothing && return success(result, 1)
# println("no match, trying defslot")
# if no match, try to match with a defslot.
# checks whether it matches the normal part if yes executes foo2
# foo2: adds the pair (default value name, default value) to the found bindings
# <-------------------foo2---------------------------->
result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)
result === nothing && return nothing
# println("defslot match!")
try
tmp = condition(result)
# tmp==nothing means no conditions are present
if tmp===nothing || tmp
return success(result, 1)
end
catch e
# println("condition failed, continuing")
return success(result, 1)
end

loop(car(data), bindings, matchers) # Try to eat exactly one term
end
end
end
Loading
Loading