36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
60 bool flatten1DDepthwiseConv =
false);
64 template <
typename OpType>
67 block.
walk([&](OpType op) {
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize,
int strideW,
int dilationW,
84 int64_t wSizeStep,
bool isSingleChanneled) {
86 if (isSingleChanneled) {
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.
create<vector::ExtractOp>(
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep,
bool isSingleChanneled) {
136 if (isSingleChanneled) {
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
159 Value res, int64_t wSize, int64_t wSizeStep,
161 bool isSingleChanneled) {
163 if (isSingleChanneled) {
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.
create<vector::InsertStridedSliceOp>(
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.
create<vector::InsertStridedSliceOp>(
204 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
207 if (dimPermutation.has_value()) {
209 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
211 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
213 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
214 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
226 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
231 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
232 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
247 std::optional<AffineMap> maybeMaskingMap);
275 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
278 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
279 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
281 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
282 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
289 unsigned operandDimPos;
290 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
294 Value dynamicDim = linalgOp.hasPureTensorSemantics()
296 linalgOp.getLoc(), operand, operandDimPos)
298 linalgOp.getLoc(), operand, operandDimPos);
299 iterSpaceValueSizes.push_back(dynamicDim);
315 if (!inputVectorSizes.empty()) {
319 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
320 scalableVecDims.append(inputScalableVecDims.begin(),
321 inputScalableVecDims.end());
326 canonicalVecShape = linalgOp.getStaticLoopRanges();
327 scalableVecDims.append(linalgOp.getNumLoops(),
false);
330 LDBG(
"Canonical vector shape: ");
331 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
332 LLVM_DEBUG(llvm::dbgs() <<
"\n");
333 LDBG(
"Scalable vector dims: ");
334 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
335 LLVM_DEBUG(llvm::dbgs() <<
"\n");
337 if (ShapedType::isDynamicShape(canonicalVecShape))
341 initIterSpaceStaticSizes(linalgOp);
346 if (
failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
356 Value VectorizationState::getOrCreateMaskFor(
358 std::optional<AffineMap> maybeMaskingMap) {
360 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
364 assert(!maskableOp.isMasked() &&
365 "Masking an operation that is already masked");
368 assert((!maybeMaskingMap || *maybeMaskingMap) &&
369 "Unexpected null mask permutation map");
371 maybeMaskingMap ? *maybeMaskingMap
373 linalgOp.getNumLoops(), rewriter.
getContext());
375 LDBG(
"Masking map: " << maskingMap <<
"\n");
379 auto activeMaskIt = activeMaskCache.find(maskingMap);
380 if (activeMaskIt != activeMaskCache.end()) {
381 Value mask = activeMaskIt->second;
382 LDBG(
"Reusing mask: " << mask <<
"\n");
393 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
394 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
395 auto maskShape = maskType.getShape();
397 LDBG(
"Mask shape: ");
398 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
399 LLVM_DEBUG(llvm::dbgs() <<
"\n");
401 if (permutedStaticSizes == maskShape) {
402 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
403 activeMaskCache[maskingMap] =
Value();
410 assert(!maskShape.empty() && !upperBounds.empty() &&
411 "Masked 0-d vectors are not supported yet");
414 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
415 maskType, upperBounds);
416 LDBG(
"Creating new mask: " << mask <<
"\n");
417 activeMaskCache[maskingMap] = mask;
428 std::optional<AffineMap> maybeMaskingMap) {
429 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
433 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
436 LDBG(
"No mask required\n");
441 assert(opToMask &&
"Expected a valid operation to mask");
442 auto maskOp = cast<vector::MaskOp>(
444 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
450 LDBG(
"Masked operation: " << *maskOp <<
"\n");
473 "expected projected permutation");
475 assert(res.getNumDims() == res.getNumResults() &&
476 "expected reindexed map with same number of dims and results");
508 std::optional<vector::CombiningKind>
510 using ::mlir::vector::CombiningKind;
515 .Case<arith::AddIOp, arith::AddFOp>(
516 [&](
auto op) {
return CombiningKind::ADD; })
517 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
518 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
519 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
520 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
521 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
523 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
524 .Case<arith::MulIOp, arith::MulFOp>(
525 [&](
auto op) {
return CombiningKind::MUL; })
526 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
527 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
528 .Default([&](
auto op) {
return std::nullopt; });
539 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
544 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
545 combinerOps.size() != 1)
549 return combinerOps[0];
555 auto dstVecType = dyn_cast<VectorType>(dstType);
557 if (dstVecType.getRank() == 0)
563 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
575 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
576 return b.
create<vector::MultiDimReductionOp>(
577 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
581 return llvm::to_vector(
595 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
596 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
605 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
607 auto vectorType = state.getCanonicalVecType(
611 if (vectorType.getRank() > 0) {
614 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
616 assert(value.
getType() == vectorType &&
"Incorrect type");
617 write = rewriter.
create<vector::TransferWriteOp>(
618 loc, value, outputOperand->
get(), indices, writeMap);
621 if (!isa<VectorType>(value.
getType()))
622 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
623 assert(value.
getType() == vectorType &&
"Incorrect type");
624 write = rewriter.
create<vector::TransferWriteOp>(
628 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
632 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
633 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
638 LDBG(
"vectorized op: " << *write <<
"\n");
667 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
676 linalgOp.getDpsInitOperand(output.index()), state);
678 newResults.push_back(newResult);
692 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
695 auto loc = indexOp.getLoc();
697 auto targetShape = state.getCanonicalVecShape();
700 llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
701 auto indexSteps = rewriter.
create<arith::ConstantOp>(
706 if (indexOp.getDim() == targetShape.size() - 1)
712 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
713 std::swap(permPattern[indexOp.getDim()], permPattern.back());
717 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
718 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
721 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
722 std::swap(transposition.back(), transposition[indexOp.getDim()]);
724 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
732 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
736 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
741 if (not extractOp.getIndices().empty()) {
742 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
746 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
747 return !VectorType::isValidElementType(type);
767 tensor::ExtractOp extractOp,
770 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
771 auto loc = extractOp.getLoc();
774 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
776 const size_t numIndices = extractOp.getIndices().size();
777 for (
size_t i = 1; i < numIndices; i++) {
778 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
782 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
785 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
788 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
790 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
801 auto targetShape = linalgOp.getStaticLoopRanges();
802 assert(((llvm::count_if(targetShape,
803 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
804 "n-D vectors are not yet supported");
805 assert(targetShape.back() != 1 &&
806 "1-D vectors with the trailing dim eqaual 1 are not yet supported");
812 auto *block = linalgOp.getBlock();
813 if (isa<BlockArgument>(val))
814 return llvm::all_of(block->getArguments(),
815 [&val](
Value v) { return (v != val); });
818 assert(defOp &&
"This is neither a block argument nor an operation result");
823 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
824 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
825 return (indexOp.getDim() != trailingLoopDim);
827 auto *ancestor = block->findAncestorOpInBlock(*defOp);
834 if (isa<arith::ConstantOp>(ancestor))
861 bool &foundIndexOp) {
863 auto targetShape = linalgOp.getStaticLoopRanges();
864 assert(((llvm::count_if(targetShape,
865 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
866 "n-D vectors are not yet supported");
867 assert(targetShape.back() != 1 &&
868 "1-D vectors with the trailing dim 1 are not yet supported");
874 auto *block = linalgOp.getBlock();
875 if (isa<BlockArgument>(val))
876 return llvm::all_of(block->getArguments(),
877 [&val](
Value v) { return (v != val); });
880 assert(defOp &&
"This is neither a block argument nor an operation result");
884 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
885 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
886 foundIndexOp = (indexOp.getDim() == trailingLoopDim);
890 auto *ancestor = block->findAncestorOpInBlock(*defOp);
897 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
919 LinalgOp &linalgOp) {
921 auto targetShape = linalgOp.getStaticLoopRanges();
922 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
925 if (inputShape.getShape().empty())
931 if (linalgOp.hasDynamicShape())
939 if ((llvm::count_if(targetShape,
940 [](int64_t dimSize) {
return dimSize > 1; }) != 1) ||
941 targetShape.back() == 1)
947 if (inputShape.getShape().back() == 1)
950 bool leadingIdxsLoopInvariant =
true;
955 auto indices = extractOp.getIndices();
956 auto leadIndices = indices.drop_back(1);
959 if (inputShape.getShape()[i] == 1)
965 if (!leadingIdxsLoopInvariant) {
966 LDBG(
"Found gather load: " << extractOp);
974 auto extractOpTrailingIdx = indices.back();
978 if (leadingIdxsLoopInvariant &&
980 LDBG(
"Found scalar broadcast load: " << extractOp);
989 bool foundIndexOp =
false;
990 bool isContiguousLoad =
992 isContiguousLoad &= foundIndexOp;
994 if (isContiguousLoad) {
995 LDBG(
"Found contigous load: " << extractOp);
1000 LDBG(
"Found gather load: " << extractOp);
1011 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1014 auto loc = extractOp.getLoc();
1017 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1018 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1022 auto passThruConstantOp =
1028 extractOp.getIndices().size(),
1029 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1040 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1041 maskConstantOp, passThruConstantOp);
1042 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1044 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1067 auto resTrailingDim = resultType.getShape().back();
1068 auto zero = rewriter.
create<arith::ConstantOp>(
1070 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1071 auto idx = bvm.
lookup(extractOp.getIndices()[i]);
1072 if (idx.getType().isIndex()) {
1073 transferReadIdxs.push_back(idx);
1077 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1079 bvm.
lookup(extractOp.getIndices()[i]));
1080 transferReadIdxs.push_back(
1081 rewriter.
create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1085 auto dstRank = resultType.getRank();
1086 auto srcRank = extractOp.getTensor().getType().getRank();
1095 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1096 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1097 permutationMap, inBounds);
1099 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1107 int32_t rankDiff = dstRank - srcRank;
1115 while (rankDiff > 0) {
1116 permutationMap = permutationMap.insertResult(
1121 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1122 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1125 LDBG(
"Vectorised as contiguous load: " << extractOp);
1138 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1139 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1143 (outputType && reduceType.getShape() == outputType.getShape()))
1172 LDBG(
"vectorize op " << *op <<
"\n");
1175 if (!customVectorizationHooks.empty()) {
1176 for (
auto &customFunc : customVectorizationHooks) {
1186 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1196 auto blockArg = dyn_cast<BlockArgument>(operand);
1197 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1198 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1202 linalgOp.getRegionOutputArgs(),
1203 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1206 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1208 if (!reductionOperands.empty()) {
1209 assert(reductionOperands.size() == 1);
1211 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1212 reductionOperands[0].second, bvm);
1219 VectorType firstMaxRankedType;
1221 auto vecOperand = bvm.
lookup(operand);
1222 assert(vecOperand &&
"Vector operand couldn't be found");
1224 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1225 if (vecType && (!firstMaxRankedType ||
1226 firstMaxRankedType.getRank() < vecType.getRank()))
1227 firstMaxRankedType = vecType;
1233 assert(vecOperand &&
"Vector operand couldn't be found");
1235 if (firstMaxRankedType) {
1238 firstMaxRankedType.getScalableDims());
1241 vecOperands.push_back(vecOperand);
1247 resultTypes.push_back(
1250 firstMaxRankedType.getScalableDims())
1286 LDBG(
"Vectorizing operation as linalg generic\n");
1287 Block *block = linalgOp.getBlock();
1294 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1296 if (linalgOp.getNumDpsInits() == 0)
1301 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1302 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1303 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1304 if (linalgOp.isScalar(opOperand)) {
1305 bvm.
map(bbarg, opOperand->get());
1311 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1317 if (isa<AffineConstantExpr>(result.value())) {
1318 zeroPos.push_back(result.index());
1324 VectorType readType;
1326 if (linalgOp.isDpsInput(opOperand)) {
1329 readType = state.getCanonicalVecType(elemType);
1336 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1342 loc, readType, opOperand->get(), indices, readMap);
1343 read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1348 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1350 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1356 if (readType.getRank() == 0)
1371 hooks.push_back(vectorizeYield);
1378 hooks.push_back(vectorizeIndex);
1385 hooks.push_back(vectorizeExtract);
1392 LDBG(
"failed to vectorize: " << op <<
"\n");
1397 state.maskOperation(rewriter, result.
newOp, linalgOp);
1398 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1422 auto inputType = cast<VectorType>(input.
getType());
1423 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1424 inputType.getElementType());
1425 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1426 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1433 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1434 assert(llvm::none_of(
1435 destShape.drop_front(inputVectorSizes.size()),
1436 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1437 "Only dims aligned with inputVectorSizes may be dynamic");
1438 bool needMaskForWrite = !llvm::equal(
1439 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1440 if (needMaskForWrite) {
1442 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1443 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1446 Value maskForWrite =
1447 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1487 auto padValue = packOp.getPaddingValue();
1489 padValue = rewriter.
create<arith::ConstantOp>(
1490 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1494 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1495 .reifyResultShapes(rewriter, reifiedReturnShapes);
1497 assert(
succeeded(status) &&
"failed to reify result shapes");
1502 bool useInBoundsInsteadOfMasking =
true;
1503 if (inputVectorSizes.empty()) {
1505 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1506 useInBoundsInsteadOfMasking =
false;
1511 auto innerTiles = packOp.getStaticInnerTiles();
1512 auto innerDimsPos = packOp.getInnerDimsPos();
1513 auto outerDimsPerm = packOp.getOuterDimsPerm();
1514 if (!outerDimsPerm.empty())
1517 for (
auto [idx, size] :
enumerate(innerTiles))
1518 inputShape[innerDimsPos[idx]] *= size;
1520 rewriter, loc, packOp.getSource(), inputShape, padValue,
1521 useInBoundsInsteadOfMasking);
1525 destShape.append(innerTiles.begin(), innerTiles.end());
1527 packOp.getDestType().getElementType());
1529 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1532 auto destPermutation =
1534 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1535 loc, shapeCastOp.getResult(), destPermutation);
1540 reifiedReturnShapes[0], inputVectorSizes);
1541 newResults.push_back(write->getResult(0));
1559 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1565 inputVectorSizes.end());
1588 for (
auto [index, size] :
enumerate(innerTiles)) {
1589 readMaskShape[innerDimPos[index]] =
1592 if (!outerDimsPerm.empty()) {
1595 readMaskShape.append(sourceShape.begin() + inputVectorSizes.size(),
1600 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1601 .reifyResultShapes(rewriter, reifiedRetShapes);
1603 LDBG(
"Unable to reify result shapes of " << unpackOp);
1608 auto padValue = rewriter.
create<arith::ConstantOp>(
1609 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1614 rewriter, loc, unpackOp.getSource(),
1617 PackingMetadata packMetadata;
1620 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1622 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1624 RankedTensorType stripMineTensorType =
1627 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1628 loc, readResult, lastDimToInsertPosPerm);
1631 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1632 stripMineTensorType, packMetadata.reassociations);
1633 mlir::VectorType vecCollapsedType =
1634 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1635 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1636 loc, vecCollapsedType, transposeOp->getResult(0));
1641 unpackOp.getDestType().hasStaticShape()
1643 : shapeCastOp.getResultVectorType().getShape());
1646 reifiedRetShapes[0], writeMaskShape);
1647 newResults.push_back(write->
getResult(0));
1658 auto padValue = padOp.getConstantPaddingValue();
1667 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1668 .reifyResultShapes(rewriter, reifiedReturnShapes);
1670 assert(
succeeded(status) &&
"failed to reify result shapes");
1672 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue);
1674 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
1675 newResults.push_back(write->
getResult(0));
1683 LDBG(
"reduction precondition failed: no reduction iterator\n");
1686 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1687 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1693 LDBG(
"reduction precondition failed: reduction detection failed\n");
1702 bool flatten1DDepthwiseConv) {
1703 if (flatten1DDepthwiseConv) {
1704 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1709 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1710 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1716 Value lhs = conv.getDpsInputOperand(0)->get();
1718 auto shapeWithoutCh = lhsShape.drop_back(1);
1719 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1720 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1721 "channel dim can be dynamic\n");
1730 bool flatten1DDepthwiseConv) {
1731 if (isa<ConvolutionOpInterface>(op.getOperation()))
1737 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1741 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1750 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1751 return !getConstantIntValue(res).has_value();
1753 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1757 if (!inputVectorSizes.empty() &&
1766 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
1768 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1771 if (!inputVectorSizes.empty() &&
1777 linalgOp, flatten1DDepthwiseConv))) {
1778 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1791 customPreconditions,
1794 customPrecondition(&innerOp, vectorizeNDExtract));
1798 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1799 return !VectorType::isValidElementType(type);
1803 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1804 return !VectorType::isValidElementType(type);
1815 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1821 LDBG(
"precondition failed: not projected permutations\n");
1825 LDBG(
"precondition failed: reduction preconditions\n");
1834 auto padValue = packOp.getPaddingValue();
1837 LDBG(
"pad value is not constant: " << packOp <<
"\n");
1841 bool satisfyEmptyCond =
true;
1842 if (inputVectorSizes.empty()) {
1843 if (!packOp.getDestType().hasStaticShape() ||
1844 !packOp.getSourceType().hasStaticShape())
1845 satisfyEmptyCond =
false;
1848 if (!satisfyEmptyCond &&
1850 resultTensorShape.take_front(packOp.getSourceRank()),
1854 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
1855 return !getConstantIntValue(v).has_value();
1857 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
1867 auto padValue = padOp.getConstantPaddingValue();
1869 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1878 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1879 std::optional<int64_t> res = getConstantIntValue(v);
1880 return !res.has_value() || res.value() != 0;
1882 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
1894 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1895 "Number of input vector sizes and scalable dims doesn't match");
1897 if (inputVectorSizes.empty())
1900 bool isScalable = inputScalableVecDims.back();
1906 auto linalgOp = dyn_cast<LinalgOp>(op);
1908 isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
1914 bool flatten1DDepthwiseConv) {
1916 inputScalableVecDims)))
1920 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
1923 flatten1DDepthwiseConv);
1925 .Case<tensor::PadOp>([&](
auto padOp) {
1928 .Case<tensor::PackOp>([&](
auto packOp) {
1931 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
1934 .Default([](
auto) {
return failure(); });
1940 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1942 for (
auto op : make_early_inc_range(toReplace)) {
1946 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
1947 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
1961 bool vectorizeNDExtract,
1962 bool flatten1DDepthwiseConv) {
1963 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
1964 LDBG(
"Input vector sizes: ");
1965 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1966 LLVM_DEBUG(llvm::dbgs() <<
"\n");
1967 LDBG(
"Input scalable vector dims: ");
1968 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
1969 LLVM_DEBUG(llvm::dbgs() <<
"\n");
1973 flatten1DDepthwiseConv))) {
1974 LDBG(
"Vectorization pre-conditions failed\n");
1980 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
1981 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
1982 inputScalableVecDims))) {
1983 LDBG(
"Vectorization state couldn't be initialized\n");
1989 auto vectorizeResult =
1991 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
1995 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
1997 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
1998 flatten1DDepthwiseConv);
2000 llvm::append_range(results, (*convOr)->getResults());
2004 LDBG(
"Unsupported convolution can't be vectorized.\n");
2008 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2021 .Case<tensor::PadOp>([&](
auto padOp) {
2025 .Case<tensor::PackOp>([&](
auto packOp) {
2029 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2031 inputVectorSizes, results);
2033 .Default([](
auto) {
return failure(); });
2035 if (
failed(vectorizeResult)) {
2036 LDBG(
"Vectorization failed\n");
2040 if (!results.empty())
2049 memref::CopyOp copyOp) {
2050 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2051 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2052 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2057 if (!VectorType::isValidElementType(srcElementType) ||
2058 !VectorType::isValidElementType(dstElementType))
2069 loc, readType, copyOp.getSource(), indices,
2071 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2076 loc,
readValue, copyOp.getTarget(), indices,
2088 return cast<IntegerAttr>(attr).getInt();
2097 for (
auto o : ofrs) {
2098 if (
auto val = llvm::dyn_cast_if_present<Value>(o)) {
2099 result.push_back(val);
2101 result.push_back(rewriter.
create<arith::ConstantIndexOp>(
2120 tensor::PadOp padOp,
Value dest) {
2121 auto sourceType = padOp.getSourceType();
2122 auto resultType = padOp.getResultType();
2123 if (!VectorType::isValidElementType(sourceType.getElementType()))
2129 auto padValue = padOp.getConstantPaddingValue();
2131 if (!sourceType.hasStaticShape())
2134 auto elemType = sourceType.getElementType();
2135 padValue = rewriter.
create<arith::ConstantOp>(
2136 padOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2142 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
2143 if (!sourceType.isDynamicDim(i)) {
2144 vecShape.push_back(sourceType.getDimSize(i));
2147 readInBounds.push_back(
true);
2148 writeInBounds.push_back(
true);
2149 }
else if (!resultType.isDynamicDim(i)) {
2153 vecShape.push_back(resultType.getDimSize(i));
2156 readInBounds.push_back(
false);
2158 writeInBounds.push_back(
2160 static_cast<int64_t
>(0));
2167 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2172 rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2173 auto read = rewriter.
create<vector::TransferReadOp>(
2174 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2179 if (llvm::equal(vecShape, resultType.getShape()) &&
2180 llvm::all_of(writeInBounds, [](
bool b) {
return b; }))
2182 dest = fill.output();
2196 template <
typename OpTy>
2202 bool changed =
false;
2204 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2205 if (
auto op = dyn_cast<OpTy>(user))
2206 changed |= rewriteUser(rewriter, padOp, op).
succeeded();
2212 tensor::PadOp padOp, OpTy op)
const = 0;
2240 vector::TransferReadOp xferOp)
const override {
2242 if (!padOp.hasZeroLowPad())
2245 auto padValue = padOp.getConstantPaddingValue();
2249 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2254 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2256 xferOp.getSourceMutable().assign(padOp.getSource());
2257 xferOp.getPaddingMutable().assign(padValue);
2302 vector::TransferWriteOp xferOp)
const override {
2304 if (xferOp.getTransferRank() == 0)
2308 if (!padOp.hasZeroLowPad())
2311 auto padValue = padOp.getConstantPaddingValue();
2315 if (!xferOp->hasOneUse())
2317 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2321 if (!trimPadding.hasZeroOffset())
2324 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2332 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2333 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2335 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2351 tensor::ExtractSliceOp afterTrimming)
const {
2354 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2355 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2358 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2359 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2364 if (t1.getRank() != t2.getRank())
2369 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2370 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2372 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2377 if (t1.getNumDynamicDims() == 0)
2385 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2389 assert(
static_cast<size_t>(t1.getRank()) ==
2390 beforeSlice.getMixedSizes().size());
2391 assert(
static_cast<size_t>(t2.getRank()) ==
2392 afterTrimming.getMixedSizes().size());
2394 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2396 if (!t1.isDynamicDim(i))
2398 auto size1 = beforeSlice.getMixedSizes()[i];
2399 auto size2 = afterTrimming.getMixedSizes()[i];
2406 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2407 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2413 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2414 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2415 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2416 minOp1.getOperands() == minOp2.getOperands())
2456 tensor::InsertSliceOp insertOp)
const override {
2458 if (!padOp.hasZeroLowPad())
2461 if (!insertOp.hasUnitStride())
2464 auto padValue = padOp.getConstantPaddingValue();
2468 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2471 if (insertOp.getDest() == padOp.getResult())
2475 padOp.getType().getElementType());
2476 unsigned vecRank = vecType.getRank();
2477 unsigned tensorRank = insertOp.getType().getRank();
2482 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2484 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2485 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2496 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2497 auto read = rewriter.
create<vector::TransferReadOp>(
2498 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2507 insertOp, read, insertOp.getDest(), writeIndices,
2536 LDBG(
"interleavedUses precondition failed, firstOp: "
2537 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2540 for (
auto v : values) {
2541 for (
auto &u : v.getUses()) {
2543 if (owner == firstOp || owner == secondOp)
2549 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2550 <<
", second op: " << *secondOp <<
"\n");
2560 memref::SubViewOp subViewOp;
2562 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2564 return memref::SubViewOp();
2565 subViewOp = newSubViewOp;
2577 if (xferOp.getMask())
2581 Value viewOrAlloc = xferOp.getSource();
2590 Value subView = subViewOp.getResult();
2593 memref::CopyOp copyOp;
2594 for (
auto &u : subView.
getUses()) {
2595 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2596 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2597 if (newCopyOp.getTarget() != subView)
2611 for (
auto &u : viewOrAlloc.
getUses()) {
2612 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2613 assert(isa<MemRefType>(newFillOp.output().getType()));
2614 if (newFillOp.output() != viewOrAlloc)
2618 maybeFillOp = newFillOp;
2623 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2625 "padding value does not match fill");
2628 Value in = copyOp.getSource();
2634 Value res = rewriter.
create<vector::TransferReadOp>(
2635 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2636 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2641 rewriter.
eraseOp(maybeFillOp);
2653 if (xferOp.getMask())
2657 Value viewOrAlloc = xferOp.getSource();
2666 Value subView = subViewOp.getResult();
2669 memref::CopyOp copyOp;
2670 for (
auto &u : subViewOp.getResult().getUses()) {
2671 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2672 if (newCopyOp.getSource() != subView)
2684 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2685 Value out = copyOp.getTarget();
2692 rewriter.
create<vector::TransferWriteOp>(
2693 xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2694 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2711 template <
int N,
typename IntTy,
typename... IntTy2>
2713 val = shapedType.getShape()[N];
2718 template <
typename... IntTy>
2720 bindShapeDims<0>(shapedType, vals...);
2724 bool isCastOfBlockArgument(
Operation *op) {
2729 bool isSupportedPoolKind(vector::CombiningKind kind) {
2731 case vector::CombiningKind::ADD:
2732 case vector::CombiningKind::MAXNUMF:
2733 case vector::CombiningKind::MAXIMUMF:
2734 case vector::CombiningKind::MAXSI:
2735 case vector::CombiningKind::MAXUI:
2736 case vector::CombiningKind::MINNUMF:
2737 case vector::CombiningKind::MINIMUMF:
2738 case vector::CombiningKind::MINSI:
2780 struct Conv1DGenerator
2782 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
2785 strideW(strideW), dilationW(dilationW) {
2787 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2789 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2790 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2791 resShaped = linalgOp.getDpsInitOperand(0)->get();
2792 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2793 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2794 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2795 if (!lhsShapedType || !rhsShapedType || !resShapedType)
2799 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2800 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2808 if (!setOperKind(reduceOp))
2811 if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2812 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2816 auto rhsRank = rhsShapedType.getRank();
2819 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2853 int64_t nSize, wSize, cSize, kwSize, fSize;
2856 switch (conv1DOpOrder) {
2859 nSize = fSize = cSize = 0;
2866 (wSize + kwSize - 1)};
2867 rhsShape = {kwSize};
2888 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2893 rhsShape = {kwSize, cSize, fSize};
2896 rhsShape = {kwSize};
2899 resShape = {nSize, wSize, fSize};
2915 lhsShape = {nSize, cSize,
2919 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2923 rhsShape = {fSize, cSize, kwSize};
2926 rhsShape = {kwSize};
2929 resShape = {nSize, fSize, wSize};
2933 vector::TransferWriteOp write;
2934 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
2939 int64_t wSizeStep = strideW == 1 ? wSize : 1;
2941 Type lhsEltType = lhsShapedType.getElementType();
2942 Type rhsEltType = rhsShapedType.getElementType();
2943 Type resEltType = resShapedType.getElementType();
2953 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2956 Value rhs =
nullptr;
2958 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2960 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
2966 switch (conv1DOpOrder) {
2974 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2975 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
2977 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2981 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
2983 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2984 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
2995 kwSize, strideW, dilationW, wSizeStep,
3001 wSizeStep, isSingleChanneled);
3003 auto linearIndex = [&](int64_t kw, int64_t w) {
3004 return kw * (wSize / wSizeStep) + w;
3010 for (int64_t kw = 0; kw < kwSize; ++kw) {
3011 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3014 if (isSingleChanneled) {
3015 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3016 lhsVals[linearIndex(kw, w)],
3017 rhsVals[kw], resVals[w]);
3019 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3020 lhsVals[linearIndex(kw, w)],
3021 rhsVals[kw], resVals[w]);
3025 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3041 switch (conv1DOpOrder) {
3048 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3049 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3055 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3063 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3064 if (srcElementType == dstElementType)
3069 const Type dstType =
3070 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3072 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3073 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3076 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3077 srcWidth < dstWidth)
3078 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3080 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3081 srcWidth < dstWidth)
3082 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3084 assert(
false &&
"unhandled promotion case");
3091 vector::IteratorType par = vector::IteratorType::parallel;
3092 vector::IteratorType red = vector::IteratorType::reduction;
3097 return rewriter.
create<vector::ContractionOp>(
3099 MapList{{n, w, c}, {c, f}, {n, w, f}},
3107 return rewriter.
create<vector::OuterProductOp>(
3108 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3131 bool channelDimScalableFlag,
3136 bool scalableChDim =
false;
3137 bool useMasking =
false;
3138 int64_t nSize, wSize, cSize, kwSize;
3141 if (ShapedType::isDynamic(cSize)) {
3142 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3143 cSize = channelDimVecSize;
3147 scalableChDim = channelDimScalableFlag;
3151 assert(!(useMasking && flatten) &&
3152 "Unsupported flattened conv with dynamic shapes");
3157 vector::TransferWriteOp write;
3158 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3163 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3165 Type lhsEltType = lhsShapedType.getElementType();
3166 Type rhsEltType = rhsShapedType.getElementType();
3167 Type resEltType = resShapedType.getElementType();
3172 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3174 lhsEltType, {
false,
false, scalableChDim});
3175 VectorType rhsType =
3177 {
false, scalableChDim});
3178 VectorType resType =
3180 {
false,
false, scalableChDim});
3193 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3196 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3203 Value lhs = rewriter.
create<vector::TransferReadOp>(
3204 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3205 auto maybeMaskedLhs = maybeMaskXferOp(
3206 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3209 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3211 auto maybeMaskedRhs = maybeMaskXferOp(
3212 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3215 Value res = rewriter.
create<vector::TransferReadOp>(
3216 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3217 auto maybeMaskedRes = maybeMaskXferOp(
3218 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3230 for (int64_t kw = 0; kw < kwSize; ++kw) {
3231 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3232 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3233 loc, maybeMaskedLhs->getResult(0),
3235 inOutSliceSizes, inOutStrides));
3239 for (int64_t kw = 0; kw < kwSize; ++kw) {
3240 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3241 loc, maybeMaskedRhs->getResult(0),
3245 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3246 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3247 loc, maybeMaskedRes->getResult(0),
3252 auto linearIndex = [&](int64_t kw, int64_t w) {
3253 return kw * (wSize / wSizeStep) + w;
3258 auto inOutFlattenSliceSizes =
3260 auto lhsTypeAfterFlattening =
3262 auto resTypeAfterFlattening =
3266 for (int64_t kw = 0; kw < kwSize; ++kw) {
3267 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3268 Value lhsVal = lhsVals[linearIndex(kw, w)];
3269 Value resVal = resVals[w];
3273 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3274 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3275 resVal = rewriter.
create<vector::ShapeCastOp>(
3276 loc, resTypeAfterFlattening, resVals[w]);
3278 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3279 rhsVals[kw], resVal, flatten);
3282 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3289 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3291 for (
auto &collection :
3292 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3293 for (
Value v : collection)
3300 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3301 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3302 loc, resVals[w], maybeMaskedRes->getResult(0),
3312 loc, maybeMaskedRes->getResult(0), resShaped,
3314 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3325 auto rhsTy = cast<ShapedType>(rhs.
getType());
3326 auto resTy = cast<ShapedType>(res.
getType());
3329 lhs =
promote(rewriter, loc, lhs, resTy);
3340 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3341 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3344 for (
int i = 0; i < resSize / rhsSize; ++i) {
3345 for (
int j = 0;
j < rhsSize; ++
j)
3346 indicies.push_back(
j);
3349 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
3352 rhs = rewriter.
create<vector::BroadcastOp>(
3353 loc, resTy.clone(rhsTy.getElementType()), rhs);
3355 rhs =
promote(rewriter, loc, rhs, resTy);
3360 if (isa<FloatType>(resTy.getElementType()))
3361 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3363 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3364 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3372 if (!iters({Par(), Red()}))
3374 "failed to match conv::W 1-par 1-red");
3377 if (layout({ {w + kw},
3390 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3392 op,
"failed to match conv::Nwc 3-par 2-red");
3395 if (layout({ {n, strideW * w + dilationW * kw, c},
3408 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3410 op,
"failed to match conv::Ncw 3-par 2-red");
3412 if (layout({ {n, c, strideW * w + dilationW * kw},
3425 if (!iters({Par(), Par(), Par(), Red()}))
3427 "failed to match pooling 3-par 1-red");
3430 if (layout({ {n, strideW * w + dilationW * kw, c},
3443 if (!iters({Par(), Par(), Par(), Red()}))
3445 "failed to match pooling 3-par 1-red");
3447 if (layout({ {n, c, strideW * w + dilationW * kw},
3458 bool vecChDimScalableFlag =
false,
3459 bool flatten =
false) {
3462 if (!iters({Par(), Par(), Par(), Red()}))
3464 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3467 if (layout({ {n, strideW * w + dilationW * kw, c},
3470 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3476 enum OperKind { Conv, Pool };
3478 OperKind oper = Conv;
3480 StringAttr poolExtOp;
3481 bool isPoolExt =
false;
3482 int strideW, dilationW;
3483 Value lhsShaped, rhsShaped, resShaped;
3484 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3495 int numBlockArguments =
3496 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3497 switch (numBlockArguments) {
3501 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3502 llvm::IsaPred<BlockArgument>);
3503 Operation *feedOp = (*feedValIt).getDefiningOp();
3504 if (isCastOfBlockArgument(feedOp)) {
3508 }
else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3510 if (isa<BlockArgument>(v))
3512 if (Operation *op = v.getDefiningOp())
3513 return isCastOfBlockArgument(op);
3536 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3543 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3544 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3545 Conv1DGenerator e(rewriter, op, stride, dilation);
3546 auto res = e.generateNonChanneledConv();
3549 res = e.generateNwcConv();
3552 res = e.generateNcwConv();
3555 res = e.generateNwcPooling();
3558 res = e.generateNcwPooling();
3565 uint64_t vecChDimSize = ShapedType::kDynamic;
3566 bool vecChDimScalableFlag =
false;
3567 if (!inputVecSizes.empty()) {
3570 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3571 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3572 "Not a 1D depthwise conv!");
3575 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3576 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3578 vecChDimSize = inputVecSizes[chDimIdx];
3579 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3581 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3582 flatten1DDepthwiseConv);
3591 if (
failed(resultOrFail))
3595 rewriter.
eraseOp(op.getOperation());
3598 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking=true)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
bool failed() const
Returns true if the provided LogicalResult corresponds to a failure value.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.