19#include "llvm/Support/Casting.h"
32static Value contractionUsersAfterYield(
Value v) {
39 if (!isa<scf::YieldOp>(user))
42 auto yield = cast<scf::YieldOp>(user);
46 return contractionUsersAfterYield(parent->
getResult(idx));
53 ShapedType inputType = cast<ShapedType>(input.
getType());
54 int64_t firstDimToCollapse = inputType.getRank() - 2;
56 if (inputType.getRank() == 1)
60 for (
int64_t i = 0; i < firstDimToCollapse; ++i)
64 for (
int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
65 collapsedIndices.push_back(i);
67 reassociation.push_back(collapsedIndices);
68 return memref::CollapseShapeOp::create(builder, loc, input, reassociation);
73static FailureOr<std::pair<Value, SmallVector<Value>>>
83 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
85 readOp.getIndices().end());
86 srcBuff = readOp.getOperand(0);
96 indices.reserve(indexVals.size());
107 return std::make_pair(srcBuff,
indices);
111static LogicalResult validateLoopStep(
OpBuilder &rewriter,
Value step,
118 if (cst.value() != value && cst.value() != 1)
125static LogicalResult validateContractOps(
OpBuilder &rewriter,
126 vector::ContractionOp contractOp,
127 unsigned int blockingFactor,
133 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
134 contractOp.getLhs(),
false);
137 auto [buffLhs, indicesLhs] = *srcIndxLhs;
140 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
141 contractOp.getRhs(),
false);
144 auto [buffRhs, indicesRhs] = *srcIndxRhs;
147 if (buffLhs != srcBuffLhs)
150 if (buffRhs != srcBuffRhs)
154 if (!contractionUsersAfterYield(contractOp.getResult()))
157 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
164 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
165 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
167 if (nonUnitDimAcc.size() != 0)
172 VectorType lhsTy = contractOp.getLhsType();
175 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
176 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
178 if (nonUnitDimLhs.size() != 1)
181 if (nonUnitDimLhs[0] != blockingFactor)
186 VectorType rhsTy = contractOp.getRhsType();
189 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
190 [](
int64_t dim) {
return (dim != 16 && dim != 1); });
192 if (nonUnitDimRhs.size() != 1)
195 if (nonUnitDimRhs[0] != blockingFactor)
203static unsigned getIndexPosition(
Value operand, scf::ForOp loop) {
204 Value iv = loop.getInductionVar();
208 .Case<TransferReadOp, LoadOp>(
209 [&](
auto readOp) { srcBuff = readOp.getOperand(0); });
215 auto offsets = subview.getOffsets();
217 for (
auto it : llvm::enumerate(offsets)) {
218 if (it.value() == iv)
228 bool rhs,
unsigned int offset,
231 auto srcIndx = getSrcIndxValue(rewriter, loc, operand,
false);
232 auto [srcBuff,
indices] = *srcIndx;
243 amx::TileType tileType = amx::TileType::get({16, (16 * offset)}, ipType);
244 return amx::TileLoadOp::create(rewriter, loc, tileType, mat,
indices);
248 Type ipType,
unsigned int offset,
Value packedBuffer,
249 Value indxToStoreInBuffer) {
254 llvm::cast<MemRefType>(matB.
getType()).getRank(), c0);
262 rewriter, loc, c0, cBound, cStep,
ValueRange{},
265 subviewOffset[subviewOffset.size() - 2] = iv;
269 auto vectorType = VectorType::get({2, (16 * (offset / 2))}, ipType);
271 vectorType = VectorType::get((16 * offset), ipType);
273 int64_t srcRank = (dyn_cast<ShapedType>(matB.
getType())).getRank();
274 Value padding = ub::PoisonOp::create(rewriter, loc, ipType);
278 Value vec1 = vector::TransferReadOp::create(
279 rewriter, loc, vectorType, matB,
ValueRange(subviewOffset), padding,
283 vec1 = vector::ShapeCastOp::create(
284 rewriter, loc, VectorType::get((16 * offset), ipType), vec1);
288 Value incIV = arith::AddIOp::create(rewriter, loc, offsetIndx, iv);
289 subviewOffset[subviewOffset.size() - 2] = incIV;
291 Value vec2 = vector::TransferReadOp::create(
292 rewriter, loc, vectorType, matB,
ValueRange(subviewOffset), padding,
295 vec2 = vector::ShapeCastOp::create(
296 rewriter, loc, VectorType::get((16 * offset), ipType), vec2);
298 vector::ShuffleOp shuffle1;
299 vector::ShuffleOp shuffle2;
303 shuffle1 = vector::ShuffleOp::create(
304 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
306 ArrayRef<int64_t>{0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9,
307 41, 10, 42, 11, 43, 16, 48, 17, 49, 18, 50,
308 19, 51, 24, 56, 25, 57, 26, 58, 27, 59});
310 shuffle2 = vector::ShuffleOp::create(
311 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
313 ArrayRef<int64_t>{4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13,
314 45, 14, 46, 15, 47, 20, 52, 21, 53, 22, 54,
315 23, 55, 28, 60, 29, 61, 30, 62, 31, 63});
321 shuffle1 = vector::ShuffleOp::create(
322 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
325 0, 32, 64, 96, 1, 33, 65, 97, 2, 34, 66, 98, 3,
326 35, 67, 99, 8, 40, 72, 104, 9, 41, 73, 105, 10, 42,
327 74, 106, 11, 43, 75, 107, 16, 48, 80, 112, 17, 49, 81,
328 113, 18, 50, 82, 114, 19, 51, 83, 115, 24, 56, 88, 120,
329 25, 57, 89, 121, 26, 58, 90, 122, 27, 59, 91, 123});
331 shuffle2 = vector::ShuffleOp::create(
332 rewriter, loc, VectorType::get({(16 * offset)}, ipType), vec1,
335 4, 36, 68, 100, 5, 37, 69, 101, 6, 38, 70, 102, 7, 39,
336 71, 103, 12, 44, 76, 108, 13, 45, 77, 109, 14, 46, 78, 110,
337 15, 47, 79, 111, 20, 52, 84, 116, 21, 53, 85, 117, 22, 54,
338 86, 118, 23, 55, 87, 119, 28, 60, 92, 124, 29, 61, 93, 125,
339 30, 62, 94, 126, 31, 63, 95, 127});
343 Value ivShuff1 = arith::DivUIOp::create(rewriter, loc, iv, cStep);
344 Value ivShuff2 = arith::AddIOp::create(rewriter, loc, ivShuff1, c16);
346 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
347 ValueRange{indxToStoreInBuffer, ivShuff1, c0});
348 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
349 ValueRange{indxToStoreInBuffer, ivShuff2, c0});
351 scf::YieldOp::create(nestedBuilder, loc);
358 unsigned int offset,
Value packedBuffer,
bool pack,
359 Value indxToStoreInBuffer,
Value indxToLoadFromMatB) {
365 for (
size_t j = 0;
j < ops.size();
j++) {
366 for (
size_t i = 0; i < ops.size(); i++) {
370 Operation *readOpRhs = ops[
j].getRhs().getDefiningOp();
371 auto itRhs = readsToTileLoads.find(readOpRhs);
372 if (itRhs != readsToTileLoads.end()) {
377 performShuffle(rewriter, loc, matB, ipType, offset, packedBuffer,
378 indxToStoreInBuffer);
382 amx::TileType::get({16, (16 * offset)}, ipType);
384 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
388 amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
391 readsToTileLoads.try_emplace(readOpRhs, loadRow1);
392 readsToTileLoads.try_emplace(ops[i].getRhs().getDefiningOp(), loadRow2);
397 return readsToTileLoads;
405 unsigned int offset,
bool isVnni,
Value packedBuffer,
bool pack,
406 Value indxToStoreInBuffer,
Value indxToLoadFromMatB) {
421 packInputs(rewriter, loc, ops, matB, ipType, offset, packedBuffer, pack,
422 indxToStoreInBuffer, indxToLoadFromMatB);
426 for (
size_t i = 0; i < ops.size(); i++) {
428 Operation *readOpLhs = ops[i].getLhs().getDefiningOp();
429 amx::TileLoadOp tilesLhs;
430 auto itLhs = readsToTileLoads.find(readOpLhs);
431 if (itLhs != readsToTileLoads.end()) {
432 tilesLhs = itLhs->second;
434 tilesLhs = createTileLoads(rewriter, loc, ops[i].getLhs(), matA, ipType,
435 false, offset, isVnni);
436 readsToTileLoads.try_emplace(readOpLhs, tilesLhs);
439 Operation *readOpRhs = ops[i].getRhs().getDefiningOp();
440 amx::TileLoadOp tilesRhs;
441 auto itRhs = readsToTileLoads.find(readOpRhs);
442 if (itRhs != readsToTileLoads.end()) {
443 tilesRhs = itRhs->second;
445 tilesRhs = createTileLoads(rewriter, loc, ops[i].getRhs(), matB, ipType,
446 true, offset, isVnni);
447 readsToTileLoads.try_emplace(readOpRhs, tilesRhs);
450 auto accTileType = amx::TileType::get({16, 16}, opType);
454 dp = amx::TileMulFOp::create(rewriter, loc, accTileType, tilesLhs,
455 tilesRhs, accIterArgs[i]);
458 dp = amx::TileMulIOp::create(rewriter, loc, accTileType, tilesLhs,
459 tilesRhs, accIterArgs[i]);
461 accumulators.push_back(dp);
467 Type opType, scf::ForOp outerLoop,
472 auto zeroTileType = amx::TileType::get({16, 16}, opType);
474 for (
int i = 0; i < size; i++) {
475 auto zeroTile = amx::TileZeroOp::create(rewriter, loc, zeroTileType);
476 loopItrArgs.push_back(zeroTile);
481static Value getIndxToLoadStoreFromPckBuffer(
483 bool isInnerLoopUBHasOddQuot,
bool isInnerLoopUBLarger,
bool pack,
484 unsigned int blockingFactor) {
490 Value quotientInnerLoop =
491 arith::DivUIOp::create(rewriter, loc, ivInnerLoop, packOffset);
492 Value remInnerLoop = arith::RemUIOp::create(
493 rewriter, loc, rewriter.
getIndexType(), quotientInnerLoop, c2);
495 if (!isInnerLoopUBLarger && !pack) {
496 remInnerLoop = arith::RemUIOp::create(
497 rewriter, loc, rewriter.
getIndexType(), ivOuterLoop, c2);
500 if (isInnerLoopUBHasOddQuot) {
501 auto remOuterLoop = arith::RemUIOp::create(
502 rewriter, loc, rewriter.
getIndexType(), ivOuterLoop, c2);
503 auto remAdd = arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(),
504 remInnerLoop, remOuterLoop);
505 remInnerLoop = arith::RemUIOp::create(rewriter, loc,
515 Type ipType,
Type opType,
unsigned int blockingFactor,
bool isVnni,
517 vector::ContractionOp contractOp, scf::ForOp outerLoop,
519 Value ivOuterLoop,
Value packedBuffer,
bool pack,
521 bool isInnerLoopUBHasOddQuot) {
527 int64_t offset = 16 * blockingFactor;
529 offset = cst.value();
531 auto newLoop = scf::ForOp::create(
532 rewriter, loc, lowerBound, upperBound, step, loopItrArgs,
538 getIndexPosition(contractOp.getLhs(), outerLoop) + 1),
542 getIndexPosition(contractOp.getLhs(), innerLoop) + 1),
544 auto lhsClone = rewriterNewInnerLoop.
clone(*vectorOpLhs, mapping);
546 Value indxToStoreInBuffer = c0;
547 Value indxToLoadFromBuffer = c0;
550 if (innerLoopIndex.
value() == 0) {
553 ivOuterLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
556 if (!isInnerLoopUBLarger || isInnerLoopUBHasOddQuot) {
557 indxToStoreInBuffer = arith::RemUIOp::create(
562 Value indxToLoadFromMatB = arith::AddIOp::create(
563 rewriter, loc, indxToStoreInBuffer, c1);
564 indxToLoadFromBuffer = arith::RemUIOp::create(
565 rewriter, loc, rewriter.
getIndexType(), indxToLoadFromMatB,
571 rewriter, locNewInnerLoop, offset);
572 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
573 nLoadIndx, ivNewInnerLoop);
574 indxToStoreInBuffer = getIndxToLoadStoreFromPckBuffer(
575 rewriter, loc, ivNewInnerLoop, ivOuterLoop,
576 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
578 Value indxToLoadFromMatB =
579 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
580 indxToLoadFromBuffer =
581 arith::RemUIOp::create(rewriter, loc, rewriter.
getIndexType(),
582 indxToLoadFromMatB, c2);
587 rewriter, locNewInnerLoop, offset);
588 ivNewInnerLoop = arith::AddIOp::create(rewriter, locNewInnerLoop,
589 nLoadIndx, ivNewInnerLoop);
590 Value quotient_K = arith::DivUIOp::create(
591 rewriter, loc, ivNewInnerLoop, nLoadIndx);
592 indxToStoreInBuffer = arith::RemUIOp::create(
593 rewriter, loc, rewriter.
getIndexType(), quotient_K, c2);
595 Value indxToLoadFromMatB =
596 arith::AddIOp::create(rewriter, loc, indxToStoreInBuffer, c1);
597 indxToLoadFromBuffer =
598 arith::RemUIOp::create(rewriter, loc, rewriter.
getIndexType(),
599 indxToLoadFromMatB, c2);
612 int64_t outerPos = getIndexPosition(contractOp.getRhs(), outerLoop);
615 unsigned operandIdx =
static_cast<unsigned>(outerPos + 1);
617 if (operandIdx < rhsOp->getNumOperands())
622 int64_t innerPos = getIndexPosition(contractOp.getRhs(), innerLoop);
625 unsigned operandIdx =
static_cast<unsigned>(innerPos + 1);
627 if (operandIdx < rhsOp->getNumOperands())
628 rhsMapping.
map(rhsOp->
getOperand(operandIdx), ivNewInnerLoop);
631 auto rhsClone = rewriterNewInnerLoop.
clone(*rhsOp, rhsMapping);
632 matB = rhsClone->getResult(0);
643 rewriter, locNewInnerLoop, offset);
645 indxToLoadFromBuffer = c0;
646 indxToLoadFromBuffer = getIndxToLoadStoreFromPckBuffer(
647 rewriter, loc, nLoadIndx, ivOuterLoop,
648 isInnerLoopUBHasOddQuot, isInnerLoopUBLarger, pack,
654 rewriter, locNewInnerLoop, offset);
656 Value quotient_K = arith::DivUIOp::create(
657 rewriter, loc, ivNewInnerLoop, nLoadIndx);
658 indxToLoadFromBuffer = arith::RemUIOp::create(
659 rewriter, loc, rewriter.
getIndexType(), quotient_K, c2);
665 rewriter, locNewInnerLoop, ops, lhsClone->getResult(0), matB,
666 ipType, opType, iterArgsNewInnerLoop, blockingFactor, isVnni,
667 packedBuffer, pack, indxToStoreInBuffer, indxToLoadFromBuffer);
669 scf::YieldOp::create(rewriterNewInnerLoop, locNewInnerLoop,
740struct VectorContractToAMXDotProduct
742 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
744 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
745 PatternRewriter &rewriter)
const override {
747 if (contractOp.getKind() != vector::CombiningKind::ADD)
749 "Expects add combining kind.");
751 unsigned int blockingFactor =
752 contractOp.getLhsType().getElementType().isBF16() ? 2 : 4;
755 contractOp.getIndexingMapsArray(), blockingFactor);
757 VectorType lhsTy = contractOp.getLhsType();
758 if (!lhsTy.getElementType().isBF16() &&
759 !lhsTy.getElementType().isSignlessInteger(8) &&
760 !lhsTy.getElementType().isF8E4M3FN() &&
761 !lhsTy.getElementType().isF8E5M2())
763 contractOp,
"Only BF16/Int8/F8 lowering is supported.");
765 if (lhsTy.getElementType() != contractOp.getRhsType().getElementType())
767 contractOp,
"Contraction should have same lhs and rhs type.");
769 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
773 if (((lhsTy.getElementType().isBF16() ||
774 lhsTy.getElementType().isF8E4M3FN() ||
775 lhsTy.getElementType().isF8E5M2()) &&
776 !accTy.getElementType().isF32()) ||
777 (lhsTy.getElementType().isSignlessInteger(8) &&
778 !accTy.getElementType().isSignlessInteger(32)))
780 "Only F32 for BF16 or Int32 for Int8 "
781 "accumulation type is supported.");
783 Operation *accReadOp =
786 Operation *resultWriteOp =
789 if (!accReadOp || !resultWriteOp)
791 contractOp,
"The ACC operand of the vector.contract should be a "
792 "transfer_read or a load. And, the result should be "
793 "stored using transfer_write or store.");
798 if (lhsTy.getElementType().isSignlessInteger(8)) {
803 if (lhsTy.getElementType().isF8E4M3FN())
806 if (lhsTy.getElementType().isF8E5M2())
809 if (accReadOp->
getBlock() == contractOp->getBlock() &&
810 resultWriteOp->
getBlock() != contractOp->getBlock())
812 contractOp,
"The accumulator store is in different block.");
814 if (accReadOp->
getBlock() != contractOp->getBlock() &&
815 resultWriteOp->
getBlock() == contractOp->getBlock())
817 contractOp,
"The accumulator read is in different block.");
819 unsigned int dimValue = blockingFactor;
821 dimValue = 16 * blockingFactor;
825 if (accReadOp->
getBlock() == contractOp->getBlock() &&
826 resultWriteOp->
getBlock() == contractOp->getBlock()) {
828 bool collapse =
false;
832 LogicalResult validate = validateContractOps(
833 rewriter, contractOp, dimValue, Value(), Value(),
false);
837 contractOp,
"The contract operation doesn't satisfy the operands "
838 "dimensions. M, N, and vnni dims are 16, 16, and 2/4. "
839 "The rest dims should be 1. Op should have one user.");
841 Location loc = contractOp.getLoc();
843 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
844 contractOp.getLhs(), collapse);
847 "The LHS src is not a MemRef type.");
848 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
850 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
851 contractOp.getRhs(), collapse);
854 "The RHS src is not a MemRef type.");
855 auto rhsSrc = *srcIndxRhs;
856 auto srcBuffRhs = rhsSrc.first;
857 auto indicesRhs = rhsSrc.second;
859 auto srcIndxAcc = getSrcIndxValue(rewriter, contractOp.getLoc(),
860 contractOp.getAcc(),
false);
863 "The ACC src is not a MemRef type.");
864 auto [srcBuffAcc, indicesAcc] = *srcIndxAcc;
869 auto tileType = amx::TileType::get({16, (16 * blockingFactor)}, ipType);
870 auto loadLhs = amx::TileLoadOp::create(rewriter, loc, tileType,
871 srcBuffLhs, indicesLhs);
874 amx::TileLoadOp loadRhs;
877 SmallVector<OpFoldResult> indexVals;
878 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
879 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
880 indexVals = SmallVector<OpFoldResult>(readOp.getIndices().begin(),
881 readOp.getIndices().end());
882 vecTy = readOp.getType();
885 SmallVector<OpFoldResult> strides(indexVals.size(), one);
887 contractOp.getRhs().getDefiningOp()->getContext(),
889 auto subview = memref::SubViewOp::create(rewriter, loc, srcBuffRhs,
890 indexVals, sizes, strides);
891 auto bufferType = MemRefType::get({16, (16 * blockingFactor)}, ipType);
892 auto packedBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
898 (blockingFactor * 16));
902 rewriter, loc, 16 * (blockingFactor / 2));
905 rewriter, loc, c0, uBound, step,
ValueRange{},
906 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
909 arith::AddIOp::create(rewriter, loc, nextLoadIndx, iv);
911 indicesRhs[indicesRhs.size() - 2] = iv;
912 indicesRhs[indicesRhs.size() - 1] = c0;
914 auto vec1 = vector::LoadOp::create(
916 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
919 indicesRhs[indicesRhs.size() - 2] = i1_load;
921 auto vec2 = vector::LoadOp::create(
923 VectorType::get(16 * (blockingFactor / 2), ipType), subview,
926 vector::ShuffleOp shuffle1;
927 vector::ShuffleOp shuffle2;
929 if (blockingFactor == 2) {
931 shuffle1 = vector::ShuffleOp::create(
932 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
933 ArrayRef<int64_t>{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21,
936 shuffle2 = vector::ShuffleOp::create(
937 rewriter, loc, VectorType::get({16}, ipType), vec1, vec2,
938 ArrayRef<int64_t>{8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
939 29, 14, 30, 15, 31});
942 if (blockingFactor == 4) {
943 shuffle1 = vector::ShuffleOp::create(
944 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
945 ArrayRef<int64_t>{0, 16, 32, 48, 1, 17, 33, 49,
946 2, 18, 34, 50, 3, 19, 35, 51,
947 4, 20, 36, 52, 5, 21, 37, 53,
948 6, 22, 38, 54, 7, 23, 39, 55});
950 shuffle2 = vector::ShuffleOp::create(
951 rewriter, loc, VectorType::get({32}, ipType), vec1, vec2,
952 ArrayRef<int64_t>{8, 24, 40, 56, 9, 25, 41, 57,
953 10, 26, 42, 58, 11, 27, 43, 59,
954 12, 28, 44, 60, 13, 29, 45, 61,
955 14, 30, 46, 62, 15, 31, 47, 63});
958 auto rem = arith::DivUIOp::create(
961 vector::StoreOp::create(rewriter, loc, shuffle1, packedBuffer,
963 vector::StoreOp::create(rewriter, loc, shuffle2, packedBuffer,
966 scf::YieldOp::create(nestedBuilder, loc);
968 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, packedBuffer,
972 loadRhs = amx::TileLoadOp::create(rewriter, loc, tileType, srcBuffRhs,
976 auto tileTypeAcc = amx::TileType::get({16, 16}, opType);
977 auto loadAcc = amx::TileLoadOp::create(rewriter, loc, tileTypeAcc,
978 srcBuffAcc, indicesAcc);
983 dp = amx::TileMulFOp::create(rewriter, loc, tileTypeAcc, loadLhs,
987 dp = amx::TileMulIOp::create(rewriter, loc, tileTypeAcc, loadLhs,
990 auto bufferType = MemRefType::get({16, 16}, opType);
991 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
993 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
ValueRange{c0, c0},
996 auto vectorType = mlir::VectorType::get({16, 16}, opType);
998 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
999 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1002 SmallVector<bool> inBounds(vectorType.getRank(),
true);
1004 Value vecRow = vector::TransferReadOp::create(
1005 rewriter, loc, vectorType, resultBuffer,
ValueRange{c0, c0}, padding,
1008 Value resultOp = contractionUsersAfterYield(contractOp.getResult());
1009 if (
auto vecType = llvm::dyn_cast<VectorType>(resultOp.getType()))
1010 vecRow = vector::ShapeCastOp::create(rewriter, loc, vecType, vecRow);
1020 SmallVector<scf::ForOp> loopLists;
1021 Operation *current = contractOp;
1028 "Accumulator read and contract op not within scf.for op");
1030 loopLists.push_back(dyn_cast<scf::ForOp>(parent));
1038 if (loopLists.size() > 2 || loopLists.size() == 0)
1040 contractOp,
"Rewrite is supported until reduction loop depth of 2.");
1042 auto srcIndxLhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1043 contractOp.getLhs(),
false);
1046 "The LHS src is not a MemRef type.");
1047 auto [srcBuffLhs, indicesLhs] = *srcIndxLhs;
1049 auto srcIndxRhs = getSrcIndxValue(rewriter, contractOp.getLoc(),
1050 contractOp.getRhs(),
false);
1053 "The RHS src is not a MemRef type.");
1054 auto [srcBuffRhs, indicesRhs] = *srcIndxRhs;
1055 Operation *vectorOpLhs;
1056 llvm::TypeSwitch<Operation *>(contractOp.getLhs().getDefiningOp())
1057 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
1058 vectorOpLhs = readOp.getBase().getDefiningOp();
1061 Operation *vectorOpRhs;
1062 llvm::TypeSwitch<Operation *>(contractOp.getRhs().getDefiningOp())
1063 .Case<TransferReadOp, LoadOp>([&](
auto readOp) {
1064 vectorOpRhs = readOp.getBase().getDefiningOp();
1068 SmallVector<vector::ContractionOp> ops;
1069 for (mlir::Operation &op : loopLists[0].getBody()->getOperations()) {
1071 if (
auto contract = llvm::dyn_cast<mlir::vector::ContractionOp>(op)) {
1073 LogicalResult validate = validateContractOps(
1074 rewriter,
contract, dimValue, srcBuffLhs, srcBuffRhs,
true);
1079 "The associated contract operations doesn't satisfy "
1080 "the re-write conditions either the dimensions are "
1081 "wrong or MemRef source are different or many users.");
1088 unsigned int pairCount = 0;
1089 for (
size_t j = 0; j < ops.size(); j++) {
1090 for (
size_t i = j; i < ops.size(); i++) {
1092 pairCount = pairCount + 2;
1096 if (pairCount != ops.size())
1098 contractOp,
"Coudn't find the pair vector contract ");
1101 scf::ForOp innerLoop;
1102 scf::ForOp outerLoop;
1106 if (loopLists.size() == 2) {
1107 outerLoop = loopLists[1];
1108 innerLoop = loopLists[0];
1110 LogicalResult validateOuterLoopStep =
1111 validateLoopStep(rewriter, outerLoop.getStep(), 1);
1112 if (
failed(validateOuterLoopStep))
1115 int64_t stepValue = 16;
1117 stepValue = stepValue * blockingFactor;
1118 LogicalResult validateInnerLoopStep =
1119 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1120 if (
failed(validateInnerLoopStep))
1122 contractOp,
"Invalid loop step. The step should be 32 for BF16 and "
1125 SmallVector<Value> loopItrArgs = createTileZeros(
1126 rewriter, outerLoop.getLoc(), opType, outerLoop, ops.size());
1129 newLoop = scf::ForOp::create(
1130 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1131 outerLoop.getUpperBound(), outerLoop.getStep(), loopItrArgs,
1132 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1133 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1134 auto newInnerLoop = createLoops(
1135 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1136 innerLoop.getUpperBound(), innerLoop.getStep(),
1137 iterArgsOuterLoop, ipType, opType, blockingFactor, isVnni,
1138 vectorOpLhs, vectorOpRhs, contractOp, outerLoop, innerLoop,
1139 ops, ivOuterLoop, nullptr, true, nullptr, false, false);
1141 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1142 newInnerLoop.getResults());
1147 bool isInnerLoopUBLarger =
false;
1148 bool isInnerLoopUBHasOddQuot =
false;
1150 int64_t ubVal = 16 * blockingFactor;
1151 mlir::Value ub = innerLoop.getUpperBound();
1152 if (
auto constOp = ub.
getDefiningOp<mlir::arith::ConstantOp>()) {
1154 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1155 ubVal = intAttr.getInt();
1159 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1160 isInnerLoopUBHasOddQuot =
1161 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1170 rewriter, outerLoop.getLoc(), 16 * blockingFactor);
1172 Value spillOuterLoop = arith::SubIOp::create(
1173 rewriter, outerLoop.getLoc(), outerLoop.getUpperBound(), c1);
1174 Value spillInnerLoop =
1175 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1176 innerLoop.getUpperBound(), spillLoopBound);
1178 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1180 memref::AllocaOp::create(rewriter, outerLoop.getLoc(), bufferType);
1183 IRMapping rhsMapping;
1185 vectorOpRhs->getOperand(
1186 getIndexPosition(contractOp.getRhs(), outerLoop) + 1),
1187 outerLoop.getLowerBound());
1189 vectorOpRhs->getOperand(
1190 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1191 innerLoop.getLowerBound());
1192 auto rhsClone = rewriter.
clone(*vectorOpRhs, rhsMapping);
1194 Value quotient_batch = arith::DivUIOp::create(
1195 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1196 outerLoop.getStep());
1197 Value quotient_k = arith::DivUIOp::create(rewriter, outerLoop.getLoc(),
1198 innerLoop.getLowerBound(),
1199 innerLoop.getStep());
1201 Value quotient_add = arith::AddIOp::create(rewriter, outerLoop.getLoc(),
1202 quotient_batch, quotient_k);
1205 Value
rem = arith::RemUIOp::create(rewriter, outerLoop.getLoc(),
1208 performShuffle(rewriter, outerLoop.getLoc(), rhsClone->getResult(0),
1209 ipType, blockingFactor, packedBuffer,
rem);
1212 auto newLoopNonSpill = scf::ForOp::create(
1213 rewriter, outerLoop.getLoc(), outerLoop.getLowerBound(),
1214 spillOuterLoop, outerLoop.getStep(), loopItrArgs,
1215 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1216 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1217 auto newInnerLoop1 = createLoops(
1218 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1219 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1220 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1221 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1222 ivOuterLoop, packedBuffer, true, spillLoopBound,
1223 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1225 auto newInnerLoop = createLoops(
1226 rewriter, innerLoop.getLoc(), spillInnerLoop,
1227 innerLoop.getUpperBound(), innerLoop.getStep(),
1228 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1229 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1230 innerLoop, ops, ivOuterLoop, packedBuffer, true, c0,
1231 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1233 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1234 newInnerLoop.getResults());
1238 newLoop = scf::ForOp::create(
1239 rewriter, outerLoop.getLoc(), spillOuterLoop,
1240 outerLoop.getUpperBound(), outerLoop.getStep(),
1241 newLoopNonSpill.getResults(),
1242 [&](OpBuilder &rewriterOuterLoop, Location locOuterLoop,
1243 Value ivOuterLoop,
ValueRange iterArgsOuterLoop) {
1244 auto newInnerLoop1 = createLoops(
1245 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1246 spillInnerLoop, innerLoop.getStep(), iterArgsOuterLoop,
1247 ipType, opType, blockingFactor, isVnni, vectorOpLhs,
1248 vectorOpRhs, contractOp, outerLoop, innerLoop, ops,
1249 ivOuterLoop, packedBuffer, true, spillLoopBound,
1250 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1252 auto newInnerLoop = createLoops(
1253 rewriter, innerLoop.getLoc(), spillInnerLoop,
1254 innerLoop.getUpperBound(), innerLoop.getStep(),
1255 newInnerLoop1.getResults(), ipType, opType, blockingFactor,
1256 isVnni, vectorOpLhs, vectorOpRhs, contractOp, outerLoop,
1257 innerLoop, ops, ivOuterLoop, packedBuffer, false, c0,
1258 isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1260 scf::YieldOp::create(rewriterOuterLoop, locOuterLoop,
1261 newInnerLoop.getResults());
1267 if (loopLists.size() == 1) {
1269 innerLoop = loopLists[0];
1270 int64_t stepValue = 16;
1272 stepValue = stepValue * blockingFactor;
1274 LogicalResult validateInnerLoopStep =
1275 validateLoopStep(rewriter, innerLoop.getStep(), stepValue);
1276 if (
failed(validateInnerLoopStep))
1279 "Invalid loop step. The step should be 32 for BF16 and "
1280 "64 for Int8/F8 or 1 if it is rduction loop other than K.");
1282 SmallVector<Value> loopItrArgs = createTileZeros(
1283 rewriter, innerLoop.getLoc(), opType, innerLoop, ops.size());
1286 newLoop = createLoops(
1287 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1288 innerLoop.getUpperBound(), innerLoop.getStep(), loopItrArgs, ipType,
1289 opType, blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1290 contractOp,
nullptr, innerLoop, ops,
nullptr,
nullptr,
true,
1291 nullptr,
false,
false);
1295 bool isInnerLoopUBLarger =
false;
1296 bool isInnerLoopUBHasOddQuot =
false;
1298 int64_t ubVal = 16 * blockingFactor;
1299 mlir::Value ub = innerLoop.getUpperBound();
1300 if (
auto constOp = ub.
getDefiningOp<mlir::arith::ConstantOp>()) {
1302 llvm::dyn_cast<mlir::IntegerAttr>(constOp.getValue())) {
1303 ubVal = intAttr.getInt();
1307 isInnerLoopUBLarger = ubVal > 16 * blockingFactor;
1308 isInnerLoopUBHasOddQuot =
1309 (((ubVal / (16 * blockingFactor)) % 2) == 1) && isInnerLoopUBLarger;
1315 int64_t offset = 16 * blockingFactor;
1317 innerLoop.getStep().getDefiningOp<arith::ConstantIndexOp>())
1318 offset = cst.value();
1321 rewriter, innerLoop.getLoc(), offset);
1322 Value spillInnerLoop =
1323 arith::SubIOp::create(rewriter, innerLoop.getLoc(),
1324 innerLoop.getUpperBound(), spillLoopBound);
1327 MemRefType::get({2, 32, (blockingFactor * 16)}, ipType);
1329 memref::AllocaOp::create(rewriter, innerLoop.getLoc(), bufferType);
1332 IRMapping rhsMapping;
1334 vectorOpRhs->getOperand(
1335 getIndexPosition(contractOp.getRhs(), innerLoop) + 1),
1336 innerLoop.getLowerBound());
1337 auto rhsClone = rewriter.
clone(*vectorOpRhs, rhsMapping);
1339 Value quotient_k = arith::DivUIOp::create(rewriter, innerLoop.getLoc(),
1340 innerLoop.getLowerBound(),
1341 innerLoop.getStep());
1344 Value
rem = arith::RemUIOp::create(rewriter, innerLoop.getLoc(),
1347 performShuffle(rewriter, innerLoop.getLoc(), rhsClone->getResult(0),
1348 ipType, blockingFactor, packedBuffer,
rem);
1350 auto newLoopNonSpill = createLoops(
1351 rewriter, innerLoop.getLoc(), innerLoop.getLowerBound(),
1352 spillInnerLoop, innerLoop.getStep(), loopItrArgs, ipType, opType,
1353 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs, contractOp,
1354 nullptr, innerLoop, ops,
nullptr, packedBuffer,
true,
1355 spillLoopBound, isInnerLoopUBLarger, isInnerLoopUBHasOddQuot);
1357 newLoop = createLoops(rewriter, innerLoop.getLoc(), spillInnerLoop,
1358 innerLoop.getUpperBound(), innerLoop.getStep(),
1359 newLoopNonSpill.getResults(), ipType, opType,
1360 blockingFactor, isVnni, vectorOpLhs, vectorOpRhs,
1361 contractOp,
nullptr, innerLoop, ops,
nullptr,
1362 packedBuffer,
false, c0, isInnerLoopUBLarger,
1363 isInnerLoopUBHasOddQuot);
1368 outerLoop = innerLoop;
1373 Location loc = outerLoop.getLoc();
1375 SmallVector<Value> indicesAcc;
1377 llvm::TypeSwitch<Operation *>(accReadOp).Case<TransferReadOp, LoadOp>(
1379 srcBuffAcc = readOp.getOperand(0);
1381 auto indices = readOp.getIndices();
1382 indicesAcc.reserve(
indices.size());
1384 llvm::transform(
indices, std::back_inserter(indicesAcc),
1385 [&](OpFoldResult ofr) {
1387 rewriter, loc, ofr);
1392 mlir::cast<mlir::MemRefType>(srcBuffAcc.
getType()).getShape();
1393 unsigned int M = outputShapes[outputShapes.size() - 2];
1394 unsigned int N = outputShapes[outputShapes.size() - 1];
1396 SmallVector<Value> dps = newLoop.getResults();
1397 auto bufferType = MemRefType::get({M, N}, opType);
1398 auto resultBuffer = memref::AllocaOp::create(rewriter, loc, bufferType);
1401 for (
unsigned int i = 0, k = 0; i < M; i = i + 16) {
1402 for (
unsigned int j = 0; j < N; j = j + 16) {
1405 amx::TileStoreOp::create(rewriter, loc, resultBuffer,
1418 rewriter, loc, c0, nBound, one,
ValueRange{},
1419 [&](OpBuilder &nestedBuilder, Location loc, Value iv,
1422 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1426 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1429 Value shuffle1 = row;
1430 Value shuffle2 = row2;
1433 shuffle1 = vector::ShuffleOp::create(
1434 rewriter, loc, VectorType::get(16, opType), row, row2,
1435 ArrayRef<int64_t>{0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20,
1438 shuffle2 = vector::ShuffleOp::create(
1439 rewriter, loc, VectorType::get(16, opType), row, row2,
1440 ArrayRef<int64_t>{8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15,
1443 indicesAcc[indicesAcc.size() - 2] = iv;
1444 indicesAcc[indicesAcc.size() - 1] = c0;
1447 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1448 srcBuffAcc, indicesAcc);
1449 indicesAcc[indicesAcc.size() - 1] = c16;
1452 vector::LoadOp::create(rewriter, loc, VectorType::get(16, opType),
1453 srcBuffAcc, indicesAcc);
1459 addOp = arith::AddFOp::create(rewriter, loc, shuffle1, valueCRow1);
1461 addOp2 = arith::AddFOp::create(rewriter, loc, shuffle2, valueCRow2);
1465 addOp = arith::AddIOp::create(rewriter, loc, shuffle1, valueCRow1);
1467 addOp2 = arith::AddIOp::create(rewriter, loc, shuffle2, valueCRow2);
1470 vector::StoreOp::create(rewriter, loc, addOp, resultBuffer,
1472 vector::StoreOp::create(rewriter, loc, addOp2, resultBuffer,
1475 scf::YieldOp::create(nestedBuilder, loc);
1478 SmallVector<Value> writeResults;
1479 for (
unsigned int i = 0; i < M; i = i + 16) {
1480 for (
unsigned int j = 0; j < N; j = j + 16) {
1484 auto vectorType = mlir::VectorType::get({16, 16}, opType);
1487 (dyn_cast<ShapedType>(resultBuffer.getType())).getRank();
1488 Value padding = ub::PoisonOp::create(rewriter, loc, opType);
1491 SmallVector<bool> inBounds(vectorType.getRank(),
true);
1493 auto vec1 = vector::TransferReadOp::create(
1494 rewriter, loc, vectorType, resultBuffer,
1495 ValueRange{indexOp_i, indexOp_j}, padding, map, inBounds);
1496 writeResults.push_back(vec1);
1501 for (
size_t i = 0; i < ops.size(); i++) {
1502 vector::ContractionOp contOp = ops[i];
1503 Value vecRow = writeResults[i];
1505 Value resultWriteOp = contractionUsersAfterYield(contOp.getResult());
1506 if (
auto vecType = llvm::dyn_cast<VectorType>(resultWriteOp.
getType()))
1507 vecRow = mlir::vector::ShapeCastOp::create(rewriter, loc, vecType,
1521 patterns.
add<VectorContractToAMXDotProduct>(patterns.
getContext());
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
IntegerAttr getIndexAttr(int64_t value)
FloatType getF8E5M2Type()
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
FloatType getF8E4M3FNType()
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
unsigned getNumUses() const
This method computes the number of uses of this Value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
use_iterator use_begin() const
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
mlir::x86::AMXTileType TileType
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
void populateVectorContractToAMXDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< int64_t, 2 > ReassociationIndices
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.