Skip to content

Commit

Permalink
Fix UT
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Mar 3, 2025
1 parent 1b9801f commit 2505bbb
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 220 deletions.
124 changes: 0 additions & 124 deletions src/Native/src/kernels/stackvm/shape_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,6 @@ using namespace nncase::kernels::stackvm;
using namespace nncase::runtime;
using namespace nncase::runtime::stackvm;

result<value_t> nncase::kernels::stackvm::conv2d_shape(
value_t input, value_t weights, value_t padding, value_t stride,
value_t dilation, [[maybe_unused]] value_t groups, value_t output,
kernel_context &) {
try_dims(in_shape, input);
try_dims(w_shape, weights);
try_strides(strides_value, stride);
try_paddings(pads, padding);
try_strides(strides, stride);
try_strides(dilations, dilation);
try_output(out_mem, output, dt_int64, dims_t{4});
auto out_shape =
conv2d_infer_shape(in_shape, w_shape, strides_value, dilations, pads);
for (int i = 0; i < 4; ++i) {
OUT_CAST(int64_t, out_mem)[i] = out_shape[i];
}
KERNEL_FINISH;
}

size_t compute_out_size(int input_size, int weights_size,
const strides_t &strides, dims_t out_paddings,
paddings_t paddings, const strides_t &dilations,
Expand Down Expand Up @@ -56,73 +37,12 @@ dims_t conv2d_transpose_infer_shape(std::span<const size_t> in_shape,
return out_shape;
}

result<value_t> nncase::kernels::stackvm::conv2d_transpose_shape(
value_t input, value_t weights, value_t stride, value_t dilation,
value_t padding, value_t output_padding, value_t groups, value_t output,
kernel_context &) {
try_dims(input_shape, input);
try_dims(weights_shape, weights);
try_strides(strides_value, stride);
try_paddings(pads, padding);
try_dims(out_padding, output_padding);
try_to_integer(groups_value, groups);
try_strides(strides, stride);
try_strides(dilations, dilation);

auto out_shape =
conv2d_transpose_infer_shape(input_shape, weights_shape, strides, pads,
out_padding, dilations, groups_value);
try_output(out_mem, output, dt_int64, dims_t{4});
for (int i = 0; i < 4; ++i) {
OUT_CAST(int64_t, out_mem)[i] = out_shape[i];
}
KERNEL_FINISH;
}

result<dims_t> to_dims(tensor shape) {
try_dims(shape_value, shape);
return ok(shape_value);
}

result<value_t> nncase::kernels::stackvm::broadcast_shape(value_t inputs,
value_t output,
kernel_context &) {
try_tuple_input(tuple_mem, inputs);
auto begin = inputs_tuple->fields().begin();
auto out_shape = std::accumulate(
std::next(begin), inputs_tuple->fields().end(),
to_dims(begin->as<tensor>().unwrap()).unwrap(),
[&](auto sum, auto field) {
auto shape = to_dims(field.template as<tensor>().unwrap()).unwrap();
auto result = kernels::detail::get_binary_output_shape(shape, sum);

return dims_t(result.begin(), result.end());
});
try_output(out_mem, output, dt_int64, dims_t{out_shape.size()});
for (int i = 0; i < out_shape.size(); ++i) {
OUT_CAST(int64_t, out_mem)[i] = out_shape[i];
}

KERNEL_FINISH;
}

#define WRITE_OUT_SHAPE \
try_output(out_mem, output, dt_int64, dims_t{out_shape.size()}); \
for (int i = 0; i < out_shape.size(); ++i) { \
OUT_CAST(int64_t, out_mem)[i] = out_shape[i]; \
}

result<value_t> nncase::kernels::stackvm::mat_mul_shape(value_t lhs,
value_t rhs,
value_t output,
kernel_context &) {
try_dims(lhs_shape, lhs);
try_dims(rhs_shape, rhs);
try_var(out_shape, matmul_infer_shape(lhs_shape, rhs_shape));
WRITE_OUT_SHAPE;
KERNEL_FINISH;
}

inline int get_windowed_output_size(int size, int filter, int stride,
int dilation, bool same, bool ceilMode) {
auto effectiveFilterSize = ((filter - 1) * dilation) + 1;
Expand Down Expand Up @@ -177,47 +97,3 @@ result<value_t> nncase::kernels::stackvm::get_paddings(
OUT_CAST(int64_t, output_mem)[3] = pad_w.after;
KERNEL_FINISH;
}

result<value_t> nncase::kernels::stackvm::reshape_shape(value_t input_shape,
value_t shape,
value_t output,
kernel_context &) {
try_dims(in_shape, input_shape);
try_axes(shape_value, shape);
auto out_shape = reshape_shape_infer(in_shape, shape_value);
WRITE_OUT_SHAPE;
KERNEL_FINISH;
}

result<value_t>
nncase::kernels::stackvm::transpose_shape(value_t input_shape, value_t perm,
value_t output,
[[maybe_unused]] kernel_context &) {
try_dims(in_shape, input_shape);
try_dims(perm_value, perm);
auto out_shape = transpose_infer_shape(in_shape, perm_value);
WRITE_OUT_SHAPE;
KERNEL_FINISH;
}

result<value_t>
nncase::kernels::stackvm::squeeze_shape(value_t input_shape, value_t dim,
value_t output,
[[maybe_unused]] kernel_context &) {
try_dims(in_shape, input_shape);
try_positive_axes(dim_value, dim, in_shape.size());
auto out_shape = squeeze_infer_shape(in_shape, dim_value);
WRITE_OUT_SHAPE;
KERNEL_FINISH;
}

result<value_t>
nncase::kernels::stackvm::unsqueeze_shape(value_t input_shape, value_t dim,
value_t output,
[[maybe_unused]] kernel_context &) {
try_dims(in_shape, input_shape);
try_axes(dim_value, dim);
auto out_shape = unsqueeze_infer_shape(in_shape, dim_value);
WRITE_OUT_SHAPE;
KERNEL_FINISH;
}
45 changes: 21 additions & 24 deletions src/Native/src/kernels/stackvm/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ using namespace nncase::runtime::stackvm;
result<value_t> nncase::kernels::stackvm::allocate(
[[maybe_unused]] typecode_t elem_type,
[[maybe_unused]] runtime::stackvm::memory_location_t location,
[[maybe_unused]] bool malloc, [[maybe_unused]] bool can_fold_const_call,
[[maybe_unused]] value_t size, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
[[maybe_unused]] bool malloc, [[maybe_unused]] value_t size,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

result<value_t> nncase::kernels::stackvm::allocate_buffer_view(
[[maybe_unused]] bool can_fold_const_call, [[maybe_unused]] value_t buffer,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
[[maybe_unused]] value_t buffer, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

Expand Down Expand Up @@ -138,23 +137,22 @@ result<value_t> nncase::kernels::stackvm::buffer_index_of(
}

result<value_t> nncase::kernels::stackvm::buffer_load(
[[maybe_unused]] bool can_fold_const_call, [[maybe_unused]] value_t input,
[[maybe_unused]] value_t indices, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
[[maybe_unused]] value_t input, [[maybe_unused]] value_t indices,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

result<value_t> nncase::kernels::stackvm::buffer_store(
[[maybe_unused]] bool can_fold_const_call, [[maybe_unused]] value_t input,
[[maybe_unused]] value_t indices, [[maybe_unused]] value_t value,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
[[maybe_unused]] value_t input, [[maybe_unused]] value_t indices,
[[maybe_unused]] value_t value, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

result<value_t> nncase::kernels::stackvm::buffer_subview(
[[maybe_unused]] bool can_fold_const_call, [[maybe_unused]] value_t buffer,
[[maybe_unused]] value_t offset, [[maybe_unused]] value_t shape,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
[[maybe_unused]] value_t buffer, [[maybe_unused]] value_t offset,
[[maybe_unused]] value_t shape, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

Expand Down Expand Up @@ -208,7 +206,6 @@ result<value_t> nncase::kernels::stackvm::concat(int32_t axis, value_t input,
}

result<value_t> nncase::kernels::stackvm::condition(
[[maybe_unused]] bool can_fold_const_call,
[[maybe_unused]] value_t predicate, [[maybe_unused]] value_t value,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
Expand Down Expand Up @@ -296,9 +293,10 @@ result<value_t> nncase::kernels::stackvm::conv2d_transpose(
return ok(output);
}

result<value_t> nncase::kernels::stackvm::ddr_of(
[[maybe_unused]] bool can_fold_const_call, [[maybe_unused]] value_t input,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
result<value_t>
nncase::kernels::stackvm::ddr_of([[maybe_unused]] value_t input,
[[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

Expand Down Expand Up @@ -625,8 +623,8 @@ nncase::kernels::stackvm::mat_mul(value_t lhs, value_t rhs, value_t output,
}

result<value_t> nncase::kernels::stackvm::match_buffer(
[[maybe_unused]] bool can_fold_const_call, [[maybe_unused]] value_t input,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
[[maybe_unused]] value_t input, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
return err(std::errc::not_supported);
}

Expand Down Expand Up @@ -830,10 +828,9 @@ nncase::kernels::stackvm::relu6([[maybe_unused]] value_t input,
}

result<value_t> nncase::kernels::stackvm::require(
[[maybe_unused]] std::string message,
[[maybe_unused]] bool can_fold_const_call,
[[maybe_unused]] value_t predicate, [[maybe_unused]] value_t value,
[[maybe_unused]] value_t output, [[maybe_unused]] kernel_context &context) {
[[maybe_unused]] std::string message, [[maybe_unused]] value_t predicate,
[[maybe_unused]] value_t value, [[maybe_unused]] value_t output,
[[maybe_unused]] kernel_context &context) {
try_to_scalar(cond, predicate, bool);
if (!cond) {
printf("%s\n", message.data());
Expand Down
11 changes: 10 additions & 1 deletion src/Nncase.Core/CompilerServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,19 @@ public static Expr FastSimplifyForDimension(Expr value)
{
return tc.Value.ElementType == DataTypes.Int64 ? tc : new TensorConst(tc.Value.Cast<long>());
}
else if (value is Const or Var or None)
else if (value is None)
{
return value;
}
else if (value is Var)
{
return value.CheckedType is TensorType tt && tt.DType == DataTypes.Int64 ? value : IR.F.Tensors.Cast(value, DataTypes.Int64);
}
else if ((value.CheckedType is TensorType tt && tt.DType != DataTypes.Int64)
|| (value.CheckedType is DistributedType dt && dt.TensorType.DType != DataTypes.Int64))
{
return SimplifyForDimension(IR.F.Tensors.Cast(value, DataTypes.Int64));
}
else if ((value is Call call && call.Arguments.AsValueEnumerable().All(x => x is Const))
|| value.CheckedType is DistributedType)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/IR/Shape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ public Shape(params Expr[] dimensions)
{
if (DumpScope.Current.IsEnabled(DumpFlags.Compile))
{
DumpScope.Current.DumpIR(this, "InvalidDimension");
DumpScope.Current.DumpIR(dim, "InvalidDimension");
}

throw new ArgumentException($"Invalid dimension type: {dim.CheckedType}");
Expand Down
32 changes: 0 additions & 32 deletions src/Nncase.Core/IR/TypePattern.cs
Original file line number Diff line number Diff line change
Expand Up @@ -279,36 +279,4 @@ public static TypePattern HasRank(Func<int, bool> cond, string reason) => HasSha
/// is scalar quant param.
/// </summary>
public static TypePattern IsQuantParamType() => IsScalar() & HasDataType(new QuantParamType());

/// <summary>
/// get padding windows output size.
/// </summary>
public static int GetWindowedOutputSize(int size, int filter, int stride, int dilation, bool same, bool ceilMode = false)
{
var effective_filter_size = ((filter - 1) * dilation) + 1;
if (same)
{
return (size + stride - 1) / stride;
}
else
{
if (!ceilMode)
{
return (size - effective_filter_size + stride) / stride;
}
else
{
return (int)System.Math.Ceiling((float)(size - effective_filter_size + stride) / stride);
}
}
}

/// <summary>
/// GetWindowedOutputSize.
/// </summary>
public static int GetWindowedOutputSize(int size, int filter, int stride, int dilation, (int Before, int After) padding)
{
var effective_filter_size = ((filter - 1) * dilation) + 1;
return (size + padding.Before + padding.After - effective_filter_size + stride) / stride;
}
}
5 changes: 5 additions & 0 deletions src/Nncase.EGraph/Passes/RewriteProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ from newExpr in candidates
foreach (var (oldExpr, oldEClass, newExpr) in replacedExprs)
{
var typeInferSuccess = CompilerServices.InferenceType(newExpr);
if (!typeInferSuccess && DumpScope.Current.IsEnabled(DumpFlags.Rewrite))
{
DumpScope.Current.DumpIR(newExpr, $"{rule}_InferFailed", "Rewrite");
}

if (!typeInferSuccess)
{
throw new InvalidOperationException($"Type inference failed for {newExpr}");
Expand Down
12 changes: 5 additions & 7 deletions src/Nncase.Evaluator/Tensors/Concat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,17 @@ public Cost Visit(ICostEvaluateContext context, Concat target)

private IRType? CheckType(TupleType inputs)
{
bool? allScalar = null;
DataType? allDType = null;
foreach (var (i, input) in Enumerable.Range(0, inputs.Count).Select(i => (i, inputs[i])))
{
TensorType type;
if (input is TensorType a)
{
if (a.IsScalar)
{
return new InvalidType($"Scalar tensor (at {i}) cannot be concatenated");
}

type = a;
}
else if (input is DistributedType { TensorType: TensorType b })
Expand All @@ -76,7 +80,6 @@ public Cost Visit(ICostEvaluateContext context, Concat target)
return new TensorType(type.DType, Shape.Unranked);
}

allScalar = (allScalar ?? type.IsScalar) & type.IsScalar;
allDType ??= type.DType;
if (allDType != type.DType)
{
Expand All @@ -85,11 +88,6 @@ public Cost Visit(ICostEvaluateContext context, Concat target)
}
}

if (allScalar == true && allDType is not null)
{
return new TensorType(allDType, new[] { inputs.Count });
}

return null;
}

Expand Down
Loading

0 comments on commit 2505bbb

Please sign in to comment.