MLIR  22.0.0git
BlockPackMatmul.cpp
Go to the documentation of this file.
1 //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/SmallVector.h"
17 
18 #include <optional>
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
22 #include "mlir/Dialect/Linalg/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 /// Return constant range span or nullopt, otherwise.
29 static std::optional<int64_t> getConstantRange(const Range &range) {
30  std::optional<int64_t> stride = getConstantIntValue(range.stride);
31  if (!stride || *stride != 1)
32  return std::nullopt;
33  std::optional<int64_t> offset = getConstantIntValue(range.offset);
34  if (!offset)
35  return std::nullopt;
36  std::optional<int64_t> size = getConstantIntValue(range.size);
37  if (!size)
38  return std::nullopt;
39  return (*size - *offset);
40 }
41 
42 /// Return true if all dimensions are fully divisible by the respective tiles.
43 static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
45  ArrayRef<int64_t> dims) {
46  if (dims.size() != tiles.size() || tiles.empty())
47  return false;
48 
49  FailureOr<ContractionDimensions> contractDims =
50  inferContractionDims(linalgOp);
51  if (failed(contractDims))
52  return false;
53  unsigned batchDimsOffset = contractDims->batch.size();
54 
55  // Skip the batch dimension if present.
56  // Offset all dimensions accordingly.
57  SmallVector<int64_t, 3> offsetDims(dims);
58  for (int64_t &offsetDim : offsetDims)
59  offsetDim += batchDimsOffset;
60 
61  auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
62  OpBuilder builder(tileOp);
63  OpBuilder::InsertionGuard guard(builder);
64  SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
65 
66  for (auto dim : llvm::enumerate(offsetDims)) {
67  if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
68  return false;
69 
70  std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
71  std::optional<int64_t> rangeOnDim =
72  getConstantRange(iterationDomain[dim.value()]);
73 
74  // If the tile factor or the range are non-constant, the tile size is
75  // considered to be invalid.
76  if (!tileSize || !rangeOnDim)
77  return false;
78 
79  // The dimension must be fully divisible by the tile.
80  if (*rangeOnDim % *tileSize != 0)
81  return false;
82  }
83 
84  return true;
85 }
86 
87 /// Return failure or packed matmul with one of its operands transposed.
88 static FailureOr<PackTransposeResult>
89 transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
90  linalg::PackOp packOp, AffineMap operandMap,
91  ArrayRef<unsigned> blocksStartDimPos,
92  bool transposeOuterBlocks, bool transposeInnerBlocks) {
93  assert(operandMap.getNumDims() >= 4 &&
94  "expected at least 4D prepacked matmul");
95  assert(blocksStartDimPos.size() >= 2 &&
96  "expected starting outer and inner block positions");
97 
98  // Bias toward innermost dimensions.
99  unsigned outerBlockPos = operandMap.getNumResults() - 4;
100  unsigned innerBlockPos = operandMap.getNumResults() - 2;
101 
102  // Transpose control options define the desired block and element layout.
103  // Block transposition (outer dimensions) or element transposition (inner
104  // dimensions) may not be necessary depending on the original matmul data
105  // layout.
106  bool isOuterTransposed =
107  operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
108  bool isInnerTransposed =
109  operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
110 
111  // Transpose only the dimensions that need that to conform to the provided
112  // transpotion settings.
113  SmallVector<int64_t> innerPerm = {0, 1};
114  if (isInnerTransposed != transposeInnerBlocks)
115  innerPerm = {1, 0};
116  SmallVector<int64_t> outerPerm = {0, 1};
117  if (isOuterTransposed != transposeOuterBlocks)
118  outerPerm = {1, 0};
119 
120  // Leave the outer dimensions, like batch, unchanged by offsetting all
121  // outer dimensions permutations.
122  SmallVector<int64_t> offsetPerms;
123  for (auto i : llvm::seq(0u, outerBlockPos))
124  offsetPerms.push_back(i);
125  for (auto perm : outerPerm)
126  offsetPerms.push_back(perm + outerBlockPos);
127  outerPerm = offsetPerms;
128 
129  FailureOr<PackTransposeResult> packTransposedMatmul =
130  packTranspose(rewriter, packOp, linalgOp,
131  /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
132 
133  return packTransposedMatmul;
134 }
135 
136 /// Pack a matmul operation into blocked 4D layout.
137 FailureOr<PackResult>
138 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
139  const ControlBlockPackMatmulFn &controlPackMatmul) {
140  // Check to not let go the batch_matmul with extended semantic, through this
141  // transform.
142  if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
143  if (batchMatmulOp->hasUserDefinedMaps()) {
144  return rewriter.notifyMatchFailure(
145  *batchMatmulOp,
146  "only batch_matmul ops with non-extended semantics are supported");
147  }
148  }
149 
150  if (linalgOp.hasPureBufferSemantics())
151  return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
152 
153  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
154  if (!options)
155  return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
156 
157  if (options->blockFactors.size() != 3)
158  return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
159 
160  SmallVector<OpFoldResult> mnkTiles =
161  getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
162 
163  // If padding is disabled, make sure that dimensions can be packed cleanly.
164  if (!options->allowPadding &&
165  !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
166  return rewriter.notifyMatchFailure(linalgOp,
167  "expect packing full tiles only");
168  }
169 
170  OpBuilder::InsertionGuard guard(rewriter);
171  // The op is replaced, we need to set the insertion point after it.
172  rewriter.setInsertionPointAfter(linalgOp);
173 
174  // Pack the matmul operation into blocked layout with two levels of
175  // subdivision:
176  // - major 2D blocks - outer dimensions, consist of minor blocks
177  // - minor 2D blocks - inner dimensions, consist of scalar elements
178  FailureOr<PackResult> packedMatmul = packMatmulGreedily(
179  rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
180  options->mnkOrder);
181  if (failed(packedMatmul))
182  return failure();
183 
184  assert(packedMatmul->packOps.size() == 3 &&
185  "invalid number of pack ops after matmul packing");
186  assert(packedMatmul->unPackOps.size() == 1 &&
187  "invalid number of unpack ops after matmul packing");
188 
189  FailureOr<ContractionDimensions> contractDims =
190  inferContractionDims(packedMatmul->packedLinalgOp);
191  if (failed(contractDims))
192  return failure();
193 
194  auto genericOp =
195  dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
196  SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
197 
198  // Transpose LHS matrix according to the options.
199  FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
200  rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
201  contractDims->m, options->lhsTransposeOuterBlocks,
202  options->lhsTransposeInnerBlocks);
203  if (failed(packedLhs))
204  return failure();
205 
206  // Update results.
207  packedMatmul->packOps[0] = packedLhs->transposedPackOp;
208  packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
209 
210  // Transpose RHS matrix according to the options.
211  FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
212  rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
213  contractDims->k, options->rhsTransposeOuterBlocks,
214  options->rhsTransposeInnerBlocks);
215  if (failed(packedRhs))
216  return failure();
217 
218  // Update results.
219  packedMatmul->packOps[1] = packedRhs->transposedPackOp;
220  packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
221 
222  return packedMatmul;
223 }
224 
225 namespace {
226 template <typename OpTy>
227 struct BlockPackMatmul : public OpRewritePattern<OpTy> {
228  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
229  PatternBenefit benefit = 1)
230  : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
231 
232  LogicalResult matchAndRewrite(OpTy linalgOp,
233  PatternRewriter &rewriter) const override {
234  FailureOr<PackResult> packedMatmul =
235  blockPackMatmul(rewriter, linalgOp, controlFn);
236  if (failed(packedMatmul))
237  return failure();
238  return success();
239  }
240 
241 private:
242  ControlBlockPackMatmulFn controlFn;
243 };
244 
245 template <>
246 struct BlockPackMatmul<linalg::GenericOp>
247  : public OpRewritePattern<linalg::GenericOp> {
248  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
249  PatternBenefit benefit = 1)
250  : OpRewritePattern<linalg::GenericOp>(context, benefit),
251  controlFn(std::move(fun)) {}
252 
253  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
254  PatternRewriter &rewriter) const override {
255  // Match suitable generics.
256  if (!linalg::isaContractionOpInterface(linalgOp)) {
257  return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
258  }
259 
260  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
261  auto infer = [&](MapList m) {
262  return AffineMap::inferFromExprList(m, linalgOp.getContext());
263  };
264 
265  AffineExpr i, j, k;
266  bindDims(linalgOp->getContext(), i, j, k);
267  SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
268 
269  // For now, only match simple matmuls.
270  if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
271  maps == infer({{k, i}, {k, j}, {i, j}}) ||
272  maps == infer({{i, k}, {j, k}, {i, j}}))) {
273  return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
274  }
275 
276  FailureOr<PackResult> packedMatmul =
277  blockPackMatmul(rewriter, linalgOp, controlFn);
278  if (failed(packedMatmul))
279  return failure();
280  return success();
281  }
282 
283 private:
284  ControlBlockPackMatmulFn controlFn;
285 };
286 
287 /// Convert linalg matmul ops to block layout and back.
288 struct LinalgBlockPackMatmul
289  : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
290  using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
291 
292  void runOnOperation() override {
293  Operation *op = getOperation();
295 
296  ControlBlockPackMatmulFn controlFn =
297  [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
299  options.blockFactors = SmallVector<int64_t>{*blockFactors};
300  options.allowPadding = allowPadding;
301  options.mnkPaddedSizesNextMultipleOf =
302  SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
303  if (!mnkOrder.empty())
304  options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
305  options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
306  options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
307  options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
308  options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
309  return options;
310  };
311 
313  if (failed(applyPatternsGreedily(op, std::move(patterns))))
314  return signalPassFailure();
315  }
316 };
317 } // namespace
318 
321  patterns.add<BlockPackMatmul<linalg::GenericOp>,
322  BlockPackMatmul<linalg::MatmulOp>,
323  BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
324  controlFn);
325 }
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:657
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:748
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
Definition: Transforms.h:1339
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.