19#include "llvm/Support/Casting.h"
32static Value contractionUsersAfterYield(
Value v) {
39 if (!isa<scf::YieldOp>(user))
42 auto yield = cast<scf::YieldOp>(user);
46 return contractionUsersAfterYield(parent->
getResult(idx));
53 ShapedType inputType = cast<ShapedType>(input.
getType());
54 int64_t firstDimToCollapse = inputType.getRank() - 2;
56 if (inputType.getRank() == 1)
60 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
64 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
65 collapsedIndices.push_back(i);
67 reassociation.push_back(collapsedIndices);
68 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
73static FailureOr<std::pair<Value, SmallVector<Value>>>
83 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
85 readOp.getIndices().end());
86 srcBuff = readOp.getOperand(0);
96 indices.reserve(indexVals.size());
107 return std::make_pair(srcBuff,
indices);
111static LogicalResult validateLoopStep(
OpBuilder &rewriter,
Value step,
118 if (cst.value() != value && cst.value() != 1)
125static LogicalResult validateContractOps(
OpBuilder &rewriter,
126 vector::ContractionOp contractOp,
127 unsigned int blockingFactor,
133 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
134 contractOp.getLhs(),
false);
137 auto [buffLhs, indicesLhs] = *srcIndxLhs;
140 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
141 contractOp.getRhs(),
false);
144 auto [buffRhs, indicesRhs] = *srcIndxRhs;
147 if (buffLhs != srcBuffLhs)
150 if (buffRhs != srcBuffRhs)
154 if (!contractionUsersAfterYield(contractOp.getResult()))
157 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
164 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
165 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
167 if (nonUnitDimAcc.size() != 0)
172 VectorType lhsTy = contractOp.getLhsType();
175 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
176 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
178 if (nonUnitDimLhs.size() != 1)
181 if (nonUnitDimLhs[0] != blockingFactor)
186 VectorType rhsTy = contractOp.getRhsType();
189 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
190 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
192 if (nonUnitDimRhs.size() != 1)
195 if (nonUnitDimRhs[0] != blockingFactor)
203static unsigned getIndexPosition(
Value operand, scf::ForOp loop) {
204 Value iv = loop.getInductionVar();
208 .Case<TransferReadOp, LoadOp>(
209 [&](
auto readOp) { srcBuff = readOp.getOperand(0); });
215 auto offsets = subview.getOffsets();
217 for (
auto it : llvm::enumerate(offsets)) {
218 if (it.value() == iv)
228 bool rhs,
unsigned int offset,
231 auto srcIndx = getSrcIndxValue(rewriter, loc, operand,
false);
232 auto [srcBuff,
indices] = *srcIndx;
243 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
244 return amx::TileLoadOp::create(rewriter, loc, tileType, mat,
indices);
248 Type ipType,
unsigned int offset,
Value packedBuffer,
249 Value indxToStoreInBuffer) {
254 llvm::cast<MemRefType>(matB.
getType()).getRank(), c0);
262 rewriter, loc, c0, cBound, cStep,
ValueRange{},
265 subviewOffset[subviewOffset.size() - 2] = iv;
269 auto vectorType = VectorType::get({2, (16 * (offset / 2))}, ipType);
271 vectorType = VectorType::get((16 * offset), ipType);
273 int64_t srcRank = (dyn_cast<ShapedType>(matB.
getType())).getRank();
274 Value padding = ub::PoisonOp::create(rewriter, loc, ipType);
278 Value vec1 = vector::TransferReadOp::create(
279 rewriter, loc, vectorType, matB,
ValueRange(subviewOffset), padding,
283 vec1 = vector::ShapeCastOp::create(
284 rewriter, loc, VectorType::get((16 * offset), ipType), vec1);
288 Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
289 subviewOffset[subviewOffset.size() - 2] = incIV;
291 Value vec2 = vector::TransferReadOp::create(
292 rewriter, loc, vectorType, matB,
ValueRange(subviewOffset), padding,
295 vec2 = vector::ShapeCastOp::create(
296 rewriter, loc, VectorType::get((16 * offset), ipType), vec2);
298 vector::ShuffleOp shuffle1;
299 vector::ShuffleOp shuffle2;
303 shuffle1 = vector::ShuffleOp::create(
304 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
306 ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
307 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
308 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
310 shuffle2 = vector::ShuffleOp::create(
311 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
313 ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
314 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
315 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
321 shuffle1 = vector::ShuffleOp::create(
322 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
325 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
326 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
327 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
328 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
329 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
331 shuffle2 = vector::ShuffleOp::create(
332 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
335 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
336 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
337 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
338 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
339 30, 62, 94, 126, 31, 63, 95, 127});
343 Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
344 Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
346 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
347 ValueRange{indxToStoreInBuffer, ivShuff1, c0});
348 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
349 ValueRange{indxToStoreInBuffer, ivShuff2, c0});
351 scf::YieldOp::create(nestedBuilder, loc);
358 unsigned int offset,
Value packedBuffer,
bool pack,
359 Value indxToStoreInBuffer,
Value indxToLoadFromMatB) {
365 for (
size_t j = 0;
j < ops.size();
j++) {
366 for (
size_t i = 0; i < ops.size(); i++) {
370 Operation *readOpRhs = ops[
j].getRhs().getDefiningOp();
371 auto itRhs = readsToTileLoads.find(readOpRhs);
372 if (itRhs != readsToTileLoads.end()) {
377 performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
378 indxToStoreInBuffer);
382 amx::TileType::get({16, (16 * offset)}, ipType);
384 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
388 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
391 readsToTileLoads.try_emplace(readOpRhs, loadRow1);
392 readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
397 return readsToTileLoads;
405 unsigned int offset,
bool isVnni,
Value packedBuffer,
bool pack,
406 Value indxToStoreInBuffer,
Value indxToLoadFromMatB) {
421 packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
422 indxToStoreInBuffer, indxToLoadFromMatB);
426 for (
size_t i = 0; i < ops.size(); i++) {
428 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
429 amx::TileLoadOp tilesLhs;
430 auto itLhs = readsToTileLoads.find(readOpLhs);
431 if (itLhs != readsToTileLoads.end()) {
432 tilesLhs = itLhs->second;
434 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
435 false, offset, isVnni);
436 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
439 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
440 amx::TileLoadOp tilesRhs;
441 auto itRhs = readsToTileLoads.find(readOpRhs);
442 if (itRhs != readsToTileLoads.end()) {
443 tilesRhs = itRhs->second;
445 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
446 true, offset, isVnni);
447 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
450 auto accTileType = amx::TileType::get({16, 16}, opType);
454 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
455 tilesRhs, accIterArgs[i]);
458 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
459 tilesRhs, accIterArgs[i]);
461 accumulators.push_back(dp);
467 Type opType, scf::ForOp outerLoop,
472 auto zeroTileType = amx::TileType::get({16, 16}, opType);
474 for (
int i = 0; i < size; i++) {
475 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
476 loopItrArgs.push_back(zeroTile);
481static Value getIndxToLoadStoreFromPckBuffer(
483 bool isInnerLoopUBHasOddQuot,
bool isInnerLoopUBLarger,
bool pack,
484 unsigned int blockingFactor) {
490 Value quotientInnerLoop =
491 arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
492 Value remInnerLoop = arith::RemUIOp::create(
493 rewriter, loc, rewriter.
getIndexType(), quotientInnerLoop, c2);
495 if (!isInnerLoopUBLarger && !pack) {
496 remInnerLoop = arith::RemUIOp::create(
497 rewriter, loc, rewriter.
getIndexType(), ivOuterLoop, c2);
500 if (isInnerLoopUBHasOddQuot) {
501 auto remOuterLoop = arith::RemUIOp::create(
502 rewriter, loc, rewriter.
getIndexType(), ivOuterLoop, c2);
503 auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(),
504 remInnerLoop, remOuterLoop);
505 remInnerLoop = arith::RemUIOp::create(rewriter, loc,
515 Type ipType,
Type opType,
unsigned int blockingFactor,
bool isVnni,
517 vector::ContractionOp contractOp, scf::ForOp outerLoop,
519 Value ivOuterLoop,
Value packedBuffer,
bool pack,
521 bool isInnerLoopUBHasOddQuot) {
527 int64_t offset = 16 * blockingFactor;
529 offset = cst.value();
531 auto newLoop = scf::ForOp::create(
532 rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
538 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
542 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
544 auto lhsClone = rewriterNewInnerLoop.
clone(*vectorOpLhs, mapping);
546 Value indxToStoreInBuffer = c0;
547 Value indxToLoadFromBuffer = c0;
550 if (innerLoopIndex.
value() == 0) {
553 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
556 if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
557 indxToStoreInBuffer = arith::RemUIOp::create(
562 Value indxToLoadFromMatB = arith::AddIOp::create(
563 rewriter, loc, indxToStoreInBuffer, c1);
564 indxToLoadFromBuffer = arith::RemUIOp::create(
565 rewriter, loc, rewriter.
getIndexType(), indxToLoadFromMatB,
571 rewriter, locNewInnerLoop, offset);
572 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
573 nLoadIndx, ivNewInnerLoop);
574 indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
575 rewriter, loc, ivNewInnerLoop, ivOuterLoop,
576 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
578 Value indxToLoadFromMatB =
579 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
580 indxToLoadFromBuffer =
581 arith::RemUIOp::create(rewriter, loc, rewriter.
getIndexType(),
582 indxToLoadFromMatB, c2);
587 rewriter, locNewInnerLoop, offset);
588 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
589 nLoadIndx, ivNewInnerLoop);
590 Value quotient_K = arith::DivUIOp::create(
591 rewriter, loc, ivNewInnerLoop, nLoadIndx);
592 indxToStoreInBuffer = arith::RemUIOp::create(
593 rewriter, loc, rewriter.
getIndexType(), quotient_K, c2);
595 Value indxToLoadFromMatB =
596 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
597 indxToLoadFromBuffer =
598 arith::RemUIOp::create(rewriter, loc, rewriter.
getIndexType(),
599 indxToLoadFromMatB, c2);
612 int64_t outerPos = getIndexPosition(contractOp.getRhs(), outerLoop);
615 unsigned operandIdx =
static_cast<unsigned>(outerPos + 1);
617 if (operandIdx < rhsOp->getNumOperands())
622 int64_t innerPos = getIndexPosition(contractOp.getRhs(), innerLoop);
625 unsigned operandIdx =
static_cast<unsigned>(innerPos + 1);
627 if (operandIdx < rhsOp->getNumOperands())
628 rhsMapping.
map(rhsOp->
getOperand(operandIdx), ivNewInnerLoop);
631 auto rhsClone = rewriterNewInnerLoop.
clone(*rhsOp, rhsMapping);
632 matB = rhsClone->getResult(0);
643 rewriter, locNewInnerLoop, offset);
645 indxToLoadFromBuffer = c0;
646 indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
647 rewriter, loc, nLoadIndx, ivOuterLoop,
648 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
654 rewriter, locNewInnerLoop, offset);
656 Value quotient_K = arith::DivUIOp::create(
657 rewriter, loc, ivNewInnerLoop, nLoadIndx);
658 indxToLoadFromBuffer = arith::RemUIOp::create(
659 rewriter, loc, rewriter.
getIndexType(), quotient_K, c2);
665 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
666 ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
667 packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
669 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
740struct VectorContractToAMXDotProduct
742 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
744 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
745 PatternRewriter &rewriter)
const override {
747 if (contractOp.getKind() != vector::CombiningKind::ADD)
749 "Expects add combining kind.");
751 unsigned int blockingFactor =
752 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
755 contractOp.getIndexingMapsArray(), blockingFactor);
757 VectorType lhsTy = contractOp.getLhsType();
758 if (!lhsTy.getElementType().isBF16() &&
759 !lhsTy.getElementType().isSignlessInteger(8) &&
760 !lhsTy.getElementType().isF8E4M3FN() &&
761 !lhsTy.getElementType().isF8E5M2())
763 contractOp,
"Only BF16/Int8/F8 lowering is supported.");
765 if (lhsTy.getElementType() != contractOp.getRhsType().getElementType())
767 contractOp,
"Contraction should have same lhs and rhs type.");
769 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
773 if (((lhsTy.getElementType().isBF16() ||
774 lhsTy.getElementType().isF8E4M3FN() ||
775 lhsTy.getElementType().isF8E5M2()) &&
776 !accTy.getElementType().isF32()) ||
777 (lhsTy.getElementType().isSignlessInteger(8) &&
778 !accTy.getElementType().isSignlessInteger(32)))
780 "Only F32 for BF16 or Int32 for Int8 "
781 "accumulation type is supported.");
783 Operation *accReadOp =
786 Operation *resultWriteOp =
789 if (!accReadOp || !resultWriteOp)
791 contractOp,
"The ACC operand of the vector.contract should be a "
792 "transfer_read or a load. And, the result should be "
793 "stored using transfer_write or store.");
798 if (lhsTy.getElementType().isSignlessInteger(8)) {
803 if (lhsTy.getElementType().isF8E4M3FN())
806 if (lhsTy.getElementType().isF8E5M2())
809 if (accReadOp->
getBlock() == contractOp->getBlock() &&
810 resultWriteOp->
getBlock() != contractOp->getBlock())
812 contractOp,
"The accumulator store is in different block.");
814 if (accReadOp->
getBlock() != contractOp->getBlock() &&
815 resultWriteOp->
getBlock() == contractOp->getBlock())
817 contractOp,
"The accumulator read is in different block.");
819 unsigned int dimValue = blockingFactor;
821 dimValue = 16 * blockingFactor;
825 if (accReadOp->
getBlock() == contractOp->getBlock() &&
826 resultWriteOp->
getBlock() == contractOp->getBlock()) {
828 bool collapse =
false;
832 LogicalResult validate = validateContractOps(
833 rewriter, contractOp, dimValue, Value(), Value(),
false);
837 contractOp,
"The contract operation doesn't satisfy the operands "
838 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
839 "The rest dims should be 1. Op should have one user.");
841 Location loc = contractOp.getLoc();
843 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
844 contractOp.getLhs(), collapse);
847 "The LHS src is not a MemRef type.");
848 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
850 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
851 contractOp.getRhs(), collapse);
854 "The RHS src is not a MemRef type.");
855 auto rhsSrc = *srcIndxRhs;
856 auto srcBuffRhs = rhsSrc.first;
857 auto indicesRhs = rhsSrc.second;
859 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
860 contractOp.getAcc(),
false);
863 "The ACC src is not a MemRef type.");
864 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
869 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
870 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
871 srcBuffLhs, indicesLhs);
874 amx::TileLoadOp loadRhs;
877 SmallVector<OpFoldResult> indexVals;
878 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
879 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
880 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
881 readOp.getIndices().end());
882 vecTy = readOp.getType();
885 SmallVector<OpFoldResult> strides(indexVals.size(), one);
887 contractOp.getRhs().getDefiningOp()->getContext(),
889 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
890 indexVals, sizes, strides);
891 auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
892 auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
898 (blockingFactor * 16));
902 rewriter, loc, 16 * (blockingFactor / 2));
905 rewriter, loc, c0, uBound, step,
ValueRange{},
906 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
909 arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
911 indicesRhs[indicesRhs.size() - 2] = iv;
912 indicesRhs[indicesRhs.size() - 1] = c0;
914 auto vec1 = vector::LoadOp::create(
916 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
919 indicesRhs[indicesRhs.size() - 2] = i1_load;
921 auto vec2 = vector::LoadOp::create(
923 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
926 vector::ShuffleOp shuffle1;
927 vector::ShuffleOp shuffle2;
929 if (blockingFactor == 2) {
931 shuffle1 = vector::ShuffleOp::create(
932 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
933 ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
936 shuffle2 = vector::ShuffleOp::create(
937 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
938 ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
939 29, 14, 30, 15, 31});
942 if (blockingFactor == 4) {
943 shuffle1 = vector::ShuffleOp::create(
944 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
945 ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
946 2, 18, 34, 50, 3, 19, 35, 51,
947 4, 20, 36, 52, 5, 21, 37, 53,
948 6, 22, 38, 54, 7, 23, 39, 55});
950 shuffle2 = vector::ShuffleOp::create(
951 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
952 ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
953 10, 26, 42, 58, 11, 27, 43, 59,
954 12, 28, 44, 60, 13, 29, 45, 61,
955 14, 30, 46, 62, 15, 31, 47, 63});
958 auto rem = arith::DivUIOp::create(
961 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
963 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
966 scf::YieldOp::create(nestedBuilder, loc);
968 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
972 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
976 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
977 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
978 srcBuffAcc, indicesAcc);
983 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
987 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
990 auto bufferType = MemRefType::get({16, 16}, opType);
991 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
993 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
ValueRange{c0, c0},
996 auto vectorType = mlir::VectorType::get({16, 16}, opType);
998 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
999 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1002 SmallVector<bool> inBounds(vectorType.getRank(),
true);
1004 Value vecRow = vector::TransferReadOp::create(
1005 rewriter, loc, vectorType, resultBuffer,
ValueRange{c0, c0}, padding,
1008 Value resultOp = contractionUsersAfterYield(contractOp.getResult());
1009 if (
auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType()))
1010 vecRow = vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
1020 SmallVector<scf::ForOp> loopLists;
1021 Operation *current = contractOp;
1028 "Accumulator read and contract op not within scf.for op");
1030 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
1038 if (loopLists.size() > 2 || loopLists.size() == 0)
1040 contractOp,
"Rewrite is supported until reduction loop depth of 2.");
1042 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1043 contractOp.getLhs(),
false);
1046 "The LHS src is not a MemRef type.");
1047 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
1049 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1050 contractOp.getRhs(),
false);
1053 "The RHS src is not a MemRef type.");
1054 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
1055 Operation *vectorOpLhs;
1056 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
1057 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
1058 vectorOpLhs = readOp.getBase().getDefiningOp();
1061 Operation *vectorOpRhs;
1062 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
1063 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
1064 vectorOpRhs = readOp.getBase().getDefiningOp();
1067 if (!vectorOpLhs || !vectorOpRhs)
1069 contractOp,
"Failed to find LHS or RHS read source operation");
1072 SmallVector<vector::ContractionOp> ops;
1073 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
1075 if (
auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
1077 LogicalResult validate = validateContractOps(
1078 rewriter,
contract, dimValue, srcBuffLhs, srcBuffRhs,
true);
1083 "The associated contract operations doesn't satisfy "
1084 "the re-write conditions either the dimensions are "
1085 "wrong or MemRef source are different or many users.");
1092 unsigned int pairCount = 0;
1093 for (
size_t j = 0; j < ops.size(); j++) {
1094 for (
size_t i = j; i < ops.size(); i++) {
1096 pairCount = pairCount + 2;
1100 if (pairCount != ops.size())
1102 contractOp,
"Coudn't find the pair vector contract ");
1105 scf::ForOp innerLoop;
1106 scf::ForOp outerLoop;
1110 if (loopLists.size() == 2) {
1111 outerLoop = loopLists[1];
1112 innerLoop = loopLists[0];
1114 LogicalResult validateOuterLoopStep =
1115 validateLoopStep(rewriter, outerLoop.getStep(), 1);
1116 if (
failed(validateOuterLoopStep))
1119 int64_t stepValue = 16;
1121 stepValue = stepValue * blockingFactor;
1122 LogicalResult validateInnerLoopStep =
1123 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1124 if (
failed(validateInnerLoopStep))
1126 contractOp,
"Invalid loop step. The step should be 32 for BF16 and "
1129 SmallVector<Value> loopItrArgs = createTileZeros(
1130 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
1133 newLoop = scf::ForOp::create(
1134 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1135 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
1136 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1137 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1138 auto newInnerLoop = createLoops(
1139 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1140 innerLoop.getUpperBound(), innerLoop.getStep(),
1141 iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
1142 vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
1143 ops, ivOuterLoop, nullptr, true, nullptr, false, false);
1145 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1146 newInnerLoop.getResults());
1151 bool isInnerLoopUBLarger =
false;
1152 bool isInnerLoopUBHasOddQuot =
false;
1154 int64_t ubVal = 16 * blockingFactor;
1155 mlir::Value ub = innerLoop.getUpperBound();
1156 if (
auto constOp = ub.
getDefiningOp<mlir::arith::ConstantOp>()) {
1158 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1159 ubVal = intAttr.getInt();
1163 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1164 isInnerLoopUBHasOddQuot =
1165 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1174 rewriter, outerLoop.getLoc(), 16 * blockingFactor);
1176 Value spillOuterLoop = arith::SubIOp::create(
1177 rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
1178 Value spillInnerLoop =
1179 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1180 innerLoop.getUpperBound(), spillLoopBound);
1182 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1184 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1187 IRMapping rhsMapping;
1189 vectorOpRhs->getOperand(
1190 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
1191 outerLoop.getLowerBound());
1193 vectorOpRhs->getOperand(
1194 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1195 innerLoop.getLowerBound());
1196 auto rhsClone = rewriter.
clone(*vectorOpRhs, rhsMapping);
1198 Value quotient_batch = arith::DivUIOp::create(
1199 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1200 outerLoop.getStep());
1201 Value quotient_k = arith::DivUIOp::create(rewriter, outerLoop.getLoc(),
1202 innerLoop.getLowerBound(),
1203 innerLoop.getStep());
1205 Value quotient_add = arith::AddIOp::create(rewriter, outerLoop.getLoc(),
1206 quotient_batch, quotient_k);
1209 Value
rem = arith::RemUIOp::create(rewriter, outerLoop.getLoc(),
1212 performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
1213 ipType, blockingFactor, packedBuffer,
rem);
1216 auto newLoopNonSpill = scf::ForOp::create(
1217 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1218 spillOuterLoop, outerLoop.getStep(), loopItrArgs,
1219 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1220 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1221 auto newInnerLoop1 = createLoops(
1222 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1223 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1224 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1225 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1226 ivOuterLoop, packedBuffer, true, spillLoopBound,
1227 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1229 auto newInnerLoop = createLoops(
1230 rewriter, innerLoop.getLoc(), spillInnerLoop,
1231 innerLoop.getUpperBound(), innerLoop.getStep(),
1232 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1233 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1234 innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
1235 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1237 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1238 newInnerLoop.getResults());
1242 newLoop = scf::ForOp::create(
1243 rewriter, outerLoop.getLoc(), spillOuterLoop,
1244 outerLoop.getUpperBound(), outerLoop.getStep(),
1245 newLoopNonSpill.getResults(),
1246 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1247 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1248 auto newInnerLoop1 = createLoops(
1249 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1250 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1251 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1252 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1253 ivOuterLoop, packedBuffer, true, spillLoopBound,
1254 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1256 auto newInnerLoop = createLoops(
1257 rewriter, innerLoop.getLoc(), spillInnerLoop,
1258 innerLoop.getUpperBound(), innerLoop.getStep(),
1259 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1260 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1261 innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
1262 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1264 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1265 newInnerLoop.getResults());
1271 if (loopLists.size() == 1) {
1273 innerLoop = loopLists[0];
1274 int64_t stepValue = 16;
1276 stepValue = stepValue * blockingFactor;
1278 LogicalResult validateInnerLoopStep =
1279 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1280 if (
failed(validateInnerLoopStep))
1283 "Invalid loop step. The step should be 32 for BF16 and "
1284 "64 for Int8/F8 or 1 if it is rduction loop other than K.");
1286 SmallVector<Value> loopItrArgs = createTileZeros(
1287 rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
1290 newLoop = createLoops(
1291 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1292 innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
1293 opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1294 contractOp,
nullptr, innerLoop, ops,
nullptr,
nullptr,
true,
1295 nullptr,
false,
false);
1299 bool isInnerLoopUBLarger =
false;
1300 bool isInnerLoopUBHasOddQuot =
false;
1302 int64_t ubVal = 16 * blockingFactor;
1303 mlir::Value ub = innerLoop.getUpperBound();
1304 if (
auto constOp = ub.
getDefiningOp<mlir::arith::ConstantOp>()) {
1306 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1307 ubVal = intAttr.getInt();
1311 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1312 isInnerLoopUBHasOddQuot =
1313 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1319 int64_t offset = 16 * blockingFactor;
1321 innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>())
1322 offset = cst.value();
1325 rewriter, innerLoop.getLoc(), offset);
1326 Value spillInnerLoop =
1327 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1328 innerLoop.getUpperBound(), spillLoopBound);
1331 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1333 memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
1336 IRMapping rhsMapping;
1338 vectorOpRhs->getOperand(
1339 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1340 innerLoop.getLowerBound());
1341 auto rhsClone = rewriter.
clone(*vectorOpRhs, rhsMapping);
1343 Value quotient_k = arith::DivUIOp::create(rewriter, innerLoop.getLoc(),
1344 innerLoop.getLowerBound(),
1345 innerLoop.getStep());
1348 Value
rem = arith::RemUIOp::create(rewriter, innerLoop.getLoc(),
1351 performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
1352 ipType, blockingFactor, packedBuffer,
rem);
1354 auto newLoopNonSpill = createLoops(
1355 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1356 spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
1357 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
1358 nullptr, innerLoop, ops,
nullptr, packedBuffer,
true,
1359 spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1361 newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
1362 innerLoop.getUpperBound(), innerLoop.getStep(),
1363 newLoopNonSpill.getResults(), ipType, opType,
1364 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1365 contractOp,
nullptr, innerLoop, ops,
nullptr,
1366 packedBuffer,
false, c0, isInnerLoopUBLarger,
1367 isInnerLoopUBHasOddQuot);
1372 outerLoop = innerLoop;
1377 Location loc = outerLoop.getLoc();
1379 SmallVector<Value> indicesAcc;
1381 llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
1383 srcBuffAcc = readOp.getOperand(0);
1385 auto indices = readOp.getIndices();
1386 indicesAcc.reserve(
indices.size());
1388 llvm::transform(
indices, std::back_inserter(indicesAcc),
1389 [&](OpFoldResult ofr) {
1391 rewriter, loc, ofr);
1396 mlir::cast<mlir::MemRefType>(srcBuffAcc.
getType()).getShape();
1397 unsigned int M = outputShapes[outputShapes.size() - 2];
1398 unsigned int N = outputShapes[outputShapes.size() - 1];
1400 SmallVector<Value> dps = newLoop.getResults();
1401 auto bufferType = MemRefType::get({M, N}, opType);
1402 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
1405 for (
unsigned int i = 0, k = 0; i < M; i = i + 16) {
1406 for (
unsigned int j = 0; j < N; j = j + 16) {
1409 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
1422 rewriter, loc, c0, nBound, one,
ValueRange{},
1423 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
1426 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1430 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1433 Value shuffle1 = row;
1434 Value shuffle2 = row2;
1437 shuffle1 = vector::ShuffleOp::create(
1438 rewriter, loc, VectorType::get(16, opType), row, row2,
1439 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
1442 shuffle2 = vector::ShuffleOp::create(
1443 rewriter, loc, VectorType::get(16, opType), row, row2,
1444 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
1447 indicesAcc[indicesAcc.size() - 2] = iv;
1448 indicesAcc[indicesAcc.size() - 1] = c0;
1451 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1452 srcBuffAcc, indicesAcc);
1453 indicesAcc[indicesAcc.size() - 1] = c16;
1456 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1457 srcBuffAcc, indicesAcc);
1463 addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
1465 addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
1469 addOp = arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
1471 addOp2 = arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
1474 vector::StoreOp::create(rewriter, loc, addOp, resultBuffer,
1476 vector::StoreOp::create(rewriter, loc, addOp2, resultBuffer,
1479 scf::YieldOp::create(nestedBuilder, loc);
1482 SmallVector<Value> writeResults;
1483 for (
unsigned int i = 0; i < M; i = i + 16) {
1484 for (
unsigned int j = 0; j < N; j = j + 16) {
1488 auto vectorType = mlir::VectorType::get({16, 16}, opType);
1491 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
1492 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1495 SmallVector<bool> inBounds(vectorType.getRank(),
true);
1497 auto vec1 = vector::TransferReadOp::create(
1498 rewriter, loc, vectorType, resultBuffer,
1499 ValueRange{indexOp_i, indexOp_j}, padding, map, inBounds);
1500 writeResults.push_back(vec1);
1505 for (
size_t i = 0; i < ops.size(); i++) {
1506 vector::ContractionOp contOp = ops[i];
1507 Value vecRow = writeResults[i];
1509 Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
1510 if (
auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.
getType()))
1511 vecRow = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
1525 patterns.
add<VectorContractToAMXDotProduct>(patterns.
getContext());
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
IntegerAttr getIndexAttr(int64_t value)
FloatType getF8E5M2Type()
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatType getF8E4M3FNType()
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
use_iterator use_begin() const
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
mlir::x86::AMXTileType TileType
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.