16#include "llvm/ADT/SmallVector.h"
21#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22#include "mlir/Dialect/Linalg/Passes.h.inc"
31 if (!stride || *stride != 1)
39 return (*size - *offset);
46 if (dims.size() != tiles.size() || tiles.empty())
49 FailureOr<ContractionDimensions> contractDims =
51 if (failed(contractDims))
53 unsigned batchDimsOffset = contractDims->batch.size();
58 for (
int64_t &offsetDim : offsetDims)
59 offsetDim += batchDimsOffset;
61 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
66 for (
auto dim : llvm::enumerate(offsetDims)) {
67 if (dim.value() >=
static_cast<int64_t>(iterationDomain.size()))
71 std::optional<int64_t> rangeOnDim =
76 if (!tileSize || !rangeOnDim)
80 if (*rangeOnDim % *tileSize != 0)
88static FailureOr<PackTransposeResult>
90 linalg::PackOp packOp,
AffineMap operandMap,
92 bool transposeOuterBlocks,
bool transposeInnerBlocks) {
94 if (!packOp.hasPureTensorSemantics())
98 "expected at least 4D prepacked matmul");
99 assert(blocksStartDimPos.size() >= 2 &&
100 "expected starting outer and inner block positions");
110 bool isOuterTransposed =
111 operandMap.
getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
112 bool isInnerTransposed =
113 operandMap.
getDimPosition(innerBlockPos) != blocksStartDimPos.back();
118 if (isInnerTransposed != transposeInnerBlocks)
121 if (isOuterTransposed != transposeOuterBlocks)
127 for (
auto i : llvm::seq(0u, outerBlockPos))
128 offsetPerms.push_back(i);
129 for (
auto perm : outerPerm)
130 offsetPerms.push_back(perm + outerBlockPos);
131 outerPerm = offsetPerms;
133 FailureOr<PackTransposeResult> packTransposedMatmul =
135 nullptr, outerPerm, innerPerm);
137 return packTransposedMatmul;
146 if (
auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
147 if (batchMatmulOp->hasUserDefinedMaps()) {
150 "only batch_matmul ops with non-extended semantics are supported");
154 if (linalgOp.hasPureBufferSemantics())
157 std::optional<BlockPackMatmulOptions>
options = controlPackMatmul(linalgOp);
161 if (
options->blockFactors.size() != 3)
171 "expect packing full tiles only");
183 rewriter, linalgOp, mnkTiles,
options->mnkPaddedSizesNextMultipleOf,
185 if (failed(packedMatmul))
188 assert(packedMatmul->packOps.size() == 3 &&
189 "invalid number of pack ops after matmul packing");
190 assert(packedMatmul->unPackOps.size() == 1 &&
191 "invalid number of unpack ops after matmul packing");
193 FailureOr<ContractionDimensions> contractDims =
195 if (failed(contractDims))
199 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
204 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
205 contractDims->m,
options->lhsTransposeOuterBlocks,
206 options->lhsTransposeInnerBlocks);
207 if (failed(packedLhs))
211 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
212 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
216 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
217 contractDims->k,
options->rhsTransposeOuterBlocks,
218 options->rhsTransposeInnerBlocks);
219 if (failed(packedRhs))
223 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
224 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
230template <
typename OpTy>
236 LogicalResult matchAndRewrite(OpTy linalgOp,
238 FailureOr<PackResult> packedMatmul =
240 if (failed(packedMatmul))
250struct BlockPackMatmul<
linalg::GenericOp>
253 PatternBenefit benefit = 1)
254 : OpRewritePattern<linalg::GenericOp>(context, benefit),
255 controlFn(std::move(fun)) {}
257 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
258 PatternRewriter &rewriter)
const override {
264 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
265 auto infer = [&](MapList m) {
270 bindDims(linalgOp->getContext(), i, j, k);
271 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
274 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
275 maps == infer({{k, i}, {k, j}, {i, j}}) ||
276 maps == infer({{i, k}, {j, k}, {i, j}}))) {
280 FailureOr<PackResult> packedMatmul =
292struct LinalgBlockPackMatmul
293 :
public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
294 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
296 void runOnOperation()
override {
297 Operation *op = getOperation();
301 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
302 BlockPackMatmulOptions
options;
303 options.blockFactors = SmallVector<int64_t>{*blockFactors};
304 options.allowPadding = allowPadding;
305 options.mnkPaddedSizesNextMultipleOf =
306 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
307 if (!mnkOrder.empty())
308 options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
309 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
310 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
311 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
312 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
318 return signalPassFailure();
325 patterns.add<BlockPackMatmul<linalg::GenericOp>,
326 BlockPackMatmul<linalg::MatmulOp>,
327 BlockPackMatmul<linalg::BatchMatmulOp>>(
patterns.getContext(),
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...