MLIR 22.0.0git
LowerContractToSVEPatterns.cpp
Go to the documentation of this file.
1//===- LowerContractToSVEPatterns.cpp - Contract to I8MM/BF16 ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering patterns from vector.contract to operations
10// that map to instructions from the SVE FEAT_I8MM and FEAT_BF16 extensions.
11//
12// TODO: There may be opportunities to unify this with a similar pattern
13// for Neon. See:
14// https://github.com/llvm/llvm-project/issues/145559
15// LowerContractToNeonPatterns.cpp
16//
17//===----------------------------------------------------------------------===//
18
25#include "mlir/IR/AffineMap.h"
27
28#include <cassert>
29#include <numeric>
30
31#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
32
33using namespace mlir;
34
35namespace {
36// Get the operand of a `vector.contract`. This function is intended to abstract
37// away from the particular way a value is extended before feeding it into the
38// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
39// (for implicit sign-extension see `vector.contract` documentation).
40//
41// The template parameter `Op` indicates the extension operation (explicit or
42// implicit) for which we are checking.
43//
44// Return success only for extensions from `i8` to `i32`.
45template <typename Op>
46std::optional<Value> getExtOperand(Value v) {
47
48 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
49 "Must be instantiated with either sign- or zero- extension op");
50
51 // If the operand is not defined by an explicit extend operation of the
52 // accepted operation type allow for an implicit sign-extension.
53 auto extOp = v.getDefiningOp<Op>();
54 if (!extOp) {
55 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
56 auto vTy = cast<VectorType>(v.getType());
57 if (!vTy.getElementType().isSignlessInteger(8))
58 return {};
59 return v;
60 }
61 return {};
62 }
63
64 // If the operand is defined by an explicit extend operation of the accepted
65 // operation type, check it's extended from `i8` to `i32`.
66 auto inOp = extOp.getIn();
67 auto inTy = dyn_cast<VectorType>(inOp.getType());
68 if (!inTy || !inTy.getElementType().isSignlessInteger(8))
69 return {};
70
71 auto outTy = dyn_cast<VectorType>(extOp.getType());
72 if (!outTy || !outTy.getElementType().isSignlessInteger(32))
73 return {};
74
75 return inOp;
76}
77
78/// This class encapsulates the algorithm and parametrisation (in terms of types
79/// and dimensions) of lowering a `vector.contract` to "primitive" matrix
80/// multiplication operations of the SVE dialect (here "primitive" would mean
81/// corresponding to a single target instruction).
82///
83/// Supported are lowering to FEAT_I8MM `smmla`, `ummla`, and `usmmla`, and to
84/// FEAT_BF16 `bfmmla`. All the transformations are very similar to each other
85/// for concreteness the description below is given for `smmla`.
86///
87/// The lowering triggers for a contraction operation that performs a matrix
88/// multiply of two 8-bit integer matrix tiles with logical dimensions
89/// <Mx8> and <8x[N]> for the left-hand side (LHS) and the right-hand side
90/// (RHS), respectively, added to a 32-bit integer accumulator operand (ACC)
91/// with dimensions <Mx[N]>, yielding a <Mx[N]> 32-bit integer result (OUT).
92///
93/// The operands' shapes are such that the operands can be evenly split into
94/// sub-tiles with dimensions as expected by the targeted FEAT_I8MM
95/// instructions. The intent is that M and N are chosen (by higher level
96/// transforms) in such a way as to maximise register usage. The main use case
97/// we envision as of now is MMT4D, thus the RHS operand is expected
98/// pre-transposed.
99///
100/// The matrix multiplication is performed by unrolling the usual tiled matrix
101/// multiplication algorithm using sub-tiles with dimensions <2x8> for the
102/// LHS, <8x[2]> for the RHS, and <2x[2]> for the result and the input
103/// accumulator.
104///
105/// One way to illustrate the operation is as follows:
106///
107/// RHS<8x[N]>: <8x[2]> <8x[2]> ... <8x[2]>
108/// +-----------------------------
109/// LHS<Mx8>: <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
110/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
111/// ... | ... ... ... ...
112/// <2x8> | <2x[2]> <2x[2]> ... <2x[2]>
113///
114/// The RHS operand is unpacked into N/2 values, each representing a sequence
115/// of VSCALE number of sub-tiles with dimensions <8x2>.
116/// The LHS operand is initially unpacked into M/2 values, each representing a
117/// sub-tile with dimensions <2x8>, and then each such sub-tile is replicated
118/// VSCALE times. Multiplying thus replicated LHS sub-tile by the corresponding
119/// RHS sub-tile correctly computes an entire result sub-tile.
120/// The 2x2 sub-tiles of the ACC and OUT have rows that are not adjacent
121/// (in memory or when imposing a row-major layout on the 2D vector value).
122/// Reading the ACC is implemented as reading two consecutive rows and
123/// interleaving the by pairs to obtain a vector having length twice the length
124/// of an ACC row. This vector now is a sequence of one-dimensional tiles with
125/// the exact layout needed by the `smmla`/`bfmmla`/etc instructions, which
126/// tiles are extracted one by one. For illustration, if we have an 2x4 ACC tile
127/// a0 a1 b0 b1
128/// a2 a3 b2 b3
129/// we read the two rows as separate values and then interleave by pairs
130/// to obtain
131/// a0 a1 a2 a3 b0 b1 b2 b3
132/// from which we extract `a0 a1 a2 a3` and `b0 b1 b2 b3`.
133///
134/// Writing the OUT tile is done by the reverse of the above procedure,
135/// concatenate two "flattened" sub-tiles into
136/// c0 c1 c2 c3 d0 d1 d2 d3
137/// deinterleave by pairs to obtain as separate values
138/// c0 c1 d0 d1
139/// c2 c3 d2 d3
140/// which are then inserted into the final result.
141///
142/// Multiplication of a signed LHS by an unsigned LHS is performed by
143/// swapping the order of the operands and emitting an `usmmla` (since there
144/// isn't an `summla` instruction). Therefore each ACC sub-tile needs
145/// to be transposed before the addition and the sum, an OUT sub-tile,
146/// needs to be transposed before insertion into the final result.
147/// This is done very elegantly by a modification of the above to
148/// interleave/deinterleave not by pairs, but by individual elements, e.g.
149/// after ordinary interleave we obtain
150/// a0 a2 a1 a3 b0 b2 b1 b3
151/// which is exactly the desired layout of having each individual 2x2 tile
152/// transposed.
153///
154/// All of the above readily applies to FEAT_BF16 `bfmmla` with the
155/// difference that the shapes of the LHS, RHS are <Mx4>, <4x[M]>, and
156/// respectively, that is the "K" dimension is fixed to 4, instead of 8 (like
157/// for the integer case).
158class VectorContractRewriter {
159protected:
160 // Designate the operation (resp. instruction) used to do sub-tile matrix
161 // multiplications.
162 enum class MMLA {
163 Nop,
164 SignedInt, // smmla
165 UnsignedInt, // ummla
166 MixedInt, // usmmla
167 Bfloat // bfmmla
168 };
169
170 // Lower-level operation to be emitted.
171 MMLA mmlaOp = MMLA::Nop;
172
173 // Indicate if the operands for the ArmSVE dialect operation need to be
174 // swapped. Currently this is needed in order to emulate an "summla"
175 // operation.
176 bool swapOperands = false;
177
178 // The operand tiles. These are not necessarily the operends of
179 // `vector.contract`, for example they could be operands to `arith.extsi`
180 // that is in turn fed into `vector.contract`.
181 Value lhs;
182 Value rhs;
183 Value acc;
184
185 // Conventional names for matrix dimensions.
186 int64_t m = 0;
187 int64_t n = 0;
188 int64_t k = 0;
189
190 // Create the matrix mulitply and accumulate operation according to
191 // `mmlaOp`.
192 Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
193 Value lhs, Value rhs);
194
195 // Check general preconditions for applying the transformation, common to the
196 // integer and the bfloat16 case.
197 LogicalResult match(vector::ContractionOp op, PatternRewriter &rewriter);
198
199public:
200 VectorContractRewriter() = default;
201
202 // Do the actuall rewrite. This member function is shared by both integer and
203 // bfloat16 rewrites.
204 Value lower(vector::ContractionOp op, PatternRewriter &rewriter);
205};
206
207Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
208 Location loc, Value acc, Value lhs,
209 Value rhs) {
210
211 Type resTy = acc.getType();
212 if (swapOperands)
213 std::swap(lhs, rhs);
214
215 switch (mmlaOp) {
216 case MMLA::SignedInt:
217 return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
218 case MMLA::UnsignedInt:
219 return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
220 case MMLA::MixedInt:
221 return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
222 case MMLA::Bfloat:
223 return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
224 default:
225 llvm_unreachable("Uninitialized operation kind");
226 }
227}
228
229LogicalResult VectorContractRewriter::match(vector::ContractionOp op,
230 PatternRewriter &rewriter) {
231 // Check iterator types for matrix multiplication.
232 auto itTypes = op.getIteratorTypesArray();
233 if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
234 itTypes[1] != vector::IteratorType::parallel ||
235 itTypes[2] != vector::IteratorType::reduction)
236 return rewriter.notifyMatchFailure(
237 op, "iterator types do not correspond to matrix multiplication");
238
239 // Check permutation maps. For now only accept
240 // lhs: (d0, d1, d2) -> (d0, d2)
241 // rhs: (d0, d1, d2) -> (d1, d2)
242 // acc: (d0, d1, d2) -> (d0, d1)
243 // This corresponds to matrix multiplication with transposed RHS.
244 if (op.getIndexingMapsArray()[0] !=
246 op.getContext()) ||
247 op.getIndexingMapsArray()[1] !=
249 op.getContext()) ||
250 op.getIndexingMapsArray()[2] != AffineMap::getMultiDimMapWithTargets(
251 3, ArrayRef{0u, 1u}, op.getContext()))
252 return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
253
254 // Check the combining kind is addition.
255 if (op.getKind() != vector::CombiningKind::ADD)
256 return rewriter.notifyMatchFailure(op, "combining kind is not an addition");
257
258 return success();
259}
260
261Value VectorContractRewriter::lower(vector::ContractionOp op,
262 PatternRewriter &rewriter) {
263
264 // Initialize some helper types.
265 Type operandEltType = cast<VectorType>(lhs.getType()).getElementType();
266 Type resultEltType = cast<VectorType>(op.getResultType()).getElementType();
267
268 const int64_t numOperandSubTileElts =
269 128 / operandEltType.getIntOrFloatBitWidth();
270
271 assert(resultEltType.getIntOrFloatBitWidth() == 32 &&
272 "Only implemented for i32 or f32 output");
273 const int64_t numResultSubTileElts = 4;
274
275 // Single-dimensional vector types for the operands of the ArmSVE dialect
276 // op.
277 auto flatLhsType =
278 VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType,
279 /*scalableDims=*/{true});
280 auto flatRhsType =
281 VectorType::get(/*shape=*/numOperandSubTileElts, operandEltType,
282 /*scalableDims=*/{true});
283 auto flatAccType =
284 VectorType::get(/*shape=*/numResultSubTileElts, resultEltType,
285 /*scalableDims=*/{true});
286
287 // Single-dimension vector type for the entire RHS tile.
288
289 auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType,
290 /*scalableDims=*/{true});
291
292 // Vector type having the same number of elements as a row in the
293 // accumulator/output tile and the same element type.
294 auto accRowTy = VectorType::get(/*shape=*/n, resultEltType,
295 /*scalableDims=*/{true});
296
297 // Vector type having twice the number of elements as a row in the
298 // accumulator/output tile the same element type.
299 auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType,
300 /*scalableDims=*/{true});
301 // Vector type having half the number of elements as a row in the
302 // accumulator/output tile and an integer element type with twice the bit
303 // width.
304 auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(),
305 /*scalableDims=*/{true});
306 // Vector type having the same the number of elements as a row in the
307 // accumulator/output tile and an integer element type with twice the bit
308 // width.
309 auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(),
310 /*scalableDims=*/{true});
311
312 Location loc = op.getLoc();
313
314 // Extract LHS sub-tiles with logical shape <2xK>.
315 SmallVector<Value> lhsTile;
316 for (int64_t i = 0; i < m; i += 2) {
317 // Extract two consecutive rows of the LHS tile.
318 auto r0 =
319 vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
320 auto r1 =
321 vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
322 // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
323 SmallVector<int64_t> shuffleIdx(2 * k);
324 std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
325 auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
326 // Turn it into a scalable vector.
327 auto s = vector::ScalableInsertOp::create(
328 rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0);
329 // Replicate the sub-tile VSCALE times to fill the entire vector.
330 auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0);
331 lhsTile.push_back(r);
332 }
333
334 // "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
335 auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.getLoc(),
336 flatRhsTileType, this->rhs);
337
338 // Extract the RHS sub-tiles with logical shape <Kx[2]>.
339 SmallVector<Value> rhsTile;
340 for (int64_t j = 0; j < n; j += 2)
341 rhsTile.push_back(vector::ScalableExtractOp::create(
342 rewriter, loc, flatRhsType, rhs, j * k));
343
344 // Extract and pack the ACC sub-tiles.
345 SmallVector<Value> accTile;
346 for (int64_t i = 0; i < m; i += 2) {
347 // Extract two consecutive rows of the accumulator tile.
348 auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
350 auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
351 ArrayRef<int64_t>{i + 1});
352 Value accTileVec;
353 if (swapOperands) {
354 // We are performing the operation with swapped LHS and RHS we need to
355 // transpose each individual 2x2 tile of the accumulator and (later) the
356 // final result.
357 accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1);
358 } else {
359 // Bitcast accumulator rows to double-width integer elements, so
360 // subsequent interleave/deinterleave work on pairs of elements.
361 auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0);
362 auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1);
363
364 // Interleave the rows, effectively flattening each 2x2 tile into 4
365 // consecutive elements.
366 auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64);
367
368 // Bitcast back to original element type.
369 accTileVec =
370 vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
371 }
372 // Extract ACC sub-tiles.
373 for (int64_t j = 0; j < n; j += 2)
374 accTile.push_back(vector::ScalableExtractOp::create(
375 rewriter, loc, flatAccType, accTileVec, j * 2));
376 }
377
378 // Emit sub-tile matrix multiplications.
379 SmallVector<Value> outTile;
380 for (int64_t i = 0; i < m / 2; ++i)
381 for (int64_t j = 0; j < n / 2; ++j) {
382 Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i],
383 rhsTile[j]);
384 outTile.push_back(mmla);
385 }
386
387 // Unpack the OUT sub-tiles and insert into the result.
388 Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
389 for (int64_t i = 0; i < m / 2; ++i) {
390 // Collect a number of sub-tiles in a row.
391 Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
392 for (int64_t j = 0; j < n / 2; ++j)
393 row = vector::ScalableInsertOp::create(
394 rewriter, loc, outTile[i * n / 2 + j], row, j * 4);
395
396 // Unpack the row to obtain two rows of the output. If we have the out
397 // sub-tiles transposed we obtain two consecutive output rows by
398 // separating even and odd elements, i.e. a simple deinterleave.
399 // Otherwise, the interleave is by pairs.
400 Value out0, out1;
401 if (swapOperands) {
402 auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row);
403 out0 = tmp.getRes1();
404 out1 = tmp.getRes2();
405 } else {
406 // Deinterleave by pairs.
407 auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row);
408 auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64);
409
410 // Bitcast back into original element type and insert into the result.
411 out0 = vector::BitCastOp::create(rewriter, loc, accRowTy,
412 deintr64.getRes1());
413 out1 = vector::BitCastOp::create(rewriter, loc, accRowTy,
414 deintr64.getRes2());
415 }
416 result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2);
417 result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1);
418 }
419
420 return result;
421}
422
423class VectorContractRewriterI8MM : public VectorContractRewriter {
424public:
425 // Check the specific preconditions for the integer case. Initialise
426 // parametrisation types and dimensions.
427 LogicalResult matchAndInit(vector::ContractionOp op,
428 PatternRewriter &rewriter) {
429 if (failed(match(op, rewriter)))
430 return failure();
431
432 VectorType lhsType = op.getLhsType();
433 VectorType rhsType = op.getRhsType();
434
435 m = lhsType.getDimSize(0);
436 n = rhsType.getDimSize(0);
437 k = rhsType.getDimSize(1);
438
439 // Check the operands have the expected shape:
440 // * for LHS: fixed vector MxK
441 // * for RHS: scalable vector [N]xK
442 // * K == 8
443 // * M and N even and at least 2
444 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
445 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 ||
446 m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
447 !rhsType.getScalableDims()[0])
448 return rewriter.notifyMatchFailure(op, "non-matching operand shape");
449
450 // Check the output is a vector of i32 elements.
451 auto outTy = dyn_cast<VectorType>(op.getResultType());
452 if (!outTy || outTy.getElementType() != rewriter.getI32Type())
453 return rewriter.notifyMatchFailure(op,
454 "output type is not a vector of i32");
455
456 // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
457 // before the extension. All four signed/unsigned combinations for input
458 // operands are supported, but they are lowered to different operations.
459 // Determine which is the appropriate operation to lower to.
460 mmlaOp = MMLA::SignedInt;
461 swapOperands = false;
462 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
463 if (!maybeLhs) {
464 mmlaOp = MMLA::UnsignedInt;
465 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
466 }
467 if (!maybeLhs)
468 return rewriter.notifyMatchFailure(
469 op, "LHS is not a sign- or zero- extended i8");
470
471 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
472 if (maybeRhs) {
473 if (mmlaOp == MMLA::UnsignedInt)
474 mmlaOp = MMLA::MixedInt;
475 } else {
476 if (mmlaOp == MMLA::SignedInt) {
477 mmlaOp = MMLA::MixedInt;
478 swapOperands = true;
479 }
480 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
481 }
482 if (!maybeRhs)
483 return rewriter.notifyMatchFailure(
484 op, "RHS is not a sign- or zero- extended i8");
485
486 // Initialise algorithm parameters.
487 lhs = *maybeLhs;
488 rhs = *maybeRhs;
489 acc = op.getAcc();
490
491 return success();
492 }
493};
494
495class VectorContractRewriterBfloat : public VectorContractRewriter {
496public:
497 // Check the specific preconditions for the bfloat16 case. Initialise
498 // parametrisation types and dimensions.
499 LogicalResult matchAndInit(vector::ContractionOp op,
500 PatternRewriter &rewriter) {
501 if (failed(match(op, rewriter)))
502 return failure();
503
504 VectorType lhsType = op.getLhsType();
505 VectorType rhsType = op.getRhsType();
506
507 m = lhsType.getDimSize(0);
508 n = rhsType.getDimSize(0);
509 k = rhsType.getDimSize(1);
510
511 // Check the operands have the expected shape:
512 // * for LHS: fixed vector MxK
513 // * for RHS: scalable vector [N]xK
514 // * K == 4
515 // * M and N even and at least 2
516 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
517 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 ||
518 m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 ||
519 !rhsType.getScalableDims()[0])
520 return rewriter.notifyMatchFailure(op, "non-matching operand shape");
521
522 // Check the output is a vector of Float32 elements.
523 auto outTy = dyn_cast<VectorType>(op.getResultType());
524 if (!outTy || outTy.getElementType() != rewriter.getF32Type())
525 return rewriter.notifyMatchFailure(op,
526 "output type is not a vector of f32");
527
528 // Check the inputs are vectors of BFloat16 elements.
529 if (lhsType.getElementType() != rewriter.getBF16Type())
530 return rewriter.notifyMatchFailure(op,
531 "input type is not a vector of bf16");
532
533 // Initialise algorithm parameters.
534 mmlaOp = MMLA::Bfloat;
535 swapOperands = false;
536 lhs = op.getLhs();
537 rhs = op.getRhs();
538 acc = op.getAcc();
539
540 return success();
541 }
542};
543
544class LowerContractionToSVEI8MMPattern
545 : public OpRewritePattern<vector::ContractionOp> {
546public:
548 LogicalResult matchAndRewrite(vector::ContractionOp op,
549 PatternRewriter &rewriter) const override {
550
551 // Match i8xi8 -> i32 matrix multiply and accumulate.
552 VectorContractRewriterI8MM vcr;
553 if (failed(vcr.matchAndInit(op, rewriter)))
554 return failure();
555
556 Value result = vcr.lower(op, rewriter);
557 rewriter.replaceOp(op, result);
558
559 return success();
560 }
561};
562
563class LowerContractionToSVEBFMMLAPattern
564 : public OpRewritePattern<vector::ContractionOp> {
565public:
567 LogicalResult matchAndRewrite(vector::ContractionOp op,
568 PatternRewriter &rewriter) const override {
569
570 // Match bf16xbf16 -> f32 matrix multiply and accumulate.
571 VectorContractRewriterBfloat vcr;
572 if (failed(vcr.matchAndInit(op, rewriter)))
573 return failure();
574
575 Value result = vcr.lower(op, rewriter);
576 rewriter.replaceOp(op, result);
577
578 return success();
579 }
580};
581
582} // namespace
583
586 MLIRContext *context = patterns.getContext();
587 patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);
588}
589
592 MLIRContext *context = patterns.getContext();
593 patterns.add<LowerContractionToSVEBFMMLAPattern>(context, /*benefit=*/2);
594}
return success()
lhs
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
FloatType getBF16Type()
Definition Builders.cpp:37
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns)
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.