36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
95 template <
typename OpType>
98 block.
walk([&](OpType op) {
113 int64_t nSize, int64_t wSize, int64_t cSize,
114 int64_t kwSize,
int strideW,
int dilationW,
115 int64_t wSizeStep,
bool isSingleChanneled) {
117 if (isSingleChanneled) {
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 for (int64_t w = 0; w < wSize; w += wSizeStep) {
124 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(rewriter.
create<vector::ExtractOp>(
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep,
bool isSingleChanneled) {
167 if (isSingleChanneled) {
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
180 for (int64_t w = 0; w < wSize; w += wSizeStep) {
181 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = rewriter.
create<vector::InsertStridedSliceOp>(
206 for (int64_t w = 0; w < wSize; w += wSizeStep) {
207 res = rewriter.
create<vector::InsertStridedSliceOp>(
222 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
239 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
242 if (dimPermutation.has_value()) {
244 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
246 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
248 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
249 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
261 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
266 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
267 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
273 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
282 std::optional<AffineMap> maybeMaskingMap);
287 bool isValidMaskingMap(
AffineMap maskingMap) {
336 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
339 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
340 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
342 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
343 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
350 unsigned operandDimPos;
351 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
355 Value dynamicDim = linalgOp.hasPureTensorSemantics()
357 linalgOp.getLoc(), operand, operandDimPos)
359 linalgOp.getLoc(), operand, operandDimPos);
360 iterSpaceValueSizes.push_back(dynamicDim);
376 if (!inputVectorSizes.empty()) {
380 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
381 scalableVecDims.append(inputScalableVecDims.begin(),
382 inputScalableVecDims.end());
387 canonicalVecShape = linalgOp.getStaticLoopRanges();
388 scalableVecDims.append(linalgOp.getNumLoops(),
false);
391 LDBG(
"Canonical vector shape: ");
392 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
393 LLVM_DEBUG(llvm::dbgs() <<
"\n");
394 LDBG(
"Scalable vector dims: ");
395 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
396 LLVM_DEBUG(llvm::dbgs() <<
"\n");
398 if (ShapedType::isDynamicShape(canonicalVecShape))
402 initIterSpaceStaticSizes(linalgOp);
407 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417 Value VectorizationState::getOrCreateMaskFor(
419 std::optional<AffineMap> maybeMaskingMap) {
421 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
422 "Ill-formed masking map.");
425 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
429 assert(!maskableOp.isMasked() &&
430 "Masking an operation that is already masked");
433 assert((!maybeMaskingMap || *maybeMaskingMap) &&
434 "Unexpected null mask permutation map");
436 maybeMaskingMap ? *maybeMaskingMap
438 linalgOp.getNumLoops(), rewriter.
getContext());
440 LDBG(
"Masking map: " << maskingMap <<
"\n");
444 auto activeMaskIt = activeMaskCache.find(maskingMap);
445 if (activeMaskIt != activeMaskCache.end()) {
446 Value mask = activeMaskIt->second;
447 LDBG(
"Reusing mask: " << mask <<
"\n");
458 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
459 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
460 auto maskShape = maskType.getShape();
462 LDBG(
"Mask shape: ");
463 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
464 LLVM_DEBUG(llvm::dbgs() <<
"\n");
466 if (permutedStaticSizes == maskShape) {
467 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
468 activeMaskCache[maskingMap] =
Value();
475 assert(!maskShape.empty() && !upperBounds.empty() &&
476 "Masked 0-d vectors are not supported yet");
479 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
480 maskType, upperBounds);
481 LDBG(
"Creating new mask: " << mask <<
"\n");
482 activeMaskCache[maskingMap] = mask;
489 std::optional<AffineMap> maybeIndexingMap) {
490 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
492 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
493 if (maybeIndexingMap)
494 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
498 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
501 LDBG(
"No mask required\n");
506 assert(opToMask &&
"Expected a valid operation to mask");
507 auto maskOp = cast<vector::MaskOp>(
509 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
515 LDBG(
"Masked operation: " << *maskOp <<
"\n");
538 "expected projected permutation");
540 assert(res.getNumDims() ==
541 (res.getNumResults() - res.getNumOfZeroResults()) &&
542 "expected reindexed map with same number of dims and results");
574 std::optional<vector::CombiningKind>
576 using ::mlir::vector::CombiningKind;
581 .Case<arith::AddIOp, arith::AddFOp>(
582 [&](
auto op) {
return CombiningKind::ADD; })
583 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
584 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
585 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
586 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
587 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
588 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
590 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
591 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
592 .Case<arith::MulIOp, arith::MulFOp>(
593 [&](
auto op) {
return CombiningKind::MUL; })
594 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
595 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
596 .Default([&](
auto op) {
return std::nullopt; });
607 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
612 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
613 combinerOps.size() != 1)
617 return combinerOps[0];
623 auto dstVecType = dyn_cast<VectorType>(dstType);
625 if (dstVecType.getRank() == 0)
631 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
643 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
644 return b.
create<vector::MultiDimReductionOp>(
645 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
649 return llvm::to_vector(
656 return isa<linalg::ReduceOp>(op) ||
657 (isa<linalg::GenericOp>(op) &&
671 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
672 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
681 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
683 auto vectorType = state.getCanonicalVecType(
687 if (vectorType.getRank() > 0) {
690 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
692 assert(value.
getType() == vectorType &&
"Incorrect type");
693 write = rewriter.
create<vector::TransferWriteOp>(
694 loc, value, outputOperand->
get(), indices, writeMap);
697 if (!isa<VectorType>(value.
getType()))
698 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
699 assert(value.
getType() == vectorType &&
"Incorrect type");
700 write = rewriter.
create<vector::TransferWriteOp>(
704 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
708 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
709 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
714 LDBG(
"vectorized op: " << *write <<
"\n");
724 std::function<LogicalResult(
Operation *,
bool)>;
743 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
752 linalgOp.getDpsInitOperand(output.index()), state);
754 newResults.push_back(newResult);
768 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
771 auto loc = indexOp.getLoc();
774 auto dim = indexOp.getDim();
776 auto indexVectorType =
778 state.getScalableVecDims()[dim]);
779 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
783 if (dim == targetShape.size() - 1)
789 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
790 std::swap(permPattern[dim], permPattern.back());
794 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
795 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
798 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
799 std::swap(transposition.back(), transposition[dim]);
801 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
809 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
813 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
818 if (not extractOp.getIndices().empty()) {
819 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
823 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
824 return !VectorType::isValidElementType(type);
844 tensor::ExtractOp extractOp,
847 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
848 auto loc = extractOp.getLoc();
851 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
853 const size_t numIndices = extractOp.getIndices().size();
854 for (
size_t i = 1; i < numIndices; i++) {
855 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
859 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
862 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
865 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
867 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
893 (linalgOp.hasDynamicShape() ||
894 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
895 "For statically shaped Linalg Ops, only one "
896 "non-unit loop dim is expected");
897 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
899 size_t idx = loopRanges.size() - 1;
900 for (; idx != 0; idx--)
901 if (loopRanges[idx] != 1)
909 VectorType resType) {
911 assert(((llvm::count_if(resType.getShape(),
912 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
913 "n-D vectors are not yet supported");
919 auto *block = linalgOp.getBlock();
920 if (isa<BlockArgument>(val))
921 return llvm::all_of(block->getArguments(),
922 [&val](
Value v) { return (v != val); });
925 assert(defOp &&
"This is neither a block argument nor an operation result");
930 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
931 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
934 auto *ancestor = block->findAncestorOpInBlock(*defOp);
941 if (isa<arith::ConstantOp>(ancestor))
945 for (
auto op : ancestor->getOperands())
969 bool &foundIndexOp, VectorType resType) {
971 assert(((llvm::count_if(resType.getShape(),
972 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973 "n-D vectors are not yet supported");
979 auto *block = linalgOp.getBlock();
980 if (isa<BlockArgument>(val))
981 return llvm::all_of(block->getArguments(),
982 [&val](
Value v) { return (v != val); });
985 assert(defOp &&
"This is neither a block argument nor an operation result");
987 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
990 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
994 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1001 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1004 bool result =
false;
1005 for (
auto op : ancestor->getOperands())
1025 LinalgOp &linalgOp, VectorType resType) {
1027 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1030 if (inputShape.getShape().empty())
1035 bool isOutput1DVector =
1036 (llvm::count_if(resType.getShape(),
1037 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1039 if (!isOutput1DVector)
1042 bool leadingIdxsLoopInvariant =
true;
1048 auto indices = extractOp.getIndices();
1049 auto leadIndices = indices.drop_back(1);
1052 if (inputShape.getShape()[i] == 1)
1058 if (!leadingIdxsLoopInvariant) {
1059 LDBG(
"Found gather load: " << extractOp);
1067 auto extractOpTrailingIdx = indices.back();
1071 if (leadingIdxsLoopInvariant &&
1073 LDBG(
"Found scalar broadcast load: " << extractOp);
1082 bool foundIndexOp =
false;
1084 foundIndexOp, resType);
1087 bool isRowVector = resType.getShape().back() != 1;
1088 isContiguousLoad &= (foundIndexOp && isRowVector);
1090 if (isContiguousLoad) {
1091 LDBG(
"Found contigous load: " << extractOp);
1096 LDBG(
"Found gather load: " << extractOp);
1107 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1110 auto loc = extractOp.getLoc();
1113 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1114 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1118 auto passThruConstantOp =
1124 extractOp.getIndices().size(),
1125 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1136 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1137 maskConstantOp, passThruConstantOp);
1138 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1140 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1163 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1164 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1166 transferReadIdxs.push_back(idx);
1170 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1173 resultType.getScalableDims().back()),
1175 transferReadIdxs.push_back(
1176 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1180 auto dstRank = resultType.getRank();
1181 auto srcRank = extractOp.getTensor().getType().getRank();
1190 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1191 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1192 permutationMap, inBounds);
1199 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
1201 auto *maskedReadOp =
1204 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1212 int32_t rankDiff = dstRank - srcRank;
1220 while (rankDiff > 0) {
1221 permutationMap = permutationMap.insertResult(
1226 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1227 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1230 LDBG(
"Vectorised as contiguous load: " << extractOp);
1243 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1244 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1248 (outputType && reduceType.getShape() == outputType.getShape()))
1277 LDBG(
"vectorize op " << *op <<
"\n");
1280 if (!customVectorizationHooks.empty()) {
1281 for (
auto &customFunc : customVectorizationHooks) {
1291 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1301 auto blockArg = dyn_cast<BlockArgument>(operand);
1302 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1303 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1307 linalgOp.getRegionOutputArgs(),
1308 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1311 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1313 if (!reductionOperands.empty()) {
1314 assert(reductionOperands.size() == 1);
1316 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1317 reductionOperands[0].second, bvm);
1324 VectorType firstMaxRankedType;
1326 auto vecOperand = bvm.
lookup(operand);
1327 assert(vecOperand &&
"Vector operand couldn't be found");
1329 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1330 if (vecType && (!firstMaxRankedType ||
1331 firstMaxRankedType.getRank() < vecType.getRank()))
1332 firstMaxRankedType = vecType;
1338 assert(vecOperand &&
"Vector operand couldn't be found");
1340 if (firstMaxRankedType) {
1343 firstMaxRankedType.getScalableDims());
1346 vecOperands.push_back(vecOperand);
1352 resultTypes.push_back(
1355 firstMaxRankedType.getScalableDims())
1387 static LogicalResult
1391 LDBG(
"Vectorizing operation as linalg generic\n");
1392 Block *block = linalgOp.getBlock();
1399 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1401 if (linalgOp.getNumDpsInits() == 0)
1406 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1407 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1408 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1409 if (linalgOp.isScalar(opOperand)) {
1410 bvm.
map(bbarg, opOperand->get());
1416 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1419 VectorType readType;
1421 if (linalgOp.isDpsInput(opOperand)) {
1424 readType = state.getCanonicalVecType(elemType);
1431 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1437 loc, readType, opOperand->get(), indices, readMap);
1438 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1443 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1445 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1451 if (readType.getRank() == 0)
1467 hooks.push_back(vectorizeYield);
1474 hooks.push_back(vectorizeIndex);
1481 hooks.push_back(vectorizeExtract);
1488 LDBG(
"failed to vectorize: " << op <<
"\n");
1493 state.maskOperation(rewriter, result.
newOp, linalgOp);
1494 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1542 bool useInBoundsInsteadOfMasking =
false) {
1544 ShapedType destType = cast<ShapedType>(dest.
getType());
1545 assert(cast<VectorType>(vectorToStore.
getType()).getRank() ==
1546 static_cast<int64_t
>(destType.getRank()) &&
1550 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1551 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1555 if (useInBoundsInsteadOfMasking) {
1558 assert(inputVecSizesForLeadingDims.size() ==
1559 static_cast<size_t>(destType.getRank()) &&
1560 "Insufficient number of input vector sizes!");
1562 for (
unsigned i = 0; i < rank; i++)
1563 inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
1564 !ShapedType::isDynamic(destShape[i]);
1568 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1575 assert(llvm::none_of(
1576 destShape.drop_front(inputVecSizesForLeadingDims.size()),
1577 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1578 "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1581 if (useInBoundsInsteadOfMasking)
1585 bool needMaskForWrite =
1586 !llvm::equal(inputVecSizesForLeadingDims,
1587 destShape.take_front(inputVecSizesForLeadingDims.size()));
1590 if (needMaskForWrite) {
1592 writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
1593 inputVecSizesForLeadingDims.end());
1594 writeMaskShape.append(destShape.begin() +
1595 inputVecSizesForLeadingDims.size(),
1598 Value maskForWrite = builder.
create<vector::CreateMaskOp>(
1640 static LogicalResult
1649 auto padValue = packOp.getPaddingValue();
1651 padValue = rewriter.
create<arith::ConstantOp>(
1652 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1655 LogicalResult status =
1656 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1657 .reifyResultShapes(rewriter, reifiedReturnShapes);
1659 assert(succeeded(status) &&
"failed to reify result shapes");
1664 bool useInBoundsInsteadOfMasking =
false;
1665 if (inputVectorSizes.empty()) {
1667 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1668 useInBoundsInsteadOfMasking =
true;
1673 auto innerTiles = packOp.getStaticInnerTiles();
1682 rewriter, loc, packOp.getSource(), inputShape, padValue,
1683 useInBoundsInsteadOfMasking);
1689 packOp.getDestType().getElementType());
1691 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1694 auto destPermutation =
1696 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1697 loc, shapeCastOp.getResult(), destPermutation);
1701 loc, reifiedReturnShapes[0],
1702 transposeOp.getResult().getType().getElementType());
1707 newResults.push_back(write->getResult(0));
1720 static LogicalResult
1729 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1734 bool useInBoundsInsteadOfMasking =
false;
1737 auto destSize = unpackOp.getDestRank();
1739 if (!inputVectorSizes.empty())
1740 assert(inputVectorSizes.size() == destSize &&
1741 "Incorrect number of input vector sizes");
1752 if (vectorSizes.empty()) {
1753 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1759 useInBoundsInsteadOfMasking =
true;
1784 readVectorSizes[innerDimPos[index]] =
1790 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1794 LogicalResult status =
1795 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1796 .reifyResultShapes(rewriter, reifiedRetShapes);
1797 if (status.failed()) {
1798 LDBG(
"Unable to reify result shapes of " << unpackOp);
1803 auto padValue = rewriter.
create<arith::ConstantOp>(
1804 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1809 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1812 PackingMetadata packMetadata;
1815 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1817 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1819 RankedTensorType stripMineTensorType =
1822 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1823 loc, readResult, lastDimToInsertPosPerm);
1826 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1827 stripMineTensorType, packMetadata.reassociations);
1828 mlir::VectorType vecCollapsedType =
1829 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1830 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1831 loc, vecCollapsedType, transposeOp->getResult(0));
1836 unpackOp.getDestType().hasStaticShape()
1838 : shapeCastOp.getResultVectorType().getShape());
1840 loc, reifiedRetShapes[0],
1841 shapeCastOp.getResult().getType().getElementType());
1845 useInBoundsInsteadOfMasking);
1846 newResults.push_back(write->getResult(0));
1853 static LogicalResult
1857 auto padValue = padOp.getConstantPaddingValue();
1865 LogicalResult status =
1866 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1867 .reifyResultShapes(rewriter, reifiedReturnShapes);
1869 assert(succeeded(status) &&
"failed to reify result shapes");
1871 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1876 loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1881 newResults.push_back(write->getResult(0));
1889 LDBG(
"reduction precondition failed: no reduction iterator\n");
1892 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1893 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1899 LDBG(
"reduction precondition failed: reduction detection failed\n");
1906 static LogicalResult
1908 bool flatten1DDepthwiseConv) {
1909 if (flatten1DDepthwiseConv) {
1910 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1915 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1916 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1922 Value lhs = conv.getDpsInputOperand(0)->get();
1924 auto shapeWithoutCh = lhsShape.drop_back(1);
1925 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1926 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1927 "channel dim can be dynamic\n");
1934 static LogicalResult
1936 bool flatten1DDepthwiseConv) {
1937 if (isa<ConvolutionOpInterface>(op.getOperation()))
1946 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1950 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1955 static LogicalResult
1959 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1960 return !getConstantIntValue(res).has_value();
1962 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1966 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1967 unpackOp.getDestType().hasStaticShape() &&
1968 unpackOp.getSourceType().hasStaticShape();
1969 if (!satisfyEmptyCond &&
1976 static LogicalResult
1981 auto sourceType = source.getType();
1982 if (!VectorType::isValidElementType(sourceType.getElementType()))
1998 bool isOutOfBoundsRead =
1999 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2001 if (!padValue && isOutOfBoundsRead) {
2002 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
2009 enum class ConvOperationKind { Conv, Pool };
2027 static std::optional<ConvOperationKind>
2029 int numBlockArguments =
2030 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2032 switch (numBlockArguments) {
2038 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2039 llvm::IsaPred<BlockArgument>);
2041 "Expected a non-block argument operand");
2042 Operation *feedOp = (*feedValIt).getDefiningOp();
2044 return ConvOperationKind::Pool;
2047 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2048 (isa<arith::AndIOp>(feedOp) &&
2051 if (isa<BlockArgument>(v))
2053 if (Operation *op = v.getDefiningOp())
2054 return isCastOfBlockArgument(op);
2057 return std::nullopt;
2060 return ConvOperationKind::Conv;
2064 return ConvOperationKind::Pool;
2066 return std::nullopt;
2072 case vector::CombiningKind::ADD:
2073 case vector::CombiningKind::MAXNUMF:
2074 case vector::CombiningKind::MAXIMUMF:
2075 case vector::CombiningKind::MAXSI:
2076 case vector::CombiningKind::MAXUI:
2077 case vector::CombiningKind::MINNUMF:
2078 case vector::CombiningKind::MINIMUMF:
2079 case vector::CombiningKind::MINSI:
2088 auto getOperandType = [&](
auto operand) {
2089 return dyn_cast<ShapedType>((operand->get()).getType());
2091 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2092 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2093 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2097 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2098 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2106 if (!maybeOper.has_value())
2113 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2114 *maybeKind != vector::CombiningKind::OR) &&
2115 (*maybeOper != ConvOperationKind::Pool ||
2120 auto rhsRank = rhsShapedType.getRank();
2121 if (*maybeOper == ConvOperationKind::Pool) {
2125 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2134 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2136 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
2139 if (!inputVectorSizes.empty() &&
2145 linalgOp, flatten1DDepthwiseConv))) {
2146 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
2159 customPreconditions,
2162 customPrecondition(&innerOp, vectorizeNDExtract));
2166 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
2167 return !VectorType::isValidElementType(type);
2171 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
2172 return !VectorType::isValidElementType(type);
2183 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2190 LDBG(
"precondition failed: not projected permutations\n");
2194 LDBG(
"precondition failed: reduction preconditions\n");
2200 static LogicalResult
2203 auto padValue = packOp.getPaddingValue();
2206 LDBG(
"pad value is not constant: " << packOp <<
"\n");
2210 bool satisfyEmptyCond =
true;
2211 if (inputVectorSizes.empty()) {
2212 if (!packOp.getDestType().hasStaticShape() ||
2213 !packOp.getSourceType().hasStaticShape())
2214 satisfyEmptyCond =
false;
2217 if (!satisfyEmptyCond &&
2219 resultTensorShape.take_front(packOp.getSourceRank()),
2223 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2224 return !getConstantIntValue(v).has_value();
2226 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
2233 static LogicalResult
2236 auto padValue = padOp.getConstantPaddingValue();
2238 LDBG(
"pad value is not constant: " << padOp <<
"\n");
2258 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2259 Value padValue = en.value();
2260 unsigned pos = en.index();
2261 std::optional<int64_t> pad = getConstantIntValue(padValue);
2262 return (!pad.has_value() || pad.value() != 0) &&
2263 resultTensorShape[pos] != 1;
2265 LDBG(
"low pad must all be zero for all non unit dims: " << padOp <<
"\n");
2274 static LogicalResult
2278 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2279 "Number of input vector sizes and scalable dims doesn't match");
2281 size_t numOfScalableDims =
2282 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2284 if (numOfScalableDims == 0)
2287 auto linalgOp = dyn_cast<LinalgOp>(op);
2295 if (numOfScalableDims > 2)
2315 bool seenNonUnitParallel =
false;
2316 auto iterators = linalgOp.getIteratorTypesArray();
2318 int64_t idx = scalableFlags.size() - 1;
2319 while (!scalableFlags[idx]) {
2320 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2321 seenNonUnitParallel |=
2322 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2324 iterators.pop_back();
2325 scalableFlags.pop_back();
2330 switch (iterators.back()) {
2331 case utils::IteratorType::reduction: {
2333 if (iterators.size() != inputVectorSizes.size()) {
2334 LDBG(
"Non-trailing reduction dim requested for scalable "
2338 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2339 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2340 "is not supported\n");
2345 case utils::IteratorType::parallel: {
2347 if (seenNonUnitParallel) {
2348 LDBG(
"Inner parallel dim not requested for scalable "
2360 if (numOfScalableDims == 2) {
2364 if (iterators.back() == utils::IteratorType::reduction) {
2365 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2369 scalableFlags.pop_back();
2370 iterators.pop_back();
2372 if (!scalableFlags.back() ||
2373 (iterators.back() != utils::IteratorType::parallel))
2379 if (linalgOp.hasUserDefinedMaps())
2384 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2385 isa<linalg::MatmulTransposeAOp>(op) ||
2386 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2393 bool flatten1DDepthwiseConv) {
2399 inputScalableVecDims)))
2403 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2406 flatten1DDepthwiseConv);
2408 .Case<tensor::PadOp>([&](
auto padOp) {
2411 .Case<linalg::PackOp>([&](
auto packOp) {
2414 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2417 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2420 .Default([](
auto) {
return failure(); });
2426 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2428 for (
auto op : make_early_inc_range(toReplace)) {
2431 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2432 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2433 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2439 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2440 tensor::InsertSliceOp>(op);
2452 bool vectorizeNDExtract,
2453 bool flatten1DDepthwiseConv) {
2454 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2455 LDBG(
"Input vector sizes: ");
2456 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2457 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2458 LDBG(
"Input scalable vector dims: ");
2459 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2460 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2464 flatten1DDepthwiseConv))) {
2465 LDBG(
"Vectorization pre-conditions failed\n");
2471 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2472 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2473 inputScalableVecDims))) {
2474 LDBG(
"Vectorization state couldn't be initialized\n");
2480 auto vectorizeResult =
2482 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2486 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2488 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2489 flatten1DDepthwiseConv);
2490 if (succeeded(convOr)) {
2491 llvm::append_range(results, (*convOr)->getResults());
2495 LDBG(
"Unsupported convolution can't be vectorized.\n");
2499 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2512 .Case<tensor::PadOp>([&](
auto padOp) {
2516 .Case<linalg::PackOp>([&](
auto packOp) {
2520 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2522 inputVectorSizes, results);
2524 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2528 .Default([](
auto) {
return failure(); });
2530 if (failed(vectorizeResult)) {
2531 LDBG(
"Vectorization failed\n");
2535 if (!results.empty())
2544 memref::CopyOp copyOp) {
2545 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2546 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2547 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2552 if (!VectorType::isValidElementType(srcElementType) ||
2553 !VectorType::isValidElementType(dstElementType))
2564 loc, readType, copyOp.getSource(), indices,
2566 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2572 loc,
readValue, copyOp.getTarget(), indices,
2583 template <
typename OpTy>
2591 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2592 if (
auto op = dyn_cast<OpTy>(user))
2593 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2599 tensor::PadOp padOp, OpTy op)
const = 0;
2627 vector::TransferReadOp xferOp)
const override {
2629 if (!padOp.hasZeroLowPad())
2632 auto padValue = padOp.getConstantPaddingValue();
2636 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2641 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2643 xferOp.getBaseMutable().assign(padOp.getSource());
2644 xferOp.getPaddingMutable().assign(padValue);
2689 vector::TransferWriteOp xferOp)
const override {
2691 if (xferOp.getTransferRank() == 0)
2695 if (!padOp.hasZeroLowPad())
2698 auto padValue = padOp.getConstantPaddingValue();
2702 if (!xferOp->hasOneUse())
2704 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2708 if (!trimPadding.hasZeroOffset())
2711 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2719 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2720 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2722 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2738 tensor::ExtractSliceOp afterTrimming)
const {
2741 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2742 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2745 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2746 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2751 if (t1.getRank() != t2.getRank())
2756 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2757 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2759 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2764 if (t1.getNumDynamicDims() == 0)
2772 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2776 assert(
static_cast<size_t>(t1.getRank()) ==
2777 beforeSlice.getMixedSizes().size());
2778 assert(
static_cast<size_t>(t2.getRank()) ==
2779 afterTrimming.getMixedSizes().size());
2781 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2783 if (!t1.isDynamicDim(i))
2785 auto size1 = beforeSlice.getMixedSizes()[i];
2786 auto size2 = afterTrimming.getMixedSizes()[i];
2793 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2794 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2800 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2801 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2802 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2803 minOp1.getOperands() == minOp2.getOperands())
2829 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2830 auto source = bcast.getSource();
2831 if (llvm::dyn_cast<VectorType>(source.getType()))
2839 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2840 return fill.getInputs()[0];
2845 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2852 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2860 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2866 static LogicalResult
2875 auto sourceType = source.getType();
2876 auto resultType = sliceOp.getResultType();
2881 auto elemType = sourceType.getElementType();
2882 padValue = rewriter.
create<arith::ConstantOp>(
2883 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2890 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2891 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2892 if (!inputVectorSizes.empty()) {
2893 vecShape.push_back(inputVectorSizes[i]);
2894 readInBounds.push_back(
false);
2895 writeInBounds.push_back(
false);
2896 }
else if (!sourceType.isDynamicDim(i)) {
2897 vecShape.push_back(sourceType.getDimSize(i));
2900 readInBounds.push_back(
true);
2901 writeInBounds.push_back(
true);
2902 }
else if (!resultType.isDynamicDim(i)) {
2908 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2911 readInBounds.push_back(
false);
2914 writeInBounds.push_back(
false);
2922 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2930 if (!inputVectorSizes.empty()) {
2933 LDBG(
"Unable to get the defining Op of " << sliceOp);
2937 LogicalResult status =
2938 cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2939 rewriter, reifiedSrcSizes);
2940 if (status.failed()) {
2941 LDBG(
"Unable to reify result shapes of " << srcDefOp);
2947 maskOp = rewriter.
create<vector::CreateMaskOp>(
2948 sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2953 rewriter.
create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2955 sliceOp.getLoc(), vecType, source, readIndices, padValue,
2963 rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2966 sliceOp.getLoc(), read->
getResult(0), sliceOp.getDest(), writeIndices,
2974 newResults.push_back(write->
getResult(0));
3008 tensor::InsertSliceOp insertOp)
const override {
3010 if (!padOp.hasZeroLowPad())
3013 if (!insertOp.hasUnitStride())
3016 auto padValue = padOp.getConstantPaddingValue();
3020 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3023 if (insertOp.getDest() == padOp.getResult())
3027 padOp.getType().getElementType());
3028 unsigned vecRank = vecType.getRank();
3029 unsigned tensorRank = insertOp.getType().getRank();
3034 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3036 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3037 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3048 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3049 auto read = rewriter.
create<vector::TransferReadOp>(
3050 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3056 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3059 insertOp, read, insertOp.getDest(), writeIndices,
3085 LDBG(
"interleavedUses precondition failed, firstOp: "
3086 << *firstOp <<
", second op: " << *secondOp <<
"\n");
3089 for (
auto v : values) {
3090 for (
auto &u : v.getUses()) {
3092 if (owner == firstOp || owner == secondOp)
3098 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3099 <<
", second op: " << *secondOp <<
"\n");
3109 memref::SubViewOp subViewOp;
3111 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3113 return memref::SubViewOp();
3114 subViewOp = newSubViewOp;
3126 if (xferOp.getMask())
3130 Value viewOrAlloc = xferOp.getBase();
3139 Value subView = subViewOp.getResult();
3142 memref::CopyOp copyOp;
3143 for (
auto &u : subView.
getUses()) {
3144 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3145 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3146 if (newCopyOp.getTarget() != subView)
3160 for (
auto &u : viewOrAlloc.
getUses()) {
3161 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3162 assert(isa<MemRefType>(newFillOp.output().getType()));
3163 if (newFillOp.output() != viewOrAlloc)
3167 maybeFillOp = newFillOp;
3172 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3174 "padding value does not match fill");
3177 Value in = copyOp.getSource();
3183 auto vectorType = xferOp.getVectorType();
3184 Value res = rewriter.
create<vector::TransferReadOp>(
3185 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3186 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3191 rewriter.
eraseOp(maybeFillOp);
3203 if (xferOp.getMask())
3207 Value viewOrAlloc = xferOp.getBase();
3216 Value subView = subViewOp.getResult();
3219 memref::CopyOp copyOp;
3220 for (
auto &u : subViewOp.getResult().getUses()) {
3221 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3222 if (newCopyOp.getSource() != subView)
3234 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3235 Value out = copyOp.getTarget();
3242 auto vector = xferOp.getVector();
3243 rewriter.
create<vector::TransferWriteOp>(
3244 xferOp.getLoc(), vector, out, xferOp.getIndices(),
3245 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3247 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3262 template <
int N,
typename IntTy,
typename... IntTy2>
3264 val = shapedType.getShape()[N];
3269 template <
typename... IntTy>
3271 bindShapeDims<0>(shapedType, vals...);
3309 struct Conv1DGenerator
3311 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3314 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3315 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3316 resShaped = linalgOp.getDpsInitOperand(0)->get();
3317 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3318 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3319 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3324 setConvOperationKind(reduceOp);
3327 reductionKind = maybeKind.value();
3335 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3336 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3358 int64_t nSize, wSize, cSize, kwSize, fSize;
3361 switch (conv1DOpOrder) {
3364 nSize = fSize = cSize = 0;
3371 (wSize + kwSize - 1)};
3372 rhsShape = {kwSize};
3379 case ConvOperationKind::Conv:
3383 case ConvOperationKind::Pool:
3393 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3397 case ConvOperationKind::Conv:
3398 rhsShape = {kwSize, cSize, fSize};
3400 case ConvOperationKind::Pool:
3401 rhsShape = {kwSize};
3404 resShape = {nSize, wSize, fSize};
3410 case ConvOperationKind::Conv:
3414 case ConvOperationKind::Pool:
3420 lhsShape = {nSize, cSize,
3424 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3427 case ConvOperationKind::Conv:
3428 rhsShape = {fSize, cSize, kwSize};
3430 case ConvOperationKind::Pool:
3431 rhsShape = {kwSize};
3434 resShape = {nSize, fSize, wSize};
3438 vector::TransferWriteOp write;
3439 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3444 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3446 Type lhsEltType = lhsShapedType.getElementType();
3447 Type rhsEltType = rhsShapedType.getElementType();
3448 Type resEltType = resShapedType.getElementType();
3458 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3461 Value rhs =
nullptr;
3462 if (oper == ConvOperationKind::Conv)
3463 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3465 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3471 switch (conv1DOpOrder) {
3479 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3480 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3482 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3485 if (oper == ConvOperationKind::Conv)
3486 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3488 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3489 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3500 kwSize, strideW, dilationW, wSizeStep,
3503 if (oper == ConvOperationKind::Conv)
3506 wSizeStep, isSingleChanneled);
3508 auto linearIndex = [&](int64_t kw, int64_t w) {
3509 return kw * (wSize / wSizeStep) + w;
3515 for (int64_t kw = 0; kw < kwSize; ++kw) {
3516 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3518 case ConvOperationKind::Conv:
3519 if (isSingleChanneled) {
3520 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3521 lhsVals[linearIndex(kw, w)],
3522 rhsVals[kw], resVals[w]);
3524 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3525 lhsVals[linearIndex(kw, w)],
3526 rhsVals[kw], resVals[w]);
3529 case ConvOperationKind::Pool:
3530 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3546 switch (conv1DOpOrder) {
3553 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3554 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3560 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3568 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3569 if (srcElementType == dstElementType)
3574 const Type dstType =
3575 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3577 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3578 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3581 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3582 srcWidth < dstWidth)
3583 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3585 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3586 srcWidth < dstWidth)
3587 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3589 assert(
false &&
"unhandled promotion case");
3596 vector::IteratorType par = vector::IteratorType::parallel;
3597 vector::IteratorType red = vector::IteratorType::reduction;
3602 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3604 MapList{{n, w, c}, {c, f}, {n, w, f}},
3606 contrationOp.setKind(reductionKind);
3607 return contrationOp;
3614 return rewriter.
create<vector::OuterProductOp>(
3615 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3637 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3638 bool channelDimScalableFlag,
3640 bool scalableChDim =
false;
3641 bool useMasking =
false;
3642 int64_t nSize, wSize, cSize, kwSize;
3645 if (ShapedType::isDynamic(cSize)) {
3646 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3647 cSize = channelDimVecSize;
3651 scalableChDim = channelDimScalableFlag;
3655 assert(!(useMasking && flatten) &&
3656 "Unsupported flattened conv with dynamic shapes");
3661 vector::TransferWriteOp write;
3662 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3667 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3669 Type lhsEltType = lhsShapedType.getElementType();
3670 Type rhsEltType = rhsShapedType.getElementType();
3671 Type resEltType = resShapedType.getElementType();
3676 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3678 lhsEltType, {
false,
false, scalableChDim});
3679 VectorType rhsType =
3681 {
false, scalableChDim});
3682 VectorType resType =
3684 {
false,
false, scalableChDim});
3697 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3698 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3702 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3705 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3712 Value lhs = rewriter.
create<vector::TransferReadOp>(
3713 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3714 auto maybeMaskedLhs = maybeMaskXferOp(
3715 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3718 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3720 auto maybeMaskedRhs = maybeMaskXferOp(
3721 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3724 Value res = rewriter.
create<vector::TransferReadOp>(
3725 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3726 auto maybeMaskedRes = maybeMaskXferOp(
3727 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3739 for (int64_t kw = 0; kw < kwSize; ++kw) {
3740 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3741 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3742 loc, maybeMaskedLhs->getResult(0),
3744 inOutSliceSizes, inOutStrides));
3748 for (int64_t kw = 0; kw < kwSize; ++kw) {
3749 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3750 loc, maybeMaskedRhs->getResult(0),
3754 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3755 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3756 loc, maybeMaskedRes->getResult(0),
3761 auto linearIndex = [&](int64_t kw, int64_t w) {
3762 return kw * (wSize / wSizeStep) + w;
3768 auto lhsTypeAfterFlattening =
3770 auto resTypeAfterFlattening =
3774 for (int64_t kw = 0; kw < kwSize; ++kw) {
3775 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3776 Value lhsVal = lhsVals[linearIndex(kw, w)];
3777 Value resVal = resVals[w];
3781 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3782 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3783 resVal = rewriter.
create<vector::ShapeCastOp>(
3784 loc, resTypeAfterFlattening, resVals[w]);
3786 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3787 rhsVals[kw], resVal, flatten);
3790 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3797 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3799 for (
auto &collection :
3800 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3801 for (
Value v : collection)
3808 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3809 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3810 loc, resVals[w], maybeMaskedRes->getResult(0),
3820 loc, maybeMaskedRes->getResult(0), resShaped,
3822 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3833 auto rhsTy = cast<ShapedType>(rhs.
getType());
3834 auto resTy = cast<ShapedType>(res.
getType());
3837 lhs =
promote(rewriter, loc, lhs, resTy);
3848 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3849 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3852 for (
int i = 0; i < resSize / rhsSize; ++i) {
3853 for (
int j = 0;
j < rhsSize; ++
j)
3854 indices.push_back(
j);
3857 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3860 rhs = rewriter.
create<vector::BroadcastOp>(
3861 loc, resTy.clone(rhsTy.getElementType()), rhs);
3863 rhs =
promote(rewriter, loc, rhs, resTy);
3868 if (isa<FloatType>(resTy.getElementType()))
3869 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3871 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3872 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3877 FailureOr<Operation *> generateNonChanneledConv() {
3880 if (!iters({Par(), Red()}))
3882 "failed to match conv::W 1-par 1-red");
3885 if (layout({ {w + kw},
3895 FailureOr<Operation *> generateNwcConv() {
3898 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3900 op,
"failed to match conv::Nwc 3-par 2-red");
3903 if (layout({ {n, strideW * w + dilationW * kw, c},
3913 FailureOr<Operation *> generateNcwConv() {
3916 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3918 op,
"failed to match conv::Ncw 3-par 2-red");
3920 if (layout({ {n, c, strideW * w + dilationW * kw},
3930 FailureOr<Operation *> generateNwcPooling() {
3933 if (!iters({Par(), Par(), Par(), Red()}))
3935 "failed to match pooling 3-par 1-red");
3938 if (layout({ {n, strideW * w + dilationW * kw, c},
3948 FailureOr<Operation *> generateNcwPooling() {
3951 if (!iters({Par(), Par(), Par(), Red()}))
3953 "failed to match pooling 3-par 1-red");
3955 if (layout({ {n, c, strideW * w + dilationW * kw},
3965 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3966 bool vecChDimScalableFlag =
false,
3967 bool flatten =
false) {
3970 if (!iters({Par(), Par(), Par(), Red()}))
3972 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3975 if (layout({ {n, strideW * w + dilationW * kw, c},
3978 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3984 ConvOperationKind oper = ConvOperationKind::Conv;
3986 StringAttr poolExtOp;
3987 bool isPoolExt =
false;
3988 int strideW, dilationW;
3989 Value lhsShaped, rhsShaped, resShaped;
3990 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3991 vector::CombiningKind reductionKind;
3994 void setConvOperationKind(
Operation *reduceOp) {
3995 int numBlockArguments =
3996 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3997 if (numBlockArguments == 1) {
4002 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4003 llvm::IsaPred<BlockArgument>);
4004 Operation *feedOp = (*feedValIt).getDefiningOp();
4006 oper = ConvOperationKind::Pool;
4011 oper = ConvOperationKind::Conv;
4015 oper = ConvOperationKind::Pool;
4025 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4026 Conv1DGenerator conv1dGen(rewriter, op);
4027 auto res = conv1dGen.generateNonChanneledConv();
4030 res = conv1dGen.generateNwcConv();
4033 res = conv1dGen.generateNcwConv();
4036 res = conv1dGen.generateNwcPooling();
4039 res = conv1dGen.generateNcwPooling();
4046 uint64_t vecChDimSize = ShapedType::kDynamic;
4047 bool vecChDimScalableFlag =
false;
4048 if (!inputVecSizes.empty()) {
4051 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4052 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4053 "Not a 1D depthwise conv!");
4056 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4057 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4059 vecChDimSize = inputVecSizes[chDimIdx];
4060 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4062 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4063 flatten1DDepthwiseConv);
4072 if (failed(resultOrFail))
4076 rewriter.
eraseOp(op.getOperation());
4079 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
union mlir::linalg::@1200::ArityGroupAndKind::Kind kind
SmallVector< int64_t > innerDimsPos
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore, Value dest, ArrayRef< int64_t > inputVecSizesForLeadingDims, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.