MLIR 23.0.0git
VectorContractBF16ToFMA.cpp
Go to the documentation of this file.
1//===- VectorContractBF16ToFMA.cpp-----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
19#include "mlir/IR/Dominance.h"
21
22#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Casting.h"
25
26using namespace mlir;
27using namespace mlir::vector;
28using namespace mlir::x86vector;
29
30// Verifies that the LHS and RHS operands of a vector.contract are load or
31// vector.transfer_read operations on a memref source buffer, and checks
32// their bounds, dimensions, offsets, and strides.
33static bool validateVectorContractOperands(Value prodOp, bool isVnni) {
34 Operation *defOp = prodOp.getDefiningOp();
35 if (!defOp)
36 return false;
37
38 if (auto readOp = prodOp.getDefiningOp<mlir::vector::TransferReadOp>()) {
39 if (readOp.hasOutOfBoundsDim())
40 return false;
41
42 if (!readOp.getPermutationMap().isMinorIdentity())
43 return false;
44 }
45
46 Value srcBuff;
48 llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
49 [&](auto readOp) {
50 srcBuff = readOp.getOperand(0);
51 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
52 readOp.getIndices().end());
53 });
54
55 if (!srcBuff)
56 return false;
57
58 // Return false, if the source is not a memref type
59 Type srcType = srcBuff.getType();
60 if (!llvm::isa<MemRefType>(srcType))
61 return false;
62
63 // Return false if the two innermost strides of the memref are not contiguous.
64 // The x86vector.avx.cvt.packed.even/odd.indexed_to_f32 operations require
65 // an eight-element tuple of bf16 values to be contiguous.
66 int dimsToCheck = isVnni ? 2 : 1;
67 if (!cast<mlir::MemRefType>(srcType).areTrailingDimsContiguous(dimsToCheck))
68 return false;
69
70 // Return false if the vnni offset of load or transfer_read is not zero.
71 if (isVnni && getConstantIntValue(indexVals.back()) != 0)
72 return false;
73
74 return true;
75}
76
77// This function retrieves the source operation of the load or transfer
78// reads and creates subviews for the BF16 packed-operations to
79// broadcast or load BF16 elements as F32 packed elements.
80//
81// Example(1) Unit Dim:
82// ```
83// vector.load %arg0[%c0, %c0, %c0]:memref<4x1x2xbf16>,vector<1x1x2xbf16>
84// ```
85// to
86// ```
87// memref.subview %arg0[%c0,%c0,%c1]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
88// memref.subview %arg0[%c0,%c0,%c0]:memref<4x1x2xbf16> to memref<1x1x1xbf16>
89// ```
90//
91// Example(2) Non-unit Dim:
92// ```
93// vector.load %arg1[%c0, %c0, %c0]:memref<1x32x2xbf16>,vector<1x8x2xbf16>
94// ```
95// to
96// ```
97// memref.subview %arg1[%c0,%c0,%c0]:memref<1x32x2xbf16> to memref<1x8x2xbf16>
98// ```
101 ArrayRef<int64_t> nonUnitDimShape, bool isUnitDim,
102 bool isVNNI) {
103
104 Operation *defOp = prodOp.getDefiningOp();
105
106 Value srcBuff;
108 llvm::TypeSwitch<Operation *>(defOp).Case<TransferReadOp, LoadOp>(
109 [&](auto readOp) {
110 srcBuff = readOp.getOperand(0);
111 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
112 readOp.getIndices().end());
113 });
114
115 int64_t mnDimSize = 1;
116 unsigned mnDimIdx = 0;
117
118 if (!isUnitDim) {
119 for (auto it : llvm::enumerate(nonUnitDimShape)) {
120 if (it.value() != 1) {
121 mnDimSize = it.value();
122 mnDimIdx = it.index();
123 break;
124 }
125 }
126 }
127
128 auto one = rewriter.getIndexAttr(1);
130
131 if (!isVNNI) {
132 SmallVector<OpFoldResult> strides(indexVals.size(), one);
133 SmallVector<OpFoldResult> sizes(indexVals.size(), one);
134 // Retrive twice the nonUnit dim BF16 element for both even and odd
135 // index elements.
136 if (!isUnitDim)
137 mnDimSize = 2 * mnDimSize;
138 sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
139 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
140 sizes, strides);
141 subviews.push_back(subview);
142 return subviews;
143 }
144
145 int vnniDimSize = isUnitDim ? 1 : 2;
146 auto nonVNNIDimSize = indexVals.size() - 1;
147 // Create the size and stride offsets.
148 SmallVector<OpFoldResult> strides(indexVals.size(), one);
149 SmallVector<OpFoldResult> sizes(nonVNNIDimSize, one);
150
151 sizes.push_back(rewriter.getIndexAttr(vnniDimSize));
152
153 // update the unit/nonUnit Dim size either it is A(LHS) or B(RHS).
154 sizes[mnDimIdx] = rewriter.getIndexAttr(mnDimSize);
155
156 // for unitDim, first broadcast odd element, so index is set to 1.
157 if (isUnitDim)
158 indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(1);
159
160 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuff, indexVals,
161 sizes, strides);
162 subviews.push_back(subview);
163
164 // For unit-dims, two subviews should be created for the odd and even
165 // element in the VNNI tuple (2xbf16) because x86vector.avx.bcst_to_f32.packed
166 // op loads and broadcast the first BF16 element into packed F32. It
167 // cannot distinguish between even and odd BF16 elements within a
168 // packed pair.
169 //
170 // Example:
171 // memref.subview %arg0[%c0,%c1]:memref<1x2xbf16> to memref<1x1xbf16> // Odd
172 // memref.subview %arg0[%c0,%c0]:memref<1x2xbf16> to memref<1x1xbf16> // Even
173 if (mnDimSize == 1) {
174 indexVals[indexVals.size() - 1] = rewriter.getIndexAttr(0);
175 sizes[indexVals.size() - 1] = rewriter.getIndexAttr(1);
176
177 auto unitDimEvenIdxSubview = memref::SubViewOp::create(
178 rewriter, loc, srcBuff, indexVals, sizes, strides);
179 subviews.push_back(unitDimEvenIdxSubview);
180 }
181
182 return subviews;
183}
184
185// Implements outer product contraction as a sequence of BF16-packed
186// operation even/odd loads and FMA operations.
187//
188// For example (VNNI packed):
189// ```
190// %1 = vector.load from memref (%m1) -> vector<1x1x2xbf16>
191// %2 = vector.load from memref (%m2) -> vector<1x8x2xbf16>
192// return vector.contract %1, %2, %arg1
193// ```
194// to
195// ```
196// %1 = x86vector.avx.bcst_to_f32.packed %m1[c1] -> vector<8xf32>
197// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
198// %3 = vector.fma %1, %2, %arg1
199// %4 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
200// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
201// return vector.fma %4, %5, %3
202// ```
203//
204// For example (Flat layout):
205// ```
206// %1 = vector.load from memref (%m1) -> vector<1x1xbf16>
207// %2 = vector.load from memref (%m2) -> vector<1x8xbf16>
208// %3 = vector.contract %1, %2, %arg1
209// %4 = vector.load from memref (%m2) -> vector<1x8xbf16>
210// %5 = vector.contract %1, %4, %arg2
211// scf.yield %3, %4
212// ```
213// to
214// ```
215// %1 = x86vector.avx.bcst_to_f32.packed %m1[c0] -> vector<8xf32>
216// %2 = x86vector.avx.cvt.packed.even.indexed_to_f32 %m2 -> vector<8xf32>
217// %3 = vector.fma %1, %2, %arg1
218// %4 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %m2 -> vector<8xf32>
219// %5 = vector.fma %1, %4, %arg2
220// scf.yield %3, %5
222 : public OpRewritePattern<vector::ContractionOp> {
223 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
224
225 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
226 PatternRewriter &rewriter) const override {
227
228 if (contractOp.getKind() != vector::CombiningKind::ADD)
229 return rewriter.notifyMatchFailure(contractOp,
230 "Expects add combining kind.");
231
232 // TODO: Move this validation to a common utility folder. Planned to
233 // do once (code refactoring), all architecture specific nanokernel
234 // passes are merged into the repo.
235 VectorType lhsTy = contractOp.getLhsType();
236 if (!lhsTy.getElementType().isBF16())
237 return rewriter.notifyMatchFailure(contractOp,
238 "Only BF16 lowering is supported.");
239
240 bool isVnni = isInVnniLayout(contractOp.getOperation(),
241 contractOp.getIndexingMapsArray(),
242 /*blockingFactor=*/2);
243
244 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
245 if (!accTy)
246 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
247
248 if (!accTy.getElementType().isF32())
249 return rewriter.notifyMatchFailure(
250 contractOp, "Only F32 acumulation supported for BF16 type.");
251
252 ArrayRef<int64_t> accShape = accTy.getShape();
253 llvm::SmallVector<int64_t> nonUnitDimAcc;
254 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
255 [](int64_t dim) { return dim != 1; });
256 if (nonUnitDimAcc.size() != 1)
257 return rewriter.notifyMatchFailure(
258 contractOp, "A or B should be a non-unit dim in acc.");
259
260 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
261 llvm::SmallVector<int64_t> nonUnitDimLhs;
262 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
263 [](int64_t dim) { return dim != 1; });
264
265 VectorType rhsTy = contractOp.getRhsType();
266 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
267 llvm::SmallVector<int64_t> nonUnitDimRhs;
268 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
269 [](int64_t dim) { return dim != 1; });
270
271 if (isVnni && (nonUnitDimLhs.size() - 1) > 0 &&
272 (nonUnitDimRhs.size() - 1) > 0)
273 return rewriter.notifyMatchFailure(contractOp,
274 "Excepts unit dimensions for either "
275 "LHS or RHS shape other than VNNI.");
276
277 if (isVnni && (nonUnitDimLhs.size() - 1) != 1 &&
278 (nonUnitDimRhs.size() - 1) != 1)
279 return rewriter.notifyMatchFailure(
280 contractOp,
281 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
282
283 if (!isVnni && nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
284 return rewriter.notifyMatchFailure(contractOp,
285 "Excepts unit dimensions for either "
286 "LHS or RHS shape.");
287
288 if (!isVnni && nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
289 return rewriter.notifyMatchFailure(
290 contractOp,
291 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
292
293 // Non-unit dimensions should match the vector length of BF16.
294 unsigned int nonUnitDim = nonUnitDimAcc.front();
295 if (nonUnitDim != 4 && nonUnitDim != 8)
296 return rewriter.notifyMatchFailure(
297 contractOp, "BF16 packed load operation expects non-unit (LHR or "
298 "RHS) dim and acc dim of size 4/8.");
299
300 if (!validateVectorContractOperands(contractOp.getLhs(), isVnni) ||
301 !validateVectorContractOperands(contractOp.getRhs(), isVnni)) {
302 return rewriter.notifyMatchFailure(
303 contractOp, "The LHS or RHS is in an invalid format. Either it has "
304 "false in-bounds, "
305 "a non-identity permutation map, a non-zero VNNI offset, "
306 "a non-memref "
307 "source, or a non-unit VNNI stride");
308 }
309
310 // Lower vector.contract to FMAs with help of BF16 packed ops.
311 auto loc = contractOp.getLoc();
312
313 // create the unit-dimension LHS or RHS subview and the
314 // corresponding non-unit dimension LHS or RHS subview on the other-side.
315 // For example, if LHS has type vector<1x1x2xbf16> and RHS has type
316 // vector<1x8x2xbf16>, we create two subview for the LHS and one subview
317 // for the RHS. In the opposite case (non-unit dimension on the LHS), we
318 // do vice-versa.
319
320 bool rhsHasMultipleNonUnitDims = nonUnitDimRhs.size() > 0;
321 if (isVnni) {
322 rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
323 }
324
325 // Select which operand is "unit" and which is "non-unit".
326 Value unitSrc =
327 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
328 Value nonUnitSrc =
329 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
330
331 ArrayRef<int64_t> nonUnitDimShape =
332 rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
333
334 // Get the pair vector.contract operation. The pair is decided on:
335 // (1) - the unitDim operand Lhs or Rhs should be same,
336 // (2) - the defining source memref should be same for nonUnitDim
337 // operation, (3) - the nonUnit dim offset difference between the
338 // vector.contracts should be 8.
339 vector::ContractionOp pairContractOp;
340 if (!isVnni) {
341 Operation *nextOp = contractOp;
342 while ((nextOp = nextOp->getNextNode())) {
343 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
344
345 if (!contOp)
346 continue;
347
348 if (validatePairVectorContract(contractOp, contOp,
349 rhsHasMultipleNonUnitDims,
350 nonUnitDimAcc.front())) {
351 pairContractOp = contOp;
352 break;
353 }
354 }
355
356 if (!pairContractOp)
357 return failure();
358
359 Operation *accReadOp0 =
360 traceToVectorReadLikeParentOperation(contractOp.getAcc());
361 Operation *accReadOp1 =
362 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
363
364 // Iterate down to find the users of contact operations until it is store
365 // or transfer_write.
366 Operation *resultWriteOp0 =
367 traceToVectorWriteLikeUserOperation(contractOp.getResult());
368 Operation *resultWriteOp1 =
369 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
370
371 if (!accReadOp0 || !accReadOp1)
372 return rewriter.notifyMatchFailure(
373 contractOp,
374 "Operand doesn't have load or transfer_read as its parent op");
375
376 if (!resultWriteOp0 || !resultWriteOp1)
377 return rewriter.notifyMatchFailure(
378 contractOp,
379 "The use of contract operations are neither vector.store "
380 "or transfer_write or has multiple users");
381
382 if (contractOp->getBlock() == accReadOp1->getBlock() &&
383 contractOp->isBeforeInBlock(accReadOp1))
384 return rewriter.notifyMatchFailure(
385 contractOp, "The load/read operation of pair contract operation is "
386 "after the contractOp");
387
388 if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
389 resultWriteOp0->isBeforeInBlock(pairContractOp)) {
390 return rewriter.notifyMatchFailure(
391 contractOp, "The store/write operation of contract operation is "
392 "before the pair contract operation");
393 }
394 }
395
396 // Build subviews.
397 auto unitDimSubview = getSubviewFromVectorInput(
398 loc, rewriter, unitSrc, nonUnitDimShape, true, isVnni);
399
400 auto nonUnitDimSubview = getSubviewFromVectorInput(
401 loc, rewriter, nonUnitSrc, nonUnitDimShape, false, isVnni);
402
403 auto castAcc = vector::ShapeCastOp::create(
404 rewriter, loc,
405 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
406 contractOp.getAcc());
407 VectorType dstType =
408 VectorType::get(nonUnitDimAcc.front(), rewriter.getF32Type());
409
410 if (!isVnni) {
411
412 // Validate and shuffle the accumulator
413 Operation *accReadOp0 =
414 traceToVectorReadLikeParentOperation(contractOp.getAcc());
415 Operation *accReadOp1 =
416 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
417
418 // Iterate down to find the users of contact operations until it is store
419 // or transfer_write.
420 Operation *resultWriteOp0 =
421 traceToVectorWriteLikeUserOperation(contractOp.getResult());
422 Operation *resultWriteOp1 =
423 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
424
425 // Shuffle the accumulators of the contract operations.
426 LogicalResult readShuffle =
427 shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
428 pairContractOp, nonUnitDim, accTy);
429
430 if (failed(readShuffle))
431 return rewriter.notifyMatchFailure(
432 contractOp, "Accumulator read is not by transfer_read or load");
433
434 // Shuffle the output of contract operations before its use.
435 LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
436 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDim, accTy);
437
438 if (failed(writeShuffle))
439 return rewriter.notifyMatchFailure(
440 contractOp,
441 "Write to accumulator is not by transfer_write or store");
442
443 rewriter.setInsertionPoint(contractOp);
444 castAcc = vector::ShapeCastOp::create(
445 rewriter, loc,
446 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
447 contractOp.getAcc());
448
449 auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
450 rewriter, loc, dstType, unitDimSubview[0]);
451 auto loadEvenIdxElementF32 =
452 x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
453 nonUnitDimSubview[0]);
454 auto evenIdxFMA =
455 vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
456 loadEvenIdxElementF32, castAcc);
457 auto castEvenFma =
458 vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
459 rewriter.replaceOp(contractOp, castEvenFma);
460
461 rewriter.setInsertionPoint(pairContractOp);
462 auto pairContOpLoc = pairContractOp.getLoc();
463 VectorType accTyPairCont =
464 dyn_cast<VectorType>(pairContractOp.getAccType());
465 auto castAccPairCont = vector::ShapeCastOp::create(
466 rewriter, pairContOpLoc,
467 VectorType::get(nonUnitDimAcc.front(),
468 accTyPairCont.getElementType()),
469 pairContractOp.getAcc());
470
471 auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
472 rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
473 auto oddIdxFMA = vector::FMAOp::create(
474 rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
475 loadOddIdxElementF32, castAccPairCont);
476 auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
477 accTyPairCont, oddIdxFMA);
478 rewriter.replaceOp(pairContractOp, castOddFma);
479
480 return success();
481 }
482
483 // Load, broadcast, and do FMA for odd indexed BF16 elements.
484 auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
485 rewriter, loc, dstType, unitDimSubview[0]);
486 auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
487 rewriter, loc, dstType, nonUnitDimSubview[0]);
488 auto oddIdxFMA =
489 vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
490 loadOddIdxElementF32, castAcc);
491
492 // Load, broadcast, and do FMA for even indexed BF16 elements.
493 auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
494 rewriter, loc, dstType, unitDimSubview[1]);
495 auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
496 rewriter, loc, dstType, nonUnitDimSubview[0]);
497 vector::FMAOp fma =
498 vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
499 loadEvenIdxElementF32, oddIdxFMA);
500
501 auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
502 rewriter.replaceOp(contractOp, castFma);
503 return success();
504 }
505};
506
return success()
static SmallVector< memref::SubViewOp > getSubviewFromVectorInput(Location loc, PatternRewriter &rewriter, Value prodOp, ArrayRef< int64_t > nonUnitDimShape, bool isUnitDim, bool isVNNI)
static bool validateVectorContractOperands(Value prodOp, bool isVnni)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
FloatType getF32Type()
Definition Builders.cpp:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns)
Operation * traceToVectorReadLikeParentOperation(Value v)
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})