MLIR  20.0.0git
BlockPackMatmul.cpp
Go to the documentation of this file.
1 //===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/PatternMatch.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 #include <optional>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
23 #include "mlir/Dialect/Linalg/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Return constant range span or nullopt, otherwise.
30 static std::optional<int64_t> getConstantRange(const Range &range) {
31  std::optional<int64_t> stride = getConstantIntValue(range.stride);
32  if (!stride || *stride != 1)
33  return std::nullopt;
34  std::optional<int64_t> offset = getConstantIntValue(range.offset);
35  if (!offset)
36  return std::nullopt;
37  std::optional<int64_t> size = getConstantIntValue(range.size);
38  if (!size)
39  return std::nullopt;
40  return (*size - *offset);
41 }
42 
43 /// Return true if all dimensions are fully divisible by the respective tiles.
44 static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
46  ArrayRef<int64_t> dims) {
47  if (dims.size() != tiles.size() || tiles.empty())
48  return false;
49 
50  FailureOr<ContractionDimensions> contractDims =
51  inferContractionDims(linalgOp);
52  if (failed(contractDims))
53  return false;
54  unsigned batchDimsOffset = contractDims->batch.size();
55 
56  // Skip the batch dimension if present.
57  // Offset all dimensions accordingly.
58  SmallVector<int64_t, 3> offsetDims{dims};
59  for (size_t i = 0; i < offsetDims.size(); i++)
60  offsetDims[i] += batchDimsOffset;
61 
62  auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
63  OpBuilder builder(tileOp);
64  OpBuilder::InsertionGuard guard(builder);
65  SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
66 
67  for (auto dim : llvm::enumerate(offsetDims)) {
68  if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
69  return false;
70 
71  std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
72  std::optional<int64_t> rangeOnDim =
73  getConstantRange(iterationDomain[dim.value()]);
74 
75  // If the tile factor or the range are non-constant, the tile size is
76  // considered to be invalid.
77  if (!tileSize || !rangeOnDim)
78  return false;
79 
80  // The dimension must be fully divisible by the tile.
81  if (*rangeOnDim % *tileSize != 0)
82  return false;
83  }
84 
85  return true;
86 }
87 
88 /// Return failure or packed matmul with one of its operands transposed.
89 static FailureOr<PackTransposeResult>
90 transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
91  tensor::PackOp packOp, AffineMap operandMap,
92  ArrayRef<unsigned> blocksStartDimPos,
93  bool transposeOuterBlocks, bool transposeInnerBlocks) {
94  assert(operandMap.getNumDims() >= 4 &&
95  "expected at least 4D prepacked matmul");
96  assert(blocksStartDimPos.size() >= 2 &&
97  "expected starting outer and inner block positions");
98 
99  // Bias toward innermost dimensions.
100  unsigned outerBlockPos = operandMap.getNumResults() - 4;
101  unsigned innerBlockPos = operandMap.getNumResults() - 2;
102 
103  // Transpose control options define the desired block and element layout.
104  // Block transposition (outer dimensions) or element transposition (inner
105  // dimensions) may not be necessary depending on the original matmul data
106  // layout.
107  bool isOuterTransposed =
108  operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
109  bool isInnerTransposed =
110  operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
111 
112  // Transpose only the dimensions that need that to conform to the provided
113  // transpotion settings.
114  SmallVector<int64_t> innerPerm{0, 1};
115  if (isInnerTransposed != transposeInnerBlocks)
116  innerPerm = {1, 0};
117  SmallVector<int64_t> outerPerm{0, 1};
118  if (isOuterTransposed != transposeOuterBlocks)
119  outerPerm = {1, 0};
120 
121  // Leave the outer dimensions, like batch, unchanged by offsetting all
122  // outer dimensions permutations.
123  SmallVector<int64_t> offsetPerms;
124  for (auto i : llvm::seq(0u, outerBlockPos))
125  offsetPerms.push_back(i);
126  for (auto perm : outerPerm)
127  offsetPerms.push_back(perm + outerBlockPos);
128  outerPerm = offsetPerms;
129 
130  FailureOr<PackTransposeResult> packTransposedMatmul =
131  packTranspose(rewriter, packOp, linalgOp,
132  /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
133 
134  return packTransposedMatmul;
135 }
136 
137 /// Pack a matmul operation into blocked 4D layout.
138 FailureOr<PackResult>
139 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
140  const ControlBlockPackMatmulFn &controlPackMatmul) {
141  if (linalgOp.hasPureBufferSemantics())
142  return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
143 
144  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
145  if (!options)
146  return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
147 
148  if (options->blockFactors.size() != 3)
149  return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
150 
151  SmallVector<OpFoldResult> mnkTiles =
152  getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
153 
154  // If padding is disabled, make sure that dimensions can be packed cleanly.
155  if (!options->allowPadding &&
156  !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
157  return rewriter.notifyMatchFailure(linalgOp,
158  "expect packing full tiles only");
159  }
160 
161  OpBuilder::InsertionGuard guard(rewriter);
162  // The op is replaced, we need to set the insertion point after it.
163  rewriter.setInsertionPointAfter(linalgOp);
164 
165  // Pack the matmul operation into blocked layout with two levels of
166  // subdivision:
167  // - major 2D blocks - outer dimensions, consist of minor blocks
168  // - minor 2D blocks - inner dimensions, consist of scalar elements
169  FailureOr<PackResult> packedMatmul = packMatmulGreedily(
170  rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
171  options->mnkOrder);
172  if (failed(packedMatmul))
173  return failure();
174 
175  assert(packedMatmul->packOps.size() == 3 &&
176  "invalid number of pack ops after matmul packing");
177  assert(packedMatmul->unPackOps.size() == 1 &&
178  "invalid number of unpack ops after matmul packing");
179 
180  FailureOr<ContractionDimensions> contractDims =
181  inferContractionDims(packedMatmul->packedLinalgOp);
182  if (failed(contractDims))
183  return failure();
184 
185  auto genericOp =
186  dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
187  SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
188 
189  // Transpose LHS matrix according to the options.
190  FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
191  rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
192  contractDims->m, options->lhsTransposeOuterBlocks,
193  options->lhsTransposeInnerBlocks);
194  if (failed(packedLhs))
195  return failure();
196 
197  // Update results.
198  packedMatmul->packOps[0] = packedLhs->transposedPackOp;
199  packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
200 
201  // Transpose RHS matrix according to the options.
202  FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
203  rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
204  contractDims->k, options->rhsTransposeOuterBlocks,
205  options->rhsTransposeInnerBlocks);
206  if (failed(packedRhs))
207  return failure();
208 
209  // Update results.
210  packedMatmul->packOps[1] = packedRhs->transposedPackOp;
211  packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
212 
213  return packedMatmul;
214 }
215 
216 namespace {
217 template <typename OpTy>
218 struct BlockPackMatmul : public OpRewritePattern<OpTy> {
219  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
220  PatternBenefit benefit = 1)
221  : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
222 
223  LogicalResult matchAndRewrite(OpTy linalgOp,
224  PatternRewriter &rewriter) const override {
225  FailureOr<PackResult> packedMatmul =
226  blockPackMatmul(rewriter, linalgOp, controlFn);
227  if (failed(packedMatmul))
228  return failure();
229  return success();
230  }
231 
232 private:
233  ControlBlockPackMatmulFn controlFn;
234 };
235 
236 template <>
237 struct BlockPackMatmul<linalg::GenericOp>
238  : public OpRewritePattern<linalg::GenericOp> {
239  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
240  PatternBenefit benefit = 1)
241  : OpRewritePattern<linalg::GenericOp>(context, benefit),
242  controlFn(std::move(fun)) {}
243 
244  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
245  PatternRewriter &rewriter) const override {
246  // Match suitable generics.
247  if (!linalg::isaContractionOpInterface(linalgOp)) {
248  return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
249  }
250 
251  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
252  auto infer = [&](MapList m) {
253  return AffineMap::inferFromExprList(m, linalgOp.getContext());
254  };
255 
256  AffineExpr i, j, k;
257  bindDims(linalgOp->getContext(), i, j, k);
258  SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
259 
260  // For now, only match simple matmuls.
261  if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
262  maps == infer({{k, i}, {k, j}, {i, j}}) ||
263  maps == infer({{i, k}, {j, k}, {i, j}}))) {
264  return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
265  }
266 
267  FailureOr<PackResult> packedMatmul =
268  blockPackMatmul(rewriter, linalgOp, controlFn);
269  if (failed(packedMatmul))
270  return failure();
271  return success();
272  }
273 
274 private:
275  ControlBlockPackMatmulFn controlFn;
276 };
277 
278 /// Convert linalg matmul ops to block layout and back.
279 struct LinalgBlockPackMatmul
280  : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
281  using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
282 
283  void runOnOperation() override {
284  Operation *op = getOperation();
286 
287  ControlBlockPackMatmulFn controlFn =
288  [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
290  options.blockFactors = SmallVector<int64_t>{*blockFactors};
291  options.allowPadding = allowPadding;
292  options.mnkPaddedSizesNextMultipleOf =
293  SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
294  if (!mnkOrder.empty())
295  options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
296  options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
297  options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
298  options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
299  options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
300  return options;
301  };
302 
304  if (failed(applyPatternsGreedily(op, std::move(patterns))))
305  return signalPassFailure();
306  }
307 };
308 } // namespace
309 
312  patterns.add<BlockPackMatmul<linalg::GenericOp>,
313  BlockPackMatmul<linalg::MatmulOp>,
314  BlockPackMatmul<linalg::BatchMatmulOp>,
315  BlockPackMatmul<linalg::MatmulTransposeAOp>,
316  BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
317  BlockPackMatmul<linalg::MatmulTransposeBOp>,
318  BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
319  patterns.getContext(), controlFn);
320 }
static FailureOr< PackTransposeResult > transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, tensor::PackOp packOp, AffineMap operandMap, ArrayRef< unsigned > blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks)
Return failure or packed matmul with one of its operands transposed.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > tiles, ArrayRef< int64_t > dims)
Return true if all dimensions are fully divisible by the respective tiles.
static std::optional< int64_t > getConstantRange(const Range &range)
Return constant range span or nullopt, otherwise.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn)
Patterns to block pack Linalg matmul ops.
FailureOr< PackResult > blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const ControlBlockPackMatmulFn &controlPackMatmul)
Pack a matmul operation into blocked 4D layout.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:768
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::function< std::optional< BlockPackMatmulOptions >(linalg::LinalgOp)> ControlBlockPackMatmulFn
Function type which is used to control matmul packing.
Definition: Transforms.h:1217
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:677
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.