Skip to content

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Oct 1, 2025

This PR introduces the initial infrastructure and vocabulary necessary for generating embeddings for MIR (discussed briefly in the earlier IR2Vec RFC - https://discourse.llvm.org/t/rfc-enhancing-mlgo-inlining-with-ir2vec-embeddings). The MIR2Vec embeddings are useful in driving target specific optimizations that work on MIR like register allocation.

(Tracking issue - #141817)

@svkeerthy svkeerthy changed the title Introducing MIR2Vec Adding initial infrastructure for supporting MIR2Vec Oct 1, 2025
@svkeerthy svkeerthy changed the title Adding initial infrastructure for supporting MIR2Vec Initial infrastructure for supporting MIR2Vec Oct 1, 2025
@svkeerthy svkeerthy changed the title Initial infrastructure for supporting MIR2Vec Initial infrastructure for MIR2Vec Oct 1, 2025
@svkeerthy svkeerthy changed the title Initial infrastructure for MIR2Vec [IR2Vec] Initial infrastructure for MIR2Vec Oct 1, 2025
@svkeerthy svkeerthy marked this pull request as ready for review October 1, 2025 00:23
@llvmbot llvmbot added llvm:codegen mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Oct 1, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 1, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

This PR introduces the initial infrastructure and vocabulary necessary for generating embeddings for MIR (discussed briefly in the earlier IR2Vec RFC - https://discourse.llvm.org/t/rfc-enhancing-mlgo-inlining-with-ir2vec-embeddings). The MIR2Vec embeddings are useful in driving target specific optimizations that work on MIR like register allocation.

(Tracking issue - #141817)


Patch is 1.46 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161463.diff

12 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+4-2)
  • (added) llvm/include/llvm/CodeGen/MIR2Vec.h (+191)
  • (modified) llvm/include/llvm/CodeGen/Passes.h (+4)
  • (modified) llvm/include/llvm/InitializePasses.h (+2)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+43-43)
  • (added) llvm/lib/Analysis/models/x86SeedEmbeddingVocab100D.json (+677)
  • (modified) llvm/lib/CodeGen/CMakeLists.txt (+1)
  • (modified) llvm/lib/CodeGen/CodeGen.cpp (+2)
  • (added) llvm/lib/CodeGen/MIR2Vec.cpp (+325)
  • (modified) llvm/tools/llc/llc.cpp (+23-5)
  • (modified) llvm/unittests/CodeGen/CMakeLists.txt (+1)
  • (added) llvm/unittests/CodeGen/MIR2VecTest.cpp (+197)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 4d02f8e05ace0..d3fea1e0980c7 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -210,6 +210,10 @@ class VocabStorage {
   const_iterator end() const {
     return const_iterator(this, getNumSections(), 0);
   }
+  using VocabMap = std::map<std::string, ir2vec::Embedding>;
+  static Error parseVocabSection(StringRef Key,
+                                 const json::Value &ParsedVocabValue,
+                                 VocabMap &TargetVocab, unsigned &Dim);
 };
 
 /// Class for storing and accessing the IR2Vec vocabulary.
@@ -593,8 +597,6 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
 
   Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
                        VocabMap &ArgVocab);
-  Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
-                          VocabMap &TargetVocab, unsigned &Dim);
   void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
                             VocabMap &ArgVocab);
   void emitError(Error Err, LLVMContext &Ctx);
diff --git a/llvm/include/llvm/CodeGen/MIR2Vec.h b/llvm/include/llvm/CodeGen/MIR2Vec.h
new file mode 100644
index 0000000000000..dc97e1c616112
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/MIR2Vec.h
@@ -0,0 +1,191 @@
+//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions. See the LICENSE file for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabAnalysis),
+/// the core mir2vec::Embedder interface for generating Machine IR embeddings,
+/// and related utilities.
+///
+/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
+/// LLVM Machine IR as embeddings which can be used as input to machine learning
+/// algorithms.
+///
+/// The original idea of MIR2Vec is described in the following paper:
+///
+/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
+/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
+/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
+/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
+/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
+/// https://arxiv.org/abs/2204.02013
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_MIR2VEC_H
+#define LLVM_CODEGEN_MIR2VEC_H
+
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/CodeGen/MachineBasicBlock.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorOr.h"
+#include <map>
+#include <set>
+#include <string>
+
+namespace llvm {
+
+class Module;
+class raw_ostream;
+class LLVMContext;
+class MIR2VecVocabAnalysis;
+class TargetInstrInfo;
+
+namespace mir2vec {
+
+// Forward declarations
+class Embedder;
+class SymbolicEmbedder;
+class FlowAwareEmbedder;
+
+extern llvm::cl::OptionCategory MIR2VecCategory;
+extern cl::opt<float> OpcWeight;
+
+using Embedding = ir2vec::Embedding;
+
+/// Class for storing and accessing the MIR2Vec vocabulary.
+/// The Vocabulary class manages seed embeddings for LLVM Machine IR
+class Vocabulary {
+  friend class llvm::MIR2VecVocabAnalysis;
+  using VocabMap = std::map<std::string, ir2vec::Embedding>;
+
+public:
+  // Define vocabulary layout - adapted for MIR
+  struct {
+    unsigned OpcodeBase = 0;
+    unsigned OperandBase = 0;
+    unsigned TotalEntries = 0;
+  } Layout;
+
+private:
+  ir2vec::VocabStorage Storage;
+  mutable std::set<std::string> UniqueBaseOpcodeNames;
+  void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
+  void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
+
+public:
+  /// Static helper method for extracting base opcode names (public for testing)
+  static std::string extractBaseOpcodeName(StringRef InstrName);
+
+  /// Helper method for getting canonical index for base name (public for
+  /// testing)
+  unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
+
+  /// Get the string key for a vocabulary entry at the given position
+  std::string getStringKey(unsigned Pos) const;
+
+  Vocabulary() = default;
+  Vocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
+  Vocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
+
+  bool isValid() const;
+  unsigned getDimension() const;
+
+  // Accessor methods
+  const Embedding &operator[](unsigned Index) const;
+
+  // Iterator access
+  using const_iterator = ir2vec::VocabStorage::const_iterator;
+  const_iterator begin() const;
+  const_iterator end() const;
+};
+
+} // namespace mir2vec
+
+/// Pass to analyze and populate MIR2Vec vocabulary from a module
+class MIR2VecVocabAnalysis : public ImmutablePass {
+  using VocabVector = std::vector<mir2vec::Embedding>;
+  using VocabMap = std::map<std::string, mir2vec::Embedding>;
+  VocabMap StrVocabMap;
+  VocabVector Vocab;
+
+  StringRef getPassName() const override;
+  Error readVocabulary();
+  void emitError(Error Err, LLVMContext &Ctx);
+
+protected:
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MachineModuleInfoWrapperPass>();
+    AU.setPreservesAll();
+  }
+
+public:
+  static char ID;
+  MIR2VecVocabAnalysis() : ImmutablePass(ID) {}
+  mir2vec::Vocabulary getMIR2VecVocabulary(const Module &M);
+};
+
+/// This pass prints the MIR2Vec embeddings for instructions, basic blocks, and
+/// functions.
+class MIR2VecPrinterPass : public PassInfoMixin<MIR2VecPrinterPass> {
+  raw_ostream &OS;
+
+public:
+  explicit MIR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+  static bool isRequired() { return true; }
+};
+
+/// This pass prints the embeddings in the MIR2Vec vocabulary
+class MIR2VecVocabPrinterPass : public MachineFunctionPass {
+  raw_ostream &OS;
+
+public:
+  static char ID;
+  explicit MIR2VecVocabPrinterPass(raw_ostream &OS)
+      : MachineFunctionPass(ID), OS(OS) {}
+
+  bool runOnMachineFunction(MachineFunction &MF) override;
+  bool doFinalization(Module &M) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<MIR2VecVocabAnalysis>();
+    AU.setPreservesAll();
+    MachineFunctionPass::getAnalysisUsage(AU);
+  }
+
+  StringRef getPassName() const override {
+    return "MIR2Vec Vocabulary Printer Pass";
+  }
+};
+
+/// Old PM version of the printer pass
+class MIR2VecPrinterLegacyPass : public ModulePass {
+  raw_ostream &OS;
+
+public:
+  static char ID;
+  explicit MIR2VecPrinterLegacyPass(raw_ostream &OS) : ModulePass(ID), OS(OS) {}
+
+  bool runOnModule(Module &M) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+    AU.addRequired<MIR2VecVocabAnalysis>();
+    AU.addRequired<MachineModuleInfoWrapperPass>();
+  }
+
+  StringRef getPassName() const override { return "MIR2Vec Printer Pass"; }
+};
+
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_MIR2VEC_H
\ No newline at end of file
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index 593308150dc82..a9c608f4a17fc 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -87,6 +87,10 @@ LLVM_ABI MachineFunctionPass *
 createMachineFunctionPrinterPass(raw_ostream &OS,
                                  const std::string &Banner = "");
 
+/// MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary
+/// contents to the given stream as a debugging tool.
+LLVM_ABI MachineFunctionPass *createMIR2VecVocabPrinterPass(raw_ostream &OS);
+
 /// StackFramePrinter pass - This pass prints out the machine function's
 /// stack frame to the given stream as a debugging tool.
 LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index a99b0ff188ea8..a4b0f31d3cd49 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -220,6 +220,8 @@ LLVM_ABI void initializeMachinePostDominatorTreeWrapperPassPass(PassRegistry &);
 LLVM_ABI void initializeMachineRegionInfoPassPass(PassRegistry &);
 LLVM_ABI void
 initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
+LLVM_ABI void initializeMIR2VecVocabAnalysisPass(PassRegistry &);
+LLVM_ABI void initializeMIR2VecVocabPrinterPassPass(PassRegistry &);
 LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
 LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 271f004b0a787..eeefc1e2709c2 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -329,6 +329,43 @@ bool VocabStorage::const_iterator::operator!=(
   return !(*this == Other);
 }
 
+Error VocabStorage::parseVocabSection(StringRef Key,
+                                      const json::Value &ParsedVocabValue,
+                                      VocabMap &TargetVocab, unsigned &Dim) {
+  json::Path::Root Path("");
+  const json::Object *RootObj = ParsedVocabValue.getAsObject();
+  if (!RootObj)
+    return createStringError(errc::invalid_argument,
+                             "JSON root is not an object");
+
+  const json::Value *SectionValue = RootObj->get(Key);
+  if (!SectionValue)
+    return createStringError(errc::invalid_argument,
+                             "Missing '" + std::string(Key) +
+                                 "' section in vocabulary file");
+  if (!json::fromJSON(*SectionValue, TargetVocab, Path))
+    return createStringError(errc::illegal_byte_sequence,
+                             "Unable to parse '" + std::string(Key) +
+                                 "' section from vocabulary");
+
+  Dim = TargetVocab.begin()->second.size();
+  if (Dim == 0)
+    return createStringError(errc::illegal_byte_sequence,
+                             "Dimension of '" + std::string(Key) +
+                                 "' section of the vocabulary is zero");
+
+  if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
+                   [Dim](const std::pair<StringRef, Embedding> &Entry) {
+                     return Entry.second.size() == Dim;
+                   }))
+    return createStringError(
+        errc::illegal_byte_sequence,
+        "All vectors in the '" + std::string(Key) +
+            "' section of the vocabulary are not of the same dimension");
+
+  return Error::success();
+}
+
 // ==----------------------------------------------------------------------===//
 // Vocabulary
 //===----------------------------------------------------------------------===//
@@ -459,43 +496,6 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
 // IR2VecVocabAnalysis
 //===----------------------------------------------------------------------===//
 
-Error IR2VecVocabAnalysis::parseVocabSection(
-    StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
-    unsigned &Dim) {
-  json::Path::Root Path("");
-  const json::Object *RootObj = ParsedVocabValue.getAsObject();
-  if (!RootObj)
-    return createStringError(errc::invalid_argument,
-                             "JSON root is not an object");
-
-  const json::Value *SectionValue = RootObj->get(Key);
-  if (!SectionValue)
-    return createStringError(errc::invalid_argument,
-                             "Missing '" + std::string(Key) +
-                                 "' section in vocabulary file");
-  if (!json::fromJSON(*SectionValue, TargetVocab, Path))
-    return createStringError(errc::illegal_byte_sequence,
-                             "Unable to parse '" + std::string(Key) +
-                                 "' section from vocabulary");
-
-  Dim = TargetVocab.begin()->second.size();
-  if (Dim == 0)
-    return createStringError(errc::illegal_byte_sequence,
-                             "Dimension of '" + std::string(Key) +
-                                 "' section of the vocabulary is zero");
-
-  if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
-                   [Dim](const std::pair<StringRef, Embedding> &Entry) {
-                     return Entry.second.size() == Dim;
-                   }))
-    return createStringError(
-        errc::illegal_byte_sequence,
-        "All vectors in the '" + std::string(Key) +
-            "' section of the vocabulary are not of the same dimension");
-
-  return Error::success();
-}
-
 // FIXME: Make this optional. We can avoid file reads
 // by auto-generating a default vocabulary during the build time.
 Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
@@ -512,16 +512,16 @@ Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
     return ParsedVocabValue.takeError();
 
   unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
-  if (auto Err =
-          parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
+  if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
+                                                 OpcVocab, OpcodeDim))
     return Err;
 
-  if (auto Err =
-          parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
+  if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
+                                                 TypeVocab, TypeDim))
     return Err;
 
-  if (auto Err =
-          parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
+  if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
+                                                 ArgVocab, ArgDim))
     return Err;
 
   if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
diff --git a/llvm/lib/Analysis/models/x86SeedEmbeddingVocab100D.json b/llvm/lib/Analysis/models/x86SeedEmbeddingVocab100D.json
new file mode 100644
index 0000000000000..0afe5c77d0461
--- /dev/null
+++ b/llvm/lib/Analysis/models/x86SeedEmbeddingVocab100D.json
@@ -0,0 +1,677 @@
+{
+    "entities" : {
+        "ABS_Fp":[0.07323841750621796, -0.006006906274706125, 0.09751169383525848, -0.011089739389717579, 0.06642112135887146, -0.015824640169739723, -0.021592319011688232, -0.0035401300992816687, 0.06047678738832474, -0.007392085622996092, 0.07134906202554703, -0.019624482840299606, -0.10975595563650131, -0.007685789838433266, 0.07451746612787247, 0.06384266912937164, -0.08230067789554596, 0.050922468304634094, 0.013724055141210556, 0.015687907114624977, -0.018451329320669174, 0.046987198293209076, -0.037734340876340866, -0.07235030829906464, 0.10218106210231781, 0.08037368208169937, -0.029537858441472054, -0.047520823776721954, -0.022125739604234695, -0.03125226870179176, -0.02882847562432289, 0.013811410404741764, 0.0023568253964185715, 0.017958490177989006, -0.05359291657805443, -0.03606243059039116, 0.07840022444725037, -0.016711654141545296, -0.038644544780254364, 0.05886651948094368, -0.011418955400586128, -0.04882095381617546, 0.04027162492275238, 0.001088760793209076, 0.03045983798801899, -0.10998888313770294, -0.0097441291436553, 0.015445191413164139, 0.030951637774705887, -0.06309321522712708, -0.019475746899843216, -0.029662512242794037, 0.05312168970704079, 0.05355998873710632, 0.05060160160064697, -0.053278811275959015, -0.01803833432495594, 0.010853713378310204, -0.053911495953798294, 0.06630647927522659, -0.08671313524246216, 0.0699775293469429, -0.08346731215715408, -0.045348167419433594, 0.06779918074607849, 0.008865933865308762, 0.05460203066468239, 0.007126103155314922, 0.0012282058596611023, 0.06817980855703354, 0.0216530654579401, 0.03552381321787834, 0.015414077788591385, -0.06002715229988098, 0.05233345925807953, 0.0782286673784256, 0.04220856353640556, -0.005762201733887196, 0.004772072657942772, 0.004578332882374525, 0.002619141712784767, 0.024511393159627914, -0.10089710354804993, 0.018322769552469254, 0.020811809226870537, -0.03358744457364082, -0.06896928697824478, -0.007399350870400667, -0.044467780739068985, -0.08094192296266556, -0.09795571863651276, 0.08391229063272476, -0.04749457910656929, 0.0029586481396108866, -5.354872337193228e-05, 0.005788655485957861, 0.015252145007252693, 0.06928747892379761, 0.041780371218919754, 0.016391364857554436],
+        "ADC":[-0.07533542811870575, -0.01729339174926281, 0.04298720881342888, 0.015697332099080086, -0.04403507336974144, -0.059322185814380646, -0.050977922976017, 0.027526788413524628, -0.07009710371494293, -0.025621667504310608, 0.0352291613817215, -0.011538374237716198, 0.03682859241962433, -0.09788215160369873, -0.07216927409172058, -0.03659192472696304, 0.05676230415701866, -0.06369645893573761, -0.04756825789809227, 0.005865555722266436, 0.022270306944847107, -0.042112063616514206, 0.07008901983499527, 0.07748222351074219, -0.1020870953798294, -0.008511601015925407, -0.05725255608558655, -0.07881367206573486, 0.05627593398094177, -0.0005361076910048723, 0.03351512551307678, 0.04348289221525192, -0.08322969079017639, -0.02161242999136448, -0.07805898040533066, 0.04819482937455177, -0.061123576015233994, -0.010114834643900394, -0.04676959663629532, -0.008176938630640507, 0.010575453750789165, -0.04312445595860481, 0.00376943894661963, -0.0691257119178772, 0.03553615137934685, 0.10397598147392273, 0.009375158697366714, 0.001147320494055748, 0.026351911947131157, -0.0194610096514225, -0.05202522128820419, 0.014047946780920029, -0.040036872029304504, 0.06963572651147842, 0.04827437922358513, -0.06908547878265381, 0.024857567623257637, -0.03304143249988556, 0.02291242778301239, 0.07687342166900635, -0.05110599845647812, -0.00873416755348444, 0.026205750182271004, 0.045064594596624374, -0.03565925359725952, 0.09580051153898239, -0.02518773265182972, 0.047807395458221436, -0.03548192232847214, 0.08286304026842117, -0.053511787205934525, 0.02892065793275833, -0.0495525486767292, 0.02590095065534115, -0.006982128601521254, 0.006042638327926397, -0.07269058376550674, 0.02401554025709629, -0.05660006031394005, -0.026029467582702637, 0.05318204686045647, 0.06714116781949997, -0.0023821850772947073, 0.05028798058629036, -0.005811943672597408, -0.003296421840786934, -0.005409242119640112, -0.10150349885225296, -0.06406981498003006, 0.02553202211856842, -0.002790689468383789, 0.0663856491446495, 0.09109167754650116, -0.04678672179579735, 0.022019781172275543, 0.007821275852620602, 0.022490357980132103, -0.058503177016973495, 0.08841150254011154, -0.00892670825123787],
+        "ADD":[-0.037626221776008606, 0.006784931290894747, 0.10051396489143372, -0.0014993306249380112, -0.0323498398065567, -0.03148593008518219, -0.014100957661867142, -0.020252650603652, 0.014126972295343876, -0.1295478343963623, 0.08520576357841492, -0.02513248659670353, 0.03539956361055374, -0.07019674777984619, -0.019069846719503403, 0.016678515821695328, -0.009174983017146587, -0.019034702330827713, -0.024083402007818222, -0.07829779386520386, -0.007908892817795277, -0.07924024760723114, -0.034599609673023224, 0.05271153524518013, 0.0016642026603221893, -0.03938138112425804, 0.0019624519627541304, 0.03562740981578827, 0.07340876758098602, 0.09457183629274368, -0.06507840752601624, 0.00246993126347661, -0.004548616707324982, 0.058226197957992554, -0.021043049171566963, -0.0599520243704319, -0.03138553351163864, 0.03265950828790665, 0.004963710438460112, -0.003248866181820631, -0.04021746292710304, 0.038208190351724625, -0.02256007120013237, 0.10770396143198013, 0.013757425360381603, 0.040707558393478394, -0.00694271270185709, -0.012331271544098854, 0.004992029629647732, -0.032236646860837936, 0.01055158581584692, 0.04604483023285866, 0.09973260760307312, 0.07322807610034943, 0.06853726506233215, 0.004230210557579994, -0.04007832333445549, 0.16341225802898407, -0.01683313027024269, -0.01998194307088852, -0.035919081419706345, -0.055...
[truncated]

const_iterator end() const {
return const_iterator(this, getNumSections(), 0);
}
using VocabMap = std::map<std::string, ir2vec::Embedding>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do this as a NFC?


/// Class for storing and accessing the MIR2Vec vocabulary.
/// The Vocabulary class manages seed embeddings for LLVM Machine IR
class Vocabulary {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rename it to MIRVocabulary - I realize it's in a different namespace, but someone importing both may get confused.

public:
// Define vocabulary layout - adapted for MIR
struct {
unsigned OpcodeBase = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not size_t (because they are size-ish?)

} // namespace mir2vec

/// Pass to analyze and populate MIR2Vec vocabulary from a module
class MIR2VecVocabAnalysis : public ImmutablePass {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanna name it MIR2VecVocabLegacyAnalysis or something like that, so you don't take the good name for the NPM?

Error VocabStorage::parseVocabSection(StringRef Key,
const json::Value &ParsedVocabValue,
VocabMap &TargetVocab, unsigned &Dim) {
json::Path::Root Path("");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can't be reused with the IR one?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:codegen mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants