From 1f4d7232ba74e92369c6e6f69e522c811fd46461 Mon Sep 17 00:00:00 2001 From: xla authors Date: Fri, 31 Oct 2025 06:41:21 -0700 Subject: [PATCH] Fix tf.math.reciprocal behavior for complex128 when inf values are provided. PiperOrigin-RevId: 826467601 --- xla/hlo/builder/lib/math.cc | 14 +++++++++++++- xla/hlo/builder/lib/math_test.cc | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/xla/hlo/builder/lib/math.cc b/xla/hlo/builder/lib/math.cc index 2c7d5c3d272b1..75147560c5472 100644 --- a/xla/hlo/builder/lib/math.cc +++ b/xla/hlo/builder/lib/math.cc @@ -213,7 +213,19 @@ XlaOp IsNegZero(XlaOp operand) { XlaOp Square(XlaOp operand) { return operand * operand; } -XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } +XlaOp Reciprocal(XlaOp operand) { + XlaBuilder* b = operand.builder(); + return b->ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(operand)); + if (primitive_util::IsComplexType(shape.element_type())) { + XlaOp is_finite = And(IsFinite(Real(operand)), IsFinite(Imag(operand))); + XlaOp is_inf = And(Not(is_finite), Eq(operand, operand)); + return Select(is_inf, ZerosLike(operand), + ScalarLike(operand, 1.0) / operand); + } + return ScalarLike(operand, 1.0) / operand; + }); +} // Computes an approximation of the error function complement (1 - erf(x)). // diff --git a/xla/hlo/builder/lib/math_test.cc b/xla/hlo/builder/lib/math_test.cc index e745aae39f99f..64b3f7cc732b2 100644 --- a/xla/hlo/builder/lib/math_test.cc +++ b/xla/hlo/builder/lib/math_test.cc @@ -364,6 +364,20 @@ TEST_F(MathTest, ReciprocalTenValues) { ComputeAndCompareR1(&builder, expected, {}, kErrorSpec); } +TEST_F(MathTest, ReciprocalComplexInfinity) { + XlaBuilder builder(TestName()); + auto x = ConstantR1>( + &builder, {{std::numeric_limits::infinity(), 0.0}, + {0.0, std::numeric_limits::infinity()}, + {std::numeric_limits::infinity(), + std::numeric_limits::infinity()}}); + Reciprocal(x); + + std::vector> expected = { + {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}; + ComputeAndCompareR1>(&builder, expected, {}, kErrorSpec); +} + TEST_F(MathTest, SqrtZeroes) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0.0, -0.0});