MLIR  21.0.0git
BlockPackMatmul.cpp
Go to the documentation of this file.
1 //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 #include <optional>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23 #include "mlir/Dialect/Linalg/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Return constant range span or nullopt, otherwise.
30 static std::optional<int64_t> getConstantRange(const Range &range) {
31  std::optional<int64_t> stride = getConstantIntValue(range.stride);
32  if (!stride || *stride != 1)
33  return std::nullopt;
34  std::optional<int64_t> offset = getConstantIntValue(range.offset);
35  if (!offset)
36  return std::nullopt;
37  std::optional<int64_t> size = getConstantIntValue(range.size);
38  if (!size)
39  return std::nullopt;
40  return (*size - *offset);
41 }
42 
43 /// Return true if all dimensions are fully divisible by the respective tiles.
44 static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
46  ArrayRef<int64_t> dims) {
47  if (dims.size() != tiles.size() || tiles.empty())
48  return false;
49 
50  FailureOr<ContractionDimensions> contractDims =
51  inferContractionDims(linalgOp);
52  if (failed(contractDims))
53  return false;
54  unsigned batchDimsOffset = contractDims->batch.size();
55 
56  // Skip the batch dimension if present.
57  // Offset all dimensions accordingly.
58  SmallVector<int64_t, 3> offsetDims(dims);
59  for (size_t i = 0; i < offsetDims.size(); i++)
60  offsetDims[i] += batchDimsOffset;
61 
62  auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
63  OpBuilder builder(tileOp);
64  OpBuilder::InsertionGuard guard(builder);
65  SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
66 
67  for (auto dim : llvm::enumerate(offsetDims)) {
68  if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
69  return false;
70 
71  std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
72  std::optional<int64_t> rangeOnDim =
73  getConstantRange(iterationDomain[dim.value()]);
74 
75  // If the tile factor or the range are non-constant, the tile size is
76  // considered to be invalid.
77  if (!tileSize || !rangeOnDim)
78  return false;
79 
80  // The dimension must be fully divisible by the tile.
81  if (*rangeOnDim % *tileSize != 0)
82  return false;
83  }
84 
85  return true;
86 }
87 
88 /// Return failure or packed matmul with one of its operands transposed.
89 static FailureOr<PackTransposeResult>
90 transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
91  linalg::PackOp packOp, AffineMap operandMap,
92  ArrayRef<unsigned> blocksStartDimPos,
93  bool transposeOuterBlocks, bool transposeInnerBlocks) {
94  assert(operandMap.getNumDims() >= 4 &&
95  "expected at least 4D prepacked matmul");
96  assert(blocksStartDimPos.size() >= 2 &&
97  "expected starting outer and inner block positions");
98 
99  // Bias toward innermost dimensions.
100  unsigned outerBlockPos = operandMap.getNumResults() - 4;
101  unsigned innerBlockPos = operandMap.getNumResults() - 2;
102 
103  // Transpose control options define the desired block and element layout.
104  // Block transposition (outer dimensions) or element transposition (inner
105  // dimensions) may not be necessary depending on the original matmul data
106  // layout.
107  bool isOuterTransposed =
108  operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
109  bool isInnerTransposed =
110  operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
111 
112  // Transpose only the dimensions that need that to conform to the provided
113  // transpotion settings.
114  SmallVector<int64_t> innerPerm = {0, 1};
115  if (isInnerTransposed != transposeInnerBlocks)
116  innerPerm = {1, 0};
117  SmallVector<int64_t> outerPerm = {0, 1};
118  if (isOuterTransposed != transposeOuterBlocks)
119  outerPerm = {1, 0};
120 
121  // Leave the outer dimensions, like batch, unchanged by offsetting all
122  // outer dimensions permutations.
123  SmallVector<int64_t> offsetPerms;
124  for (auto i : llvm::seq(0u, outerBlockPos))
125  offsetPerms.push_back(i);
126  for (auto perm : outerPerm)
127  offsetPerms.push_back(perm + outerBlockPos);
128  outerPerm = offsetPerms;
129 
130  FailureOr<PackTransposeResult> packTransposedMatmul =
131  packTranspose(rewriter, packOp, linalgOp,
132  /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
133 
134  return packTransposedMatmul;
135 }
136 
137 /// Pack a matmul operation into blocked 4D layout.
138 FailureOr<PackResult>
139 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
140  const ControlBlockPackMatmulFn &controlPackMatmul) {
141  // Check to not let go the batch_matmul with extended semantic, through this
142  // transform.
143  if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
144  if (batchMatmulOp->hasUserDefinedMaps()) {
145  return rewriter.notifyMatchFailure(
146  *batchMatmulOp,
147  "only batch_matmul ops with non-extended semantics are supported");
148  }
149  }
150 
151  if (linalgOp.hasPureBufferSemantics())
152  return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
153 
154  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
155  if (!options)
156  return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
157 
158  if (options->blockFactors.size() != 3)
159  return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
160 
161  SmallVector<OpFoldResult> mnkTiles =
162  getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
163 
164  // If padding is disabled, make sure that dimensions can be packed cleanly.
165  if (!options->allowPadding &&
166  !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
167  return rewriter.notifyMatchFailure(linalgOp,
168  "expect packing full tiles only");
169  }
170 
171  OpBuilder::InsertionGuard guard(rewriter);
172  // The op is replaced, we need to set the insertion point after it.
173  rewriter.setInsertionPointAfter(linalgOp);
174 
175  // Pack the matmul operation into blocked layout with two levels of
176  // subdivision:
177  // - major 2D blocks - outer dimensions, consist of minor blocks
178  // - minor 2D blocks - inner dimensions, consist of scalar elements
179  FailureOr<PackResult> packedMatmul = packMatmulGreedily(
180  rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
181  options->mnkOrder);
182  if (failed(packedMatmul))
183  return failure();
184 
185  assert(packedMatmul->packOps.size() == 3 &&
186  "invalid number of pack ops after matmul packing");
187  assert(packedMatmul->unPackOps.size() == 1 &&
188  "invalid number of unpack ops after matmul packing");
189 
190  FailureOr<ContractionDimensions> contractDims =
191  inferContractionDims(packedMatmul->packedLinalgOp);
192  if (failed(contractDims))
193  return failure();
194 
195  auto genericOp =
196  dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
197  SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
198 
199  // Transpose LHS matrix according to the options.
200  FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
201  rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
202  contractDims->m, options->lhsTransposeOuterBlocks,
203  options->lhsTransposeInnerBlocks);
204  if (failed(packedLhs))
205  return failure();
206 
207  // Update results.
208  packedMatmul->packOps[0] = packedLhs->transposedPackOp;
209  packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
210 
211  // Transpose RHS matrix according to the options.
212  FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
213  rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
214  contractDims->k, options->rhsTransposeOuterBlocks,
215  options->rhsTransposeInnerBlocks);
216  if (failed(packedRhs))
217  return failure();
218 
219  // Update results.
220  packedMatmul->packOps[1] = packedRhs->transposedPackOp;
221  packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
222 
223  return packedMatmul;
224 }
225 
226 namespace {
227 template <typename OpTy>
228 struct BlockPackMatmul : public OpRewritePattern<OpTy> {
229  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
230  PatternBenefit benefit = 1)
231  : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
232 
233  LogicalResult matchAndRewrite(OpTy linalgOp,
234  PatternRewriter &rewriter) const override {
235  FailureOr<PackResult> packedMatmul =
236  blockPackMatmul(rewriter, linalgOp, controlFn);
237  if (failed(packedMatmul))
238  return failure();
239  return success();
240  }
241 
242 private:
243  ControlBlockPackMatmulFn controlFn;
244 };
245 
246 template <>
247 struct BlockPackMatmul<linalg::GenericOp>
248  : public OpRewritePattern<linalg::GenericOp> {
249  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
250  PatternBenefit benefit = 1)
251  : OpRewritePattern<linalg::GenericOp>(context, benefit),
252  controlFn(std::move(fun)) {}
253 
254  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
255  PatternRewriter &rewriter) const override {
256  // Match suitable generics.
257  if (!linalg::isaContractionOpInterface(linalgOp)) {
258  return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
259  }
260 
261  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
262  auto infer = [&](MapList m) {
263  return AffineMap::inferFromExprList(m, linalgOp.getContext());
264  };
265 
266  AffineExpr i, j, k;
267  bindDims(linalgOp->getContext(), i, j, k);
268  SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
269 
270  // For now, only match simple matmuls.
271  if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
272  maps == infer({{k, i}, {k, j}, {i, j}}) ||
273  maps == infer({{i, k}, {j, k}, {i, j}}))) {
274  return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
275  }
276 
277  FailureOr<PackResult> packedMatmul =
278  blockPackMatmul(rewriter, linalgOp, controlFn);
279  if (failed(packedMatmul))
280  return failure();
281  return success();
282  }
283 
284 private:
285  ControlBlockPackMatmulFn controlFn;
286 };
287 
288 /// Convert linalg matmul ops to block layout and back.
289 struct LinalgBlockPackMatmul
290  : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
291  using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
292 
293  void runOnOperation() override {
294  Operation *op = getOperation();
296 
297  ControlBlockPackMatmulFn controlFn =
298  [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
300  options.blockFactors = SmallVector<int64_t>{*blockFactors};
301  options.allowPadding = allowPadding;
302  options.mnkPaddedSizesNextMultipleOf =
303  SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
304  if (!mnkOrder.empty())
305  options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
306  options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
307  options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
308  options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
309  options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
310  return options;
311  };
312 
314  if (failed(applyPatternsGreedily(op, std::move(patterns))))
315  return signalPassFailure();
316  }
317 };
318 } // namespace
319 
322  patterns.add<BlockPackMatmul<linalg::GenericOp>,
323  BlockPackMatmul<linalg::MatmulOp>,
324  BlockPackMatmul<linalg::BatchMatmulOp>,
325  BlockPackMatmul<linalg::MatmulTransposeAOp>,
326  BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
327  BlockPackMatmul<linalg::MatmulTransposeBOp>,
328  BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
329  patterns.getContext(), controlFn);
330 }
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:677
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:768
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
Definition: Transforms.h:1218
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.