Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Vocabulary {
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value *Op);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);

/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
}

unsigned Vocabulary::getSlotIndex(const Value *Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(Op));
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
return MaxOpcodes + MaxCanonicalTypeIDs + Index;
}
Expand All @@ -297,7 +297,7 @@ const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const {
}

const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(&Arg)];
return Vocab[getSlotIndex(Arg)];
}

StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
Expand Down
2 changes: 1 addition & 1 deletion llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getSlotIndex(U.get());
unsigned OperandID = Vocabulary::getSlotIndex(*U);
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';

Expand Down
8 changes: 4 additions & 4 deletions llvm/unittests/Analysis/IR2VecTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,23 +507,23 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
#define EXPECTED_VOCAB_OPERAND_SLOT(X) \
MaxOpcodes + MaxCanonicalTypeIDs + static_cast<unsigned>(X)
// Test Function operand
EXPECT_EQ(Vocabulary::getSlotIndex(F),
EXPECT_EQ(Vocabulary::getSlotIndex(*F),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::FunctionID));

// Test Constant operand
Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
EXPECT_EQ(Vocabulary::getSlotIndex(C),
EXPECT_EQ(Vocabulary::getSlotIndex(*C),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::ConstantID));

// Test Pointer operand
BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
EXPECT_EQ(Vocabulary::getSlotIndex(PtrVal),
EXPECT_EQ(Vocabulary::getSlotIndex(*PtrVal),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::PointerID));

// Test Variable operand (function argument)
Argument *Arg = F->getArg(0);
EXPECT_EQ(Vocabulary::getSlotIndex(Arg),
EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
#undef EXPECTED_VOCAB_OPERAND_SLOT
}
Expand Down