36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
64 template <
typename OpType>
67 block.
walk([&](OpType op) {
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize,
int strideW,
int dilationW,
84 int64_t wSizeStep,
bool isSingleChanneled) {
86 if (isSingleChanneled) {
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.
create<vector::ExtractOp>(
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep,
bool isSingleChanneled) {
136 if (isSingleChanneled) {
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
159 Value res, int64_t wSize, int64_t wSizeStep,
161 bool isSingleChanneled) {
163 if (isSingleChanneled) {
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.
create<vector::InsertStridedSliceOp>(
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.
create<vector::InsertStridedSliceOp>(
191 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
208 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
211 if (dimPermutation.has_value()) {
213 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
215 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
217 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
230 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
235 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
242 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
251 std::optional<AffineMap> maybeMaskingMap);
279 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
282 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
283 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
285 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
286 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
293 unsigned operandDimPos;
294 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
298 Value dynamicDim = linalgOp.hasPureTensorSemantics()
300 linalgOp.getLoc(), operand, operandDimPos)
302 linalgOp.getLoc(), operand, operandDimPos);
303 iterSpaceValueSizes.push_back(dynamicDim);
319 if (!inputVectorSizes.empty()) {
323 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
324 scalableVecDims.append(inputScalableVecDims.begin(),
325 inputScalableVecDims.end());
330 canonicalVecShape = linalgOp.getStaticLoopRanges();
331 scalableVecDims.append(linalgOp.getNumLoops(),
false);
334 LDBG(
"Canonical vector shape: ");
335 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
336 LLVM_DEBUG(llvm::dbgs() <<
"\n");
337 LDBG(
"Scalable vector dims: ");
338 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
339 LLVM_DEBUG(llvm::dbgs() <<
"\n");
341 if (ShapedType::isDynamicShape(canonicalVecShape))
345 initIterSpaceStaticSizes(linalgOp);
350 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
360 Value VectorizationState::getOrCreateMaskFor(
362 std::optional<AffineMap> maybeMaskingMap) {
364 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
368 assert(!maskableOp.isMasked() &&
369 "Masking an operation that is already masked");
372 assert((!maybeMaskingMap || *maybeMaskingMap) &&
373 "Unexpected null mask permutation map");
375 maybeMaskingMap ? *maybeMaskingMap
377 linalgOp.getNumLoops(), rewriter.
getContext());
379 LDBG(
"Masking map: " << maskingMap <<
"\n");
383 auto activeMaskIt = activeMaskCache.find(maskingMap);
384 if (activeMaskIt != activeMaskCache.end()) {
385 Value mask = activeMaskIt->second;
386 LDBG(
"Reusing mask: " << mask <<
"\n");
397 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
398 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
399 auto maskShape = maskType.getShape();
401 LDBG(
"Mask shape: ");
402 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
403 LLVM_DEBUG(llvm::dbgs() <<
"\n");
405 if (permutedStaticSizes == maskShape) {
406 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
407 activeMaskCache[maskingMap] =
Value();
414 assert(!maskShape.empty() && !upperBounds.empty() &&
415 "Masked 0-d vectors are not supported yet");
418 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
419 maskType, upperBounds);
420 LDBG(
"Creating new mask: " << mask <<
"\n");
421 activeMaskCache[maskingMap] = mask;
432 std::optional<AffineMap> maybeMaskingMap) {
433 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
437 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
440 LDBG(
"No mask required\n");
445 assert(opToMask &&
"Expected a valid operation to mask");
446 auto maskOp = cast<vector::MaskOp>(
448 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
454 LDBG(
"Masked operation: " << *maskOp <<
"\n");
477 "expected projected permutation");
479 assert(res.getNumDims() == res.getNumResults() &&
480 "expected reindexed map with same number of dims and results");
512 std::optional<vector::CombiningKind>
514 using ::mlir::vector::CombiningKind;
519 .Case<arith::AddIOp, arith::AddFOp>(
520 [&](
auto op) {
return CombiningKind::ADD; })
521 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
522 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
523 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
524 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
525 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
526 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
528 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
529 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
530 .Case<arith::MulIOp, arith::MulFOp>(
531 [&](
auto op) {
return CombiningKind::MUL; })
532 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
533 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
534 .Default([&](
auto op) {
return std::nullopt; });
545 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
550 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
551 combinerOps.size() != 1)
555 return combinerOps[0];
561 auto dstVecType = dyn_cast<VectorType>(dstType);
563 if (dstVecType.getRank() == 0)
569 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
581 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
582 return b.
create<vector::MultiDimReductionOp>(
583 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
587 return llvm::to_vector(
594 return isa<linalg::ReduceOp>(op) ||
595 (isa<linalg::GenericOp>(op) &&
609 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
610 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
619 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
621 auto vectorType = state.getCanonicalVecType(
625 if (vectorType.getRank() > 0) {
628 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
630 assert(value.
getType() == vectorType &&
"Incorrect type");
631 write = rewriter.
create<vector::TransferWriteOp>(
632 loc, value, outputOperand->
get(), indices, writeMap);
635 if (!isa<VectorType>(value.
getType()))
636 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
637 assert(value.
getType() == vectorType &&
"Incorrect type");
638 write = rewriter.
create<vector::TransferWriteOp>(
642 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
646 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
647 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
652 LDBG(
"vectorized op: " << *write <<
"\n");
662 std::function<LogicalResult(
Operation *,
bool)>;
681 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
690 linalgOp.getDpsInitOperand(output.index()), state);
692 newResults.push_back(newResult);
706 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
709 auto loc = indexOp.getLoc();
712 auto dim = indexOp.getDim();
714 auto indexVectorType =
716 state.getScalableVecDims()[dim]);
717 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
721 if (dim == targetShape.size() - 1)
727 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
728 std::swap(permPattern[dim], permPattern.back());
732 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
733 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
736 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
737 std::swap(transposition.back(), transposition[dim]);
739 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
747 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
751 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
756 if (not extractOp.getIndices().empty()) {
757 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
761 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
762 return !VectorType::isValidElementType(type);
782 tensor::ExtractOp extractOp,
785 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
786 auto loc = extractOp.getLoc();
789 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
791 const size_t numIndices = extractOp.getIndices().size();
792 for (
size_t i = 1; i < numIndices; i++) {
793 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
797 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
800 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
803 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
805 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
816 auto targetShape = linalgOp.getStaticLoopRanges();
817 assert(llvm::count_if(targetShape,
818 [](int64_t dimSize) {
return dimSize > 1; }) == 1 &&
819 "n-D vectors are not yet supported");
825 auto *block = linalgOp.getBlock();
826 if (isa<BlockArgument>(val))
827 return llvm::all_of(block->getArguments(),
828 [&val](
Value v) { return (v != val); });
831 assert(defOp &&
"This is neither a block argument nor an operation result");
836 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
837 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
838 return (indexOp.getDim() != trailingLoopDim);
840 auto *ancestor = block->findAncestorOpInBlock(*defOp);
847 if (isa<arith::ConstantOp>(ancestor))
874 bool &foundIndexOp) {
876 auto targetShape = linalgOp.getStaticLoopRanges();
877 assert(((llvm::count_if(targetShape,
878 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
879 "n-D vectors are not yet supported");
885 auto *block = linalgOp.getBlock();
886 if (isa<BlockArgument>(val))
887 return llvm::all_of(block->getArguments(),
888 [&val](
Value v) { return (v != val); });
891 assert(defOp &&
"This is neither a block argument nor an operation result");
895 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
896 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
897 foundIndexOp = (indexOp.getDim() == trailingLoopDim);
901 auto *ancestor = block->findAncestorOpInBlock(*defOp);
908 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
930 LinalgOp &linalgOp) {
932 auto targetShape = linalgOp.getStaticLoopRanges();
933 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
936 if (inputShape.getShape().empty())
942 if (linalgOp.hasDynamicShape())
947 bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
952 if (!isOutput1DVector)
955 bool leadingIdxsLoopInvariant =
true;
961 auto indices = extractOp.getIndices();
962 auto leadIndices = indices.drop_back(1);
965 if (inputShape.getShape()[i] == 1)
971 if (!leadingIdxsLoopInvariant) {
972 LDBG(
"Found gather load: " << extractOp);
980 auto extractOpTrailingIdx = indices.back();
984 if (leadingIdxsLoopInvariant &&
986 LDBG(
"Found scalar broadcast load: " << extractOp);
995 bool foundIndexOp =
false;
996 bool isContiguousLoad =
998 isContiguousLoad &= foundIndexOp;
1000 if (isContiguousLoad) {
1001 LDBG(
"Found contigous load: " << extractOp);
1006 LDBG(
"Found gather load: " << extractOp);
1017 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1020 auto loc = extractOp.getLoc();
1023 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1024 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1028 auto passThruConstantOp =
1034 extractOp.getIndices().size(),
1035 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1046 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1047 maskConstantOp, passThruConstantOp);
1048 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1050 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1073 auto zero = rewriter.
create<arith::ConstantOp>(
1075 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1076 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1078 transferReadIdxs.push_back(idx);
1082 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1085 resultType.getScalableDims().back()),
1087 transferReadIdxs.push_back(
1088 rewriter.
create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1092 auto dstRank = resultType.getRank();
1093 auto srcRank = extractOp.getTensor().getType().getRank();
1102 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1103 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1104 permutationMap, inBounds);
1106 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1114 int32_t rankDiff = dstRank - srcRank;
1122 while (rankDiff > 0) {
1123 permutationMap = permutationMap.insertResult(
1128 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1129 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1132 LDBG(
"Vectorised as contiguous load: " << extractOp);
1145 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1146 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1150 (outputType && reduceType.getShape() == outputType.getShape()))
1179 LDBG(
"vectorize op " << *op <<
"\n");
1182 if (!customVectorizationHooks.empty()) {
1183 for (
auto &customFunc : customVectorizationHooks) {
1193 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1203 auto blockArg = dyn_cast<BlockArgument>(operand);
1204 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1205 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1209 linalgOp.getRegionOutputArgs(),
1210 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1213 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1215 if (!reductionOperands.empty()) {
1216 assert(reductionOperands.size() == 1);
1218 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1219 reductionOperands[0].second, bvm);
1226 VectorType firstMaxRankedType;
1228 auto vecOperand = bvm.
lookup(operand);
1229 assert(vecOperand &&
"Vector operand couldn't be found");
1231 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1232 if (vecType && (!firstMaxRankedType ||
1233 firstMaxRankedType.getRank() < vecType.getRank()))
1234 firstMaxRankedType = vecType;
1240 assert(vecOperand &&
"Vector operand couldn't be found");
1242 if (firstMaxRankedType) {
1245 firstMaxRankedType.getScalableDims());
1248 vecOperands.push_back(vecOperand);
1254 resultTypes.push_back(
1257 firstMaxRankedType.getScalableDims())
1289 static LogicalResult
1293 LDBG(
"Vectorizing operation as linalg generic\n");
1294 Block *block = linalgOp.getBlock();
1301 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1303 if (linalgOp.getNumDpsInits() == 0)
1308 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1309 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1310 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1311 if (linalgOp.isScalar(opOperand)) {
1312 bvm.
map(bbarg, opOperand->get());
1318 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1324 if (isa<AffineConstantExpr>(result.value())) {
1325 zeroPos.push_back(result.index());
1331 VectorType readType;
1333 if (linalgOp.isDpsInput(opOperand)) {
1336 readType = state.getCanonicalVecType(elemType);
1343 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1353 for (
auto idx : broadcastedDims)
1354 inBounds[idx] =
true;
1357 loc, readType, opOperand->get(), indices, readMap,
1359 read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
1364 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1366 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1372 if (readType.getRank() == 0)
1387 hooks.push_back(vectorizeYield);
1394 hooks.push_back(vectorizeIndex);
1401 hooks.push_back(vectorizeExtract);
1408 LDBG(
"failed to vectorize: " << op <<
"\n");
1413 state.maskOperation(rewriter, result.
newOp, linalgOp);
1414 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1439 bool useInBoundsInsteadOfMasking) {
1441 auto inputType = cast<VectorType>(input.
getType());
1442 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1443 inputType.getElementType());
1444 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1445 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1446 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1448 if (useInBoundsInsteadOfMasking) {
1450 for (
unsigned i = 0; i < rank; i++)
1451 inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1452 !ShapedType::isDynamic(destShape[i]);
1460 assert(llvm::none_of(
1461 destShape.drop_front(inputVectorSizes.size()),
1462 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1463 "Only dims aligned with inputVectorSizes may be dynamic");
1464 if (useInBoundsInsteadOfMasking)
1466 bool needMaskForWrite = !llvm::equal(
1467 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1468 if (needMaskForWrite) {
1470 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1471 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1474 Value maskForWrite =
1475 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1507 static LogicalResult
1515 auto padValue = packOp.getPaddingValue();
1517 padValue = rewriter.
create<arith::ConstantOp>(
1518 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1521 LogicalResult status =
1522 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1523 .reifyResultShapes(rewriter, reifiedReturnShapes);
1525 assert(succeeded(status) &&
"failed to reify result shapes");
1530 bool useInBoundsInsteadOfMasking =
false;
1531 if (inputVectorSizes.empty()) {
1533 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1534 useInBoundsInsteadOfMasking =
true;
1539 auto innerTiles = packOp.getStaticInnerTiles();
1540 auto innerDimsPos = packOp.getInnerDimsPos();
1541 auto outerDimsPerm = packOp.getOuterDimsPerm();
1542 if (!outerDimsPerm.empty())
1545 for (
auto [idx, size] :
enumerate(innerTiles))
1546 inputShape[innerDimsPos[idx]] *= size;
1548 rewriter, loc, packOp.getSource(), inputShape, padValue,
1549 useInBoundsInsteadOfMasking);
1553 destShape.append(innerTiles.begin(), innerTiles.end());
1555 packOp.getDestType().getElementType());
1557 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1560 auto destPermutation =
1562 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1563 loc, shapeCastOp.getResult(), destPermutation);
1567 rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1568 inputVectorSizes,
false);
1569 newResults.push_back(write->getResult(0));
1582 static LogicalResult
1590 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1595 bool useInBoundsInsteadOfMasking =
false;
1598 auto destSize = unpackOp.getDestRank();
1600 if (!inputVectorSizes.empty())
1601 assert(inputVectorSizes.size() == destSize &&
1602 "Incorrect number of input vector sizes");
1613 if (vectorSizes.empty()) {
1614 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1615 if (!outerDimsPerm.empty())
1618 vectorSizes[pos] *= innerTiles[i];
1620 useInBoundsInsteadOfMasking =
true;
1644 for (
auto [index, size] :
enumerate(innerTiles)) {
1645 readVectorSizes[innerDimPos[index]] =
1648 if (!outerDimsPerm.empty()) {
1651 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1655 LogicalResult status =
1656 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1657 .reifyResultShapes(rewriter, reifiedRetShapes);
1658 if (status.failed()) {
1659 LDBG(
"Unable to reify result shapes of " << unpackOp);
1664 auto padValue = rewriter.
create<arith::ConstantOp>(
1665 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1670 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1673 PackingMetadata packMetadata;
1676 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1678 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1680 RankedTensorType stripMineTensorType =
1683 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1684 loc, readResult, lastDimToInsertPosPerm);
1687 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1688 stripMineTensorType, packMetadata.reassociations);
1689 mlir::VectorType vecCollapsedType =
1690 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1691 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1692 loc, vecCollapsedType, transposeOp->getResult(0));
1697 unpackOp.getDestType().hasStaticShape()
1699 : shapeCastOp.getResultVectorType().getShape());
1701 rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1702 writeVectorSizes, useInBoundsInsteadOfMasking);
1703 newResults.push_back(write->
getResult(0));
1710 static LogicalResult
1714 auto padValue = padOp.getConstantPaddingValue();
1722 LogicalResult status =
1723 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1724 .reifyResultShapes(rewriter, reifiedReturnShapes);
1726 assert(succeeded(status) &&
"failed to reify result shapes");
1728 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1731 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1733 newResults.push_back(write->
getResult(0));
1741 LDBG(
"reduction precondition failed: no reduction iterator\n");
1744 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1745 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1751 LDBG(
"reduction precondition failed: reduction detection failed\n");
1758 static LogicalResult
1760 bool flatten1DDepthwiseConv) {
1761 if (flatten1DDepthwiseConv) {
1762 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1767 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1768 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1774 Value lhs = conv.getDpsInputOperand(0)->get();
1776 auto shapeWithoutCh = lhsShape.drop_back(1);
1777 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1778 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1779 "channel dim can be dynamic\n");
1786 static LogicalResult
1788 bool flatten1DDepthwiseConv) {
1789 if (isa<ConvolutionOpInterface>(op.getOperation()))
1798 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1802 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1807 static LogicalResult
1811 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1812 return !getConstantIntValue(res).has_value();
1814 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1818 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1819 unpackOp.getDestType().hasStaticShape() &&
1820 unpackOp.getSourceType().hasStaticShape();
1821 if (!satisfyEmptyCond &&
1830 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
1832 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1835 if (!inputVectorSizes.empty() &&
1841 linalgOp, flatten1DDepthwiseConv))) {
1842 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1855 customPreconditions,
1858 customPrecondition(&innerOp, vectorizeNDExtract));
1862 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1863 return !VectorType::isValidElementType(type);
1867 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1868 return !VectorType::isValidElementType(type);
1879 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1885 LDBG(
"precondition failed: not projected permutations\n");
1889 LDBG(
"precondition failed: reduction preconditions\n");
1895 static LogicalResult
1898 auto padValue = packOp.getPaddingValue();
1901 LDBG(
"pad value is not constant: " << packOp <<
"\n");
1905 bool satisfyEmptyCond =
true;
1906 if (inputVectorSizes.empty()) {
1907 if (!packOp.getDestType().hasStaticShape() ||
1908 !packOp.getSourceType().hasStaticShape())
1909 satisfyEmptyCond =
false;
1912 if (!satisfyEmptyCond &&
1914 resultTensorShape.take_front(packOp.getSourceRank()),
1918 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
1919 return !getConstantIntValue(v).has_value();
1921 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
1928 static LogicalResult
1931 auto padValue = padOp.getConstantPaddingValue();
1933 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1942 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1943 std::optional<int64_t> res = getConstantIntValue(v);
1944 return !res.has_value() || res.value() != 0;
1946 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
1955 static LogicalResult
1959 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1960 "Number of input vector sizes and scalable dims doesn't match");
1962 size_t numOfScalableDims =
1963 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
1965 if (numOfScalableDims == 0)
1968 auto linalgOp = dyn_cast<LinalgOp>(op);
1976 if (numOfScalableDims > 2)
1991 bool seenParalell =
false;
1992 auto iterators = linalgOp.getIteratorTypesArray();
1994 while (!scalableFlags.back()) {
1995 seenParalell |= (iterators.back() == utils::IteratorType::parallel);
1997 iterators.pop_back();
1998 scalableFlags.pop_back();
2001 switch (iterators.back()) {
2002 case utils::IteratorType::reduction: {
2004 if (iterators.size() != inputVectorSizes.size()) {
2005 LDBG(
"Non-trailing reduction dim requested for scalable "
2009 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2010 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2011 "is not supported\n");
2016 case utils::IteratorType::parallel: {
2019 LDBG(
"Inner parallel dim not requested for scalable "
2031 if (numOfScalableDims == 2) {
2035 if (iterators.back() == utils::IteratorType::reduction) {
2036 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2040 scalableFlags.pop_back();
2041 iterators.pop_back();
2043 if (!scalableFlags.back() ||
2044 (iterators.back() != utils::IteratorType::parallel))
2050 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2051 isa<linalg::MatmulTransposeAOp>(op) ||
2052 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2059 bool flatten1DDepthwiseConv) {
2061 inputScalableVecDims)))
2065 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2068 flatten1DDepthwiseConv);
2070 .Case<tensor::PadOp>([&](
auto padOp) {
2073 .Case<tensor::PackOp>([&](
auto packOp) {
2076 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2079 .Default([](
auto) {
return failure(); });
2085 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2087 for (
auto op : make_early_inc_range(toReplace)) {
2091 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2092 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2106 bool vectorizeNDExtract,
2107 bool flatten1DDepthwiseConv) {
2108 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2109 LDBG(
"Input vector sizes: ");
2110 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2111 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2112 LDBG(
"Input scalable vector dims: ");
2113 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2114 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2118 flatten1DDepthwiseConv))) {
2119 LDBG(
"Vectorization pre-conditions failed\n");
2125 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2126 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2127 inputScalableVecDims))) {
2128 LDBG(
"Vectorization state couldn't be initialized\n");
2134 auto vectorizeResult =
2136 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2140 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2142 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2143 flatten1DDepthwiseConv);
2144 if (succeeded(convOr)) {
2145 llvm::append_range(results, (*convOr)->getResults());
2149 LDBG(
"Unsupported convolution can't be vectorized.\n");
2153 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2166 .Case<tensor::PadOp>([&](
auto padOp) {
2170 .Case<tensor::PackOp>([&](
auto packOp) {
2174 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2176 inputVectorSizes, results);
2178 .Default([](
auto) {
return failure(); });
2180 if (failed(vectorizeResult)) {
2181 LDBG(
"Vectorization failed\n");
2185 if (!results.empty())
2194 memref::CopyOp copyOp) {
2195 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2196 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2197 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2202 if (!VectorType::isValidElementType(srcElementType) ||
2203 !VectorType::isValidElementType(dstElementType))
2214 loc, readType, copyOp.getSource(), indices,
2216 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2221 loc,
readValue, copyOp.getTarget(), indices,
2233 return cast<IntegerAttr>(attr).getInt();
2242 for (
auto o : ofrs) {
2243 if (
auto val = llvm::dyn_cast_if_present<Value>(o)) {
2244 result.push_back(val);
2246 result.push_back(rewriter.
create<arith::ConstantIndexOp>(
2265 tensor::PadOp padOp,
Value dest) {
2266 auto sourceType = padOp.getSourceType();
2267 auto resultType = padOp.getResultType();
2268 if (!VectorType::isValidElementType(sourceType.getElementType()))
2274 auto padValue = padOp.getConstantPaddingValue();
2276 if (!sourceType.hasStaticShape())
2279 auto elemType = sourceType.getElementType();
2280 padValue = rewriter.
create<arith::ConstantOp>(
2281 padOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2287 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
2288 if (!sourceType.isDynamicDim(i)) {
2289 vecShape.push_back(sourceType.getDimSize(i));
2292 readInBounds.push_back(
true);
2293 writeInBounds.push_back(
true);
2294 }
else if (!resultType.isDynamicDim(i)) {
2298 vecShape.push_back(resultType.getDimSize(i));
2301 readInBounds.push_back(
false);
2303 writeInBounds.push_back(
2305 static_cast<int64_t
>(0));
2312 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2317 rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2318 auto read = rewriter.
create<vector::TransferReadOp>(
2319 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
2324 if (llvm::equal(vecShape, resultType.getShape()) &&
2325 llvm::all_of(writeInBounds, [](
bool b) {
return b; }))
2327 dest = fill.output();
2341 template <
typename OpTy>
2347 bool changed =
false;
2349 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2350 if (
auto op = dyn_cast<OpTy>(user))
2351 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2352 return success(changed);
2357 tensor::PadOp padOp, OpTy op)
const = 0;
2385 vector::TransferReadOp xferOp)
const override {
2387 if (!padOp.hasZeroLowPad())
2390 auto padValue = padOp.getConstantPaddingValue();
2394 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2399 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2401 xferOp.getSourceMutable().assign(padOp.getSource());
2402 xferOp.getPaddingMutable().assign(padValue);
2447 vector::TransferWriteOp xferOp)
const override {
2449 if (xferOp.getTransferRank() == 0)
2453 if (!padOp.hasZeroLowPad())
2456 auto padValue = padOp.getConstantPaddingValue();
2460 if (!xferOp->hasOneUse())
2462 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2466 if (!trimPadding.hasZeroOffset())
2469 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2477 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2478 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2480 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2496 tensor::ExtractSliceOp afterTrimming)
const {
2499 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2500 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2503 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2504 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2509 if (t1.getRank() != t2.getRank())
2514 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2515 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2517 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2522 if (t1.getNumDynamicDims() == 0)
2530 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2534 assert(
static_cast<size_t>(t1.getRank()) ==
2535 beforeSlice.getMixedSizes().size());
2536 assert(
static_cast<size_t>(t2.getRank()) ==
2537 afterTrimming.getMixedSizes().size());
2539 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2541 if (!t1.isDynamicDim(i))
2543 auto size1 = beforeSlice.getMixedSizes()[i];
2544 auto size2 = afterTrimming.getMixedSizes()[i];
2551 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2552 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2558 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2559 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2560 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2561 minOp1.getOperands() == minOp2.getOperands())
2601 tensor::InsertSliceOp insertOp)
const override {
2603 if (!padOp.hasZeroLowPad())
2606 if (!insertOp.hasUnitStride())
2609 auto padValue = padOp.getConstantPaddingValue();
2613 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2616 if (insertOp.getDest() == padOp.getResult())
2620 padOp.getType().getElementType());
2621 unsigned vecRank = vecType.getRank();
2622 unsigned tensorRank = insertOp.getType().getRank();
2627 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2629 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2630 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2641 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2642 auto read = rewriter.
create<vector::TransferReadOp>(
2643 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2652 insertOp, read, insertOp.getDest(), writeIndices,
2681 LDBG(
"interleavedUses precondition failed, firstOp: "
2682 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2685 for (
auto v : values) {
2686 for (
auto &u : v.getUses()) {
2688 if (owner == firstOp || owner == secondOp)
2694 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2695 <<
", second op: " << *secondOp <<
"\n");
2705 memref::SubViewOp subViewOp;
2707 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2709 return memref::SubViewOp();
2710 subViewOp = newSubViewOp;
2722 if (xferOp.getMask())
2726 Value viewOrAlloc = xferOp.getSource();
2735 Value subView = subViewOp.getResult();
2738 memref::CopyOp copyOp;
2739 for (
auto &u : subView.
getUses()) {
2740 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2741 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2742 if (newCopyOp.getTarget() != subView)
2756 for (
auto &u : viewOrAlloc.
getUses()) {
2757 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2758 assert(isa<MemRefType>(newFillOp.output().getType()));
2759 if (newFillOp.output() != viewOrAlloc)
2763 maybeFillOp = newFillOp;
2768 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2770 "padding value does not match fill");
2773 Value in = copyOp.getSource();
2779 auto vectorType = xferOp.getVectorType();
2780 Value res = rewriter.
create<vector::TransferReadOp>(
2781 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2782 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2787 rewriter.
eraseOp(maybeFillOp);
2799 if (xferOp.getMask())
2803 Value viewOrAlloc = xferOp.getSource();
2812 Value subView = subViewOp.getResult();
2815 memref::CopyOp copyOp;
2816 for (
auto &u : subViewOp.getResult().getUses()) {
2817 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2818 if (newCopyOp.getSource() != subView)
2830 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2831 Value out = copyOp.getTarget();
2838 auto vector = xferOp.getVector();
2839 rewriter.
create<vector::TransferWriteOp>(
2840 xferOp.getLoc(), vector, out, xferOp.getIndices(),
2841 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2858 template <
int N,
typename IntTy,
typename... IntTy2>
2860 val = shapedType.getShape()[N];
2865 template <
typename... IntTy>
2867 bindShapeDims<0>(shapedType, vals...);
2871 bool isCastOfBlockArgument(
Operation *op) {
2876 bool isSupportedPoolKind(vector::CombiningKind kind) {
2878 case vector::CombiningKind::ADD:
2879 case vector::CombiningKind::MAXNUMF:
2880 case vector::CombiningKind::MAXIMUMF:
2881 case vector::CombiningKind::MAXSI:
2882 case vector::CombiningKind::MAXUI:
2883 case vector::CombiningKind::MINNUMF:
2884 case vector::CombiningKind::MINIMUMF:
2885 case vector::CombiningKind::MINSI:
2927 struct Conv1DGenerator
2929 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
2932 strideW(strideW), dilationW(dilationW) {
2934 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2936 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2937 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2938 resShaped = linalgOp.getDpsInitOperand(0)->get();
2939 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
2940 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
2941 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
2942 if (!lhsShapedType || !rhsShapedType || !resShapedType)
2946 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2947 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2955 if (!setOperKind(reduceOp))
2958 if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2959 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2963 auto rhsRank = rhsShapedType.getRank();
2966 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3000 int64_t nSize, wSize, cSize, kwSize, fSize;
3003 switch (conv1DOpOrder) {
3006 nSize = fSize = cSize = 0;
3013 (wSize + kwSize - 1)};
3014 rhsShape = {kwSize};
3035 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3040 rhsShape = {kwSize, cSize, fSize};
3043 rhsShape = {kwSize};
3046 resShape = {nSize, wSize, fSize};
3062 lhsShape = {nSize, cSize,
3066 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3070 rhsShape = {fSize, cSize, kwSize};
3073 rhsShape = {kwSize};
3076 resShape = {nSize, fSize, wSize};
3080 vector::TransferWriteOp write;
3081 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3086 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3088 Type lhsEltType = lhsShapedType.getElementType();
3089 Type rhsEltType = rhsShapedType.getElementType();
3090 Type resEltType = resShapedType.getElementType();
3100 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3103 Value rhs =
nullptr;
3105 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3107 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3113 switch (conv1DOpOrder) {
3121 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3122 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3124 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3128 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3130 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3131 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3142 kwSize, strideW, dilationW, wSizeStep,
3148 wSizeStep, isSingleChanneled);
3150 auto linearIndex = [&](int64_t kw, int64_t w) {
3151 return kw * (wSize / wSizeStep) + w;
3157 for (int64_t kw = 0; kw < kwSize; ++kw) {
3158 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3161 if (isSingleChanneled) {
3162 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3163 lhsVals[linearIndex(kw, w)],
3164 rhsVals[kw], resVals[w]);
3166 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3167 lhsVals[linearIndex(kw, w)],
3168 rhsVals[kw], resVals[w]);
3172 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3188 switch (conv1DOpOrder) {
3195 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3196 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3202 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3210 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3211 if (srcElementType == dstElementType)
3216 const Type dstType =
3217 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3219 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3220 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3223 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3224 srcWidth < dstWidth)
3225 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3227 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3228 srcWidth < dstWidth)
3229 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3231 assert(
false &&
"unhandled promotion case");
3238 vector::IteratorType par = vector::IteratorType::parallel;
3239 vector::IteratorType red = vector::IteratorType::reduction;
3244 return rewriter.
create<vector::ContractionOp>(
3246 MapList{{n, w, c}, {c, f}, {n, w, f}},
3254 return rewriter.
create<vector::OuterProductOp>(
3255 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3277 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3278 bool channelDimScalableFlag,
3283 bool scalableChDim =
false;
3284 bool useMasking =
false;
3285 int64_t nSize, wSize, cSize, kwSize;
3288 if (ShapedType::isDynamic(cSize)) {
3289 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3290 cSize = channelDimVecSize;
3294 scalableChDim = channelDimScalableFlag;
3298 assert(!(useMasking && flatten) &&
3299 "Unsupported flattened conv with dynamic shapes");
3304 vector::TransferWriteOp write;
3305 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3310 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3312 Type lhsEltType = lhsShapedType.getElementType();
3313 Type rhsEltType = rhsShapedType.getElementType();
3314 Type resEltType = resShapedType.getElementType();
3319 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3321 lhsEltType, {
false,
false, scalableChDim});
3322 VectorType rhsType =
3324 {
false, scalableChDim});
3325 VectorType resType =
3327 {
false,
false, scalableChDim});
3340 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3341 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3345 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3348 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3355 Value lhs = rewriter.
create<vector::TransferReadOp>(
3356 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3357 auto maybeMaskedLhs = maybeMaskXferOp(
3358 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3361 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3363 auto maybeMaskedRhs = maybeMaskXferOp(
3364 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3367 Value res = rewriter.
create<vector::TransferReadOp>(
3368 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3369 auto maybeMaskedRes = maybeMaskXferOp(
3370 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3382 for (int64_t kw = 0; kw < kwSize; ++kw) {
3383 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3384 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3385 loc, maybeMaskedLhs->getResult(0),
3387 inOutSliceSizes, inOutStrides));
3391 for (int64_t kw = 0; kw < kwSize; ++kw) {
3392 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3393 loc, maybeMaskedRhs->getResult(0),
3397 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3398 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3399 loc, maybeMaskedRes->getResult(0),
3404 auto linearIndex = [&](int64_t kw, int64_t w) {
3405 return kw * (wSize / wSizeStep) + w;
3410 auto inOutFlattenSliceSizes =
3412 auto lhsTypeAfterFlattening =
3414 auto resTypeAfterFlattening =
3418 for (int64_t kw = 0; kw < kwSize; ++kw) {
3419 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3420 Value lhsVal = lhsVals[linearIndex(kw, w)];
3421 Value resVal = resVals[w];
3425 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3426 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3427 resVal = rewriter.
create<vector::ShapeCastOp>(
3428 loc, resTypeAfterFlattening, resVals[w]);
3430 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3431 rhsVals[kw], resVal, flatten);
3434 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3441 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3443 for (
auto &collection :
3444 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3445 for (
Value v : collection)
3452 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3453 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3454 loc, resVals[w], maybeMaskedRes->getResult(0),
3464 loc, maybeMaskedRes->getResult(0), resShaped,
3466 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3477 auto rhsTy = cast<ShapedType>(rhs.
getType());
3478 auto resTy = cast<ShapedType>(res.
getType());
3481 lhs =
promote(rewriter, loc, lhs, resTy);
3492 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3493 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3496 for (
int i = 0; i < resSize / rhsSize; ++i) {
3497 for (
int j = 0;
j < rhsSize; ++
j)
3498 indices.push_back(
j);
3501 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3504 rhs = rewriter.
create<vector::BroadcastOp>(
3505 loc, resTy.clone(rhsTy.getElementType()), rhs);
3507 rhs =
promote(rewriter, loc, rhs, resTy);
3512 if (isa<FloatType>(resTy.getElementType()))
3513 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3515 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3516 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3521 FailureOr<Operation *> generateNonChanneledConv() {
3524 if (!iters({Par(), Red()}))
3526 "failed to match conv::W 1-par 1-red");
3529 if (layout({ {w + kw},
3539 FailureOr<Operation *> generateNwcConv() {
3542 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3544 op,
"failed to match conv::Nwc 3-par 2-red");
3547 if (layout({ {n, strideW * w + dilationW * kw, c},
3557 FailureOr<Operation *> generateNcwConv() {
3560 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3562 op,
"failed to match conv::Ncw 3-par 2-red");
3564 if (layout({ {n, c, strideW * w + dilationW * kw},
3574 FailureOr<Operation *> generateNwcPooling() {
3577 if (!iters({Par(), Par(), Par(), Red()}))
3579 "failed to match pooling 3-par 1-red");
3582 if (layout({ {n, strideW * w + dilationW * kw, c},
3592 FailureOr<Operation *> generateNcwPooling() {
3595 if (!iters({Par(), Par(), Par(), Red()}))
3597 "failed to match pooling 3-par 1-red");
3599 if (layout({ {n, c, strideW * w + dilationW * kw},
3609 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3610 bool vecChDimScalableFlag =
false,
3611 bool flatten =
false) {
3614 if (!iters({Par(), Par(), Par(), Red()}))
3616 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3619 if (layout({ {n, strideW * w + dilationW * kw, c},
3622 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3628 enum OperKind { Conv, Pool };
3630 OperKind oper = Conv;
3632 StringAttr poolExtOp;
3633 bool isPoolExt =
false;
3634 int strideW, dilationW;
3635 Value lhsShaped, rhsShaped, resShaped;
3636 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3647 int numBlockArguments =
3648 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3649 switch (numBlockArguments) {
3653 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3654 llvm::IsaPred<BlockArgument>);
3655 Operation *feedOp = (*feedValIt).getDefiningOp();
3656 if (isCastOfBlockArgument(feedOp)) {
3660 }
else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
3662 if (isa<BlockArgument>(v))
3664 if (Operation *op = v.getDefiningOp())
3665 return isCastOfBlockArgument(op);
3688 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3695 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3696 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3697 Conv1DGenerator e(rewriter, op, stride, dilation);
3698 auto res = e.generateNonChanneledConv();
3701 res = e.generateNwcConv();
3704 res = e.generateNcwConv();
3707 res = e.generateNwcPooling();
3710 res = e.generateNcwPooling();
3717 uint64_t vecChDimSize = ShapedType::kDynamic;
3718 bool vecChDimScalableFlag =
false;
3719 if (!inputVecSizes.empty()) {
3722 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3723 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3724 "Not a 1D depthwise conv!");
3727 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3728 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3730 vecChDimSize = inputVecSizes[chDimIdx];
3731 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3733 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3734 flatten1DDepthwiseConv);
3743 if (failed(resultOrFail))
3747 rewriter.
eraseOp(op.getOperation());
3750 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static VectorShape vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
AffineMap dropResults(ArrayRef< int64_t > positions) const
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
AttrClass getAttrOfType(StringAttr name)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.