37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/ADT/iterator_range.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
46 #include <type_traits>
51 #define DEBUG_TYPE "linalg-vectorization"
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
57 static FailureOr<Operation *>
61 bool flatten1DDepthwiseConv =
false);
96 template <
typename OpType>
99 block.
walk([&](OpType op) {
114 int64_t nSize, int64_t wSize, int64_t cSize,
115 int64_t kwSize,
int strideW,
int dilationW,
116 int64_t wSizeStep,
bool isSingleChanneled) {
118 if (isSingleChanneled) {
123 for (int64_t kw = 0; kw < kwSize; ++kw) {
124 for (int64_t w = 0; w < wSize; w += wSizeStep) {
125 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
134 for (int64_t kw = 0; kw < kwSize; ++kw) {
135 for (int64_t w = 0; w < wSize; w += wSizeStep) {
136 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
154 for (int64_t kw = 0; kw < kwSize; ++kw) {
155 result.push_back(rewriter.
create<vector::ExtractOp>(
165 int64_t nSize, int64_t wSize, int64_t fSize,
166 int64_t wSizeStep,
bool isSingleChanneled) {
168 if (isSingleChanneled) {
172 for (int64_t w = 0; w < wSize; w += wSizeStep) {
173 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
181 for (int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
191 Value res, int64_t wSize, int64_t wSizeStep,
193 bool isSingleChanneled) {
195 if (isSingleChanneled) {
199 for (int64_t w = 0; w < wSize; w += wSizeStep) {
200 res = rewriter.
create<vector::InsertStridedSliceOp>(
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = rewriter.
create<vector::InsertStridedSliceOp>(
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
240 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
243 if (dimPermutation.has_value()) {
245 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
250 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
262 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
267 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
268 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
274 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
283 std::optional<AffineMap> maybeMaskingMap);
288 bool isValidMaskingMap(
AffineMap maskingMap) {
337 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
340 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
341 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
343 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
344 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
351 unsigned operandDimPos;
352 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
356 Value dynamicDim = linalgOp.hasPureTensorSemantics()
358 linalgOp.getLoc(), operand, operandDimPos)
360 linalgOp.getLoc(), operand, operandDimPos);
361 iterSpaceValueSizes.push_back(dynamicDim);
377 if (!inputVectorSizes.empty()) {
381 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
382 scalableVecDims.append(inputScalableVecDims.begin(),
383 inputScalableVecDims.end());
388 canonicalVecShape = linalgOp.getStaticLoopRanges();
389 scalableVecDims.append(linalgOp.getNumLoops(),
false);
392 LDBG(
"Canonical vector shape: ");
393 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
394 LLVM_DEBUG(llvm::dbgs() <<
"\n");
395 LDBG(
"Scalable vector dims: ");
396 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
397 LLVM_DEBUG(llvm::dbgs() <<
"\n");
399 if (ShapedType::isDynamicShape(canonicalVecShape))
403 initIterSpaceStaticSizes(linalgOp);
408 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
418 Value VectorizationState::getOrCreateMaskFor(
420 std::optional<AffineMap> maybeMaskingMap) {
422 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
423 "Ill-formed masking map.");
426 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
430 assert(!maskableOp.isMasked() &&
431 "Masking an operation that is already masked");
434 assert((!maybeMaskingMap || *maybeMaskingMap) &&
435 "Unexpected null mask permutation map");
437 maybeMaskingMap ? *maybeMaskingMap
439 linalgOp.getNumLoops(), rewriter.
getContext());
441 LDBG(
"Masking map: " << maskingMap <<
"\n");
445 auto activeMaskIt = activeMaskCache.find(maskingMap);
446 if (activeMaskIt != activeMaskCache.end()) {
447 Value mask = activeMaskIt->second;
448 LDBG(
"Reusing mask: " << mask <<
"\n");
459 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
460 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
461 auto maskShape = maskType.getShape();
463 LDBG(
"Mask shape: ");
464 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
465 LLVM_DEBUG(llvm::dbgs() <<
"\n");
467 if (permutedStaticSizes == maskShape) {
468 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
469 activeMaskCache[maskingMap] =
Value();
476 assert(!maskShape.empty() && !upperBounds.empty() &&
477 "Masked 0-d vectors are not supported yet");
480 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
481 maskType, upperBounds);
482 LDBG(
"Creating new mask: " << mask <<
"\n");
483 activeMaskCache[maskingMap] = mask;
490 std::optional<AffineMap> maybeIndexingMap) {
491 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
493 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
494 if (maybeIndexingMap)
495 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
499 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
502 LDBG(
"No mask required\n");
507 assert(opToMask &&
"Expected a valid operation to mask");
508 auto maskOp = cast<vector::MaskOp>(
510 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
516 LDBG(
"Masked operation: " << *maskOp <<
"\n");
539 "expected projected permutation");
541 assert(res.getNumDims() ==
542 (res.getNumResults() - res.getNumOfZeroResults()) &&
543 "expected reindexed map with same number of dims and results");
579 std::optional<vector::CombiningKind>
581 using ::mlir::vector::CombiningKind;
586 .Case<arith::AddIOp, arith::AddFOp>(
587 [&](
auto op) {
return CombiningKind::ADD; })
588 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
589 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
590 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
591 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
592 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
593 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
595 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
596 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
597 .Case<arith::MulIOp, arith::MulFOp>(
598 [&](
auto op) {
return CombiningKind::MUL; })
599 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
600 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
601 .Default([&](
auto op) {
return std::nullopt; });
612 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
617 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
618 combinerOps.size() != 1)
622 return combinerOps[0];
628 auto dstVecType = dyn_cast<VectorType>(dstType);
630 if (dstVecType.getRank() == 0)
636 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
648 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
649 return b.
create<vector::MultiDimReductionOp>(
650 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
654 return llvm::to_vector(
661 return isa<linalg::ReduceOp>(op) ||
662 (isa<linalg::GenericOp>(op) &&
676 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
677 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
686 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
688 auto vectorType = state.getCanonicalVecType(
692 if (vectorType.getRank() > 0) {
695 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
697 assert(value.
getType() == vectorType &&
"Incorrect type");
698 write = rewriter.
create<vector::TransferWriteOp>(
699 loc, value, outputOperand->
get(), indices, writeMap);
702 if (!isa<VectorType>(value.
getType()))
703 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
704 assert(value.
getType() == vectorType &&
"Incorrect type");
705 write = rewriter.
create<vector::TransferWriteOp>(
709 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
713 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
714 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
719 LDBG(
"vectorized op: " << *write <<
"\n");
729 std::function<LogicalResult(
Operation *,
bool)>;
748 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
757 linalgOp.getDpsInitOperand(output.index()), state);
759 newResults.push_back(newResult);
773 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
776 auto loc = indexOp.getLoc();
779 auto dim = indexOp.getDim();
781 auto indexVectorType =
783 state.getScalableVecDims()[dim]);
784 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
788 if (dim == targetShape.size() - 1)
794 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
795 std::swap(permPattern[dim], permPattern.back());
799 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
800 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
803 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
804 std::swap(transposition.back(), transposition[dim]);
806 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
814 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
818 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
823 if (not extractOp.getIndices().empty()) {
824 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
828 if (!llvm::all_of(extractOp->getResultTypes(),
829 VectorType::isValidElementType)) {
848 tensor::ExtractOp extractOp,
851 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
852 auto loc = extractOp.getLoc();
855 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
857 const size_t numIndices = extractOp.getIndices().size();
858 for (
size_t i = 1; i < numIndices; i++) {
859 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
863 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
866 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
869 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
871 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
897 (linalgOp.hasDynamicShape() ||
898 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
899 "For statically shaped Linalg Ops, only one "
900 "non-unit loop dim is expected");
901 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
903 size_t idx = loopRanges.size() - 1;
904 for (; idx != 0; idx--)
905 if (loopRanges[idx] != 1)
913 VectorType resType) {
915 assert(((llvm::count_if(resType.getShape(),
916 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
917 "n-D vectors are not yet supported");
923 auto *block = linalgOp.getBlock();
924 if (isa<BlockArgument>(val))
925 return !llvm::is_contained(block->getArguments(), val);
928 assert(defOp &&
"This is neither a block argument nor an operation result");
933 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
934 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
937 auto *ancestor = block->findAncestorOpInBlock(*defOp);
944 if (isa<arith::ConstantOp>(ancestor))
948 for (
auto op : ancestor->getOperands())
972 bool &foundIndexOp, VectorType resType) {
974 assert(((llvm::count_if(resType.getShape(),
975 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976 "n-D vectors are not yet supported");
982 auto *block = linalgOp.getBlock();
983 if (isa<BlockArgument>(val))
984 return !llvm::is_contained(block->getArguments(), val);
987 assert(defOp &&
"This is neither a block argument nor an operation result");
989 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
992 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
996 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1003 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1006 bool result =
false;
1007 for (
auto op : ancestor->getOperands())
1027 LinalgOp &linalgOp, VectorType resType) {
1029 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1032 if (inputShape.getShape().empty())
1037 bool isOutput1DVector =
1038 (llvm::count_if(resType.getShape(),
1039 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1041 if (!isOutput1DVector)
1044 bool leadingIdxsLoopInvariant =
true;
1050 auto indices = extractOp.getIndices();
1051 auto leadIndices = indices.drop_back(1);
1054 if (inputShape.getShape()[i] == 1)
1060 if (!leadingIdxsLoopInvariant) {
1061 LDBG(
"Found gather load: " << extractOp);
1069 auto extractOpTrailingIdx = indices.back();
1073 if (leadingIdxsLoopInvariant &&
1075 LDBG(
"Found scalar broadcast load: " << extractOp);
1084 bool foundIndexOp =
false;
1086 foundIndexOp, resType);
1089 bool isRowVector = resType.getShape().back() != 1;
1090 isContiguousLoad &= (foundIndexOp && isRowVector);
1092 if (isContiguousLoad) {
1093 LDBG(
"Found contigous load: " << extractOp);
1098 LDBG(
"Found gather load: " << extractOp);
1109 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1112 auto loc = extractOp.getLoc();
1115 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1116 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1120 auto passThruConstantOp =
1126 extractOp.getIndices().size(),
1127 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1138 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1139 maskConstantOp, passThruConstantOp);
1140 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1142 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1165 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1166 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1168 transferReadIdxs.push_back(idx);
1172 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1175 resultType.getScalableDims().back()),
1177 transferReadIdxs.push_back(
1178 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1182 auto dstRank = resultType.getRank();
1183 auto srcRank = extractOp.getTensor().getType().getRank();
1192 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1193 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1194 std::nullopt, permutationMap, inBounds);
1201 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
1203 auto *maskedReadOp =
1206 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1215 int32_t rankDiff = dstRank - srcRank;
1223 while (rankDiff > 0) {
1224 permutationMap = permutationMap.insertResult(
1229 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1230 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1231 std::nullopt, permutationMap, inBounds);
1233 LDBG(
"Vectorised as contiguous load: " << extractOp);
1247 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1248 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1252 (outputType && reduceType.getShape() == outputType.getShape()))
1281 LDBG(
"vectorize op " << *op <<
"\n");
1284 if (!customVectorizationHooks.empty()) {
1285 for (
auto &customFunc : customVectorizationHooks) {
1295 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1297 rewriter.
clone(*op)};
1306 auto blockArg = dyn_cast<BlockArgument>(operand);
1307 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1308 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1312 linalgOp.getRegionOutputArgs(),
1313 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1316 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1318 if (!reductionOperands.empty()) {
1319 assert(reductionOperands.size() == 1);
1321 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1322 reductionOperands[0].second, bvm);
1329 VectorType firstMaxRankedType;
1331 auto vecOperand = bvm.
lookup(operand);
1332 assert(vecOperand &&
"Vector operand couldn't be found");
1334 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1335 if (vecType && (!firstMaxRankedType ||
1336 firstMaxRankedType.getRank() < vecType.getRank()))
1337 firstMaxRankedType = vecType;
1343 assert(vecOperand &&
"Vector operand couldn't be found");
1345 if (firstMaxRankedType) {
1348 firstMaxRankedType.getScalableDims());
1351 vecOperands.push_back(vecOperand);
1357 resultTypes.push_back(
1360 firstMaxRankedType.getScalableDims())
1392 static LogicalResult
1396 LDBG(
"Vectorizing operation as linalg generic\n");
1397 Block *block = linalgOp.getBlock();
1404 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1406 if (linalgOp.getNumDpsInits() == 0)
1411 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1412 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1413 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1414 if (linalgOp.isScalar(opOperand)) {
1415 bvm.
map(bbarg, opOperand->get());
1421 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1424 VectorType readType;
1426 if (linalgOp.isDpsInput(opOperand)) {
1429 readType = state.getCanonicalVecType(elemType);
1436 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1442 loc, readType, opOperand->get(), indices,
1443 std::nullopt, readMap);
1444 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1449 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1451 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1457 if (readType.getRank() == 0)
1473 hooks.push_back(vectorizeYield);
1480 hooks.push_back(vectorizeIndex);
1487 hooks.push_back(vectorizeExtract);
1494 LDBG(
"failed to vectorize: " << op <<
"\n");
1499 state.maskOperation(rewriter, result.
newOp, linalgOp);
1500 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1566 if (ShapedType::isDynamicShape(destShape))
1573 cstMaskSizes.push_back(*intSize);
1578 if (cstMaskSizes.size() != maskShape.size())
1586 cstWriteIdxs.push_back(intVal.getSExtValue());
1591 if (cstWriteIdxs.size() != destShape.size())
1600 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1602 if ( maskShape[i] > destShape[rankDiff + i] ||
1603 destShape[rankDiff + i] <
1604 (
std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1640 bool useInBoundsInsteadOfMasking =
false) {
1642 ShapedType destType = cast<ShapedType>(dest.
getType());
1643 int64_t destRank = destType.getRank();
1644 auto destShape = destType.getShape();
1646 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1647 int64_t vecToStoreRank = vecToStoreType.getRank();
1648 auto vecToStoreShape = vecToStoreType.getShape();
1652 if (useInBoundsInsteadOfMasking) {
1655 for (
unsigned i = 0; i < vecToStoreRank; i++)
1657 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1658 !ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
1662 assert(writeIndices.empty() ||
1663 writeIndices.size() ==
static_cast<size_t>(destRank) &&
1664 "Invalid number of write indices!");
1665 if (writeIndices.empty()) {
1666 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1667 writeIndices.assign(destRank, zero);
1672 builder.
create<vector::TransferWriteOp>(loc,
1679 if (useInBoundsInsteadOfMasking)
1683 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1698 Value maskForWrite =
1699 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1737 static LogicalResult
1746 auto padValue = packOp.getPaddingValue();
1748 padValue = rewriter.
create<arith::ConstantOp>(
1749 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1752 LogicalResult status =
1753 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1754 .reifyResultShapes(rewriter, reifiedReturnShapes);
1756 assert(succeeded(status) &&
"failed to reify result shapes");
1761 bool useInBoundsInsteadOfMasking =
false;
1762 if (inputVectorSizes.empty()) {
1764 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1765 useInBoundsInsteadOfMasking =
true;
1770 auto innerTiles = packOp.getStaticInnerTiles();
1779 rewriter, loc, packOp.getSource(), inputShape, padValue,
1780 useInBoundsInsteadOfMasking);
1786 packOp.getDestType().getElementType());
1788 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1791 auto destPermutation =
1793 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1794 loc, shapeCastOp.getResult(), destPermutation);
1798 loc, reifiedReturnShapes[0],
1799 transposeOp.getResult().getType().getElementType());
1802 newResults.push_back(write->getResult(0));
1815 static LogicalResult
1824 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1829 bool useInBoundsInsteadOfMasking =
false;
1832 auto destSize = unpackOp.getDestRank();
1834 if (!inputVectorSizes.empty())
1835 assert(inputVectorSizes.size() == destSize &&
1836 "Incorrect number of input vector sizes");
1847 if (vectorSizes.empty()) {
1848 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1854 useInBoundsInsteadOfMasking =
true;
1879 readVectorSizes[innerDimPos[index]] =
1885 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1889 LogicalResult status =
1890 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1891 .reifyResultShapes(rewriter, reifiedRetShapes);
1892 if (status.failed()) {
1893 LDBG(
"Unable to reify result shapes of " << unpackOp);
1898 auto padValue = rewriter.
create<arith::ConstantOp>(
1899 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1904 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1907 PackingMetadata packMetadata;
1910 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1912 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1914 RankedTensorType stripMineTensorType =
1917 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1918 loc, readResult, lastDimToInsertPosPerm);
1921 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1922 stripMineTensorType, packMetadata.reassociations);
1923 mlir::VectorType vecCollapsedType =
1924 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1925 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1926 loc, vecCollapsedType, transposeOp->getResult(0));
1931 unpackOp.getDestType().hasStaticShape()
1933 : shapeCastOp.getResultVectorType().getShape());
1935 loc, reifiedRetShapes[0],
1936 shapeCastOp.getResult().getType().getElementType());
1938 rewriter, loc, shapeCastOp.getResult(), dest,
1939 {}, useInBoundsInsteadOfMasking);
1940 newResults.push_back(write->getResult(0));
1947 static LogicalResult
1951 auto padValue = padOp.getConstantPaddingValue();
1959 LogicalResult status =
1960 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1961 .reifyResultShapes(rewriter, reifiedReturnShapes);
1963 assert(succeeded(status) &&
"failed to reify result shapes");
1965 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1970 loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1972 newResults.push_back(write->getResult(0));
1980 LDBG(
"reduction precondition failed: no reduction iterator\n");
1983 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1984 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1990 LDBG(
"reduction precondition failed: reduction detection failed\n");
1997 static LogicalResult
1999 bool flatten1DDepthwiseConv) {
2000 if (flatten1DDepthwiseConv) {
2001 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
2006 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2007 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2013 Value lhs = conv.getDpsInputOperand(0)->get();
2015 auto shapeWithoutCh = lhsShape.drop_back(1);
2016 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2017 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
2018 "channel dim can be dynamic\n");
2025 static LogicalResult
2027 bool flatten1DDepthwiseConv) {
2028 if (isa<ConvolutionOpInterface>(op.getOperation()))
2037 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2041 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
2046 static LogicalResult
2050 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
2051 return !getConstantIntValue(res).has_value();
2053 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
2057 bool satisfyEmptyCond = inputVectorSizes.empty() &&
2058 unpackOp.getDestType().hasStaticShape() &&
2059 unpackOp.getSourceType().hasStaticShape();
2060 if (!satisfyEmptyCond &&
2067 static LogicalResult
2072 auto sourceType = source.getType();
2073 if (!VectorType::isValidElementType(sourceType.getElementType()))
2089 bool isOutOfBoundsRead =
2090 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2092 if (!padValue && isOutOfBoundsRead) {
2093 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
2100 enum class ConvOperationKind { Conv, Pool };
2118 static std::optional<ConvOperationKind>
2120 int numBlockArguments =
2121 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2123 switch (numBlockArguments) {
2129 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2130 llvm::IsaPred<BlockArgument>);
2132 "Expected a non-block argument operand");
2133 Operation *feedOp = (*feedValIt).getDefiningOp();
2135 return ConvOperationKind::Pool;
2138 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2139 (isa<arith::AndIOp>(feedOp) &&
2142 if (isa<BlockArgument>(v))
2144 if (Operation *op = v.getDefiningOp())
2145 return isCastOfBlockArgument(op);
2148 return std::nullopt;
2151 return ConvOperationKind::Conv;
2155 return ConvOperationKind::Pool;
2157 return std::nullopt;
2163 case vector::CombiningKind::ADD:
2164 case vector::CombiningKind::MAXNUMF:
2165 case vector::CombiningKind::MAXIMUMF:
2166 case vector::CombiningKind::MAXSI:
2167 case vector::CombiningKind::MAXUI:
2168 case vector::CombiningKind::MINNUMF:
2169 case vector::CombiningKind::MINIMUMF:
2170 case vector::CombiningKind::MINSI:
2179 auto getOperandType = [&](
auto operand) {
2180 return dyn_cast<ShapedType>((operand->get()).getType());
2182 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2183 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2184 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2188 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2189 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2197 if (!maybeOper.has_value())
2204 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2205 *maybeKind != vector::CombiningKind::OR) &&
2206 (*maybeOper != ConvOperationKind::Pool ||
2211 auto rhsRank = rhsShapedType.getRank();
2212 if (*maybeOper == ConvOperationKind::Pool) {
2216 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2225 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2227 if (llvm::any_of(linalgOp->getOpOperands(), [&](
OpOperand &operand) {
2228 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2232 if (!inputVectorSizes.empty() &&
2238 linalgOp, flatten1DDepthwiseConv))) {
2239 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
2252 customPreconditions,
2255 customPrecondition(&innerOp, vectorizeNDExtract));
2259 if (!llvm::all_of(innerOp.getOperandTypes(),
2260 VectorType::isValidElementType)) {
2263 if (!llvm::all_of(innerOp.getResultTypes(),
2264 VectorType::isValidElementType)) {
2274 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2281 LDBG(
"precondition failed: not projected permutations\n");
2285 LDBG(
"precondition failed: reduction preconditions\n");
2291 static LogicalResult
2294 auto padValue = packOp.getPaddingValue();
2297 LDBG(
"pad value is not constant: " << packOp <<
"\n");
2301 bool satisfyEmptyCond =
true;
2302 if (inputVectorSizes.empty()) {
2303 if (!packOp.getDestType().hasStaticShape() ||
2304 !packOp.getSourceType().hasStaticShape())
2305 satisfyEmptyCond =
false;
2308 if (!satisfyEmptyCond &&
2310 resultTensorShape.take_front(packOp.getSourceRank()),
2314 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2315 return !getConstantIntValue(v).has_value();
2317 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
2324 static LogicalResult
2327 auto padValue = padOp.getConstantPaddingValue();
2329 LDBG(
"pad value is not constant: " << padOp <<
"\n");
2349 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2350 Value padValue = en.value();
2351 unsigned pos = en.index();
2352 std::optional<int64_t> pad = getConstantIntValue(padValue);
2353 return (!pad.has_value() || pad.value() != 0) &&
2354 resultTensorShape[pos] != 1;
2356 LDBG(
"low pad must all be zero for all non unit dims: " << padOp <<
"\n");
2365 static LogicalResult
2369 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2370 "Number of input vector sizes and scalable dims doesn't match");
2372 size_t numOfScalableDims =
2373 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2375 if (numOfScalableDims == 0)
2378 auto linalgOp = dyn_cast<LinalgOp>(op);
2386 if (numOfScalableDims > 2)
2406 bool seenNonUnitParallel =
false;
2407 auto iterators = linalgOp.getIteratorTypesArray();
2409 int64_t idx = scalableFlags.size() - 1;
2410 while (!scalableFlags[idx]) {
2411 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2412 seenNonUnitParallel |=
2413 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2415 iterators.pop_back();
2416 scalableFlags.pop_back();
2421 switch (iterators.back()) {
2422 case utils::IteratorType::reduction: {
2424 if (iterators.size() != inputVectorSizes.size()) {
2425 LDBG(
"Non-trailing reduction dim requested for scalable "
2429 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2430 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2431 "is not supported\n");
2436 case utils::IteratorType::parallel: {
2438 if (seenNonUnitParallel) {
2439 LDBG(
"Inner parallel dim not requested for scalable "
2451 if (numOfScalableDims == 2) {
2455 if (iterators.back() == utils::IteratorType::reduction) {
2456 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2460 scalableFlags.pop_back();
2461 iterators.pop_back();
2463 if (!scalableFlags.back() ||
2464 (iterators.back() != utils::IteratorType::parallel))
2470 if (linalgOp.hasUserDefinedMaps())
2475 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2476 isa<linalg::MatmulTransposeAOp>(op) ||
2477 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2484 bool flatten1DDepthwiseConv) {
2490 inputScalableVecDims)))
2494 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2497 flatten1DDepthwiseConv);
2499 .Case<tensor::PadOp>([&](
auto padOp) {
2502 .Case<linalg::PackOp>([&](
auto packOp) {
2505 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2508 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2511 .Default([](
auto) {
return failure(); });
2517 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2519 for (
auto op : make_early_inc_range(toReplace)) {
2522 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2523 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2524 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2530 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2531 tensor::InsertSliceOp>(op);
2534 FailureOr<VectorizationResult>
2538 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2539 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2540 LDBG(
"Input vector sizes: ");
2541 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2542 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2543 LDBG(
"Input scalable vector dims: ");
2544 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2545 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2549 flatten1DDepthwiseConv))) {
2550 LDBG(
"Vectorization pre-conditions failed\n");
2556 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2557 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2558 inputScalableVecDims))) {
2559 LDBG(
"Vectorization state couldn't be initialized\n");
2565 auto vectorizeResult =
2567 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2571 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2573 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2574 flatten1DDepthwiseConv);
2575 if (succeeded(convOr)) {
2576 llvm::append_range(results, (*convOr)->getResults());
2580 LDBG(
"Unsupported convolution can't be vectorized.\n");
2584 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2597 .Case<tensor::PadOp>([&](
auto padOp) {
2601 .Case<linalg::PackOp>([&](
auto packOp) {
2605 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2607 inputVectorSizes, results);
2609 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2613 .Default([](
auto) {
return failure(); });
2615 if (failed(vectorizeResult)) {
2616 LDBG(
"Vectorization failed\n");
2624 memref::CopyOp copyOp) {
2625 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2626 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2627 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2632 if (!VectorType::isValidElementType(srcElementType) ||
2633 !VectorType::isValidElementType(dstElementType))
2644 loc, readType, copyOp.getSource(), indices,
2647 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2653 loc,
readValue, copyOp.getTarget(), indices,
2664 template <
typename OpTy>
2672 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2673 if (
auto op = dyn_cast<OpTy>(user))
2674 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2680 tensor::PadOp padOp, OpTy op)
const = 0;
2708 vector::TransferReadOp xferOp)
const override {
2710 if (!padOp.hasZeroLowPad())
2713 auto padValue = padOp.getConstantPaddingValue();
2717 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2722 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2724 xferOp.getBaseMutable().assign(padOp.getSource());
2725 xferOp.getPaddingMutable().assign(padValue);
2770 vector::TransferWriteOp xferOp)
const override {
2772 if (xferOp.getTransferRank() == 0)
2776 if (!padOp.hasZeroLowPad())
2779 auto padValue = padOp.getConstantPaddingValue();
2783 if (!xferOp->hasOneUse())
2785 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2789 if (!trimPadding.hasZeroOffset())
2792 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2800 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2801 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2803 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2819 tensor::ExtractSliceOp afterTrimming)
const {
2822 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2823 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2826 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2827 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2832 if (t1.getRank() != t2.getRank())
2837 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2838 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2840 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2845 if (t1.getNumDynamicDims() == 0)
2853 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2857 assert(
static_cast<size_t>(t1.getRank()) ==
2858 beforeSlice.getMixedSizes().size());
2859 assert(
static_cast<size_t>(t2.getRank()) ==
2860 afterTrimming.getMixedSizes().size());
2862 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2864 if (!t1.isDynamicDim(i))
2866 auto size1 = beforeSlice.getMixedSizes()[i];
2867 auto size2 = afterTrimming.getMixedSizes()[i];
2874 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2875 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2881 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2882 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2883 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2884 minOp1.getOperands() == minOp2.getOperands())
2910 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2911 auto source = bcast.getSource();
2912 if (llvm::dyn_cast<VectorType>(source.getType()))
2920 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2921 return fill.getInputs()[0];
2926 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2933 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2941 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2947 static LogicalResult
2956 auto sourceType = source.getType();
2957 auto resultType = sliceOp.getResultType();
2962 auto elemType = sourceType.getElementType();
2963 padValue = rewriter.
create<arith::ConstantOp>(
2964 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2969 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2970 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2971 if (!inputVectorSizes.empty()) {
2972 vecShape.push_back(inputVectorSizes[i]);
2973 }
else if (!sourceType.isDynamicDim(i)) {
2974 vecShape.push_back(sourceType.getDimSize(i));
2975 }
else if (!resultType.isDynamicDim(i)) {
2981 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2988 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2991 auto loc = sliceOp.getLoc();
2995 vecType.getRank(), rewriter.
create<arith::ConstantIndexOp>(loc, 0));
2997 rewriter, loc, source, vecType.getShape(), padValue,
2998 inputVectorSizes.empty());
3005 writeIndices, inputVectorSizes.empty());
3008 newResults.push_back(write->
getResult(0));
3042 tensor::InsertSliceOp insertOp)
const override {
3044 if (!padOp.hasZeroLowPad())
3047 if (!insertOp.hasUnitStride())
3050 auto padValue = padOp.getConstantPaddingValue();
3054 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3057 if (insertOp.getDest() == padOp.getResult())
3061 padOp.getType().getElementType());
3062 unsigned vecRank = vecType.getRank();
3063 unsigned tensorRank = insertOp.getType().getRank();
3068 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3070 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3071 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3082 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3083 auto read = rewriter.
create<vector::TransferReadOp>(
3084 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3090 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3093 insertOp, read, insertOp.getDest(), writeIndices,
3119 LDBG(
"interleavedUses precondition failed, firstOp: "
3120 << *firstOp <<
", second op: " << *secondOp <<
"\n");
3123 for (
auto v : values) {
3124 for (
auto &u : v.getUses()) {
3126 if (owner == firstOp || owner == secondOp)
3132 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3133 <<
", second op: " << *secondOp <<
"\n");
3143 memref::SubViewOp subViewOp;
3145 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3147 return memref::SubViewOp();
3148 subViewOp = newSubViewOp;
3160 if (xferOp.getMask())
3164 Value viewOrAlloc = xferOp.getBase();
3173 Value subView = subViewOp.getResult();
3176 memref::CopyOp copyOp;
3177 for (
auto &u : subView.
getUses()) {
3178 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3179 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3180 if (newCopyOp.getTarget() != subView)
3194 for (
auto &u : viewOrAlloc.
getUses()) {
3195 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3196 assert(isa<MemRefType>(newFillOp.output().getType()));
3197 if (newFillOp.output() != viewOrAlloc)
3201 maybeFillOp = newFillOp;
3206 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3208 "padding value does not match fill");
3211 Value in = copyOp.getSource();
3217 auto vectorType = xferOp.getVectorType();
3218 Value res = rewriter.
create<vector::TransferReadOp>(
3219 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3220 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3225 rewriter.
eraseOp(maybeFillOp);
3237 if (xferOp.getMask())
3241 Value viewOrAlloc = xferOp.getBase();
3250 Value subView = subViewOp.getResult();
3253 memref::CopyOp copyOp;
3254 for (
auto &u : subViewOp.getResult().getUses()) {
3255 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3256 if (newCopyOp.getSource() != subView)
3268 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3269 Value out = copyOp.getTarget();
3276 auto vector = xferOp.getVector();
3277 rewriter.
create<vector::TransferWriteOp>(
3278 xferOp.getLoc(), vector, out, xferOp.getIndices(),
3279 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3281 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3296 template <
int N,
typename IntTy,
typename... IntTy2>
3298 val = shapedType.getShape()[N];
3303 template <
typename... IntTy>
3305 bindShapeDims<0>(shapedType, vals...);
3343 struct Conv1DGenerator
3345 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3348 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3349 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3350 resShaped = linalgOp.getDpsInitOperand(0)->get();
3351 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3352 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3353 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3358 setConvOperationKind(reduceOp);
3361 reductionKind = maybeKind.value();
3369 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3370 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3392 int64_t nSize, wSize, cSize, kwSize, fSize;
3395 switch (conv1DOpOrder) {
3398 nSize = fSize = cSize = 0;
3405 (wSize + kwSize - 1)};
3406 rhsShape = {kwSize};
3413 case ConvOperationKind::Conv:
3417 case ConvOperationKind::Pool:
3427 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3431 case ConvOperationKind::Conv:
3432 rhsShape = {kwSize, cSize, fSize};
3434 case ConvOperationKind::Pool:
3435 rhsShape = {kwSize};
3438 resShape = {nSize, wSize, fSize};
3444 case ConvOperationKind::Conv:
3448 case ConvOperationKind::Pool:
3454 lhsShape = {nSize, cSize,
3458 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3461 case ConvOperationKind::Conv:
3462 rhsShape = {fSize, cSize, kwSize};
3464 case ConvOperationKind::Pool:
3465 rhsShape = {kwSize};
3468 resShape = {nSize, fSize, wSize};
3472 vector::TransferWriteOp write;
3473 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3478 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3480 Type lhsEltType = lhsShapedType.getElementType();
3481 Type rhsEltType = rhsShapedType.getElementType();
3482 Type resEltType = resShapedType.getElementType();
3492 Value lhs = rewriter.
create<vector::TransferReadOp>(
3493 loc, lhsType, lhsShaped, lhsPadding,
3496 Value rhs =
nullptr;
3497 if (oper == ConvOperationKind::Conv)
3498 rhs = rewriter.
create<vector::TransferReadOp>(
3499 loc, rhsType, rhsShaped, rhsPadding,
3501 Value res = rewriter.
create<vector::TransferReadOp>(
3502 loc, resType, resShaped, resPadding,
3508 switch (conv1DOpOrder) {
3516 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3517 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3519 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3522 if (oper == ConvOperationKind::Conv)
3523 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3525 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3526 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3537 kwSize, strideW, dilationW, wSizeStep,
3540 if (oper == ConvOperationKind::Conv)
3543 wSizeStep, isSingleChanneled);
3545 auto linearIndex = [&](int64_t kw, int64_t w) {
3546 return kw * (wSize / wSizeStep) + w;
3552 for (int64_t kw = 0; kw < kwSize; ++kw) {
3553 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3555 case ConvOperationKind::Conv:
3556 if (isSingleChanneled) {
3557 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3558 lhsVals[linearIndex(kw, w)],
3559 rhsVals[kw], resVals[w]);
3561 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3562 lhsVals[linearIndex(kw, w)],
3563 rhsVals[kw], resVals[w]);
3566 case ConvOperationKind::Pool:
3567 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3583 switch (conv1DOpOrder) {
3590 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3591 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3597 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3605 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3606 if (srcElementType == dstElementType)
3611 const Type dstType =
3612 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3614 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3615 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3618 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3619 srcWidth < dstWidth)
3620 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3622 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3623 srcWidth < dstWidth)
3624 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3626 assert(
false &&
"unhandled promotion case");
3633 vector::IteratorType par = vector::IteratorType::parallel;
3634 vector::IteratorType red = vector::IteratorType::reduction;
3639 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3641 MapList{{n, w, c}, {c, f}, {n, w, f}},
3643 contrationOp.setKind(reductionKind);
3644 return contrationOp;
3651 return rewriter.
create<vector::OuterProductOp>(
3652 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3674 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3675 bool channelDimScalableFlag,
3677 bool scalableChDim =
false;
3678 bool useMasking =
false;
3679 int64_t nSize, wSize, cSize, kwSize;
3682 if (ShapedType::isDynamic(cSize)) {
3683 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3684 cSize = channelDimVecSize;
3688 scalableChDim = channelDimScalableFlag;
3692 assert(!(useMasking && flatten) &&
3693 "Unsupported flattened conv with dynamic shapes");
3698 vector::TransferWriteOp write;
3699 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3704 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3706 Type lhsEltType = lhsShapedType.getElementType();
3707 Type rhsEltType = rhsShapedType.getElementType();
3708 Type resEltType = resShapedType.getElementType();
3713 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3715 lhsEltType, {
false,
false, scalableChDim});
3716 VectorType rhsType =
3718 {
false, scalableChDim});
3719 VectorType resType =
3721 {
false,
false, scalableChDim});
3734 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3735 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3739 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3742 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3749 Value lhs = rewriter.
create<vector::TransferReadOp>(
3750 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3752 auto maybeMaskedLhs = maybeMaskXferOp(
3753 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3756 Value rhs = rewriter.
create<vector::TransferReadOp>(
3757 loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3759 auto maybeMaskedRhs = maybeMaskXferOp(
3760 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3763 Value res = rewriter.
create<vector::TransferReadOp>(
3764 loc, resType, resShaped,
ValueRange{zero, zero, zero},
3766 auto maybeMaskedRes = maybeMaskXferOp(
3767 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3779 for (int64_t kw = 0; kw < kwSize; ++kw) {
3780 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3781 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3782 loc, maybeMaskedLhs->getResult(0),
3784 inOutSliceSizes, inOutStrides));
3788 for (int64_t kw = 0; kw < kwSize; ++kw) {
3789 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3790 loc, maybeMaskedRhs->getResult(0),
3794 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3795 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3796 loc, maybeMaskedRes->getResult(0),
3801 auto linearIndex = [&](int64_t kw, int64_t w) {
3802 return kw * (wSize / wSizeStep) + w;
3808 auto lhsTypeAfterFlattening =
3810 auto resTypeAfterFlattening =
3814 for (int64_t kw = 0; kw < kwSize; ++kw) {
3815 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3816 Value lhsVal = lhsVals[linearIndex(kw, w)];
3817 Value resVal = resVals[w];
3821 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3822 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3823 resVal = rewriter.
create<vector::ShapeCastOp>(
3824 loc, resTypeAfterFlattening, resVals[w]);
3826 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3827 rhsVals[kw], resVal, flatten);
3830 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3837 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3839 for (
auto &collection :
3840 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3841 for (
Value v : collection)
3848 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3849 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3850 loc, resVals[w], maybeMaskedRes->getResult(0),
3860 loc, maybeMaskedRes->getResult(0), resShaped,
3862 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3873 auto rhsTy = cast<ShapedType>(rhs.
getType());
3874 auto resTy = cast<ShapedType>(res.
getType());
3877 lhs =
promote(rewriter, loc, lhs, resTy);
3888 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3889 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3892 for (
int i = 0; i < resSize / rhsSize; ++i) {
3893 for (
int j = 0;
j < rhsSize; ++
j)
3894 indices.push_back(
j);
3897 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3900 rhs = rewriter.
create<vector::BroadcastOp>(
3901 loc, resTy.clone(rhsTy.getElementType()), rhs);
3903 rhs =
promote(rewriter, loc, rhs, resTy);
3908 if (isa<FloatType>(resTy.getElementType()))
3909 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3911 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3912 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3917 FailureOr<Operation *> generateNonChanneledConv() {
3920 if (!iters({Par(), Red()}))
3922 "failed to match conv::W 1-par 1-red");
3925 if (layout({ {w + kw},
3935 FailureOr<Operation *> generateNwcConv() {
3938 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3940 op,
"failed to match conv::Nwc 3-par 2-red");
3943 if (layout({ {n, strideW * w + dilationW * kw, c},
3953 FailureOr<Operation *> generateNcwConv() {
3956 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3958 op,
"failed to match conv::Ncw 3-par 2-red");
3960 if (layout({ {n, c, strideW * w + dilationW * kw},
3970 FailureOr<Operation *> generateNwcPooling() {
3973 if (!iters({Par(), Par(), Par(), Red()}))
3975 "failed to match pooling 3-par 1-red");
3978 if (layout({ {n, strideW * w + dilationW * kw, c},
3988 FailureOr<Operation *> generateNcwPooling() {
3991 if (!iters({Par(), Par(), Par(), Red()}))
3993 "failed to match pooling 3-par 1-red");
3995 if (layout({ {n, c, strideW * w + dilationW * kw},
4005 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4006 bool vecChDimScalableFlag =
false,
4007 bool flatten =
false) {
4010 if (!iters({Par(), Par(), Par(), Red()}))
4012 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4015 if (layout({ {n, strideW * w + dilationW * kw, c},
4018 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4024 ConvOperationKind oper = ConvOperationKind::Conv;
4026 StringAttr poolExtOp;
4027 bool isPoolExt =
false;
4028 int strideW, dilationW;
4029 Value lhsShaped, rhsShaped, resShaped;
4030 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4031 vector::CombiningKind reductionKind;
4034 void setConvOperationKind(
Operation *reduceOp) {
4035 int numBlockArguments =
4036 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4037 if (numBlockArguments == 1) {
4042 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4043 llvm::IsaPred<BlockArgument>);
4044 Operation *feedOp = (*feedValIt).getDefiningOp();
4046 oper = ConvOperationKind::Pool;
4051 oper = ConvOperationKind::Conv;
4055 oper = ConvOperationKind::Pool;
4065 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4066 Conv1DGenerator conv1dGen(rewriter, op);
4067 auto res = conv1dGen.generateNonChanneledConv();
4070 res = conv1dGen.generateNwcConv();
4073 res = conv1dGen.generateNcwConv();
4076 res = conv1dGen.generateNwcPooling();
4079 res = conv1dGen.generateNcwPooling();
4086 uint64_t vecChDimSize = ShapedType::kDynamic;
4087 bool vecChDimScalableFlag =
false;
4088 if (!inputVecSizes.empty()) {
4091 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4092 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4093 "Not a 1D depthwise conv!");
4096 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4097 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4099 vecChDimSize = inputVecSizes[chDimIdx];
4100 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4102 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4103 flatten1DDepthwiseConv);
4112 if (failed(resultOrFail))
4116 rewriter.
eraseOp(op.getOperation());
4119 assert(newOp->
getNumResults() == 1 &&
"expected single result");
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
SmallVector< int64_t > innerDimsPos
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.