Skip to content

Commit

Permalink
[LAYOUTS] Use least squares solution in invertAndCompose (triton-lang…
Browse files Browse the repository at this point in the history
…#5309)

In this PR, we remove the need fro a few hacks in `invertAndCompose`,
namely the need for `getInjectiveMat` which did not work in cases where
the input and the output had a different number of registers (same with
different number of blocks) and led to the implementation of hacks on
top of it like the gymnastics with `getFreeVariable`.

We now just compute the `invertAndCompose` as the matrix `X` which is
the solution to the system `AX = B`. We add enough asserts to check that
this system has at least one solution (i.e. A is surjective) and we make
explicit the heuristic we use to minimise data-movement (not consider
dimensions that are the same, and otherwise incentivise broadcasting via
choosing the solution of minimal norm). For an explanation of how to
solve
this system, see
https://github.com/triton-lang/triton/pull/5309/files/a9069c73637a6b4735cdc39d1c7f338cfdd17a8f#r1869084111

In the future, this function would be better returning the compact form
of the system, where if a dimension is not present it's because the
conversion is uniform over that dimension, but for that we need to adapt
our lowering algorithms.
  • Loading branch information
lezcano authored Dec 5, 2024
1 parent 390e27f commit 67ea999
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 250 deletions.
7 changes: 0 additions & 7 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,6 @@ class LinearLayout {
// (i.e. every input bit affects the output).
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;

// Increase an input dimension without affecting the output dimension. The
// added free variables are mapped to 0, ensuring that the new input
// dimensions correspond directly to the existing output space. The function
// errors out if `newInDimSize` is less than the current size or the new size
// is not a power of 2.
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;

std::string toString() const;

friend bool operator==(LinearLayout lhs, LinearLayout rhs);
Expand Down
38 changes: 2 additions & 36 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,42 +663,8 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
auto numDstRegs = dstLayout->getInDimSize(kRegister);
// The `invertAndCompose` function will generate a layout that is injective
// by assigning new output dimensions to free variables. For instance,
// consider a scenario where `srcLayout` has a free variable in the lane
// dimension, while `dstLayout` has two free variables in the lane
// dimension and also a larger number of registers.
// The injective form of `srcLayout` will add only a single additional row
// to the transformation matrix, whereas the injective form of `dstLayout`
// will add two additional rows. This discrepancy causes misleading results
// because the matrices end up with a different number of rows.
//
// Take `dstLayout ⋅ srcLayout^-1` as an example:
//
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
// 1] → [n + 2, n + 1]
//
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
// variable in registers, and the `(n + 2)`-th row represents the free
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
// in two layouts do not correspond to the same free variable.
//
// To address this issue, we pad the free variables in `srcLayout` and
// `dstLayout` to ensure they have the same number of registers. This
// guarantees that the resulting matrices have the same number of rows,
// ensuring consistency in the composition process.
auto numRegs = std::max(numSrcRegs, numDstRegs);
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
// comp describes the layout function to create dst from src.
LinearLayout comp =
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);

auto comp = dstLayout->invertAndCompose(*srcLayout);
// We try to quotient by the largest subspace first
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
Expand Down
25 changes: 7 additions & 18 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// TODO(Keren): implement warp shuffle instead of using the general
// approach that uses shared memory
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
} else if (llvm::is_contained(dims, kRegister) ||
dstLayout.getInDimSize(kRegister) !=
srcLayout.getInDimSize(kRegister)) {
} else if (llvm::is_contained(dims, kRegister)) {
// Case 4. Transfer between values in the same thread, in which case we
// simply reorder the elements of adaptor.getSrc().
return transferWithinThread(
op, dstLayout.getFreeVariableMasks()[kRegister],
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
return transferWithinThread(op, *conversion, adaptor, rewriter);
} else {
// Cast 5. The two layouts are equivalent. We should probably remove
// these in RemoveLayoutConversion.
Expand All @@ -332,8 +328,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}

LogicalResult
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
const LinearLayout &conversion, OpAdaptor adaptor,
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
auto loc = op.getLoc();
Expand All @@ -343,16 +339,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> outVals(numRegs);
for (int i = 0; i < numRegs; i++) {
// Remove free masks from the register index
// For example, if idx = 0b00111, and masks = 0b00100, then we get
// 0b00011. It means that register 7 (0b111) has the same value as
// register 3 (0b011).
auto idx = i & (~regMasks);
auto srcIdx = conversion.hasInDim(kRegister)
? conversion.apply({{kRegister, idx}}).begin()->second
: idx;
SmallVector<Value> outVals(conversion.getInDimSize(kRegister));
for (int i = 0; i < outVals.size(); i++) {
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
outVals[i] = inVals[srcIdx];
}
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
Expand Down
Loading

0 comments on commit 67ea999

Please sign in to comment.