Skip to content

Commit 9e296b4

Browse files
jcai19Google-ML-Automation
authored andcommitted
[XLA][Numerics][HLO Original Value] Try to infer original value of GetTupleElement operations
If the tuple operand of an GetTupleElement instruction is a Tuple operation and the instruction does not have an original value yet, returns the original value of the corresponding tuple element instead. PiperOrigin-RevId: 827373266
1 parent 01a696b commit 9e296b4

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

xla/hlo/ir/hlo_instruction.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6043,7 +6043,13 @@ void HloInstruction::set_output_to_operand_aliasing(
60436043
}
60446044

60456045
std::shared_ptr<OriginalValue> HloInstruction::original_value() const {
6046-
return original_value_;
6046+
if (original_value_ != nullptr || opcode_ != HloOpcode::kGetTupleElement) {
6047+
return original_value_;
6048+
}
6049+
const HloInstruction* tuple = operand(0);
6050+
return tuple->opcode() == HloOpcode::kTuple
6051+
? tuple->operand(tuple_index())->original_value()
6052+
: nullptr;
60476053
}
60486054

60496055
void HloInstruction::set_original_value(

xla/hlo/ir/hlo_original_value_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,5 +435,24 @@ ENTRY main {
435435
EXPECT_EQ(p0->original_value(), p1->original_value());
436436
}
437437

438+
TEST_F(OriginalValueHloTest, InferGetTupleElementOriginalValue) {
439+
const char* hlo_string = R"(
440+
HloModule test
441+
442+
ENTRY main {
443+
p0 = f32[] parameter(0), origin={{"p0"}}
444+
p1 = f32[] parameter(1)
445+
tuple = (f32[], f32[]) tuple(p0, p1)
446+
ROOT gte = f32[] get-tuple-element(tuple), index=0
447+
}
448+
)";
449+
TF_ASSERT_OK_AND_ASSIGN(auto module,
450+
ParseAndReturnVerifiedModule(hlo_string));
451+
const HloInstruction* gte = module->entry_computation()->root_instruction();
452+
453+
EXPECT_NE(gte->original_value(), nullptr);
454+
EXPECT_EQ(gte->original_value()->ToString(), R"({"p0"})");
455+
}
456+
438457
} // namespace
439458
} // namespace xla

0 commit comments

Comments
 (0)