MLIR 23.0.0git
VectorContractToAMXDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToAMXDotProduct.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
17#include "mlir/IR/Dominance.h"
19#include "llvm/Support/Casting.h"
20
21#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::vector;
26using namespace mlir::x86;
27
28namespace {
29
30// Recursively follows single-use values through scf.yield operations
31// and returns the first non-yield user result in the contraction chain.
32static Value contractionUsersAfterYield(Value v) {
33 if (v.getNumUses() != 1)
34 return nullptr;
35
36 OpOperand &use = *v.use_begin();
37 Operation *user = use.getOwner();
38
39 if (!isa<scf::YieldOp>(user))
40 return v;
41
42 auto yield = cast<scf::YieldOp>(user);
43 Operation *parent = yield->getParentOp();
44 unsigned idx = use.getOperandNumber();
45
46 return contractionUsersAfterYield(parent->getResult(idx));
47}
48
49// Function to collapse the last two dimension (vnni and k) to help the
50// amx.tile_load to correctly load the packed element type.
52 Value input) {
53 ShapedType inputType = cast<ShapedType>(input.getType());
54 int64_t firstDimToCollapse = inputType.getRank() - 2;
55
56 if (inputType.getRank() == 1)
57 return input;
58
60 for (int64_t i = 0; i < firstDimToCollapse; ++i)
61 reassociation.push_back(ReassociationIndices{i});
62
63 ReassociationIndices collapsedIndices;
64 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
65 collapsedIndices.push_back(i);
66
67 reassociation.push_back(collapsedIndices);
68 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
69}
70
71// Get the MemRef source and offset index for the operands of
72// vector.contract.
73static FailureOr<std::pair<Value, SmallVector<Value>>>
74getSrcIndxValue(OpBuilder &rewriter, Location loc, Value operand,
75 bool isNotAcc) {
76 Operation *defOp = operand.getDefiningOp();
77 if (!defOp)
78 return failure();
79
80 Value srcBuff;
83 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
84 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
85 readOp.getIndices().end());
86 srcBuff = readOp.getOperand(0);
87 });
88
89 if (!srcBuff)
90 return failure();
91
92 if (isNotAcc)
93 indexVals.pop_back();
94
96 indices.reserve(indexVals.size());
97
98 for (OpFoldResult ofr : indexVals) {
99 indices.push_back(
100 mlir::getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101 }
102
103 if (isNotAcc) {
104 srcBuff = collapseInnerDims(rewriter, loc, srcBuff);
105 }
106
107 return std::make_pair(srcBuff, indices);
108}
109
110// Function to validate the loop step value.
111static LogicalResult validateLoopStep(OpBuilder &rewriter, Value step,
112 int64_t value) {
113
114 auto cst = step.getDefiningOp<arith::ConstantIndexOp>();
115 if (!cst)
116 return failure();
117
118 if (cst.value() != value && cst.value() != 1)
119 return failure();
120
121 return success();
122}
123
124// Function to validate the vector.contract operation.
125static LogicalResult validateContractOps(OpBuilder &rewriter,
126 vector::ContractionOp contractOp,
127 unsigned int blockingFactor,
128 Value srcBuffLhs, Value srcBuffRhs,
129 bool srcValidate) {
130
131 if (srcValidate) {
132 // Get the MemRef buffer of LHS operand.
133 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
134 contractOp.getLhs(), false);
135 if (failed(srcIndxLhs))
136 return failure();
137 auto [buffLhs, indicesLhs] = *srcIndxLhs;
138
139 // Get the MemRef buffer of RHS operand.
140 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
141 contractOp.getRhs(), false);
142 if (failed(srcIndxRhs))
143 return failure();
144 auto [buffRhs, indicesRhs] = *srcIndxRhs;
145
146 // Return failure if the Memref buff didn't match.
147 if (buffLhs != srcBuffLhs)
148 return failure();
149
150 if (buffRhs != srcBuffRhs)
151 return failure();
152 }
153
154 if (!contractionUsersAfterYield(contractOp.getResult()))
155 return failure();
156
157 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
158 if (!accTy)
159 return failure();
160
161 // The Accumulator dims should be 16 or 1. Like <1x16x16> or <16x16>.
162 ArrayRef<int64_t> accShape = accTy.getShape();
163 llvm::SmallVector<int64_t> nonUnitDimAcc;
164 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
165 [](int64_t dim) { return (dim != 16 && dim != 1); });
166
167 if (nonUnitDimAcc.size() != 0)
168 return failure();
169
170 // The LHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
171 // <16x16x4>. The vnni dims should be 2 or 4.
172 VectorType lhsTy = contractOp.getLhsType();
173 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
174 llvm::SmallVector<int64_t> nonUnitDimLhs;
175 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
176 [](int64_t dim) { return (dim != 16 && dim != 1); });
177
178 if (nonUnitDimLhs.size() != 1)
179 return failure();
180
181 if (nonUnitDimLhs[0] != blockingFactor)
182 return failure();
183
184 // The RHS dims should be 16 or vnni or 1. Like <1x16x16x2> or
185 // <16x16x4>. The vnni dims should be 2 or 4.
186 VectorType rhsTy = contractOp.getRhsType();
187 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
188 llvm::SmallVector<int64_t> nonUnitDimRhs;
189 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
190 [](int64_t dim) { return (dim != 16 && dim != 1); });
191
192 if (nonUnitDimRhs.size() != 1)
193 return failure();
194
195 if (nonUnitDimRhs[0] != blockingFactor)
196 return failure();
197
198 return success();
199}
200
201// Returns the loop index position to get mapped during the
202// MemRef type clone.
203static unsigned getIndexPosition(Value operand, scf::ForOp loop) {
204 Value iv = loop.getInductionVar();
205
206 Value srcBuff;
208 .Case<TransferReadOp, LoadOp>(
209 [&](auto readOp) { srcBuff = readOp.getOperand(0); });
210
211 auto subview = srcBuff.getDefiningOp<memref::SubViewOp>();
212 if (!subview)
213 return 0;
214
215 auto offsets = subview.getOffsets();
216
217 for (auto it : llvm::enumerate(offsets)) {
218 if (it.value() == iv)
219 return it.index();
220 }
221
222 return 0;
223}
224
225// Creates amx.tile_loads.
226static amx::TileLoadOp createTileLoads(OpBuilder &rewriter, Location loc,
227 Value operand, Value mat, Type ipType,
228 bool rhs, unsigned int offset,
229 bool isVnni) {
230
231 auto srcIndx = getSrcIndxValue(rewriter, loc, operand, false);
232 auto [srcBuff, indices] = *srcIndx;
233 if (isVnni) {
234 indices.pop_back();
235 }
236
237 if (rhs && isVnni) {
238 auto cOffset = arith::ConstantIndexOp::create(rewriter, loc, offset);
239 indices[indices.size() - 1] = arith::MulIOp::create(
240 rewriter, loc, indices[indices.size() - 1], cOffset);
241 }
242
243 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
244 return amx::TileLoadOp::create(rewriter, loc, tileType, mat, indices);
245}
246
247static void performShuffle(OpBuilder &rewriter, Location loc, Value matB,
248 Type ipType, unsigned int offset, Value packedBuffer,
249 Value indxToStoreInBuffer) {
250
251 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
252 Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
253 SmallVector<Value> subviewOffset(
254 llvm::cast<MemRefType>(matB.getType()).getRank(), c0);
255
256 Value cStep = arith::ConstantIndexOp::create(rewriter, loc, offset);
257 Value cBound = arith::ConstantIndexOp::create(rewriter, loc, (16 * offset));
258 Value offsetIndx =
259 arith::ConstantIndexOp::create(rewriter, loc, (offset / 2));
260
261 scf::ForOp::create(
262 rewriter, loc, c0, cBound, cStep, ValueRange{},
263 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
264 ValueRange iterArgs) {
265 subviewOffset[subviewOffset.size() - 2] = iv;
266
267 // Retrieve two rows of vector (32) for int8 and f8 type. For bf16,
268 // retrieve one row of vector (32).
269 auto vectorType = VectorType::get({2, (16 * (offset / 2))}, ipType);
270 if (ipType.isBF16())
271 vectorType = VectorType::get((16 * offset), ipType);
272
273 int64_t srcRank = (dyn_cast<ShapedType>(matB.getType())).getRank();
274 Value padding = ub::PoisonOp::create(rewriter, loc, ipType);
275 auto map = AffineMap::getMinorIdentityMap(srcRank, vectorType.getRank(),
276 rewriter.getContext());
277 SmallVector<bool> inBounds(vectorType.getRank(), true);
278 Value vec1 = vector::TransferReadOp::create(
279 rewriter, loc, vectorType, matB, ValueRange(subviewOffset), padding,
280 map, inBounds);
281
282 if (!ipType.isBF16())
283 vec1 = vector::ShapeCastOp::create(
284 rewriter, loc, VectorType::get((16 * offset), ipType), vec1);
285
286 // Increment the iv by 1 or 2 based on the type to load the next 32/64
287 // elements
288 Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
289 subviewOffset[subviewOffset.size() - 2] = incIV;
290
291 Value vec2 = vector::TransferReadOp::create(
292 rewriter, loc, vectorType, matB, ValueRange(subviewOffset), padding,
293 map, inBounds);
294 if (!ipType.isBF16())
295 vec2 = vector::ShapeCastOp::create(
296 rewriter, loc, VectorType::get((16 * offset), ipType), vec2);
297
298 vector::ShuffleOp shuffle1;
299 vector::ShuffleOp shuffle2;
300
301 if (ipType.isBF16()) {
302
303 shuffle1 = vector::ShuffleOp::create(
304 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
305 vec2,
306 ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
307 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
308 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
309
310 shuffle2 = vector::ShuffleOp::create(
311 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
312 vec2,
313 ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
314 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
315 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
316 }
317
318 if (ipType.isSignlessInteger(8) || ipType.isF8E5M2() ||
319 ipType.isF8E4M3FN()) {
320
321 shuffle1 = vector::ShuffleOp::create(
322 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
323 vec2,
325 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
326 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
327 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
328 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
329 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
330
331 shuffle2 = vector::ShuffleOp::create(
332 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
333 vec2,
335 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
336 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
337 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
338 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
339 30, 62, 94, 126, 31, 63, 95, 127});
340 }
341
342 // iv to store the shuffled elements
343 Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
344 Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
345
346 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
347 ValueRange{indxToStoreInBuffer, ivShuff1, c0});
348 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
349 ValueRange{indxToStoreInBuffer, ivShuff2, c0});
350
351 scf::YieldOp::create(nestedBuilder, loc);
352 });
353}
354
356packInputs(OpBuilder &rewriter, Location loc,
358 unsigned int offset, Value packedBuffer, bool pack,
359 Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
360
362 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
363 Value c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
364
365 for (size_t j = 0; j < ops.size(); j++) {
366 for (size_t i = 0; i < ops.size(); i++) {
367
368 if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16)) {
369
370 Operation *readOpRhs = ops[j].getRhs().getDefiningOp();
371 auto itRhs = readsToTileLoads.find(readOpRhs);
372 if (itRhs != readsToTileLoads.end()) {
373 continue;
374 }
375
376 if (pack) {
377 performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
378 indxToStoreInBuffer);
379 }
380
381 amx::TileType tileType =
382 amx::TileType::get({16, (16 * offset)}, ipType);
383 auto loadRow1 =
384 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
385 ValueRange{indxToLoadFromMatB, c0, c0});
386
387 auto loadRow2 =
388 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
389 ValueRange{indxToLoadFromMatB, c16, c0});
390
391 readsToTileLoads.try_emplace(readOpRhs, loadRow1);
392 readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
393 }
394 }
395 }
396
397 return readsToTileLoads;
398}
399
400// Creates tiled amx dot-products.
402createTiledDp(OpBuilder &rewriter, Location loc,
404 Type ipType, Type opType, ValueRange accIterArgs,
405 unsigned int offset, bool isVnni, Value packedBuffer, bool pack,
406 Value indxToStoreInBuffer, Value indxToLoadFromMatB) {
407
408 if (isVnni) {
409 matA = collapseInnerDims(rewriter, loc, matA);
410 matB = collapseInnerDims(rewriter, loc, matB);
411 }
412
413 SmallVector<Value> accumulators;
414 // Stores the amx.tile_load operation vs it's equivalent vector tranfer_read
415 // or load operations.
417
418 // function call to online pack the input B matrix
419 if (!isVnni) {
420 readsToTileLoads =
421 packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
422 indxToStoreInBuffer, indxToLoadFromMatB);
423 }
424
425 // Iterate over the contraction operations and compute the tiled dot-product.
426 for (size_t i = 0; i < ops.size(); i++) {
427
428 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
429 amx::TileLoadOp tilesLhs;
430 auto itLhs = readsToTileLoads.find(readOpLhs);
431 if (itLhs != readsToTileLoads.end()) {
432 tilesLhs = itLhs->second;
433 } else {
434 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
435 false, offset, isVnni);
436 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
437 }
438
439 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
440 amx::TileLoadOp tilesRhs;
441 auto itRhs = readsToTileLoads.find(readOpRhs);
442 if (itRhs != readsToTileLoads.end()) {
443 tilesRhs = itRhs->second;
444 } else {
445 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
446 true, offset, isVnni);
447 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
448 }
449
450 auto accTileType = amx::TileType::get({16, 16}, opType);
451
452 Value dp;
453 if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
454 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
455 tilesRhs, accIterArgs[i]);
456
457 if (ipType.isSignlessInteger(8))
458 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
459 tilesRhs, accIterArgs[i]);
460
461 accumulators.push_back(dp);
462 }
463 return accumulators;
464}
465
466static SmallVector<Value> createTileZeros(OpBuilder &rewriter, Location loc,
467 Type opType, scf::ForOp outerLoop,
468 int64_t size) {
469 rewriter.setInsertionPoint(outerLoop);
470
471 SmallVector<Value> loopItrArgs;
472 auto zeroTileType = amx::TileType::get({16, 16}, opType);
473
474 for (int i = 0; i < size; i++) {
475 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
476 loopItrArgs.push_back(zeroTile);
477 }
478 return loopItrArgs;
479}
480
481static Value getIndxToLoadStoreFromPckBuffer(
482 OpBuilder &rewriter, Location loc, Value ivInnerLoop, Value ivOuterLoop,
483 bool isInnerLoopUBHasOddQuot, bool isInnerLoopUBLarger, bool pack,
484 unsigned int blockingFactor) {
485
486 Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
487 Value packOffset =
488 arith::ConstantIndexOp::create(rewriter, loc, (16 * blockingFactor));
489
490 Value quotientInnerLoop =
491 arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
492 Value remInnerLoop = arith::RemUIOp::create(
493 rewriter, loc, rewriter.getIndexType(), quotientInnerLoop, c2);
494
495 if (!isInnerLoopUBLarger && !pack) {
496 remInnerLoop = arith::RemUIOp::create(
497 rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
498 }
499
500 if (isInnerLoopUBHasOddQuot) {
501 auto remOuterLoop = arith::RemUIOp::create(
502 rewriter, loc, rewriter.getIndexType(), ivOuterLoop, c2);
503 auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
504 remInnerLoop, remOuterLoop);
505 remInnerLoop = arith::RemUIOp::create(rewriter, loc,
506 rewriter.getIndexType(), remAdd, c2);
507 }
508
509 return remInnerLoop;
510}
511
512static scf::ForOp
513createLoops(OpBuilder &rewriter, Location loc, Value lowerBound,
514 Value upperBound, Value step, SmallVector<Value> loopItrArgs,
515 Type ipType, Type opType, unsigned int blockingFactor, bool isVnni,
516 Operation *vectorOpLhs, Operation *vectorOpRhs,
517 vector::ContractionOp contractOp, scf::ForOp outerLoop,
518 scf::ForOp innerLoop, SmallVector<vector::ContractionOp> ops,
519 Value ivOuterLoop, Value packedBuffer, bool pack,
520 arith::ConstantIndexOp innerLoopIndex, bool isInnerLoopUBLarger,
521 bool isInnerLoopUBHasOddQuot) {
522
523 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
524 Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
525 Value c2 = arith::ConstantIndexOp::create(rewriter, loc, 2);
526
527 int64_t offset = 16 * blockingFactor;
528 if (auto cst = step.getDefiningOp<arith::ConstantIndexOp>())
529 offset = cst.value();
530
531 auto newLoop = scf::ForOp::create(
532 rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
533 [&](OpBuilder &rewriterNewInnerLoop, Location locNewInnerLoop,
534 Value ivNewInnerLoop, ValueRange iterArgsNewInnerLoop) {
535 IRMapping mapping;
536 if (outerLoop)
537 mapping.map(vectorOpLhs->getOperand(
538 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
539 ivOuterLoop);
540
541 mapping.map(vectorOpLhs->getOperand(
542 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
543 ivNewInnerLoop);
544 auto lhsClone = rewriterNewInnerLoop.clone(*vectorOpLhs, mapping);
545
546 Value indxToStoreInBuffer = c0;
547 Value indxToLoadFromBuffer = c0;
548 if (!isVnni) {
549 if (outerLoop) {
550 if (innerLoopIndex.value() == 0) {
551 if (pack) {
552 ivNewInnerLoop = c0;
553 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
554 c1, ivOuterLoop);
555
556 if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
557 indxToStoreInBuffer = arith::RemUIOp::create(
558 rewriter, locNewInnerLoop, rewriter.getIndexType(),
559 ivOuterLoop, c2);
560 }
561
562 Value indxToLoadFromMatB = arith::AddIOp::create(
563 rewriter, loc, indxToStoreInBuffer, c1);
564 indxToLoadFromBuffer = arith::RemUIOp::create(
565 rewriter, loc, rewriter.getIndexType(), indxToLoadFromMatB,
566 c2);
567 }
568
569 } else {
571 rewriter, locNewInnerLoop, offset);
572 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
573 nLoadIndx, ivNewInnerLoop);
574 indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
575 rewriter, loc, ivNewInnerLoop, ivOuterLoop,
576 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
577 blockingFactor);
578 Value indxToLoadFromMatB =
579 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
580 indxToLoadFromBuffer =
581 arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
582 indxToLoadFromMatB, c2);
583 }
584 } else {
585 if (pack) {
587 rewriter, locNewInnerLoop, offset);
588 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
589 nLoadIndx, ivNewInnerLoop);
590 Value quotient_K = arith::DivUIOp::create(
591 rewriter, loc, ivNewInnerLoop, nLoadIndx);
592 indxToStoreInBuffer = arith::RemUIOp::create(
593 rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
594
595 Value indxToLoadFromMatB =
596 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
597 indxToLoadFromBuffer =
598 arith::RemUIOp::create(rewriter, loc, rewriter.getIndexType(),
599 indxToLoadFromMatB, c2);
600 }
601 }
602 }
603 IRMapping rhsMapping;
604
605 Value matB;
606 Operation *rhsOp = vectorOpRhs;
607
608 // Clone for the subview type operations
609 if (rhsOp->getNumOperands() > 0) {
610
611 if (outerLoop) {
612 int64_t outerPos = getIndexPosition(contractOp.getRhs(), outerLoop);
613
614 if (outerPos >= 0) {
615 unsigned operandIdx = static_cast<unsigned>(outerPos + 1);
616
617 if (operandIdx < rhsOp->getNumOperands())
618 rhsMapping.map(rhsOp->getOperand(operandIdx), ivOuterLoop);
619 }
620 }
621
622 int64_t innerPos = getIndexPosition(contractOp.getRhs(), innerLoop);
623
624 if (innerPos >= 0) {
625 unsigned operandIdx = static_cast<unsigned>(innerPos + 1);
626
627 if (operandIdx < rhsOp->getNumOperands())
628 rhsMapping.map(rhsOp->getOperand(operandIdx), ivNewInnerLoop);
629 }
630
631 auto rhsClone = rewriterNewInnerLoop.clone(*rhsOp, rhsMapping);
632 matB = rhsClone->getResult(0);
633
634 } else {
635 // The mat B is of kind 'memref.get_global @__constant'
636 matB = rhsOp->getResult(0);
637 }
638
639 if (!isVnni) {
640 if (outerLoop) {
641 if (!pack) {
643 rewriter, locNewInnerLoop, offset);
644 matB = Value();
645 indxToLoadFromBuffer = c0;
646 indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
647 rewriter, loc, nLoadIndx, ivOuterLoop,
648 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
649 blockingFactor);
650 }
651 } else {
652 if (!pack) {
654 rewriter, locNewInnerLoop, offset);
655 matB = Value();
656 Value quotient_K = arith::DivUIOp::create(
657 rewriter, loc, ivNewInnerLoop, nLoadIndx);
658 indxToLoadFromBuffer = arith::RemUIOp::create(
659 rewriter, loc, rewriter.getIndexType(), quotient_K, c2);
660 }
661 }
662 }
663 // compute tiled dot-product
664 SmallVector<Value> accumulators = createTiledDp(
665 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
666 ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
667 packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
668
669 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
670 accumulators);
671 });
672
673 return newLoop;
674}
675
676// Implements tiled dot-product operation for a vector.contract operation or a
677// sequence of vector.contracts inside the reduction loops.
678//
679// For example:
680// Case 1: register blocked vector.contract with prepacked input
681// ```
682// vector.transfer_read %arg0 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
683// vector.transfer_read %arg1 {{.}*} : memref<16x32x4xi8>, vector<16x16x4xi8>
684// vector.contract <16x16x4xi8>, <16x16x4xi8> into <16x16xi32>
685// vector.transfer_write arg2 {{.}*} : vector<16x16xi32>, memref<32x32xi32>
686// ```
687// to
688// ```
689// amx.tile_load %arg0 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
690// amx.tile_load %arg1 {{.}*} : memref<16x32x4xi8> into !amx.tile<16x64xi8>
691// amx.tile_muli !amx.tile<16x64xi8> -> !amx.tile<16x16xi32>
692// amx.tile_store %arg2{{.}*} : memref<32x32xi32>, !amx.tile<16x16xi32>
693// ```
694//
695//
696// Case2: vector.contract with register blocked
697//
698// Output IR with online packing (with s/w pipeline advantage):
699// s/w pipeline: load, pack to VNNI, and store the B sub matrix
700// of the 0th batch-reduce and K iteration.
701// scf.for (0 to 31) {
702// - load 0th and 1st vector<32xbf16>, pack into VNNI, store the
703// first shuffle in 0th and 2nd shuffle in 16th index of the
704// buffer.
705// }
706// scf.for (0 to br-2) { batch-reduce loop
707// scf.for (0 to k-2) { K loop
708// - load A matrix
709// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
710// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
711// iteration from the buffer (d) compute the tiled dot-product
712// }
713// Last iteration of the the K Loop (k-1) {
714// - load A matrix
715// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
716// matrix for the next batch-reduce + K loop iteration (c) load VNNI pack B
717// matrix of K iteration from the buffer (d) compute the tiled dot-product
718// }
719// }
720// Last iteration of the batch-reduce loop (br-1) {
721// scf.for (0 to k-2) { K loop
722// - load A matrix
723// - scf.loop for s/w pipeline: load, pack to VNNI, and store the B sub
724// matrix for the next K loop iteration (c) load VNNI pack B matrix of K
725// iteration from the buffer (d) compute the tiled dot-product
726// }
727// Last iteration of the the K Loop (k-1) {
728// - load A matrix
729// - load VNNI pack B matrix of K iteration from the buffer
730// - compute the tiled dot-product
731// }
732// }
733//
734// scf.for (0 to M)
735// scf.for (0 to N)
736// - Load the ith and i+1th acc
737// - Shuffle them as we packed using vpunpack
738// - Load C matrix and do arith.add with the shuffle
739// - Store back into C matrix
740struct VectorContractToAMXDotProduct
741 : public OpRewritePattern<vector::ContractionOp> {
742 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
743
744 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
745 PatternRewriter &rewriter) const override {
746
747 if (contractOp.getKind() != vector::CombiningKind::ADD)
748 return rewriter.notifyMatchFailure(contractOp,
749 "Expects add combining kind.");
750
751 unsigned int blockingFactor =
752 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
753 bool isVnni =
754 isInVnniLayout(contractOp.getOperation(),
755 contractOp.getIndexingMapsArray(), blockingFactor);
756
757 VectorType lhsTy = contractOp.getLhsType();
758 if (!lhsTy.getElementType().isBF16() &&
759 !lhsTy.getElementType().isSignlessInteger(8) &&
760 !lhsTy.getElementType().isF8E4M3FN() &&
761 !lhsTy.getElementType().isF8E5M2())
762 return rewriter.notifyMatchFailure(
763 contractOp, "Only BF16/Int8/F8 lowering is supported.");
764
765 if (lhsTy.getElementType() != contractOp.getRhsType().getElementType())
766 return rewriter.notifyMatchFailure(
767 contractOp, "Contraction should have same lhs and rhs type.");
768
769 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
770 if (!accTy)
771 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
772
773 if (((lhsTy.getElementType().isBF16() ||
774 lhsTy.getElementType().isF8E4M3FN() ||
775 lhsTy.getElementType().isF8E5M2()) &&
776 !accTy.getElementType().isF32()) ||
777 (lhsTy.getElementType().isSignlessInteger(8) &&
778 !accTy.getElementType().isSignlessInteger(32)))
779 return rewriter.notifyMatchFailure(contractOp,
780 "Only F32 for BF16 or Int32 for Int8 "
781 "accumulation type is supported.");
782
783 Operation *accReadOp =
784 traceToVectorReadLikeParentOperation(contractOp.getAcc());
785
786 Operation *resultWriteOp =
787 traceToVectorWriteLikeUserOperation(contractOp.getResult());
788
789 if (!accReadOp || !resultWriteOp)
790 return rewriter.notifyMatchFailure(
791 contractOp, "The ACC operand of the vector.contract should be a "
792 "transfer_read or a load. And, the result should be "
793 "stored using transfer_write or store.");
794
795 Type ipType = rewriter.getBF16Type();
796 Type opType = rewriter.getF32Type();
797
798 if (lhsTy.getElementType().isSignlessInteger(8)) {
799 ipType = rewriter.getIntegerType(8);
800 opType = rewriter.getIntegerType(32);
801 }
802
803 if (lhsTy.getElementType().isF8E4M3FN())
804 ipType = rewriter.getF8E4M3FNType();
805
806 if (lhsTy.getElementType().isF8E5M2())
807 ipType = rewriter.getF8E5M2Type();
808
809 if (accReadOp->getBlock() == contractOp->getBlock() &&
810 resultWriteOp->getBlock() != contractOp->getBlock())
811 return rewriter.notifyMatchFailure(
812 contractOp, "The accumulator store is in different block.");
813
814 if (accReadOp->getBlock() != contractOp->getBlock() &&
815 resultWriteOp->getBlock() == contractOp->getBlock())
816 return rewriter.notifyMatchFailure(
817 contractOp, "The accumulator read is in different block.");
818
819 unsigned int dimValue = blockingFactor;
820 if (!isVnni)
821 dimValue = 16 * blockingFactor;
822
823 // Case 1: For just one VC rewrite. Where all accumulator read/write
824 // within the same block.
825 if (accReadOp->getBlock() == contractOp->getBlock() &&
826 resultWriteOp->getBlock() == contractOp->getBlock()) {
827
828 bool collapse = false;
829 if (isVnni)
830 collapse = true;
831
832 LogicalResult validate = validateContractOps(
833 rewriter, contractOp, dimValue, Value(), Value(), false);
834
835 if (failed(validate))
836 return rewriter.notifyMatchFailure(
837 contractOp, "The contract operation doesn't satisfy the operands "
838 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
839 "The rest dims should be 1. Op should have one user.");
840
841 Location loc = contractOp.getLoc();
842
843 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
844 contractOp.getLhs(), collapse);
845 if (failed(srcIndxLhs))
846 return rewriter.notifyMatchFailure(contractOp,
847 "The LHS src is not a MemRef type.");
848 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
849
850 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
851 contractOp.getRhs(), collapse);
852 if (failed(srcIndxRhs))
853 return rewriter.notifyMatchFailure(contractOp,
854 "The RHS src is not a MemRef type.");
855 auto rhsSrc = *srcIndxRhs;
856 auto srcBuffRhs = rhsSrc.first;
857 auto indicesRhs = rhsSrc.second;
858
859 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
860 contractOp.getAcc(), false);
861 if (failed(srcIndxAcc))
862 return rewriter.notifyMatchFailure(contractOp,
863 "The ACC src is not a MemRef type.");
864 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
865
866 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
867
868 // amx.tile_loads
869 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
870 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
871 srcBuffLhs, indicesLhs);
872
873 // Create the subview and then load.
874 amx::TileLoadOp loadRhs;
875 if (!isVnni) {
876 VectorType vecTy;
877 SmallVector<OpFoldResult> indexVals;
878 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
879 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
880 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
881 readOp.getIndices().end());
882 vecTy = readOp.getType();
883 });
884 auto one = rewriter.getIndexAttr(1);
885 SmallVector<OpFoldResult> strides(indexVals.size(), one);
886 SmallVector<OpFoldResult> sizes = getAsIndexOpFoldResult(
887 contractOp.getRhs().getDefiningOp()->getContext(),
888 vecTy.getShape());
889 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
890 indexVals, sizes, strides);
891 auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
892 auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
893
894 // create a loop that does online packing.
895 Value step =
896 arith::ConstantIndexOp::create(rewriter, loc, blockingFactor);
897 Value uBound = arith::ConstantIndexOp::create(rewriter, loc,
898 (blockingFactor * 16));
899 Value nextLoadIndx =
900 arith::ConstantIndexOp::create(rewriter, loc, (blockingFactor / 2));
901 Value nextStoreIndx = arith::ConstantIndexOp::create(
902 rewriter, loc, 16 * (blockingFactor / 2));
903
904 scf::ForOp::create(
905 rewriter, loc, c0, uBound, step, ValueRange{},
906 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
907 ValueRange iterArgs) {
908 Value i1_load =
909 arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
910
911 indicesRhs[indicesRhs.size() - 2] = iv;
912 indicesRhs[indicesRhs.size() - 1] = c0;
913 ValueRange range1(indicesRhs);
914 auto vec1 = vector::LoadOp::create(
915 rewriter, loc,
916 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
917 range1);
918
919 indicesRhs[indicesRhs.size() - 2] = i1_load;
920 ValueRange range2(indicesRhs);
921 auto vec2 = vector::LoadOp::create(
922 rewriter, loc,
923 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
924 range2);
925
926 vector::ShuffleOp shuffle1;
927 vector::ShuffleOp shuffle2;
928
929 if (blockingFactor == 2) {
930
931 shuffle1 = vector::ShuffleOp::create(
932 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
933 ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
934 6, 22, 7, 23});
935
936 shuffle2 = vector::ShuffleOp::create(
937 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
938 ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
939 29, 14, 30, 15, 31});
940 }
941
942 if (blockingFactor == 4) {
943 shuffle1 = vector::ShuffleOp::create(
944 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
945 ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
946 2, 18, 34, 50, 3, 19, 35, 51,
947 4, 20, 36, 52, 5, 21, 37, 53,
948 6, 22, 38, 54, 7, 23, 39, 55});
949
950 shuffle2 = vector::ShuffleOp::create(
951 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
952 ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
953 10, 26, 42, 58, 11, 27, 43, 59,
954 12, 28, 44, 60, 13, 29, 45, 61,
955 14, 30, 46, 62, 15, 31, 47, 63});
956 }
957
958 auto rem = arith::DivUIOp::create(
959 rewriter, loc, rewriter.getIndexType(), iv, step);
960
961 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
962 ValueRange{rem, c0});
963 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
964 ValueRange{rem, nextStoreIndx});
965
966 scf::YieldOp::create(nestedBuilder, loc);
967 });
968 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
969 ValueRange{c0, c0});
970 } else {
971
972 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
973 indicesRhs);
974 }
975
976 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
977 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
978 srcBuffAcc, indicesAcc);
979
980 // Tiled dot-product.
981 Value dp;
982 if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN())
983 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
984 loadRhs, loadAcc);
985
986 if (ipType.isSignlessInteger(8))
987 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
988 loadRhs, loadAcc);
989
990 auto bufferType = MemRefType::get({16, 16}, opType);
991 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
992
993 amx::TileStoreOp::create(rewriter, loc, resultBuffer, ValueRange{c0, c0},
994 dp);
995
996 auto vectorType = mlir::VectorType::get({16, 16}, opType);
997 int64_t srcRank =
998 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
999 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1000 auto map = AffineMap::getMinorIdentityMap(srcRank, vectorType.getRank(),
1001 rewriter.getContext());
1002 SmallVector<bool> inBounds(vectorType.getRank(), true);
1003
1004 Value vecRow = vector::TransferReadOp::create(
1005 rewriter, loc, vectorType, resultBuffer, ValueRange{c0, c0}, padding,
1006 map, inBounds);
1007
1008 Value resultOp = contractionUsersAfterYield(contractOp.getResult());
1009 if (auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType()))
1010 vecRow = vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
1011
1012 rewriter.replaceAllUsesWith(resultOp, vecRow);
1013 return success();
1014 }
1015
1016 // Case 2: The acc are passed as iter args through the reduction loop.
1017 // We support, reduction loop depth until 2. TODO: Support for n-depth
1018 // reduction loop.
1019 // TODOs: Re-factor 2a and 2b.
1020 SmallVector<scf::ForOp> loopLists;
1021 Operation *current = contractOp;
1022 while (true) {
1023 Operation *parent = current->getParentOfType<scf::ForOp>();
1024
1025 if (!parent)
1026 return rewriter.notifyMatchFailure(
1027 contractOp,
1028 "Accumulator read and contract op not within scf.for op");
1029
1030 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
1031
1032 if (accReadOp->getBlock() == parent->getBlock()) {
1033 break;
1034 }
1035
1036 current = parent;
1037 }
1038 if (loopLists.size() > 2 || loopLists.size() == 0)
1039 return rewriter.notifyMatchFailure(
1040 contractOp, "Rewrite is supported until reduction loop depth of 2.");
1041
1042 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1043 contractOp.getLhs(), false);
1044 if (failed(srcIndxLhs))
1045 return rewriter.notifyMatchFailure(contractOp,
1046 "The LHS src is not a MemRef type.");
1047 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
1048
1049 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1050 contractOp.getRhs(), false);
1051 if (failed(srcIndxRhs))
1052 return rewriter.notifyMatchFailure(contractOp,
1053 "The RHS src is not a MemRef type.");
1054 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
1055 Operation *vectorOpLhs;
1056 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
1057 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
1058 vectorOpLhs = readOp.getBase().getDefiningOp();
1059 });
1060
1061 Operation *vectorOpRhs;
1062 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
1063 .Case<TransferReadOp, LoadOp>([&](auto readOp) {
1064 vectorOpRhs = readOp.getBase().getDefiningOp();
1065 });
1066
1067 if (!vectorOpLhs || !vectorOpRhs)
1068 return rewriter.notifyMatchFailure(
1069 contractOp, "Failed to find LHS or RHS read source operation");
1070
1071 // Retrive all the contaction operation within the loop.
1072 SmallVector<vector::ContractionOp> ops;
1073 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
1074
1075 if (auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
1076
1077 LogicalResult validate = validateContractOps(
1078 rewriter, contract, dimValue, srcBuffLhs, srcBuffRhs, true);
1079
1080 if (failed(validate))
1081 return rewriter.notifyMatchFailure(
1082 contractOp,
1083 "The associated contract operations doesn't satisfy "
1084 "the re-write conditions either the dimensions are "
1085 "wrong or MemRef source are different or many users.");
1086
1087 ops.push_back(contract);
1088 }
1089 }
1090
1091 if (!isVnni) {
1092 unsigned int pairCount = 0;
1093 for (size_t j = 0; j < ops.size(); j++) {
1094 for (size_t i = j; i < ops.size(); i++) {
1095 if (i != j && validatePairVectorContract(ops[j], ops[i], true, 16))
1096 pairCount = pairCount + 2;
1097 }
1098 }
1099
1100 if (pairCount != ops.size())
1101 return rewriter.notifyMatchFailure(
1102 contractOp, "Coudn't find the pair vector contract ");
1103 }
1104
1105 scf::ForOp innerLoop;
1106 scf::ForOp outerLoop;
1107
1108 scf::ForOp newLoop;
1109 // Case 2a: Reduction loop depth is 2.
1110 if (loopLists.size() == 2) {
1111 outerLoop = loopLists[1];
1112 innerLoop = loopLists[0];
1113
1114 LogicalResult validateOuterLoopStep =
1115 validateLoopStep(rewriter, outerLoop.getStep(), 1);
1116 if (failed(validateOuterLoopStep))
1117 return rewriter.notifyMatchFailure(contractOp, "Invalid loop step.");
1118
1119 int64_t stepValue = 16;
1120 if (!isVnni)
1121 stepValue = stepValue * blockingFactor;
1122 LogicalResult validateInnerLoopStep =
1123 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1124 if (failed(validateInnerLoopStep))
1125 return rewriter.notifyMatchFailure(
1126 contractOp, "Invalid loop step. The step should be 32 for BF16 and "
1127 "64 for Int8/F8.");
1128
1129 SmallVector<Value> loopItrArgs = createTileZeros(
1130 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
1131
1132 if (isVnni) {
1133 newLoop = scf::ForOp::create(
1134 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1135 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
1136 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1137 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1138 auto newInnerLoop = createLoops(
1139 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1140 innerLoop.getUpperBound(), innerLoop.getStep(),
1141 iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
1142 vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
1143 ops, ivOuterLoop, nullptr, true, nullptr, false, false);
1144
1145 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1146 newInnerLoop.getResults());
1147 });
1148
1149 } else {
1150
1151 bool isInnerLoopUBLarger = false;
1152 bool isInnerLoopUBHasOddQuot = false;
1153
1154 int64_t ubVal = 16 * blockingFactor;
1155 mlir::Value ub = innerLoop.getUpperBound();
1156 if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
1157 if (auto intAttr =
1158 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1159 ubVal = intAttr.getInt();
1160 }
1161 }
1162
1163 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1164 isInnerLoopUBHasOddQuot =
1165 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1166
1167 rewriter.setInsertionPoint(outerLoop);
1168
1169 auto c0 =
1170 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 0);
1171 auto c1 =
1172 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 1);
1173 auto spillLoopBound = arith::ConstantIndexOp::create(
1174 rewriter, outerLoop.getLoc(), 16 * blockingFactor);
1175
1176 Value spillOuterLoop = arith::SubIOp::create(
1177 rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
1178 Value spillInnerLoop =
1179 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1180 innerLoop.getUpperBound(), spillLoopBound);
1181 auto bufferType =
1182 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1183 auto packedBuffer =
1184 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1185
1186 // First Shuffling outside the reduction loops
1187 IRMapping rhsMapping;
1188 rhsMapping.map(
1189 vectorOpRhs->getOperand(
1190 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
1191 outerLoop.getLowerBound());
1192 rhsMapping.map(
1193 vectorOpRhs->getOperand(
1194 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1195 innerLoop.getLowerBound());
1196 auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
1197
1198 Value quotient_batch = arith::DivUIOp::create(
1199 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1200 outerLoop.getStep());
1201 Value quotient_k = arith::DivUIOp::create(rewriter, outerLoop.getLoc(),
1202 innerLoop.getLowerBound(),
1203 innerLoop.getStep());
1204
1205 Value quotient_add = arith::AddIOp::create(rewriter, outerLoop.getLoc(),
1206 quotient_batch, quotient_k);
1207 Value c2 =
1208 arith::ConstantIndexOp::create(rewriter, outerLoop.getLoc(), 2);
1209 Value rem = arith::RemUIOp::create(rewriter, outerLoop.getLoc(),
1210 quotient_add, c2);
1211
1212 performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
1213 ipType, blockingFactor, packedBuffer, rem);
1214
1215 // First Set of Loops
1216 auto newLoopNonSpill = scf::ForOp::create(
1217 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1218 spillOuterLoop, outerLoop.getStep(), loopItrArgs,
1219 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1220 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1221 auto newInnerLoop1 = createLoops(
1222 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1223 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1224 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1225 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1226 ivOuterLoop, packedBuffer, true, spillLoopBound,
1227 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1228
1229 auto newInnerLoop = createLoops(
1230 rewriter, innerLoop.getLoc(), spillInnerLoop,
1231 innerLoop.getUpperBound(), innerLoop.getStep(),
1232 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1233 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1234 innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
1235 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1236
1237 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1238 newInnerLoop.getResults());
1239 });
1240
1241 // Last set of Loops
1242 newLoop = scf::ForOp::create(
1243 rewriter, outerLoop.getLoc(), spillOuterLoop,
1244 outerLoop.getUpperBound(), outerLoop.getStep(),
1245 newLoopNonSpill.getResults(),
1246 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1247 Value ivOuterLoop, ValueRange iterArgsOuterLoop) {
1248 auto newInnerLoop1 = createLoops(
1249 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1250 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1251 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1252 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1253 ivOuterLoop, packedBuffer, true, spillLoopBound,
1254 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1255
1256 auto newInnerLoop = createLoops(
1257 rewriter, innerLoop.getLoc(), spillInnerLoop,
1258 innerLoop.getUpperBound(), innerLoop.getStep(),
1259 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1260 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1261 innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
1262 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1263
1264 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1265 newInnerLoop.getResults());
1266 });
1267 }
1268 }
1269
1270 // Case 2b: Reduction loop depth is 1.
1271 if (loopLists.size() == 1) {
1272
1273 innerLoop = loopLists[0];
1274 int64_t stepValue = 16;
1275 if (!isVnni)
1276 stepValue = stepValue * blockingFactor;
1277
1278 LogicalResult validateInnerLoopStep =
1279 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1280 if (failed(validateInnerLoopStep))
1281 return rewriter.notifyMatchFailure(
1282 contractOp,
1283 "Invalid loop step. The step should be 32 for BF16 and "
1284 "64 for Int8/F8 or 1 if it is rduction loop other than K.");
1285
1286 SmallVector<Value> loopItrArgs = createTileZeros(
1287 rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
1288
1289 if (isVnni) {
1290 newLoop = createLoops(
1291 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1292 innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
1293 opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1294 contractOp, nullptr, innerLoop, ops, nullptr, nullptr, true,
1295 nullptr, false, false);
1296
1297 } else {
1298
1299 bool isInnerLoopUBLarger = false;
1300 bool isInnerLoopUBHasOddQuot = false;
1301
1302 int64_t ubVal = 16 * blockingFactor;
1303 mlir::Value ub = innerLoop.getUpperBound();
1304 if (auto constOp = ub.getDefiningOp<mlir::arith::ConstantOp>()) {
1305 if (auto intAttr =
1306 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1307 ubVal = intAttr.getInt();
1308 }
1309 }
1310
1311 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1312 isInnerLoopUBHasOddQuot =
1313 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1314
1315 rewriter.setInsertionPoint(innerLoop);
1316
1317 auto c0 =
1318 arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 0);
1319 int64_t offset = 16 * blockingFactor;
1320 if (auto cst =
1321 innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>())
1322 offset = cst.value();
1323
1324 auto spillLoopBound = arith::ConstantIndexOp::create(
1325 rewriter, innerLoop.getLoc(), offset);
1326 Value spillInnerLoop =
1327 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1328 innerLoop.getUpperBound(), spillLoopBound);
1329
1330 auto bufferType =
1331 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1332 auto packedBuffer =
1333 memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
1334
1335 // First Shuffling outside the reduction loops
1336 IRMapping rhsMapping;
1337 rhsMapping.map(
1338 vectorOpRhs->getOperand(
1339 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1340 innerLoop.getLowerBound());
1341 auto rhsClone = rewriter.clone(*vectorOpRhs, rhsMapping);
1342
1343 Value quotient_k = arith::DivUIOp::create(rewriter, innerLoop.getLoc(),
1344 innerLoop.getLowerBound(),
1345 innerLoop.getStep());
1346 Value c2 =
1347 arith::ConstantIndexOp::create(rewriter, innerLoop.getLoc(), 2);
1348 Value rem = arith::RemUIOp::create(rewriter, innerLoop.getLoc(),
1349 quotient_k, c2);
1350
1351 performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
1352 ipType, blockingFactor, packedBuffer, rem);
1353
1354 auto newLoopNonSpill = createLoops(
1355 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1356 spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
1357 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
1358 nullptr, innerLoop, ops, nullptr, packedBuffer, true,
1359 spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1360
1361 newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
1362 innerLoop.getUpperBound(), innerLoop.getStep(),
1363 newLoopNonSpill.getResults(), ipType, opType,
1364 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1365 contractOp, nullptr, innerLoop, ops, nullptr,
1366 packedBuffer, false, c0, isInnerLoopUBLarger,
1367 isInnerLoopUBHasOddQuot);
1368 }
1369
1370 // This helps the final store back to the acc uses the same code for
1371 // the both reduction loop depth 1 or 2.
1372 outerLoop = innerLoop;
1373 }
1374
1375 // Copy the amx tile accumulation results to a MemRef buffer, add the
1376 // initial accumulation value, and store back to the C-Matrix
1377 Location loc = outerLoop.getLoc();
1378 Value srcBuffAcc;
1379 SmallVector<Value> indicesAcc;
1380
1381 llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
1382 [&](auto readOp) {
1383 srcBuffAcc = readOp.getOperand(0);
1384
1385 auto indices = readOp.getIndices();
1386 indicesAcc.reserve(indices.size());
1387
1388 llvm::transform(indices, std::back_inserter(indicesAcc),
1389 [&](OpFoldResult ofr) {
1391 rewriter, loc, ofr);
1392 });
1393 });
1394
1395 auto outputShapes =
1396 mlir::cast<mlir::MemRefType>(srcBuffAcc.getType()).getShape();
1397 unsigned int M = outputShapes[outputShapes.size() - 2];
1398 unsigned int N = outputShapes[outputShapes.size() - 1];
1399
1400 SmallVector<Value> dps = newLoop.getResults();
1401 auto bufferType = MemRefType::get({M, N}, opType);
1402 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
1403
1404 // Store the amx tiled-dot product output into an MxN memref.
1405 for (unsigned int i = 0, k = 0; i < M; i = i + 16) {
1406 for (unsigned int j = 0; j < N; j = j + 16) {
1407 Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
1408 Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
1409 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
1410 ValueRange{indexOp_i, indexOp_j}, dps[k]);
1411 k++;
1412 }
1413 }
1414 auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
1415 auto c16 = arith::ConstantIndexOp::create(rewriter, loc, 16);
1416 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
1417 auto nBound = arith::ConstantIndexOp::create(rewriter, loc, N);
1418
1419 // Create a loop that iterates over the MxN memerf, retrives two rows +
1420 // shuffle them, add up the C element values and stores them to temp buffer.
1421 scf::ForOp::create(
1422 rewriter, loc, c0, nBound, one, ValueRange{},
1423 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
1424 ValueRange iterArgs) {
1425 auto row =
1426 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1427 resultBuffer, ValueRange{iv, c0});
1428
1429 auto row2 =
1430 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1431 resultBuffer, ValueRange{iv, c16});
1432
1433 Value shuffle1 = row;
1434 Value shuffle2 = row2;
1435
1436 if (!isVnni) {
1437 shuffle1 = vector::ShuffleOp::create(
1438 rewriter, loc, VectorType::get(16, opType), row, row2,
1439 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
1440 21, 22, 23});
1441
1442 shuffle2 = vector::ShuffleOp::create(
1443 rewriter, loc, VectorType::get(16, opType), row, row2,
1444 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
1445 28, 29, 30, 31});
1446 }
1447 indicesAcc[indicesAcc.size() - 2] = iv;
1448 indicesAcc[indicesAcc.size() - 1] = c0;
1449
1450 Value valueCRow1 =
1451 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1452 srcBuffAcc, indicesAcc);
1453 indicesAcc[indicesAcc.size() - 1] = c16;
1454
1455 Value valueCRow2 =
1456 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1457 srcBuffAcc, indicesAcc);
1458
1459 Value addOp;
1460 Value addOp2;
1461
1462 if (ipType.isBF16() || ipType.isF8E5M2() || ipType.isF8E4M3FN()) {
1463 addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
1464
1465 addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
1466 }
1467
1468 if (ipType.isSignlessInteger(8)) {
1469 addOp = arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
1470
1471 addOp2 = arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
1472 }
1473
1474 vector::StoreOp::create(rewriter, loc, addOp, resultBuffer,
1475 ValueRange{iv, c0});
1476 vector::StoreOp::create(rewriter, loc, addOp2, resultBuffer,
1477 ValueRange{iv, c16});
1478
1479 scf::YieldOp::create(nestedBuilder, loc);
1480 });
1481
1482 SmallVector<Value> writeResults;
1483 for (unsigned int i = 0; i < M; i = i + 16) {
1484 for (unsigned int j = 0; j < N; j = j + 16) {
1485 Value indexOp_i = arith::ConstantIndexOp::create(rewriter, loc, i);
1486 Value indexOp_j = arith::ConstantIndexOp::create(rewriter, loc, j);
1487
1488 auto vectorType = mlir::VectorType::get({16, 16}, opType);
1489
1490 int64_t srcRank =
1491 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
1492 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1493 auto map = AffineMap::getMinorIdentityMap(srcRank, vectorType.getRank(),
1494 rewriter.getContext());
1495 SmallVector<bool> inBounds(vectorType.getRank(), true);
1496
1497 auto vec1 = vector::TransferReadOp::create(
1498 rewriter, loc, vectorType, resultBuffer,
1499 ValueRange{indexOp_i, indexOp_j}, padding, map, inBounds);
1500 writeResults.push_back(vec1);
1501 }
1502 }
1503
1504 // Replace use of vector.contract with dot-products.
1505 for (size_t i = 0; i < ops.size(); i++) {
1506 vector::ContractionOp contOp = ops[i];
1507 Value vecRow = writeResults[i];
1508
1509 Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
1510 if (auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.getType()))
1511 vecRow = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
1512 writeResults[i]);
1513
1514 rewriter.replaceAllUsesWith(resultWriteOp, vecRow);
1515 }
1516
1517 return success();
1518 }
1519};
1520
1521} // namespace
1522
1524 RewritePatternSet &patterns) {
1525 patterns.add<VectorContractToAMXDotProduct>(patterns.getContext());
1526}
return success()
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, Value input, int64_t firstDimToCollapse)
Creates a memref.collapse_shape collapsing all inner dimensions of the input starting at firstDimToCo...
#define rem(a, b)
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
FloatType getF32Type()
Definition Builders.cpp:47
FloatType getF8E5M2Type()
Definition Builders.cpp:39
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
FloatType getBF16Type()
Definition Builders.cpp:41
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
FloatType getF8E4M3FNType()
Definition Builders.cpp:37
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
unsigned getNumOperands()
Definition Operation.h:371
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF8E5M2() const
Definition Types.cpp:45
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isF8E4M3FN() const
Definition Types.cpp:44
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition Value.cpp:52
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
use_iterator use_begin() const
Definition Value.h:184
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
mlir::x86::AMXTileType TileType
Definition X86Dialect.h:40
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:194
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:154
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:352
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.