37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "linalg-vectorization"
53 static FailureOr<Operation *>
57 bool flatten1DDepthwiseConv =
false);
92 template <
typename OpType>
95 block.
walk([&](OpType op) {
110 int64_t nSize, int64_t wSize, int64_t cSize,
111 int64_t kwSize,
int strideW,
int dilationW,
112 int64_t wSizeStep,
bool isSingleChanneled) {
114 if (isSingleChanneled) {
119 for (int64_t kw = 0; kw < kwSize; ++kw) {
120 for (int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
179 for (int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
226 bool assumeDynamicDimsMatchVecSizes =
false);
241 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
244 if (dimPermutation.has_value()) {
246 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
248 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
275 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
284 std::optional<AffineMap> maybeMaskingMap);
289 bool isValidMaskingMap(
AffineMap maskingMap) {
342 bool assumeDynamicDimsMatchVecSizes =
false;
346 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
349 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
360 unsigned operandDimPos;
361 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
366 linalgOp.hasPureTensorSemantics()
367 ? (
Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 : (
Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 iterSpaceValueSizes.push_back(dynamicDim);
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
389 if (!inputVectorSizes.empty()) {
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(),
false);
404 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407 if (ShapedType::isDynamicShape(canonicalVecShape))
411 initIterSpaceStaticSizes(linalgOp);
416 if (
failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
426 Value VectorizationState::getOrCreateMaskFor(
428 std::optional<AffineMap> maybeMaskingMap) {
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.
getContext());
449 LDBG() <<
"Masking map: " << maskingMap;
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() <<
"Reusing mask: " << mask;
467 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
471 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
473 if (permutedStaticSizes == maskShape) {
474 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] =
Value();
479 if (assumeDynamicDimsMatchVecSizes) {
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 return std::get<0>(it) == ShapedType::kDynamic
487 : std::get<0>(it) == std::get<1>(it);
490 <<
"Dynamic + static dimensions match vector sizes, masking is not "
492 activeMaskCache[maskingMap] =
Value();
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() <<
"Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() <<
"Trying to mask: " << *opToMask;
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526 LDBG() <<
"No mask required";
531 assert(opToMask &&
"Expected a valid operation to mask");
532 auto maskOp = cast<vector::MaskOp>(
534 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
540 LDBG() <<
"Masked operation: " << *maskOp;
563 "expected projected permutation");
565 assert(res.getNumDims() ==
566 (res.getNumResults() - res.getNumOfZeroResults()) &&
567 "expected reindexed map with same number of dims and results");
603 std::optional<vector::CombiningKind>
605 using ::mlir::vector::CombiningKind;
610 .Case<arith::AddIOp, arith::AddFOp>(
611 [&](
auto op) {
return CombiningKind::ADD; })
612 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
613 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
614 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
615 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
616 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
617 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
619 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
620 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
621 .Case<arith::MulIOp, arith::MulFOp>(
622 [&](
auto op) {
return CombiningKind::MUL; })
623 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
624 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
625 .Default([&](
auto op) {
return std::nullopt; });
636 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
641 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
642 combinerOps.size() != 1)
646 return combinerOps[0];
652 auto dstVecType = dyn_cast<VectorType>(dstType);
654 if (dstVecType.getRank() == 0)
660 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
672 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
673 return vector::MultiDimReductionOp::create(
674 b, reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
678 return llvm::to_vector(
685 return isa<linalg::ReduceOp>(op) ||
686 (isa<linalg::GenericOp>(op) &&
700 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
701 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
710 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
712 auto vectorType = state.getCanonicalVecType(
716 if (vectorType.getRank() > 0) {
719 linalgOp.getRank(outputOperand),
722 assert(value.
getType() == vectorType &&
"Incorrect type");
723 write = vector::TransferWriteOp::create(
724 rewriter, loc, value, outputOperand->
get(), indices, writeMap);
727 if (!isa<VectorType>(value.
getType()))
728 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
729 assert(value.
getType() == vectorType &&
"Incorrect type");
730 write = vector::TransferWriteOp::create(rewriter, loc, value,
734 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
738 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
739 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
744 LDBG() <<
"vectorized op: " << *write;
754 std::function<LogicalResult(
Operation *,
bool)>;
773 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
782 linalgOp.getDpsInitOperand(output.index()), state);
784 newResults.push_back(newResult);
798 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
801 auto loc = indexOp.getLoc();
804 auto dim = indexOp.getDim();
806 auto indexVectorType =
808 state.getScalableVecDims()[dim]);
809 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
813 if (dim == targetShape.size() - 1)
819 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
820 std::swap(permPattern[dim], permPattern.back());
824 auto broadCastOp = vector::BroadcastOp::create(
826 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
828 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
829 std::swap(transposition.back(), transposition[dim]);
831 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
839 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
843 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
848 if (not extractOp.getIndices().empty()) {
849 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
853 if (!llvm::all_of(extractOp->getResultTypes(),
854 VectorType::isValidElementType)) {
873 tensor::ExtractOp extractOp,
876 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
877 auto loc = extractOp.getLoc();
880 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
882 const size_t numIndices = extractOp.getIndices().size();
883 for (
size_t i = 1; i < numIndices; i++) {
888 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
891 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
894 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
896 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
922 (linalgOp.hasDynamicShape() ||
923 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
924 "For statically shaped Linalg Ops, only one "
925 "non-unit loop dim is expected");
926 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
928 size_t idx = loopRanges.size() - 1;
929 for (; idx != 0; idx--)
930 if (loopRanges[idx] != 1)
938 VectorType resType) {
940 assert(((llvm::count_if(resType.getShape(),
941 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942 "n-D vectors are not yet supported");
948 auto *block = linalgOp.getBlock();
949 if (isa<BlockArgument>(val))
950 return !llvm::is_contained(block->getArguments(), val);
953 assert(defOp &&
"This is neither a block argument nor an operation result");
958 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
962 auto *ancestor = block->findAncestorOpInBlock(*defOp);
969 if (isa<arith::ConstantOp>(ancestor))
973 for (
auto op : ancestor->getOperands())
997 bool &foundIndexOp, VectorType resType) {
999 assert(((llvm::count_if(resType.getShape(),
1000 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1001 "n-D vectors are not yet supported");
1007 auto *block = linalgOp.getBlock();
1008 if (isa<BlockArgument>(val))
1009 return !llvm::is_contained(block->getArguments(), val);
1012 assert(defOp &&
"This is neither a block argument nor an operation result");
1014 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1017 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1021 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1028 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1031 bool result =
false;
1032 for (
auto op : ancestor->getOperands())
1052 LinalgOp &linalgOp, VectorType resType) {
1054 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1057 if (inputShape.getShape().empty())
1062 bool isOutput1DVector =
1063 (llvm::count_if(resType.getShape(),
1064 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1066 if (!isOutput1DVector)
1069 bool leadingIdxsLoopInvariant =
true;
1075 auto indices = extractOp.getIndices();
1076 auto leadIndices = indices.drop_back(1);
1079 if (inputShape.getShape()[i] == 1)
1085 if (!leadingIdxsLoopInvariant) {
1086 LDBG() <<
"Found gather load: " << extractOp;
1094 auto extractOpTrailingIdx = indices.back();
1098 if (leadingIdxsLoopInvariant &&
1100 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1109 bool foundIndexOp =
false;
1111 foundIndexOp, resType);
1114 bool isRowVector = resType.getShape().back() != 1;
1115 isContiguousLoad &= (foundIndexOp && isRowVector);
1117 if (isContiguousLoad) {
1118 LDBG() <<
"Found contigous load: " << extractOp;
1123 LDBG() <<
"Found gather load: " << extractOp;
1134 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1137 auto loc = extractOp.getLoc();
1140 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1141 auto maskConstantOp = arith::ConstantOp::create(
1145 auto passThruConstantOp = arith::ConstantOp::create(
1151 extractOp.getIndices().size(),
1162 Operation *gatherOp = vector::GatherOp::create(
1163 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1164 maskConstantOp, passThruConstantOp);
1165 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1167 LDBG() <<
"Vectorised as gather load: " << extractOp;
1190 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1191 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1193 transferReadIdxs.push_back(idx);
1197 auto indexAs1dVector = vector::ShapeCastOp::create(
1200 resultType.getScalableDims().back()),
1202 transferReadIdxs.push_back(
1203 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1207 auto dstRank = resultType.getRank();
1208 auto srcRank = extractOp.getTensor().getType().getRank();
1217 auto transferReadOp = vector::TransferReadOp::create(
1218 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1219 std::nullopt, permutationMap, inBounds);
1226 auto allTrue = vector::ConstantMaskOp::create(
1228 auto *maskedReadOp =
1231 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1240 int32_t rankDiff = dstRank - srcRank;
1248 while (rankDiff > 0) {
1249 permutationMap = permutationMap.insertResult(
1254 auto transferReadOp = vector::TransferReadOp::create(
1255 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1256 std::nullopt, permutationMap, inBounds);
1258 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1272 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1273 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1277 (outputType && reduceType.getShape() == outputType.getShape()))
1306 LDBG() <<
"vectorize op " << *op;
1309 if (!customVectorizationHooks.empty()) {
1310 for (
auto &customFunc : customVectorizationHooks) {
1320 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1322 rewriter.
clone(*op)};
1331 auto blockArg = dyn_cast<BlockArgument>(operand);
1332 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1333 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1337 linalgOp.getRegionOutputArgs(),
1338 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1341 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1343 if (!reductionOperands.empty()) {
1344 assert(reductionOperands.size() == 1);
1346 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1347 reductionOperands[0].second, bvm);
1354 VectorType firstMaxRankedType;
1356 auto vecOperand = bvm.
lookup(operand);
1357 assert(vecOperand &&
"Vector operand couldn't be found");
1359 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1360 if (vecType && (!firstMaxRankedType ||
1361 firstMaxRankedType.getRank() < vecType.getRank()))
1362 firstMaxRankedType = vecType;
1368 assert(vecOperand &&
"Vector operand couldn't be found");
1370 if (firstMaxRankedType) {
1373 firstMaxRankedType.getScalableDims());
1376 vecOperands.push_back(vecOperand);
1382 resultTypes.push_back(
1385 firstMaxRankedType.getScalableDims())
1417 static LogicalResult
1421 LDBG() <<
"Vectorizing operation as linalg generic/n";
1422 Block *block = linalgOp.getBlock();
1429 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1431 if (linalgOp.getNumDpsInits() == 0)
1437 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1438 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1439 if (linalgOp.isScalar(opOperand)) {
1440 bvm.
map(bbarg, opOperand->get());
1446 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1449 VectorType readType;
1451 if (linalgOp.isDpsInput(opOperand)) {
1454 readType = state.getCanonicalVecType(elemType);
1461 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1466 Operation *read = vector::TransferReadOp::create(
1467 rewriter, loc, readType, opOperand->get(), indices,
1468 std::nullopt, readMap);
1469 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1474 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1476 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1482 if (readType.getRank() == 0)
1486 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1498 hooks.push_back(vectorizeYield);
1505 hooks.push_back(vectorizeIndex);
1512 hooks.push_back(vectorizeExtract);
1519 LDBG() <<
"failed to vectorize: " << op;
1524 state.maskOperation(rewriter, result.
newOp, linalgOp);
1525 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1591 if (ShapedType::isDynamicShape(destShape))
1598 cstMaskSizes.push_back(*intSize);
1603 if (cstMaskSizes.size() != maskShape.size())
1611 cstWriteIdxs.push_back(intVal.getSExtValue());
1616 if (cstWriteIdxs.size() != destShape.size())
1625 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1627 if ( maskShape[i] > destShape[rankDiff + i] ||
1628 destShape[rankDiff + i] <
1629 (
std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1665 bool useInBoundsInsteadOfMasking =
false) {
1667 ShapedType destType = cast<ShapedType>(dest.
getType());
1668 int64_t destRank = destType.getRank();
1669 auto destShape = destType.getShape();
1671 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1672 int64_t vecToStoreRank = vecToStoreType.getRank();
1673 auto vecToStoreShape = vecToStoreType.getShape();
1677 if (useInBoundsInsteadOfMasking) {
1680 for (
unsigned i = 0; i < vecToStoreRank; i++)
1682 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1683 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1687 assert((writeIndices.empty() ||
1688 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1689 "Invalid number of write indices!");
1690 if (writeIndices.empty()) {
1692 writeIndices.assign(destRank, zero);
1696 Operation *write = vector::TransferWriteOp::create(builder, loc,
1703 if (useInBoundsInsteadOfMasking)
1707 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1712 vecToStoreType.getScalableDims());
1715 isa<MemRefType>(dest.
getType())
1725 Value maskForWrite =
1726 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1764 static LogicalResult
1773 auto padValue = packOp.getPaddingValue();
1775 padValue = arith::ConstantOp::create(
1777 rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1780 LogicalResult status =
1781 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1782 .reifyResultShapes(rewriter, reifiedReturnShapes);
1784 assert(succeeded(status) &&
"failed to reify result shapes");
1789 bool useInBoundsInsteadOfMasking =
false;
1790 if (inputVectorSizes.empty()) {
1792 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1793 useInBoundsInsteadOfMasking =
true;
1798 auto innerTiles = packOp.getStaticInnerTiles();
1807 rewriter, loc, packOp.getSource(), inputShape, padValue,
1808 useInBoundsInsteadOfMasking,
1815 packOp.getDestType().getElementType());
1817 vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1820 auto destPermutation =
1822 auto transposeOp = vector::TransposeOp::create(
1823 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1826 Value dest = tensor::EmptyOp::create(
1827 rewriter, loc, reifiedReturnShapes[0],
1828 transposeOp.getResult().getType().getElementType());
1831 newResults.push_back(write->
getResult(0));
1853 assert(type.getNumScalableDims() < 2 &&
1854 "Collapsing more than 1 scalable dim is not supported ATM");
1860 auto shape = type.getShape();
1861 auto scalableFlags = type.getScalableDims();
1865 unsigned currentDim = 0;
1867 unsigned dim = m.getNumResults();
1870 for (
unsigned d = 0; d < dim; ++d) {
1871 size *= shape[currentDim + d];
1872 flag |= scalableFlags[currentDim + d];
1874 newShape.push_back(size);
1875 newScalableFlags.push_back(flag);
1879 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1911 static LogicalResult
1916 if (!inputVectorSizes.empty()) {
1917 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1918 "Invalid number of input vector sizes!");
1919 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1920 "Incompatible number of vector sizes and vector scalable flags!");
1927 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1930 bool useInBoundsInsteadOfMasking =
false;
1939 if (inputVectorSizes.empty()) {
1940 if (ShapedType::isDynamicShape(sourceShape))
1943 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1944 useInBoundsInsteadOfMasking =
true;
1948 auto padValue = arith::ConstantOp::create(
1950 rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1952 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1953 useInBoundsInsteadOfMasking, readScalableVectorFlags);
1956 PackingMetadata packMetadata;
1959 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1960 rewriter, loc, readResult, lastDimToInsertPosPerm);
1964 transposeOp.getType(),
1966 rewriter.
getContext(), packMetadata.reassociations)));
1967 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1968 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1972 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1973 {}, useInBoundsInsteadOfMasking);
1975 newResults.push_back(write->
getResult(0));
1982 static LogicalResult
1986 auto padValue = padOp.getConstantPaddingValue();
1994 LogicalResult status =
1995 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1996 .reifyResultShapes(rewriter, reifiedReturnShapes);
1998 assert(succeeded(status) &&
"failed to reify result shapes");
2000 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2004 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2005 padOp.getResultType().getElementType());
2007 newResults.push_back(write->
getResult(0));
2015 LDBG() <<
"reduction precondition failed: no reduction iterator";
2018 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2019 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2025 LDBG() <<
"reduction precondition failed: reduction detection failed";
2032 static LogicalResult
2034 bool flatten1DDepthwiseConv) {
2035 if (flatten1DDepthwiseConv) {
2036 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2041 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2042 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2048 Value lhs = conv.getDpsInputOperand(0)->get();
2050 auto shapeWithoutCh = lhsShape.drop_back(1);
2051 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2052 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2053 "channel dim can be dynamic";
2060 static LogicalResult
2062 bool flatten1DDepthwiseConv) {
2063 if (isa<ConvolutionOpInterface>(op.getOperation()))
2072 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2076 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2085 static LogicalResult
2090 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2091 unpackOp.getSourceType().hasStaticShape())
2096 if (!inputVectorSizes.empty() &&
2097 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2098 LDBG() <<
"Incorrect number of input vector sizes";
2104 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2105 LDBG() <<
"Invalid vector sizes for the read operation";
2112 static LogicalResult
2117 auto sourceType = source.getType();
2118 if (!VectorType::isValidElementType(sourceType.getElementType()))
2134 bool isOutOfBoundsRead =
2135 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2137 if (!padValue && isOutOfBoundsRead) {
2138 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2151 static LogicalResult
2161 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2164 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2168 LDBG() <<
"Failed to determine contraction combining kind.";
2175 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2176 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2178 LDBG() <<
"Contractions with broadcasts are not supported.";
2184 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2188 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2192 VectorType readType =
2193 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2196 rewriter, loc, opOperand.get(), readType.getShape(),
2198 false, readType.getScalableDims());
2199 vecOperands.push_back(read);
2204 auto iterators = linalgOp.getIteratorTypesArray();
2205 for (utils::IteratorType iter : iterators) {
2206 auto vecIter = iter == utils::IteratorType::parallel
2207 ? vector::IteratorType::parallel
2208 : vector::IteratorType::reduction;
2213 Operation *contractOp = vector::ContractionOp::create(
2214 rewriter, loc, vecOperands[0],
2215 vecOperands[1], vecOperands[2],
2216 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2217 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2221 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2225 newResults.push_back(write->
getResult(0));
2231 enum class ConvOperationKind { Conv, Pool };
2249 static std::optional<ConvOperationKind>
2251 int numBlockArguments =
2252 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2254 switch (numBlockArguments) {
2260 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2261 llvm::IsaPred<BlockArgument>);
2263 "Expected a non-block argument operand");
2264 Operation *feedOp = (*feedValIt).getDefiningOp();
2266 return ConvOperationKind::Pool;
2269 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2270 (isa<arith::AndIOp>(feedOp) &&
2273 if (isa<BlockArgument>(v))
2275 if (Operation *op = v.getDefiningOp())
2276 return isCastOfBlockArgument(op);
2279 return std::nullopt;
2282 return ConvOperationKind::Conv;
2286 return ConvOperationKind::Pool;
2288 return std::nullopt;
2294 case vector::CombiningKind::ADD:
2295 case vector::CombiningKind::MAXNUMF:
2296 case vector::CombiningKind::MAXIMUMF:
2297 case vector::CombiningKind::MAXSI:
2298 case vector::CombiningKind::MAXUI:
2299 case vector::CombiningKind::MINNUMF:
2300 case vector::CombiningKind::MINIMUMF:
2301 case vector::CombiningKind::MINSI:
2310 auto getOperandType = [&](
auto operand) {
2311 return dyn_cast<ShapedType>((operand->get()).getType());
2313 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2314 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2315 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2319 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2320 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2328 if (!maybeOper.has_value())
2335 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2336 *maybeKind != vector::CombiningKind::OR) &&
2337 (*maybeOper != ConvOperationKind::Pool ||
2342 auto rhsRank = rhsShapedType.getRank();
2343 if (*maybeOper == ConvOperationKind::Pool) {
2347 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2356 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2358 if (llvm::any_of(linalgOp->getOpOperands(), [&](
OpOperand &operand) {
2359 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2363 if (!inputVectorSizes.empty() &&
2369 linalgOp, flatten1DDepthwiseConv))) {
2370 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2383 customPreconditions,
2386 customPrecondition(&innerOp, vectorizeNDExtract));
2390 if (!llvm::all_of(innerOp.getOperandTypes(),
2391 VectorType::isValidElementType)) {
2394 if (!llvm::all_of(innerOp.getResultTypes(),
2395 VectorType::isValidElementType)) {
2405 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2412 LDBG() <<
"precondition failed: not projected permutations";
2416 LDBG() <<
"precondition failed: reduction preconditions";
2422 static LogicalResult
2425 auto padValue = packOp.getPaddingValue();
2428 LDBG() <<
"pad value is not constant: " << packOp;
2433 bool satisfyEmptyCond =
true;
2434 if (inputVectorSizes.empty()) {
2435 if (!packOp.getDestType().hasStaticShape() ||
2436 !packOp.getSourceType().hasStaticShape())
2437 satisfyEmptyCond =
false;
2440 if (!satisfyEmptyCond &&
2442 resultTensorShape.take_front(packOp.getSourceRank()),
2446 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2447 return !getConstantIntValue(v).has_value();
2449 LDBG() <<
"inner_tiles must be constant: " << packOp;
2456 static LogicalResult
2459 auto padValue = padOp.getConstantPaddingValue();
2461 LDBG() <<
"pad value is not constant: " << padOp;
2481 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2482 Value padValue = en.value();
2483 unsigned pos = en.index();
2484 std::optional<int64_t> pad = getConstantIntValue(padValue);
2485 return (!pad.has_value() || pad.value() != 0) &&
2486 resultTensorShape[pos] != 1;
2488 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2501 static LogicalResult
2505 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2506 "Number of input vector sizes and scalable dims doesn't match");
2508 size_t numOfScalableDims =
2509 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2511 if (numOfScalableDims == 0)
2514 auto linalgOp = dyn_cast<LinalgOp>(op);
2519 return success(isa<linalg::UnPackOp>(op));
2523 if (numOfScalableDims > 2)
2543 bool seenNonUnitParallel =
false;
2544 auto iterators = linalgOp.getIteratorTypesArray();
2546 int64_t idx = scalableFlags.size() - 1;
2547 while (!scalableFlags[idx]) {
2548 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2549 seenNonUnitParallel |=
2550 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2552 iterators.pop_back();
2553 scalableFlags.pop_back();
2558 switch (iterators.back()) {
2559 case utils::IteratorType::reduction: {
2561 if (iterators.size() != inputVectorSizes.size()) {
2562 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2566 if (isa<linalg::MatmulOp>(op)) {
2568 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2574 case utils::IteratorType::parallel: {
2576 if (seenNonUnitParallel) {
2577 LDBG() <<
"Inner parallel dim not requested for scalable "
2589 if (numOfScalableDims == 2) {
2593 if (iterators.back() == utils::IteratorType::reduction) {
2594 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2599 scalableFlags.pop_back();
2600 iterators.pop_back();
2602 if (!scalableFlags.back() ||
2603 (iterators.back() != utils::IteratorType::parallel))
2609 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2610 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2611 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2612 isa<linalg::BatchMmt4DOp>(op) ||
2619 bool flatten1DDepthwiseConv) {
2625 inputScalableVecDims)))
2629 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2632 flatten1DDepthwiseConv);
2634 .Case<tensor::PadOp>([&](
auto padOp) {
2637 .Case<linalg::PackOp>([&](
auto packOp) {
2640 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2643 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2646 .Default([](
auto) {
return failure(); });
2652 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2654 for (
auto op : make_early_inc_range(toReplace)) {
2657 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2658 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2659 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2665 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2666 tensor::InsertSliceOp>(op);
2672 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2673 bool createNamedContraction) {
2674 LDBG() <<
"Attempting to vectorize: " << *op;
2675 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2676 LDBG() <<
"Input scalable vector dims: "
2677 << llvm::interleaved(inputScalableVecDims);
2681 flatten1DDepthwiseConv))) {
2682 LDBG() <<
"Vectorization pre-conditions failed";
2688 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2689 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2690 inputScalableVecDims,
2691 assumeDynamicDimsMatchVecSizes))) {
2692 LDBG() <<
"Vectorization state couldn't be initialized";
2698 auto vectorizeResult =
2700 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2704 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2706 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2707 flatten1DDepthwiseConv);
2708 if (succeeded(convOr)) {
2709 llvm::append_range(results, (*convOr)->getResults());
2713 LDBG() <<
"Unsupported convolution can't be vectorized.";
2717 if (createNamedContraction &&
2718 isa<ContractionOpInterface>(linalgOp.getOperation()))
2723 <<
"Vectorize generic by broadcasting to the canonical vector "
2736 .Case<tensor::PadOp>([&](
auto padOp) {
2740 .Case<linalg::PackOp>([&](
auto packOp) {
2744 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2747 inputScalableVecDims, results);
2749 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2753 .Default([](
auto) {
return failure(); });
2755 if (
failed(vectorizeResult)) {
2756 LDBG() <<
"Vectorization failed";
2764 memref::CopyOp copyOp) {
2765 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2766 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2767 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2772 if (!VectorType::isValidElementType(srcElementType) ||
2773 !VectorType::isValidElementType(dstElementType))
2784 rewriter, loc, readType, copyOp.getSource(), indices,
2787 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2791 vector::BroadcastOp::create(rewriter, loc, writeType,
readValue);
2793 Operation *writeValue = vector::TransferWriteOp::create(
2794 rewriter, loc,
readValue, copyOp.getTarget(), indices,
2805 template <
typename OpTy>
2813 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2814 if (
auto op = dyn_cast<OpTy>(user))
2815 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2821 tensor::PadOp padOp, OpTy op)
const = 0;
2849 vector::TransferReadOp xferOp)
const override {
2851 if (!padOp.hasZeroLowPad())
2854 auto padValue = padOp.getConstantPaddingValue();
2858 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2863 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2865 xferOp.getBaseMutable().assign(padOp.getSource());
2866 xferOp.getPaddingMutable().assign(padValue);
2911 vector::TransferWriteOp xferOp)
const override {
2913 if (xferOp.getTransferRank() == 0)
2917 if (!padOp.hasZeroLowPad())
2920 auto padValue = padOp.getConstantPaddingValue();
2924 if (!xferOp->hasOneUse())
2926 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2930 if (!trimPadding.hasZeroOffset())
2933 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2941 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2942 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2944 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2960 tensor::ExtractSliceOp afterTrimming)
const {
2963 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2964 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2967 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2968 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2973 if (t1.getRank() != t2.getRank())
2978 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2979 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2981 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2986 if (t1.getNumDynamicDims() == 0)
2994 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2998 assert(
static_cast<size_t>(t1.getRank()) ==
2999 beforeSlice.getMixedSizes().size());
3000 assert(
static_cast<size_t>(t2.getRank()) ==
3001 afterTrimming.getMixedSizes().size());
3003 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3005 if (!t1.isDynamicDim(i))
3007 auto size1 = beforeSlice.getMixedSizes()[i];
3008 auto size2 = afterTrimming.getMixedSizes()[i];
3015 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3016 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3022 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3023 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3024 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3025 minOp1.getOperands() == minOp2.getOperands())
3051 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3052 auto source = bcast.getSource();
3053 if (llvm::dyn_cast<VectorType>(source.getType()))
3061 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3062 return fill.getInputs()[0];
3067 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3074 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3082 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3088 static LogicalResult
3097 auto sourceType = source.getType();
3098 auto resultType = sliceOp.getResultType();
3103 auto elemType = sourceType.getElementType();
3104 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3110 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3111 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3112 if (!inputVectorSizes.empty()) {
3113 vecShape.push_back(inputVectorSizes[i]);
3114 }
else if (!sourceType.isDynamicDim(i)) {
3115 vecShape.push_back(sourceType.getDimSize(i));
3116 }
else if (!resultType.isDynamicDim(i)) {
3122 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3129 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
3132 auto loc = sliceOp.getLoc();
3138 rewriter, loc, source, vecType.getShape(), padValue,
3139 inputVectorSizes.empty(),
3147 writeIndices, inputVectorSizes.empty());
3150 newResults.push_back(write->
getResult(0));
3184 tensor::InsertSliceOp insertOp)
const override {
3186 if (!padOp.hasZeroLowPad())
3189 if (!insertOp.hasUnitStride())
3192 auto padValue = padOp.getConstantPaddingValue();
3196 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3199 if (insertOp.getDest() == padOp.getResult())
3203 padOp.getType().getElementType());
3204 unsigned vecRank = vecType.getRank();
3205 unsigned tensorRank = insertOp.getType().getRank();
3210 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3212 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3213 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3225 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3226 vecType, padOp.getSource(),
3227 readIndices, padValue);
3233 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3236 insertOp, read, insertOp.getDest(), writeIndices,
3262 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3263 <<
", second op: " << *secondOp;
3266 for (
auto v : values) {
3267 for (
auto &u : v.getUses()) {
3269 if (owner == firstOp || owner == secondOp)
3275 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3276 <<
", second op: " << *secondOp;
3286 memref::SubViewOp subViewOp;
3288 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3290 return memref::SubViewOp();
3291 subViewOp = newSubViewOp;
3303 if (xferOp.getMask())
3307 Value viewOrAlloc = xferOp.getBase();
3316 Value subView = subViewOp.getResult();
3319 memref::CopyOp copyOp;
3320 for (
auto &u : subView.
getUses()) {
3321 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3322 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3323 if (newCopyOp.getTarget() != subView)
3337 for (
auto &u : viewOrAlloc.
getUses()) {
3338 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3339 assert(isa<MemRefType>(newFillOp.output().getType()));
3340 if (newFillOp.output() != viewOrAlloc)
3344 maybeFillOp = newFillOp;
3349 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3351 "padding value does not match fill");
3354 Value in = copyOp.getSource();
3360 auto vectorType = xferOp.getVectorType();
3361 Value res = vector::TransferReadOp::create(
3362 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3363 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3368 rewriter.
eraseOp(maybeFillOp);
3380 if (xferOp.getMask())
3384 Value viewOrAlloc = xferOp.getBase();
3393 Value subView = subViewOp.getResult();
3396 memref::CopyOp copyOp;
3397 for (
auto &u : subViewOp.getResult().getUses()) {
3398 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3399 if (newCopyOp.getSource() != subView)
3411 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3412 Value out = copyOp.getTarget();
3419 auto vector = xferOp.getVector();
3420 vector::TransferWriteOp::create(
3421 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3422 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3424 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3439 template <
int N,
typename IntTy,
typename... IntTy2>
3441 val = shapedType.getShape()[N];
3446 template <
typename... IntTy>
3448 bindShapeDims<0>(shapedType, vals...);
3486 struct Conv1DGenerator
3488 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3491 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3492 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3493 resShaped = linalgOp.getDpsInitOperand(0)->get();
3494 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3495 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3496 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3501 setConvOperationKind(reduceOp);
3504 reductionKind = maybeKind.value();
3512 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3513 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3535 int64_t nSize, wSize, cSize, kwSize, fSize;
3538 switch (conv1DOpOrder) {
3541 nSize = fSize = cSize = 0;
3548 (wSize + kwSize - 1)};
3549 rhsShape = {kwSize};
3556 case ConvOperationKind::Conv:
3560 case ConvOperationKind::Pool:
3570 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3574 case ConvOperationKind::Conv:
3575 rhsShape = {kwSize, cSize, fSize};
3577 case ConvOperationKind::Pool:
3578 rhsShape = {kwSize};
3581 resShape = {nSize, wSize, fSize};
3587 case ConvOperationKind::Conv:
3591 case ConvOperationKind::Pool:
3597 lhsShape = {nSize, cSize,
3601 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3604 case ConvOperationKind::Conv:
3605 rhsShape = {fSize, cSize, kwSize};
3607 case ConvOperationKind::Pool:
3608 rhsShape = {kwSize};
3611 resShape = {nSize, fSize, wSize};
3615 vector::TransferWriteOp write;
3621 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3623 Type lhsEltType = lhsShapedType.getElementType();
3624 Type rhsEltType = rhsShapedType.getElementType();
3625 Type resEltType = resShapedType.getElementType();
3635 Value lhs = vector::TransferReadOp::create(
3636 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3639 Value rhs =
nullptr;
3640 if (oper == ConvOperationKind::Conv)
3641 rhs = vector::TransferReadOp::create(
3642 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3644 Value res = vector::TransferReadOp::create(
3645 rewriter, loc, resType, resShaped, resPadding,
3651 switch (conv1DOpOrder) {
3659 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3660 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3662 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3665 if (oper == ConvOperationKind::Conv)
3666 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3668 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3669 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3680 kwSize, strideW, dilationW, wSizeStep,
3683 if (oper == ConvOperationKind::Conv)
3686 wSizeStep, isSingleChanneled);
3688 auto linearIndex = [&](int64_t kw, int64_t w) {
3689 return kw * (wSize / wSizeStep) + w;
3695 for (int64_t kw = 0; kw < kwSize; ++kw) {
3696 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3698 case ConvOperationKind::Conv:
3699 if (isSingleChanneled) {
3700 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3701 lhsVals[linearIndex(kw, w)],
3702 rhsVals[kw], resVals[w]);
3704 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3705 lhsVals[linearIndex(kw, w)],
3706 rhsVals[kw], resVals[w]);
3709 case ConvOperationKind::Pool:
3710 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3726 switch (conv1DOpOrder) {
3733 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3734 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3739 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3748 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3749 if (srcElementType == dstElementType)
3754 const Type dstType =
3755 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3757 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3758 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3761 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3762 srcWidth < dstWidth)
3763 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3765 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3766 srcWidth < dstWidth)
3767 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3769 assert(
false &&
"unhandled promotion case");
3776 vector::IteratorType par = vector::IteratorType::parallel;
3777 vector::IteratorType red = vector::IteratorType::reduction;
3782 auto contrationOp = vector::ContractionOp::create(
3783 rewriter, loc, lhs, rhs, res,
3784 MapList{{n, w, c}, {c, f}, {n, w, f}},
3786 contrationOp.setKind(reductionKind);
3787 return contrationOp;
3794 return vector::OuterProductOp::create(rewriter, loc, res.
getType(), lhs,
3795 rhs, res, vector::CombiningKind::ADD);
3817 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3818 bool channelDimScalableFlag,
3820 bool scalableChDim =
false;
3821 bool useMasking =
false;
3822 int64_t nSize, wSize, cSize, kwSize;
3825 if (ShapedType::isDynamic(cSize)) {
3826 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3827 cSize = channelDimVecSize;
3831 scalableChDim = channelDimScalableFlag;
3835 assert(!(useMasking && flatten) &&
3836 "Unsupported flattened conv with dynamic shapes");
3841 vector::TransferWriteOp write;
3847 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3849 Type lhsEltType = lhsShapedType.getElementType();
3850 Type rhsEltType = rhsShapedType.getElementType();
3851 Type resEltType = resShapedType.getElementType();
3856 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3858 lhsEltType, {
false,
false, scalableChDim});
3859 VectorType rhsType =
3861 {
false, scalableChDim});
3862 VectorType resType =
3864 {
false,
false, scalableChDim});
3877 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3878 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3882 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3885 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3892 Value lhs = vector::TransferReadOp::create(
3893 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3895 auto maybeMaskedLhs = maybeMaskXferOp(
3896 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3899 Value rhs = vector::TransferReadOp::create(
3900 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3902 auto maybeMaskedRhs = maybeMaskXferOp(
3903 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3906 Value res = vector::TransferReadOp::create(
3907 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3909 auto maybeMaskedRes = maybeMaskXferOp(
3910 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3922 for (int64_t kw = 0; kw < kwSize; ++kw) {
3923 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3924 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3925 rewriter, loc, maybeMaskedLhs->getResult(0),
3927 inOutSliceSizes, inOutStrides));
3931 for (int64_t kw = 0; kw < kwSize; ++kw) {
3933 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3937 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3938 resVals.push_back(vector::ExtractStridedSliceOp::create(
3939 rewriter, loc, maybeMaskedRes->getResult(0),
3944 auto linearIndex = [&](int64_t kw, int64_t w) {
3945 return kw * (wSize / wSizeStep) + w;
3951 auto lhsTypeAfterFlattening =
3953 auto resTypeAfterFlattening =
3957 for (int64_t kw = 0; kw < kwSize; ++kw) {
3958 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3959 Value lhsVal = lhsVals[linearIndex(kw, w)];
3960 Value resVal = resVals[w];
3965 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3966 lhsVals[linearIndex(kw, w)]);
3967 resVal = vector::ShapeCastOp::create(
3968 rewriter, loc, resTypeAfterFlattening, resVals[w]);
3970 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3971 rhsVals[kw], resVal, flatten);
3974 resVals[w] = vector::ShapeCastOp::create(
3982 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3984 for (
auto &collection :
3985 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3986 for (
Value v : collection)
3993 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3994 maybeMaskedRes = vector::InsertStridedSliceOp::create(
3995 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4004 Operation *resOut = vector::TransferWriteOp::create(
4005 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4007 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4018 auto rhsTy = cast<ShapedType>(rhs.
getType());
4019 auto resTy = cast<ShapedType>(res.
getType());
4022 lhs =
promote(rewriter, loc, lhs, resTy);
4033 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
4034 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4037 for (
int i = 0; i < resSize / rhsSize; ++i) {
4038 for (
int j = 0;
j < rhsSize; ++
j)
4039 indices.push_back(
j);
4042 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4045 rhs = vector::BroadcastOp::create(rewriter, loc,
4046 resTy.clone(rhsTy.getElementType()), rhs);
4048 rhs =
promote(rewriter, loc, rhs, resTy);
4053 if (isa<FloatType>(resTy.getElementType()))
4054 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4056 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4057 return arith::AddIOp::create(rewriter, loc, mul, res);
4062 FailureOr<Operation *> generateNonChanneledConv() {
4065 if (!iters({Par(), Red()}))
4067 "failed to match conv::W 1-par 1-red");
4070 if (layout({ {w + kw},
4080 FailureOr<Operation *> generateNwcConv() {
4083 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4085 op,
"failed to match conv::Nwc 3-par 2-red");
4088 if (layout({ {n, strideW * w + dilationW * kw, c},
4098 FailureOr<Operation *> generateNcwConv() {
4101 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4103 op,
"failed to match conv::Ncw 3-par 2-red");
4105 if (layout({ {n, c, strideW * w + dilationW * kw},
4115 FailureOr<Operation *> generateNwcPooling() {
4118 if (!iters({Par(), Par(), Par(), Red()}))
4120 "failed to match pooling 3-par 1-red");
4123 if (layout({ {n, strideW * w + dilationW * kw, c},
4133 FailureOr<Operation *> generateNcwPooling() {
4136 if (!iters({Par(), Par(), Par(), Red()}))
4138 "failed to match pooling 3-par 1-red");
4140 if (layout({ {n, c, strideW * w + dilationW * kw},
4150 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4151 bool vecChDimScalableFlag =
false,
4152 bool flatten =
false) {
4155 if (!iters({Par(), Par(), Par(), Red()}))
4157 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4160 if (layout({ {n, strideW * w + dilationW * kw, c},
4163 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4169 ConvOperationKind oper = ConvOperationKind::Conv;
4171 StringAttr poolExtOp;
4172 bool isPoolExt =
false;
4173 int strideW, dilationW;
4174 Value lhsShaped, rhsShaped, resShaped;
4175 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4176 vector::CombiningKind reductionKind;
4179 void setConvOperationKind(
Operation *reduceOp) {
4180 int numBlockArguments =
4181 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4182 if (numBlockArguments == 1) {
4187 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4188 llvm::IsaPred<BlockArgument>);
4189 Operation *feedOp = (*feedValIt).getDefiningOp();
4191 oper = ConvOperationKind::Pool;
4196 oper = ConvOperationKind::Conv;
4200 oper = ConvOperationKind::Pool;
4210 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4211 Conv1DGenerator conv1dGen(rewriter, op);
4212 auto res = conv1dGen.generateNonChanneledConv();
4215 res = conv1dGen.generateNwcConv();
4218 res = conv1dGen.generateNcwConv();
4221 res = conv1dGen.generateNwcPooling();
4224 res = conv1dGen.generateNcwPooling();
4231 uint64_t vecChDimSize = ShapedType::kDynamic;
4232 bool vecChDimScalableFlag =
false;
4233 if (!inputVecSizes.empty()) {
4236 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4237 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4238 "Not a 1D depthwise conv!");
4241 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4242 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4244 vecChDimSize = inputVecSizes[chDimIdx];
4245 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4247 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4248 flatten1DDepthwiseConv);
4257 if (
failed(resultOrFail))
4261 rewriter.
eraseOp(op.getOperation());
4264 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
SmallVector< int64_t > innerDimsPos
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.