From 0273c626f6d9c3b5c6262f9f0566a8138d4a6d20 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Thu, 6 Mar 2025 05:07:49 +0000 Subject: [PATCH] Fix type infer of Expand --- src/Nncase.Core/Evaluator/TypeInference.cs | 22 ++++++++++++++++++++++ src/Nncase.Evaluator/Tensors/Expand.cs | 8 +++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/Nncase.Core/Evaluator/TypeInference.cs b/src/Nncase.Core/Evaluator/TypeInference.cs index 6086cf787..bf0bd7166 100644 --- a/src/Nncase.Core/Evaluator/TypeInference.cs +++ b/src/Nncase.Core/Evaluator/TypeInference.cs @@ -590,4 +590,26 @@ public static IRType[] BroadcastDistributeTypes(params IRType[] types) return types; } + + public static Shape ExpandShape(Shape inShape, Shape expandShape) + { + if (inShape.IsUnranked || expandShape.IsUnranked) + { + return Shape.Unranked; + } + + var dimExtends = expandShape.Rank - inShape.Rank; + var newDims = expandShape.ToArray(); + + // dimsExtends may be negative + for (int i = Math.Max(0, dimExtends); i < newDims.Length; i++) + { + var inDimIndex = i - dimExtends; + ref var dimValue = ref newDims[i]; + dimValue = Dimension.Select(dimValue, 1L, inShape[inDimIndex], dimValue); + } + + newDims = inShape.Take(dimExtends < 0 ? -dimExtends : 0).Concat(newDims).ToArray(); + return new Shape(newDims); + } } diff --git a/src/Nncase.Evaluator/Tensors/Expand.cs b/src/Nncase.Evaluator/Tensors/Expand.cs index cfc430fef..e3cc94274 100644 --- a/src/Nncase.Evaluator/Tensors/Expand.cs +++ b/src/Nncase.Evaluator/Tensors/Expand.cs @@ -62,15 +62,17 @@ public IRType Visit(ITypeInferenceContext context, Expand target) private IRType Visit(ITypeInferenceContext context, Expand target, TensorType input, TensorType shape) { var shapeExpr = context.GetArgument(target, Expand.Shape); - return input with { Shape = Shape.FromExpr(shapeExpr) }; + var newShape = TypeInference.ExpandShape(input.Shape, Shape.FromExpr(shapeExpr)); + return input with { Shape = newShape }; } private IRType Visit(ITypeInferenceContext context, Expand target, DistributedType input, TensorType shape) { var invalid = new InvalidType(input.ToString()); - var newShape = Shape.FromExpr(context.GetArgument(target, Expand.Shape)); - if (newShape.IsRanked) + var shapeExpr = Shape.FromExpr(context.GetArgument(target, Expand.Shape)); + if (input.TensorType.Shape.IsRanked && shapeExpr.IsRanked) { + var newShape = TypeInference.ExpandShape(input.TensorType.Shape, shapeExpr); var ndsbp = new SBP[input.Placement.Rank]; for (int i = 0; i < input.Placement.Rank; i++) {