MLIR 23.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
19#include "mlir/IR/Dominance.h"
21
22#include "mlir/Pass/Pass.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86;
28
29namespace {
30
31// Returns true if the A or B matrix vector is packed (shuffled) to
32// VNNI layout, already.
33static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
34 if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
35 if (isa<vector::ShuffleOp>(defOp))
36 return true;
37
38 if (isa<vector::ShapeCastOp>(defOp)) {
39 Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
40 if (isa<vector::ShuffleOp>(defOpShpCst))
41 return true;
42 }
43 }
44
45 return false;
46}
47
48static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
49 mlir::Operation *targetContract,
50 mlir::PatternRewriter &rewriter) {
51 for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
52 mlir::Operation *user = use.getOwner();
53 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
54 mlir::isa<mlir::scf::ForOp>(user)) {
55 rewriter.modifyOpInPlace(user, [&]() { use.set(newVal); });
56 }
57 }
58}
59
60// Function to convert the flat layout A or B matrix vector<32xbf16>
61// into VNNI packed layout using the vpunpack operations
62static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
63 mlir::Operation *opA,
64 mlir::Operation *opB,
65 mlir::vector::ContractionOp contractA,
66 mlir::vector::ContractionOp contractB,
67 int64_t nonUnitDimAcc,
68 mlir::VectorType Ty) {
69
70 bool opABeforeopB = opA->isBeforeInBlock(opB);
71
72 if (opABeforeopB)
73 rewriter.moveOpAfter(opB, opA);
74 else
75 rewriter.moveOpAfter(opA, opB);
76
77 mlir::Operation *insertAfter = opABeforeopB ? opB : opA;
78
79 rewriter.setInsertionPointAfter(insertAfter);
80 mlir::Location loc = insertAfter->getLoc();
81
82 auto elemTy = Ty.getElementType();
83 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
84
85 Value srcBuff;
86 SmallVector<Value> indexVals;
87
88 llvm::TypeSwitch<Operation *>(opA).Case<TransferReadOp, LoadOp>(
89 [&](auto readOp) {
90 srcBuff = readOp.getOperand(0);
91
92 auto indices = readOp.getIndices();
93 indexVals.reserve(indices.size());
94
95 llvm::transform(
96 indices, std::back_inserter(indexVals), [&](OpFoldResult ofr) {
97 return mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
98 });
99 });
100
101 int64_t srcRank = (dyn_cast<ShapedType>(srcBuff.getType())).getRank();
102 Value padding = ub::PoisonOp::create(rewriter, loc, elemTy);
103 auto map = AffineMap::getMinorIdentityMap(srcRank, flatTy.getRank(),
104 rewriter.getContext());
105 SmallVector<bool> inBounds(flatTy.getRank(), true);
106
107 auto vec1 = vector::TransferReadOp::create(rewriter, loc, flatTy, srcBuff,
108 indexVals, padding, map, inBounds);
109
110 unsigned int offset = 1;
111 if (elemTy.isSignlessInteger(8))
112 offset = 2;
113
114 Value cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
115 auto nextIndx =
116 arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(), cOffset,
117 indexVals[indexVals.size() - 2]);
118 indexVals[indexVals.size() - 2] = nextIndx;
119
120 auto vec2 = vector::TransferReadOp::create(rewriter, loc, flatTy, srcBuff,
121 indexVals, padding, map, inBounds);
122
123 static constexpr int64_t maskLo_bf16[] = {
124 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
125 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
126 static constexpr int64_t maskHi_bf16[] = {
127 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
128 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
129
130 static constexpr int64_t maskLo_int8_avx2[] = {
131 0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51,
132 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59};
133 static constexpr int64_t maskHi_int8_avx2[] = {
134 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55,
135 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63};
136
137 static constexpr int64_t maskLo_int8_avx10[] = {
138 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3, 35, 67, 99,
139 8, 40, 72, 104, 9, 41, 73, 105, 10, 42, 74, 106, 11, 43, 75, 107,
140 16, 48, 80, 112, 17, 49, 81, 113, 18, 50, 82, 114, 19, 51, 83, 115,
141 24, 56, 88, 120, 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123};
142 static constexpr int64_t maskHi_int8_avx10[] = {
143 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39, 71, 103,
144 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110, 15, 47, 79, 111,
145 20, 52, 84, 116, 21, 53, 85, 117, 22, 54, 86, 118, 23, 55, 87, 119,
146 28, 60, 92, 124, 29, 61, 93, 125, 30, 62, 94, 126, 31, 63, 95, 127};
147
148 mlir::DenseI64ArrayAttr maskLo = rewriter.getDenseI64ArrayAttr(maskLo_bf16);
149 mlir::DenseI64ArrayAttr maskHi = rewriter.getDenseI64ArrayAttr(maskHi_bf16);
150
151 if (elemTy.isSignlessInteger(8)) {
152 maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx10);
153 maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx10);
154
155 if (nonUnitDimAcc == 32) {
156 maskLo = rewriter.getDenseI64ArrayAttr(maskLo_int8_avx2);
157 maskHi = rewriter.getDenseI64ArrayAttr(maskHi_int8_avx2);
158 }
159 }
160
161 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
162 vec2, maskLo);
163 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, vec1,
164 vec2, maskHi);
165
166 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
167 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
168
169 rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
170 rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
171}
172
173// Implements packed type outer product contraction as a sequence
174// of broadcast and packed dot-product operations.
175//
176// For example - for bf16 type (VNNI):
177// ```
178// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
179// ```
180// to
181// ```
182// vector.broadcast %lhs to <32xbf16>
183// x86.avx512.dot vector<32xbf16> -> vector<16xf32>
184// ```
185//
186// For example - for bf16 type (Flat layout):
187// ```
188// %1 = vector.load -> <2x16xbf16>
189// %2 = vector.load -> <2x16xbf16>
190// vector.contract <1x2xbf16>, %1 into <1x16xf32>
191// vector.contract <1x2xbf16>, %2 into <1x16xf32>
192// ```
193// to
194// ```
195// %1 = vector.load -> <2x16xbf16>
196// %2 = vector.load -> <2x16xbf16>
197// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
198// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
199// vector.broadcast %lhs to <32xbf16>
200// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
201// vector.broadcast %lhs to <32xbf16>
202// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
203// ```
204struct VectorContractToPackedTypeDotProduct
205 : public OpRewritePattern<vector::ContractionOp> {
206 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
207
208 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
209 PatternRewriter &rewriter) const override {
210
211 if (contractOp.getKind() != vector::CombiningKind::ADD)
212 return rewriter.notifyMatchFailure(contractOp,
213 "Expects add combining kind.");
214
215 VectorType lhsTy = contractOp.getLhsType();
216 if (!lhsTy.getElementType().isBF16() &&
217 !lhsTy.getElementType().isSignlessInteger(8))
218 return rewriter.notifyMatchFailure(
219 contractOp, "Only BF16/Int8 lowering is supported.");
220
221 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
222 bool isVnni =
223 isInVnniLayout(contractOp.getOperation(),
224 contractOp.getIndexingMapsArray(), blockingFactor);
225
226 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
227 if (!accTy)
228 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
229
230 ArrayRef<int64_t> accShape = accTy.getShape();
231 llvm::SmallVector<int64_t> nonUnitDimAcc;
232 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
233 [](int64_t dim) { return dim != 1; });
234 if (nonUnitDimAcc.size() != 1)
235 return rewriter.notifyMatchFailure(
236 contractOp, "A or B should be a non-unit dim in acc.");
237
238 int64_t nonUnitDimValue = nonUnitDimAcc.front();
239 // Non-unit dimensions should match the vector length of BF16 or Int8
240 // dot-product.
241 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
242 nonUnitDimValue != 8 && nonUnitDimValue != 16)
243 return rewriter.notifyMatchFailure(
244 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
245 "RHS) dim and acc dim of size 4/8/16.");
246
247 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
248 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
249 nonUnitDimAcc.front() == nonUnitDimValue)
250 return rewriter.notifyMatchFailure(
251 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
252 "RHS) dim and acc dim of size 4/8/16.");
253
254 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
255 llvm::SmallVector<int64_t> nonUnitDimLhs;
256 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
257 [](int64_t dim) { return dim != 1; });
258
259 VectorType rhsTy = contractOp.getRhsType();
260 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
261 llvm::SmallVector<int64_t> nonUnitDimRhs;
262 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
263 [](int64_t dim) { return dim != 1; });
264
265 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
266 return rewriter.notifyMatchFailure(contractOp,
267 "Excepts unit dimensions for either "
268 "LHS or RHS shape.");
269
270 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
271 return rewriter.notifyMatchFailure(
272 contractOp,
273 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
274
275 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
276 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
277 : nonUnitDimRhs.front();
278
279 if (!isVnni && (extraFlatDim != blockingFactor))
280 return rewriter.notifyMatchFailure(
281 contractOp, "The K or reduction dim for flat layout should be 2/4.");
282
283 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
284 (lhsTy.getElementType().isSignlessInteger(8) &&
285 !accTy.getElementType().isSignlessInteger(32)))
286 return rewriter.notifyMatchFailure(contractOp,
287 "Only F32 for BF16 or Int32 for Int8 "
288 "accumulation type is supported.");
289
290 Value unitDimOperand =
291 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
292 Value nonUnitDimOperand =
293 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
294
295 // If the A or B matrix vector of the contact operation is not packed, then
296 // find it's pair contract operation and pack (shuffle) them to VNNI packed.
297 if (!isVnni) {
298 vector::ContractionOp pairContractOp;
299 Operation *nextOp = contractOp;
300 while ((nextOp = nextOp->getNextNode())) {
301 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
302
303 if (!contOp)
304 continue;
305
306 if (validatePairVectorContract(contractOp, contOp,
307 rhsHasMultipleNonUnitDims,
308 nonUnitDimValue)) {
309 pairContractOp = contOp;
310 break;
311 }
312 }
313
314 // If the accumulators are shuffled we get nullptr else the
315 // transfer_read or load operations.
316 Operation *accRead =
317 traceToVectorReadLikeParentOperation(contractOp.getAcc());
318
319 if (!pairContractOp &&
320 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
321 return rewriter.notifyMatchFailure(contractOp,
322 "Could not find a contract pair");
323
324 // Validate and shuffle the accumulator
325 if (accRead) {
326 // Trace back to the load or transfer_read operations of the contract
327 // accumulators.
328 Operation *accReadOp0 =
329 traceToVectorReadLikeParentOperation(contractOp.getAcc());
330 Operation *accReadOp1 =
331 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
332
333 // Iterate down to find the users of contact operations until it is
334 // store or transfer_write.
335 Operation *resultWriteOp0 =
336 traceToVectorWriteLikeUserOperation(contractOp.getResult());
337 Operation *resultWriteOp1 =
338 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
339
340 if (!accReadOp0 || !accReadOp1)
341 return rewriter.notifyMatchFailure(
342 contractOp,
343 "Operands doesn't have load or transfer_read as it's parent op");
344
345 if (!resultWriteOp0 || !resultWriteOp1)
346 return rewriter.notifyMatchFailure(
347 contractOp,
348 "The use of contract operations are neither vector.store "
349 "or transfer_write or has multiple users.");
350
351 if (contractOp->getBlock() == accReadOp1->getBlock() &&
352 contractOp->isBeforeInBlock(accReadOp1))
353 return rewriter.notifyMatchFailure(
354 contractOp,
355 "The load/read operation of pair contract operation is "
356 "after the contractOp");
357
358 if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
359 resultWriteOp0->isBeforeInBlock(pairContractOp))
360 return rewriter.notifyMatchFailure(
361 contractOp, "The store/write operation of contract operation is "
362 "before the pair contract operation");
363 // Shuffle the accumulators of the contract operations.
364 LogicalResult readShuffle =
365 shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
366 pairContractOp, nonUnitDimValue, accTy);
367
368 if (failed(readShuffle))
369 return rewriter.notifyMatchFailure(
370 contractOp, "Accumulator read is not by transfer_read or load");
371
372 // Shuffle the output of contract operations before it's use.
373 LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
374 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
375
376 if (failed(writeShuffle))
377 return rewriter.notifyMatchFailure(
378 contractOp,
379 "Write to accumulator is not by transfer_write or store");
380 }
381
382 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
383 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
384 ? pairContractOp.getRhs()
385 : pairContractOp.getLhs();
386
387 // Get the non-packed A or B matrix's vector<32xbf16> elements.
388 Operation *nonUnitDimReadOp =
389 traceToVectorReadLikeParentOperation(nonUnitDimOperand);
390 Operation *nonUnitDimReadOpPairContract =
391 traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
392
393 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
394 return rewriter.notifyMatchFailure(
395 contractOp, "Could not find a valid contract pair");
396
397 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
398 ? contractOp.getRhsType()
399 : contractOp.getLhsType();
400
401 packNonUnitDimOperandToVNNI(
402 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
403 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
404 nonUnitDimTy);
405
406 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
407 : contractOp.getLhs();
408 }
409 }
410
411 rewriter.setInsertionPoint(contractOp);
412 auto loc = contractOp.getLoc();
413 auto castAcc = vector::ShapeCastOp::create(
414 rewriter, loc,
415 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
416 contractOp.getAcc());
417
418 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
419 ? contractOp.getRhsType()
420 : contractOp.getLhsType();
421 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
422 : contractOp.getRhsType();
423
424 Value dp;
425
426 auto castNonUnitDim = vector::ShapeCastOp::create(
427 rewriter, loc,
428 VectorType::get(blockingFactor * nonUnitDimValue,
429 nonUnitDimTy.getElementType()),
430 nonUnitDimOperand);
431
432 auto castUnitDim = vector::ShapeCastOp::create(
433 rewriter, loc,
434 VectorType::get(blockingFactor, unitDimTy.getElementType()),
435 unitDimOperand);
436 auto bitcastUnitDim = vector::BitCastOp::create(
437 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
438 castUnitDim);
439 auto broadcastUnitDim = vector::BroadcastOp::create(
440 rewriter, loc,
441 VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
442 bitcastUnitDim);
443 auto bitcastUnitDimPkType = vector::BitCastOp::create(
444 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
445
446 if (lhsTy.getElementType().isBF16()) {
447 dp = x86::avx512::DotBF16Op::create(
448 rewriter, loc,
449 VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
450 bitcastUnitDimPkType, castNonUnitDim);
451 }
452
453 if (lhsTy.getElementType().isSignlessInteger(8)) {
454 if (nonUnitDimAcc.front() == 16) {
455 dp = x86::avx10::AVX10DotInt8Op::create(
456 rewriter, loc,
457 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
458 castAcc, bitcastUnitDimPkType, castNonUnitDim);
459 } else {
460 dp = x86::avx::DotInt8Op::create(
461 rewriter, loc,
462 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
463 castAcc, bitcastUnitDimPkType, castNonUnitDim);
464 }
465 }
466
467 if (!dp)
468 return failure();
469
470 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
471 rewriter.replaceOp(contractOp, castDp);
472 return success();
473 }
474};
475
476} // namespace
477
479 RewritePatternSet &patterns) {
480 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
481}
return success()
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
FloatType getF32Type()
Definition Builders.cpp:47
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:292
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:194
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:154
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:242
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:352
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...