diff --git a/xla/hlo/ir/hlo_instruction.cc b/xla/hlo/ir/hlo_instruction.cc index 374621ba14cd4..a13fb0540a520 100644 --- a/xla/hlo/ir/hlo_instruction.cc +++ b/xla/hlo/ir/hlo_instruction.cc @@ -6043,7 +6043,13 @@ void HloInstruction::set_output_to_operand_aliasing( } std::shared_ptr HloInstruction::original_value() const { - return original_value_; + if (original_value_ != nullptr || opcode_ != HloOpcode::kGetTupleElement) { + return original_value_; + } + const HloInstruction* tuple = operand(0); + return tuple->opcode() == HloOpcode::kTuple + ? tuple->operand(tuple_index())->original_value() + : nullptr; } void HloInstruction::set_original_value( diff --git a/xla/hlo/ir/hlo_original_value_test.cc b/xla/hlo/ir/hlo_original_value_test.cc index 2cd2dc855f9e5..b70a8f1576587 100644 --- a/xla/hlo/ir/hlo_original_value_test.cc +++ b/xla/hlo/ir/hlo_original_value_test.cc @@ -435,5 +435,24 @@ ENTRY main { EXPECT_EQ(p0->original_value(), p1->original_value()); } +TEST_F(OriginalValueHloTest, InferGetTupleElementOriginalValue) { + const char* hlo_string = R"( +HloModule test + +ENTRY main { + p0 = f32[] parameter(0), origin={{"p0"}} + p1 = f32[] parameter(1) + tuple = (f32[], f32[]) tuple(p0, p1) + ROOT gte = f32[] get-tuple-element(tuple), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloInstruction* gte = module->entry_computation()->root_instruction(); + + EXPECT_NE(gte->original_value(), nullptr); + EXPECT_EQ(gte->original_value()->ToString(), R"({"p0"})"); +} + } // namespace } // namespace xla