3535#include " llvm/CodeGen/MachineFunctionPass.h"
3636#include " llvm/CodeGen/MachineInstr.h"
3737#include " llvm/CodeGen/MachineModuleInfo.h"
38+ #include " llvm/CodeGen/MachineOperand.h"
39+ #include " llvm/CodeGen/MachineRegisterInfo.h"
3840#include " llvm/IR/PassManager.h"
3941#include " llvm/Pass.h"
4042#include " llvm/Support/CommandLine.h"
@@ -61,7 +63,7 @@ class MIREmbedder;
6163class SymbolicMIREmbedder ;
6264
6365extern llvm::cl::OptionCategory MIR2VecCategory;
64- extern cl::opt<float > OpcWeight;
66+ extern cl::opt<float > OpcWeight, CommonOperandWeight, RegOperandWeight ;
6567
6668using Embedding = ir2vec::Embedding;
6769using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
@@ -74,31 +76,114 @@ class MIRVocabulary {
7476 friend class llvm ::MIR2VecVocabLegacyAnalysis;
7577 using VocabMap = std::map<std::string, ir2vec::Embedding>;
7678
77- private:
78- // Define vocabulary layout - adapted for MIR
79+ // MIRVocabulary Layout:
80+ // +-------------------+-----------------------------------------------------+
81+ // | Entity Type | Description |
82+ // +-------------------+-----------------------------------------------------+
83+ // | 1. Opcodes | Target specific opcodes derived from TII, grouped |
84+ // | | by instruction semantics. |
85+ // | 2. Common Operands| All common operand types, except register operands, |
86+ // | | defined by MachineOperand::MachineOperandType enum. |
87+ // | 3. Physical | Register classes defined by the target, specialized |
88+ // | Reg classes | by physical registers. |
89+ // | 4. Virtual | Register classes defined by the target, specialized |
90+ // | Reg classes | by virtual and physical registers. |
91+ // +-------------------+-----------------------------------------------------+
92+
93+ // / Layout information for the MIR vocabulary. Defines the starting index
94+ // / and size of each section in the vocabulary.
7995 struct {
8096 size_t OpcodeBase = 0 ;
81- size_t OperandBase = 0 ;
97+ size_t CommonOperandBase = 0 ;
98+ size_t PhyRegBase = 0 ;
99+ size_t VirtRegBase = 0 ;
82100 size_t TotalEntries = 0 ;
83101 } Layout;
84102
85- enum class Section : unsigned { Opcodes = 0 , MaxSections };
103+ enum class Section : unsigned {
104+ Opcodes = 0 ,
105+ CommonOperands = 1 ,
106+ PhyRegisters = 2 ,
107+ VirtRegisters = 3 ,
108+ MaxSections
109+ };
86110
87111 ir2vec::VocabStorage Storage;
88112 mutable std::set<std::string> UniqueBaseOpcodeNames;
113+ mutable SmallVector<std::string, 24 > RegisterOperandNames;
114+
115+ // Some instructions have optional register operands that may be NoRegister.
116+ // We return a zero vector in such cases.
117+ mutable Embedding ZeroEmbedding;
118+
119+ // We have specialized MO_Register handling in the Register operand section,
120+ // so we don't include it here. Also, no MO_DbgInstrRef for now.
121+ static constexpr StringLiteral CommonOperandNames[] = {
122+ " Immediate" , " CImmediate" , " FPImmediate" , " MBB" ,
123+ " FrameIndex" , " ConstantPoolIndex" , " TargetIndex" , " JumpTableIndex" ,
124+ " ExternalSymbol" , " GlobalAddress" , " BlockAddress" , " RegisterMask" ,
125+ " RegisterLiveOut" , " Metadata" , " MCSymbol" , " CFIIndex" ,
126+ " IntrinsicID" , " Predicate" , " ShuffleMask" };
127+ static_assert (std::size(CommonOperandNames) == MachineOperand::MO_Last - 1 &&
128+ " Common operand names size changed, update accordingly" );
129+
89130 const TargetInstrInfo &TII;
90- void generateStorage (const VocabMap &OpcodeMap);
131+ const TargetRegisterInfo &TRI;
132+ const MachineRegisterInfo &MRI;
133+
134+ void generateStorage (const VocabMap &OpcodeMap,
135+ const VocabMap &CommonOperandMap,
136+ const VocabMap &PhyRegMap, const VocabMap &VirtRegMap);
91137 void buildCanonicalOpcodeMapping ();
138+ void buildRegisterOperandMapping ();
92139
93140 // / Get canonical index for a machine opcode
94141 unsigned getCanonicalOpcodeIndex (unsigned Opcode) const ;
95142
143+ // / Get index for a common (non-register) machine operand
144+ unsigned
145+ getCommonOperandIndex (MachineOperand::MachineOperandType OperandType) const ;
146+
147+ // / Get index for a register machine operand
148+ unsigned getRegisterOperandIndex (Register Reg) const ;
149+
150+ // Accessors for operand types
151+ const Embedding &
152+ operator [](MachineOperand::MachineOperandType OperandType) const {
153+ unsigned LocalIndex = getCommonOperandIndex (OperandType);
154+ return Storage[static_cast <unsigned >(Section::CommonOperands)][LocalIndex];
155+ }
156+
157+ const Embedding &operator [](Register Reg) const {
158+ // Reg is sometimes NoRegister (0) for optional operands. We return a zero
159+ // vector in this case.
160+ if (!Reg.isValid ())
161+ return ZeroEmbedding;
162+ // TODO: Implement proper stack slot handling for MIR2Vec embeddings.
163+ // Stack slots represent frame indices and should have their own
164+ // embedding strategy rather than defaulting to register class 0.
165+ // Consider: 1) Separate vocabulary section for stack slots
166+ // 2) Stack slot size/alignment based embeddings
167+ // 3) Frame index based categorization
168+ if (Reg.isStack ())
169+ return ZeroEmbedding;
170+
171+ unsigned LocalIndex = getRegisterOperandIndex (Reg);
172+ auto SectionID =
173+ Reg.isPhysical () ? Section::PhyRegisters : Section::VirtRegisters;
174+ return Storage[static_cast <unsigned >(SectionID)][LocalIndex];
175+ }
176+
96177public:
97178 // / Static method for extracting base opcode names (public for testing)
98179 static std::string extractBaseOpcodeName (StringRef InstrName);
99180
100- // / Get canonical index for base name (public for testing)
181+ // / Get indices from opcode or operand names. These are public for testing.
182+ // / String based lookups are inefficient and should be avoided in general.
101183 unsigned getCanonicalIndexForBaseName (StringRef BaseName) const ;
184+ unsigned getCanonicalIndexForOperandName (StringRef OperandName) const ;
185+ unsigned getCanonicalIndexForRegisterClass (StringRef RegName,
186+ bool IsPhysical = true ) const ;
102187
103188 // / Get the string key for a vocabulary entry at the given position
104189 std::string getStringKey (unsigned Pos) const ;
@@ -111,6 +196,14 @@ class MIRVocabulary {
111196 return Storage[static_cast <unsigned >(Section::Opcodes)][LocalIndex];
112197 }
113198
199+ const Embedding &operator [](MachineOperand Operand) const {
200+ auto OperandType = Operand.getType ();
201+ if (OperandType == MachineOperand::MO_Register)
202+ return operator [](Operand.getReg ());
203+ else
204+ return operator [](OperandType);
205+ }
206+
114207 // Iterator access
115208 using const_iterator = ir2vec::VocabStorage::const_iterator;
116209 const_iterator begin () const { return Storage.begin (); }
@@ -120,18 +213,25 @@ class MIRVocabulary {
120213 MIRVocabulary () = delete ;
121214
122215 // / Factory method to create MIRVocabulary from vocabulary map
123- static Expected<MIRVocabulary> create (VocabMap &&Entries,
124- const TargetInstrInfo &TII);
216+ static Expected<MIRVocabulary>
217+ create (VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,
218+ VocabMap &&VirtRegMap, const TargetInstrInfo &TII,
219+ const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI);
125220
126221 // / Create a dummy vocabulary for testing purposes.
127222 static Expected<MIRVocabulary>
128- createDummyVocabForTest (const TargetInstrInfo &TII, unsigned Dim = 1 );
223+ createDummyVocabForTest (const TargetInstrInfo &TII,
224+ const TargetRegisterInfo &TRI,
225+ const MachineRegisterInfo &MRI, unsigned Dim = 1 );
129226
130227 // / Total number of entries in the vocabulary
131228 size_t getCanonicalSize () const { return Storage.size (); }
132229
133230private:
134- MIRVocabulary (VocabMap &&Entries, const TargetInstrInfo &TII);
231+ MIRVocabulary (VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,
232+ VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
233+ const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
234+ const MachineRegisterInfo &MRI);
135235};
136236
137237// / Base class for MIR embedders
@@ -144,11 +244,13 @@ class MIREmbedder {
144244 const unsigned Dimension;
145245
146246 // / Weight for opcode embeddings
147- const float OpcWeight;
247+ const float OpcWeight, CommonOperandWeight, RegOperandWeight ;
148248
149249 MIREmbedder (const MachineFunction &MF, const MIRVocabulary &Vocab)
150250 : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
151- OpcWeight (mir2vec::OpcWeight) {}
251+ OpcWeight (mir2vec::OpcWeight),
252+ CommonOperandWeight(mir2vec::CommonOperandWeight),
253+ RegOperandWeight(mir2vec::RegOperandWeight) {}
152254
153255 // / Function to compute embeddings.
154256 Embedding computeEmbeddings () const ;
@@ -208,11 +310,11 @@ class SymbolicMIREmbedder : public MIREmbedder {
208310class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
209311 using VocabVector = std::vector<mir2vec::Embedding>;
210312 using VocabMap = std::map<std::string, mir2vec::Embedding>;
211- VocabMap StrVocabMap;
212- VocabVector Vocab;
313+ std::optional<mir2vec::MIRVocabulary> Vocab;
213314
214315 StringRef getPassName () const override ;
215- Error readVocabulary ();
316+ Error readVocabulary (VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
317+ VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
216318
217319protected:
218320 void getAnalysisUsage (AnalysisUsage &AU) const override {
@@ -275,4 +377,4 @@ MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
275377
276378} // namespace llvm
277379
278- #endif // LLVM_CODEGEN_MIR2VEC_H
380+ #endif // LLVM_CODEGEN_MIR2VEC_H
0 commit comments