diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index f439a345eb19..36a0f8963cc2 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -124,18 +124,28 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc const auto* int_dim0 = dim0.as(); const auto* int_dim1 = dim1.as(); if (int_dim0 != nullptr && int_dim0->value == 1) { + // static dim(1) output_shape.push_back(dim1); } else if (int_dim1 != nullptr && int_dim1->value == 1) { + // static dim(1) output_shape.push_back(dim0); } else if (analyzer->CanProveEqual(dim0, dim1)) { + // equal static dims or equal symbolic dims output_shape.push_back(dim0); } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + // different static dims ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the first input shape at dim " << x1_ndim - i << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i << " is " << dim1 << ", which are not broadcastable."); + } else if (int_dim0 == nullptr && int_dim1) { + // symbolic dim and static dim + output_shape.push_back(dim1); + } else if (int_dim0 && int_dim1 == nullptr) { + // static dim and symbolic dim + output_shape.push_back(dim0); } else { - // Use simple fallback when shape mismatch. + // Use simple fallback when shapes mismatch. return std::nullopt; } } diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 20c111495d6a..58854d50a3a3 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -224,8 +224,10 @@ def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): y2 = relax.Var("y", R.Tensor((4, k, m, 1), "float32")) y3 = relax.Var("y", R.Tensor("float32", ndim=2)) y4 = relax.Var("y", R.Tensor("float32", ndim=-1)) + y5 = relax.Var("y", R.Tensor((m, 3), "float32")) _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, binary_arith_op(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x0, y5), relax.TensorStructInfo((m, 3), "float32")) _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, binary_arith_op(x1, y2), relax.TensorStructInfo((4, k, m, n), "float32")) _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=4))