Skip to content

Commit 52b1850

Browse files
authored
[IR2Vec] Add support for Cmp predicates in vocabulary and embeddings (#156952)
Comparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space. Following are the changes: 1. Expand the vocabulary slot layout to include predicate entries after opcodes, types, and operands 2. Add methods to handle predicate embedding lookups and conversions 3. Update the embedder implementations to include predicate information when processing CmpInst instructions 4. Update test files to include the new predicate entries in the vocabulary (Tracking issues: #141817, #141833)
1 parent aeffd36 commit 52b1850

File tree

13 files changed

+351
-24
lines changed

13 files changed

+351
-24
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#define LLVM_ANALYSIS_IR2VEC_H
3737

3838
#include "llvm/ADT/DenseMap.h"
39+
#include "llvm/IR/Instructions.h"
3940
#include "llvm/IR/PassManager.h"
4041
#include "llvm/IR/Type.h"
4142
#include "llvm/Support/CommandLine.h"
@@ -162,15 +163,34 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
162163
/// embeddings.
163164
class Vocabulary {
164165
friend class llvm::IR2VecVocabAnalysis;
166+
167+
// Vocabulary Slot Layout:
168+
// +----------------+------------------------------------------------------+
169+
// | Entity Type | Index Range |
170+
// +----------------+------------------------------------------------------+
171+
// | Opcodes | [0 .. (MaxOpcodes-1)] |
172+
// | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] |
173+
// | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] |
174+
// +----------------+------------------------------------------------------+
175+
// Note: MaxOpcodes is the number of unique opcodes supported by LLVM IR.
176+
// MaxCanonicalTypeIDs is the number of canonicalized type IDs.
177+
// "Similar" LLVM Types are grouped/canonicalized together. E.g., all
178+
// float variants (FloatTy, DoubleTy, HalfTy, etc.) map to
179+
// CanonicalTypeID::FloatTy. This helps reduce the vocabulary size
180+
// and improves learning. Operands include Comparison predicates
181+
// (ICmp/FCmp) along with other operand types. This can be extended to
182+
// include other specializations in future.
165183
using VocabVector = std::vector<ir2vec::Embedding>;
166184
VocabVector Vocab;
167185

168-
public:
169-
// Slot layout:
170-
// [0 .. MaxOpcodes-1] => Instruction opcodes
171-
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
172-
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
186+
static constexpr unsigned NumICmpPredicates =
187+
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
188+
static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
189+
static constexpr unsigned NumFCmpPredicates =
190+
static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
191+
static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
173192

193+
public:
174194
/// Canonical type IDs supported by IR2Vec Vocabulary
175195
enum class CanonicalTypeID : unsigned {
176196
FloatTy,
@@ -207,13 +227,18 @@ class Vocabulary {
207227
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
208228
static constexpr unsigned MaxOperandKinds =
209229
static_cast<unsigned>(OperandKind::MaxOperandKind);
230+
// CmpInst::Predicate has gaps. We want the vocabulary to be dense without
231+
// empty slots.
232+
static constexpr unsigned MaxPredicateKinds =
233+
NumICmpPredicates + NumFCmpPredicates;
210234

211235
Vocabulary() = default;
212236
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
213237

214238
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
215239
LLVM_ABI unsigned getDimension() const;
216-
/// Total number of entries (opcodes + canonicalized types + operand kinds)
240+
/// Total number of entries (opcodes + canonicalized types + operand kinds +
241+
/// predicates)
217242
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
218243

219244
/// Function to get vocabulary key for a given Opcode
@@ -228,16 +253,21 @@ class Vocabulary {
228253
/// Function to classify an operand into OperandKind
229254
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
230255

256+
/// Function to get vocabulary key for a given predicate
257+
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
258+
231259
/// Functions to return the slot index or position of a given Opcode, TypeID,
232260
/// or OperandKind in the vocabulary.
233261
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
234262
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
235263
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
264+
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
236265

237266
/// Accessors to get the embedding for a given entity.
238267
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
239268
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
240269
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
270+
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
241271

242272
/// Const Iterator type aliases
243273
using const_iterator = VocabVector::const_iterator;
@@ -274,7 +304,13 @@ class Vocabulary {
274304

275305
private:
276306
constexpr static unsigned NumCanonicalEntries =
277-
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
307+
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
308+
309+
// Base offsets for slot layout to simplify index computation
310+
constexpr static unsigned OperandBaseOffset =
311+
MaxOpcodes + MaxCanonicalTypeIDs;
312+
constexpr static unsigned PredicateBaseOffset =
313+
OperandBaseOffset + MaxOperandKinds;
278314

279315
/// String mappings for CanonicalTypeID values
280316
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -326,6 +362,11 @@ class Vocabulary {
326362

327363
/// Function to convert TypeID to CanonicalTypeID
328364
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
365+
366+
/// Function to get the predicate enum value for a given index. Index is
367+
/// relative to the predicates section of the vocabulary. E.g., Index 0
368+
/// corresponds to the first predicate.
369+
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
329370
};
330371

331372
/// Embedder provides the interface to generate embeddings (vector

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
216216
ArgEmb += Vocab[*Op];
217217
auto InstVector =
218218
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
219+
if (const auto *IC = dyn_cast<CmpInst>(&I))
220+
InstVector += Vocab[IC->getPredicate()];
219221
InstVecMap[&I] = InstVector;
220222
BBVector += InstVector;
221223
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
250252
// embeddings
251253
auto InstVector =
252254
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
255+
// Add compare predicate embedding as an additional operand if applicable
256+
if (const auto *IC = dyn_cast<CmpInst>(&I))
257+
InstVector += Vocab[IC->getPredicate()];
253258
InstVecMap[&I] = InstVector;
254259
BBVector += InstVector;
255260
}
@@ -278,7 +283,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
278283
unsigned Vocabulary::getSlotIndex(const Value &Op) {
279284
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
280285
assert(Index < MaxOperandKinds && "Invalid OperandKind");
281-
return MaxOpcodes + MaxCanonicalTypeIDs + Index;
286+
return OperandBaseOffset + Index;
287+
}
288+
289+
unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
290+
unsigned PU = static_cast<unsigned>(P);
291+
unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
292+
unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
293+
294+
unsigned PredIdx =
295+
(PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
296+
return PredicateBaseOffset + PredIdx;
282297
}
283298

284299
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -293,6 +308,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
293308
return Vocab[getSlotIndex(Arg)];
294309
}
295310

311+
const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
312+
return Vocab[getSlotIndex(P)];
313+
}
314+
296315
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
297316
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
298317
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -338,18 +357,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
338357
return OperandKind::VariableID;
339358
}
340359

360+
CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
361+
assert(Index < MaxPredicateKinds && "Invalid predicate index");
362+
unsigned PredEnumVal =
363+
(Index < NumFCmpPredicates)
364+
? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
365+
: (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
366+
(Index - NumFCmpPredicates));
367+
return static_cast<CmpInst::Predicate>(PredEnumVal);
368+
}
369+
370+
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
371+
static SmallString<16> PredNameBuffer;
372+
if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
373+
PredNameBuffer = "FCMP_";
374+
else
375+
PredNameBuffer = "ICMP_";
376+
PredNameBuffer += CmpInst::getPredicateName(Pred);
377+
return PredNameBuffer;
378+
}
379+
341380
StringRef Vocabulary::getStringKey(unsigned Pos) {
342381
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
343382
// Opcode
344383
if (Pos < MaxOpcodes)
345384
return getVocabKeyForOpcode(Pos + 1);
346385
// Type
347-
if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
386+
if (Pos < OperandBaseOffset)
348387
return getVocabKeyForCanonicalTypeID(
349388
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
350389
// Operand
351-
return getVocabKeyForOperandKind(
352-
static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
390+
if (Pos < PredicateBaseOffset)
391+
return getVocabKeyForOperandKind(
392+
static_cast<OperandKind>(Pos - OperandBaseOffset));
393+
// Predicates
394+
return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
353395
}
354396

355397
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -363,11 +405,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
363405
VocabVector DummyVocab;
364406
DummyVocab.reserve(NumCanonicalEntries);
365407
float DummyVal = 0.1f;
366-
// Create a dummy vocabulary with entries for all opcodes, types, and
367-
// operands
368-
for ([[maybe_unused]] unsigned _ :
369-
seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
370-
Vocabulary::MaxOperandKinds)) {
408+
// Create a dummy vocabulary with entries for all opcodes, types, operands
409+
// and predicates
410+
for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
371411
DummyVocab.push_back(Embedding(Dim, DummyVal));
372412
DummyVal += 0.1f;
373413
}
@@ -510,6 +550,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
510550
}
511551
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
512552
NumericArgEmbeddings.end());
553+
554+
// Handle Predicates: part of Operands section. We look up predicate keys
555+
// in ArgVocab.
556+
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
557+
Embedding(Dim, 0));
558+
NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
559+
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
560+
StringRef VocabKey =
561+
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
562+
auto It = ArgVocab.find(VocabKey.str());
563+
if (It != ArgVocab.end()) {
564+
NumericPredEmbeddings[PK] = It->second;
565+
continue;
566+
}
567+
handleMissingEntity(VocabKey.str());
568+
}
569+
Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
570+
NumericPredEmbeddings.end());
513571
}
514572

515573
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)

llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,32 @@
8787
"Function": [1, 2],
8888
"Pointer": [3, 4],
8989
"Constant": [5, 6],
90-
"Variable": [7, 8]
90+
"Variable": [7, 8],
91+
"FCMP_false": [9, 10],
92+
"FCMP_oeq": [11, 12],
93+
"FCMP_ogt": [13, 14],
94+
"FCMP_oge": [15, 16],
95+
"FCMP_olt": [17, 18],
96+
"FCMP_ole": [19, 20],
97+
"FCMP_one": [21, 22],
98+
"FCMP_ord": [23, 24],
99+
"FCMP_uno": [25, 26],
100+
"FCMP_ueq": [27, 28],
101+
"FCMP_ugt": [29, 30],
102+
"FCMP_uge": [31, 32],
103+
"FCMP_ult": [33, 34],
104+
"FCMP_ule": [35, 36],
105+
"FCMP_une": [37, 38],
106+
"FCMP_true": [39, 40],
107+
"ICMP_eq": [41, 42],
108+
"ICMP_ne": [43, 44],
109+
"ICMP_ugt": [45, 46],
110+
"ICMP_uge": [47, 48],
111+
"ICMP_ult": [49, 50],
112+
"ICMP_ule": [51, 52],
113+
"ICMP_sgt": [53, 54],
114+
"ICMP_sge": [55, 56],
115+
"ICMP_slt": [57, 58],
116+
"ICMP_sle": [59, 60]
91117
}
92118
}

llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,32 @@
8686
"Function": [1, 2, 3],
8787
"Pointer": [4, 5, 6],
8888
"Constant": [7, 8, 9],
89-
"Variable": [10, 11, 12]
89+
"Variable": [10, 11, 12],
90+
"FCMP_false": [13, 14, 15],
91+
"FCMP_oeq": [16, 17, 18],
92+
"FCMP_ogt": [19, 20, 21],
93+
"FCMP_oge": [22, 23, 24],
94+
"FCMP_olt": [25, 26, 27],
95+
"FCMP_ole": [28, 29, 30],
96+
"FCMP_one": [31, 32, 33],
97+
"FCMP_ord": [34, 35, 36],
98+
"FCMP_uno": [37, 38, 39],
99+
"FCMP_ueq": [40, 41, 42],
100+
"FCMP_ugt": [43, 44, 45],
101+
"FCMP_uge": [46, 47, 48],
102+
"FCMP_ult": [49, 50, 51],
103+
"FCMP_ule": [52, 53, 54],
104+
"FCMP_une": [55, 56, 57],
105+
"FCMP_true": [58, 59, 60],
106+
"ICMP_eq": [61, 62, 63],
107+
"ICMP_ne": [64, 65, 66],
108+
"ICMP_ugt": [67, 68, 69],
109+
"ICMP_uge": [70, 71, 72],
110+
"ICMP_ult": [73, 74, 75],
111+
"ICMP_ule": [76, 77, 78],
112+
"ICMP_sgt": [79, 80, 81],
113+
"ICMP_sge": [82, 83, 84],
114+
"ICMP_slt": [85, 86, 87],
115+
"ICMP_sle": [88, 89, 90]
90116
}
91117
}

llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"FPTrunc": [133, 134, 135],
4848
"FPExt": [136, 137, 138],
4949
"PtrToInt": [139, 140, 141],
50+
"PtrToAddr": [202, 203, 204],
5051
"IntToPtr": [142, 143, 144],
5152
"BitCast": [145, 146, 147],
5253
"AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
8687
"Function": [0, 0, 0],
8788
"Pointer": [0, 0, 0],
8889
"Constant": [0, 0, 0],
89-
"Variable": [0, 0, 0]
90+
"Variable": [0, 0, 0],
91+
"FCMP_false": [0, 0, 0],
92+
"FCMP_oeq": [0, 0, 0],
93+
"FCMP_ogt": [0, 0, 0],
94+
"FCMP_oge": [0, 0, 0],
95+
"FCMP_olt": [0, 0, 0],
96+
"FCMP_ole": [0, 0, 0],
97+
"FCMP_one": [0, 0, 0],
98+
"FCMP_ord": [0, 0, 0],
99+
"FCMP_uno": [0, 0, 0],
100+
"FCMP_ueq": [0, 0, 0],
101+
"FCMP_ugt": [0, 0, 0],
102+
"FCMP_uge": [0, 0, 0],
103+
"FCMP_ult": [0, 0, 0],
104+
"FCMP_ule": [0, 0, 0],
105+
"FCMP_une": [0, 0, 0],
106+
"FCMP_true": [0, 0, 0],
107+
"ICMP_eq": [0, 0, 0],
108+
"ICMP_ne": [0, 0, 0],
109+
"ICMP_ugt": [0, 0, 0],
110+
"ICMP_uge": [0, 0, 0],
111+
"ICMP_ult": [0, 0, 0],
112+
"ICMP_ule": [0, 0, 0],
113+
"ICMP_sgt": [1, 1, 1],
114+
"ICMP_sge": [0, 0, 0],
115+
"ICMP_slt": [0, 0, 0],
116+
"ICMP_sle": [0, 0, 0]
90117
}
91118
}

llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
8282
Key: Pointer: [ 0.60 0.80 ]
8383
Key: Constant: [ 1.00 1.20 ]
8484
Key: Variable: [ 1.40 1.60 ]
85+
Key: FCMP_false: [ 1.80 2.00 ]
86+
Key: FCMP_oeq: [ 2.20 2.40 ]
87+
Key: FCMP_ogt: [ 2.60 2.80 ]
88+
Key: FCMP_oge: [ 3.00 3.20 ]
89+
Key: FCMP_olt: [ 3.40 3.60 ]
90+
Key: FCMP_ole: [ 3.80 4.00 ]
91+
Key: FCMP_one: [ 4.20 4.40 ]
92+
Key: FCMP_ord: [ 4.60 4.80 ]
93+
Key: FCMP_uno: [ 5.00 5.20 ]
94+
Key: FCMP_ueq: [ 5.40 5.60 ]
95+
Key: FCMP_ugt: [ 5.80 6.00 ]
96+
Key: FCMP_uge: [ 6.20 6.40 ]
97+
Key: FCMP_ult: [ 6.60 6.80 ]
98+
Key: FCMP_ule: [ 7.00 7.20 ]
99+
Key: FCMP_une: [ 7.40 7.60 ]
100+
Key: FCMP_true: [ 7.80 8.00 ]
101+
Key: ICMP_eq: [ 8.20 8.40 ]
102+
Key: ICMP_ne: [ 8.60 8.80 ]
103+
Key: ICMP_ugt: [ 9.00 9.20 ]
104+
Key: ICMP_uge: [ 9.40 9.60 ]
105+
Key: ICMP_ult: [ 9.80 10.00 ]
106+
Key: ICMP_ule: [ 10.20 10.40 ]
107+
Key: ICMP_sgt: [ 10.60 10.80 ]
108+
Key: ICMP_sge: [ 11.00 11.20 ]
109+
Key: ICMP_slt: [ 11.40 11.60 ]
110+
Key: ICMP_sle: [ 11.80 12.00 ]

0 commit comments

Comments
 (0)