MLIR 23.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
19#include "mlir/IR/Dominance.h"
21
22#include "mlir/Pass/Pass.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86;
28
29namespace {
30
31// Returns true if the A or B matrix vector is packed (shuffled) to
32// VNNI layout, already.
33static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
34 if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
35 if (isa<vector::ShuffleOp>(defOp))
36 return true;
37
38 if (isa<vector::ShapeCastOp>(defOp)) {
39 Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
40 if (isa<vector::ShuffleOp>(defOpShpCst))
41 return true;
42 }
43 }
44
45 return false;
46}
47
48static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
49 mlir::Operation *targetContract,
50 mlir::PatternRewriter &rewriter) {
51 for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
52 mlir::Operation *user = use.getOwner();
53 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
54 mlir::isa<mlir::scf::ForOp>(user)) {
55 rewriter.modifyOpInPlace(user, [&]() { use.set(newVal); });
56 }
57 }
58}
59
60// Function to convert the flat layout A or B matrix vector<32xbf16>
61// into VNNI packed layout using the vpunpack operations
62static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
63 mlir::Operation *opA,
64 mlir::Operation *opB,
65 mlir::vector::ContractionOp contractA,
66 mlir::vector::ContractionOp contractB,
67 int64_t nonUnitDimAcc,
68 mlir::VectorType Ty) {
69
70 bool opABeforeopB = opA->isBeforeInBlock(opB);
71
72 if (opABeforeopB)
73 rewriter.moveOpAfter(opB, opA);
74 else
75 rewriter.moveOpAfter(opA, opB);
76
77 mlir::Operation *insertAfter = opABeforeopB ? opB : opA;
78
79 rewriter.setInsertionPointAfter(insertAfter);
80 mlir::Location loc = insertAfter->getLoc();
81
82 auto elemTy = Ty.getElementType();
83 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
84
85 if (elemTy.isSignlessInteger(8))
86 flatTy = mlir::VectorType::get({2, nonUnitDimAcc / 2}, elemTy);
87
88 Value srcBuff;
89 SmallVector<Value> indexVals;
90
91 llvm::TypeSwitch<Operation *>(opA).Case<TransferReadOp, LoadOp>(
92 [&](auto readOp) {
93 srcBuff = readOp.getOperand(0);
94
95 auto indices = readOp.getIndices();
96 indexVals.reserve(indices.size());
97
98 llvm::transform(
99 indices, std::back_inserter(indexVals), [&](OpFoldResult ofr) {
100 return mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
101 });
102 });
103
104 int64_t srcRank = (dyn_cast<ShapedType>(srcBuff.getType())).getRank();
105 Value padding = ub::PoisonOp::create(rewriter, loc, elemTy);
106 auto map = AffineMap::getMinorIdentityMap(srcRank, flatTy.getRank(),
107 rewriter.getContext());
108 SmallVector<bool> inBounds(flatTy.getRank(), true);
109
110 Value vec1 = vector::TransferReadOp::create(
111 rewriter, loc, flatTy, srcBuff, indexVals, padding, map, inBounds);
112
113 if (elemTy.isSignlessInteger(8))
114 vec1 = vector::ShapeCastOp::create(
115 rewriter, loc, VectorType::get(nonUnitDimAcc, elemTy), vec1);
116
117 unsigned int offset = 1;
118 if (elemTy.isSignlessInteger(8))
119 offset = 2;
120
121 Value cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
122 auto nextIndx =
123 arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), cOffset,
124 indexVals[indexVals.size() - 2]);
125 indexVals[indexVals.size() - 2] = nextIndx;
126
127 Value vec2 = vector::TransferReadOp::create(
128 rewriter, loc, flatTy, srcBuff, indexVals, padding, map, inBounds);
129
130 if (elemTy.isSignlessInteger(8))
131 vec2 = vector::ShapeCastOp::create(
132 rewriter, loc, VectorType::get(nonUnitDimAcc, elemTy), vec2);
133
134 flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
135
136 static constexpr int64_t maskLo_bf16[] = {
137 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
138 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
139 static constexpr int64_t maskHi_bf16[] = {
140 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
141 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
142
143 static constexpr int64_t maskLo_int8_avx2[] = {
144 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
145 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59};
146 static constexpr int64_t maskHi_int8_avx2[] = {
147 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
148 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
149
150 static constexpr int64_t maskLo_int8_avx10[] = {
151 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99,
152 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107,
153 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115,
154 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123};
155 static constexpr int64_t maskHi_int8_avx10[] = {
156 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103,
157 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111,
158 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119,
159 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127};
160
161 mlir::DenseI64ArrayAttr maskLo = rewriter.getDenseI64ArrayAttr(maskLo_bf16);
162 mlir::DenseI64ArrayAttr maskHi = rewriter.getDenseI64ArrayAttr(maskHi_bf16);
163
164 if (elemTy.isSignlessInteger(8)) {
165 maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx10);
166 maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx10);
167
168 if (nonUnitDimAcc == 32) {
169 maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx2);
170 maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx2);
171 }
172 }
173
174 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
175 vec2, maskLo);
176 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
177 vec2, maskHi);
178
179 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
180 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
181
182 rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
183 rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
184}
185
186// Implements packed type outer product contraction as a sequence
187// of broadcast and packed dot-product operations.
188//
189// For example - for bf16 type (VNNI):
190// ```
191// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
192// ```
193// to
194// ```
195// vector.broadcast %lhs to <32xbf16>
196// x86.avx512.dot vector<32xbf16> -> vector<16xf32>
197// ```
198//
199// For example - for bf16 type (Flat layout):
200// ```
201// %1 = vector.load -> <2x16xbf16>
202// %2 = vector.load -> <2x16xbf16>
203// vector.contract <1x2xbf16>, %1 into <1x16xf32>
204// vector.contract <1x2xbf16>, %2 into <1x16xf32>
205// ```
206// to
207// ```
208// %1 = vector.load -> <2x16xbf16>
209// %2 = vector.load -> <2x16xbf16>
210// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
211// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
212// vector.broadcast %lhs to <32xbf16>
213// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
214// vector.broadcast %lhs to <32xbf16>
215// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
216// ```
217struct VectorContractToPackedTypeDotProduct
218 : public OpRewritePattern<vector::ContractionOp> {
219 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
220
221 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
222 PatternRewriter &rewriter) const override {
223
224 if (contractOp.getKind() != vector::CombiningKind::ADD)
225 return rewriter.notifyMatchFailure(contractOp,
226 "Expects add combining kind.");
227
228 VectorType lhsTy = contractOp.getLhsType();
229 if (!lhsTy.getElementType().isBF16() &&
230 !lhsTy.getElementType().isSignlessInteger(8))
231 return rewriter.notifyMatchFailure(
232 contractOp, "Only BF16/Int8 lowering is supported.");
233
234 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
235 bool isVnni =
236 isInVnniLayout(contractOp.getOperation(),
237 contractOp.getIndexingMapsArray(), blockingFactor);
238
239 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
240 if (!accTy)
241 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
242
243 ArrayRef<int64_t> accShape = accTy.getShape();
244 llvm::SmallVector<int64_t> nonUnitDimAcc;
245 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
246 [](int64_t dim) { return dim != 1; });
247 if (nonUnitDimAcc.size() != 1)
248 return rewriter.notifyMatchFailure(
249 contractOp, "A or B should be a non-unit dim in acc.");
250
251 int64_t nonUnitDimValue = nonUnitDimAcc.front();
252 // Non-unit dimensions should match the vector length of BF16 or Int8
253 // dot-product.
254 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
255 nonUnitDimValue != 8 && nonUnitDimValue != 16)
256 return rewriter.notifyMatchFailure(
257 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
258 "RHS) dim and acc dim of size 4/8/16.");
259
260 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
261 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
262 nonUnitDimAcc.front() == nonUnitDimValue)
263 return rewriter.notifyMatchFailure(
264 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
265 "RHS) dim and acc dim of size 4/8/16.");
266
267 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
268 llvm::SmallVector<int64_t> nonUnitDimLhs;
269 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
270 [](int64_t dim) { return dim != 1; });
271
272 VectorType rhsTy = contractOp.getRhsType();
273 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
274 llvm::SmallVector<int64_t> nonUnitDimRhs;
275 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
276 [](int64_t dim) { return dim != 1; });
277
278 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
279 return rewriter.notifyMatchFailure(contractOp,
280 "Excepts unit dimensions for either "
281 "LHS or RHS shape.");
282
283 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
284 return rewriter.notifyMatchFailure(
285 contractOp,
286 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
287
288 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
289 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
290 : nonUnitDimRhs.front();
291
292 if (!isVnni && (extraFlatDim != blockingFactor))
293 return rewriter.notifyMatchFailure(
294 contractOp, "The K or reduction dim for flat layout should be 2/4.");
295
296 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
297 (lhsTy.getElementType().isSignlessInteger(8) &&
298 !accTy.getElementType().isSignlessInteger(32)))
299 return rewriter.notifyMatchFailure(contractOp,
300 "Only F32 for BF16 or Int32 for Int8 "
301 "accumulation type is supported.");
302
303 Value unitDimOperand =
304 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
305 Value nonUnitDimOperand =
306 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
307
308 // If the A or B matrix vector of the contact operation is not packed, then
309 // find it's pair contract operation and pack (shuffle) them to VNNI packed.
310 if (!isVnni) {
311 vector::ContractionOp pairContractOp;
312 Operation *nextOp = contractOp;
313 while ((nextOp = nextOp->getNextNode())) {
314 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
315
316 if (!contOp)
317 continue;
318
319 if (validatePairVectorContract(contractOp, contOp,
320 rhsHasMultipleNonUnitDims,
321 nonUnitDimValue)) {
322 pairContractOp = contOp;
323 break;
324 }
325 }
326
327 // If the accumulators are shuffled we get nullptr else the
328 // transfer_read or load operations.
329 Operation *accRead =
330 traceToVectorReadLikeParentOperation(contractOp.getAcc());
331
332 if (!pairContractOp &&
333 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
334 return rewriter.notifyMatchFailure(contractOp,
335 "Could not find a contract pair");
336
337 // Validate and shuffle the accumulator
338 if (accRead) {
339 // Trace back to the load or transfer_read operations of the contract
340 // accumulators.
341 Operation *accReadOp0 =
342 traceToVectorReadLikeParentOperation(contractOp.getAcc());
343 Operation *accReadOp1 =
344 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
345
346 // Iterate down to find the users of contact operations until it is
347 // store or transfer_write.
348 Operation *resultWriteOp0 =
349 traceToVectorWriteLikeUserOperation(contractOp.getResult());
350 Operation *resultWriteOp1 =
351 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
352
353 if (!accReadOp0 || !accReadOp1)
354 return rewriter.notifyMatchFailure(
355 contractOp,
356 "Operands doesn't have load or transfer_read as it's parent op");
357
358 if (!resultWriteOp0 || !resultWriteOp1)
359 return rewriter.notifyMatchFailure(
360 contractOp,
361 "The use of contract operations are neither vector.store "
362 "or transfer_write or has multiple users.");
363
364 if (contractOp->getBlock() == accReadOp1->getBlock() &&
365 contractOp->isBeforeInBlock(accReadOp1))
366 return rewriter.notifyMatchFailure(
367 contractOp,
368 "The load/read operation of pair contract operation is "
369 "after the contractOp");
370
371 if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
372 resultWriteOp0->isBeforeInBlock(pairContractOp))
373 return rewriter.notifyMatchFailure(
374 contractOp, "The store/write operation of contract operation is "
375 "before the pair contract operation");
376 // Shuffle the accumulators of the contract operations.
377 LogicalResult readShuffle =
378 shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
379 pairContractOp, nonUnitDimValue, accTy);
380
381 if (failed(readShuffle))
382 return rewriter.notifyMatchFailure(
383 contractOp, "Accumulator read is not by transfer_read or load");
384
385 // Shuffle the output of contract operations before it's use.
386 LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
387 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
388
389 if (failed(writeShuffle))
390 return rewriter.notifyMatchFailure(
391 contractOp,
392 "Write to accumulator is not by transfer_write or store");
393 }
394
395 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
396 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
397 ? pairContractOp.getRhs()
398 : pairContractOp.getLhs();
399
400 // Get the non-packed A or B matrix's vector<32xbf16> elements.
401 Operation *nonUnitDimReadOp =
402 traceToVectorReadLikeParentOperation(nonUnitDimOperand);
403 Operation *nonUnitDimReadOpPairContract =
404 traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
405
406 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
407 return rewriter.notifyMatchFailure(
408 contractOp, "Could not find a valid contract pair");
409
410 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
411 ? contractOp.getRhsType()
412 : contractOp.getLhsType();
413
414 packNonUnitDimOperandToVNNI(
415 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
416 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
417 nonUnitDimTy);
418
419 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
420 : contractOp.getLhs();
421 }
422 }
423
424 rewriter.setInsertionPoint(contractOp);
425 auto loc = contractOp.getLoc();
426 auto castAcc = vector::ShapeCastOp::create(
427 rewriter, loc,
428 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
429 contractOp.getAcc());
430
431 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
432 ? contractOp.getRhsType()
433 : contractOp.getLhsType();
434 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
435 : contractOp.getRhsType();
436
437 Value dp;
438
439 auto castNonUnitDim = vector::ShapeCastOp::create(
440 rewriter, loc,
441 VectorType::get(blockingFactor * nonUnitDimValue,
442 nonUnitDimTy.getElementType()),
443 nonUnitDimOperand);
444
445 auto castUnitDim = vector::ShapeCastOp::create(
446 rewriter, loc,
447 VectorType::get(blockingFactor, unitDimTy.getElementType()),
448 unitDimOperand);
449 auto bitcastUnitDim = vector::BitCastOp::create(
450 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
451 castUnitDim);
452 auto broadcastUnitDim = vector::BroadcastOp::create(
453 rewriter, loc,
454 VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
455 bitcastUnitDim);
456 auto bitcastUnitDimPkType = vector::BitCastOp::create(
457 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
458
459 if (lhsTy.getElementType().isBF16()) {
460 dp = x86::avx512::DotBF16Op::create(
461 rewriter, loc,
462 VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
463 bitcastUnitDimPkType, castNonUnitDim);
464 }
465
466 if (lhsTy.getElementType().isSignlessInteger(8)) {
467 if (nonUnitDimAcc.front() == 16) {
468 dp = x86::avx10::AVX10DotInt8Op::create(
469 rewriter, loc,
470 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
471 castAcc, bitcastUnitDimPkType, castNonUnitDim);
472 } else {
473 dp = x86::avx::DotInt8Op::create(
474 rewriter, loc,
475 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
476 castAcc, bitcastUnitDimPkType, castNonUnitDim);
477 }
478 }
479
480 if (!dp)
481 return failure();
482
483 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
484 rewriter.replaceOp(contractOp, castDp);
485 return success();
486 }
487};
488
489} // namespace
490
492 RewritePatternSet &patterns) {
493 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
494}
return success()
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
FloatType getF32Type()
Definition Builders.cpp:47
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:292
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:194
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:154
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:242
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:352
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...