Skip to content

Commit 6185e40

Browse files
committed
Support predicates
1 parent 8c8500c commit 6185e40

File tree

3 files changed

+98
-17
lines changed

3 files changed

+98
-17
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#define LLVM_ANALYSIS_IR2VEC_H
3737

3838
#include "llvm/ADT/DenseMap.h"
39+
#include "llvm/IR/Instructions.h"
3940
#include "llvm/IR/PassManager.h"
4041
#include "llvm/IR/Type.h"
4142
#include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
162163
/// embeddings.
163164
class Vocabulary {
164165
friend class llvm::IR2VecVocabAnalysis;
166+
// Slot layout:
167+
// [0 .. MaxOpcodes-1] => Instruction
168+
// opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] =>
169+
// Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
170+
// Operands
171+
// Within Operands: first OperandKind entries, followed by compare
172+
// predicates
165173
using VocabVector = std::vector<ir2vec::Embedding>;
166174
VocabVector Vocab;
175+
167176
bool Valid = false;
177+
static constexpr unsigned NumICmpPredicates =
178+
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
179+
static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
180+
static constexpr unsigned NumFCmpPredicates =
181+
static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
182+
static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
168183

169184
public:
170-
// Slot layout:
171-
// [0 .. MaxOpcodes-1] => Instruction opcodes
172-
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
173-
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
174-
175185
/// Canonical type IDs supported by IR2Vec Vocabulary
176186
enum class CanonicalTypeID : unsigned {
177187
FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
208218
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
209219
static constexpr unsigned MaxOperandKinds =
210220
static_cast<unsigned>(OperandKind::MaxOperandKind);
221+
// CmpInst::Predicate has gaps. We want the vocabulary to be dense without
222+
// empty slots.
223+
static constexpr unsigned MaxPredicateKinds =
224+
NumICmpPredicates + NumFCmpPredicates;
211225

212226
Vocabulary() = default;
213227
LLVM_ABI Vocabulary(VocabVector &&Vocab);
214228

215229
LLVM_ABI bool isValid() const;
216230
LLVM_ABI unsigned getDimension() const;
217-
/// Total number of entries (opcodes + canonicalized types + operand kinds)
231+
/// Total number of entries (opcodes + canonicalized types + operand kinds +
232+
/// predicates)
218233
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
219234

220235
/// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
229244
/// Function to classify an operand into OperandKind
230245
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
231246

247+
/// Function to get vocabulary key for a given predicate
248+
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
249+
232250
/// Functions to return the slot index or position of a given Opcode, TypeID,
233251
/// or OperandKind in the vocabulary.
234252
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
235253
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
236254
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
255+
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
237256

238257
/// Accessors to get the embedding for a given entity.
239258
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
240259
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
241260
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
261+
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
242262

243263
/// Const Iterator type aliases
244264
using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
275295

276296
private:
277297
constexpr static unsigned NumCanonicalEntries =
278-
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
298+
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
299+
300+
// Base offsets for slot layout to simplify index computation
301+
constexpr static unsigned OperandBaseOffset =
302+
MaxOpcodes + MaxCanonicalTypeIDs;
303+
constexpr static unsigned PredicateBaseOffset =
304+
OperandBaseOffset + MaxOperandKinds;
279305

280306
/// String mappings for CanonicalTypeID values
281307
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
327353

328354
/// Function to convert TypeID to CanonicalTypeID
329355
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
356+
357+
/// Function to get the predicate enum value for a given index
358+
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
330359
};
331360

332361
/// Embedder provides the interface to generate embeddings (vector

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
216216
ArgEmb += Vocab[*Op];
217217
auto InstVector =
218218
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
219+
if (const auto *IC = dyn_cast<CmpInst>(&I))
220+
InstVector += Vocab[IC->getPredicate()];
219221
InstVecMap[&I] = InstVector;
220222
BBVector += InstVector;
221223
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
250252
// embeddings
251253
auto InstVector =
252254
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
255+
// Add compare predicate embedding as an additional operand if applicable
256+
if (const auto *IC = dyn_cast<CmpInst>(&I))
257+
InstVector += Vocab[IC->getPredicate()];
253258
InstVecMap[&I] = InstVector;
254259
BBVector += InstVector;
255260
}
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
285290
unsigned Vocabulary::getSlotIndex(const Value &Op) {
286291
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
287292
assert(Index < MaxOperandKinds && "Invalid OperandKind");
288-
return MaxOpcodes + MaxCanonicalTypeIDs + Index;
293+
return OperandBaseOffset + Index;
294+
}
295+
296+
unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
297+
unsigned PU = static_cast<unsigned>(P);
298+
unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
299+
unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
300+
301+
unsigned PredIdx =
302+
(PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
303+
return PredicateBaseOffset + PredIdx;
289304
}
290305

291306
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
300315
return Vocab[getSlotIndex(Arg)];
301316
}
302317

318+
const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
319+
return Vocab[getSlotIndex(P)];
320+
}
321+
303322
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
304323
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
305324
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -345,18 +364,35 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
345364
return OperandKind::VariableID;
346365
}
347366

367+
CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
368+
assert(Index < MaxPredicateKinds && "Invalid predicate index");
369+
unsigned PredEnumVal =
370+
(Index < NumFCmpPredicates)
371+
? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
372+
: (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
373+
(Index - NumFCmpPredicates));
374+
return static_cast<CmpInst::Predicate>(PredEnumVal);
375+
}
376+
377+
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
378+
return CmpInst::getPredicateName(Pred);
379+
}
380+
348381
StringRef Vocabulary::getStringKey(unsigned Pos) {
349382
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
350383
// Opcode
351384
if (Pos < MaxOpcodes)
352385
return getVocabKeyForOpcode(Pos + 1);
353386
// Type
354-
if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
387+
if (Pos < OperandBaseOffset)
355388
return getVocabKeyForCanonicalTypeID(
356389
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
357390
// Operand
358-
return getVocabKeyForOperandKind(
359-
static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
391+
if (Pos < PredicateBaseOffset)
392+
return getVocabKeyForOperandKind(
393+
static_cast<OperandKind>(Pos - OperandBaseOffset));
394+
// Predicates
395+
return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
360396
}
361397

362398
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +406,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
370406
VocabVector DummyVocab;
371407
DummyVocab.reserve(NumCanonicalEntries);
372408
float DummyVal = 0.1f;
373-
// Create a dummy vocabulary with entries for all opcodes, types, and
374-
// operands
375-
for ([[maybe_unused]] unsigned _ :
376-
seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
377-
Vocabulary::MaxOperandKinds)) {
409+
// Create a dummy vocabulary with entries for all opcodes, types, operands
410+
// and predicates
411+
for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
378412
DummyVocab.push_back(Embedding(Dim, DummyVal));
379413
DummyVal += 0.1f;
380414
}
@@ -517,6 +551,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
517551
}
518552
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
519553
NumericArgEmbeddings.end());
554+
555+
// Handle Predicates: part of Operands section. We look up predicate keys
556+
// in ArgVocab.
557+
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
558+
Embedding(Dim, 0));
559+
NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
560+
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
561+
StringRef VocabKey =
562+
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
563+
auto It = ArgVocab.find(VocabKey.str());
564+
if (It != ArgVocab.end()) {
565+
NumericPredEmbeddings[PK] = It->second;
566+
continue;
567+
}
568+
handleMissingEntity(VocabKey.str());
569+
}
570+
Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
571+
NumericPredEmbeddings.end());
520572
}
521573

522574
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)

llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class IR2VecTool {
184184
// Add "Arg" relationships
185185
unsigned ArgIndex = 0;
186186
for (const Use &U : I.operands()) {
187-
unsigned OperandID = Vocabulary::getSlotIndex(*U);
187+
unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
188188
unsigned RelationID = ArgRelation + ArgIndex;
189189
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
190190

0 commit comments

Comments
 (0)