36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
64 template <
typename OpType>
67 block.
walk([&](OpType op) {
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize,
int strideW,
int dilationW,
84 int64_t wSizeStep,
bool isSingleChanneled) {
86 if (isSingleChanneled) {
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.
create<vector::ExtractOp>(
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep,
bool isSingleChanneled) {
136 if (isSingleChanneled) {
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
159 Value res, int64_t wSize, int64_t wSizeStep,
161 bool isSingleChanneled) {
163 if (isSingleChanneled) {
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.
create<vector::InsertStridedSliceOp>(
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.
create<vector::InsertStridedSliceOp>(
191 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
208 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
211 if (dimPermutation.has_value()) {
213 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
215 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
217 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
230 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
235 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
242 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
251 std::optional<AffineMap> maybeMaskingMap);
279 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
282 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
285 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
286 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
293 unsigned operandDimPos;
294 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
298 Value dynamicDim = linalgOp.hasPureTensorSemantics()
300 linalgOp.getLoc(), operand, operandDimPos)
302 linalgOp.getLoc(), operand, operandDimPos);
303 iterSpaceValueSizes.push_back(dynamicDim);
319 if (!inputVectorSizes.empty()) {
323 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324 scalableVecDims.append(inputScalableVecDims.begin(),
325 inputScalableVecDims.end());
330 canonicalVecShape = linalgOp.getStaticLoopRanges();
331 scalableVecDims.append(linalgOp.getNumLoops(),
false);
334 LDBG(
"Canonical vector shape: ");
335 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336 LLVM_DEBUG(llvm::dbgs() <<
"\n");
337 LDBG(
"Scalable vector dims: ");
338 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339 LLVM_DEBUG(llvm::dbgs() <<
"\n");
341 if (ShapedType::isDynamicShape(canonicalVecShape))
345 initIterSpaceStaticSizes(linalgOp);
350 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
360 Value VectorizationState::getOrCreateMaskFor(
362 std::optional<AffineMap> maybeMaskingMap) {
364 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
368 assert(!maskableOp.isMasked() &&
369 "Masking an operation that is already masked");
372 assert((!maybeMaskingMap || *maybeMaskingMap) &&
373 "Unexpected null mask permutation map");
375 maybeMaskingMap ? *maybeMaskingMap
377 linalgOp.getNumLoops(), rewriter.
getContext());
379 LDBG(
"Masking map: " << maskingMap <<
"\n");
383 auto activeMaskIt = activeMaskCache.find(maskingMap);
384 if (activeMaskIt != activeMaskCache.end()) {
385 Value mask = activeMaskIt->second;
386 LDBG(
"Reusing mask: " << mask <<
"\n");
397 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
399 auto maskShape = maskType.getShape();
401 LDBG(
"Mask shape: ");
402 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403 LLVM_DEBUG(llvm::dbgs() <<
"\n");
405 if (permutedStaticSizes == maskShape) {
406 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
407 activeMaskCache[maskingMap] =
Value();
414 assert(!maskShape.empty() && !upperBounds.empty() &&
415 "Masked 0-d vectors are not supported yet");
418 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
419 maskType, upperBounds);
420 LDBG(
"Creating new mask: " << mask <<
"\n");
421 activeMaskCache[maskingMap] = mask;
432 std::optional<AffineMap> maybeMaskingMap) {
433 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
437 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
440 LDBG(
"No mask required\n");
445 assert(opToMask &&
"Expected a valid operation to mask");
446 auto maskOp = cast<vector::MaskOp>(
448 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
454 LDBG(
"Masked operation: " << *maskOp <<
"\n");
477 "expected projected permutation");
479 assert(res.getNumDims() == res.getNumResults() &&
480 "expected reindexed map with same number of dims and results");
512 std::optional<vector::CombiningKind>
514 using ::mlir::vector::CombiningKind;
519 .Case<arith::AddIOp, arith::AddFOp>(
520 [&](
auto op) {
return CombiningKind::ADD; })
521 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
522 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
523 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
524 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
525 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
527 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
528 .Case<arith::MulIOp, arith::MulFOp>(
529 [&](
auto op) {
return CombiningKind::MUL; })
530 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
531 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
532 .Default([&](
auto op) {
return std::nullopt; });
543 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
548 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
549 combinerOps.size() != 1)
553 return combinerOps[0];
559 auto dstVecType = dyn_cast<VectorType>(dstType);
561 if (dstVecType.getRank() == 0)
567 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
579 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
580 return b.
create<vector::MultiDimReductionOp>(
581 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
585 return llvm::to_vector(
592 return isa<linalg::ReduceOp>(op) ||
593 (isa<linalg::GenericOp>(op) &&
607 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
608 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
617 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
619 auto vectorType = state.getCanonicalVecType(
623 if (vectorType.getRank() > 0) {
626 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
628 assert(value.
getType() == vectorType &&
"Incorrect type");
629 write = rewriter.
create<vector::TransferWriteOp>(
630 loc, value, outputOperand->
get(), indices, writeMap);
633 if (!isa<VectorType>(value.
getType()))
634 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
635 assert(value.
getType() == vectorType &&
"Incorrect type");
636 write = rewriter.
create<vector::TransferWriteOp>(
640 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
644 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
645 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
650 LDBG(
"vectorized op: " << *write <<
"\n");
660 std::function<LogicalResult(
Operation *,
bool)>;
679 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
688 linalgOp.getDpsInitOperand(output.index()), state);
690 newResults.push_back(newResult);
704 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
707 auto loc = indexOp.getLoc();
710 auto dim = indexOp.getDim();
712 auto indexVectorType =
714 state.getScalableVecDims()[dim]);
715 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
719 if (dim == targetShape.size() - 1)
725 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
726 std::swap(permPattern[dim], permPattern.back());
730 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
731 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
734 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
735 std::swap(transposition.back(), transposition[dim]);
737 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
745 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
749 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
754 if (not extractOp.getIndices().empty()) {
755 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
759 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
760 return !VectorType::isValidElementType(type);
780 tensor::ExtractOp extractOp,
783 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
784 auto loc = extractOp.getLoc();
787 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
789 const size_t numIndices = extractOp.getIndices().size();
790 for (
size_t i = 1; i < numIndices; i++) {
791 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
795 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
798 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
801 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
803 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
814 auto targetShape = linalgOp.getStaticLoopRanges();
815 assert(((llvm::count_if(targetShape,
816 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
817 "n-D vectors are not yet supported");
818 assert(targetShape.back() != 1 &&
819 "1-D vectors with the trailing dim eqaual 1 are not yet supported");
825 auto *block = linalgOp.getBlock();
826 if (isa<BlockArgument>(val))
827 return llvm::all_of(block->getArguments(),
828 [&val](
Value v) { return (v != val); });
831 assert(defOp &&
"This is neither a block argument nor an operation result");
836 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
837 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
838 return (indexOp.getDim() != trailingLoopDim);
840 auto *ancestor = block->findAncestorOpInBlock(*defOp);
847 if (isa<arith::ConstantOp>(ancestor))
874 bool &foundIndexOp) {
876 auto targetShape = linalgOp.getStaticLoopRanges();
877 assert(((llvm::count_if(targetShape,
878 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
879 "n-D vectors are not yet supported");
880 assert(targetShape.back() != 1 &&
881 "1-D vectors with the trailing dim 1 are not yet supported");
887 auto *block = linalgOp.getBlock();
888 if (isa<BlockArgument>(val))
889 return llvm::all_of(block->getArguments(),
890 [&val](
Value v) { return (v != val); });
893 assert(defOp &&
"This is neither a block argument nor an operation result");
897 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
898 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
899 foundIndexOp = (indexOp.getDim() == trailingLoopDim);
903 auto *ancestor = block->findAncestorOpInBlock(*defOp);
910 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
932 LinalgOp &linalgOp) {
934 auto targetShape = linalgOp.getStaticLoopRanges();
935 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
938 if (inputShape.getShape().empty())
944 if (linalgOp.hasDynamicShape())
952 if ((llvm::count_if(targetShape,
953 [](int64_t dimSize) {
return dimSize > 1; }) != 1) ||
954 targetShape.back() == 1)
960 if (inputShape.getShape().back() == 1)
963 bool leadingIdxsLoopInvariant =
true;
968 auto indices = extractOp.getIndices();
969 auto leadIndices = indices.drop_back(1);
972 if (inputShape.getShape()[i] == 1)
978 if (!leadingIdxsLoopInvariant) {
979 LDBG(
"Found gather load: " << extractOp);
987 auto extractOpTrailingIdx = indices.back();
991 if (leadingIdxsLoopInvariant &&
993 LDBG(
"Found scalar broadcast load: " << extractOp);
1002 bool foundIndexOp =
false;
1003 bool isContiguousLoad =
1005 isContiguousLoad &= foundIndexOp;
1007 if (isContiguousLoad) {
1008 LDBG(
"Found contigous load: " << extractOp);
1013 LDBG(
"Found gather load: " << extractOp);
1024 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1027 auto loc = extractOp.getLoc();
1030 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1031 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1035 auto passThruConstantOp =
1041 extractOp.getIndices().size(),
1042 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1053 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1054 maskConstantOp, passThruConstantOp);
1055 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1057 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1080 auto zero = rewriter.
create<arith::ConstantOp>(
1082 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1083 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1085 transferReadIdxs.push_back(idx);
1089 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1092 resultType.getScalableDims().back()),
1094 transferReadIdxs.push_back(
1095 rewriter.
create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1099 auto dstRank = resultType.getRank();
1100 auto srcRank = extractOp.getTensor().getType().getRank();
1109 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1110 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1111 permutationMap, inBounds);
1113 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1121 int32_t rankDiff = dstRank - srcRank;
1129 while (rankDiff > 0) {
1130 permutationMap = permutationMap.insertResult(
1135 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1136 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1139 LDBG(
"Vectorised as contiguous load: " << extractOp);
1152 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1153 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1157 (outputType && reduceType.getShape() == outputType.getShape()))
1186 LDBG(
"vectorize op " << *op <<
"\n");
1189 if (!customVectorizationHooks.empty()) {
1190 for (
auto &customFunc : customVectorizationHooks) {
1200 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1210 auto blockArg = dyn_cast<BlockArgument>(operand);
1211 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1212 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1216 linalgOp.getRegionOutputArgs(),
1217 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1220 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1222 if (!reductionOperands.empty()) {
1223 assert(reductionOperands.size() == 1);
1225 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1226 reductionOperands[0].second, bvm);
1233 VectorType firstMaxRankedType;
1235 auto vecOperand = bvm.
lookup(operand);
1236 assert(vecOperand &&
"Vector operand couldn't be found");
1238 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1239 if (vecType && (!firstMaxRankedType ||
1240 firstMaxRankedType.getRank() < vecType.getRank()))
1241 firstMaxRankedType = vecType;
1247 assert(vecOperand &&
"Vector operand couldn't be found");
1249 if (firstMaxRankedType) {
1252 firstMaxRankedType.getScalableDims());
1255 vecOperands.push_back(vecOperand);
1261 resultTypes.push_back(
1264 firstMaxRankedType.getScalableDims())
1296 static LogicalResult
1300 LDBG(
"Vectorizing operation as linalg generic\n");
1301 Block *block = linalgOp.getBlock();
1308 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1310 if (linalgOp.getNumDpsInits() == 0)
1315 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1316 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1317 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1318 if (linalgOp.isScalar(opOperand)) {
1319 bvm.
map(bbarg, opOperand->get());
1325 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1331 if (isa<AffineConstantExpr>(result.value())) {
1332 zeroPos.push_back(result.index());
1338 VectorType readType;
1340 if (linalgOp.isDpsInput(opOperand)) {
1343 readType = state.getCanonicalVecType(elemType);
1350 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1360 for (
auto idx : broadcastedDims)
1361 inBounds[idx] =
true;
1364 loc, readType, opOperand->get(), indices, readMap,
1366 read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1371 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1373 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1379 if (readType.getRank() == 0)
1394 hooks.push_back(vectorizeYield);
1401 hooks.push_back(vectorizeIndex);
1408 hooks.push_back(vectorizeExtract);
1415 LDBG(
"failed to vectorize: " << op <<
"\n");
1420 state.maskOperation(rewriter, result.
newOp, linalgOp);
1421 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1446 bool useInBoundsInsteadOfMasking) {
1448 auto inputType = cast<VectorType>(input.
getType());
1449 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1450 inputType.getElementType());
1451 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1452 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1453 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1455 if (useInBoundsInsteadOfMasking) {
1457 for (
unsigned i = 0; i < rank; i++)
1458 inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1459 !ShapedType::isDynamic(destShape[i]);
1467 assert(llvm::none_of(
1468 destShape.drop_front(inputVectorSizes.size()),
1469 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1470 "Only dims aligned with inputVectorSizes may be dynamic");
1471 if (useInBoundsInsteadOfMasking)
1473 bool needMaskForWrite = !llvm::equal(
1474 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1475 if (needMaskForWrite) {
1477 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1478 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1481 Value maskForWrite =
1482 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1514 static LogicalResult
1522 auto padValue = packOp.getPaddingValue();
1524 padValue = rewriter.
create<arith::ConstantOp>(
1525 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1528 LogicalResult status =
1529 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1530 .reifyResultShapes(rewriter, reifiedReturnShapes);
1532 assert(succeeded(status) &&
"failed to reify result shapes");
1537 bool useInBoundsInsteadOfMasking =
false;
1538 if (inputVectorSizes.empty()) {
1540 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1541 useInBoundsInsteadOfMasking =
true;
1546 auto innerTiles = packOp.getStaticInnerTiles();
1547 auto innerDimsPos = packOp.getInnerDimsPos();
1548 auto outerDimsPerm = packOp.getOuterDimsPerm();
1549 if (!outerDimsPerm.empty())
1552 for (
auto [idx, size] :
enumerate(innerTiles))
1553 inputShape[innerDimsPos[idx]] *= size;
1555 rewriter, loc, packOp.getSource(), inputShape, padValue,
1556 useInBoundsInsteadOfMasking);
1560 destShape.append(innerTiles.begin(), innerTiles.end());
1562 packOp.getDestType().getElementType());
1564 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1567 auto destPermutation =
1569 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1570 loc, shapeCastOp.getResult(), destPermutation);
1574 rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1575 inputVectorSizes,
false);
1576 newResults.push_back(write->getResult(0));
1589 static LogicalResult
1597 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1602 bool useInBoundsInsteadOfMasking =
false;
1605 auto destSize = unpackOp.getDestRank();
1607 if (!inputVectorSizes.empty())
1608 assert(inputVectorSizes.size() == destSize &&
1609 "Incorrect number of input vector sizes");
1620 if (vectorSizes.empty()) {
1621 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1622 if (!outerDimsPerm.empty())
1625 vectorSizes[pos] *= innerTiles[i];
1627 useInBoundsInsteadOfMasking =
true;
1651 for (
auto [index, size] :
enumerate(innerTiles)) {
1652 readVectorSizes[innerDimPos[index]] =
1655 if (!outerDimsPerm.empty()) {
1658 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1662 LogicalResult status =
1663 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1664 .reifyResultShapes(rewriter, reifiedRetShapes);
1665 if (status.failed()) {
1666 LDBG(
"Unable to reify result shapes of " << unpackOp);
1671 auto padValue = rewriter.
create<arith::ConstantOp>(
1672 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1677 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1680 PackingMetadata packMetadata;
1683 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1685 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1687 RankedTensorType stripMineTensorType =
1690 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1691 loc, readResult, lastDimToInsertPosPerm);
1694 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1695 stripMineTensorType, packMetadata.reassociations);
1696 mlir::VectorType vecCollapsedType =
1697 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1698 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1699 loc, vecCollapsedType, transposeOp->getResult(0));
1704 unpackOp.getDestType().hasStaticShape()
1706 : shapeCastOp.getResultVectorType().getShape());
1708 rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1709 writeVectorSizes, useInBoundsInsteadOfMasking);
1710 newResults.push_back(write->
getResult(0));
1717 static LogicalResult
1721 auto padValue = padOp.getConstantPaddingValue();
1729 LogicalResult status =
1730 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1731 .reifyResultShapes(rewriter, reifiedReturnShapes);
1733 assert(succeeded(status) &&
"failed to reify result shapes");
1735 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1738 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1740 newResults.push_back(write->
getResult(0));
1748 LDBG(
"reduction precondition failed: no reduction iterator\n");
1751 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1752 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1758 LDBG(
"reduction precondition failed: reduction detection failed\n");
1765 static LogicalResult
1767 bool flatten1DDepthwiseConv) {
1768 if (flatten1DDepthwiseConv) {
1769 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1774 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1775 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1781 Value lhs = conv.getDpsInputOperand(0)->get();
1783 auto shapeWithoutCh = lhsShape.drop_back(1);
1784 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1785 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1786 "channel dim can be dynamic\n");
1793 static LogicalResult
1795 bool flatten1DDepthwiseConv) {
1796 if (isa<ConvolutionOpInterface>(op.getOperation()))
1805 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1809 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1814 static LogicalResult
1818 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1819 return !getConstantIntValue(res).has_value();
1821 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1825 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1826 unpackOp.getDestType().hasStaticShape() &&
1827 unpackOp.getSourceType().hasStaticShape();
1828 if (!satisfyEmptyCond &&
1837 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
1839 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1842 if (!inputVectorSizes.empty() &&
1848 linalgOp, flatten1DDepthwiseConv))) {
1849 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1862 customPreconditions,
1865 customPrecondition(&innerOp, vectorizeNDExtract));
1869 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1870 return !VectorType::isValidElementType(type);
1874 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1875 return !VectorType::isValidElementType(type);
1886 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1892 LDBG(
"precondition failed: not projected permutations\n");
1896 LDBG(
"precondition failed: reduction preconditions\n");
1902 static LogicalResult
1905 auto padValue = packOp.getPaddingValue();
1908 LDBG(
"pad value is not constant: " << packOp <<
"\n");
1912 bool satisfyEmptyCond =
true;
1913 if (inputVectorSizes.empty()) {
1914 if (!packOp.getDestType().hasStaticShape() ||
1915 !packOp.getSourceType().hasStaticShape())
1916 satisfyEmptyCond =
false;
1919 if (!satisfyEmptyCond &&
1921 resultTensorShape.take_front(packOp.getSourceRank()),
1925 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
1926 return !getConstantIntValue(v).has_value();
1928 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
1935 static LogicalResult
1938 auto padValue = padOp.getConstantPaddingValue();
1940 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1949 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1950 std::optional<int64_t> res = getConstantIntValue(v);
1951 return !res.has_value() || res.value() != 0;
1953 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
1962 static LogicalResult
1966 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1967 "Number of input vector sizes and scalable dims doesn't match");
1969 size_t numOfScalableDims =
1970 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
1972 if (numOfScalableDims == 0)
1975 auto linalgOp = dyn_cast<LinalgOp>(op);
1983 if (numOfScalableDims > 2)
1998 bool seenParalell =
false;
1999 auto iterators = linalgOp.getIteratorTypesArray();
2001 while (!scalableFlags.back()) {
2002 seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2004 iterators.pop_back();
2005 scalableFlags.pop_back();
2008 switch (iterators.back()) {
2009 case utils::IteratorType::reduction: {
2011 if (iterators.size() != inputVectorSizes.size()) {
2012 LDBG(
"Non-trailing reduction dim requested for scalable "
2016 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2017 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2018 "is not supported\n");
2023 case utils::IteratorType::parallel: {
2026 LDBG(
"Inner parallel dim not requested for scalable "
2038 if (numOfScalableDims == 2) {
2042 if (iterators.back() == utils::IteratorType::reduction) {
2043 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2047 scalableFlags.pop_back();
2048 iterators.pop_back();
2050 if (!scalableFlags.back() ||
2051 (iterators.back() != utils::IteratorType::parallel))
2057 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2058 isa<linalg::MatmulTransposeAOp>(op) ||
2059 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2066 bool flatten1DDepthwiseConv) {
2068 inputScalableVecDims)))
2072 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2075 flatten1DDepthwiseConv);
2077 .Case<tensor::PadOp>([&](
auto padOp) {
2080 .Case<tensor::PackOp>([&](
auto packOp) {
2083 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2086 .Default([](
auto) {
return failure(); });
2092 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2094 for (
auto op : make_early_inc_range(toReplace)) {
2098 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2099 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2113 bool vectorizeNDExtract,
2114 bool flatten1DDepthwiseConv) {
2115 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2116 LDBG(
"Input vector sizes: ");
2117 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2118 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2119 LDBG(
"Input scalable vector dims: ");
2120 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2121 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2125 flatten1DDepthwiseConv))) {
2126 LDBG(
"Vectorization pre-conditions failed\n");
2132 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2133 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2134 inputScalableVecDims))) {
2135 LDBG(
"Vectorization state couldn't be initialized\n");
2141 auto vectorizeResult =
2143 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2147 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2149 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2150 flatten1DDepthwiseConv);
2151 if (succeeded(convOr)) {
2152 llvm::append_range(results, (*convOr)->getResults());
2156 LDBG(
"Unsupported convolution can't be vectorized.\n");
2160 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2173 .Case<tensor::PadOp>([&](
auto padOp) {
2177 .Case<tensor::PackOp>([&](
auto packOp) {
2181 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2183 inputVectorSizes, results);
2185 .Default([](
auto) {
return failure(); });
2187 if (failed(vectorizeResult)) {
2188 LDBG(
"Vectorization failed\n");
2192 if (!results.empty())
2201 memref::CopyOp copyOp) {
2202 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2203 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2204 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2209 if (!VectorType::isValidElementType(srcElementType) ||
2210 !VectorType::isValidElementType(dstElementType))
2221 loc, readType, copyOp.getSource(), indices,
2223 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2228 loc,
readValue, copyOp.getTarget(), indices,
2240 return cast<IntegerAttr>(attr).getInt();
2249 for (
auto o : ofrs) {
2250 if (
auto val = llvm::dyn_cast_if_present<Value>(o)) {
2251 result.push_back(val);
2253 result.push_back(rewriter.
create<arith::ConstantIndexOp>(
2272 tensor::PadOp padOp,
Value dest) {
2273 auto sourceType = padOp.getSourceType();
2274 auto resultType = padOp.getResultType();
2275 if (!VectorType::isValidElementType(sourceType.getElementType()))
2281 auto padValue = padOp.getConstantPaddingValue();
2283 if (!sourceType.hasStaticShape())
2286 auto elemType = sourceType.getElementType();
2287 padValue = rewriter.
create<arith::ConstantOp>(
2288 padOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2294 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
2295 if (!sourceType.isDynamicDim(i)) {
2296 vecShape.push_back(sourceType.getDimSize(i));
2299 readInBounds.push_back(
true);
2300 writeInBounds.push_back(
true);
2301 }
else if (!resultType.isDynamicDim(i)) {
2305 vecShape.push_back(resultType.getDimSize(i));
2308 readInBounds.push_back(
false);
2310 writeInBounds.push_back(
2312 static_cast<int64_t
>(0));
2319 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2324 rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2325 auto read = rewriter.
create<vector::TransferReadOp>(
2326 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2331 if (llvm::equal(vecShape, resultType.getShape()) &&
2332 llvm::all_of(writeInBounds, [](
bool b) {
return b; }))
2334 dest = fill.output();
2348 template <
typename OpTy>
2354 bool changed =
false;
2356 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2357 if (
auto op = dyn_cast<OpTy>(user))
2358 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2359 return success(changed);
2364 tensor::PadOp padOp, OpTy op)
const = 0;
2392 vector::TransferReadOp xferOp)
const override {
2394 if (!padOp.hasZeroLowPad())
2397 auto padValue = padOp.getConstantPaddingValue();
2401 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2406 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2408 xferOp.getSourceMutable().assign(padOp.getSource());
2409 xferOp.getPaddingMutable().assign(padValue);
2454 vector::TransferWriteOp xferOp)
const override {
2456 if (xferOp.getTransferRank() == 0)
2460 if (!padOp.hasZeroLowPad())
2463 auto padValue = padOp.getConstantPaddingValue();
2467 if (!xferOp->hasOneUse())
2469 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2473 if (!trimPadding.hasZeroOffset())
2476 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2484 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2485 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2487 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2503 tensor::ExtractSliceOp afterTrimming)
const {
2506 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2507 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2510 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2511 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2516 if (t1.getRank() != t2.getRank())
2521 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2522 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2524 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2529 if (t1.getNumDynamicDims() == 0)
2537 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2541 assert(
static_cast<size_t>(t1.getRank()) ==
2542 beforeSlice.getMixedSizes().size());
2543 assert(
static_cast<size_t>(t2.getRank()) ==
2544 afterTrimming.getMixedSizes().size());
2546 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2548 if (!t1.isDynamicDim(i))
2550 auto size1 = beforeSlice.getMixedSizes()[i];
2551 auto size2 = afterTrimming.getMixedSizes()[i];
2558 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2559 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2565 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2566 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2567 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2568 minOp1.getOperands() == minOp2.getOperands())
2608 tensor::InsertSliceOp insertOp)
const override {
2610 if (!padOp.hasZeroLowPad())
2613 if (!insertOp.hasUnitStride())
2616 auto padValue = padOp.getConstantPaddingValue();
2620 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2623 if (insertOp.getDest() == padOp.getResult())
2627 padOp.getType().getElementType());
2628 unsigned vecRank = vecType.getRank();
2629 unsigned tensorRank = insertOp.getType().getRank();
2634 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2636 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2637 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2648 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2649 auto read = rewriter.
create<vector::TransferReadOp>(
2650 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2659 insertOp, read, insertOp.getDest(), writeIndices,
2688 LDBG(
"interleavedUses precondition failed, firstOp: "
2689 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2692 for (
auto v : values) {
2693 for (
auto &u : v.getUses()) {
2695 if (owner == firstOp || owner == secondOp)
2701 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2702 <<
", second op: " << *secondOp <<
"\n");
2712 memref::SubViewOp subViewOp;
2714 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2716 return memref::SubViewOp();
2717 subViewOp = newSubViewOp;
2729 if (xferOp.getMask())
2733 Value viewOrAlloc = xferOp.getSource();
2742 Value subView = subViewOp.getResult();
2745 memref::CopyOp copyOp;
2746 for (
auto &u : subView.
getUses()) {
2747 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2748 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2749 if (newCopyOp.getTarget() != subView)
2763 for (
auto &u : viewOrAlloc.
getUses()) {
2764 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2765 assert(isa<MemRefType>(newFillOp.output().getType()));
2766 if (newFillOp.output() != viewOrAlloc)
2770 maybeFillOp = newFillOp;
2775 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2777 "padding value does not match fill");
2780 Value in = copyOp.getSource();
2786 auto vectorType = xferOp.getVectorType();
2787 Value res = rewriter.
create<vector::TransferReadOp>(
2788 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2789 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2794 rewriter.
eraseOp(maybeFillOp);
2806 if (xferOp.getMask())
2810 Value viewOrAlloc = xferOp.getSource();
2819 Value subView = subViewOp.getResult();
2822 memref::CopyOp copyOp;
2823 for (
auto &u : subViewOp.getResult().getUses()) {
2824 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2825 if (newCopyOp.getSource() != subView)
2837 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2838 Value out = copyOp.getTarget();
2845 auto vector = xferOp.getVector();
2846 rewriter.
create<vector::TransferWriteOp>(
2847 xferOp.getLoc(), vector, out, xferOp.getIndices(),
2848 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2865 template <
int N,
typename IntTy,
typename... IntTy2>
2867 val = shapedType.getShape()[N];
2872 template <
typename... IntTy>
2874 bindShapeDims<0>(shapedType, vals...);
2878 bool isCastOfBlockArgument(
Operation *op) {
2883 bool isSupportedPoolKind(vector::CombiningKind kind) {
2885 case vector::CombiningKind::ADD:
2886 case vector::CombiningKind::MAXNUMF:
2887 case vector::CombiningKind::MAXIMUMF:
2888 case vector::CombiningKind::MAXSI:
2889 case vector::CombiningKind::MAXUI:
2890 case vector::CombiningKind::MINNUMF:
2891 case vector::CombiningKind::MINIMUMF:
2892 case vector::CombiningKind::MINSI:
2934 struct Conv1DGenerator
2936 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
2939 strideW(strideW), dilationW(dilationW) {
2941 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2943 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2944 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2945 resShaped = linalgOp.getDpsInitOperand(0)->get();
2946 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2947 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2948 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2949 if (!lhsShapedType || !rhsShapedType || !resShapedType)
2953 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2954 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2962 if (!setOperKind(reduceOp))
2965 if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2966 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2970 auto rhsRank = rhsShapedType.getRank();
2973 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3007 int64_t nSize, wSize, cSize, kwSize, fSize;
3010 switch (conv1DOpOrder) {
3013 nSize = fSize = cSize = 0;
3020 (wSize + kwSize - 1)};
3021 rhsShape = {kwSize};
3042 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3047 rhsShape = {kwSize, cSize, fSize};
3050 rhsShape = {kwSize};
3053 resShape = {nSize, wSize, fSize};
3069 lhsShape = {nSize, cSize,
3073 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3077 rhsShape = {fSize, cSize, kwSize};
3080 rhsShape = {kwSize};
3083 resShape = {nSize, fSize, wSize};
3087 vector::TransferWriteOp write;
3088 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3093 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3095 Type lhsEltType = lhsShapedType.getElementType();
3096 Type rhsEltType = rhsShapedType.getElementType();
3097 Type resEltType = resShapedType.getElementType();
3107 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3110 Value rhs =
nullptr;
3112 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3114 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3120 switch (conv1DOpOrder) {
3128 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3129 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3131 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3135 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3137 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3138 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3149 kwSize, strideW, dilationW, wSizeStep,
3155 wSizeStep, isSingleChanneled);
3157 auto linearIndex = [&](int64_t kw, int64_t w) {
3158 return kw * (wSize / wSizeStep) + w;
3164 for (int64_t kw = 0; kw < kwSize; ++kw) {
3165 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3168 if (isSingleChanneled) {
3169 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3170 lhsVals[linearIndex(kw, w)],
3171 rhsVals[kw], resVals[w]);
3173 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3174 lhsVals[linearIndex(kw, w)],
3175 rhsVals[kw], resVals[w]);
3179 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3195 switch (conv1DOpOrder) {
3202 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3203 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3209 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3217 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3218 if (srcElementType == dstElementType)
3223 const Type dstType =
3224 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3226 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3227 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3230 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3231 srcWidth < dstWidth)
3232 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3234 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3235 srcWidth < dstWidth)
3236 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3238 assert(
false &&
"unhandled promotion case");
3245 vector::IteratorType par = vector::IteratorType::parallel;
3246 vector::IteratorType red = vector::IteratorType::reduction;
3251 return rewriter.
create<vector::ContractionOp>(
3253 MapList{{n, w, c}, {c, f}, {n, w, f}},
3261 return rewriter.
create<vector::OuterProductOp>(
3262 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3284 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3285 bool channelDimScalableFlag,
3290 bool scalableChDim =
false;
3291 bool useMasking =
false;
3292 int64_t nSize, wSize, cSize, kwSize;
3295 if (ShapedType::isDynamic(cSize)) {
3296 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3297 cSize = channelDimVecSize;
3301 scalableChDim = channelDimScalableFlag;
3305 assert(!(useMasking && flatten) &&
3306 "Unsupported flattened conv with dynamic shapes");
3311 vector::TransferWriteOp write;
3312 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3317 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3319 Type lhsEltType = lhsShapedType.getElementType();
3320 Type rhsEltType = rhsShapedType.getElementType();
3321 Type resEltType = resShapedType.getElementType();
3326 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3328 lhsEltType, {
false,
false, scalableChDim});
3329 VectorType rhsType =
3331 {
false, scalableChDim});
3332 VectorType resType =
3334 {
false,
false, scalableChDim});
3347 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3348 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3352 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3355 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3362 Value lhs = rewriter.
create<vector::TransferReadOp>(
3363 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3364 auto maybeMaskedLhs = maybeMaskXferOp(
3365 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3368 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3370 auto maybeMaskedRhs = maybeMaskXferOp(
3371 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3374 Value res = rewriter.
create<vector::TransferReadOp>(
3375 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3376 auto maybeMaskedRes = maybeMaskXferOp(
3377 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3389 for (int64_t kw = 0; kw < kwSize; ++kw) {
3390 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3391 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3392 loc, maybeMaskedLhs->getResult(0),
3394 inOutSliceSizes, inOutStrides));
3398 for (int64_t kw = 0; kw < kwSize; ++kw) {
3399 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3400 loc, maybeMaskedRhs->getResult(0),
3404 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3405 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3406 loc, maybeMaskedRes->getResult(0),
3411 auto linearIndex = [&](int64_t kw, int64_t w) {
3412 return kw * (wSize / wSizeStep) + w;
3417 auto inOutFlattenSliceSizes =
3419 auto lhsTypeAfterFlattening =
3421 auto resTypeAfterFlattening =
3425 for (int64_t kw = 0; kw < kwSize; ++kw) {
3426 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3427 Value lhsVal = lhsVals[linearIndex(kw, w)];
3428 Value resVal = resVals[w];
3432 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3433 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3434 resVal = rewriter.
create<vector::ShapeCastOp>(
3435 loc, resTypeAfterFlattening, resVals[w]);
3437 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3438 rhsVals[kw], resVal, flatten);
3441 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3448 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3450 for (
auto &collection :
3451 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3452 for (
Value v : collection)
3459 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3460 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3461 loc, resVals[w], maybeMaskedRes->getResult(0),
3471 loc, maybeMaskedRes->getResult(0), resShaped,
3473 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3484 auto rhsTy = cast<ShapedType>(rhs.
getType());
3485 auto resTy = cast<ShapedType>(res.
getType());
3488 lhs =
promote(rewriter, loc, lhs, resTy);
3499 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3500 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3503 for (
int i = 0; i < resSize / rhsSize; ++i) {
3504 for (
int j = 0;
j < rhsSize; ++
j)
3505 indices.push_back(
j);
3508 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3511 rhs = rewriter.
create<vector::BroadcastOp>(
3512 loc, resTy.clone(rhsTy.getElementType()), rhs);
3514 rhs =
promote(rewriter, loc, rhs, resTy);
3519 if (isa<FloatType>(resTy.getElementType()))
3520 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3522 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3523 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3528 FailureOr<Operation *> generateNonChanneledConv() {
3531 if (!iters({Par(), Red()}))
3533 "failed to match conv::W 1-par 1-red");
3536 if (layout({ {w + kw},
3546 FailureOr<Operation *> generateNwcConv() {
3549 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3551 op,
"failed to match conv::Nwc 3-par 2-red");
3554 if (layout({ {n, strideW * w + dilationW * kw, c},
3564 FailureOr<Operation *> generateNcwConv() {
3567 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3569 op,
"failed to match conv::Ncw 3-par 2-red");
3571 if (layout({ {n, c, strideW * w + dilationW * kw},
3581 FailureOr<Operation *> generateNwcPooling() {
3584 if (!iters({Par(), Par(), Par(), Red()}))
3586 "failed to match pooling 3-par 1-red");
3589 if (layout({ {n, strideW * w + dilationW * kw, c},
3599 FailureOr<Operation *> generateNcwPooling() {
3602 if (!iters({Par(), Par(), Par(), Red()}))
3604 "failed to match pooling 3-par 1-red");
3606 if (layout({ {n, c, strideW * w + dilationW * kw},
3616 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3617 bool vecChDimScalableFlag =
false,
3618 bool flatten =
false) {
3621 if (!iters({Par(), Par(), Par(), Red()}))
3623 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3626 if (layout({ {n, strideW * w + dilationW * kw, c},
3629 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3635 enum OperKind { Conv, Pool };
3637 OperKind oper = Conv;
3639 StringAttr poolExtOp;
3640 bool isPoolExt =
false;
3641 int strideW, dilationW;
3642 Value lhsShaped, rhsShaped, resShaped;
3643 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3654 int numBlockArguments =
3655 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3656 switch (numBlockArguments) {
3660 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3661 llvm::IsaPred<BlockArgument>);
3662 Operation *feedOp = (*feedValIt).getDefiningOp();
3663 if (isCastOfBlockArgument(feedOp)) {
3667 }
else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3669 if (isa<BlockArgument>(v))
3671 if (Operation *op = v.getDefiningOp())
3672 return isCastOfBlockArgument(op);
3695 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3702 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3703 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3704 Conv1DGenerator e(rewriter, op, stride, dilation);
3705 auto res = e.generateNonChanneledConv();
3708 res = e.generateNwcConv();
3711 res = e.generateNcwConv();
3714 res = e.generateNwcPooling();
3717 res = e.generateNcwPooling();
3724 uint64_t vecChDimSize = ShapedType::kDynamic;
3725 bool vecChDimScalableFlag =
false;
3726 if (!inputVecSizes.empty()) {
3729 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3730 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3731 "Not a 1D depthwise conv!");
3734 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3735 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3737 vecChDimSize = inputVecSizes[chDimIdx];
3738 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3740 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3741 flatten1DDepthwiseConv);
3750 if (failed(resultOrFail))
3754 rewriter.
eraseOp(op.getOperation());
3757 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.