16 #include "llvm/ADT/SmallVector.h"
21 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22 #include "mlir/Dialect/Linalg/Passes.h.inc"
31 if (!stride || *stride != 1)
39 return (*size - *offset);
46 if (dims.size() != tiles.size() || tiles.empty())
49 FailureOr<ContractionDimensions> contractDims =
53 unsigned batchDimsOffset = contractDims->batch.size();
58 for (int64_t &offsetDim : offsetDims)
59 offsetDim += batchDimsOffset;
61 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
67 if (dim.value() >=
static_cast<int64_t
>(iterationDomain.size()))
71 std::optional<int64_t> rangeOnDim =
76 if (!tileSize || !rangeOnDim)
80 if (*rangeOnDim % *tileSize != 0)
88 static FailureOr<PackTransposeResult>
90 linalg::PackOp packOp,
AffineMap operandMap,
92 bool transposeOuterBlocks,
bool transposeInnerBlocks) {
94 "expected at least 4D prepacked matmul");
95 assert(blocksStartDimPos.size() >= 2 &&
96 "expected starting outer and inner block positions");
106 bool isOuterTransposed =
107 operandMap.
getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
108 bool isInnerTransposed =
109 operandMap.
getDimPosition(innerBlockPos) != blocksStartDimPos.back();
114 if (isInnerTransposed != transposeInnerBlocks)
117 if (isOuterTransposed != transposeOuterBlocks)
123 for (
auto i : llvm::seq(0u, outerBlockPos))
124 offsetPerms.push_back(i);
125 for (
auto perm : outerPerm)
126 offsetPerms.push_back(perm + outerBlockPos);
127 outerPerm = offsetPerms;
129 FailureOr<PackTransposeResult> packTransposedMatmul =
131 nullptr, outerPerm, innerPerm);
133 return packTransposedMatmul;
137 FailureOr<PackResult>
142 if (
auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
143 if (batchMatmulOp->hasUserDefinedMaps()) {
146 "only batch_matmul ops with non-extended semantics are supported");
150 if (linalgOp.hasPureBufferSemantics())
153 std::optional<BlockPackMatmulOptions>
options = controlPackMatmul(linalgOp);
157 if (
options->blockFactors.size() != 3)
167 "expect packing full tiles only");
179 rewriter, linalgOp, mnkTiles,
options->mnkPaddedSizesNextMultipleOf,
184 assert(packedMatmul->packOps.size() == 3 &&
185 "invalid number of pack ops after matmul packing");
186 assert(packedMatmul->unPackOps.size() == 1 &&
187 "invalid number of unpack ops after matmul packing");
189 FailureOr<ContractionDimensions> contractDims =
195 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
200 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
201 contractDims->m,
options->lhsTransposeOuterBlocks,
202 options->lhsTransposeInnerBlocks);
207 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
208 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
212 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
213 contractDims->k,
options->rhsTransposeOuterBlocks,
214 options->rhsTransposeInnerBlocks);
219 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
220 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
226 template <
typename OpTy>
232 LogicalResult matchAndRewrite(OpTy linalgOp,
234 FailureOr<PackResult> packedMatmul =
246 struct BlockPackMatmul<linalg::GenericOp>
251 controlFn(std::move(fun)) {}
253 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
261 auto infer = [&](MapList m) {
266 bindDims(linalgOp->getContext(), i,
j, k);
270 if (!(maps == infer({{i, k}, {k,
j}, {i,
j}}) ||
271 maps == infer({{k, i}, {k,
j}, {i,
j}}) ||
272 maps == infer({{i, k}, {
j, k}, {i,
j}}))) {
276 FailureOr<PackResult> packedMatmul =
288 struct LinalgBlockPackMatmul
289 :
public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
290 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
292 void runOnOperation()
override {
300 options.allowPadding = allowPadding;
301 options.mnkPaddedSizesNextMultipleOf =
303 if (!mnkOrder.empty())
305 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
306 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
307 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
308 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
314 return signalPassFailure();
321 patterns.add<BlockPackMatmul<linalg::GenericOp>,
322 BlockPackMatmul<linalg::MatmulOp>,
323 BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(),
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.