Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/PTO/Transforms/InsertSync/PTOIRTranslator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ class PTOIRTranslator {
syncIR_(syncIR),
buffer2MemInfoMap_(buffer2MemInfoMap),
memAnalyzer_(memDepAnalyzer),
mode_(syncAnalysisMode) { };
mode_(syncAnalysisMode) {
(void)memAnalyzer_;
(void)mode_;
};

// 核心入口:执行 IR 分析和转换
void Build();
Expand Down
3 changes: 1 addition & 2 deletions lib/PTO/IR/PTOSyncUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ PIPE mlir::pto::mapSyncOpTypeToPipe(SyncOpType opType) {
case SyncOpType::TVEC:
case SyncOpType::TVECWAIT_EVENT:
return PIPE::PIPE_V;
default:
return PIPE::PIPE_UNASSIGNED;
}
return PIPE::PIPE_UNASSIGNED;
}

bool mlir::pto::isConcreteSyncPipe(PIPE pipe) {
Expand Down
11 changes: 6 additions & 5 deletions lib/PTO/IR/PTOTypeDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,15 +530,18 @@ Type TileBufType::parse(AsmParser &parser) {
static llvm::StringRef stringifyLocFromMemorySpace(mlir::Attribute memorySpace) {
auto asAttr = llvm::dyn_cast_or_null<AddressSpaceAttr>(memorySpace);
switch (asAttr.getAddressSpace()) {
case AddressSpace::Zero:
case AddressSpace::GM:
return "illegal";
case AddressSpace::MAT: return "mat";
case AddressSpace::LEFT: return "left";
case AddressSpace::RIGHT: return "right";
case AddressSpace::ACC: return "acc";
case AddressSpace::VEC: return "vec";
case AddressSpace::BIAS: return "bias";
case AddressSpace::SCALING: return "scaling";
default: return "illegal";
}
return "illegal";
}

static llvm::StringRef stringifyLocFromPad(mlir::Attribute pad) {
Expand All @@ -550,9 +553,8 @@ static llvm::StringRef stringifyLocFromPad(mlir::Attribute pad) {
case PadValue::Zero: return "1";
case PadValue::Max: return "2";
case PadValue::Min: return "3";
default:
return "9999";
}
return "9999";
}

static llvm::StringRef stringifyCompactModeInt(mlir::Attribute compactMode) {
Expand All @@ -567,9 +569,8 @@ static llvm::StringRef stringifyCompactModeInt(mlir::Attribute compactMode) {
return "1";
case CompactMode::RowPlusOne:
return "2";
default:
return "9999";
}
return "9999";
}

static void printTileBufDim(AsmPrinter &printer, int64_t dim) {
Expand Down
26 changes: 14 additions & 12 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ static const char *addrSpaceQualifier(pto::AddressSpace as) {
return "__gm__";
}

static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName =
[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeAttrName =
"__pto.lowered_set_validshape";
static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName =
[[maybe_unused]] static constexpr llvm::StringLiteral kLoweredSetValidShapeConfigAttrName =
"__pto.lowered_set_validshape_config";
static constexpr llvm::StringLiteral kForceDynamicValidShapeAttrName =
"__pto.force_dynamic_valid_shape";
Expand Down Expand Up @@ -479,8 +479,8 @@ class PTOToEmitCTypeConverter : public TypeConverter {
// ---------------------------------------------------------
// 2. PTO 特殊类型 (透传或转换)
// ---------------------------------------------------------
addConversion([Ctx](emitc::OpaqueType type) { return type; });
addConversion([Ctx](emitc::PointerType type) { return type; });
addConversion([](emitc::OpaqueType type) { return type; });
addConversion([](emitc::PointerType type) { return type; });

// ---------------------------------------------------------
// 2.5 PtrType 转换 (指针类型)
Expand Down Expand Up @@ -3124,7 +3124,8 @@ struct SubviewToEmitCPattern : public OpConversionPattern<memref::SubViewOp> {
}
}

auto typedPtrTy = emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr + "*");
auto typedPtrTy = emitc::PointerType::get(
emitc::OpaqueType::get(ctx, qualifier + " " + castElemTypeStr));
Value typedSourcePtr = rewriter.create<emitc::CastOp>(loc, typedPtrTy, sourcePtr);
newPtr = rewriter.create<emitc::AddOp>(loc, typedPtrTy, typedSourcePtr, totalOffset);
} else {
Expand Down Expand Up @@ -3659,7 +3660,8 @@ static Value maybeWrapGlobalMemrefAsGlobalTensor(
static Value castToGMBytePointer(ConversionPatternRewriter &rewriter,
Location loc, Value value) {
auto *ctx = rewriter.getContext();
auto targetTy = emitc::OpaqueType::get(ctx, "__gm__ uint8_t*");
auto targetTy =
emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ uint8_t"));
if (value.getType() == targetTy)
return value;

Expand Down Expand Up @@ -4379,9 +4381,9 @@ static ArrayAttr buildAccPhaseTemplateArgs(ConversionPatternRewriter &rewriter,
case pto::AccPhase::Final:
tmpl = "AccPhase::Final";
break;
default:
llvm_unreachable("unknown AccPhase");
}
if (tmpl.empty())
return ArrayAttr{};
return rewriter.getArrayAttr(
{emitc::OpaqueAttr::get(rewriter.getContext(), tmpl)});
}
Expand Down Expand Up @@ -4806,7 +4808,7 @@ static LogicalResult extractSyncTripletTokens(Operation *op,
static inline std::string pipeTokFromPipeEnum(mlir::pto::PIPE p) {
return mlir::pto::stringifyPIPE(p).str();
}
static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) {
[[maybe_unused]] static inline std::string evtTokFromEventEnum(mlir::pto::EVENT e) {
return mlir::pto::stringifyEVENT(e).str();
}
static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a) {
Expand Down Expand Up @@ -5719,7 +5721,8 @@ struct PTOInitializeL2LPipeToEmitC
auto emitPipeTy =
cast<Type>(getTypeConverter()->convertType(op.getPipe().getType()));

auto gmPtrTy = emitc::OpaqueType::get(ctx, "__gm__ void *");
auto gmPtrTy =
emitc::PointerType::get(emitc::OpaqueType::get(ctx, "__gm__ void"));
Value nullGm =
makeEmitCOpaqueConstant(rewriter, op.getLoc(), gmPtrTy, "nullptr");
auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t");
Expand Down Expand Up @@ -11077,7 +11080,7 @@ class ArithCmpIToEmitC : public OpConversionPattern<arith::CmpIOp> {

// 将 arith.cmpi 转换为 emitc.cmp
// 映射 Predicate: eq -> equal, slt -> less, etc.
emitc::CmpPredicate emitcPred;
emitc::CmpPredicate emitcPred = emitc::CmpPredicate::eq;
const bool isUnsignedPred =
op.getPredicate() == arith::CmpIPredicate::ult ||
op.getPredicate() == arith::CmpIPredicate::ule ||
Expand All @@ -11095,7 +11098,6 @@ class ArithCmpIToEmitC : public OpConversionPattern<arith::CmpIOp> {
case arith::CmpIPredicate::ule: emitcPred = emitc::CmpPredicate::le; break;
case arith::CmpIPredicate::ugt: emitcPred = emitc::CmpPredicate::gt; break;
case arith::CmpIPredicate::uge: emitcPred = emitc::CmpPredicate::ge; break;
default: return failure();
}

Type resTy = getTypeConverter()->convertType(op.getType());
Expand Down
Loading