16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/TypeSwitch.h"
22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23 #include "mlir/Dialect/Linalg/Passes.h.inc"
32 if (!stride || *stride != 1)
40 return (*size - *offset);
47 if (dims.size() != tiles.size() || tiles.empty())
50 FailureOr<ContractionDimensions> contractDims =
52 if (failed(contractDims))
54 unsigned batchDimsOffset = contractDims->batch.size();
59 for (
size_t i = 0; i < offsetDims.size(); i++)
60 offsetDims[i] += batchDimsOffset;
62 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
68 if (dim.value() >=
static_cast<int64_t
>(iterationDomain.size()))
72 std::optional<int64_t> rangeOnDim =
77 if (!tileSize || !rangeOnDim)
81 if (*rangeOnDim % *tileSize != 0)
89 static FailureOr<PackTransposeResult>
91 tensor::PackOp packOp,
AffineMap operandMap,
93 bool transposeOuterBlocks,
bool transposeInnerBlocks) {
95 "expected at least 4D prepacked matmul");
96 assert(blocksStartDimPos.size() >= 2 &&
97 "expected starting outer and inner block positions");
107 bool isOuterTransposed =
108 operandMap.
getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
109 bool isInnerTransposed =
110 operandMap.
getDimPosition(innerBlockPos) != blocksStartDimPos.back();
115 if (isInnerTransposed != transposeInnerBlocks)
118 if (isOuterTransposed != transposeOuterBlocks)
124 for (
auto i : llvm::seq(0u, outerBlockPos))
125 offsetPerms.push_back(i);
126 for (
auto perm : outerPerm)
127 offsetPerms.push_back(perm + outerBlockPos);
128 outerPerm = offsetPerms;
130 FailureOr<PackTransposeResult> packTransposedMatmul =
132 nullptr, outerPerm, innerPerm);
134 return packTransposedMatmul;
138 FailureOr<PackResult>
141 if (linalgOp.hasPureBufferSemantics())
144 std::optional<BlockPackMatmulOptions>
options = controlPackMatmul(linalgOp);
148 if (
options->blockFactors.size() != 3)
158 "expect packing full tiles only");
170 rewriter, linalgOp, mnkTiles,
options->mnkPaddedSizesNextMultipleOf,
172 if (failed(packedMatmul))
175 assert(packedMatmul->packOps.size() == 3 &&
176 "invalid number of pack ops after matmul packing");
177 assert(packedMatmul->unPackOps.size() == 1 &&
178 "invalid number of unpack ops after matmul packing");
180 FailureOr<ContractionDimensions> contractDims =
182 if (failed(contractDims))
186 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
191 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
192 contractDims->m,
options->lhsTransposeOuterBlocks,
193 options->lhsTransposeInnerBlocks);
194 if (failed(packedLhs))
198 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
199 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
203 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
204 contractDims->k,
options->rhsTransposeOuterBlocks,
205 options->rhsTransposeInnerBlocks);
206 if (failed(packedRhs))
210 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
211 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
217 template <
typename OpTy>
223 LogicalResult matchAndRewrite(OpTy linalgOp,
225 FailureOr<PackResult> packedMatmul =
227 if (failed(packedMatmul))
237 struct BlockPackMatmul<linalg::GenericOp>
242 controlFn(std::move(fun)) {}
244 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
252 auto infer = [&](MapList m) {
257 bindDims(linalgOp->getContext(), i,
j, k);
261 if (!(maps == infer({{i, k}, {k,
j}, {i,
j}}) ||
262 maps == infer({{k, i}, {k,
j}, {i,
j}}) ||
263 maps == infer({{i, k}, {
j, k}, {i,
j}}))) {
267 FailureOr<PackResult> packedMatmul =
269 if (failed(packedMatmul))
279 struct LinalgBlockPackMatmul
280 :
public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
281 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
283 void runOnOperation()
override {
291 options.allowPadding = allowPadding;
292 options.mnkPaddedSizesNextMultipleOf =
294 if (!mnkOrder.empty())
296 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
297 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
298 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
299 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
305 return signalPassFailure();
312 patterns.add<BlockPackMatmul<linalg::GenericOp>,
313 BlockPackMatmul<linalg::MatmulOp>,
314 BlockPackMatmul<linalg::BatchMatmulOp>,
315 BlockPackMatmul<linalg::MatmulTransposeAOp>,
316 BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
317 BlockPackMatmul<linalg::MatmulTransposeBOp>,
318 BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, tensor::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.