37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "linalg-vectorization"
53 static FailureOr<Operation *>
57 bool flatten1DDepthwiseConv =
false);
92 template <
typename OpType>
95 block.
walk([&](OpType op) {
110 int64_t nSize, int64_t wSize, int64_t cSize,
111 int64_t kwSize,
int strideW,
int dilationW,
112 int64_t wSizeStep,
bool isSingleChanneled) {
114 if (isSingleChanneled) {
119 for (int64_t kw = 0; kw < kwSize; ++kw) {
120 for (int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
179 for (int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
226 bool assumeDynamicDimsMatchVecSizes =
false);
241 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
244 if (dimPermutation.has_value()) {
246 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
248 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
275 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
284 std::optional<AffineMap> maybeMaskingMap);
289 bool isValidMaskingMap(
AffineMap maskingMap) {
342 bool assumeDynamicDimsMatchVecSizes =
false;
346 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
349 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
360 unsigned operandDimPos;
361 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
366 linalgOp.hasPureTensorSemantics()
367 ? (
Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 : (
Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 iterSpaceValueSizes.push_back(dynamicDim);
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
389 if (!inputVectorSizes.empty()) {
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(),
false);
404 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407 if (ShapedType::isDynamicShape(canonicalVecShape))
411 initIterSpaceStaticSizes(linalgOp);
416 if (
failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
426 Value VectorizationState::getOrCreateMaskFor(
428 std::optional<AffineMap> maybeMaskingMap) {
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.
getContext());
449 LDBG() <<
"Masking map: " << maskingMap;
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() <<
"Reusing mask: " << mask;
467 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
471 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
473 if (permutedStaticSizes == maskShape) {
474 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] =
Value();
479 if (assumeDynamicDimsMatchVecSizes) {
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 return std::get<0>(it) == ShapedType::kDynamic
487 : std::get<0>(it) == std::get<1>(it);
490 <<
"Dynamic + static dimensions match vector sizes, masking is not "
492 activeMaskCache[maskingMap] =
Value();
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() <<
"Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() <<
"Trying to mask: " << *opToMask;
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526 LDBG() <<
"No mask required";
531 assert(opToMask &&
"Expected a valid operation to mask");
532 auto maskOp = cast<vector::MaskOp>(
534 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
540 LDBG() <<
"Masked operation: " << *maskOp;
563 "expected projected permutation");
565 assert(res.getNumDims() ==
566 (res.getNumResults() - res.getNumOfZeroResults()) &&
567 "expected reindexed map with same number of dims and results");
603 std::optional<vector::CombiningKind>
605 using ::mlir::vector::CombiningKind;
610 .Case<arith::AddIOp, arith::AddFOp>(
611 [&](
auto op) {
return CombiningKind::ADD; })
612 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
613 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
614 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
615 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
616 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
617 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
619 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
620 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
621 .Case<arith::MulIOp, arith::MulFOp>(
622 [&](
auto op) {
return CombiningKind::MUL; })
623 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
624 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
625 .Default([&](
auto op) {
return std::nullopt; });
636 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
641 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
642 combinerOps.size() != 1)
646 return combinerOps[0];
652 auto dstVecType = dyn_cast<VectorType>(dstType);
654 if (dstVecType.getRank() == 0)
660 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
672 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
673 return vector::MultiDimReductionOp::create(
674 b, reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
678 return llvm::to_vector(
685 return isa<linalg::ReduceOp>(op) ||
686 (isa<linalg::GenericOp>(op) &&
700 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
701 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
710 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
712 auto vectorType = state.getCanonicalVecType(
716 if (vectorType.getRank() > 0) {
719 linalgOp.getRank(outputOperand),
722 assert(value.
getType() == vectorType &&
"Incorrect type");
723 write = vector::TransferWriteOp::create(
724 rewriter, loc, value, outputOperand->
get(), indices, writeMap);
727 if (!isa<VectorType>(value.
getType()))
728 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
729 assert(value.
getType() == vectorType &&
"Incorrect type");
730 write = vector::TransferWriteOp::create(rewriter, loc, value,
734 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
738 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
739 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
744 LDBG() <<
"vectorized op: " << *write;
754 std::function<LogicalResult(
Operation *,
bool)>;
773 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
782 linalgOp.getDpsInitOperand(output.index()), state);
784 newResults.push_back(newResult);
798 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
801 auto loc = indexOp.getLoc();
804 auto dim = indexOp.getDim();
806 auto indexVectorType =
808 state.getScalableVecDims()[dim]);
809 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
813 if (dim == targetShape.size() - 1)
819 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
820 std::swap(permPattern[dim], permPattern.back());
824 auto broadCastOp = vector::BroadcastOp::create(
826 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
828 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
829 std::swap(transposition.back(), transposition[dim]);
831 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
839 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
843 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
848 if (not extractOp.getIndices().empty()) {
849 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
853 if (!llvm::all_of(extractOp->getResultTypes(),
854 VectorType::isValidElementType)) {
873 tensor::ExtractOp extractOp,
876 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
877 auto loc = extractOp.getLoc();
880 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
882 const size_t numIndices = extractOp.getIndices().size();
883 for (
size_t i = 1; i < numIndices; i++) {
888 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
891 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
894 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
896 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
922 (linalgOp.hasDynamicShape() ||
923 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
924 "For statically shaped Linalg Ops, only one "
925 "non-unit loop dim is expected");
926 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
928 size_t idx = loopRanges.size() - 1;
929 for (; idx != 0; idx--)
930 if (loopRanges[idx] != 1)
938 VectorType resType) {
940 assert(((llvm::count_if(resType.getShape(),
941 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942 "n-D vectors are not yet supported");
948 auto *block = linalgOp.getBlock();
949 if (isa<BlockArgument>(val))
950 return !llvm::is_contained(block->getArguments(), val);
953 assert(defOp &&
"This is neither a block argument nor an operation result");
958 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
962 auto *ancestor = block->findAncestorOpInBlock(*defOp);
969 if (isa<arith::ConstantOp>(ancestor))
973 for (
auto op : ancestor->getOperands())
997 bool &foundIndexOp, VectorType resType) {
999 assert(((llvm::count_if(resType.getShape(),
1000 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1001 "n-D vectors are not yet supported");
1007 auto *block = linalgOp.getBlock();
1008 if (isa<BlockArgument>(val))
1009 return !llvm::is_contained(block->getArguments(), val);
1012 assert(defOp &&
"This is neither a block argument nor an operation result");
1014 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1017 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1021 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1028 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1031 bool result =
false;
1032 for (
auto op : ancestor->getOperands())
1052 LinalgOp &linalgOp, VectorType resType) {
1054 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1057 if (inputShape.getShape().empty())
1062 bool isOutput1DVector =
1063 (llvm::count_if(resType.getShape(),
1064 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1066 if (!isOutput1DVector)
1069 bool leadingIdxsLoopInvariant =
true;
1075 auto indices = extractOp.getIndices();
1076 auto leadIndices = indices.drop_back(1);
1079 if (inputShape.getShape()[i] == 1)
1085 if (!leadingIdxsLoopInvariant) {
1086 LDBG() <<
"Found gather load: " << extractOp;
1094 auto extractOpTrailingIdx = indices.back();
1098 if (leadingIdxsLoopInvariant &&
1100 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1109 bool foundIndexOp =
false;
1111 foundIndexOp, resType);
1114 bool isRowVector = resType.getShape().back() != 1;
1115 isContiguousLoad &= (foundIndexOp && isRowVector);
1117 if (isContiguousLoad) {
1118 LDBG() <<
"Found contigous load: " << extractOp;
1123 LDBG() <<
"Found gather load: " << extractOp;
1134 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1137 auto loc = extractOp.getLoc();
1140 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1141 auto maskConstantOp = arith::ConstantOp::create(
1145 auto passThruConstantOp = arith::ConstantOp::create(
1151 extractOp.getIndices().size(),
1162 Operation *gatherOp = vector::GatherOp::create(
1163 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1164 maskConstantOp, passThruConstantOp);
1165 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1167 LDBG() <<
"Vectorised as gather load: " << extractOp;
1190 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1191 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1193 transferReadIdxs.push_back(idx);
1197 auto indexAs1dVector = vector::ShapeCastOp::create(
1200 resultType.getScalableDims().back()),
1202 transferReadIdxs.push_back(
1203 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1207 auto dstRank = resultType.getRank();
1208 auto srcRank = extractOp.getTensor().getType().getRank();
1217 auto transferReadOp = vector::TransferReadOp::create(
1218 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1219 std::nullopt, permutationMap, inBounds);
1226 auto allTrue = vector::ConstantMaskOp::create(
1228 auto *maskedReadOp =
1231 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1240 int32_t rankDiff = dstRank - srcRank;
1248 while (rankDiff > 0) {
1249 permutationMap = permutationMap.insertResult(
1254 auto transferReadOp = vector::TransferReadOp::create(
1255 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1256 std::nullopt, permutationMap, inBounds);
1258 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1272 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1273 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1277 (outputType && reduceType.getShape() == outputType.getShape()))
1306 LDBG() <<
"vectorize op " << *op;
1309 if (!customVectorizationHooks.empty()) {
1310 for (
auto &customFunc : customVectorizationHooks) {
1320 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1322 rewriter.
clone(*op)};
1331 auto blockArg = dyn_cast<BlockArgument>(operand);
1332 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1333 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1337 linalgOp.getRegionOutputArgs(),
1338 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1341 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1343 if (!reductionOperands.empty()) {
1344 assert(reductionOperands.size() == 1);
1346 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1347 reductionOperands[0].second, bvm);
1354 VectorType firstMaxRankedType;
1356 auto vecOperand = bvm.
lookup(operand);
1357 assert(vecOperand &&
"Vector operand couldn't be found");
1359 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1360 if (vecType && (!firstMaxRankedType ||
1361 firstMaxRankedType.getRank() < vecType.getRank()))
1362 firstMaxRankedType = vecType;
1368 assert(vecOperand &&
"Vector operand couldn't be found");
1370 if (firstMaxRankedType) {
1373 firstMaxRankedType.getScalableDims());
1376 vecOperands.push_back(vecOperand);
1382 resultTypes.push_back(
1385 firstMaxRankedType.getScalableDims())
1417 static LogicalResult
1421 LDBG() <<
"Vectorizing operation as linalg generic/n";
1422 Block *block = linalgOp.getBlock();
1429 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1431 if (linalgOp.getNumDpsInits() == 0)
1437 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1438 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1439 if (linalgOp.isScalar(opOperand)) {
1440 bvm.
map(bbarg, opOperand->get());
1446 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1449 VectorType readType;
1451 if (linalgOp.isDpsInput(opOperand)) {
1454 readType = state.getCanonicalVecType(elemType);
1461 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1466 Operation *read = vector::TransferReadOp::create(
1467 rewriter, loc, readType, opOperand->get(), indices,
1468 std::nullopt, readMap);
1469 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1474 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1476 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1482 if (readType.getRank() == 0)
1486 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1498 hooks.push_back(vectorizeYield);
1505 hooks.push_back(vectorizeIndex);
1512 hooks.push_back(vectorizeExtract);
1519 LDBG() <<
"failed to vectorize: " << op;
1524 state.maskOperation(rewriter, result.
newOp, linalgOp);
1525 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1591 if (ShapedType::isDynamicShape(destShape))
1598 cstMaskSizes.push_back(*intSize);
1603 if (cstMaskSizes.size() != maskShape.size())
1611 cstWriteIdxs.push_back(intVal.getSExtValue());
1616 if (cstWriteIdxs.size() != destShape.size())
1625 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1627 if ( maskShape[i] > destShape[rankDiff + i] ||
1628 destShape[rankDiff + i] <
1629 (
std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1665 bool useInBoundsInsteadOfMasking =
false) {
1667 ShapedType destType = cast<ShapedType>(dest.
getType());
1668 int64_t destRank = destType.getRank();
1669 auto destShape = destType.getShape();
1671 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1672 int64_t vecToStoreRank = vecToStoreType.getRank();
1673 auto vecToStoreShape = vecToStoreType.getShape();
1677 if (useInBoundsInsteadOfMasking) {
1680 for (
unsigned i = 0; i < vecToStoreRank; i++)
1682 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1683 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1687 assert((writeIndices.empty() ||
1688 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1689 "Invalid number of write indices!");
1690 if (writeIndices.empty()) {
1692 writeIndices.assign(destRank, zero);
1696 Operation *write = vector::TransferWriteOp::create(builder, loc,
1703 if (useInBoundsInsteadOfMasking)
1707 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1712 vecToStoreType.getScalableDims());
1715 isa<MemRefType>(dest.
getType())
1725 Value maskForWrite =
1726 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1764 static LogicalResult
1773 auto padValue = packOp.getPaddingValue();
1775 padValue = arith::ConstantOp::create(
1777 rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1783 bool useInBoundsInsteadOfMasking =
false;
1784 if (inputVectorSizes.empty()) {
1786 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1787 useInBoundsInsteadOfMasking =
true;
1792 auto innerTiles = packOp.getStaticInnerTiles();
1801 rewriter, loc, packOp.getSource(), inputShape, padValue,
1802 useInBoundsInsteadOfMasking,
1809 packOp.getDestType().getElementType());
1811 vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1814 auto destPermutation =
1816 auto transposeOp = vector::TransposeOp::create(
1817 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1821 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1822 newResults.push_back(write->
getResult(0));
1844 assert(type.getNumScalableDims() < 2 &&
1845 "Collapsing more than 1 scalable dim is not supported ATM");
1851 auto shape = type.getShape();
1852 auto scalableFlags = type.getScalableDims();
1856 unsigned currentDim = 0;
1858 unsigned dim = m.getNumResults();
1861 for (
unsigned d = 0; d < dim; ++d) {
1862 size *= shape[currentDim + d];
1863 flag |= scalableFlags[currentDim + d];
1865 newShape.push_back(size);
1866 newScalableFlags.push_back(flag);
1870 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1902 static LogicalResult
1907 if (!inputVectorSizes.empty()) {
1908 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1909 "Invalid number of input vector sizes!");
1910 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1911 "Incompatible number of vector sizes and vector scalable flags!");
1918 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1921 bool useInBoundsInsteadOfMasking =
false;
1930 if (inputVectorSizes.empty()) {
1931 if (ShapedType::isDynamicShape(sourceShape))
1934 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1935 useInBoundsInsteadOfMasking =
true;
1939 auto padValue = arith::ConstantOp::create(
1941 rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1943 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1944 useInBoundsInsteadOfMasking, readScalableVectorFlags);
1947 PackingMetadata packMetadata;
1950 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1951 rewriter, loc, readResult, lastDimToInsertPosPerm);
1955 transposeOp.getType(),
1957 rewriter.
getContext(), packMetadata.reassociations)));
1958 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1959 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1963 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1964 {}, useInBoundsInsteadOfMasking);
1966 newResults.push_back(write->
getResult(0));
1973 static LogicalResult
1977 auto padValue = padOp.getConstantPaddingValue();
1985 LogicalResult status =
1986 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1987 .reifyResultShapes(rewriter, reifiedReturnShapes);
1989 assert(succeeded(status) &&
"failed to reify result shapes");
1991 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1995 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
1996 padOp.getResultType().getElementType());
1998 newResults.push_back(write->
getResult(0));
2006 LDBG() <<
"reduction precondition failed: no reduction iterator";
2009 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2010 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2016 LDBG() <<
"reduction precondition failed: reduction detection failed";
2023 static LogicalResult
2025 bool flatten1DDepthwiseConv) {
2026 if (flatten1DDepthwiseConv) {
2027 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2032 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2033 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2039 Value lhs = conv.getDpsInputOperand(0)->get();
2041 auto shapeWithoutCh = lhsShape.drop_back(1);
2042 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2043 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2044 "channel dim can be dynamic";
2051 static LogicalResult
2053 bool flatten1DDepthwiseConv) {
2054 if (isa<ConvolutionOpInterface>(op.getOperation()))
2063 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2067 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2076 static LogicalResult
2081 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2082 unpackOp.getSourceType().hasStaticShape())
2087 if (!inputVectorSizes.empty() &&
2088 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2089 LDBG() <<
"Incorrect number of input vector sizes";
2095 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2096 LDBG() <<
"Invalid vector sizes for the read operation";
2103 static LogicalResult
2108 auto sourceType = source.getType();
2109 if (!VectorType::isValidElementType(sourceType.getElementType()))
2125 bool isOutOfBoundsRead =
2126 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2128 if (!padValue && isOutOfBoundsRead) {
2129 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2142 static LogicalResult
2152 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2155 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2159 LDBG() <<
"Failed to determine contraction combining kind.";
2166 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2167 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2169 LDBG() <<
"Contractions with broadcasts are not supported.";
2175 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2179 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2183 VectorType readType =
2184 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2187 rewriter, loc, opOperand.get(), readType.getShape(),
2189 false, readType.getScalableDims());
2190 vecOperands.push_back(read);
2195 auto iterators = linalgOp.getIteratorTypesArray();
2196 for (utils::IteratorType iter : iterators) {
2197 auto vecIter = iter == utils::IteratorType::parallel
2198 ? vector::IteratorType::parallel
2199 : vector::IteratorType::reduction;
2204 Operation *contractOp = vector::ContractionOp::create(
2205 rewriter, loc, vecOperands[0],
2206 vecOperands[1], vecOperands[2],
2207 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2208 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2212 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2216 newResults.push_back(write->
getResult(0));
2222 enum class ConvOperationKind { Conv, Pool };
2240 static std::optional<ConvOperationKind>
2242 int numBlockArguments =
2243 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2245 switch (numBlockArguments) {
2251 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2252 llvm::IsaPred<BlockArgument>);
2254 "Expected a non-block argument operand");
2255 Operation *feedOp = (*feedValIt).getDefiningOp();
2257 return ConvOperationKind::Pool;
2260 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2261 (isa<arith::AndIOp>(feedOp) &&
2264 if (isa<BlockArgument>(v))
2266 if (Operation *op = v.getDefiningOp())
2267 return isCastOfBlockArgument(op);
2270 return std::nullopt;
2273 return ConvOperationKind::Conv;
2277 return ConvOperationKind::Pool;
2279 return std::nullopt;
2285 case vector::CombiningKind::ADD:
2286 case vector::CombiningKind::MAXNUMF:
2287 case vector::CombiningKind::MAXIMUMF:
2288 case vector::CombiningKind::MAXSI:
2289 case vector::CombiningKind::MAXUI:
2290 case vector::CombiningKind::MINNUMF:
2291 case vector::CombiningKind::MINIMUMF:
2292 case vector::CombiningKind::MINSI:
2301 auto getOperandType = [&](
auto operand) {
2302 return dyn_cast<ShapedType>((operand->get()).getType());
2304 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2305 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2306 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2310 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2311 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2319 if (!maybeOper.has_value())
2326 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2327 *maybeKind != vector::CombiningKind::OR) &&
2328 (*maybeOper != ConvOperationKind::Pool ||
2333 auto rhsRank = rhsShapedType.getRank();
2334 if (*maybeOper == ConvOperationKind::Pool) {
2338 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2347 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2349 if (llvm::any_of(linalgOp->getOpOperands(), [&](
OpOperand &operand) {
2350 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2354 if (!inputVectorSizes.empty() &&
2360 linalgOp, flatten1DDepthwiseConv))) {
2361 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2374 customPreconditions,
2377 customPrecondition(&innerOp, vectorizeNDExtract));
2381 if (!llvm::all_of(innerOp.getOperandTypes(),
2382 VectorType::isValidElementType)) {
2385 if (!llvm::all_of(innerOp.getResultTypes(),
2386 VectorType::isValidElementType)) {
2396 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2403 LDBG() <<
"precondition failed: not projected permutations";
2407 LDBG() <<
"precondition failed: reduction preconditions";
2413 static LogicalResult
2416 auto padValue = packOp.getPaddingValue();
2419 LDBG() <<
"pad value is not constant: " << packOp;
2424 bool satisfyEmptyCond =
true;
2425 if (inputVectorSizes.empty()) {
2426 if (!packOp.getDestType().hasStaticShape() ||
2427 !packOp.getSourceType().hasStaticShape())
2428 satisfyEmptyCond =
false;
2431 if (!satisfyEmptyCond &&
2433 resultTensorShape.take_front(packOp.getSourceRank()),
2437 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2438 return !getConstantIntValue(v).has_value();
2440 LDBG() <<
"inner_tiles must be constant: " << packOp;
2447 static LogicalResult
2450 auto padValue = padOp.getConstantPaddingValue();
2452 LDBG() <<
"pad value is not constant: " << padOp;
2472 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2473 Value padValue = en.value();
2474 unsigned pos = en.index();
2475 std::optional<int64_t> pad = getConstantIntValue(padValue);
2476 return (!pad.has_value() || pad.value() != 0) &&
2477 resultTensorShape[pos] != 1;
2479 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2492 static LogicalResult
2496 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2497 "Number of input vector sizes and scalable dims doesn't match");
2499 size_t numOfScalableDims =
2500 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2502 if (numOfScalableDims == 0)
2505 auto linalgOp = dyn_cast<LinalgOp>(op);
2510 return success(isa<linalg::UnPackOp>(op));
2514 if (numOfScalableDims > 2)
2534 bool seenNonUnitParallel =
false;
2535 auto iterators = linalgOp.getIteratorTypesArray();
2537 int64_t idx = scalableFlags.size() - 1;
2538 while (!scalableFlags[idx]) {
2539 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2540 seenNonUnitParallel |=
2541 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2543 iterators.pop_back();
2544 scalableFlags.pop_back();
2549 switch (iterators.back()) {
2550 case utils::IteratorType::reduction: {
2552 if (iterators.size() != inputVectorSizes.size()) {
2553 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2557 if (isa<linalg::MatmulOp>(op)) {
2559 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2565 case utils::IteratorType::parallel: {
2567 if (seenNonUnitParallel) {
2568 LDBG() <<
"Inner parallel dim not requested for scalable "
2580 if (numOfScalableDims == 2) {
2584 if (iterators.back() == utils::IteratorType::reduction) {
2585 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2590 scalableFlags.pop_back();
2591 iterators.pop_back();
2593 if (!scalableFlags.back() ||
2594 (iterators.back() != utils::IteratorType::parallel))
2600 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2601 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2602 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2603 isa<linalg::BatchMmt4DOp>(op) ||
2610 bool flatten1DDepthwiseConv) {
2616 inputScalableVecDims)))
2620 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2623 flatten1DDepthwiseConv);
2625 .Case<tensor::PadOp>([&](
auto padOp) {
2628 .Case<linalg::PackOp>([&](
auto packOp) {
2631 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2634 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2637 .Default([](
auto) {
return failure(); });
2643 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2645 for (
auto op : make_early_inc_range(toReplace)) {
2648 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2649 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2650 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2656 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2657 tensor::InsertSliceOp>(op);
2663 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2664 bool createNamedContraction) {
2665 LDBG() <<
"Attempting to vectorize: " << *op;
2666 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2667 LDBG() <<
"Input scalable vector dims: "
2668 << llvm::interleaved(inputScalableVecDims);
2672 flatten1DDepthwiseConv))) {
2673 LDBG() <<
"Vectorization pre-conditions failed";
2679 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2680 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2681 inputScalableVecDims,
2682 assumeDynamicDimsMatchVecSizes))) {
2683 LDBG() <<
"Vectorization state couldn't be initialized";
2689 auto vectorizeResult =
2691 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2695 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2697 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2698 flatten1DDepthwiseConv);
2699 if (succeeded(convOr)) {
2700 llvm::append_range(results, (*convOr)->getResults());
2704 LDBG() <<
"Unsupported convolution can't be vectorized.";
2708 if (createNamedContraction &&
2709 isa<ContractionOpInterface>(linalgOp.getOperation()))
2714 <<
"Vectorize generic by broadcasting to the canonical vector "
2727 .Case<tensor::PadOp>([&](
auto padOp) {
2731 .Case<linalg::PackOp>([&](
auto packOp) {
2735 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2738 inputScalableVecDims, results);
2740 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2744 .Default([](
auto) {
return failure(); });
2746 if (
failed(vectorizeResult)) {
2747 LDBG() <<
"Vectorization failed";
2755 memref::CopyOp copyOp) {
2756 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2757 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2758 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2763 if (!VectorType::isValidElementType(srcElementType) ||
2764 !VectorType::isValidElementType(dstElementType))
2775 rewriter, loc, readType, copyOp.getSource(), indices,
2778 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2782 vector::BroadcastOp::create(rewriter, loc, writeType,
readValue);
2784 Operation *writeValue = vector::TransferWriteOp::create(
2785 rewriter, loc,
readValue, copyOp.getTarget(), indices,
2796 template <
typename OpTy>
2804 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2805 if (
auto op = dyn_cast<OpTy>(user))
2806 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2812 tensor::PadOp padOp, OpTy op)
const = 0;
2840 vector::TransferReadOp xferOp)
const override {
2842 if (!padOp.hasZeroLowPad())
2845 auto padValue = padOp.getConstantPaddingValue();
2849 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2854 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2856 xferOp.getBaseMutable().assign(padOp.getSource());
2857 xferOp.getPaddingMutable().assign(padValue);
2902 vector::TransferWriteOp xferOp)
const override {
2904 if (xferOp.getTransferRank() == 0)
2908 if (!padOp.hasZeroLowPad())
2911 auto padValue = padOp.getConstantPaddingValue();
2915 if (!xferOp->hasOneUse())
2917 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2921 if (!trimPadding.hasZeroOffset())
2924 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2932 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2933 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2935 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2951 tensor::ExtractSliceOp afterTrimming)
const {
2954 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2955 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2958 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2959 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2964 if (t1.getRank() != t2.getRank())
2969 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2970 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2972 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2977 if (t1.getNumDynamicDims() == 0)
2985 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2989 assert(
static_cast<size_t>(t1.getRank()) ==
2990 beforeSlice.getMixedSizes().size());
2991 assert(
static_cast<size_t>(t2.getRank()) ==
2992 afterTrimming.getMixedSizes().size());
2994 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2996 if (!t1.isDynamicDim(i))
2998 auto size1 = beforeSlice.getMixedSizes()[i];
2999 auto size2 = afterTrimming.getMixedSizes()[i];
3006 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3007 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3013 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3014 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3015 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3016 minOp1.getOperands() == minOp2.getOperands())
3042 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3043 auto source = bcast.getSource();
3044 if (llvm::dyn_cast<VectorType>(source.getType()))
3052 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3053 return fill.getInputs()[0];
3058 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3065 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3073 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3079 static LogicalResult
3088 auto sourceType = source.getType();
3089 auto resultType = sliceOp.getResultType();
3094 auto elemType = sourceType.getElementType();
3095 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3101 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3102 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3103 if (!inputVectorSizes.empty()) {
3104 vecShape.push_back(inputVectorSizes[i]);
3105 }
else if (!sourceType.isDynamicDim(i)) {
3106 vecShape.push_back(sourceType.getDimSize(i));
3107 }
else if (!resultType.isDynamicDim(i)) {
3113 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3120 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
3123 auto loc = sliceOp.getLoc();
3129 rewriter, loc, source, vecType.getShape(), padValue,
3130 inputVectorSizes.empty(),
3138 writeIndices, inputVectorSizes.empty());
3141 newResults.push_back(write->
getResult(0));
3175 tensor::InsertSliceOp insertOp)
const override {
3177 if (!padOp.hasZeroLowPad())
3180 if (!insertOp.hasUnitStride())
3183 auto padValue = padOp.getConstantPaddingValue();
3187 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3190 if (insertOp.getDest() == padOp.getResult())
3194 padOp.getType().getElementType());
3195 unsigned vecRank = vecType.getRank();
3196 unsigned tensorRank = insertOp.getType().getRank();
3201 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3203 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3204 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3216 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3217 vecType, padOp.getSource(),
3218 readIndices, padValue);
3224 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3227 insertOp, read, insertOp.getDest(), writeIndices,
3253 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3254 <<
", second op: " << *secondOp;
3257 for (
auto v : values) {
3258 for (
auto &u : v.getUses()) {
3260 if (owner == firstOp || owner == secondOp)
3266 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3267 <<
", second op: " << *secondOp;
3277 memref::SubViewOp subViewOp;
3279 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3281 return memref::SubViewOp();
3282 subViewOp = newSubViewOp;
3294 if (xferOp.getMask())
3298 Value viewOrAlloc = xferOp.getBase();
3307 Value subView = subViewOp.getResult();
3310 memref::CopyOp copyOp;
3311 for (
auto &u : subView.
getUses()) {
3312 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3313 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3314 if (newCopyOp.getTarget() != subView)
3328 for (
auto &u : viewOrAlloc.
getUses()) {
3329 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3330 assert(isa<MemRefType>(newFillOp.output().getType()));
3331 if (newFillOp.output() != viewOrAlloc)
3335 maybeFillOp = newFillOp;
3340 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3342 "padding value does not match fill");
3345 Value in = copyOp.getSource();
3351 auto vectorType = xferOp.getVectorType();
3352 Value res = vector::TransferReadOp::create(
3353 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3354 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3359 rewriter.
eraseOp(maybeFillOp);
3371 if (xferOp.getMask())
3375 Value viewOrAlloc = xferOp.getBase();
3384 Value subView = subViewOp.getResult();
3387 memref::CopyOp copyOp;
3388 for (
auto &u : subViewOp.getResult().getUses()) {
3389 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3390 if (newCopyOp.getSource() != subView)
3402 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3403 Value out = copyOp.getTarget();
3410 auto vector = xferOp.getVector();
3411 vector::TransferWriteOp::create(
3412 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3413 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3415 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3430 template <
int N,
typename IntTy,
typename... IntTy2>
3432 val = shapedType.getShape()[N];
3437 template <
typename... IntTy>
3439 bindShapeDims<0>(shapedType, vals...);
3477 struct Conv1DGenerator
3479 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3482 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3483 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3484 resShaped = linalgOp.getDpsInitOperand(0)->get();
3485 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3486 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3487 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3492 setConvOperationKind(reduceOp);
3495 reductionKind = maybeKind.value();
3503 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3504 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3526 int64_t nSize, wSize, cSize, kwSize, fSize;
3529 switch (conv1DOpOrder) {
3532 nSize = fSize = cSize = 0;
3539 (wSize + kwSize - 1)};
3540 rhsShape = {kwSize};
3547 case ConvOperationKind::Conv:
3551 case ConvOperationKind::Pool:
3561 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3565 case ConvOperationKind::Conv:
3566 rhsShape = {kwSize, cSize, fSize};
3568 case ConvOperationKind::Pool:
3569 rhsShape = {kwSize};
3572 resShape = {nSize, wSize, fSize};
3578 case ConvOperationKind::Conv:
3582 case ConvOperationKind::Pool:
3588 lhsShape = {nSize, cSize,
3592 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3595 case ConvOperationKind::Conv:
3596 rhsShape = {fSize, cSize, kwSize};
3598 case ConvOperationKind::Pool:
3599 rhsShape = {kwSize};
3602 resShape = {nSize, fSize, wSize};
3606 vector::TransferWriteOp write;
3612 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3614 Type lhsEltType = lhsShapedType.getElementType();
3615 Type rhsEltType = rhsShapedType.getElementType();
3616 Type resEltType = resShapedType.getElementType();
3626 Value lhs = vector::TransferReadOp::create(
3627 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3630 Value rhs =
nullptr;
3631 if (oper == ConvOperationKind::Conv)
3632 rhs = vector::TransferReadOp::create(
3633 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3635 Value res = vector::TransferReadOp::create(
3636 rewriter, loc, resType, resShaped, resPadding,
3642 switch (conv1DOpOrder) {
3650 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3651 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3653 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3656 if (oper == ConvOperationKind::Conv)
3657 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3659 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3660 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3671 kwSize, strideW, dilationW, wSizeStep,
3674 if (oper == ConvOperationKind::Conv)
3677 wSizeStep, isSingleChanneled);
3679 auto linearIndex = [&](int64_t kw, int64_t w) {
3680 return kw * (wSize / wSizeStep) + w;
3686 for (int64_t kw = 0; kw < kwSize; ++kw) {
3687 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3689 case ConvOperationKind::Conv:
3690 if (isSingleChanneled) {
3691 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3692 lhsVals[linearIndex(kw, w)],
3693 rhsVals[kw], resVals[w]);
3695 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3696 lhsVals[linearIndex(kw, w)],
3697 rhsVals[kw], resVals[w]);
3700 case ConvOperationKind::Pool:
3701 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3717 switch (conv1DOpOrder) {
3724 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3725 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3730 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3739 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3740 if (srcElementType == dstElementType)
3745 const Type dstType =
3746 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3748 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3749 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3752 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3753 srcWidth < dstWidth)
3754 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3756 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3757 srcWidth < dstWidth)
3758 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3760 assert(
false &&
"unhandled promotion case");
3767 vector::IteratorType par = vector::IteratorType::parallel;
3768 vector::IteratorType red = vector::IteratorType::reduction;
3773 auto contrationOp = vector::ContractionOp::create(
3774 rewriter, loc, lhs, rhs, res,
3775 MapList{{n, w, c}, {c, f}, {n, w, f}},
3777 contrationOp.setKind(reductionKind);
3778 return contrationOp;
3785 return vector::OuterProductOp::create(rewriter, loc, res.
getType(), lhs,
3786 rhs, res, vector::CombiningKind::ADD);
3808 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3809 bool channelDimScalableFlag,
3811 bool scalableChDim =
false;
3812 bool useMasking =
false;
3813 int64_t nSize, wSize, cSize, kwSize;
3816 if (ShapedType::isDynamic(cSize)) {
3817 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3818 cSize = channelDimVecSize;
3822 scalableChDim = channelDimScalableFlag;
3826 assert(!(useMasking && flatten) &&
3827 "Unsupported flattened conv with dynamic shapes");
3832 vector::TransferWriteOp write;
3838 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3840 Type lhsEltType = lhsShapedType.getElementType();
3841 Type rhsEltType = rhsShapedType.getElementType();
3842 Type resEltType = resShapedType.getElementType();
3847 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3849 lhsEltType, {
false,
false, scalableChDim});
3850 VectorType rhsType =
3852 {
false, scalableChDim});
3853 VectorType resType =
3855 {
false,
false, scalableChDim});
3868 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3869 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3873 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3876 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3883 Value lhs = vector::TransferReadOp::create(
3884 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3886 auto maybeMaskedLhs = maybeMaskXferOp(
3887 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3890 Value rhs = vector::TransferReadOp::create(
3891 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3893 auto maybeMaskedRhs = maybeMaskXferOp(
3894 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3897 Value res = vector::TransferReadOp::create(
3898 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3900 auto maybeMaskedRes = maybeMaskXferOp(
3901 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3913 for (int64_t kw = 0; kw < kwSize; ++kw) {
3914 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3915 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3916 rewriter, loc, maybeMaskedLhs->getResult(0),
3918 inOutSliceSizes, inOutStrides));
3922 for (int64_t kw = 0; kw < kwSize; ++kw) {
3924 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3928 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3929 resVals.push_back(vector::ExtractStridedSliceOp::create(
3930 rewriter, loc, maybeMaskedRes->getResult(0),
3935 auto linearIndex = [&](int64_t kw, int64_t w) {
3936 return kw * (wSize / wSizeStep) + w;
3942 auto lhsTypeAfterFlattening =
3944 auto resTypeAfterFlattening =
3948 for (int64_t kw = 0; kw < kwSize; ++kw) {
3949 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3950 Value lhsVal = lhsVals[linearIndex(kw, w)];
3951 Value resVal = resVals[w];
3956 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3957 lhsVals[linearIndex(kw, w)]);
3958 resVal = vector::ShapeCastOp::create(
3959 rewriter, loc, resTypeAfterFlattening, resVals[w]);
3961 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3962 rhsVals[kw], resVal, flatten);
3965 resVals[w] = vector::ShapeCastOp::create(
3973 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3975 for (
auto &collection :
3976 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3977 for (
Value v : collection)
3984 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3985 maybeMaskedRes = vector::InsertStridedSliceOp::create(
3986 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
3995 Operation *resOut = vector::TransferWriteOp::create(
3996 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
3998 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4009 auto rhsTy = cast<ShapedType>(rhs.
getType());
4010 auto resTy = cast<ShapedType>(res.
getType());
4013 lhs =
promote(rewriter, loc, lhs, resTy);
4024 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
4025 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4028 for (
int i = 0; i < resSize / rhsSize; ++i) {
4029 for (
int j = 0;
j < rhsSize; ++
j)
4030 indices.push_back(
j);
4033 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4036 rhs = vector::BroadcastOp::create(rewriter, loc,
4037 resTy.clone(rhsTy.getElementType()), rhs);
4039 rhs =
promote(rewriter, loc, rhs, resTy);
4044 if (isa<FloatType>(resTy.getElementType()))
4045 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4047 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4048 return arith::AddIOp::create(rewriter, loc, mul, res);
4053 FailureOr<Operation *> generateNonChanneledConv() {
4056 if (!iters({Par(), Red()}))
4058 "failed to match conv::W 1-par 1-red");
4061 if (layout({ {w + kw},
4071 FailureOr<Operation *> generateNwcConv() {
4074 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4076 op,
"failed to match conv::Nwc 3-par 2-red");
4079 if (layout({ {n, strideW * w + dilationW * kw, c},
4089 FailureOr<Operation *> generateNcwConv() {
4092 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4094 op,
"failed to match conv::Ncw 3-par 2-red");
4096 if (layout({ {n, c, strideW * w + dilationW * kw},
4106 FailureOr<Operation *> generateNwcPooling() {
4109 if (!iters({Par(), Par(), Par(), Red()}))
4111 "failed to match pooling 3-par 1-red");
4114 if (layout({ {n, strideW * w + dilationW * kw, c},
4124 FailureOr<Operation *> generateNcwPooling() {
4127 if (!iters({Par(), Par(), Par(), Red()}))
4129 "failed to match pooling 3-par 1-red");
4131 if (layout({ {n, c, strideW * w + dilationW * kw},
4141 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4142 bool vecChDimScalableFlag =
false,
4143 bool flatten =
false) {
4146 if (!iters({Par(), Par(), Par(), Red()}))
4148 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4151 if (layout({ {n, strideW * w + dilationW * kw, c},
4154 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4160 ConvOperationKind oper = ConvOperationKind::Conv;
4162 StringAttr poolExtOp;
4163 bool isPoolExt =
false;
4164 int strideW, dilationW;
4165 Value lhsShaped, rhsShaped, resShaped;
4166 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4167 vector::CombiningKind reductionKind;
4170 void setConvOperationKind(
Operation *reduceOp) {
4171 int numBlockArguments =
4172 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4173 if (numBlockArguments == 1) {
4178 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4179 llvm::IsaPred<BlockArgument>);
4180 Operation *feedOp = (*feedValIt).getDefiningOp();
4182 oper = ConvOperationKind::Pool;
4187 oper = ConvOperationKind::Conv;
4191 oper = ConvOperationKind::Pool;
4201 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4202 Conv1DGenerator conv1dGen(rewriter, op);
4203 auto res = conv1dGen.generateNonChanneledConv();
4206 res = conv1dGen.generateNwcConv();
4209 res = conv1dGen.generateNcwConv();
4212 res = conv1dGen.generateNwcPooling();
4215 res = conv1dGen.generateNcwPooling();
4222 uint64_t vecChDimSize = ShapedType::kDynamic;
4223 bool vecChDimScalableFlag =
false;
4224 if (!inputVecSizes.empty()) {
4227 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4228 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4229 "Not a 1D depthwise conv!");
4232 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4233 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4235 vecChDimSize = inputVecSizes[chDimIdx];
4236 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4238 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4239 flatten1DDepthwiseConv);
4248 if (
failed(resultOrFail))
4252 rewriter.
eraseOp(op.getOperation());
4255 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
SmallVector< int64_t > innerDimsPos
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.