Skip to content

Commit

Permalink
Fix race condition of vargs pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Mar 4, 2025
1 parent d682811 commit 27849d8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/Nncase.Core/PatternMatch/VArgsPattern.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public VArgsPattern(IEnumerable<Expr> fields, string? name)

public IEnumerator<Pattern> GetEnumerator() => Fields.GetEnumerator();

public bool MatchLeaf(ReadOnlySpan<Expr> input)
public bool MatchLeaf(ReadOnlySpan<Expr> input, out IReadOnlyList<Pattern> fields)
{
Fields = FieldsGenerator(input);
Fields = fields = FieldsGenerator(input);

if (input.Length != Fields.Count)
{
Expand Down
33 changes: 17 additions & 16 deletions src/Nncase.EGraph/PatternMatch/EGraphMatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading.Tasks;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Utilities;

namespace Nncase.PatternMatch;

Expand Down Expand Up @@ -111,12 +112,13 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, F
var context = new MatchContext(matchScopes, pattern, expr);

if (context.HasCandidates
&& pattern.MatchLeaf(expr))
&& pattern.MatchLeaf(expr)
&& pattern.Parameters.MatchLeaf(SpanUtility.UnsafeCast<Var, Expr>(expr.Parameters), out var paramsPattern))
{
var newScopes = Visit(context.Candidates, pattern.Body, enode.Children[0]);
if (newScopes.Count > 0)
{
newScopes = Visit(newScopes, pattern.Parameters, enode.Children.Skip(1));
newScopes = Visit(newScopes, pattern.Parameters, paramsPattern, enode.Children.Skip(1));
if (newScopes.Count > 0)
{
context.NewScopes.AddRange(newScopes);
Expand Down Expand Up @@ -149,12 +151,12 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, C
if (context.HasCandidates
&& pattern.MatchLeaf(expr)
&& pattern.Target.MatchLeaf(expr.Target)
&& pattern.Arguments.MatchLeaf(expr.Arguments))
&& pattern.Arguments.MatchLeaf(expr.Arguments, out var argsPattern))
{
var newScopes = Visit(context.Candidates, pattern.Target, enode.Children[0]);
if (newScopes.Count > 0)
{
newScopes = Visit(newScopes, pattern.Arguments, enode.Children.Skip(1));
newScopes = Visit(newScopes, pattern.Arguments, argsPattern, enode.Children.Skip(1));
if (newScopes.Count > 0)
{
context.NewScopes.AddRange(newScopes);
Expand All @@ -174,15 +176,15 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, I
&& pattern.MatchLeaf(expr)
&& pattern.Then.MatchLeaf(expr.Then)
&& pattern.Else.MatchLeaf(expr.Else)
&& pattern.Arguments.MatchLeaf(expr.Arguments))
&& pattern.Arguments.MatchLeaf(expr.Arguments, out var argsPattern))
{
var newScopes = Visit(context.Candidates, pattern.Then, enode.Children[0]);
if (newScopes.Count > 0)
{
newScopes = Visit(newScopes, pattern.Else, enode.Children[1]);
if (newScopes.Count > 0)
{
newScopes = Visit(newScopes, pattern.Arguments, enode.Children.Skip(2));
newScopes = Visit(newScopes, pattern.Arguments, argsPattern, enode.Children.Skip(2));
if (newScopes.Count > 0)
{
context.NewScopes.AddRange(newScopes);
Expand Down Expand Up @@ -223,9 +225,9 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, T

if (context.HasCandidates
&& pattern.MatchLeaf(expr)
&& pattern.Fields.MatchLeaf(expr.Fields))
&& pattern.Fields.MatchLeaf(expr.Fields, out var fieldsPattern))
{
var newScopes = Visit(context.Candidates, pattern.Fields, enode.Children);
var newScopes = Visit(context.Candidates, pattern.Fields, fieldsPattern, enode.Children);
if (newScopes.Count > 0)
{
context.NewScopes.AddRange(newScopes);
Expand Down Expand Up @@ -265,18 +267,17 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, O
return context.NewScopes;
}

private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, VArgsPattern pattern, IReadOnlyList<ENode> enodes)
private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, VArgsPattern pattern, IReadOnlyList<Pattern> argsPattern, IReadOnlyList<ENode> enodes)
{
var exprs = enodes.Select(x => x.Expr).ToArray();
var context = new MatchContext(matchScopes, pattern, exprs);

if (context.HasCandidates
&& pattern.MatchLeaf(exprs))
if (context.HasCandidates)
{
IReadOnlyList<MatchScope> scopes = context.Candidates;
for (int i = 0; i < pattern.Count; i++)
for (int i = 0; i < argsPattern.Count; i++)
{
scopes = Visit(scopes, pattern[i], enodes[i]);
scopes = Visit(scopes, argsPattern[i], enodes[i]);
if (scopes.Count == 0)
{
break;
Expand All @@ -293,9 +294,9 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, V
return context.NewScopes;
}

private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, VArgsPattern pattern, IEnumerable<EClass> eClasses)
private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, VArgsPattern pattern, IReadOnlyList<Pattern> argsPattern, IEnumerable<EClass> eClasses)
{
if (pattern.Count == 0 || eClasses.Count() != pattern.Count)
if (argsPattern.Count == 0 || eClasses.Count() != argsPattern.Count)
{
return Array.Empty<MatchScope>();
}
Expand All @@ -307,7 +308,7 @@ private IReadOnlyList<MatchScope> Visit(IReadOnlyList<MatchScope> matchScopes, V
select from en in ec.Nodes
select en).CartesianProduct())
{
var scopes = Visit(matchScopes, pattern, enodes.ToList());
var scopes = Visit(matchScopes, pattern, argsPattern, enodes.ToList());
if (scopes.Count > 0)
{
newScopes.AddRange(scopes);
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Graph/PatternMatch/Matcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ private bool VisitExprPattern(Expr expr, ExprPattern exprPattern)
private bool VisitVArgsPattern<T>(ReadOnlySpan<T> exprs, VArgsPattern vArgsPattern)
where T : Expr
{
bool isMatch = vArgsPattern.MatchLeaf(SpanUtility.UnsafeCast<T, Expr>(exprs));
bool isMatch = vArgsPattern.MatchLeaf(SpanUtility.UnsafeCast<T, Expr>(exprs), out var fieldsPattern);
if (isMatch)
{
for (int i = 0; i < exprs.Length; i++)
{
isMatch = Visit(exprs[i], vArgsPattern.Fields[i]);
isMatch = Visit(exprs[i], fieldsPattern[i]);
if (!isMatch)
{
break;
Expand Down

0 comments on commit 27849d8

Please sign in to comment.