16#include "llvm/ADT/SmallVector.h"
21#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22#include "mlir/Dialect/Linalg/Passes.h.inc"
31 if (!stride || *stride != 1)
39 return (*size - *offset);
46 if (dims.size() != tiles.size() || tiles.empty())
49 FailureOr<ContractionDimensions> contractDims =
51 if (failed(contractDims))
53 unsigned batchDimsOffset = contractDims->batch.size();
58 for (
int64_t &offsetDim : offsetDims)
59 offsetDim += batchDimsOffset;
61 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
66 for (
auto dim : llvm::enumerate(offsetDims)) {
67 if (dim.value() >=
static_cast<int64_t>(iterationDomain.size()))
71 std::optional<int64_t> rangeOnDim =
76 if (!tileSize || !rangeOnDim)
80 if (*rangeOnDim % *tileSize != 0)
88static FailureOr<PackTransposeResult>
90 linalg::PackOp packOp,
AffineMap operandMap,
92 bool transposeOuterBlocks,
bool transposeInnerBlocks) {
94 "expected at least 4D prepacked matmul");
95 assert(blocksStartDimPos.size() >= 2 &&
96 "expected starting outer and inner block positions");
106 bool isOuterTransposed =
107 operandMap.
getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
108 bool isInnerTransposed =
109 operandMap.
getDimPosition(innerBlockPos) != blocksStartDimPos.back();
114 if (isInnerTransposed != transposeInnerBlocks)
117 if (isOuterTransposed != transposeOuterBlocks)
123 for (
auto i : llvm::seq(0u, outerBlockPos))
124 offsetPerms.push_back(i);
125 for (
auto perm : outerPerm)
126 offsetPerms.push_back(perm + outerBlockPos);
127 outerPerm = offsetPerms;
129 FailureOr<PackTransposeResult> packTransposedMatmul =
131 nullptr, outerPerm, innerPerm);
133 return packTransposedMatmul;
142 if (
auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
143 if (batchMatmulOp->hasUserDefinedMaps()) {
146 "only batch_matmul ops with non-extended semantics are supported");
150 if (linalgOp.hasPureBufferSemantics())
153 std::optional<BlockPackMatmulOptions>
options = controlPackMatmul(linalgOp);
157 if (
options->blockFactors.size() != 3)
167 "expect packing full tiles only");
179 rewriter, linalgOp, mnkTiles,
options->mnkPaddedSizesNextMultipleOf,
181 if (failed(packedMatmul))
184 assert(packedMatmul->packOps.size() == 3 &&
185 "invalid number of pack ops after matmul packing");
186 assert(packedMatmul->unPackOps.size() == 1 &&
187 "invalid number of unpack ops after matmul packing");
189 FailureOr<ContractionDimensions> contractDims =
191 if (failed(contractDims))
195 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
200 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
201 contractDims->m,
options->lhsTransposeOuterBlocks,
202 options->lhsTransposeInnerBlocks);
203 if (failed(packedLhs))
207 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
208 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
212 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
213 contractDims->k,
options->rhsTransposeOuterBlocks,
214 options->rhsTransposeInnerBlocks);
215 if (failed(packedRhs))
219 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
220 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
226template <
typename OpTy>
232 LogicalResult matchAndRewrite(OpTy linalgOp,
234 FailureOr<PackResult> packedMatmul =
236 if (failed(packedMatmul))
246struct BlockPackMatmul<
linalg::GenericOp>
249 PatternBenefit benefit = 1)
250 : OpRewritePattern<linalg::GenericOp>(context, benefit),
251 controlFn(std::move(fun)) {}
253 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
254 PatternRewriter &rewriter)
const override {
260 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
261 auto infer = [&](MapList m) {
266 bindDims(linalgOp->getContext(), i, j, k);
267 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
270 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
271 maps == infer({{k, i}, {k, j}, {i, j}}) ||
272 maps == infer({{i, k}, {j, k}, {i, j}}))) {
276 FailureOr<PackResult> packedMatmul =
288struct LinalgBlockPackMatmul
290 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
292 void runOnOperation()
override {
293 Operation *op = getOperation();
297 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
298 BlockPackMatmulOptions
options;
299 options.blockFactors = SmallVector<int64_t>{*blockFactors};
300 options.allowPadding = allowPadding;
301 options.mnkPaddedSizesNextMultipleOf =
302 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
303 if (!mnkOrder.empty())
304 options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
305 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
306 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
307 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
308 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
314 return signalPassFailure();
321 patterns.add<BlockPackMatmul<linalg::GenericOp>,
322 BlockPackMatmul<linalg::MatmulOp>,
323 BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(),
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...