19#include "llvm/Support/Casting.h"
34 ShapedType inputType = cast<ShapedType>(input.
getType());
35 int64_t firstDimToCollapse = inputType.getRank() - 2;
37 if (inputType.getRank() == 1)
41 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
45 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
46 collapsedIndices.push_back(i);
48 reassociation.push_back(collapsedIndices);
49 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
54static FailureOr<std::pair<Value, SmallVector<Value>>>
64 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
66 readOp.getIndices().end());
67 srcBuff = readOp.getOperand(0);
77 indices.reserve(indexVals.size());
88 return std::make_pair(srcBuff,
indices);
92static LogicalResult validateContractOps(
OpBuilder &rewriter,
93 vector::ContractionOp contractOp,
94 unsigned int blockingFactor,
100 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
101 contractOp.getLhs(),
false);
104 auto [buffLhs, indicesLhs] = *srcIndxLhs;
107 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
108 contractOp.getRhs(),
false);
111 auto [buffRhs, indicesRhs] = *srcIndxRhs;
114 if (buffLhs != srcBuffLhs)
117 if (buffRhs != srcBuffRhs)
121 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
128 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
129 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
131 if (nonUnitDimAcc.size() != 0)
136 VectorType lhsTy = contractOp.getLhsType();
139 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
140 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
142 if (nonUnitDimLhs.size() != 1)
145 if (nonUnitDimLhs[0] != blockingFactor)
150 VectorType rhsTy = contractOp.getRhsType();
153 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
154 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
156 if (nonUnitDimRhs.size() != 1)
159 if (nonUnitDimRhs[0] != blockingFactor)
167static unsigned getIndexPosition(
Value operand, scf::ForOp loop) {
168 Value iv = loop.getInductionVar();
172 .Case<TransferReadOp, LoadOp>(
173 [&](
auto readOp) { srcBuff = readOp.getOperand(0); });
179 auto offsets = subview.getOffsets();
181 for (
auto it : llvm::enumerate(offsets)) {
182 if (it.value() == iv)
192 bool rhs,
unsigned int offset,
195 auto srcIndx = getSrcIndxValue(rewriter, loc, operand,
false);
196 auto [srcBuff,
indices] = *srcIndx;
207 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
208 return amx::TileLoadOp::create(rewriter, loc, tileType, mat,
indices);
212 Type ipType,
unsigned int offset,
Value packedBuffer,
213 Value indxToStoreInBuffer) {
218 auto subview = matB.
getDefiningOp<mlir::memref::SubViewOp>();
227 rewriter, loc, c0, cBound, cStep,
ValueRange{},
230 subviewOffset[subviewOffset.size() - 2] = iv;
231 auto vec1 = vector::LoadOp::create(
232 rewriter, loc, VectorType::get((16 * offset), ipType), matB,
237 Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
238 subviewOffset[subviewOffset.size() - 2] = incIV;
239 auto vec2 = vector::LoadOp::create(
240 rewriter, loc, VectorType::get((16 * offset), ipType), matB,
243 vector::ShuffleOp shuffle1;
244 vector::ShuffleOp shuffle2;
248 shuffle1 = vector::ShuffleOp::create(
249 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
251 ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
252 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
253 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
255 shuffle2 = vector::ShuffleOp::create(
256 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
258 ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
259 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
260 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
265 shuffle1 = vector::ShuffleOp::create(
266 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
269 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
270 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
271 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
272 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
273 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
275 shuffle2 = vector::ShuffleOp::create(
276 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
279 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
280 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
281 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
282 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
283 30, 62, 94, 126, 31, 63, 95, 127});
287 Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
288 Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
290 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
291 ValueRange{indxToStoreInBuffer, ivShuff1, c0});
292 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
293 ValueRange{indxToStoreInBuffer, ivShuff2, c0});
295 scf::YieldOp::create(nestedBuilder, loc);
302 unsigned int offset,
Value packedBuffer,
bool pack,
303 Value indxToStoreInBuffer,
Value indxToLoadFromMatB) {
309 for (
size_t j = 0;
j < ops.size();
j++) {
310 for (
size_t i = 0; i < ops.size(); i++) {
314 Operation *readOpRhs = ops[
j].getRhs().getDefiningOp();
315 auto itRhs = readsToTileLoads.find(readOpRhs);
316 if (itRhs != readsToTileLoads.end()) {
321 performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
322 indxToStoreInBuffer);
326 amx::TileType::get({16, (16 * offset)}, ipType);
328 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
332 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
335 readsToTileLoads.try_emplace(readOpRhs, loadRow1);
336 readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
341 return readsToTileLoads;
349 unsigned int offset,
bool isVnni,
Value packedBuffer,
bool pack,
350 Value indxToStoreInBuffer,
Value indxToLoadFromMatB) {
365 packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
366 indxToStoreInBuffer, indxToLoadFromMatB);
370 for (
size_t i = 0; i < ops.size(); i++) {
372 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
373 amx::TileLoadOp tilesLhs;
374 auto itLhs = readsToTileLoads.find(readOpLhs);
375 if (itLhs != readsToTileLoads.end()) {
376 tilesLhs = itLhs->second;
378 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
379 false, offset, isVnni);
380 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
383 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
384 amx::TileLoadOp tilesRhs;
385 auto itRhs = readsToTileLoads.find(readOpRhs);
386 if (itRhs != readsToTileLoads.end()) {
387 tilesRhs = itRhs->second;
389 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
390 true, offset, isVnni);
391 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
394 auto accTileType = amx::TileType::get({16, 16}, opType);
398 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
399 tilesRhs, accIterArgs[i]);
402 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
403 tilesRhs, accIterArgs[i]);
405 accumulators.push_back(dp);
411 Type opType, scf::ForOp outerLoop,
416 auto zeroTileType = amx::TileType::get({16, 16}, opType);
418 for (
int i = 0; i < size; i++) {
419 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
420 loopItrArgs.push_back(zeroTile);
425static Value getIndxToLoadStoreFromPckBuffer(
427 bool isInnerLoopUBHasOddQuot,
bool isInnerLoopUBLarger,
bool pack,
428 unsigned int blockingFactor) {
434 Value quotientInnerLoop =
435 arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
436 Value remInnerLoop = arith::RemUIOp::create(
437 rewriter, loc, rewriter.
getIndexType(), quotientInnerLoop, c2);
439 if (!isInnerLoopUBLarger && !pack) {
440 remInnerLoop = arith::RemUIOp::create(
441 rewriter, loc, rewriter.
getIndexType(), ivOuterLoop, c2);
444 if (isInnerLoopUBHasOddQuot) {
445 auto remOuterLoop = arith::RemUIOp::create(
446 rewriter, loc, rewriter.
getIndexType(), ivOuterLoop, c2);
447 auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(),
448 remInnerLoop, remOuterLoop);
449 remInnerLoop = arith::RemUIOp::create(rewriter, loc,
459 Type ipType,
Type opType,
unsigned int blockingFactor,
bool isVnni,
461 vector::ContractionOp contractOp, scf::ForOp outerLoop,
463 Value ivOuterLoop,
Value packedBuffer,
bool pack,
465 bool isInnerLoopUBHasOddQuot) {
471 auto newLoop = scf::ForOp::create(
472 rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
478 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
482 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
484 auto lhsClone = rewriterNewInnerLoop.
clone(*vectorOpLhs, mapping);
486 Value indxToStoreInBuffer = c0;
487 Value indxToLoadFromBuffer = c0;
491 if (innerLoopIndex.
value() == 0) {
494 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
497 if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
498 indxToStoreInBuffer = arith::RemUIOp::create(
503 Value indxToLoadFromMatB = arith::AddIOp::create(
504 rewriter, loc, indxToStoreInBuffer, c1);
505 indxToLoadFromBuffer = arith::RemUIOp::create(
506 rewriter, loc, rewriter.
getIndexType(), indxToLoadFromMatB,
512 rewriter, locNewInnerLoop, (16 * blockingFactor));
513 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
514 nLoadIndx, ivNewInnerLoop);
515 indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
516 rewriter, loc, ivNewInnerLoop, ivOuterLoop,
517 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
519 Value indxToLoadFromMatB =
520 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
521 indxToLoadFromBuffer =
522 arith::RemUIOp::create(rewriter, loc, rewriter.
getIndexType(),
523 indxToLoadFromMatB, c2);
528 rewriter, locNewInnerLoop, (16 * blockingFactor));
529 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
530 nLoadIndx, ivNewInnerLoop);
531 Value quotient_K = arith::DivUIOp::create(
532 rewriter, loc, ivNewInnerLoop, nLoadIndx);
533 indxToStoreInBuffer = arith::RemUIOp::create(
534 rewriter, loc, rewriter.
getIndexType(), quotient_K, c2);
536 Value indxToLoadFromMatB =
537 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
538 indxToLoadFromBuffer =
539 arith::RemUIOp::create(rewriter, loc, rewriter.
getIndexType(),
540 indxToLoadFromMatB, c2);
549 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
554 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
556 auto rhsClone = rewriterNewInnerLoop.
clone(*vectorOpRhs, rhsMapping);
558 Value matB = rhsClone->getResult(0);
564 rewriter, locNewInnerLoop, (16 * blockingFactor));
566 indxToLoadFromBuffer = c0;
567 indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
568 rewriter, loc, nLoadIndx, ivOuterLoop,
569 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
575 rewriter, locNewInnerLoop, (16 * blockingFactor));
577 Value quotient_K = arith::DivUIOp::create(
578 rewriter, loc, ivNewInnerLoop, nLoadIndx);
579 indxToLoadFromBuffer = arith::RemUIOp::create(
580 rewriter, loc, rewriter.
getIndexType(), quotient_K, c2);
587 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
588 ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
589 packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
591 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
662struct VectorContractToAMXDotProduct
664 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
666 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
667 PatternRewriter &rewriter)
const override {
669 if (contractOp.getKind() != vector::CombiningKind::ADD)
671 "Expects add combining kind.");
673 unsigned int blockingFactor =
674 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
677 contractOp.getIndexingMapsArray(), blockingFactor);
679 VectorType lhsTy = contractOp.getLhsType();
680 if (!lhsTy.getElementType().isBF16() &&
681 !lhsTy.getElementType().isSignlessInteger(8))
683 contractOp,
"Only BF16/Int8 lowering is supported.");
685 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
689 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
690 (lhsTy.getElementType().isSignlessInteger(8) &&
691 !accTy.getElementType().isSignlessInteger(32)))
693 "Only F32 for BF16 or Int32 for Int8 "
694 "accumulation type is supported.");
696 Operation *accReadOp =
699 Operation *resultWriteOp =
702 if (!accReadOp || !resultWriteOp)
704 contractOp,
"The ACC operand of the vector.contract should be a "
705 "transfer_read or a load. And, the result should be "
706 "stored using transfer_write or store.");
711 if (lhsTy.getElementType().isSignlessInteger(8)) {
716 if (accReadOp->
getBlock() == contractOp->getBlock() &&
717 resultWriteOp->
getBlock() != contractOp->getBlock())
719 contractOp,
"The accumulator store is in different block.");
721 if (accReadOp->
getBlock() != contractOp->getBlock() &&
722 resultWriteOp->
getBlock() == contractOp->getBlock())
724 contractOp,
"The accumulator read is in different block.");
726 unsigned int dimValue = blockingFactor;
728 dimValue = 16 * blockingFactor;
732 if (accReadOp->
getBlock() == contractOp->getBlock() &&
733 resultWriteOp->
getBlock() == contractOp->getBlock()) {
735 bool collapse =
false;
739 LogicalResult validate = validateContractOps(
740 rewriter, contractOp, dimValue, Value(), Value(),
false);
744 contractOp,
"The contract operation doesn't satisfy the operands "
745 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
746 "The rest dims should be 1.");
748 Location loc = contractOp.getLoc();
750 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
751 contractOp.getLhs(), collapse);
754 "The LHS src is not a MemRef type.");
755 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
757 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
758 contractOp.getRhs(), collapse);
761 "The RHS src is not a MemRef type.");
762 auto rhsSrc = *srcIndxRhs;
763 auto srcBuffRhs = rhsSrc.first;
764 auto indicesRhs = rhsSrc.second;
766 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
767 contractOp.getAcc(),
false);
770 "The ACC src is not a MemRef type.");
771 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
774 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
775 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
776 srcBuffLhs, indicesLhs);
779 amx::TileLoadOp loadRhs;
782 SmallVector<OpFoldResult> indexVals;
783 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
784 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
785 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
786 readOp.getIndices().end());
787 vecTy = readOp.getType();
790 SmallVector<OpFoldResult> strides(indexVals.size(), one);
792 contractOp.getRhs().getDefiningOp()->getContext(),
794 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
795 indexVals, sizes, strides);
796 auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
797 auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
804 (blockingFactor * 16));
808 rewriter, loc, 16 * (blockingFactor / 2));
811 rewriter, loc, c0, uBound, step,
ValueRange{},
812 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
815 arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
817 indicesRhs[indicesRhs.size() - 2] = iv;
819 auto vec1 = vector::LoadOp::create(
821 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
824 indicesRhs[indicesRhs.size() - 2] = i1_load;
826 auto vec2 = vector::LoadOp::create(
828 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
831 vector::ShuffleOp shuffle1;
832 vector::ShuffleOp shuffle2;
834 if (blockingFactor == 2) {
836 shuffle1 = vector::ShuffleOp::create(
837 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
838 ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
841 shuffle2 = vector::ShuffleOp::create(
842 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
843 ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
844 29, 14, 30, 15, 31});
847 if (blockingFactor == 4) {
848 shuffle1 = vector::ShuffleOp::create(
849 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
850 ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
851 2, 18, 34, 50, 3, 19, 35, 51,
852 4, 20, 36, 52, 5, 21, 37, 53,
853 6, 22, 38, 54, 7, 23, 39, 55});
855 shuffle2 = vector::ShuffleOp::create(
856 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
857 ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
858 10, 26, 42, 58, 11, 27, 43, 59,
859 12, 28, 44, 60, 13, 29, 45, 61,
860 14, 30, 46, 62, 15, 31, 47, 63});
863 auto rem = arith::RemUIOp::create(
866 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
868 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
871 scf::YieldOp::create(nestedBuilder, loc);
873 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
877 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
881 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
882 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
883 srcBuffAcc, indicesAcc);
888 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
892 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
895 amx::TileStoreOp::create(rewriter, loc, srcBuffAcc, indicesAcc, dp);
897 rewriter.
eraseOp(resultWriteOp);
905 SmallVector<scf::ForOp> loopLists;
906 Operation *current = contractOp;
913 "Accumulator read and contract op not within scf.for op");
915 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
923 if (loopLists.size() > 2 || loopLists.size() == 0)
925 contractOp,
"Rewrite is supported until reduction loop depth of 2.");
927 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
928 contractOp.getLhs(),
false);
931 "The LHS src is not a MemRef type.");
932 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
934 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
935 contractOp.getRhs(),
false);
938 "The RHS src is not a MemRef type.");
939 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
940 Operation *vectorOpLhs;
941 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
942 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
943 vectorOpLhs = readOp.getBase().getDefiningOp();
946 Operation *vectorOpRhs;
947 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
948 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
949 vectorOpRhs = readOp.getBase().getDefiningOp();
953 SmallVector<vector::ContractionOp> ops;
954 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
956 if (
auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
958 LogicalResult validate = validateContractOps(
959 rewriter,
contract, dimValue, srcBuffLhs, srcBuffRhs,
true);
963 contractOp,
"The associated contract operations doesn't satisfy "
964 "the re-write conditions either the dimensions are "
965 "wrong or MemRef source are different.");
972 unsigned int pairCount = 0;
973 for (
size_t j = 0; j < ops.size(); j++) {
974 for (
size_t i = j; i < ops.size(); i++) {
976 pairCount = pairCount + 2;
980 if (pairCount != ops.size())
982 contractOp,
"Coudn't find the pair vector contract ");
985 scf::ForOp innerLoop;
986 scf::ForOp outerLoop;
990 if (loopLists.size() == 2) {
991 outerLoop = loopLists[1];
992 innerLoop = loopLists[0];
994 SmallVector<Value> loopItrArgs = createTileZeros(
995 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
998 newLoop = scf::ForOp::create(
999 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1000 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
1001 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1002 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1003 auto newInnerLoop = createLoops(
1004 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1005 innerLoop.getUpperBound(), innerLoop.getStep(),
1006 iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
1007 vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
1008 ops, ivOuterLoop, nullptr, true, nullptr, false, false);
1010 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1011 newInnerLoop.getResults());
1016 bool isInnerLoopUBLarger =
false;
1017 bool isInnerLoopUBHasOddQuot =
false;
1019 int64_t ubVal = 16 * blockingFactor;
1020 mlir::Value ub = innerLoop.getUpperBound();
1021 if (
auto constOp = ub.
getDefiningOp<mlir::arith::ConstantOp>()) {
1023 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1024 ubVal = intAttr.getInt();
1028 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1029 isInnerLoopUBHasOddQuot =
1030 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1039 rewriter, outerLoop.getLoc(), 16 * blockingFactor);
1041 Value spillOuterLoop = arith::SubIOp::create(
1042 rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
1043 Value spillInnerLoop =
1044 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1045 innerLoop.getUpperBound(), spillLoopBound);
1047 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1049 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1052 IRMapping rhsMapping;
1054 vectorOpRhs->getOperand(
1055 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
1058 vectorOpRhs->getOperand(
1059 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1061 auto rhsClone = rewriter.
clone(*vectorOpRhs, rhsMapping);
1063 performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
1064 ipType, blockingFactor, packedBuffer, c0);
1067 auto newLoopNonSpill = scf::ForOp::create(
1068 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1069 spillOuterLoop, outerLoop.getStep(), loopItrArgs,
1070 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1071 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1072 auto newInnerLoop1 = createLoops(
1073 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1074 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1075 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1076 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1077 ivOuterLoop, packedBuffer, true, spillLoopBound,
1078 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1080 auto newInnerLoop = createLoops(
1081 rewriter, innerLoop.getLoc(), spillInnerLoop,
1082 innerLoop.getUpperBound(), innerLoop.getStep(),
1083 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1084 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1085 innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
1086 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1088 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1089 newInnerLoop.getResults());
1093 newLoop = scf::ForOp::create(
1094 rewriter, outerLoop.getLoc(), spillOuterLoop,
1095 outerLoop.getUpperBound(), outerLoop.getStep(),
1096 newLoopNonSpill.getResults(),
1097 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1098 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1099 auto newInnerLoop1 = createLoops(
1100 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1101 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1102 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1103 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1104 ivOuterLoop, packedBuffer, true, spillLoopBound,
1105 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1107 auto newInnerLoop = createLoops(
1108 rewriter, innerLoop.getLoc(), spillInnerLoop,
1109 innerLoop.getUpperBound(), innerLoop.getStep(),
1110 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1111 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1112 innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
1113 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1115 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1116 newInnerLoop.getResults());
1122 if (loopLists.size() == 1) {
1123 innerLoop = loopLists[0];
1125 SmallVector<Value> loopItrArgs = createTileZeros(
1126 rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
1130 newLoop = createLoops(
1131 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1132 innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
1133 opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1134 contractOp,
nullptr, innerLoop, ops,
nullptr,
nullptr,
true,
1135 nullptr,
false,
false);
1138 bool isInnerLoopUBLarger =
false;
1139 bool isInnerLoopUBHasOddQuot =
false;
1141 int64_t ubVal = 16 * blockingFactor;
1142 mlir::Value ub = innerLoop.getUpperBound();
1143 if (
auto constOp = ub.
getDefiningOp<mlir::arith::ConstantOp>()) {
1145 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1146 ubVal = intAttr.getInt();
1150 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1151 isInnerLoopUBHasOddQuot =
1152 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1158 rewriter, innerLoop.getLoc(), 16 * blockingFactor);
1160 Value spillInnerLoop =
1161 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1162 innerLoop.getUpperBound(), spillLoopBound);
1165 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1167 memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
1170 IRMapping rhsMapping;
1172 vectorOpRhs->getOperand(
1173 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1175 auto rhsClone = rewriter.
clone(*vectorOpRhs, rhsMapping);
1177 performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
1178 ipType, blockingFactor, packedBuffer, c0);
1180 auto newLoopNonSpill = createLoops(
1181 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1182 spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
1183 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
1184 nullptr, innerLoop, ops,
nullptr, packedBuffer,
true,
1185 spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1187 newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
1188 innerLoop.getUpperBound(), innerLoop.getStep(),
1189 newLoopNonSpill.getResults(), ipType, opType,
1190 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1191 contractOp,
nullptr, innerLoop, ops,
nullptr,
1192 packedBuffer,
false, c0, isInnerLoopUBLarger,
1193 isInnerLoopUBHasOddQuot);
1198 outerLoop = innerLoop;
1205 Location loc = outerLoop.getLoc();
1206 Operation *accReadOp =
1210 SmallVector<Value> indicesAcc;
1212 llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
1214 srcBuffAcc = readOp.getOperand(0);
1216 auto indices = readOp.getIndices();
1217 indicesAcc.reserve(
indices.size());
1219 llvm::transform(
indices, std::back_inserter(indicesAcc),
1220 [&](OpFoldResult ofr) {
1222 rewriter, loc, ofr);
1227 mlir::cast<mlir::MemRefType>(srcBuffAcc.
getType()).getShape();
1228 unsigned int M = outputShapes[outputShapes.size() - 2];
1229 unsigned int N = outputShapes[outputShapes.size() - 1];
1231 SmallVector<Value> dps = newLoop.getResults();
1232 auto bufferType = MemRefType::get({M, N}, opType);
1233 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
1236 for (
unsigned int i = 0, k = 0; i < M; i = i + 16) {
1237 for (
unsigned int j = 0; j < N; j = j + 16) {
1240 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
1253 rewriter, loc, c0, mBound, one,
ValueRange{},
1254 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
1256 auto row = vector::LoadOp::create(rewriter, loc,
1257 VectorType::get(16, opType),
1260 auto row2 = vector::LoadOp::create(
1261 rewriter, loc, VectorType::get(16, opType), resultBuffer,
1264 auto shuffle1 = vector::ShuffleOp::create(
1265 rewriter, loc, VectorType::get(16, opType), row, row2,
1266 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
1269 auto shuffle2 = vector::ShuffleOp::create(
1270 rewriter, loc, VectorType::get(16, opType), row, row2,
1271 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
1274 indicesAcc[indicesAcc.size() - 2] = iv;
1275 indicesAcc[indicesAcc.size() - 1] = c0;
1277 Value valueCRow1 = vector::LoadOp::create(
1278 rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
1280 indicesAcc[indicesAcc.size() - 1] = c16;
1282 Value valueCRow2 = vector::LoadOp::create(
1283 rewriter, loc, VectorType::get(16, opType), srcBuffAcc,
1291 arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
1294 arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
1299 arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
1302 arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
1304 indicesAcc[indicesAcc.size() - 1] = c0;
1305 vector::StoreOp::create(rewriter, loc, addOp, srcBuffAcc,
1307 indicesAcc[indicesAcc.size() - 1] = c16;
1308 vector::StoreOp::create(rewriter, loc, addOp2, srcBuffAcc,
1311 scf::YieldOp::create(nestedBuilder, loc);
1315 auto bufferType = MemRefType::get({16, 16}, opType);
1317 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1318 SmallVector<Value> dps = newLoop.getResults();
1320 for (
size_t i = 0; i < ops.size(); i++) {
1321 vector::ContractionOp contOp = ops[i];
1322 Operation *resultWriteOp =
1330 amx::TileStoreOp::create(rewriter, outerLoop.getLoc(), resultBuffer,
1341 rewriter, outerLoop.getLoc(), c0, mBound, one,
ValueRange{},
1342 [&](OpBuilder &builder, Location loc, Value iv,
1344 auto resultAcc = vector::LoadOp::create(
1345 rewriter, loc, VectorType::get(16, opType), resultBuffer,
1346 ValueRange{iv, c0});
1348 Operation *accReadOp =
1352 SmallVector<Value> indicesAcc;
1354 llvm::TypeSwitch<Operation *>(accReadOp)
1355 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
1356 srcBuffAcc = readOp.getOperand(0);
1358 auto indices = readOp.getIndices();
1359 indicesAcc.reserve(
indices.size());
1362 indices, std::back_inserter(indicesAcc),
1363 [&](OpFoldResult ofr) {
1365 rewriter, loc, ofr);
1370 arith::AddIOp::create(builder, loc, iv, indicesAcc[0]);
1371 indicesAcc[indicesAcc.size() - 2] = sum;
1373 auto acc = vector::LoadOp::create(rewriter, loc,
1374 VectorType::get(16, opType),
1375 srcBuffAcc, indicesAcc);
1378 addition = arith::AddFOp::create(rewriter, loc, resultAcc, acc);
1381 addition = arith::AddIOp::create(rewriter, loc, resultAcc, acc);
1383 vector::StoreOp::create(builder, loc, addition, srcBuffAcc,
1386 scf::YieldOp::create(builder, outerLoop.getLoc());
1390 rewriter.
eraseOp(resultWriteOp);
1401 patterns.
add<VectorContractToAMXDotProduct>(patterns.
getContext());
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
mlir::x86::AMXTileType TileType
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.