[AscendNPU-IR][A5] Support cmp and select ops.#908
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces support for comparison and selection operations on the A5 architecture within the NPUIR codegen. Key changes include the implementation of VselectCodegenA5, helper functions for mapping string modes to MLIR predicates, and updates to CreateHIVMBinaryVectorOp to handle arith::CmpFOp and arith::CmpIOp with appropriate type casting. Feedback highlights a potential MLIR verification failure due to mismatched element types during condition broadcasting and a logic error in the template dispatch for comparison operations. Additionally, improvements were suggested regarding the removal of unused variables and the implementation of stricter error handling for unsupported comparison modes.
| if constexpr (std::is_same_v<T, mlir::arith::CmpFOp>) { | ||
| cmpOp = builder.create<mlir::arith::CmpFOp>(loc, GetFPredicate(mode), src0, src1); | ||
| } else { | ||
| cmpOp = builder.create<mlir::arith::CmpIOp>(loc, GetIPredicate(mode), src0, src1); | ||
| } |
There was a problem hiding this comment.
The current logic uses if constexpr on the template parameter T to decide whether to create a CmpFOp or CmpIOp. However, CreateHIVMBinaryVectorOp is called with both CmpFOp and CmpIOp as template arguments (line 3572). This means std::is_same_v<T, mlir::arith::CmpFOp> will always be true in that context, forcing a CmpFOp even if the input operands are integers. You must check the actual runtime type of the operands (srcType) to dispatch to the correct MLIR operation.
mlir::Value cmpOp;
if (srcType.isa<mlir::FloatType>()) {
cmpOp = builder.create<mlir::arith::CmpFOp>(loc, GetFPredicate(mode), src0, src1);
} else {
cmpOp = builder.create<mlir::arith::CmpIOp>(loc, GetIPredicate(mode), src0, src1);
}880852a to
9ccfc70
Compare
| mlir::Value dst_data_name = GetVarValue(npuirop.dst); | ||
|
|
||
| if (!dst_data_name.getType().isa<mlir::TensorType>()) { | ||
| return; |
There was a problem hiding this comment.
Check all three inputs' (src0 , src1 , dst) TensorType.
| if (mode == "lt") return mlir::arith::CmpIPredicate::slt; | ||
| if (mode == "le") return mlir::arith::CmpIPredicate::sle; | ||
| if (mode == "gt") return mlir::arith::CmpIPredicate::sgt; | ||
| if (mode == "ge") return mlir::arith::CmpIPredicate::sge; |
There was a problem hiding this comment.
For unsigned integer types (e.g., uint8 ), ult / ule / ugt / uge should be used. The current implementation always uses signed comparisons, which will produce incorrect results for unsigned integers.
|
|
||
| mlir::arith::CmpFPredicate GetFPredicate(std::string mode) { | ||
| if (mode == "eq") return mlir::arith::CmpFPredicate::OEQ; | ||
| if (mode == "ne") return mlir::arith::CmpFPredicate::UNE; |
There was a problem hiding this comment.
UNE means "unordered or not equal", i.e., it returns true when either operand is NaN. The original hivm::VCmpOp 's NE mode semantics might be "ordered not equal" ( ONE ). For floating-point comparisons involving NaN, UNE and ONE behave differently.
Please confirm the NE semantics of the VCmpOp . If it's an ordered comparison, use ONE ; if unordered semantics are indeed intended, add a comment explaining why.
|
|
||
| A_full = torch.randn((M, N), dtype=torch.float16).npu() | ||
| B_full = torch.randn((M, N), dtype=torch.float16).npu() | ||
| C_full = torch.empty((M, N), dtype=torch.float16).npu() |
There was a problem hiding this comment.
Since the core CmpOp change involves float vs integer dispatch logic, maybe should add test cases for float32 and int32 / int8 types, especially end-to-end verification of integer comparisons.
Support compile and select ops.