16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/TypeSwitch.h"
22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 if (!stride || *stride != 1)
40 return (*size - *offset);
47 if (dims.size() != tiles.size() || tiles.empty())
50 FailureOr<ContractionDimensions> contractDims =
52 if (failed(contractDims))
54 unsigned batchDimsOffset = contractDims->batch.size();
59 for (
size_t i = 0; i < offsetDims.size(); i++)
60 offsetDims[i] += batchDimsOffset;
62 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
68 if (dim.value() >=
static_cast<int64_t
>(iterationDomain.size()))
72 std::optional<int64_t> rangeOnDim =
77 if (!tileSize || !rangeOnDim)
81 if (*rangeOnDim % *tileSize != 0)
89 static FailureOr<PackTransposeResult>
91 linalg::PackOp packOp,
AffineMap operandMap,
93 bool transposeOuterBlocks,
bool transposeInnerBlocks) {
95 "expected at least 4D prepacked matmul");
96 assert(blocksStartDimPos.size() >= 2 &&
97 "expected starting outer and inner block positions");
107 bool isOuterTransposed =
108 operandMap.
getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
109 bool isInnerTransposed =
110 operandMap.
getDimPosition(innerBlockPos) != blocksStartDimPos.back();
115 if (isInnerTransposed != transposeInnerBlocks)
118 if (isOuterTransposed != transposeOuterBlocks)
124 for (
auto i : llvm::seq(0u, outerBlockPos))
125 offsetPerms.push_back(i);
126 for (
auto perm : outerPerm)
127 offsetPerms.push_back(perm + outerBlockPos);
128 outerPerm = offsetPerms;
130 FailureOr<PackTransposeResult> packTransposedMatmul =
132 nullptr, outerPerm, innerPerm);
134 return packTransposedMatmul;
138 FailureOr<PackResult>
143 if (
auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
144 if (batchMatmulOp->hasUserDefinedMaps()) {
147 "only batch_matmul ops with non-extended semantics are supported");
151 if (linalgOp.hasPureBufferSemantics())
154 std::optional<BlockPackMatmulOptions>
options = controlPackMatmul(linalgOp);
158 if (
options->blockFactors.size() != 3)
168 "expect packing full tiles only");
180 rewriter, linalgOp, mnkTiles,
options->mnkPaddedSizesNextMultipleOf,
182 if (failed(packedMatmul))
185 assert(packedMatmul->packOps.size() == 3 &&
186 "invalid number of pack ops after matmul packing");
187 assert(packedMatmul->unPackOps.size() == 1 &&
188 "invalid number of unpack ops after matmul packing");
190 FailureOr<ContractionDimensions> contractDims =
192 if (failed(contractDims))
196 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
201 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
202 contractDims->m,
options->lhsTransposeOuterBlocks,
203 options->lhsTransposeInnerBlocks);
204 if (failed(packedLhs))
208 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
209 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
213 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
214 contractDims->k,
options->rhsTransposeOuterBlocks,
215 options->rhsTransposeInnerBlocks);
216 if (failed(packedRhs))
220 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
221 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
227 template <
typename OpTy>
233 LogicalResult matchAndRewrite(OpTy linalgOp,
235 FailureOr<PackResult> packedMatmul =
237 if (failed(packedMatmul))
247 struct BlockPackMatmul<linalg::GenericOp>
252 controlFn(std::move(fun)) {}
254 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
262 auto infer = [&](MapList m) {
267 bindDims(linalgOp->getContext(), i,
j, k);
271 if (!(maps == infer({{i, k}, {k,
j}, {i,
j}}) ||
272 maps == infer({{k, i}, {k,
j}, {i,
j}}) ||
273 maps == infer({{i, k}, {
j, k}, {i,
j}}))) {
277 FailureOr<PackResult> packedMatmul =
279 if (failed(packedMatmul))
289 struct LinalgBlockPackMatmul
290 :
public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
291 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
293 void runOnOperation()
override {
301 options.allowPadding = allowPadding;
302 options.mnkPaddedSizesNextMultipleOf =
304 if (!mnkOrder.empty())
306 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
307 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
308 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
309 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
315 return signalPassFailure();
322 patterns.add<BlockPackMatmul<linalg::GenericOp>,
323 BlockPackMatmul<linalg::MatmulOp>,
324 BlockPackMatmul<linalg::BatchMatmulOp>,
325 BlockPackMatmul<linalg::MatmulTransposeAOp>,
326 BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
327 BlockPackMatmul<linalg::MatmulTransposeBOp>,
328 BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.