Skip to content

Commit caa93a0

Browse files
committed
Handle Operands
1 parent ae13506 commit caa93a0

File tree

13 files changed

+1406
-166
lines changed

13 files changed

+1406
-166
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 119 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "llvm/CodeGen/MachineFunctionPass.h"
3636
#include "llvm/CodeGen/MachineInstr.h"
3737
#include "llvm/CodeGen/MachineModuleInfo.h"
38+
#include "llvm/CodeGen/MachineOperand.h"
39+
#include "llvm/CodeGen/MachineRegisterInfo.h"
3840
#include "llvm/IR/PassManager.h"
3941
#include "llvm/Pass.h"
4042
#include "llvm/Support/CommandLine.h"
@@ -61,7 +63,7 @@ class MIREmbedder;
6163
class SymbolicMIREmbedder;
6264

6365
extern llvm::cl::OptionCategory MIR2VecCategory;
64-
extern cl::opt<float> OpcWeight;
66+
extern cl::opt<float> OpcWeight, CommonOperandWeight, RegOperandWeight;
6567

6668
using Embedding = ir2vec::Embedding;
6769
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
@@ -74,31 +76,114 @@ class MIRVocabulary {
7476
friend class llvm::MIR2VecVocabLegacyAnalysis;
7577
using VocabMap = std::map<std::string, ir2vec::Embedding>;
7678

77-
private:
78-
// Define vocabulary layout - adapted for MIR
79+
// MIRVocabulary Layout:
80+
// +-------------------+-----------------------------------------------------+
81+
// | Entity Type | Description |
82+
// +-------------------+-----------------------------------------------------+
83+
// | 1. Opcodes | Target specific opcodes derived from TII, grouped |
84+
// | | by instruction semantics. |
85+
// | 2. Common Operands| All common operand types, except register operands, |
86+
// | | defined by MachineOperand::MachineOperandType enum. |
87+
// | 3. Physical | Register classes defined by the target, specialized |
88+
// | Reg classes | by physical registers. |
89+
// | 4. Virtual | Register classes defined by the target, specialized |
90+
// | Reg classes | by virtual and physical registers. |
91+
// +-------------------+-----------------------------------------------------+
92+
93+
/// Layout information for the MIR vocabulary. Defines the starting index
94+
/// and size of each section in the vocabulary.
7995
struct {
8096
size_t OpcodeBase = 0;
81-
size_t OperandBase = 0;
97+
size_t CommonOperandBase = 0;
98+
size_t PhyRegBase = 0;
99+
size_t VirtRegBase = 0;
82100
size_t TotalEntries = 0;
83101
} Layout;
84102

85-
enum class Section : unsigned { Opcodes = 0, MaxSections };
103+
enum class Section : unsigned {
104+
Opcodes = 0,
105+
CommonOperands = 1,
106+
PhyRegisters = 2,
107+
VirtRegisters = 3,
108+
MaxSections
109+
};
86110

87111
ir2vec::VocabStorage Storage;
88112
mutable std::set<std::string> UniqueBaseOpcodeNames;
113+
mutable SmallVector<std::string, 24> RegisterOperandNames;
114+
115+
// Some instructions have optional register operands that may be NoRegister.
116+
// We return a zero vector in such cases.
117+
mutable Embedding ZeroEmbedding;
118+
119+
// We have specialized MO_Register handling in the Register operand section,
120+
// so we don't include it here. Also, no MO_DbgInstrRef for now.
121+
static constexpr StringLiteral CommonOperandNames[] = {
122+
"Immediate", "CImmediate", "FPImmediate", "MBB",
123+
"FrameIndex", "ConstantPoolIndex", "TargetIndex", "JumpTableIndex",
124+
"ExternalSymbol", "GlobalAddress", "BlockAddress", "RegisterMask",
125+
"RegisterLiveOut", "Metadata", "MCSymbol", "CFIIndex",
126+
"IntrinsicID", "Predicate", "ShuffleMask"};
127+
static_assert(std::size(CommonOperandNames) == MachineOperand::MO_Last - 1 &&
128+
"Common operand names size changed, update accordingly");
129+
89130
const TargetInstrInfo &TII;
90-
void generateStorage(const VocabMap &OpcodeMap);
131+
const TargetRegisterInfo &TRI;
132+
const MachineRegisterInfo &MRI;
133+
134+
void generateStorage(const VocabMap &OpcodeMap,
135+
const VocabMap &CommonOperandMap,
136+
const VocabMap &PhyRegMap, const VocabMap &VirtRegMap);
91137
void buildCanonicalOpcodeMapping();
138+
void buildRegisterOperandMapping();
92139

93140
/// Get canonical index for a machine opcode
94141
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
95142

143+
/// Get index for a common (non-register) machine operand
144+
unsigned
145+
getCommonOperandIndex(MachineOperand::MachineOperandType OperandType) const;
146+
147+
/// Get index for a register machine operand
148+
unsigned getRegisterOperandIndex(Register Reg) const;
149+
150+
// Accessors for operand types
151+
const Embedding &
152+
operator[](MachineOperand::MachineOperandType OperandType) const {
153+
unsigned LocalIndex = getCommonOperandIndex(OperandType);
154+
return Storage[static_cast<unsigned>(Section::CommonOperands)][LocalIndex];
155+
}
156+
157+
const Embedding &operator[](Register Reg) const {
158+
// Reg is sometimes NoRegister (0) for optional operands. We return a zero
159+
// vector in this case.
160+
if (!Reg.isValid())
161+
return ZeroEmbedding;
162+
// TODO: Implement proper stack slot handling for MIR2Vec embeddings.
163+
// Stack slots represent frame indices and should have their own
164+
// embedding strategy rather than defaulting to register class 0.
165+
// Consider: 1) Separate vocabulary section for stack slots
166+
// 2) Stack slot size/alignment based embeddings
167+
// 3) Frame index based categorization
168+
if (Reg.isStack())
169+
return ZeroEmbedding;
170+
171+
unsigned LocalIndex = getRegisterOperandIndex(Reg);
172+
auto SectionID =
173+
Reg.isPhysical() ? Section::PhyRegisters : Section::VirtRegisters;
174+
return Storage[static_cast<unsigned>(SectionID)][LocalIndex];
175+
}
176+
96177
public:
97178
/// Static method for extracting base opcode names (public for testing)
98179
static std::string extractBaseOpcodeName(StringRef InstrName);
99180

100-
/// Get canonical index for base name (public for testing)
181+
/// Get indices from opcode or operand names. These are public for testing.
182+
/// String based lookups are inefficient and should be avoided in general.
101183
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
184+
unsigned getCanonicalIndexForOperandName(StringRef OperandName) const;
185+
unsigned getCanonicalIndexForRegisterClass(StringRef RegName,
186+
bool IsPhysical = true) const;
102187

103188
/// Get the string key for a vocabulary entry at the given position
104189
std::string getStringKey(unsigned Pos) const;
@@ -111,6 +196,14 @@ class MIRVocabulary {
111196
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
112197
}
113198

199+
const Embedding &operator[](MachineOperand Operand) const {
200+
auto OperandType = Operand.getType();
201+
if (OperandType == MachineOperand::MO_Register)
202+
return operator[](Operand.getReg());
203+
else
204+
return operator[](OperandType);
205+
}
206+
114207
// Iterator access
115208
using const_iterator = ir2vec::VocabStorage::const_iterator;
116209
const_iterator begin() const { return Storage.begin(); }
@@ -120,18 +213,25 @@ class MIRVocabulary {
120213
MIRVocabulary() = delete;
121214

122215
/// Factory method to create MIRVocabulary from vocabulary map
123-
static Expected<MIRVocabulary> create(VocabMap &&Entries,
124-
const TargetInstrInfo &TII);
216+
static Expected<MIRVocabulary>
217+
create(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap, VocabMap &&PhyRegMap,
218+
VocabMap &&VirtRegMap, const TargetInstrInfo &TII,
219+
const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI);
125220

126221
/// Create a dummy vocabulary for testing purposes.
127222
static Expected<MIRVocabulary>
128-
createDummyVocabForTest(const TargetInstrInfo &TII, unsigned Dim = 1);
223+
createDummyVocabForTest(const TargetInstrInfo &TII,
224+
const TargetRegisterInfo &TRI,
225+
const MachineRegisterInfo &MRI, unsigned Dim = 1);
129226

130227
/// Total number of entries in the vocabulary
131228
size_t getCanonicalSize() const { return Storage.size(); }
132229

133230
private:
134-
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo &TII);
231+
MIRVocabulary(VocabMap &&OpcMap, VocabMap &&CommonOperandsMap,
232+
VocabMap &&PhyRegMap, VocabMap &&VirtRegMap,
233+
const TargetInstrInfo &TII, const TargetRegisterInfo &TRI,
234+
const MachineRegisterInfo &MRI);
135235
};
136236

137237
/// Base class for MIR embedders
@@ -144,11 +244,13 @@ class MIREmbedder {
144244
const unsigned Dimension;
145245

146246
/// Weight for opcode embeddings
147-
const float OpcWeight;
247+
const float OpcWeight, CommonOperandWeight, RegOperandWeight;
148248

149249
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
150250
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
151-
OpcWeight(mir2vec::OpcWeight) {}
251+
OpcWeight(mir2vec::OpcWeight),
252+
CommonOperandWeight(mir2vec::CommonOperandWeight),
253+
RegOperandWeight(mir2vec::RegOperandWeight) {}
152254

153255
/// Function to compute embeddings.
154256
Embedding computeEmbeddings() const;
@@ -208,11 +310,11 @@ class SymbolicMIREmbedder : public MIREmbedder {
208310
class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
209311
using VocabVector = std::vector<mir2vec::Embedding>;
210312
using VocabMap = std::map<std::string, mir2vec::Embedding>;
211-
VocabMap StrVocabMap;
212-
VocabVector Vocab;
313+
std::optional<mir2vec::MIRVocabulary> Vocab;
213314

214315
StringRef getPassName() const override;
215-
Error readVocabulary();
316+
Error readVocabulary(VocabMap &OpcVocab, VocabMap &CommonOperandVocab,
317+
VocabMap &PhyRegVocabMap, VocabMap &VirtRegVocabMap);
216318

217319
protected:
218320
void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -275,4 +377,4 @@ MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
275377

276378
} // namespace llvm
277379

278-
#endif // LLVM_CODEGEN_MIR2VEC_H
380+
#endif // LLVM_CODEGEN_MIR2VEC_H

0 commit comments

Comments
 (0)