Skip to content

Commit

Permalink
Fix type infer of Expand
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Mar 6, 2025
1 parent 73dcee7 commit 0273c62
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
22 changes: 22 additions & 0 deletions src/Nncase.Core/Evaluator/TypeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -590,4 +590,26 @@ public static IRType[] BroadcastDistributeTypes(params IRType[] types)

return types;
}

public static Shape ExpandShape(Shape inShape, Shape expandShape)
{
if (inShape.IsUnranked || expandShape.IsUnranked)
{
return Shape.Unranked;
}

var dimExtends = expandShape.Rank - inShape.Rank;
var newDims = expandShape.ToArray();

// dimsExtends may be negative
for (int i = Math.Max(0, dimExtends); i < newDims.Length; i++)
{
var inDimIndex = i - dimExtends;
ref var dimValue = ref newDims[i];
dimValue = Dimension.Select(dimValue, 1L, inShape[inDimIndex], dimValue);
}

newDims = inShape.Take(dimExtends < 0 ? -dimExtends : 0).Concat(newDims).ToArray();
return new Shape(newDims);
}
}
8 changes: 5 additions & 3 deletions src/Nncase.Evaluator/Tensors/Expand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@ public IRType Visit(ITypeInferenceContext context, Expand target)
private IRType Visit(ITypeInferenceContext context, Expand target, TensorType input, TensorType shape)
{
var shapeExpr = context.GetArgument(target, Expand.Shape);
return input with { Shape = Shape.FromExpr(shapeExpr) };
var newShape = TypeInference.ExpandShape(input.Shape, Shape.FromExpr(shapeExpr));
return input with { Shape = newShape };
}

private IRType Visit(ITypeInferenceContext context, Expand target, DistributedType input, TensorType shape)
{
var invalid = new InvalidType(input.ToString());
var newShape = Shape.FromExpr(context.GetArgument(target, Expand.Shape));
if (newShape.IsRanked)
var shapeExpr = Shape.FromExpr(context.GetArgument(target, Expand.Shape));
if (input.TensorType.Shape.IsRanked && shapeExpr.IsRanked)
{
var newShape = TypeInference.ExpandShape(input.TensorType.Shape, shapeExpr);
var ndsbp = new SBP[input.Placement.Rank];
for (int i = 0; i < input.Placement.Rank; i++)
{
Expand Down

0 comments on commit 0273c62

Please sign in to comment.