MLIR 22.0.0git
BlockPackMatmul.cpp
Go to the documentation of this file.
1//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "llvm/ADT/SmallVector.h"
17
18#include <optional>
19
20namespace mlir {
21#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22#include "mlir/Dialect/Linalg/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::linalg;
27
28/// Return constant range span or nullopt, otherwise.
29static std::optional<int64_t> getConstantRange(const Range &range) {
30 std::optional<int64_t> stride = getConstantIntValue(range.stride);
31 if (!stride || *stride != 1)
32 return std::nullopt;
33 std::optional<int64_t> offset = getConstantIntValue(range.offset);
34 if (!offset)
35 return std::nullopt;
36 std::optional<int64_t> size = getConstantIntValue(range.size);
37 if (!size)
38 return std::nullopt;
39 return (*size - *offset);
40}
41
42/// Return true if all dimensions are fully divisible by the respective tiles.
43static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
45 ArrayRef<int64_t> dims) {
46 if (dims.size() != tiles.size() || tiles.empty())
47 return false;
48
49 FailureOr<ContractionDimensions> contractDims =
50 inferContractionDims(linalgOp);
51 if (failed(contractDims))
52 return false;
53 unsigned batchDimsOffset = contractDims->batch.size();
54
55 // Skip the batch dimension if present.
56 // Offset all dimensions accordingly.
57 SmallVector<int64_t, 3> offsetDims(dims);
58 for (int64_t &offsetDim : offsetDims)
59 offsetDim += batchDimsOffset;
60
61 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
62 OpBuilder builder(tileOp);
63 OpBuilder::InsertionGuard guard(builder);
64 SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
65
66 for (auto dim : llvm::enumerate(offsetDims)) {
67 if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
68 return false;
69
70 std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
71 std::optional<int64_t> rangeOnDim =
72 getConstantRange(iterationDomain[dim.value()]);
73
74 // If the tile factor or the range are non-constant, the tile size is
75 // considered to be invalid.
76 if (!tileSize || !rangeOnDim)
77 return false;
78
79 // The dimension must be fully divisible by the tile.
80 if (*rangeOnDim % *tileSize != 0)
81 return false;
82 }
83
84 return true;
85}
86
87/// Return failure or packed matmul with one of its operands transposed.
88static FailureOr<PackTransposeResult>
89transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
90 linalg::PackOp packOp, AffineMap operandMap,
91 ArrayRef<unsigned> blocksStartDimPos,
92 bool transposeOuterBlocks, bool transposeInnerBlocks) {
93 assert(operandMap.getNumDims() >= 4 &&
94 "expected at least 4D prepacked matmul");
95 assert(blocksStartDimPos.size() >= 2 &&
96 "expected starting outer and inner block positions");
97
98 // Bias toward innermost dimensions.
99 unsigned outerBlockPos = operandMap.getNumResults() - 4;
100 unsigned innerBlockPos = operandMap.getNumResults() - 2;
101
102 // Transpose control options define the desired block and element layout.
103 // Block transposition (outer dimensions) or element transposition (inner
104 // dimensions) may not be necessary depending on the original matmul data
105 // layout.
106 bool isOuterTransposed =
107 operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
108 bool isInnerTransposed =
109 operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
110
111 // Transpose only the dimensions that need that to conform to the provided
112 // transpotion settings.
113 SmallVector<int64_t> innerPerm = {0, 1};
114 if (isInnerTransposed != transposeInnerBlocks)
115 innerPerm = {1, 0};
116 SmallVector<int64_t> outerPerm = {0, 1};
117 if (isOuterTransposed != transposeOuterBlocks)
118 outerPerm = {1, 0};
119
120 // Leave the outer dimensions, like batch, unchanged by offsetting all
121 // outer dimensions permutations.
122 SmallVector<int64_t> offsetPerms;
123 for (auto i : llvm::seq(0u, outerBlockPos))
124 offsetPerms.push_back(i);
125 for (auto perm : outerPerm)
126 offsetPerms.push_back(perm + outerBlockPos);
127 outerPerm = offsetPerms;
128
129 FailureOr<PackTransposeResult> packTransposedMatmul =
130 packTranspose(rewriter, packOp, linalgOp,
131 /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
132
133 return packTransposedMatmul;
134}
135
136/// Pack a matmul operation into blocked 4D layout.
137FailureOr<PackResult>
138linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
139 const ControlBlockPackMatmulFn &controlPackMatmul) {
140 // Check to not let go the batch_matmul with extended semantic, through this
141 // transform.
142 if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
143 if (batchMatmulOp->hasUserDefinedMaps()) {
144 return rewriter.notifyMatchFailure(
145 *batchMatmulOp,
146 "only batch_matmul ops with non-extended semantics are supported");
147 }
148 }
149
150 if (linalgOp.hasPureBufferSemantics())
151 return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
152
153 std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
154 if (!options)
155 return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
156
157 if (options->blockFactors.size() != 3)
158 return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
159
161 getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
162
163 // If padding is disabled, make sure that dimensions can be packed cleanly.
164 if (!options->allowPadding &&
165 !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
166 return rewriter.notifyMatchFailure(linalgOp,
167 "expect packing full tiles only");
168 }
169
170 OpBuilder::InsertionGuard guard(rewriter);
171 // The op is replaced, we need to set the insertion point after it.
172 rewriter.setInsertionPointAfter(linalgOp);
173
174 // Pack the matmul operation into blocked layout with two levels of
175 // subdivision:
176 // - major 2D blocks - outer dimensions, consist of minor blocks
177 // - minor 2D blocks - inner dimensions, consist of scalar elements
178 FailureOr<PackResult> packedMatmul = packMatmulGreedily(
179 rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
180 options->mnkOrder);
181 if (failed(packedMatmul))
182 return failure();
183
184 assert(packedMatmul->packOps.size() == 3 &&
185 "invalid number of pack ops after matmul packing");
186 assert(packedMatmul->unPackOps.size() == 1 &&
187 "invalid number of unpack ops after matmul packing");
188
189 FailureOr<ContractionDimensions> contractDims =
190 inferContractionDims(packedMatmul->packedLinalgOp);
191 if (failed(contractDims))
192 return failure();
193
194 auto genericOp =
195 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
196 SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
197
198 // Transpose LHS matrix according to the options.
199 FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
200 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
201 contractDims->m, options->lhsTransposeOuterBlocks,
202 options->lhsTransposeInnerBlocks);
203 if (failed(packedLhs))
204 return failure();
205
206 // Update results.
207 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
208 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
209
210 // Transpose RHS matrix according to the options.
211 FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
212 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
213 contractDims->k, options->rhsTransposeOuterBlocks,
214 options->rhsTransposeInnerBlocks);
215 if (failed(packedRhs))
216 return failure();
217
218 // Update results.
219 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
220 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
221
222 return packedMatmul;
223}
224
225namespace {
226template <typename OpTy>
227struct BlockPackMatmul : public OpRewritePattern<OpTy> {
228 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
229 PatternBenefit benefit = 1)
230 : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
231
232 LogicalResult matchAndRewrite(OpTy linalgOp,
233 PatternRewriter &rewriter) const override {
234 FailureOr<PackResult> packedMatmul =
235 blockPackMatmul(rewriter, linalgOp, controlFn);
236 if (failed(packedMatmul))
237 return failure();
238 return success();
239 }
240
241private:
242 ControlBlockPackMatmulFn controlFn;
243};
244
245template <>
246struct BlockPackMatmul<linalg::GenericOp>
247 : public OpRewritePattern<linalg::GenericOp> {
248 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
249 PatternBenefit benefit = 1)
250 : OpRewritePattern<linalg::GenericOp>(context, benefit),
251 controlFn(std::move(fun)) {}
252
253 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
254 PatternRewriter &rewriter) const override {
255 // Match suitable generics.
256 if (!linalg::isaContractionOpInterface(linalgOp)) {
257 return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
258 }
259
260 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
261 auto infer = [&](MapList m) {
262 return AffineMap::inferFromExprList(m, linalgOp.getContext());
263 };
264
265 AffineExpr i, j, k;
266 bindDims(linalgOp->getContext(), i, j, k);
267 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
268
269 // For now, only match simple matmuls.
270 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
271 maps == infer({{k, i}, {k, j}, {i, j}}) ||
272 maps == infer({{i, k}, {j, k}, {i, j}}))) {
273 return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
274 }
275
276 FailureOr<PackResult> packedMatmul =
277 blockPackMatmul(rewriter, linalgOp, controlFn);
278 if (failed(packedMatmul))
279 return failure();
280 return success();
281 }
282
283private:
284 ControlBlockPackMatmulFn controlFn;
285};
286
287/// Convert linalg matmul ops to block layout and back.
288struct LinalgBlockPackMatmul
289 : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
290 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
291
292 void runOnOperation() override {
293 Operation *op = getOperation();
294 RewritePatternSet patterns(&getContext());
295
296 ControlBlockPackMatmulFn controlFn =
297 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
298 BlockPackMatmulOptions options;
299 options.blockFactors = SmallVector<int64_t>{*blockFactors};
300 options.allowPadding = allowPadding;
301 options.mnkPaddedSizesNextMultipleOf =
302 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
303 if (!mnkOrder.empty())
304 options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
305 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
306 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
307 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
308 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
309 return options;
310 };
311
313 if (failed(applyPatternsGreedily(op, std::move(patterns))))
314 return signalPassFailure();
315 }
316};
317} // namespace
318
321 patterns.add<BlockPackMatmul<linalg::GenericOp>,
322 BlockPackMatmul<linalg::MatmulOp>,
323 BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
324 controlFn);
325}
return success()
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset