MLIR 23.0.0git
BlockPackMatmul.cpp
Go to the documentation of this file.
1//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "llvm/ADT/SmallVector.h"
17
18#include <optional>
19
20namespace mlir {
21#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22#include "mlir/Dialect/Linalg/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::linalg;
27
28/// Return constant range span or nullopt, otherwise.
29static std::optional<int64_t> getConstantRange(const Range &range) {
30 std::optional<int64_t> stride = getConstantIntValue(range.stride);
31 if (!stride || *stride != 1)
32 return std::nullopt;
33 std::optional<int64_t> offset = getConstantIntValue(range.offset);
34 if (!offset)
35 return std::nullopt;
36 std::optional<int64_t> size = getConstantIntValue(range.size);
37 if (!size)
38 return std::nullopt;
39 return (*size - *offset);
40}
41
42/// Return true if all dimensions are fully divisible by the respective tiles.
43static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
45 ArrayRef<int64_t> dims) {
46 if (dims.size() != tiles.size() || tiles.empty())
47 return false;
48
49 FailureOr<ContractionDimensions> contractDims =
50 inferContractionDims(linalgOp);
51 if (failed(contractDims))
52 return false;
53 unsigned batchDimsOffset = contractDims->batch.size();
54
55 // Skip the batch dimension if present.
56 // Offset all dimensions accordingly.
57 SmallVector<int64_t, 3> offsetDims(dims);
58 for (int64_t &offsetDim : offsetDims)
59 offsetDim += batchDimsOffset;
60
61 auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
62 OpBuilder builder(tileOp);
63 OpBuilder::InsertionGuard guard(builder);
64 SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
65
66 for (auto dim : llvm::enumerate(offsetDims)) {
67 if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
68 return false;
69
70 std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
71 std::optional<int64_t> rangeOnDim =
72 getConstantRange(iterationDomain[dim.value()]);
73
74 // If the tile factor or the range are non-constant, the tile size is
75 // considered to be invalid.
76 if (!tileSize || !rangeOnDim)
77 return false;
78
79 // The dimension must be fully divisible by the tile.
80 if (*rangeOnDim % *tileSize != 0)
81 return false;
82 }
83
84 return true;
85}
86
87/// Return failure or packed matmul with one of its operands transposed.
88static FailureOr<PackTransposeResult>
89transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
90 linalg::PackOp packOp, AffineMap operandMap,
91 ArrayRef<unsigned> blocksStartDimPos,
92 bool transposeOuterBlocks, bool transposeInnerBlocks) {
93 // TODO: Support Memref PackOp. Temporarily return failure.
94 if (!packOp.hasPureTensorSemantics())
95 return failure();
96
97 assert(operandMap.getNumDims() >= 4 &&
98 "expected at least 4D prepacked matmul");
99 assert(blocksStartDimPos.size() >= 2 &&
100 "expected starting outer and inner block positions");
101
102 // Bias toward innermost dimensions.
103 unsigned outerBlockPos = operandMap.getNumResults() - 4;
104 unsigned innerBlockPos = operandMap.getNumResults() - 2;
105
106 // Transpose control options define the desired block and element layout.
107 // Block transposition (outer dimensions) or element transposition (inner
108 // dimensions) may not be necessary depending on the original matmul data
109 // layout.
110 bool isOuterTransposed =
111 operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
112 bool isInnerTransposed =
113 operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
114
115 // Transpose only the dimensions that need that to conform to the provided
116 // transpotion settings.
117 SmallVector<int64_t> innerPerm = {0, 1};
118 if (isInnerTransposed != transposeInnerBlocks)
119 innerPerm = {1, 0};
120 SmallVector<int64_t> outerPerm = {0, 1};
121 if (isOuterTransposed != transposeOuterBlocks)
122 outerPerm = {1, 0};
123
124 // Leave the outer dimensions, like batch, unchanged by offsetting all
125 // outer dimensions permutations.
126 SmallVector<int64_t> offsetPerms;
127 for (auto i : llvm::seq(0u, outerBlockPos))
128 offsetPerms.push_back(i);
129 for (auto perm : outerPerm)
130 offsetPerms.push_back(perm + outerBlockPos);
131 outerPerm = offsetPerms;
132
133 FailureOr<PackTransposeResult> packTransposedMatmul =
134 packTranspose(rewriter, packOp, linalgOp,
135 /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
136
137 return packTransposedMatmul;
138}
139
140/// Pack a matmul operation into blocked 4D layout.
141FailureOr<PackResult>
142linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
143 const ControlBlockPackMatmulFn &controlPackMatmul) {
144 // Check to not let go the batch_matmul with extended semantic, through this
145 // transform.
146 if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
147 if (batchMatmulOp->hasUserDefinedMaps()) {
148 return rewriter.notifyMatchFailure(
149 *batchMatmulOp,
150 "only batch_matmul ops with non-extended semantics are supported");
151 }
152 }
153
154 if (linalgOp.hasPureBufferSemantics())
155 return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
156
157 std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
158 if (!options)
159 return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
160
161 if (options->blockFactors.size() != 3)
162 return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
163
165 getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
166
167 // If padding is disabled, make sure that dimensions can be packed cleanly.
168 if (!options->allowPadding &&
169 !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
170 return rewriter.notifyMatchFailure(linalgOp,
171 "expect packing full tiles only");
172 }
173
174 OpBuilder::InsertionGuard guard(rewriter);
175 // The op is replaced, we need to set the insertion point after it.
176 rewriter.setInsertionPointAfter(linalgOp);
177
178 // Pack the matmul operation into blocked layout with two levels of
179 // subdivision:
180 // - major 2D blocks - outer dimensions, consist of minor blocks
181 // - minor 2D blocks - inner dimensions, consist of scalar elements
182 FailureOr<PackResult> packedMatmul = packMatmulGreedily(
183 rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
184 options->mnkOrder);
185 if (failed(packedMatmul))
186 return failure();
187
188 assert(packedMatmul->packOps.size() == 3 &&
189 "invalid number of pack ops after matmul packing");
190 assert(packedMatmul->unPackOps.size() == 1 &&
191 "invalid number of unpack ops after matmul packing");
192
193 FailureOr<ContractionDimensions> contractDims =
194 inferContractionDims(packedMatmul->packedLinalgOp);
195 if (failed(contractDims))
196 return failure();
197
198 auto genericOp =
199 dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
200 SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
201
202 // Transpose LHS matrix according to the options.
203 FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
204 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
205 contractDims->m, options->lhsTransposeOuterBlocks,
206 options->lhsTransposeInnerBlocks);
207 if (failed(packedLhs))
208 return failure();
209
210 // Update results.
211 packedMatmul->packOps[0] = packedLhs->transposedPackOp;
212 packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
213
214 // Transpose RHS matrix according to the options.
215 FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
216 rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
217 contractDims->k, options->rhsTransposeOuterBlocks,
218 options->rhsTransposeInnerBlocks);
219 if (failed(packedRhs))
220 return failure();
221
222 // Update results.
223 packedMatmul->packOps[1] = packedRhs->transposedPackOp;
224 packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
225
226 return packedMatmul;
227}
228
229namespace {
230template <typename OpTy>
231struct BlockPackMatmul : public OpRewritePattern<OpTy> {
232 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
233 PatternBenefit benefit = 1)
234 : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
235
236 LogicalResult matchAndRewrite(OpTy linalgOp,
237 PatternRewriter &rewriter) const override {
238 FailureOr<PackResult> packedMatmul =
239 blockPackMatmul(rewriter, linalgOp, controlFn);
240 if (failed(packedMatmul))
241 return failure();
242 return success();
243 }
244
245private:
246 ControlBlockPackMatmulFn controlFn;
247};
248
249template <>
250struct BlockPackMatmul<linalg::GenericOp>
251 : public OpRewritePattern<linalg::GenericOp> {
252 BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
253 PatternBenefit benefit = 1)
254 : OpRewritePattern<linalg::GenericOp>(context, benefit),
255 controlFn(std::move(fun)) {}
256
257 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
258 PatternRewriter &rewriter) const override {
259 // Match suitable generics.
260 if (!linalg::isaContractionOpInterface(linalgOp)) {
261 return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
262 }
263
264 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
265 auto infer = [&](MapList m) {
266 return AffineMap::inferFromExprList(m, linalgOp.getContext());
267 };
268
269 AffineExpr i, j, k;
270 bindDims(linalgOp->getContext(), i, j, k);
271 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
272
273 // For now, only match simple matmuls.
274 if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
275 maps == infer({{k, i}, {k, j}, {i, j}}) ||
276 maps == infer({{i, k}, {j, k}, {i, j}}))) {
277 return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
278 }
279
280 FailureOr<PackResult> packedMatmul =
281 blockPackMatmul(rewriter, linalgOp, controlFn);
282 if (failed(packedMatmul))
283 return failure();
284 return success();
285 }
286
287private:
288 ControlBlockPackMatmulFn controlFn;
289};
290
291/// Convert linalg matmul ops to block layout and back.
292struct LinalgBlockPackMatmul
293 : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
294 using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
295
296 void runOnOperation() override {
297 Operation *op = getOperation();
298 RewritePatternSet patterns(&getContext());
299
300 ControlBlockPackMatmulFn controlFn =
301 [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
302 BlockPackMatmulOptions options;
303 options.blockFactors = SmallVector<int64_t>{*blockFactors};
304 options.allowPadding = allowPadding;
305 options.mnkPaddedSizesNextMultipleOf =
306 SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
307 if (!mnkOrder.empty())
308 options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
309 options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
310 options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
311 options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
312 options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
313 return options;
314 };
315
317 if (failed(applyPatternsGreedily(op, std::move(patterns))))
318 return signalPassFailure();
319 }
320};
321} // namespace
322
325 patterns.add<BlockPackMatmul<linalg::GenericOp>,
326 BlockPackMatmul<linalg::MatmulOp>,
327 BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
328 controlFn);
329}
return success()
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset