36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
95 template <
typename OpType>
98 block.
walk([&](OpType op) {
113 int64_t nSize, int64_t wSize, int64_t cSize,
114 int64_t kwSize,
int strideW,
int dilationW,
115 int64_t wSizeStep,
bool isSingleChanneled) {
117 if (isSingleChanneled) {
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 for (int64_t w = 0; w < wSize; w += wSizeStep) {
124 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(rewriter.
create<vector::ExtractOp>(
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep,
bool isSingleChanneled) {
167 if (isSingleChanneled) {
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
180 for (int64_t w = 0; w < wSize; w += wSizeStep) {
181 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = rewriter.
create<vector::InsertStridedSliceOp>(
206 for (int64_t w = 0; w < wSize; w += wSizeStep) {
207 res = rewriter.
create<vector::InsertStridedSliceOp>(
222 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
239 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
242 if (dimPermutation.has_value()) {
244 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
246 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
248 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
249 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
261 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
266 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
267 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
273 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
282 std::optional<AffineMap> maybeMaskingMap);
287 bool isValidMaskingMap(
AffineMap maskingMap) {
336 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
339 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
340 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
342 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
343 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
350 unsigned operandDimPos;
351 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
355 Value dynamicDim = linalgOp.hasPureTensorSemantics()
357 linalgOp.getLoc(), operand, operandDimPos)
359 linalgOp.getLoc(), operand, operandDimPos);
360 iterSpaceValueSizes.push_back(dynamicDim);
376 if (!inputVectorSizes.empty()) {
380 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
381 scalableVecDims.append(inputScalableVecDims.begin(),
382 inputScalableVecDims.end());
387 canonicalVecShape = linalgOp.getStaticLoopRanges();
388 scalableVecDims.append(linalgOp.getNumLoops(),
false);
391 LDBG(
"Canonical vector shape: ");
392 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
393 LLVM_DEBUG(llvm::dbgs() <<
"\n");
394 LDBG(
"Scalable vector dims: ");
395 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
396 LLVM_DEBUG(llvm::dbgs() <<
"\n");
398 if (ShapedType::isDynamicShape(canonicalVecShape))
402 initIterSpaceStaticSizes(linalgOp);
407 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417 Value VectorizationState::getOrCreateMaskFor(
419 std::optional<AffineMap> maybeMaskingMap) {
421 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
422 "Ill-formed masking map.");
425 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
429 assert(!maskableOp.isMasked() &&
430 "Masking an operation that is already masked");
433 assert((!maybeMaskingMap || *maybeMaskingMap) &&
434 "Unexpected null mask permutation map");
436 maybeMaskingMap ? *maybeMaskingMap
438 linalgOp.getNumLoops(), rewriter.
getContext());
440 LDBG(
"Masking map: " << maskingMap <<
"\n");
444 auto activeMaskIt = activeMaskCache.find(maskingMap);
445 if (activeMaskIt != activeMaskCache.end()) {
446 Value mask = activeMaskIt->second;
447 LDBG(
"Reusing mask: " << mask <<
"\n");
458 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
459 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
460 auto maskShape = maskType.getShape();
462 LDBG(
"Mask shape: ");
463 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
464 LLVM_DEBUG(llvm::dbgs() <<
"\n");
466 if (permutedStaticSizes == maskShape) {
467 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
468 activeMaskCache[maskingMap] =
Value();
475 assert(!maskShape.empty() && !upperBounds.empty() &&
476 "Masked 0-d vectors are not supported yet");
479 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
480 maskType, upperBounds);
481 LDBG(
"Creating new mask: " << mask <<
"\n");
482 activeMaskCache[maskingMap] = mask;
489 std::optional<AffineMap> maybeIndexingMap) {
490 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
492 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
493 if (maybeIndexingMap)
494 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
498 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
501 LDBG(
"No mask required\n");
506 assert(opToMask &&
"Expected a valid operation to mask");
507 auto maskOp = cast<vector::MaskOp>(
509 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
515 LDBG(
"Masked operation: " << *maskOp <<
"\n");
538 "expected projected permutation");
540 assert(res.getNumDims() ==
541 (res.getNumResults() - res.getNumOfZeroResults()) &&
542 "expected reindexed map with same number of dims and results");
574 std::optional<vector::CombiningKind>
576 using ::mlir::vector::CombiningKind;
581 .Case<arith::AddIOp, arith::AddFOp>(
582 [&](
auto op) {
return CombiningKind::ADD; })
583 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
584 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
585 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
586 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
587 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
588 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
590 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
591 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
592 .Case<arith::MulIOp, arith::MulFOp>(
593 [&](
auto op) {
return CombiningKind::MUL; })
594 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
595 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
596 .Default([&](
auto op) {
return std::nullopt; });
607 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
612 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
613 combinerOps.size() != 1)
617 return combinerOps[0];
623 auto dstVecType = dyn_cast<VectorType>(dstType);
625 if (dstVecType.getRank() == 0)
631 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
643 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
644 return b.
create<vector::MultiDimReductionOp>(
645 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
649 return llvm::to_vector(
656 return isa<linalg::ReduceOp>(op) ||
657 (isa<linalg::GenericOp>(op) &&
671 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
672 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
681 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
683 auto vectorType = state.getCanonicalVecType(
687 if (vectorType.getRank() > 0) {
690 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
692 assert(value.
getType() == vectorType &&
"Incorrect type");
693 write = rewriter.
create<vector::TransferWriteOp>(
694 loc, value, outputOperand->
get(), indices, writeMap);
697 if (!isa<VectorType>(value.
getType()))
698 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
699 assert(value.
getType() == vectorType &&
"Incorrect type");
700 write = rewriter.
create<vector::TransferWriteOp>(
704 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
708 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
709 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
714 LDBG(
"vectorized op: " << *write <<
"\n");
724 std::function<LogicalResult(
Operation *,
bool)>;
743 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
752 linalgOp.getDpsInitOperand(output.index()), state);
754 newResults.push_back(newResult);
768 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
771 auto loc = indexOp.getLoc();
774 auto dim = indexOp.getDim();
776 auto indexVectorType =
778 state.getScalableVecDims()[dim]);
779 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
783 if (dim == targetShape.size() - 1)
789 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
790 std::swap(permPattern[dim], permPattern.back());
794 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
795 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
798 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
799 std::swap(transposition.back(), transposition[dim]);
801 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
809 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
813 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
818 if (not extractOp.getIndices().empty()) {
819 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
823 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
824 return !VectorType::isValidElementType(type);
844 tensor::ExtractOp extractOp,
847 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
848 auto loc = extractOp.getLoc();
851 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
853 const size_t numIndices = extractOp.getIndices().size();
854 for (
size_t i = 1; i < numIndices; i++) {
855 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
859 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
862 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
865 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
867 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
893 (linalgOp.hasDynamicShape() ||
894 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
895 "For statically shaped Linalg Ops, only one "
896 "non-unit loop dim is expected");
897 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
899 size_t idx = loopRanges.size() - 1;
900 for (; idx != 0; idx--)
901 if (loopRanges[idx] != 1)
909 VectorType resType) {
911 assert(((llvm::count_if(resType.getShape(),
912 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
913 "n-D vectors are not yet supported");
919 auto *block = linalgOp.getBlock();
920 if (isa<BlockArgument>(val))
921 return llvm::all_of(block->getArguments(),
922 [&val](
Value v) { return (v != val); });
925 assert(defOp &&
"This is neither a block argument nor an operation result");
930 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
931 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
934 auto *ancestor = block->findAncestorOpInBlock(*defOp);
941 if (isa<arith::ConstantOp>(ancestor))
945 for (
auto op : ancestor->getOperands())
969 bool &foundIndexOp, VectorType resType) {
971 assert(((llvm::count_if(resType.getShape(),
972 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973 "n-D vectors are not yet supported");
979 auto *block = linalgOp.getBlock();
980 if (isa<BlockArgument>(val))
981 return llvm::all_of(block->getArguments(),
982 [&val](
Value v) { return (v != val); });
985 assert(defOp &&
"This is neither a block argument nor an operation result");
987 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
990 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
994 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1001 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1004 bool result =
false;
1005 for (
auto op : ancestor->getOperands())
1025 LinalgOp &linalgOp, VectorType resType) {
1027 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1030 if (inputShape.getShape().empty())
1035 bool isOutput1DVector =
1036 (llvm::count_if(resType.getShape(),
1037 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1039 if (!isOutput1DVector)
1042 bool leadingIdxsLoopInvariant =
true;
1048 auto indices = extractOp.getIndices();
1049 auto leadIndices = indices.drop_back(1);
1052 if (inputShape.getShape()[i] == 1)
1058 if (!leadingIdxsLoopInvariant) {
1059 LDBG(
"Found gather load: " << extractOp);
1067 auto extractOpTrailingIdx = indices.back();
1071 if (leadingIdxsLoopInvariant &&
1073 LDBG(
"Found scalar broadcast load: " << extractOp);
1082 bool foundIndexOp =
false;
1084 foundIndexOp, resType);
1087 bool isRowVector = resType.getShape().back() != 1;
1088 isContiguousLoad &= (foundIndexOp && isRowVector);
1090 if (isContiguousLoad) {
1091 LDBG(
"Found contigous load: " << extractOp);
1096 LDBG(
"Found gather load: " << extractOp);
1107 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1110 auto loc = extractOp.getLoc();
1113 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1114 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1118 auto passThruConstantOp =
1124 extractOp.getIndices().size(),
1125 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1136 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1137 maskConstantOp, passThruConstantOp);
1138 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1140 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1163 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1164 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1166 transferReadIdxs.push_back(idx);
1170 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1173 resultType.getScalableDims().back()),
1175 transferReadIdxs.push_back(
1176 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1180 auto dstRank = resultType.getRank();
1181 auto srcRank = extractOp.getTensor().getType().getRank();
1190 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1191 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1192 permutationMap, inBounds);
1199 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
1201 auto *maskedReadOp =
1204 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1212 int32_t rankDiff = dstRank - srcRank;
1220 while (rankDiff > 0) {
1221 permutationMap = permutationMap.insertResult(
1226 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1227 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1230 LDBG(
"Vectorised as contiguous load: " << extractOp);
1243 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1244 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1248 (outputType && reduceType.getShape() == outputType.getShape()))
1277 LDBG(
"vectorize op " << *op <<
"\n");
1280 if (!customVectorizationHooks.empty()) {
1281 for (
auto &customFunc : customVectorizationHooks) {
1291 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1301 auto blockArg = dyn_cast<BlockArgument>(operand);
1302 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1303 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1307 linalgOp.getRegionOutputArgs(),
1308 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1311 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1313 if (!reductionOperands.empty()) {
1314 assert(reductionOperands.size() == 1);
1316 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1317 reductionOperands[0].second, bvm);
1324 VectorType firstMaxRankedType;
1326 auto vecOperand = bvm.
lookup(operand);
1327 assert(vecOperand &&
"Vector operand couldn't be found");
1329 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1330 if (vecType && (!firstMaxRankedType ||
1331 firstMaxRankedType.getRank() < vecType.getRank()))
1332 firstMaxRankedType = vecType;
1338 assert(vecOperand &&
"Vector operand couldn't be found");
1340 if (firstMaxRankedType) {
1343 firstMaxRankedType.getScalableDims());
1346 vecOperands.push_back(vecOperand);
1352 resultTypes.push_back(
1355 firstMaxRankedType.getScalableDims())
1387 static LogicalResult
1391 LDBG(
"Vectorizing operation as linalg generic\n");
1392 Block *block = linalgOp.getBlock();
1399 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1401 if (linalgOp.getNumDpsInits() == 0)
1406 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1407 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1408 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1409 if (linalgOp.isScalar(opOperand)) {
1410 bvm.
map(bbarg, opOperand->get());
1416 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1419 VectorType readType;
1421 if (linalgOp.isDpsInput(opOperand)) {
1424 readType = state.getCanonicalVecType(elemType);
1431 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1437 loc, readType, opOperand->get(), indices, readMap);
1438 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1443 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1445 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1451 if (readType.getRank() == 0)
1467 hooks.push_back(vectorizeYield);
1474 hooks.push_back(vectorizeIndex);
1481 hooks.push_back(vectorizeExtract);
1488 LDBG(
"failed to vectorize: " << op <<
"\n");
1493 state.maskOperation(rewriter, result.
newOp, linalgOp);
1494 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1519 bool useInBoundsInsteadOfMasking) {
1521 auto inputType = cast<VectorType>(input.
getType());
1522 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1523 inputType.getElementType());
1524 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1525 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1526 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1528 if (useInBoundsInsteadOfMasking) {
1530 for (
unsigned i = 0; i < rank; i++)
1531 inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1532 !ShapedType::isDynamic(destShape[i]);
1540 assert(llvm::none_of(
1541 destShape.drop_front(inputVectorSizes.size()),
1542 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1543 "Only dims aligned with inputVectorSizes may be dynamic");
1544 if (useInBoundsInsteadOfMasking)
1546 bool needMaskForWrite = !llvm::equal(
1547 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1548 if (needMaskForWrite) {
1550 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1551 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1554 Value maskForWrite =
1555 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1587 static LogicalResult
1596 auto padValue = packOp.getPaddingValue();
1598 padValue = rewriter.
create<arith::ConstantOp>(
1599 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1602 LogicalResult status =
1603 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1604 .reifyResultShapes(rewriter, reifiedReturnShapes);
1606 assert(succeeded(status) &&
"failed to reify result shapes");
1611 bool useInBoundsInsteadOfMasking =
false;
1612 if (inputVectorSizes.empty()) {
1614 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1615 useInBoundsInsteadOfMasking =
true;
1620 auto innerTiles = packOp.getStaticInnerTiles();
1629 rewriter, loc, packOp.getSource(), inputShape, padValue,
1630 useInBoundsInsteadOfMasking);
1636 packOp.getDestType().getElementType());
1638 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1641 auto destPermutation =
1643 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1644 loc, shapeCastOp.getResult(), destPermutation);
1648 rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1649 inputVectorSizes,
false);
1650 newResults.push_back(write->getResult(0));
1663 static LogicalResult
1672 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1677 bool useInBoundsInsteadOfMasking =
false;
1680 auto destSize = unpackOp.getDestRank();
1682 if (!inputVectorSizes.empty())
1683 assert(inputVectorSizes.size() == destSize &&
1684 "Incorrect number of input vector sizes");
1695 if (vectorSizes.empty()) {
1696 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1702 useInBoundsInsteadOfMasking =
true;
1727 readVectorSizes[innerDimPos[index]] =
1733 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1737 LogicalResult status =
1738 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1739 .reifyResultShapes(rewriter, reifiedRetShapes);
1740 if (status.failed()) {
1741 LDBG(
"Unable to reify result shapes of " << unpackOp);
1746 auto padValue = rewriter.
create<arith::ConstantOp>(
1747 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1752 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1755 PackingMetadata packMetadata;
1758 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1760 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1762 RankedTensorType stripMineTensorType =
1765 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1766 loc, readResult, lastDimToInsertPosPerm);
1769 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1770 stripMineTensorType, packMetadata.reassociations);
1771 mlir::VectorType vecCollapsedType =
1772 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1773 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1774 loc, vecCollapsedType, transposeOp->getResult(0));
1779 unpackOp.getDestType().hasStaticShape()
1781 : shapeCastOp.getResultVectorType().getShape());
1783 rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1784 writeVectorSizes, useInBoundsInsteadOfMasking);
1785 newResults.push_back(write->
getResult(0));
1792 static LogicalResult
1796 auto padValue = padOp.getConstantPaddingValue();
1804 LogicalResult status =
1805 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1806 .reifyResultShapes(rewriter, reifiedReturnShapes);
1808 assert(succeeded(status) &&
"failed to reify result shapes");
1810 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1813 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1815 newResults.push_back(write->
getResult(0));
1823 LDBG(
"reduction precondition failed: no reduction iterator\n");
1826 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1827 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1833 LDBG(
"reduction precondition failed: reduction detection failed\n");
1840 static LogicalResult
1842 bool flatten1DDepthwiseConv) {
1843 if (flatten1DDepthwiseConv) {
1844 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1849 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1850 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1856 Value lhs = conv.getDpsInputOperand(0)->get();
1858 auto shapeWithoutCh = lhsShape.drop_back(1);
1859 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1860 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1861 "channel dim can be dynamic\n");
1868 static LogicalResult
1870 bool flatten1DDepthwiseConv) {
1871 if (isa<ConvolutionOpInterface>(op.getOperation()))
1880 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1884 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1889 static LogicalResult
1893 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1894 return !getConstantIntValue(res).has_value();
1896 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1900 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1901 unpackOp.getDestType().hasStaticShape() &&
1902 unpackOp.getSourceType().hasStaticShape();
1903 if (!satisfyEmptyCond &&
1910 static LogicalResult
1915 auto sourceType = source.getType();
1916 if (!VectorType::isValidElementType(sourceType.getElementType()))
1932 bool isOutOfBoundsRead =
1933 !sourceType.hasStaticShape() && inputVectorSizes.empty();
1935 if (!padValue && isOutOfBoundsRead) {
1936 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
1943 enum class ConvOperationKind { Conv, Pool };
1961 static std::optional<ConvOperationKind>
1963 int numBlockArguments =
1964 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
1966 switch (numBlockArguments) {
1972 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
1973 llvm::IsaPred<BlockArgument>);
1975 "Expected a non-block argument operand");
1976 Operation *feedOp = (*feedValIt).getDefiningOp();
1978 return ConvOperationKind::Pool;
1981 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
1982 (isa<arith::AndIOp>(feedOp) &&
1985 if (isa<BlockArgument>(v))
1987 if (Operation *op = v.getDefiningOp())
1988 return isCastOfBlockArgument(op);
1991 return std::nullopt;
1994 return ConvOperationKind::Conv;
1998 return ConvOperationKind::Pool;
2000 return std::nullopt;
2006 case vector::CombiningKind::ADD:
2007 case vector::CombiningKind::MAXNUMF:
2008 case vector::CombiningKind::MAXIMUMF:
2009 case vector::CombiningKind::MAXSI:
2010 case vector::CombiningKind::MAXUI:
2011 case vector::CombiningKind::MINNUMF:
2012 case vector::CombiningKind::MINIMUMF:
2013 case vector::CombiningKind::MINSI:
2022 auto getOperandType = [&](
auto operand) {
2023 return dyn_cast<ShapedType>((operand->get()).getType());
2025 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2026 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2027 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2031 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2032 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2040 if (!maybeOper.has_value())
2047 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2048 *maybeKind != vector::CombiningKind::OR) &&
2049 (*maybeOper != ConvOperationKind::Pool ||
2054 auto rhsRank = rhsShapedType.getRank();
2055 if (*maybeOper == ConvOperationKind::Pool) {
2059 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2068 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2070 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2073 if (!inputVectorSizes.empty() &&
2079 linalgOp, flatten1DDepthwiseConv))) {
2080 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
2093 customPreconditions,
2096 customPrecondition(&innerOp, vectorizeNDExtract));
2100 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
2101 return !VectorType::isValidElementType(type);
2105 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
2106 return !VectorType::isValidElementType(type);
2117 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2124 LDBG(
"precondition failed: not projected permutations\n");
2128 LDBG(
"precondition failed: reduction preconditions\n");
2134 static LogicalResult
2137 auto padValue = packOp.getPaddingValue();
2140 LDBG(
"pad value is not constant: " << packOp <<
"\n");
2144 bool satisfyEmptyCond =
true;
2145 if (inputVectorSizes.empty()) {
2146 if (!packOp.getDestType().hasStaticShape() ||
2147 !packOp.getSourceType().hasStaticShape())
2148 satisfyEmptyCond =
false;
2151 if (!satisfyEmptyCond &&
2153 resultTensorShape.take_front(packOp.getSourceRank()),
2157 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2158 return !getConstantIntValue(v).has_value();
2160 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
2167 static LogicalResult
2170 auto padValue = padOp.getConstantPaddingValue();
2172 LDBG(
"pad value is not constant: " << padOp <<
"\n");
2192 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2193 Value padValue = en.value();
2194 unsigned pos = en.index();
2195 std::optional<int64_t> pad = getConstantIntValue(padValue);
2196 return (!pad.has_value() || pad.value() != 0) &&
2197 resultTensorShape[pos] != 1;
2199 LDBG(
"low pad must all be zero for all non unit dims: " << padOp <<
"\n");
2208 static LogicalResult
2212 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2213 "Number of input vector sizes and scalable dims doesn't match");
2215 size_t numOfScalableDims =
2216 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2218 if (numOfScalableDims == 0)
2221 auto linalgOp = dyn_cast<LinalgOp>(op);
2229 if (numOfScalableDims > 2)
2249 bool seenNonUnitParallel =
false;
2250 auto iterators = linalgOp.getIteratorTypesArray();
2252 int64_t idx = scalableFlags.size() - 1;
2253 while (!scalableFlags[idx]) {
2254 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2255 seenNonUnitParallel |=
2256 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2258 iterators.pop_back();
2259 scalableFlags.pop_back();
2264 switch (iterators.back()) {
2265 case utils::IteratorType::reduction: {
2267 if (iterators.size() != inputVectorSizes.size()) {
2268 LDBG(
"Non-trailing reduction dim requested for scalable "
2272 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2273 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2274 "is not supported\n");
2279 case utils::IteratorType::parallel: {
2281 if (seenNonUnitParallel) {
2282 LDBG(
"Inner parallel dim not requested for scalable "
2294 if (numOfScalableDims == 2) {
2298 if (iterators.back() == utils::IteratorType::reduction) {
2299 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2303 scalableFlags.pop_back();
2304 iterators.pop_back();
2306 if (!scalableFlags.back() ||
2307 (iterators.back() != utils::IteratorType::parallel))
2313 if (linalgOp.hasUserDefinedMaps())
2318 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2319 isa<linalg::MatmulTransposeAOp>(op) ||
2320 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2327 bool flatten1DDepthwiseConv) {
2333 inputScalableVecDims)))
2337 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2340 flatten1DDepthwiseConv);
2342 .Case<tensor::PadOp>([&](
auto padOp) {
2345 .Case<linalg::PackOp>([&](
auto packOp) {
2348 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2351 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2354 .Default([](
auto) {
return failure(); });
2360 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2362 for (
auto op : make_early_inc_range(toReplace)) {
2365 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2366 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2367 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2373 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2374 tensor::InsertSliceOp>(op);
2386 bool vectorizeNDExtract,
2387 bool flatten1DDepthwiseConv) {
2388 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2389 LDBG(
"Input vector sizes: ");
2390 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2391 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2392 LDBG(
"Input scalable vector dims: ");
2393 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2394 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2398 flatten1DDepthwiseConv))) {
2399 LDBG(
"Vectorization pre-conditions failed\n");
2405 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2406 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2407 inputScalableVecDims))) {
2408 LDBG(
"Vectorization state couldn't be initialized\n");
2414 auto vectorizeResult =
2416 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2420 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2422 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2423 flatten1DDepthwiseConv);
2424 if (succeeded(convOr)) {
2425 llvm::append_range(results, (*convOr)->getResults());
2429 LDBG(
"Unsupported convolution can't be vectorized.\n");
2433 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2446 .Case<tensor::PadOp>([&](
auto padOp) {
2450 .Case<linalg::PackOp>([&](
auto packOp) {
2454 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2456 inputVectorSizes, results);
2458 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2462 .Default([](
auto) {
return failure(); });
2464 if (failed(vectorizeResult)) {
2465 LDBG(
"Vectorization failed\n");
2469 if (!results.empty())
2478 memref::CopyOp copyOp) {
2479 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2480 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2481 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2486 if (!VectorType::isValidElementType(srcElementType) ||
2487 !VectorType::isValidElementType(dstElementType))
2498 loc, readType, copyOp.getSource(), indices,
2500 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2506 loc,
readValue, copyOp.getTarget(), indices,
2517 template <
typename OpTy>
2525 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2526 if (
auto op = dyn_cast<OpTy>(user))
2527 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2533 tensor::PadOp padOp, OpTy op)
const = 0;
2561 vector::TransferReadOp xferOp)
const override {
2563 if (!padOp.hasZeroLowPad())
2566 auto padValue = padOp.getConstantPaddingValue();
2570 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2575 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2577 xferOp.getSourceMutable().assign(padOp.getSource());
2578 xferOp.getPaddingMutable().assign(padValue);
2623 vector::TransferWriteOp xferOp)
const override {
2625 if (xferOp.getTransferRank() == 0)
2629 if (!padOp.hasZeroLowPad())
2632 auto padValue = padOp.getConstantPaddingValue();
2636 if (!xferOp->hasOneUse())
2638 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2642 if (!trimPadding.hasZeroOffset())
2645 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2653 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2654 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2656 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2672 tensor::ExtractSliceOp afterTrimming)
const {
2675 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2676 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2679 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2680 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2685 if (t1.getRank() != t2.getRank())
2690 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2691 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2693 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2698 if (t1.getNumDynamicDims() == 0)
2706 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2710 assert(
static_cast<size_t>(t1.getRank()) ==
2711 beforeSlice.getMixedSizes().size());
2712 assert(
static_cast<size_t>(t2.getRank()) ==
2713 afterTrimming.getMixedSizes().size());
2715 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2717 if (!t1.isDynamicDim(i))
2719 auto size1 = beforeSlice.getMixedSizes()[i];
2720 auto size2 = afterTrimming.getMixedSizes()[i];
2727 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2728 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2734 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2735 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2736 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2737 minOp1.getOperands() == minOp2.getOperands())
2763 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2764 auto source = bcast.getSource();
2765 if (llvm::dyn_cast<VectorType>(source.getType()))
2773 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2774 return fill.getInputs()[0];
2779 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2786 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2794 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2800 static LogicalResult
2809 auto sourceType = source.getType();
2810 auto resultType = sliceOp.getResultType();
2815 auto elemType = sourceType.getElementType();
2816 padValue = rewriter.
create<arith::ConstantOp>(
2817 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2824 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2825 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2826 if (!inputVectorSizes.empty()) {
2827 vecShape.push_back(inputVectorSizes[i]);
2828 readInBounds.push_back(
false);
2829 writeInBounds.push_back(
false);
2830 }
else if (!sourceType.isDynamicDim(i)) {
2831 vecShape.push_back(sourceType.getDimSize(i));
2834 readInBounds.push_back(
true);
2835 writeInBounds.push_back(
true);
2836 }
else if (!resultType.isDynamicDim(i)) {
2842 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2845 readInBounds.push_back(
false);
2848 writeInBounds.push_back(
false);
2856 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2864 if (!inputVectorSizes.empty()) {
2867 LDBG(
"Unable to get the defining Op of " << sliceOp);
2871 LogicalResult status =
2872 cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2873 rewriter, reifiedSrcSizes);
2874 if (status.failed()) {
2875 LDBG(
"Unable to reify result shapes of " << srcDefOp);
2881 maskOp = rewriter.
create<vector::CreateMaskOp>(
2882 sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2887 rewriter.
create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2889 sliceOp.getLoc(), vecType, source, readIndices, padValue,
2897 rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2900 sliceOp.getLoc(), read->
getResult(0), sliceOp.getDest(), writeIndices,
2908 newResults.push_back(write->
getResult(0));
2942 tensor::InsertSliceOp insertOp)
const override {
2944 if (!padOp.hasZeroLowPad())
2947 if (!insertOp.hasUnitStride())
2950 auto padValue = padOp.getConstantPaddingValue();
2954 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2957 if (insertOp.getDest() == padOp.getResult())
2961 padOp.getType().getElementType());
2962 unsigned vecRank = vecType.getRank();
2963 unsigned tensorRank = insertOp.getType().getRank();
2968 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2970 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2971 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2982 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2983 auto read = rewriter.
create<vector::TransferReadOp>(
2984 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2990 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2993 insertOp, read, insertOp.getDest(), writeIndices,
3019 LDBG(
"interleavedUses precondition failed, firstOp: "
3020 << *firstOp <<
", second op: " << *secondOp <<
"\n");
3023 for (
auto v : values) {
3024 for (
auto &u : v.getUses()) {
3026 if (owner == firstOp || owner == secondOp)
3032 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3033 <<
", second op: " << *secondOp <<
"\n");
3043 memref::SubViewOp subViewOp;
3045 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3047 return memref::SubViewOp();
3048 subViewOp = newSubViewOp;
3060 if (xferOp.getMask())
3064 Value viewOrAlloc = xferOp.getSource();
3073 Value subView = subViewOp.getResult();
3076 memref::CopyOp copyOp;
3077 for (
auto &u : subView.
getUses()) {
3078 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3079 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3080 if (newCopyOp.getTarget() != subView)
3094 for (
auto &u : viewOrAlloc.
getUses()) {
3095 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3096 assert(isa<MemRefType>(newFillOp.output().getType()));
3097 if (newFillOp.output() != viewOrAlloc)
3101 maybeFillOp = newFillOp;
3106 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3108 "padding value does not match fill");
3111 Value in = copyOp.getSource();
3117 auto vectorType = xferOp.getVectorType();
3118 Value res = rewriter.
create<vector::TransferReadOp>(
3119 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3120 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3125 rewriter.
eraseOp(maybeFillOp);
3137 if (xferOp.getMask())
3141 Value viewOrAlloc = xferOp.getSource();
3150 Value subView = subViewOp.getResult();
3153 memref::CopyOp copyOp;
3154 for (
auto &u : subViewOp.getResult().getUses()) {
3155 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3156 if (newCopyOp.getSource() != subView)
3168 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3169 Value out = copyOp.getTarget();
3176 auto vector = xferOp.getVector();
3177 rewriter.
create<vector::TransferWriteOp>(
3178 xferOp.getLoc(), vector, out, xferOp.getIndices(),
3179 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3181 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3196 template <
int N,
typename IntTy,
typename... IntTy2>
3198 val = shapedType.getShape()[N];
3203 template <
typename... IntTy>
3205 bindShapeDims<0>(shapedType, vals...);
3243 struct Conv1DGenerator
3245 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3248 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3249 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3250 resShaped = linalgOp.getDpsInitOperand(0)->get();
3251 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3252 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3253 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3258 setConvOperationKind(reduceOp);
3261 reductionKind = maybeKind.value();
3269 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3270 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3292 int64_t nSize, wSize, cSize, kwSize, fSize;
3295 switch (conv1DOpOrder) {
3298 nSize = fSize = cSize = 0;
3305 (wSize + kwSize - 1)};
3306 rhsShape = {kwSize};
3313 case ConvOperationKind::Conv:
3317 case ConvOperationKind::Pool:
3327 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3331 case ConvOperationKind::Conv:
3332 rhsShape = {kwSize, cSize, fSize};
3334 case ConvOperationKind::Pool:
3335 rhsShape = {kwSize};
3338 resShape = {nSize, wSize, fSize};
3344 case ConvOperationKind::Conv:
3348 case ConvOperationKind::Pool:
3354 lhsShape = {nSize, cSize,
3358 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3361 case ConvOperationKind::Conv:
3362 rhsShape = {fSize, cSize, kwSize};
3364 case ConvOperationKind::Pool:
3365 rhsShape = {kwSize};
3368 resShape = {nSize, fSize, wSize};
3372 vector::TransferWriteOp write;
3373 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3378 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3380 Type lhsEltType = lhsShapedType.getElementType();
3381 Type rhsEltType = rhsShapedType.getElementType();
3382 Type resEltType = resShapedType.getElementType();
3392 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3395 Value rhs =
nullptr;
3396 if (oper == ConvOperationKind::Conv)
3397 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3399 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3405 switch (conv1DOpOrder) {
3413 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3414 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3416 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3419 if (oper == ConvOperationKind::Conv)
3420 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3422 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3423 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3434 kwSize, strideW, dilationW, wSizeStep,
3437 if (oper == ConvOperationKind::Conv)
3440 wSizeStep, isSingleChanneled);
3442 auto linearIndex = [&](int64_t kw, int64_t w) {
3443 return kw * (wSize / wSizeStep) + w;
3449 for (int64_t kw = 0; kw < kwSize; ++kw) {
3450 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3452 case ConvOperationKind::Conv:
3453 if (isSingleChanneled) {
3454 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3455 lhsVals[linearIndex(kw, w)],
3456 rhsVals[kw], resVals[w]);
3458 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3459 lhsVals[linearIndex(kw, w)],
3460 rhsVals[kw], resVals[w]);
3463 case ConvOperationKind::Pool:
3464 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3480 switch (conv1DOpOrder) {
3487 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3488 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3494 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3502 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3503 if (srcElementType == dstElementType)
3508 const Type dstType =
3509 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3511 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3512 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3515 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3516 srcWidth < dstWidth)
3517 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3519 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3520 srcWidth < dstWidth)
3521 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3523 assert(
false &&
"unhandled promotion case");
3530 vector::IteratorType par = vector::IteratorType::parallel;
3531 vector::IteratorType red = vector::IteratorType::reduction;
3536 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3538 MapList{{n, w, c}, {c, f}, {n, w, f}},
3540 contrationOp.setKind(reductionKind);
3541 return contrationOp;
3548 return rewriter.
create<vector::OuterProductOp>(
3549 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3571 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3572 bool channelDimScalableFlag,
3574 bool scalableChDim =
false;
3575 bool useMasking =
false;
3576 int64_t nSize, wSize, cSize, kwSize;
3579 if (ShapedType::isDynamic(cSize)) {
3580 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3581 cSize = channelDimVecSize;
3585 scalableChDim = channelDimScalableFlag;
3589 assert(!(useMasking && flatten) &&
3590 "Unsupported flattened conv with dynamic shapes");
3595 vector::TransferWriteOp write;
3596 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3601 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3603 Type lhsEltType = lhsShapedType.getElementType();
3604 Type rhsEltType = rhsShapedType.getElementType();
3605 Type resEltType = resShapedType.getElementType();
3610 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3612 lhsEltType, {
false,
false, scalableChDim});
3613 VectorType rhsType =
3615 {
false, scalableChDim});
3616 VectorType resType =
3618 {
false,
false, scalableChDim});
3631 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3632 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3636 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3639 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3646 Value lhs = rewriter.
create<vector::TransferReadOp>(
3647 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3648 auto maybeMaskedLhs = maybeMaskXferOp(
3649 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3652 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3654 auto maybeMaskedRhs = maybeMaskXferOp(
3655 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3658 Value res = rewriter.
create<vector::TransferReadOp>(
3659 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3660 auto maybeMaskedRes = maybeMaskXferOp(
3661 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3673 for (int64_t kw = 0; kw < kwSize; ++kw) {
3674 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3675 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3676 loc, maybeMaskedLhs->getResult(0),
3678 inOutSliceSizes, inOutStrides));
3682 for (int64_t kw = 0; kw < kwSize; ++kw) {
3683 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3684 loc, maybeMaskedRhs->getResult(0),
3688 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3689 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3690 loc, maybeMaskedRes->getResult(0),
3695 auto linearIndex = [&](int64_t kw, int64_t w) {
3696 return kw * (wSize / wSizeStep) + w;
3702 auto lhsTypeAfterFlattening =
3704 auto resTypeAfterFlattening =
3708 for (int64_t kw = 0; kw < kwSize; ++kw) {
3709 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3710 Value lhsVal = lhsVals[linearIndex(kw, w)];
3711 Value resVal = resVals[w];
3715 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3716 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3717 resVal = rewriter.
create<vector::ShapeCastOp>(
3718 loc, resTypeAfterFlattening, resVals[w]);
3720 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3721 rhsVals[kw], resVal, flatten);
3724 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3731 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3733 for (
auto &collection :
3734 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3735 for (
Value v : collection)
3742 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3743 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3744 loc, resVals[w], maybeMaskedRes->getResult(0),
3754 loc, maybeMaskedRes->getResult(0), resShaped,
3756 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3767 auto rhsTy = cast<ShapedType>(rhs.
getType());
3768 auto resTy = cast<ShapedType>(res.
getType());
3771 lhs =
promote(rewriter, loc, lhs, resTy);
3782 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3783 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3786 for (
int i = 0; i < resSize / rhsSize; ++i) {
3787 for (
int j = 0;
j < rhsSize; ++
j)
3788 indices.push_back(
j);
3791 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3794 rhs = rewriter.
create<vector::BroadcastOp>(
3795 loc, resTy.clone(rhsTy.getElementType()), rhs);
3797 rhs =
promote(rewriter, loc, rhs, resTy);
3802 if (isa<FloatType>(resTy.getElementType()))
3803 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3805 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3806 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3811 FailureOr<Operation *> generateNonChanneledConv() {
3814 if (!iters({Par(), Red()}))
3816 "failed to match conv::W 1-par 1-red");
3819 if (layout({ {w + kw},
3829 FailureOr<Operation *> generateNwcConv() {
3832 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3834 op,
"failed to match conv::Nwc 3-par 2-red");
3837 if (layout({ {n, strideW * w + dilationW * kw, c},
3847 FailureOr<Operation *> generateNcwConv() {
3850 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3852 op,
"failed to match conv::Ncw 3-par 2-red");
3854 if (layout({ {n, c, strideW * w + dilationW * kw},
3864 FailureOr<Operation *> generateNwcPooling() {
3867 if (!iters({Par(), Par(), Par(), Red()}))
3869 "failed to match pooling 3-par 1-red");
3872 if (layout({ {n, strideW * w + dilationW * kw, c},
3882 FailureOr<Operation *> generateNcwPooling() {
3885 if (!iters({Par(), Par(), Par(), Red()}))
3887 "failed to match pooling 3-par 1-red");
3889 if (layout({ {n, c, strideW * w + dilationW * kw},
3899 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3900 bool vecChDimScalableFlag =
false,
3901 bool flatten =
false) {
3904 if (!iters({Par(), Par(), Par(), Red()}))
3906 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3909 if (layout({ {n, strideW * w + dilationW * kw, c},
3912 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3918 ConvOperationKind oper = ConvOperationKind::Conv;
3920 StringAttr poolExtOp;
3921 bool isPoolExt =
false;
3922 int strideW, dilationW;
3923 Value lhsShaped, rhsShaped, resShaped;
3924 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3925 vector::CombiningKind reductionKind;
3928 void setConvOperationKind(
Operation *reduceOp) {
3929 int numBlockArguments =
3930 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3931 if (numBlockArguments == 1) {
3936 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3937 llvm::IsaPred<BlockArgument>);
3938 Operation *feedOp = (*feedValIt).getDefiningOp();
3940 oper = ConvOperationKind::Pool;
3945 oper = ConvOperationKind::Conv;
3949 oper = ConvOperationKind::Pool;
3959 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3960 Conv1DGenerator conv1dGen(rewriter, op);
3961 auto res = conv1dGen.generateNonChanneledConv();
3964 res = conv1dGen.generateNwcConv();
3967 res = conv1dGen.generateNcwConv();
3970 res = conv1dGen.generateNwcPooling();
3973 res = conv1dGen.generateNcwPooling();
3980 uint64_t vecChDimSize = ShapedType::kDynamic;
3981 bool vecChDimScalableFlag =
false;
3982 if (!inputVecSizes.empty()) {
3985 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3986 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3987 "Not a 1D depthwise conv!");
3990 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3991 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3993 vecChDimSize = inputVecSizes[chDimIdx];
3994 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3996 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3997 flatten1DDepthwiseConv);
4006 if (failed(resultOrFail))
4010 rewriter.
eraseOp(op.getOperation());
4013 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
SmallVector< int64_t > innerDimsPos
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.