@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
216
216
ArgEmb += Vocab[*Op];
217
217
auto InstVector =
218
218
Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
219
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
220
+ InstVector += Vocab[IC->getPredicate ()];
219
221
InstVecMap[&I] = InstVector;
220
222
BBVector += InstVector;
221
223
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
250
252
// embeddings
251
253
auto InstVector =
252
254
Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
255
+ // Add compare predicate embedding as an additional operand if applicable
256
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
257
+ InstVector += Vocab[IC->getPredicate ()];
253
258
InstVecMap[&I] = InstVector;
254
259
BBVector += InstVector;
255
260
}
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
285
290
unsigned Vocabulary::getSlotIndex (const Value &Op) {
286
291
unsigned Index = static_cast <unsigned >(getOperandKind (&Op));
287
292
assert (Index < MaxOperandKinds && " Invalid OperandKind" );
288
- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
293
+ return OperandBaseOffset + Index;
294
+ }
295
+
296
+ unsigned Vocabulary::getSlotIndex (CmpInst::Predicate P) {
297
+ unsigned PU = static_cast <unsigned >(P);
298
+ unsigned FirstFC = static_cast <unsigned >(CmpInst::FIRST_FCMP_PREDICATE);
299
+ unsigned FirstIC = static_cast <unsigned >(CmpInst::FIRST_ICMP_PREDICATE);
300
+
301
+ unsigned PredIdx =
302
+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
303
+ return PredicateBaseOffset + PredIdx;
289
304
}
290
305
291
306
const Embedding &Vocabulary::operator [](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
300
315
return Vocab[getSlotIndex (Arg)];
301
316
}
302
317
318
+ const ir2vec::Embedding &Vocabulary::operator [](CmpInst::Predicate P) const {
319
+ return Vocab[getSlotIndex (P)];
320
+ }
321
+
303
322
StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
304
323
assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
305
324
#define HANDLE_INST (NUM, OPCODE, CLASS ) \
@@ -345,18 +364,35 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
345
364
return OperandKind::VariableID;
346
365
}
347
366
367
+ CmpInst::Predicate Vocabulary::getPredicate (unsigned Index) {
368
+ assert (Index < MaxPredicateKinds && " Invalid predicate index" );
369
+ unsigned PredEnumVal =
370
+ (Index < NumFCmpPredicates)
371
+ ? (static_cast <unsigned >(CmpInst::FIRST_FCMP_PREDICATE) + Index)
372
+ : (static_cast <unsigned >(CmpInst::FIRST_ICMP_PREDICATE) +
373
+ (Index - NumFCmpPredicates));
374
+ return static_cast <CmpInst::Predicate>(PredEnumVal);
375
+ }
376
+
377
+ StringRef Vocabulary::getVocabKeyForPredicate (CmpInst::Predicate Pred) {
378
+ return CmpInst::getPredicateName (Pred);
379
+ }
380
+
348
381
StringRef Vocabulary::getStringKey (unsigned Pos) {
349
382
assert (Pos < NumCanonicalEntries && " Position out of bounds in vocabulary" );
350
383
// Opcode
351
384
if (Pos < MaxOpcodes)
352
385
return getVocabKeyForOpcode (Pos + 1 );
353
386
// Type
354
- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs )
387
+ if (Pos < OperandBaseOffset )
355
388
return getVocabKeyForCanonicalTypeID (
356
389
static_cast <CanonicalTypeID>(Pos - MaxOpcodes));
357
390
// Operand
358
- return getVocabKeyForOperandKind (
359
- static_cast <OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
391
+ if (Pos < PredicateBaseOffset)
392
+ return getVocabKeyForOperandKind (
393
+ static_cast <OperandKind>(Pos - OperandBaseOffset));
394
+ // Predicates
395
+ return getVocabKeyForPredicate (getPredicate (Pos - PredicateBaseOffset));
360
396
}
361
397
362
398
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +406,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
370
406
VocabVector DummyVocab;
371
407
DummyVocab.reserve (NumCanonicalEntries);
372
408
float DummyVal = 0 .1f ;
373
- // Create a dummy vocabulary with entries for all opcodes, types, and
374
- // operands
375
- for ([[maybe_unused]] unsigned _ :
376
- seq (0u , Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
377
- Vocabulary::MaxOperandKinds)) {
409
+ // Create a dummy vocabulary with entries for all opcodes, types, operands
410
+ // and predicates
411
+ for ([[maybe_unused]] unsigned _ : seq (0u , Vocabulary::NumCanonicalEntries)) {
378
412
DummyVocab.push_back (Embedding (Dim, DummyVal));
379
413
DummyVal += 0 .1f ;
380
414
}
@@ -517,6 +551,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
517
551
}
518
552
Vocab.insert (Vocab.end (), NumericArgEmbeddings.begin (),
519
553
NumericArgEmbeddings.end ());
554
+
555
+ // Handle Predicates: part of Operands section. We look up predicate keys
556
+ // in ArgVocab.
557
+ std::vector<Embedding> NumericPredEmbeddings (Vocabulary::MaxPredicateKinds,
558
+ Embedding (Dim, 0 ));
559
+ NumericPredEmbeddings.reserve (Vocabulary::MaxPredicateKinds);
560
+ for (unsigned PK : seq (0u , Vocabulary::MaxPredicateKinds)) {
561
+ StringRef VocabKey =
562
+ Vocabulary::getVocabKeyForPredicate (Vocabulary::getPredicate (PK));
563
+ auto It = ArgVocab.find (VocabKey.str ());
564
+ if (It != ArgVocab.end ()) {
565
+ NumericPredEmbeddings[PK] = It->second ;
566
+ continue ;
567
+ }
568
+ handleMissingEntity (VocabKey.str ());
569
+ }
570
+ Vocab.insert (Vocab.end (), NumericPredEmbeddings.begin (),
571
+ NumericPredEmbeddings.end ());
520
572
}
521
573
522
574
IR2VecVocabAnalysis::IR2VecVocabAnalysis (const VocabVector &Vocab)
0 commit comments