29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
36 #include <type_traits>
41 #define DEBUG_TYPE "linalg-vectorization"
43 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
44 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
52 template <
typename OpType>
55 block.
walk([&](OpType op) {
70 int64_t nSize, int64_t wSize, int64_t cSize,
71 int64_t kwSize,
int strideW,
int dilationW,
72 int64_t wSizeStep,
bool isSingleChanneled) {
74 if (isSingleChanneled) {
79 for (int64_t kw = 0; kw < kwSize; ++kw) {
80 for (int64_t w = 0; w < wSize; w += wSizeStep) {
81 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
90 for (int64_t kw = 0; kw < kwSize; ++kw) {
91 for (int64_t w = 0; w < wSize; w += wSizeStep) {
92 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
110 for (int64_t kw = 0; kw < kwSize; ++kw) {
111 result.push_back(rewriter.
create<vector::ExtractOp>(
121 int64_t nSize, int64_t wSize, int64_t fSize,
122 int64_t wSizeStep,
bool isSingleChanneled) {
124 if (isSingleChanneled) {
128 for (int64_t w = 0; w < wSize; w += wSizeStep) {
129 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
137 for (int64_t w = 0; w < wSize; w += wSizeStep) {
138 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
147 Value res, int64_t wSize, int64_t wSizeStep,
149 bool isSingleChanneled) {
151 if (isSingleChanneled) {
155 for (int64_t w = 0; w < wSize; w += wSizeStep) {
156 res = rewriter.
create<vector::InsertStridedSliceOp>(
163 for (int64_t w = 0; w < wSize; w += wSizeStep) {
164 res = rewriter.
create<vector::InsertStridedSliceOp>(
192 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
195 if (dimPermutation.has_value()) {
197 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
199 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
201 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
202 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
214 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
219 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
220 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
235 std::optional<AffineMap> maybeMaskingMap);
263 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
266 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
267 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
269 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
270 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
277 unsigned operandDimPos;
278 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
282 Value dynamicDim = linalgOp.hasTensorSemantics()
284 linalgOp.getLoc(), operand, operandDimPos)
286 linalgOp.getLoc(), operand, operandDimPos);
287 iterSpaceValueSizes.push_back(dynamicDim);
303 if (!inputVectorSizes.empty()) {
307 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
308 scalableVecDims.append(inputScalableVecDims.begin(),
309 inputScalableVecDims.end());
314 canonicalVecShape = linalgOp.getStaticLoopRanges();
315 scalableVecDims.append(linalgOp.getNumLoops(),
false);
318 LDBG(
"Canonical vector shape: ");
319 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
320 LLVM_DEBUG(llvm::dbgs() <<
"\n");
321 LDBG(
"Scalable vector dims: ");
322 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
323 LLVM_DEBUG(llvm::dbgs() <<
"\n");
325 if (ShapedType::isDynamicShape(canonicalVecShape))
329 initIterSpaceStaticSizes(linalgOp);
334 if (
failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
344 Value VectorizationState::getOrCreateMaskFor(
346 std::optional<AffineMap> maybeMaskingMap) {
348 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
352 assert(!maskableOp.isMasked() &&
353 "Masking an operation that is already masked");
356 assert((!maybeMaskingMap || *maybeMaskingMap) &&
357 "Unexpected null mask permutation map");
359 maybeMaskingMap ? *maybeMaskingMap
361 linalgOp.getNumLoops(), rewriter.
getContext());
363 LDBG(
"Masking map: " << maskingMap <<
"\n");
367 auto activeMaskIt = activeMaskCache.find(maskingMap);
368 if (activeMaskIt != activeMaskCache.end()) {
369 Value mask = activeMaskIt->second;
370 LDBG(
"Reusing mask: " << mask <<
"\n");
381 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
382 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
383 auto maskShape = maskType.getShape();
385 LDBG(
"Mask shape: ");
386 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
387 LLVM_DEBUG(llvm::dbgs() <<
"\n");
389 if (permutedStaticSizes == maskShape) {
390 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
391 activeMaskCache[maskingMap] =
Value();
398 assert(!maskShape.empty() && !upperBounds.empty() &&
399 "Masked 0-d vectors are not supported yet");
402 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
403 maskType, upperBounds);
404 LDBG(
"Creating new mask: " << mask <<
"\n");
405 activeMaskCache[maskingMap] = mask;
416 std::optional<AffineMap> maybeMaskingMap) {
417 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
421 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
424 LDBG(
"No mask required\n");
429 assert(opToMask &&
"Expected a valid operation to mask");
430 auto maskOp = cast<vector::MaskOp>(
432 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
438 LDBG(
"Masked operation: " << *maskOp <<
"\n");
461 "expected projected permutation");
463 assert(res.getNumDims() == res.getNumResults() &&
464 "expected reindexed map with same number of dims and results");
496 std::optional<vector::CombiningKind>
498 using ::mlir::vector::CombiningKind;
503 .Case<arith::AddIOp, arith::AddFOp>(
504 [&](
auto op) {
return CombiningKind::ADD; })
505 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
506 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
507 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
508 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
509 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
510 .Case<arith::MinUIOp>([&](
auto op) {
return CombiningKind::MINUI; })
511 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
512 .Case<arith::MulIOp, arith::MulFOp>(
513 [&](
auto op) {
return CombiningKind::MUL; })
514 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
515 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
516 .Default([&](
auto op) {
return std::nullopt; });
527 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
532 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
533 combinerOps.size() != 1)
537 return combinerOps[0];
543 auto dstVecType = dyn_cast<VectorType>(dstType);
545 if (dstVecType.getRank() == 0)
551 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
563 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
564 return b.
create<vector::MultiDimReductionOp>(
565 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
569 return llvm::to_vector(
583 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
584 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
593 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
595 auto vectorType = state.getCanonicalVecType(
599 if (vectorType.getRank() > 0) {
602 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
604 assert(value.
getType() == vectorType &&
"Incorrect type");
605 write = rewriter.
create<vector::TransferWriteOp>(
606 loc, value, outputOperand->
get(), indices, writeMap);
609 if (!isa<VectorType>(value.
getType()))
610 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
611 assert(value.
getType() == vectorType &&
"Incorrect type");
612 write = rewriter.
create<vector::TransferWriteOp>(
616 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
620 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
621 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
626 LDBG(
"vectorized op: " << *write <<
"\n");
655 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
664 linalgOp.getDpsInitOperand(output.index()), state);
666 newResults.push_back(newResult);
680 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
683 auto loc = indexOp.getLoc();
685 auto targetShape = state.getCanonicalVecShape();
688 llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
689 auto indexSteps = rewriter.
create<arith::ConstantOp>(
694 if (indexOp.getDim() == targetShape.size() - 1)
700 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
701 std::swap(permPattern[indexOp.getDim()], permPattern.back());
705 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
706 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
709 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
710 std::swap(transposition.back(), transposition[indexOp.getDim()]);
712 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
720 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
724 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
729 if (not extractOp.getIndices().empty()) {
730 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
734 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
735 return !VectorType::isValidElementType(type);
755 tensor::ExtractOp extractOp,
758 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
759 auto loc = extractOp.getLoc();
762 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
764 const size_t numIndices = extractOp.getIndices().size();
765 for (
size_t i = 1; i < numIndices; i++) {
766 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
770 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
773 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
776 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
778 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
789 auto targetShape = linalgOp.getStaticLoopRanges();
790 assert(((llvm::count_if(targetShape,
791 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
792 "n-D vectors are not yet supported");
793 assert(targetShape.back() != 1 &&
794 "1-D vectors with the trailing dim eqaual 1 are not yet supported");
800 auto *block = linalgOp.getBlock();
801 if (isa<BlockArgument>(val))
802 return llvm::all_of(block->getArguments(),
803 [&val](
Value v) { return (v != val); });
806 assert(defOp &&
"This is neither a block argument nor an operation result");
811 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
812 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
813 return (indexOp.getDim() != trailingLoopDim);
815 auto *ancestor = block->findAncestorOpInBlock(*defOp);
822 if (isa<arith::ConstantOp>(ancestor))
849 bool &foundIndexOp) {
851 auto targetShape = linalgOp.getStaticLoopRanges();
852 assert(((llvm::count_if(targetShape,
853 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
854 "n-D vectors are not yet supported");
855 assert(targetShape.back() != 1 &&
856 "1-D vectors with the trailing dim 1 are not yet supported");
862 auto *block = linalgOp.getBlock();
863 if (isa<BlockArgument>(val))
864 return llvm::all_of(block->getArguments(),
865 [&val](
Value v) { return (v != val); });
868 assert(defOp &&
"This is neither a block argument nor an operation result");
872 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
873 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
874 foundIndexOp = (indexOp.getDim() == trailingLoopDim);
878 auto *ancestor = block->findAncestorOpInBlock(*defOp);
885 if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
902 LinalgOp &linalgOp) {
904 auto targetShape = linalgOp.getStaticLoopRanges();
905 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
908 if (inputShape.getShape().empty())
914 if (linalgOp.hasDynamicShape())
922 if ((llvm::count_if(targetShape,
923 [](int64_t dimSize) {
return dimSize > 1; }) != 1) ||
924 targetShape.back() == 1)
930 if (inputShape.getShape().back() == 1)
933 bool leadingIdxsLoopInvariant =
true;
938 auto indices = extractOp.getIndices();
939 auto leadIndices = indices.drop_back(1);
942 if (inputShape.getShape()[i] == 1)
948 if (!leadingIdxsLoopInvariant) {
949 LDBG(
"Found gather load: " << extractOp);
957 auto extractOpTrailingIdx = indices.back();
961 if (leadingIdxsLoopInvariant &&
963 LDBG(
"Found scalar broadcast load: " << extractOp);
972 bool foundIndexOp =
false;
973 bool isContiguousLoad =
975 isContiguousLoad &= foundIndexOp;
977 if (isContiguousLoad) {
978 LDBG(
"Found contigous load: " << extractOp);
983 LDBG(
"Found gather load: " << extractOp);
994 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
997 auto loc = extractOp.getLoc();
1000 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1001 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1005 auto passThruConstantOp =
1011 extractOp.getIndices().size(),
1012 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1023 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1024 maskConstantOp, passThruConstantOp);
1025 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1027 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1050 auto resTrailingDim = resultType.getShape().back();
1051 auto zero = rewriter.
create<arith::ConstantOp>(
1053 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1054 auto idx = bvm.
lookup(extractOp.getIndices()[i]);
1055 if (idx.getType().isIndex()) {
1056 transferReadIdxs.push_back(idx);
1060 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1062 bvm.
lookup(extractOp.getIndices()[i]));
1063 transferReadIdxs.push_back(
1064 rewriter.
create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1068 auto dstRank = resultType.getRank();
1069 auto srcRank = extractOp.getTensor().getType().getRank();
1078 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1079 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1080 permutationMap, inBounds);
1082 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1090 int32_t rankDiff = dstRank - srcRank;
1098 while (rankDiff > 0) {
1099 permutationMap = permutationMap.insertResult(
1104 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1105 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1108 LDBG(
"Vectorised as contiguous load: " << extractOp);
1121 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1122 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1126 (outputType && reduceType.getShape() == outputType.getShape()))
1155 LDBG(
"vectorize op " << *op <<
"\n");
1158 if (!customVectorizationHooks.empty()) {
1159 for (
auto &customFunc : customVectorizationHooks) {
1169 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1179 auto blockArg = dyn_cast<BlockArgument>(operand);
1180 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1181 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1185 linalgOp.getRegionOutputArgs(),
1186 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1189 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1191 if (!reductionOperands.empty()) {
1192 assert(reductionOperands.size() == 1);
1194 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1195 reductionOperands[0].second, bvm);
1202 VectorType firstMaxRankedType;
1204 auto vecOperand = bvm.
lookup(operand);
1205 assert(vecOperand &&
"Vector operand couldn't be found");
1207 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1208 if (vecType && (!firstMaxRankedType ||
1209 firstMaxRankedType.getRank() < vecType.getRank()))
1210 firstMaxRankedType = vecType;
1216 assert(vecOperand &&
"Vector operand couldn't be found");
1218 if (firstMaxRankedType) {
1221 firstMaxRankedType.getScalableDims());
1224 vecOperands.push_back(vecOperand);
1230 resultTypes.push_back(
1233 firstMaxRankedType.getScalableDims())
1269 LDBG(
"Vectorizing operation as linalg generic\n");
1270 Block *block = linalgOp.getBlock();
1277 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1279 if (linalgOp.getNumDpsInits() == 0)
1284 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1285 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1286 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1287 if (linalgOp.isScalar(opOperand)) {
1288 bvm.
map(bbarg, opOperand->get());
1294 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1301 zeroPos.push_back(result.index());
1307 VectorType readType;
1309 if (linalgOp.isDpsInput(opOperand)) {
1312 readType = state.getCanonicalVecType(elemType);
1319 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1325 loc, readType, opOperand->get(), indices, readMap);
1326 read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1331 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1333 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1339 if (readType.getRank() == 0)
1354 hooks.push_back(vectorizeYield);
1361 hooks.push_back(vectorizeIndex);
1368 hooks.push_back(vectorizeExtract);
1375 LDBG(
"failed to vectorize: " << op <<
"\n");
1380 state.maskOperation(rewriter, result.
newOp, linalgOp);
1381 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1396 auto padValue = padOp.getConstantPaddingValue();
1398 int64_t rank = inputVectorSizes.size();
1400 auto vectorType =
VectorType::get(inputVectorSizes, padValue.getType());
1408 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1409 .reifyResultShapes(rewriter, reifiedReturnShapes);
1411 assert(
succeeded(status) &&
"failed to reify result shapes");
1412 auto emptyOp = rewriter.
create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
1413 padValue.getType());
1417 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
1418 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1419 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1426 auto maskedOp = cast<vector::MaskOp>(
1430 maskedOp->getResult(0),
1434 bool needMaskForWrite = llvm::any_of(
1435 llvm::zip_equal(inputVectorSizes, padOp.getResultType().getShape()),
1436 [](
auto it) { return std::get<0>(it) != std::get<1>(it); });
1437 if (needMaskForWrite) {
1438 Value maskForWrite = rewriter.
create<vector::CreateMaskOp>(
1439 loc, maskType, reifiedReturnShapes[0]);
1442 newResults.push_back(write->
getResult(0));
1450 LDBG(
"reduction precondition failed: no reduction iterator\n");
1453 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1454 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1460 LDBG(
"reduction precondition failed: reduction detection failed\n");
1469 if (!isa<linalg::GenericOp, linalg::FillOp, linalg::CopyOp,
1470 linalg::ContractionOpInterface>(op.getOperation()))
1473 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1486 LDBG(
"Iteration space static sizes:");
1487 LLVM_DEBUG(llvm::interleaveComma(shape, llvm::dbgs()));
1488 LLVM_DEBUG(llvm::dbgs() <<
"\n");
1490 if (inputVectorSizes.size() != shape.size()) {
1491 LDBG(
"Input vector sizes don't match the number of loops");
1494 if (ShapedType::isDynamicShape(inputVectorSizes)) {
1495 LDBG(
"Input vector sizes can't have dynamic dimensions");
1498 if (!llvm::all_of(llvm::zip(shape, inputVectorSizes),
1499 [](std::tuple<int64_t, int64_t> sizePair) {
1500 int64_t staticSize = std::get<0>(sizePair);
1501 int64_t inputSize = std::get<1>(sizePair);
1502 return ShapedType::isDynamic(staticSize) ||
1503 staticSize <= inputSize;
1505 LDBG(
"Input vector sizes must be greater than or equal to iteration space "
1515 bool vectorizeNDExtract) {
1517 if (llvm::any_of(linalgOp.getStaticShape(),
1518 [](int64_t dim) { return dim == 0; }))
1521 if (!inputVectorSizes.empty() &&
1526 if (linalgOp.hasDynamicShape() &&
1528 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1541 customPreconditions,
1544 customPrecondition(&innerOp, vectorizeNDExtract));
1548 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1549 return !VectorType::isValidElementType(type);
1553 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1554 return !VectorType::isValidElementType(type);
1564 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1570 LDBG(
"precondition failed: not projected permutations\n");
1574 LDBG(
"precondition failed: reduction preconditions\n");
1583 auto padValue = padOp.getConstantPaddingValue();
1585 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1593 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1594 std::optional<int64_t> res = getConstantIntValue(v);
1595 return !res.has_value() || res.value() != 0;
1597 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
1609 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1610 "Number of input vector sizes and scalable dims doesn't match");
1612 if (inputVectorSizes.empty())
1615 bool isScalable = inputScalableVecDims.back();
1620 auto linalgOp = dyn_cast<LinalgOp>(op);
1628 inputScalableVecDims)))
1632 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
1634 vectorizeNDExtract);
1636 .Case<tensor::PadOp>([&](
auto padOp) {
1639 .Default([](
auto) {
return failure(); });
1645 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
1647 for (
auto op : make_early_inc_range(toReplace)) {
1651 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
1652 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
1666 bool vectorizeNDExtract) {
1667 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
1668 LDBG(
"Input vector sizes: ");
1669 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1670 LLVM_DEBUG(llvm::dbgs() <<
"\n");
1671 LDBG(
"Input scalable vector dims: ");
1672 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
1673 LLVM_DEBUG(llvm::dbgs() <<
"\n");
1676 vectorizeNDExtract))) {
1677 LDBG(
"Vectorization pre-conditions failed\n");
1683 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
1684 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
1685 inputScalableVecDims))) {
1686 LDBG(
"Vectorization state couldn't be initialized\n");
1692 auto vectorizeResult =
1694 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
1700 llvm::append_range(results, (*convOr)->getResults());
1704 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
1717 .Case<tensor::PadOp>([&](
auto padOp) {
1721 .Default([](
auto) {
return failure(); });
1723 if (
failed(vectorizeResult)) {
1724 LDBG(
"Vectorization failed\n");
1728 if (!results.empty())
1737 memref::CopyOp copyOp) {
1739 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
1740 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
1741 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
1746 if (!VectorType::isValidElementType(srcElementType) ||
1747 !VectorType::isValidElementType(dstElementType))
1758 loc, readType, copyOp.getSource(), indices,
1760 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
1765 loc,
readValue, copyOp.getTarget(), indices,
1777 return cast<IntegerAttr>(attr).getInt();
1786 for (
auto o : ofrs) {
1787 if (
auto val = llvm::dyn_cast_if_present<Value>(o)) {
1788 result.push_back(val);
1790 result.push_back(rewriter.
create<arith::ConstantIndexOp>(
1809 tensor::PadOp padOp,
Value dest) {
1810 auto sourceType = padOp.getSourceType();
1811 auto resultType = padOp.getResultType();
1812 if (!VectorType::isValidElementType(sourceType.getElementType()))
1818 auto padValue = padOp.getConstantPaddingValue();
1820 if (!sourceType.hasStaticShape())
1823 auto elemType = sourceType.getElementType();
1824 padValue = rewriter.
create<arith::ConstantOp>(
1825 padOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
1831 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
1832 if (!sourceType.isDynamicDim(i)) {
1833 vecShape.push_back(sourceType.getDimSize(i));
1836 readInBounds.push_back(
true);
1837 writeInBounds.push_back(
true);
1838 }
else if (!resultType.isDynamicDim(i)) {
1842 vecShape.push_back(resultType.getDimSize(i));
1845 readInBounds.push_back(
false);
1847 writeInBounds.push_back(
1849 static_cast<int64_t
>(0));
1856 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
1861 rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1862 auto read = rewriter.
create<vector::TransferReadOp>(
1863 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
1868 if (llvm::equal(vecShape, resultType.getShape()) &&
1869 llvm::all_of(writeInBounds, [](
bool b) {
return b; }))
1871 dest = fill.output();
1885 template <
typename OpTy>
1891 bool changed =
false;
1893 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
1894 if (
auto op = dyn_cast<OpTy>(user))
1895 changed |= rewriteUser(rewriter, padOp, op).
succeeded();
1901 tensor::PadOp padOp, OpTy op)
const = 0;
1929 vector::TransferReadOp xferOp)
const override {
1931 if (!padOp.hasZeroLowPad())
1934 auto padValue = padOp.getConstantPaddingValue();
1938 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
1943 xferOp->setAttr(xferOp.getInBoundsAttrName(),
1945 xferOp.getSourceMutable().assign(padOp.getSource());
1946 xferOp.getPaddingMutable().assign(padValue);
1991 vector::TransferWriteOp xferOp)
const override {
1993 if (xferOp.getTransferRank() == 0)
1997 if (!padOp.hasZeroLowPad())
2000 auto padValue = padOp.getConstantPaddingValue();
2004 if (!xferOp->hasOneUse())
2006 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2010 if (!trimPadding.hasZeroOffset())
2013 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2021 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2022 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2024 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2040 tensor::ExtractSliceOp afterTrimming)
const {
2043 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2044 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2047 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2048 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2053 if (t1.getRank() != t2.getRank())
2058 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2059 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2061 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2066 if (t1.getNumDynamicDims() == 0)
2074 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2078 assert(
static_cast<size_t>(t1.getRank()) ==
2079 beforeSlice.getMixedSizes().size());
2080 assert(
static_cast<size_t>(t2.getRank()) ==
2081 afterTrimming.getMixedSizes().size());
2083 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2085 if (!t1.isDynamicDim(i))
2087 auto size1 = beforeSlice.getMixedSizes()[i];
2088 auto size2 = afterTrimming.getMixedSizes()[i];
2095 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2096 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2102 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2103 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2104 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2105 minOp1.getOperands() == minOp2.getOperands())
2145 tensor::InsertSliceOp insertOp)
const override {
2147 if (!padOp.hasZeroLowPad())
2150 if (!insertOp.hasUnitStride())
2153 auto padValue = padOp.getConstantPaddingValue();
2157 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2160 if (insertOp.getDest() == padOp.getResult())
2164 padOp.getType().getElementType());
2165 unsigned vecRank = vecType.getRank();
2166 unsigned tensorRank = insertOp.getType().getRank();
2171 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2173 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2174 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2185 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2186 auto read = rewriter.
create<vector::TransferReadOp>(
2187 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2196 insertOp, read, insertOp.getDest(), writeIndices,
2225 LDBG(
"interleavedUses precondition failed, firstOp: "
2226 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2229 for (
auto v : values) {
2230 for (
auto &u : v.getUses()) {
2232 if (owner == firstOp || owner == secondOp)
2238 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2239 <<
", second op: " << *secondOp <<
"\n");
2249 memref::SubViewOp subViewOp;
2251 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2253 return memref::SubViewOp();
2254 subViewOp = newSubViewOp;
2266 if (xferOp.getMask())
2270 Value viewOrAlloc = xferOp.getSource();
2279 Value subView = subViewOp.getResult();
2282 memref::CopyOp copyOp;
2283 for (
auto &u : subView.
getUses()) {
2284 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2285 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2286 if (newCopyOp.getTarget() != subView)
2300 for (
auto &u : viewOrAlloc.
getUses()) {
2301 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2302 assert(isa<MemRefType>(newFillOp.output().getType()));
2303 if (newFillOp.output() != viewOrAlloc)
2307 maybeFillOp = newFillOp;
2312 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2314 "padding value does not match fill");
2317 Value in = copyOp.getSource();
2323 Value res = rewriter.
create<vector::TransferReadOp>(
2324 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2325 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2330 rewriter.
eraseOp(maybeFillOp);
2342 if (xferOp.getMask())
2346 Value viewOrAlloc = xferOp.getSource();
2355 Value subView = subViewOp.getResult();
2358 memref::CopyOp copyOp;
2359 for (
auto &u : subViewOp.getResult().getUses()) {
2360 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2361 if (newCopyOp.getSource() != subView)
2373 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2374 Value out = copyOp.getTarget();
2381 rewriter.
create<vector::TransferWriteOp>(
2382 xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2383 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2400 template <
int N,
typename IntTy,
typename... IntTy2>
2402 val = shapedType.getShape()[N];
2407 template <
typename... IntTy>
2409 bindShapeDims<0>(shapedType, vals...);
2413 bool isCastOfBlockArgument(
Operation *op) {
2418 bool isSupportedPoolKind(vector::CombiningKind kind) {
2420 case vector::CombiningKind::ADD:
2421 case vector::CombiningKind::MAXF:
2422 case vector::CombiningKind::MAXIMUMF:
2423 case vector::CombiningKind::MAXSI:
2424 case vector::CombiningKind::MAXUI:
2425 case vector::CombiningKind::MINF:
2426 case vector::CombiningKind::MINIMUMF:
2427 case vector::CombiningKind::MINSI:
2428 case vector::CombiningKind::MINUI:
2469 struct Conv1DGenerator
2471 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
2474 strideW(strideW), dilationW(dilationW) {
2476 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2478 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2479 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2480 resShaped = linalgOp.getDpsInitOperand(0)->get();
2481 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2482 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2483 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2484 if (!lhsShapedType || !rhsShapedType || !resShapedType)
2488 if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) ||
2489 (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1)))
2497 if (!setOperKind(reduceOp))
2500 if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2501 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2505 auto rhsRank = rhsShapedType.getRank();
2508 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2542 int64_t nSize, wSize, cSize, kwSize, fSize;
2545 switch (conv1DOpOrder) {
2548 nSize = fSize = cSize = 0;
2555 (wSize + kwSize - 1)};
2556 rhsShape = {kwSize};
2577 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2582 rhsShape = {kwSize, cSize, fSize};
2585 rhsShape = {kwSize};
2588 resShape = {nSize, wSize, fSize};
2604 lhsShape = {nSize, cSize,
2608 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2612 rhsShape = {fSize, cSize, kwSize};
2615 rhsShape = {kwSize};
2618 resShape = {nSize, fSize, wSize};
2622 vector::TransferWriteOp write;
2623 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
2628 int64_t wSizeStep = strideW == 1 ? wSize : 1;
2630 Type lhsEltType = lhsShapedType.getElementType();
2631 Type rhsEltType = rhsShapedType.getElementType();
2632 Type resEltType = resShapedType.getElementType();
2642 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2645 Value rhs =
nullptr;
2647 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2649 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
2655 switch (conv1DOpOrder) {
2663 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2664 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
2666 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2670 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
2672 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2673 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
2684 kwSize, strideW, dilationW, wSizeStep,
2690 wSizeStep, isSingleChanneled);
2692 auto linearIndex = [&](int64_t kw, int64_t w) {
2693 return kw * (wSize / wSizeStep) + w;
2699 for (int64_t kw = 0; kw < kwSize; ++kw) {
2700 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2703 if (isSingleChanneled) {
2704 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
2705 lhsVals[linearIndex(kw, w)],
2706 rhsVals[kw], resVals[w]);
2708 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
2709 lhsVals[linearIndex(kw, w)],
2710 rhsVals[kw], resVals[w]);
2714 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
2730 switch (conv1DOpOrder) {
2737 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
2738 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
2744 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
2752 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
2753 if (srcElementType == dstElementType)
2758 const Type dstType =
2759 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
2761 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
2762 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
2765 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
2766 srcWidth < dstWidth)
2767 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
2769 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
2770 srcWidth < dstWidth)
2771 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
2773 assert(
false &&
"unhandled promotion case");
2780 vector::IteratorType par = vector::IteratorType::parallel;
2781 vector::IteratorType red = vector::IteratorType::reduction;
2786 return rewriter.
create<vector::ContractionOp>(
2788 MapList{{n, w, c}, {c, f}, {n, w, f}},
2796 return rewriter.
create<vector::OuterProductOp>(
2797 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
2823 int64_t nSize, wSize, cSize, kwSize;
2829 vector::TransferWriteOp write;
2830 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
2835 int64_t wSizeStep = strideW == 1 ? wSize : 1;
2837 Type lhsEltType = lhsShapedType.getElementType();
2838 Type rhsEltType = rhsShapedType.getElementType();
2839 Type resEltType = resShapedType.getElementType();
2844 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
2848 VectorType resType =
VectorType::get({nSize, wSize, cSize}, resEltType);
2852 Value lhs = rewriter.
create<vector::TransferReadOp>(
2853 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
2855 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2858 Value res = rewriter.
create<vector::TransferReadOp>(
2859 loc, resType, resShaped,
ValueRange{zero, zero, zero});
2868 for (int64_t kw = 0; kw < kwSize; ++kw) {
2869 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2870 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
2878 for (int64_t kw = 0; kw < kwSize; ++kw) {
2879 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
2883 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2884 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
2891 auto linearIndex = [&](int64_t kw, int64_t w) {
2892 return kw * (wSize / wSizeStep) + w;
2896 for (int64_t kw = 0; kw < kwSize; ++kw) {
2897 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2898 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
2899 lhsVals[linearIndex(kw, w)],
2900 rhsVals[kw], resVals[w]);
2905 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
2907 for (
auto &collection :
2908 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
2909 for (
Value v : collection)
2916 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2917 res = rewriter.
create<vector::InsertStridedSliceOp>(
2918 loc, resVals[w], res,
2928 .
create<vector::TransferWriteOp>(loc, res, resShaped,
2936 auto rhsTy = cast<ShapedType>(rhs.
getType());
2937 auto resTy = cast<ShapedType>(res.
getType());
2940 lhs =
promote(rewriter, loc, lhs, resTy);
2942 rhs = rewriter.
create<vector::BroadcastOp>(
2943 loc, resTy.clone(rhsTy.getElementType()), rhs);
2944 rhs =
promote(rewriter, loc, rhs, resTy);
2949 if (isa<FloatType>(resTy.getElementType()))
2950 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
2952 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
2953 return rewriter.
create<arith::AddIOp>(loc, mul, res);
2961 if (!iters({Par(), Red()}))
2963 "failed to match conv::W 1-par 1-red");
2966 if (layout({ {w + kw},
2979 if (!iters({Par(), Par(), Par(), Red(), Red()}))
2981 op,
"failed to match conv::Nwc 3-par 2-red");
2984 if (layout({ {n, strideW * w + dilationW * kw, c},
2997 if (!iters({Par(), Par(), Par(), Red(), Red()}))
2999 op,
"failed to match conv::Ncw 3-par 2-red");
3001 if (layout({ {n, c, strideW * w + dilationW * kw},
3014 if (!iters({Par(), Par(), Par(), Red()}))
3016 "failed to match pooling 3-par 1-red");
3019 if (layout({ {n, strideW * w + dilationW * kw, c},
3032 if (!iters({Par(), Par(), Par(), Red()}))
3034 "failed to match pooling 3-par 1-red");
3036 if (layout({ {n, c, strideW * w + dilationW * kw},
3049 if (!iters({Par(), Par(), Par(), Red()}))
3051 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3054 if (layout({ {n, strideW * w + dilationW * kw, c},
3057 return depthwiseConv();
3063 enum OperKind { Conv, Pool };
3065 OperKind oper = Conv;
3067 StringAttr poolExtOp;
3068 bool isPoolExt =
false;
3069 int strideW, dilationW;
3070 Value lhsShaped, rhsShaped, resShaped;
3071 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3082 int numBlockArguments = llvm::count_if(
3084 switch (numBlockArguments) {
3089 return !isa<BlockArgument>(v);
3091 Operation *feedOp = (*feedValIt).getDefiningOp();
3092 if (isCastOfBlockArgument(feedOp)) {
3096 }
else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3098 if (isa<BlockArgument>(v))
3100 if (Operation *op = v.getDefiningOp())
3101 return isCastOfBlockArgument(op);
3130 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3131 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3132 Conv1DGenerator e(rewriter, op, stride, dilation);
3133 auto res = e.generateNonChanneledConv();
3136 res = e.generateNwcConv();
3139 res = e.generateNcwConv();
3142 res = e.generateNwcPooling();
3145 res = e.generateNcwPooling();
3148 return e.generateDilatedConv();
3152 using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
3157 if (
failed(resultOrFail))
3161 rewriter.
eraseOp(op.getOperation());
3164 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static ArrayRef< int64_t > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract)
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp)
Try to vectorize convOp as a convolution.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Check whether extractOp would be a gather or a contiguous load Op after vectorising linalgOp.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
An integer constant appearing in affine expression.
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map witn numDims input dimensions and filtered results using keepDimFilter...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk the operations in this block.
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Return success if the operation can be vectorized.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Emit a suitable vector form for an operation.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.