MLIR  22.0.0git
LowerContractToNeonPatterns.cpp
Go to the documentation of this file.
1 //===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to operations
10 // that map to instructions from the Neon FEAT_I8MM extension.
11 //
12 // TODO: There may be opportunities to unify this with a similar pattern
13 // for SVE. See:
14 // https://github.com/llvm/llvm-project/issues/145559
15 // LowerContractToSVEPatterns.cpp
16 //
17 //===----------------------------------------------------------------------===//
18 
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/PatternMatch.h"
27 
28 #define DEBUG_TYPE "lower-contract-to-arm-neon"
29 
30 using namespace mlir;
31 using namespace mlir::arm_neon;
32 
33 namespace {
34 /// Get the operand of a `vector.contract`. This function is intended to
35 /// abstract away from the particular way a value is extended before feeding it
36 /// into the `vector.contract` - via zero-extend or an explicit or implicit
37 /// sign-extend (for implicit sign-extension see `vector.contract`
38 /// documentation).
39 ///
40 /// The template parameter `Op` indicates the extension operation (explicit or
41 /// implicit) for which we are checking.
42 ///
43 // Return success only for extensions from `iN` (N <= 8) to `i32`.
44 template <typename Op>
45 std::optional<Value> getExtOperand(Value v) {
46 
47  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
48  "Must be instantiated with either sign- or zero- extension op");
49 
50  // If the operand is not defined by an explicit extend operation of the
51  // accepted operation type allow for an implicit sign-extension.
52  auto extOp = v.getDefiningOp<Op>();
53  if (!extOp) {
54  if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
55  auto eltTy = cast<VectorType>(v.getType()).getElementType();
56  if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
57  return {};
58  return v;
59  }
60  return {};
61  }
62 
63  // If the operand is defined by an explicit extend operation of the accepted
64  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
65  auto inOp = extOp.getIn();
66  auto inTy = dyn_cast<VectorType>(inOp.getType());
67  if (!inTy)
68  return {};
69  auto inEltTy = inTy.getElementType();
70  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
71  return {};
72 
73  auto outTy = dyn_cast<VectorType>(extOp.getType());
74  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
75  return {};
76 
77  return inOp;
78 }
79 
80 /// Helper function to extend a vector with elements iN, N < 8 to
81 /// a vector of i8. Do sign extension if the parameter `signExt` is true,
82 /// zero extension otherwise.
83 Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
84  bool signExt, PatternRewriter &rewriter) {
85  Type targetTy = srcTy.clone(rewriter.getI8Type());
86  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
87  : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
88 }
89 
90 class VectorContractRewriter {
91 protected:
92  // Designate the operation (resp. instruction) used to do sub-tile matrix
93  // multiplications.
94  enum class MMLA {
95  Nop,
96  SignedInt, // smmla
97  UnsignedInt, // ummla
98  MixedInt, // usmmla
99  Bfloat // bfmmla
100  };
101 
102  // Lower-level operation to be emitted.
103  MMLA mmlaOp = MMLA::Nop;
104 
105  // Indicate if the operands for the ArmNeon dialect operation need to be
106  // swapped. Currently this is needed in order to emulate an "summla"
107  // operation.
108  bool swapOperands = false;
109 
110  // The operand tiles. These are not necessarily the operands of
111  // `vector.contract`, for example they could be operands to `arith.extsi`
112  // that is in turn fed into `vector.contract`.
113  Value lhs;
114  Value rhs;
115  Value acc;
116 
117  // The dimensions logically corresponding to matrix multiplication of
118  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
119  // shapes, for example RHS could be NxK with a transposing indexing map.
120  int64_t dimM = 0;
121  int64_t dimN = 0;
122  int64_t dimK = 0;
123 
124  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
125  SmallVector<int64_t> iterationBounds;
126 
127  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
128  // of this shape.
129  SmallVector<int64_t> subTileShape;
130 
131  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
132  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
133  Value lhs, Value rhs) {
134 
135  if (swapOperands)
136  std::swap(lhs, rhs);
137  switch (mmlaOp) {
138  case MMLA::SignedInt:
139  return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
140  lhs, rhs);
141  case MMLA::UnsignedInt:
142  return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
143  lhs, rhs);
144  case MMLA::MixedInt:
145  return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
146  lhs, rhs);
147  case MMLA::Bfloat:
148  return arm_neon::BfmmlaOp::create(rewriter, loc, acc.getType(), acc, lhs,
149  rhs);
150  case MMLA::Nop:
151  llvm_unreachable("Uninitialized operation type");
152  }
153  }
154 
155  // Check common preconditions for applying the patterns and initialize
156  // logical dimensions.
157  LogicalResult matchAndInit(vector::ContractionOp op,
158  PatternRewriter &rewriter) {
159  // Check iterator types for matrix multiplication.
160  SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
161  if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
162  itTypes[1] != vector::IteratorType::parallel ||
163  itTypes[2] != vector::IteratorType::reduction) &&
164  (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel ||
165  itTypes[1] != vector::IteratorType::reduction))
166  return rewriter.notifyMatchFailure(
167  op, "iterator types do not correspond to matrix multiplication");
168 
169  // Avoid 0-D vectors and 1-D rhs:
170  VectorType lhsType = op.getLhsType();
171  VectorType rhsType = op.getRhsType();
172  if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
173  rhsType.getRank() != 2)
174  return rewriter.notifyMatchFailure(op, "Invalid operand rank");
175 
176  // This codegen does not work for scalable vectors. Return failure so this
177  // pattern is not accidentally chosen over patterns that lower to ArmSVE.
178  if (lhsType.isScalable() || rhsType.isScalable())
179  return rewriter.notifyMatchFailure(op,
180  "Not applicable to scalable vectors");
181 
182  // Initialize dimensions and check for a matching K dimension.
183  dimM = lhsType.getDimSize(0);
184  dimN = rhsType.getDimSize(0);
185  dimK = rhsType.getDimSize(1);
186 
187  int64_t lhsDimK;
188  if (lhsType.getRank() == 1) {
189  dimM = 1;
190  lhsDimK = lhsType.getDimSize(0);
191  } else {
192  lhsDimK = lhsType.getDimSize(1);
193  }
194 
195  if (lhsDimK != dimK)
196  return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
197 
198  return success();
199  }
200 
201 public:
202  void lower(vector::ContractionOp op, PatternRewriter &rewriter) {
203  // Create some convenience types.
204  auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
205  auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
206  auto inputExpandedType =
207  VectorType::get({2, subTileShape.back()}, inputElementType);
208  auto outputExpandedType = VectorType::get({2, 2}, accElementType);
209 
210  // One-dimensional representation of logical sub-tiles as required by the
211  // ArmNeon ops.
212  auto collapsedInputType =
213  VectorType::get(inputExpandedType.getNumElements(), inputElementType);
214  auto collapsedOutputType =
215  VectorType::get(outputExpandedType.getNumElements(), accElementType);
216 
217  // Get indexing maps for a more concise/convenient access.
218  auto indexingMaps = op.getIndexingMapsArray();
219  AffineMap &lhsPermutationMap = indexingMaps[0];
220  AffineMap &rhsPermutationMap = indexingMaps[1];
221  AffineMap &accPermutationMap = indexingMaps[2];
222 
223  Location loc = op.getLoc();
224 
225  // Initial accumulator for the final result. This is the un-tiled result if
226  // tiling is done.
227  Value result =
228  arith::ConstantOp::create(rewriter, loc, op.getResultType(),
229  rewriter.getZeroAttr(op.getResultType()));
230 
231  SmallVector<int64_t, 3> loopOrder = {0, 1};
232  if (iterationBounds.size() == 3)
233  loopOrder.push_back(2);
234 
235  // Keep track of the previous accumulator when tiling over K.
236  Value kAcc;
237  for (SmallVector<int64_t> offsets :
238  StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
239  // Helper to compute the new shape of each operand and extract the slice.
240  auto extractOperand = [&](Value operand, AffineMap permutationMap,
241  ArrayRef<int64_t> operandOffsets) {
243  permutationMap, ArrayRef<int64_t>(subTileShape));
244  SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
245  return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
246  loc, operand, operandOffsets, operandShape, operandStrides);
247  };
248 
249  // Extract tiled lhs, rhs, and acc
250  SmallVector<int64_t> lhsOffsets =
251  applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
252  Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
253  SmallVector<int64_t> rhsOffsets =
254  applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
255  Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
256  SmallVector<int64_t> accOffsets =
257  applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
258  Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
259 
260  // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
261  // rows along dimM. Expand their shapes to match the ArmNeon op.
262  if (dimM == 1) {
263  auto expandRowVector = [&](Value tiledOperand,
264  VectorType expandedTypeType) {
265  auto emptyOperand =
266  arith::ConstantOp::create(rewriter, loc, expandedTypeType,
267  rewriter.getZeroAttr(expandedTypeType));
268  SmallVector<int64_t> offsets(
269  cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
270  SmallVector<int64_t> strides(
271  cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
272  return rewriter.createOrFold<vector::InsertStridedSliceOp>(
273  loc, tiledOperand, emptyOperand, offsets, strides);
274  };
275  tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
276  tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
277  }
278 
279  // Transpose ACC if doing signed by unsigned multiplication, because we're
280  // using the instruction for unsigned by signed multiplication with
281  // reversed operands.
282  if (swapOperands)
283  tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc,
284  ArrayRef<int64_t>({1, 0}));
285 
286  // Collapse tiled operands to 1D vectors required by the ArmNeon ops
287  auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
288  tiledLhs.getLoc(), collapsedInputType, tiledLhs);
289  auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
290  tiledRhs.getLoc(), collapsedInputType, tiledRhs);
291 
292  bool initialKAcc = offsets.back() == 0;
293  Value collapsedRes;
294  if (!initialKAcc) {
295  collapsedRes = kAcc;
296  } else {
297  collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
298  tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
299  }
300 
301  // Insert contract op
302  kAcc =
303  createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
304 
305  // Reshape output back to 2D
306  Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
307  kAcc.getLoc(), tiledAcc.getType(), kAcc);
308 
309  // Because of the reversed operands the result is obtained transposed.
310  // Transpose it back,
311  if (swapOperands)
312  tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes,
313  ArrayRef<int64_t>({1, 0}));
314 
315  // With vecmat, only one row of tiled ACC can be inserted into the final
316  // result
317  if (dimM == 1)
318  tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
319 
320  // Insert the tiled result back into the non tiled result of the
321  // contract op.
322  SmallVector<int64_t> strides(
323  cast<ShapedType>(tiledRes.getType()).getRank(), 1);
324  result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
325  loc, tiledRes, result, accOffsets, strides);
326  }
327 
328  rewriter.replaceOp(op, result);
329  }
330 };
331 
332 class VectorContractRewriterI8MM : public VectorContractRewriter {
333 public:
334  LogicalResult matchAndInit(vector::ContractionOp op,
335  PatternRewriter &rewriter) {
336  if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
337  return failure();
338 
339  // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
340  // tiling.
341  if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
342  return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
343 
344  // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
345  // values before the extension. All four signed/unsigned combinations for
346  // input operands are supported, but they are lowered to different
347  // operations. Determine which is the appropriate operation to lower to.
348  mmlaOp = MMLA::SignedInt;
349  auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
350  if (!maybeLhs) {
351  mmlaOp = MMLA::UnsignedInt;
352  maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
353  }
354  if (!maybeLhs)
355  return rewriter.notifyMatchFailure(
356  op, "LHS is not a sign- or zero- extended iN, N <= 8");
357 
358  auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
359  if (maybeRhs) {
360  if (mmlaOp == MMLA::UnsignedInt)
361  mmlaOp = MMLA::MixedInt;
362  } else {
363  if (mmlaOp == MMLA::SignedInt) {
364  mmlaOp = MMLA::MixedInt;
365  swapOperands = true;
366  }
367  maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
368  }
369 
370  if (!maybeRhs)
371  return rewriter.notifyMatchFailure(
372  op, "RHS is not a sign- or zero- extended iN, N <= 8");
373 
374  lhs = *maybeLhs;
375  rhs = *maybeRhs;
376  acc = op.getAcc();
377 
378  // Extend inputs from iN, N < 8 to i8.
379  Location loc = op.getLoc();
380  auto lhsExtInType = cast<VectorType>(lhs.getType());
381  if (lhsExtInType.getElementTypeBitWidth() < 8)
382  lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
383  /* signExt */
384  (mmlaOp == MMLA::SignedInt ||
385  (mmlaOp == MMLA::MixedInt && !swapOperands)),
386  rewriter);
387 
388  auto rhsExtInType = cast<VectorType>(rhs.getType());
389  if (rhsExtInType.getElementTypeBitWidth() < 8)
390  rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
391  /* signExt */
392  (mmlaOp == MMLA::SignedInt ||
393  (mmlaOp == MMLA::MixedInt && swapOperands)),
394  rewriter);
395 
396  // Initialize parameters for unrolling.
397  iterationBounds = *op.getShapeForUnroll();
398  if (iterationBounds.size() == 3)
399  subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 8});
400  else
401  subTileShape = SmallVector<int64_t>({2, 8});
402 
403  return success();
404  }
405 };
406 
407 class VectorContractRewriterBFMMLA : public VectorContractRewriter {
408 public:
409  LogicalResult matchAndInit(vector::ContractionOp op,
410  PatternRewriter &rewriter) {
411 
412  if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
413  return failure();
414 
415  // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
416  // tiling.
417  if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
418  return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
419 
420  // Check the output is a vector of Float32 elements.
421  auto outTy = dyn_cast<VectorType>(op.getResultType());
422  if (!outTy || outTy.getElementType() != rewriter.getF32Type())
423  return rewriter.notifyMatchFailure(op,
424  "output type is not a vector of f32");
425 
426  // Check the inputs are vectors of BFloat16 elements.
427  if (op.getLhsType().getElementType() != rewriter.getBF16Type())
428  return rewriter.notifyMatchFailure(op,
429  "input type is not a vector of bf16");
430 
431  mmlaOp = MMLA::Bfloat;
432  swapOperands = false;
433  lhs = op.getLhs();
434  rhs = op.getRhs();
435  acc = op.getAcc();
436 
437  // Initialize parameters for unrolling.
438  iterationBounds = *op.getShapeForUnroll();
439  if (iterationBounds.size() == 3)
440  subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
441  else
442  subTileShape = SmallVector<int64_t>({2, 4});
443 
444  return success();
445  }
446 };
447 
448 /// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
449 /// any vector.contract into multiple smmla instructions with unrolling so long
450 /// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
451 /// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
452 /// necessary, a single smmla instruction is emitted.
453 class LowerContractionToNeonI8MMPattern
454  : public OpRewritePattern<vector::ContractionOp> {
455 public:
457  LogicalResult matchAndRewrite(vector::ContractionOp op,
458  PatternRewriter &rewriter) const override {
459 
460  VectorContractRewriterI8MM vcr;
461  if (failed(vcr.matchAndInit(op, rewriter)))
462  return failure();
463  vcr.lower(op, rewriter);
464 
465  return success();
466  }
467 };
468 
469 class LowerContractionToNeonBFMMLAPattern
470  : public OpRewritePattern<vector::ContractionOp> {
471 public:
473  LogicalResult matchAndRewrite(vector::ContractionOp op,
474  PatternRewriter &rewriter) const override {
475 
476  VectorContractRewriterBFMMLA vcr;
477  if (failed(vcr.matchAndInit(op, rewriter)))
478  return failure();
479  vcr.lower(op, rewriter);
480 
481  return success();
482  }
483 };
484 
485 } // namespace
486 
489  MLIRContext *context = patterns.getContext();
490  patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
491 }
492 
495  MLIRContext *context = patterns.getContext();
496  patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
497 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
FloatType getF32Type()
Definition: Builders.cpp:42
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
FloatType getBF16Type()
Definition: Builders.cpp:36
IntegerType getI8Type()
Definition: Builders.cpp:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns)
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition: AffineMap.h:675
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319