MLIR  22.0.0git
LowerContractToSVEPatterns.cpp
Go to the documentation of this file.
1 //===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering patterns from vector.contract to operations
10 // that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
11 //
12 // TODO: There may be opportunities to unify this with a similar pattern
13 // for Neon. See:
14 // https://github.com/llvm/llvm-project/issues/145559
15 // LowerContractToNeonPatterns.cpp
16 //
17 //===----------------------------------------------------------------------===//
18 
25 #include "mlir/IR/AffineMap.h"
26 #include "mlir/IR/PatternMatch.h"
27 
28 #include <cassert>
29 #include <numeric>
30 
31 #define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
32 
33 using namespace mlir;
34 
35 namespace {
36 // Get the operand of a `vector.contract`. This function is intended to abstract
37 // away from the particular way a value is extended before feeding it into the
38 // `vector.contract` - via zero-extend or an explicit or implicit sign-extend
39 // (for implicit sign-extension see `vector.contract` documentation).
40 //
41 // The template parameter `Op` indicates the extension operation (explicit or
42 // implicit) for which we are checking.
43 //
44 // Return success only for extensions from `i8` to `i32`.
45 template <typename Op>
46 std::optional<Value> getExtOperand(Value v) {
47 
48  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
49  "Must be instantiated with either sign- or zero- extension op");
50 
51  // If the operand is not defined by an explicit extend operation of the
52  // accepted operation type allow for an implicit sign-extension.
53  auto extOp = v.getDefiningOp<Op>();
54  if (!extOp) {
55  if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
56  auto vTy = cast<VectorType>(v.getType());
57  if (!vTy.getElementType().isSignlessInteger(8))
58  return {};
59  return v;
60  }
61  return {};
62  }
63 
64  // If the operand is defined by an explicit extend operation of the accepted
65  // operation type, check it's extended from `i8` to `i32`.
66  auto inOp = extOp.getIn();
67  auto inTy = dyn_cast<VectorType>(inOp.getType());
68  if (!inTy || !inTy.getElementType().isSignlessInteger(8))
69  return {};
70 
71  auto outTy = dyn_cast<VectorType>(extOp.getType());
72  if (!outTy || !outTy.getElementType().isSignlessInteger(32))
73  return {};
74 
75  return inOp;
76 }
77 
78 /// This class encapsulates the algorithm and parametrisation (in terms of types
79 /// and dimensions) of lowering a `vector.contract` to "primitive" matrix
80 /// multiplication operations of the SVE dialect (here "primitive" would mean
81 /// corresponding to a single target instruction).
82 ///
83 /// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
84 /// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
85 /// for concreteness the description below is given for `smmla`.
86 ///
87 /// The lowering triggers for a contraction operation that performs a matrix
88 /// multiply of two 8-bit integer matrix tiles with logical dimensions
89 /// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
90 /// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
91 /// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
92 ///
93 /// The operands' shapes are such that the operands can be evenly split into
94 /// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
95 /// instructions. The intent is that M and N are chosen (by higher level
96 /// transforms) in such a way as to maximise register usage. The main use case
97 /// we envision as of now is MMT4D, thus the RHS operand is expected
98 /// pre-transposed.
99 ///
100 /// The matrix multiplication is performed by unrolling the usual tiled matrix
101 /// multiplication algorithm using sub-tiles with dimensions <2x8> for the
102 /// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
103 /// accumulator.
104 ///
105 /// One way to illustrate the operation is as follows:
106 ///
107 /// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
108 /// +-----------------------------
109 /// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
110 /// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
111 /// ... | ... ... ... ...
112 /// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
113 ///
114 /// The RHS operand is unpacked into N/2 values, each representing a sequence
115 /// of VSCALE number of sub-tiles with dimensions <8x2>.
116 /// The LHS operand is initially unpacked into M/2 values, each representing a
117 /// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
118 /// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
119 /// RHS sub-tile correctly computes an entire result sub-tile.
120 /// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
121 /// (in memory or when imposing a row-major layout on the 2D vector value).
122 /// Reading the ACC is implemented as reading two consecutive rows and
123 /// interleaving the by pairs to obtain a vector having length twice the length
124 /// of an ACC row. This vector now is a sequence of one-dimensional tiles with
125 /// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
126 /// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
127 /// a0 a1 b0 b1
128 /// a2 a3 b2 b3
129 /// we read the two rows as separate values and then interleave by pairs
130 /// to obtain
131 /// a0 a1 a2 a3 b0 b1 b2 b3
132 /// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
133 ///
134 /// Writing the OUT tile is done by the reverse of the above procedure,
135 /// concatenate two "flattened" sub-tiles into
136 /// c0 c1 c2 c3 d0 d1 d2 d3
137 /// deinterleave by pairs to obtain as separate values
138 /// c0 c1 d0 d1
139 /// c2 c3 d2 d3
140 /// which are then inserted into the final result.
141 ///
142 /// Multiplication of a signed LHS by an unsigned LHS is performed by
143 /// swapping the order of the operands and emitting an `usmmla` (since there
144 /// isn't an `summla` instruction). Therefore each ACC sub-tile needs
145 /// to be transposed before the addition and the sum, an OUT sub-tile,
146 /// needs to be transposed before insertion into the final result.
147 /// This is done very elegantly by a modification of the above to
148 /// interleave/deinterleave not by pairs, but by individual elements, e.g.
149 /// after ordinary interleave we obtain
150 /// a0 a2 a1 a3 b0 b2 b1 b3
151 /// which is exactly the desired layout of having each individual 2x2 tile
152 /// transposed.
153 ///
154 /// All of the above readily applies to FEAT_BF16 `bfmmla` with the
155 /// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
156 /// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
157 /// for the integer case).
158 class VectorContractRewriter {
159 protected:
160  // Designate the operation (resp. instruction) used to do sub-tile matrix
161  // multiplications.
162  enum class MMLA {
163  Nop,
164  SignedInt, // smmla
165  UnsignedInt, // ummla
166  MixedInt, // usmmla
167  Bfloat // bfmmla
168  };
169 
170  // Lower-level operation to be emitted.
171  MMLA mmlaOp = MMLA::Nop;
172 
173  // Indicate if the operands for the ArmSVE dialect operation need to be
174  // swapped. Currently this is needed in order to emulate an "summla"
175  // operation.
176  bool swapOperands = false;
177 
178  // The operand tiles. These are not necessarily the operends of
179  // `vector.contract`, for example they could be operands to `arith.extsi`
180  // that is in turn fed into `vector.contract`.
181  Value lhs;
182  Value rhs;
183  Value acc;
184 
185  // Conventional names for matrix dimensions.
186  int64_t m = 0;
187  int64_t n = 0;
188  int64_t k = 0;
189 
190  // Create the matrix mulitply and accumulate operation according to
191  // `mmlaOp`.
192  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
193  Value lhs, Value rhs);
194 
195  // Check general preconditions for applying the transformation, common to the
196  // integer and the bfloat16 case.
197  LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
198 
199 public:
200  VectorContractRewriter() = default;
201 
202  // Do the actuall rewrite. This member function is shared by both integer and
203  // bfloat16 rewrites.
204  Value lower(vector::ContractionOp op, PatternRewriter &rewriter);
205 };
206 
207 Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
208  Location loc, Value acc, Value lhs,
209  Value rhs) {
210 
211  Type resTy = acc.getType();
212  if (swapOperands)
213  std::swap(lhs, rhs);
214 
215  switch (mmlaOp) {
216  case MMLA::SignedInt:
217  return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
218  case MMLA::UnsignedInt:
219  return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
220  case MMLA::MixedInt:
221  return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
222  case MMLA::Bfloat:
223  return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
224  default:
225  llvm_unreachable("Uninitialized operation kind");
226  }
227 }
228 
229 LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
230  PatternRewriter &rewriter) {
231  // Check iterator types for matrix multiplication.
232  auto itTypes = op.getIteratorTypesArray();
233  if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
234  itTypes[1] != vector::IteratorType::parallel ||
235  itTypes[2] != vector::IteratorType::reduction)
236  return rewriter.notifyMatchFailure(
237  op, "iterator types do not correspond to matrix multiplication");
238 
239  // Check permutation maps. For now only accept
240  // lhs: (d0, d1, d2) -> (d0, d2)
241  // rhs: (d0, d1, d2) -> (d1, d2)
242  // acc: (d0, d1, d2) -> (d0, d1)
243  // This corresponds to matrix multiplication with transposed RHS.
244  if (op.getIndexingMapsArray()[0] !=
246  op.getContext()) ||
247  op.getIndexingMapsArray()[1] !=
249  op.getContext()) ||
250  op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
251  3, ArrayRef{0u, 1u}, op.getContext()))
252  return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
253 
254  // Check the combining kind is addition.
255  if (op.getKind() != vector::CombiningKind::ADD)
256  return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
257 
258  return success();
259 }
260 
261 Value VectorContractRewriter::lower(vector::ContractionOp op,
262  PatternRewriter &rewriter) {
263 
264  // Initialize some helper types.
265  Type operandEltType = cast<VectorType>(lhs.getType()).getElementType();
266  Type resultEltType = cast<VectorType>(op.getResultType()).getElementType();
267 
268  const int64_t numOperandSubTileElts =
269  128 / operandEltType.getIntOrFloatBitWidth();
270 
271  assert(resultEltType.getIntOrFloatBitWidth() == 32 &&
272  "Only implemented for i32 or f32 output");
273  const int64_t numResultSubTileElts = 4;
274 
275  // Single-dimensional vector types for the operands of the ArmSVE dialect
276  // op.
277  auto flatLhsType =
278  VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType,
279  /*scalableDims=*/{true});
280  auto flatRhsType =
281  VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType,
282  /*scalableDims=*/{true});
283  auto flatAccType =
284  VectorType::get(/*shape=*/numResultSubTileElts, resultEltType,
285  /*scalableDims=*/{true});
286 
287  // Single-dimension vector type for the entire RHS tile.
288 
289  auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType,
290  /*scalableDims=*/{true});
291 
292  // Vector type having the same number of elements as a row in the
293  // accumulator/output tile and the same element type.
294  auto accRowTy = VectorType::get(/*shape=*/n, resultEltType,
295  /*scalableDims=*/{true});
296 
297  // Vector type having twice the number of elements as a row in the
298  // accumulator/output tile the same element type.
299  auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType,
300  /*scalableDims=*/{true});
301  // Vector type having half the number of elements as a row in the
302  // accumulator/output tile and an integer element type with twice the bit
303  // width.
304  auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(),
305  /*scalableDims=*/{true});
306  // Vector type having the same the number of elements as a row in the
307  // accumulator/output tile and an integer element type with twice the bit
308  // width.
309  auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(),
310  /*scalableDims=*/{true});
311 
312  Location loc = op.getLoc();
313 
314  // Extract LHS sub-tiles with logical shape <2xK>.
315  SmallVector<Value> lhsTile;
316  for (int64_t i = 0; i < m; i += 2) {
317  // Extract two consecutive rows of the LHS tile.
318  auto r0 =
319  vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
320  auto r1 =
321  vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
322  // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
323  SmallVector<int64_t> shuffleIdx(2 * k);
324  std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
325  auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
326  // Turn it into a scalable vector.
327  auto s = vector::ScalableInsertOp::create(
328  rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0);
329  // Replicate the sub-tile VSCALE times to fill the entire vector.
330  auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0);
331  lhsTile.push_back(r);
332  }
333 
334  // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
335  auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.getLoc(),
336  flatRhsTileType, this->rhs);
337 
338  // Extract the RHS sub-tiles with logical shape <Kx[2]>.
339  SmallVector<Value> rhsTile;
340  for (int64_t j = 0; j < n; j += 2)
341  rhsTile.push_back(vector::ScalableExtractOp::create(
342  rewriter, loc, flatRhsType, rhs, j * k));
343 
344  // Extract and pack the ACC sub-tiles.
345  SmallVector<Value> accTile;
346  for (int64_t i = 0; i < m; i += 2) {
347  // Extract two consecutive rows of the accumulator tile.
348  auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
349  ArrayRef<int64_t>{i});
350  auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
351  ArrayRef<int64_t>{i + 1});
352  Value accTileVec;
353  if (swapOperands) {
354  // We are performing the operation with swapped LHS and RHS we need to
355  // transpose each individual 2x2 tile of the accumulator and (later) the
356  // final result.
357  accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1);
358  } else {
359  // Bitcast accumulator rows to double-width integer elements, so
360  // subsequent interleave/deinterleave work on pairs of elements.
361  auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0);
362  auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1);
363 
364  // Interleave the rows, effectively flattening each 2x2 tile into 4
365  // consecutive elements.
366  auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64);
367 
368  // Bitcast back to original element type.
369  accTileVec =
370  vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
371  }
372  // Extract ACC sub-tiles.
373  for (int64_t j = 0; j < n; j += 2)
374  accTile.push_back(vector::ScalableExtractOp::create(
375  rewriter, loc, flatAccType, accTileVec, j * 2));
376  }
377 
378  // Emit sub-tile matrix multiplications.
379  SmallVector<Value> outTile;
380  for (int64_t i = 0; i < m / 2; ++i)
381  for (int64_t j = 0; j < n / 2; ++j) {
382  Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i],
383  rhsTile[j]);
384  outTile.push_back(mmla);
385  }
386 
387  // Unpack the OUT sub-tiles and insert into the result.
388  Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
389  for (int64_t i = 0; i < m / 2; ++i) {
390  // Collect a number of sub-tiles in a row.
391  Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
392  for (int64_t j = 0; j < n / 2; ++j)
393  row = vector::ScalableInsertOp::create(
394  rewriter, loc, outTile[i * n / 2 + j], row, j * 4);
395 
396  // Unpack the row to obtain two rows of the output. If we have the out
397  // sub-tiles transposed we obtain two consecutive output rows by
398  // separating even and odd elements, i.e. a simple deinterleave.
399  // Otherwise, the interleave is by pairs.
400  Value out0, out1;
401  if (swapOperands) {
402  auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row);
403  out0 = tmp.getRes1();
404  out1 = tmp.getRes2();
405  } else {
406  // Deinterleave by pairs.
407  auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row);
408  auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64);
409 
410  // Bitcast back into original element type and insert into the result.
411  out0 = vector::BitCastOp::create(rewriter, loc, accRowTy,
412  deintr64.getRes1());
413  out1 = vector::BitCastOp::create(rewriter, loc, accRowTy,
414  deintr64.getRes2());
415  }
416  result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2);
417  result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1);
418  }
419 
420  return result;
421 }
422 
423 class VectorContractRewriterI8MM : public VectorContractRewriter {
424 public:
425  // Check the specific preconditions for the integer case. Initialise
426  // parametrisation types and dimensions.
427  LogicalResult matchAndInit(vector::ContractionOp op,
428  PatternRewriter &rewriter) {
429  if (failed(match(op, rewriter)))
430  return failure();
431 
432  VectorType lhsType = op.getLhsType();
433  VectorType rhsType = op.getRhsType();
434 
435  m = lhsType.getDimSize(0);
436  n = rhsType.getDimSize(0);
437  k = rhsType.getDimSize(1);
438 
439  // Check the operands have the expected shape:
440  // * for LHS: fixed vector MxK
441  // * for RHS: scalable vector [N]xK
442  // * K == 8
443  // * M and N even and at least 2
444  if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
445  rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
446  m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
447  !rhsType.getScalableDims()[0])
448  return rewriter.notifyMatchFailure(op, "non-matching operand shape");
449 
450  // Check the output is a vector of i32 elements.
451  auto outTy = dyn_cast<VectorType>(op.getResultType());
452  if (!outTy || outTy.getElementType() != rewriter.getI32Type())
453  return rewriter.notifyMatchFailure(op,
454  "output type is not a vector of i32");
455 
456  // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
457  // before the extension. All four signed/unsigned combinations for input
458  // operands are supported, but they are lowered to different operations.
459  // Determine which is the appropriate operation to lower to.
460  mmlaOp = MMLA::SignedInt;
461  swapOperands = false;
462  auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
463  if (!maybeLhs) {
464  mmlaOp = MMLA::UnsignedInt;
465  maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
466  }
467  if (!maybeLhs)
468  return rewriter.notifyMatchFailure(
469  op, "LHS is not a sign- or zero- extended i8");
470 
471  auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
472  if (maybeRhs) {
473  if (mmlaOp == MMLA::UnsignedInt)
474  mmlaOp = MMLA::MixedInt;
475  } else {
476  if (mmlaOp == MMLA::SignedInt) {
477  mmlaOp = MMLA::MixedInt;
478  swapOperands = true;
479  }
480  maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
481  }
482  if (!maybeRhs)
483  return rewriter.notifyMatchFailure(
484  op, "RHS is not a sign- or zero- extended i8");
485 
486  // Initialise algorithm parameters.
487  lhs = *maybeLhs;
488  rhs = *maybeRhs;
489  acc = op.getAcc();
490 
491  return success();
492  }
493 };
494 
495 class VectorContractRewriterBfloat : public VectorContractRewriter {
496 public:
497  // Check the specific preconditions for the bfloat16 case. Initialise
498  // parametrisation types and dimensions.
499  LogicalResult matchAndInit(vector::ContractionOp op,
500  PatternRewriter &rewriter) {
501  if (failed(match(op, rewriter)))
502  return failure();
503 
504  VectorType lhsType = op.getLhsType();
505  VectorType rhsType = op.getRhsType();
506 
507  m = lhsType.getDimSize(0);
508  n = rhsType.getDimSize(0);
509  k = rhsType.getDimSize(1);
510 
511  // Check the operands have the expected shape:
512  // * for LHS: fixed vector MxK
513  // * for RHS: scalable vector [N]xK
514  // * K == 4
515  // * M and N even and at least 2
516  if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
517  rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
518  m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
519  !rhsType.getScalableDims()[0])
520  return rewriter.notifyMatchFailure(op, "non-matching operand shape");
521 
522  // Check the output is a vector of Float32 elements.
523  auto outTy = dyn_cast<VectorType>(op.getResultType());
524  if (!outTy || outTy.getElementType() != rewriter.getF32Type())
525  return rewriter.notifyMatchFailure(op,
526  "output type is not a vector of f32");
527 
528  // Check the inputs are vectors of BFloat16 elements.
529  if (lhsType.getElementType() != rewriter.getBF16Type())
530  return rewriter.notifyMatchFailure(op,
531  "input type is not a vector of bf16");
532 
533  // Initialise algorithm parameters.
534  mmlaOp = MMLA::Bfloat;
535  swapOperands = false;
536  lhs = op.getLhs();
537  rhs = op.getRhs();
538  acc = op.getAcc();
539 
540  return success();
541  }
542 };
543 
544 class LowerContractionToSVEI8MMPattern
545  : public OpRewritePattern<vector::ContractionOp> {
546 public:
548  LogicalResult matchAndRewrite(vector::ContractionOp op,
549  PatternRewriter &rewriter) const override {
550 
551  // Match i8xi8 -> i32 matrix multiply and accumulate.
552  VectorContractRewriterI8MM vcr;
553  if (failed(vcr.matchAndInit(op, rewriter)))
554  return failure();
555 
556  Value result = vcr.lower(op, rewriter);
557  rewriter.replaceOp(op, result);
558 
559  return success();
560  }
561 };
562 
563 class LowerContractionToSVEBFMMLAPattern
564  : public OpRewritePattern<vector::ContractionOp> {
565 public:
567  LogicalResult matchAndRewrite(vector::ContractionOp op,
568  PatternRewriter &rewriter) const override {
569 
570  // Match bf16xbf16 -> f32 matrix multiply and accumulate.
571  VectorContractRewriterBfloat vcr;
572  if (failed(vcr.matchAndInit(op, rewriter)))
573  return failure();
574 
575  Value result = vcr.lower(op, rewriter);
576  rewriter.replaceOp(op, result);
577 
578  return success();
579  }
580 };
581 
582 } // namespace
583 
586  MLIRContext *context = patterns.getContext();
587  patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
588 }
589 
592  MLIRContext *context = patterns.getContext();
593  patterns.add<LowerContractionToSVEBFMMLAPattern>(context, /*benefit=*/2);
594 }
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
Definition: AffineMap.cpp:276
FloatType getF32Type()
Definition: Builders.cpp:42
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
FloatType getBF16Type()
Definition: Builders.cpp:36
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns)
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.