37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "linalg-vectorization"
53 static FailureOr<Operation *>
57 bool flatten1DDepthwiseConv =
false);
92 template <
typename OpType>
95 block.
walk([&](OpType op) {
110 int64_t nSize, int64_t wSize, int64_t cSize,
111 int64_t kwSize,
int strideW,
int dilationW,
112 int64_t wSizeStep,
bool isSingleChanneled) {
114 if (isSingleChanneled) {
119 for (int64_t kw = 0; kw < kwSize; ++kw) {
120 for (int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
179 for (int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
226 bool assumeDynamicDimsMatchVecSizes =
false);
241 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
244 if (dimPermutation.has_value()) {
246 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
248 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
275 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
284 std::optional<AffineMap> maybeMaskingMap);
289 bool isValidMaskingMap(
AffineMap maskingMap) {
342 bool assumeDynamicDimsMatchVecSizes =
false;
346 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
349 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
360 unsigned operandDimPos;
361 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
366 linalgOp.hasPureTensorSemantics()
367 ? (
Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 : (
Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 iterSpaceValueSizes.push_back(dynamicDim);
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
389 if (!inputVectorSizes.empty()) {
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(),
false);
404 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407 if (ShapedType::isDynamicShape(canonicalVecShape))
411 initIterSpaceStaticSizes(linalgOp);
416 if (
failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
426 Value VectorizationState::getOrCreateMaskFor(
428 std::optional<AffineMap> maybeMaskingMap) {
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.
getContext());
449 LDBG() <<
"Masking map: " << maskingMap;
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() <<
"Reusing mask: " << mask;
467 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
471 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
473 if (permutedStaticSizes == maskShape) {
474 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] =
Value();
479 if (assumeDynamicDimsMatchVecSizes) {
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 return std::get<0>(it) == ShapedType::kDynamic
487 : std::get<0>(it) == std::get<1>(it);
490 <<
"Dynamic + static dimensions match vector sizes, masking is not "
492 activeMaskCache[maskingMap] =
Value();
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() <<
"Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() <<
"Trying to mask: " << *opToMask;
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526 LDBG() <<
"No mask required";
527 if (assumeDynamicDimsMatchVecSizes) {
529 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
535 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
536 "setting their in-bounds to true!";
538 ShapedType xferType = xferOp.getShapedType();
543 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
548 unsigned pos = dimExpr.getPosition();
549 if (xferType.isDynamicDim(pos))
550 inBoundsMap[i] =
true;
553 xferOp.setInBoundsAttr(
565 assert(opToMask &&
"Expected a valid operation to mask");
566 auto maskOp = cast<vector::MaskOp>(
568 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
574 LDBG() <<
"Masked operation: " << *maskOp;
597 "expected projected permutation");
599 assert(res.getNumDims() ==
600 (res.getNumResults() - res.getNumOfZeroResults()) &&
601 "expected reindexed map with same number of dims and results");
637 std::optional<vector::CombiningKind>
639 using ::mlir::vector::CombiningKind;
644 .Case<arith::AddIOp, arith::AddFOp>(
645 [&](
auto op) {
return CombiningKind::ADD; })
646 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
647 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
648 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
649 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
650 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
651 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
653 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
654 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
655 .Case<arith::MulIOp, arith::MulFOp>(
656 [&](
auto op) {
return CombiningKind::MUL; })
657 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
658 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
659 .Default([&](
auto op) {
return std::nullopt; });
670 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
675 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
676 combinerOps.size() != 1)
680 return combinerOps[0];
686 auto dstVecType = dyn_cast<VectorType>(dstType);
688 if (dstVecType.getRank() == 0)
694 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
706 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
707 return vector::MultiDimReductionOp::create(
708 b, reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
712 return llvm::to_vector(
719 return isa<linalg::ReduceOp>(op) ||
720 (isa<linalg::GenericOp>(op) &&
734 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
735 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
744 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
746 auto vectorType = state.getCanonicalVecType(
750 if (vectorType.getRank() > 0) {
753 linalgOp.getRank(outputOperand),
756 assert(value.
getType() == vectorType &&
"Incorrect type");
757 write = vector::TransferWriteOp::create(
758 rewriter, loc, value, outputOperand->
get(), indices, writeMap);
761 if (!isa<VectorType>(value.
getType()))
762 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763 assert(value.
getType() == vectorType &&
"Incorrect type");
764 write = vector::TransferWriteOp::create(rewriter, loc, value,
768 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
772 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
773 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
778 LDBG() <<
"vectorized op: " << *write;
788 std::function<LogicalResult(
Operation *,
bool)>;
807 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
816 linalgOp.getDpsInitOperand(output.index()), state);
818 newResults.push_back(newResult);
832 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
835 auto loc = indexOp.getLoc();
838 auto dim = indexOp.getDim();
840 auto indexVectorType =
842 state.getScalableVecDims()[dim]);
843 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
847 if (dim == targetShape.size() - 1)
853 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
854 std::swap(permPattern[dim], permPattern.back());
858 auto broadCastOp = vector::BroadcastOp::create(
860 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
862 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
863 std::swap(transposition.back(), transposition[dim]);
865 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
873 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
877 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
882 if (not extractOp.getIndices().empty()) {
883 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
887 if (!llvm::all_of(extractOp->getResultTypes(),
888 VectorType::isValidElementType)) {
907 tensor::ExtractOp extractOp,
910 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
911 auto loc = extractOp.getLoc();
914 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
916 const size_t numIndices = extractOp.getIndices().size();
917 for (
size_t i = 1; i < numIndices; i++) {
922 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
925 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
928 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
930 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
956 (linalgOp.hasDynamicShape() ||
957 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
958 "For statically shaped Linalg Ops, only one "
959 "non-unit loop dim is expected");
960 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
962 size_t idx = loopRanges.size() - 1;
963 for (; idx != 0; idx--)
964 if (loopRanges[idx] != 1)
972 VectorType resType) {
974 assert(((llvm::count_if(resType.getShape(),
975 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976 "n-D vectors are not yet supported");
982 auto *block = linalgOp.getBlock();
983 if (isa<BlockArgument>(val))
984 return !llvm::is_contained(block->getArguments(), val);
987 assert(defOp &&
"This is neither a block argument nor an operation result");
992 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
993 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
996 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1003 if (isa<arith::ConstantOp>(ancestor))
1007 for (
auto op : ancestor->getOperands())
1031 bool &foundIndexOp, VectorType resType) {
1033 assert(((llvm::count_if(resType.getShape(),
1034 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1035 "n-D vectors are not yet supported");
1041 auto *block = linalgOp.getBlock();
1042 if (isa<BlockArgument>(val))
1043 return !llvm::is_contained(block->getArguments(), val);
1046 assert(defOp &&
"This is neither a block argument nor an operation result");
1048 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1051 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1055 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1062 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1065 bool result =
false;
1066 for (
auto op : ancestor->getOperands())
1086 LinalgOp &linalgOp, VectorType resType) {
1088 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1091 if (inputShape.getShape().empty())
1096 bool isOutput1DVector =
1097 (llvm::count_if(resType.getShape(),
1098 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1100 if (!isOutput1DVector)
1103 bool leadingIdxsLoopInvariant =
true;
1109 auto indices = extractOp.getIndices();
1110 auto leadIndices = indices.drop_back(1);
1113 if (inputShape.getShape()[i] == 1)
1119 if (!leadingIdxsLoopInvariant) {
1120 LDBG() <<
"Found gather load: " << extractOp;
1128 auto extractOpTrailingIdx = indices.back();
1132 if (leadingIdxsLoopInvariant &&
1134 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1143 bool foundIndexOp =
false;
1145 foundIndexOp, resType);
1148 bool isRowVector = resType.getShape().back() != 1;
1149 isContiguousLoad &= (foundIndexOp && isRowVector);
1151 if (isContiguousLoad) {
1152 LDBG() <<
"Found contigous load: " << extractOp;
1157 LDBG() <<
"Found gather load: " << extractOp;
1168 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1171 auto loc = extractOp.getLoc();
1174 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1175 auto maskConstantOp = arith::ConstantOp::create(
1179 auto passThruConstantOp = arith::ConstantOp::create(
1185 extractOp.getIndices().size(),
1196 Operation *gatherOp = vector::GatherOp::create(
1197 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1198 maskConstantOp, passThruConstantOp);
1199 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1201 LDBG() <<
"Vectorised as gather load: " << extractOp;
1224 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1225 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1227 transferReadIdxs.push_back(idx);
1231 auto indexAs1dVector = vector::ShapeCastOp::create(
1234 resultType.getScalableDims().back()),
1236 transferReadIdxs.push_back(
1237 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1241 auto dstRank = resultType.getRank();
1242 auto srcRank = extractOp.getTensor().getType().getRank();
1251 auto transferReadOp = vector::TransferReadOp::create(
1252 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1253 std::nullopt, permutationMap, inBounds);
1260 auto allTrue = vector::ConstantMaskOp::create(
1262 auto *maskedReadOp =
1265 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1274 int32_t rankDiff = dstRank - srcRank;
1282 while (rankDiff > 0) {
1283 permutationMap = permutationMap.insertResult(
1288 auto transferReadOp = vector::TransferReadOp::create(
1289 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1290 std::nullopt, permutationMap, inBounds);
1292 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1306 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1307 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1311 (outputType && reduceType.getShape() == outputType.getShape()))
1340 LDBG() <<
"vectorize op " << *op;
1343 if (!customVectorizationHooks.empty()) {
1344 for (
auto &customFunc : customVectorizationHooks) {
1354 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1356 rewriter.
clone(*op)};
1365 auto blockArg = dyn_cast<BlockArgument>(operand);
1366 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1367 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1371 linalgOp.getRegionOutputArgs(),
1372 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1375 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1377 if (!reductionOperands.empty()) {
1378 assert(reductionOperands.size() == 1);
1380 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1381 reductionOperands[0].second, bvm);
1388 VectorType firstMaxRankedType;
1390 auto vecOperand = bvm.
lookup(operand);
1391 assert(vecOperand &&
"Vector operand couldn't be found");
1393 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1394 if (vecType && (!firstMaxRankedType ||
1395 firstMaxRankedType.getRank() < vecType.getRank()))
1396 firstMaxRankedType = vecType;
1402 assert(vecOperand &&
"Vector operand couldn't be found");
1404 if (firstMaxRankedType) {
1407 firstMaxRankedType.getScalableDims());
1410 vecOperands.push_back(vecOperand);
1416 resultTypes.push_back(
1419 firstMaxRankedType.getScalableDims())
1451 static LogicalResult
1455 LDBG() <<
"Vectorizing operation as linalg generic/n";
1456 Block *block = linalgOp.getBlock();
1463 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1465 if (linalgOp.getNumDpsInits() == 0)
1471 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1472 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1473 if (linalgOp.isScalar(opOperand)) {
1474 bvm.
map(bbarg, opOperand->get());
1480 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1483 VectorType readType;
1485 if (linalgOp.isDpsInput(opOperand)) {
1488 readType = state.getCanonicalVecType(elemType);
1495 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1500 Operation *read = vector::TransferReadOp::create(
1501 rewriter, loc, readType, opOperand->get(), indices,
1502 std::nullopt, readMap);
1503 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1508 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1510 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1516 if (readType.getRank() == 0)
1520 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1532 hooks.push_back(vectorizeYield);
1539 hooks.push_back(vectorizeIndex);
1546 hooks.push_back(vectorizeExtract);
1553 LDBG() <<
"failed to vectorize: " << op;
1558 state.maskOperation(rewriter, result.
newOp, linalgOp);
1559 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1625 if (ShapedType::isDynamicShape(destShape))
1632 cstMaskSizes.push_back(*intSize);
1637 if (cstMaskSizes.size() != maskShape.size())
1645 cstWriteIdxs.push_back(intVal.getSExtValue());
1650 if (cstWriteIdxs.size() != destShape.size())
1659 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1661 if ( maskShape[i] > destShape[rankDiff + i] ||
1662 destShape[rankDiff + i] <
1663 (
std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1699 bool useInBoundsInsteadOfMasking =
false) {
1701 ShapedType destType = cast<ShapedType>(dest.
getType());
1702 int64_t destRank = destType.getRank();
1703 auto destShape = destType.getShape();
1705 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1706 int64_t vecToStoreRank = vecToStoreType.getRank();
1707 auto vecToStoreShape = vecToStoreType.getShape();
1711 if (useInBoundsInsteadOfMasking) {
1714 for (
unsigned i = 0; i < vecToStoreRank; i++)
1716 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1717 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1721 assert((writeIndices.empty() ||
1722 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1723 "Invalid number of write indices!");
1724 if (writeIndices.empty()) {
1726 writeIndices.assign(destRank, zero);
1730 Operation *write = vector::TransferWriteOp::create(builder, loc,
1737 if (useInBoundsInsteadOfMasking)
1741 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1746 vecToStoreType.getScalableDims());
1749 isa<MemRefType>(dest.
getType())
1759 Value maskForWrite =
1760 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1798 static LogicalResult
1807 std::optional<Value> padValue = packOp.getPaddingValue()
1808 ? std::optional(packOp.getPaddingValue())
1814 bool useInBoundsInsteadOfMasking =
false;
1815 if (inputVectorSizes.empty()) {
1817 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1818 useInBoundsInsteadOfMasking =
true;
1823 auto innerTiles = packOp.getStaticInnerTiles();
1832 rewriter, loc, packOp.getSource(), inputShape, padValue,
1833 useInBoundsInsteadOfMasking,
1840 packOp.getDestType().getElementType());
1842 vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1845 auto destPermutation =
1847 auto transposeOp = vector::TransposeOp::create(
1848 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1852 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1853 newResults.push_back(write->
getResult(0));
1875 assert(type.getNumScalableDims() < 2 &&
1876 "Collapsing more than 1 scalable dim is not supported ATM");
1882 auto shape = type.getShape();
1883 auto scalableFlags = type.getScalableDims();
1887 unsigned currentDim = 0;
1889 unsigned dim = m.getNumResults();
1892 for (
unsigned d = 0; d < dim; ++d) {
1893 size *= shape[currentDim + d];
1894 flag |= scalableFlags[currentDim + d];
1896 newShape.push_back(size);
1897 newScalableFlags.push_back(flag);
1901 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1933 static LogicalResult
1938 if (!inputVectorSizes.empty()) {
1939 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1940 "Invalid number of input vector sizes!");
1941 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1942 "Incompatible number of vector sizes and vector scalable flags!");
1949 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1952 bool useInBoundsInsteadOfMasking =
false;
1961 if (inputVectorSizes.empty()) {
1962 if (ShapedType::isDynamicShape(sourceShape))
1965 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1966 useInBoundsInsteadOfMasking =
true;
1971 rewriter, loc, unpackOp.getSource(), readVectorSizes, std::nullopt,
1972 useInBoundsInsteadOfMasking, readScalableVectorFlags);
1975 PackingMetadata packMetadata;
1978 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1979 rewriter, loc, readResult, lastDimToInsertPosPerm);
1983 transposeOp.getType(),
1985 rewriter.
getContext(), packMetadata.reassociations)));
1986 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1987 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1991 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1992 {}, useInBoundsInsteadOfMasking);
1994 newResults.push_back(write->
getResult(0));
2001 static LogicalResult
2005 auto padValue = padOp.getConstantPaddingValue();
2013 LogicalResult status =
2014 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2015 .reifyResultShapes(rewriter, reifiedReturnShapes);
2017 assert(succeeded(status) &&
"failed to reify result shapes");
2019 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2023 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2024 padOp.getResultType().getElementType());
2026 newResults.push_back(write->
getResult(0));
2034 LDBG() <<
"reduction precondition failed: no reduction iterator";
2037 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2038 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2044 LDBG() <<
"reduction precondition failed: reduction detection failed";
2051 static LogicalResult
2053 bool flatten1DDepthwiseConv) {
2054 if (flatten1DDepthwiseConv) {
2055 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2060 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2061 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2067 Value lhs = conv.getDpsInputOperand(0)->get();
2069 auto shapeWithoutCh = lhsShape.drop_back(1);
2070 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2071 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2072 "channel dim can be dynamic";
2079 static LogicalResult
2081 bool flatten1DDepthwiseConv) {
2082 if (isa<ConvolutionOpInterface>(op.getOperation()))
2091 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2095 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2104 static LogicalResult
2109 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2110 unpackOp.getSourceType().hasStaticShape())
2115 if (!inputVectorSizes.empty() &&
2116 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2117 LDBG() <<
"Incorrect number of input vector sizes";
2123 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2124 LDBG() <<
"Invalid vector sizes for the read operation";
2131 static LogicalResult
2136 auto sourceType = source.getType();
2137 if (!VectorType::isValidElementType(sourceType.getElementType()))
2153 bool isOutOfBoundsRead =
2154 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2156 if (!padValue && isOutOfBoundsRead) {
2157 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2170 static LogicalResult
2180 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2183 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2187 LDBG() <<
"Failed to determine contraction combining kind.";
2194 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2195 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2197 LDBG() <<
"Contractions with broadcasts are not supported.";
2203 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2207 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2211 VectorType readType =
2212 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2215 rewriter, loc, opOperand.get(), readType.getShape(),
2217 false, readType.getScalableDims());
2218 vecOperands.push_back(read);
2223 auto iterators = linalgOp.getIteratorTypesArray();
2224 for (utils::IteratorType iter : iterators) {
2225 auto vecIter = iter == utils::IteratorType::parallel
2226 ? vector::IteratorType::parallel
2227 : vector::IteratorType::reduction;
2232 Operation *contractOp = vector::ContractionOp::create(
2233 rewriter, loc, vecOperands[0],
2234 vecOperands[1], vecOperands[2],
2235 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2236 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2240 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2244 newResults.push_back(write->
getResult(0));
2250 enum class ConvOperationKind { Conv, Pool };
2268 static std::optional<ConvOperationKind>
2270 int numBlockArguments =
2271 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2273 switch (numBlockArguments) {
2279 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2280 llvm::IsaPred<BlockArgument>);
2282 "Expected a non-block argument operand");
2283 Operation *feedOp = (*feedValIt).getDefiningOp();
2285 return ConvOperationKind::Pool;
2288 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2289 (isa<arith::AndIOp>(feedOp) &&
2292 if (isa<BlockArgument>(v))
2294 if (Operation *op = v.getDefiningOp())
2295 return isCastOfBlockArgument(op);
2298 return std::nullopt;
2301 return ConvOperationKind::Conv;
2305 return ConvOperationKind::Pool;
2307 return std::nullopt;
2313 case vector::CombiningKind::ADD:
2314 case vector::CombiningKind::MAXNUMF:
2315 case vector::CombiningKind::MAXIMUMF:
2316 case vector::CombiningKind::MAXSI:
2317 case vector::CombiningKind::MAXUI:
2318 case vector::CombiningKind::MINNUMF:
2319 case vector::CombiningKind::MINIMUMF:
2320 case vector::CombiningKind::MINSI:
2329 auto getOperandType = [&](
auto operand) {
2330 return dyn_cast<ShapedType>((operand->get()).getType());
2332 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2333 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2334 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2338 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2339 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2347 if (!maybeOper.has_value())
2354 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2355 *maybeKind != vector::CombiningKind::OR) &&
2356 (*maybeOper != ConvOperationKind::Pool ||
2361 auto rhsRank = rhsShapedType.getRank();
2362 if (*maybeOper == ConvOperationKind::Pool) {
2366 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2375 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2377 if (llvm::any_of(linalgOp->getOpOperands(), [&](
OpOperand &operand) {
2378 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2382 if (!inputVectorSizes.empty() &&
2388 linalgOp, flatten1DDepthwiseConv))) {
2389 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2402 customPreconditions,
2405 customPrecondition(&innerOp, vectorizeNDExtract));
2409 if (!llvm::all_of(innerOp.getOperandTypes(),
2410 VectorType::isValidElementType)) {
2413 if (!llvm::all_of(innerOp.getResultTypes(),
2414 VectorType::isValidElementType)) {
2424 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2431 LDBG() <<
"precondition failed: not projected permutations";
2435 LDBG() <<
"precondition failed: reduction preconditions";
2441 static LogicalResult
2444 auto padValue = packOp.getPaddingValue();
2447 LDBG() <<
"pad value is not constant: " << packOp;
2452 bool satisfyEmptyCond =
true;
2453 if (inputVectorSizes.empty()) {
2454 if (!packOp.getDestType().hasStaticShape() ||
2455 !packOp.getSourceType().hasStaticShape())
2456 satisfyEmptyCond =
false;
2459 if (!satisfyEmptyCond &&
2461 resultTensorShape.take_front(packOp.getSourceRank()),
2465 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2466 return !getConstantIntValue(v).has_value();
2468 LDBG() <<
"inner_tiles must be constant: " << packOp;
2475 static LogicalResult
2478 auto padValue = padOp.getConstantPaddingValue();
2480 LDBG() <<
"pad value is not constant: " << padOp;
2500 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2501 Value padValue = en.value();
2502 unsigned pos = en.index();
2503 std::optional<int64_t> pad = getConstantIntValue(padValue);
2504 return (!pad.has_value() || pad.value() != 0) &&
2505 resultTensorShape[pos] != 1;
2507 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2520 static LogicalResult
2524 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2525 "Number of input vector sizes and scalable dims doesn't match");
2527 size_t numOfScalableDims =
2528 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2530 if (numOfScalableDims == 0)
2533 auto linalgOp = dyn_cast<LinalgOp>(op);
2538 return success(isa<linalg::UnPackOp>(op));
2542 if (numOfScalableDims > 2)
2562 bool seenNonUnitParallel =
false;
2563 auto iterators = linalgOp.getIteratorTypesArray();
2565 int64_t idx = scalableFlags.size() - 1;
2566 while (!scalableFlags[idx]) {
2567 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2568 seenNonUnitParallel |=
2569 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2571 iterators.pop_back();
2572 scalableFlags.pop_back();
2577 switch (iterators.back()) {
2578 case utils::IteratorType::reduction: {
2580 if (iterators.size() != inputVectorSizes.size()) {
2581 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2585 if (isa<linalg::MatmulOp>(op)) {
2587 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2593 case utils::IteratorType::parallel: {
2595 if (seenNonUnitParallel) {
2596 LDBG() <<
"Inner parallel dim not requested for scalable "
2608 if (numOfScalableDims == 2) {
2612 if (iterators.back() == utils::IteratorType::reduction) {
2613 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2618 scalableFlags.pop_back();
2619 iterators.pop_back();
2621 if (!scalableFlags.back() ||
2622 (iterators.back() != utils::IteratorType::parallel))
2628 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2629 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2630 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2631 isa<linalg::BatchMmt4DOp>(op) ||
2638 bool flatten1DDepthwiseConv) {
2644 inputScalableVecDims)))
2648 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2651 flatten1DDepthwiseConv);
2653 .Case<tensor::PadOp>([&](
auto padOp) {
2656 .Case<linalg::PackOp>([&](
auto packOp) {
2659 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2662 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2665 .Default([](
auto) {
return failure(); });
2671 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2673 for (
auto op : make_early_inc_range(toReplace)) {
2676 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2677 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2678 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2684 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2685 tensor::InsertSliceOp>(op);
2691 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2692 bool createNamedContraction) {
2693 LDBG() <<
"Attempting to vectorize: " << *op;
2694 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2695 LDBG() <<
"Input scalable vector dims: "
2696 << llvm::interleaved(inputScalableVecDims);
2700 flatten1DDepthwiseConv))) {
2701 LDBG() <<
"Vectorization pre-conditions failed";
2707 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2708 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2709 inputScalableVecDims,
2710 assumeDynamicDimsMatchVecSizes))) {
2711 LDBG() <<
"Vectorization state couldn't be initialized";
2717 auto vectorizeResult =
2719 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2723 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2725 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2726 flatten1DDepthwiseConv);
2727 if (succeeded(convOr)) {
2728 llvm::append_range(results, (*convOr)->getResults());
2732 LDBG() <<
"Unsupported convolution can't be vectorized.";
2736 if (createNamedContraction &&
2737 isa<ContractionOpInterface>(linalgOp.getOperation()))
2742 <<
"Vectorize generic by broadcasting to the canonical vector "
2755 .Case<tensor::PadOp>([&](
auto padOp) {
2759 .Case<linalg::PackOp>([&](
auto packOp) {
2763 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2766 inputScalableVecDims, results);
2768 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2772 .Default([](
auto) {
return failure(); });
2774 if (
failed(vectorizeResult)) {
2775 LDBG() <<
"Vectorization failed";
2783 memref::CopyOp copyOp) {
2784 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2785 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2786 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2791 if (!VectorType::isValidElementType(srcElementType) ||
2792 !VectorType::isValidElementType(dstElementType))
2803 rewriter, loc, readType, copyOp.getSource(), indices,
2806 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2810 vector::BroadcastOp::create(rewriter, loc, writeType,
readValue);
2812 Operation *writeValue = vector::TransferWriteOp::create(
2813 rewriter, loc,
readValue, copyOp.getTarget(), indices,
2824 template <
typename OpTy>
2832 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2833 if (
auto op = dyn_cast<OpTy>(user))
2834 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2840 tensor::PadOp padOp, OpTy op)
const = 0;
2868 vector::TransferReadOp xferOp)
const override {
2870 if (!padOp.hasZeroLowPad())
2873 auto padValue = padOp.getConstantPaddingValue();
2877 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2882 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2884 xferOp.getBaseMutable().assign(padOp.getSource());
2885 xferOp.getPaddingMutable().assign(padValue);
2930 vector::TransferWriteOp xferOp)
const override {
2932 if (xferOp.getTransferRank() == 0)
2936 if (!padOp.hasZeroLowPad())
2939 auto padValue = padOp.getConstantPaddingValue();
2943 if (!xferOp->hasOneUse())
2945 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2949 if (!trimPadding.hasZeroOffset())
2952 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2960 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2961 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2963 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2979 tensor::ExtractSliceOp afterTrimming)
const {
2982 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2983 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2986 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2987 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2992 if (t1.getRank() != t2.getRank())
2997 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2998 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3000 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3005 if (t1.getNumDynamicDims() == 0)
3013 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
3017 assert(
static_cast<size_t>(t1.getRank()) ==
3018 beforeSlice.getMixedSizes().size());
3019 assert(
static_cast<size_t>(t2.getRank()) ==
3020 afterTrimming.getMixedSizes().size());
3022 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3024 if (!t1.isDynamicDim(i))
3026 auto size1 = beforeSlice.getMixedSizes()[i];
3027 auto size2 = afterTrimming.getMixedSizes()[i];
3034 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3035 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3041 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3042 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3043 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3044 minOp1.getOperands() == minOp2.getOperands())
3070 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3071 auto source = bcast.getSource();
3072 if (llvm::dyn_cast<VectorType>(source.getType()))
3080 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3081 return fill.getInputs()[0];
3086 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3093 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3101 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3107 static LogicalResult
3116 auto sourceType = source.getType();
3117 auto resultType = sliceOp.getResultType();
3122 auto elemType = sourceType.getElementType();
3123 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3129 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3130 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3131 if (!inputVectorSizes.empty()) {
3132 vecShape.push_back(inputVectorSizes[i]);
3133 }
else if (!sourceType.isDynamicDim(i)) {
3134 vecShape.push_back(sourceType.getDimSize(i));
3135 }
else if (!resultType.isDynamicDim(i)) {
3141 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3148 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
3151 auto loc = sliceOp.getLoc();
3157 rewriter, loc, source, vecType.getShape(), padValue,
3158 inputVectorSizes.empty(),
3166 writeIndices, inputVectorSizes.empty());
3169 newResults.push_back(write->
getResult(0));
3203 tensor::InsertSliceOp insertOp)
const override {
3205 if (!padOp.hasZeroLowPad())
3208 if (!insertOp.hasUnitStride())
3211 auto padValue = padOp.getConstantPaddingValue();
3215 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3218 if (insertOp.getDest() == padOp.getResult())
3222 padOp.getType().getElementType());
3223 unsigned vecRank = vecType.getRank();
3224 unsigned tensorRank = insertOp.getType().getRank();
3229 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3231 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3232 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3244 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3245 vecType, padOp.getSource(),
3246 readIndices, padValue);
3252 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3255 insertOp, read, insertOp.getDest(), writeIndices,
3281 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3282 <<
", second op: " << *secondOp;
3285 for (
auto v : values) {
3286 for (
auto &u : v.getUses()) {
3288 if (owner == firstOp || owner == secondOp)
3294 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3295 <<
", second op: " << *secondOp;
3305 memref::SubViewOp subViewOp;
3307 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3309 return memref::SubViewOp();
3310 subViewOp = newSubViewOp;
3322 if (xferOp.getMask())
3326 Value viewOrAlloc = xferOp.getBase();
3335 Value subView = subViewOp.getResult();
3338 memref::CopyOp copyOp;
3339 for (
auto &u : subView.
getUses()) {
3340 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3341 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3342 if (newCopyOp.getTarget() != subView)
3356 for (
auto &u : viewOrAlloc.
getUses()) {
3357 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3358 assert(isa<MemRefType>(newFillOp.output().getType()));
3359 if (newFillOp.output() != viewOrAlloc)
3363 maybeFillOp = newFillOp;
3368 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3370 "padding value does not match fill");
3373 Value in = copyOp.getSource();
3379 auto vectorType = xferOp.getVectorType();
3380 Value res = vector::TransferReadOp::create(
3381 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3382 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3387 rewriter.
eraseOp(maybeFillOp);
3399 if (xferOp.getMask())
3403 Value viewOrAlloc = xferOp.getBase();
3412 Value subView = subViewOp.getResult();
3415 memref::CopyOp copyOp;
3416 for (
auto &u : subViewOp.getResult().getUses()) {
3417 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3418 if (newCopyOp.getSource() != subView)
3430 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3431 Value out = copyOp.getTarget();
3438 auto vector = xferOp.getVector();
3439 vector::TransferWriteOp::create(
3440 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3441 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3443 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3458 template <
int N,
typename IntTy,
typename... IntTy2>
3460 val = shapedType.getShape()[N];
3465 template <
typename... IntTy>
3467 bindShapeDims<0>(shapedType, vals...);
3505 struct Conv1DGenerator
3507 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3510 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3511 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3512 resShaped = linalgOp.getDpsInitOperand(0)->get();
3513 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3514 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3515 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3520 setConvOperationKind(reduceOp);
3523 reductionKind = maybeKind.value();
3531 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3532 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3554 int64_t nSize, wSize, cSize, kwSize, fSize;
3557 switch (conv1DOpOrder) {
3560 nSize = fSize = cSize = 0;
3567 (wSize + kwSize - 1)};
3568 rhsShape = {kwSize};
3575 case ConvOperationKind::Conv:
3579 case ConvOperationKind::Pool:
3589 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3593 case ConvOperationKind::Conv:
3594 rhsShape = {kwSize, cSize, fSize};
3596 case ConvOperationKind::Pool:
3597 rhsShape = {kwSize};
3600 resShape = {nSize, wSize, fSize};
3606 case ConvOperationKind::Conv:
3610 case ConvOperationKind::Pool:
3616 lhsShape = {nSize, cSize,
3620 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3623 case ConvOperationKind::Conv:
3624 rhsShape = {fSize, cSize, kwSize};
3626 case ConvOperationKind::Pool:
3627 rhsShape = {kwSize};
3630 resShape = {nSize, fSize, wSize};
3634 vector::TransferWriteOp write;
3640 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3642 Type lhsEltType = lhsShapedType.getElementType();
3643 Type rhsEltType = rhsShapedType.getElementType();
3644 Type resEltType = resShapedType.getElementType();
3654 Value lhs = vector::TransferReadOp::create(
3655 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3658 Value rhs =
nullptr;
3659 if (oper == ConvOperationKind::Conv)
3660 rhs = vector::TransferReadOp::create(
3661 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3663 Value res = vector::TransferReadOp::create(
3664 rewriter, loc, resType, resShaped, resPadding,
3670 switch (conv1DOpOrder) {
3678 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3679 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3681 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3684 if (oper == ConvOperationKind::Conv)
3685 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3687 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3688 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3699 kwSize, strideW, dilationW, wSizeStep,
3702 if (oper == ConvOperationKind::Conv)
3705 wSizeStep, isSingleChanneled);
3707 auto linearIndex = [&](int64_t kw, int64_t w) {
3708 return kw * (wSize / wSizeStep) + w;
3714 for (int64_t kw = 0; kw < kwSize; ++kw) {
3715 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3717 case ConvOperationKind::Conv:
3718 if (isSingleChanneled) {
3719 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3720 lhsVals[linearIndex(kw, w)],
3721 rhsVals[kw], resVals[w]);
3723 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3724 lhsVals[linearIndex(kw, w)],
3725 rhsVals[kw], resVals[w]);
3728 case ConvOperationKind::Pool:
3729 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3745 switch (conv1DOpOrder) {
3752 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3753 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3758 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3767 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3768 if (srcElementType == dstElementType)
3773 const Type dstType =
3774 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3776 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3777 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3780 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3781 srcWidth < dstWidth)
3782 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3784 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3785 srcWidth < dstWidth)
3786 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3788 assert(
false &&
"unhandled promotion case");
3795 vector::IteratorType par = vector::IteratorType::parallel;
3796 vector::IteratorType red = vector::IteratorType::reduction;
3801 auto contrationOp = vector::ContractionOp::create(
3802 rewriter, loc, lhs, rhs, res,
3803 MapList{{n, w, c}, {c, f}, {n, w, f}},
3805 contrationOp.setKind(reductionKind);
3806 return contrationOp;
3813 return vector::OuterProductOp::create(rewriter, loc, res.
getType(), lhs,
3814 rhs, res, vector::CombiningKind::ADD);
3836 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3837 bool channelDimScalableFlag,
3839 bool scalableChDim =
false;
3840 bool useMasking =
false;
3841 int64_t nSize, wSize, cSize, kwSize;
3844 if (ShapedType::isDynamic(cSize)) {
3845 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3846 cSize = channelDimVecSize;
3850 scalableChDim = channelDimScalableFlag;
3854 assert(!(useMasking && flatten) &&
3855 "Unsupported flattened conv with dynamic shapes");
3860 vector::TransferWriteOp write;
3866 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3868 Type lhsEltType = lhsShapedType.getElementType();
3869 Type rhsEltType = rhsShapedType.getElementType();
3870 Type resEltType = resShapedType.getElementType();
3875 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3877 lhsEltType, {
false,
false, scalableChDim});
3878 VectorType rhsType =
3880 {
false, scalableChDim});
3881 VectorType resType =
3883 {
false,
false, scalableChDim});
3896 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3897 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3901 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3904 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3911 Value lhs = vector::TransferReadOp::create(
3912 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3914 auto maybeMaskedLhs = maybeMaskXferOp(
3915 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3918 Value rhs = vector::TransferReadOp::create(
3919 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3921 auto maybeMaskedRhs = maybeMaskXferOp(
3922 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3925 Value res = vector::TransferReadOp::create(
3926 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3928 auto maybeMaskedRes = maybeMaskXferOp(
3929 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3941 for (int64_t kw = 0; kw < kwSize; ++kw) {
3942 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3943 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3944 rewriter, loc, maybeMaskedLhs->getResult(0),
3946 inOutSliceSizes, inOutStrides));
3950 for (int64_t kw = 0; kw < kwSize; ++kw) {
3952 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3956 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3957 resVals.push_back(vector::ExtractStridedSliceOp::create(
3958 rewriter, loc, maybeMaskedRes->getResult(0),
3963 auto linearIndex = [&](int64_t kw, int64_t w) {
3964 return kw * (wSize / wSizeStep) + w;
3970 auto lhsTypeAfterFlattening =
3972 auto resTypeAfterFlattening =
3976 for (int64_t kw = 0; kw < kwSize; ++kw) {
3977 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3978 Value lhsVal = lhsVals[linearIndex(kw, w)];
3979 Value resVal = resVals[w];
3984 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3985 lhsVals[linearIndex(kw, w)]);
3986 resVal = vector::ShapeCastOp::create(
3987 rewriter, loc, resTypeAfterFlattening, resVals[w]);
3989 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3990 rhsVals[kw], resVal, flatten);
3993 resVals[w] = vector::ShapeCastOp::create(
4001 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
4003 for (
auto &collection :
4004 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4005 for (
Value v : collection)
4012 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4013 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4014 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4023 Operation *resOut = vector::TransferWriteOp::create(
4024 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4026 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4037 auto rhsTy = cast<ShapedType>(rhs.
getType());
4038 auto resTy = cast<ShapedType>(res.
getType());
4041 lhs =
promote(rewriter, loc, lhs, resTy);
4052 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
4053 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4056 for (
int i = 0; i < resSize / rhsSize; ++i) {
4057 for (
int j = 0;
j < rhsSize; ++
j)
4058 indices.push_back(
j);
4061 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4064 rhs = vector::BroadcastOp::create(rewriter, loc,
4065 resTy.clone(rhsTy.getElementType()), rhs);
4067 rhs =
promote(rewriter, loc, rhs, resTy);
4072 if (isa<FloatType>(resTy.getElementType()))
4073 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4075 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4076 return arith::AddIOp::create(rewriter, loc, mul, res);
4081 FailureOr<Operation *> generateNonChanneledConv() {
4084 if (!iters({Par(), Red()}))
4086 "failed to match conv::W 1-par 1-red");
4089 if (layout({ {w + kw},
4099 FailureOr<Operation *> generateNwcConv() {
4102 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4104 op,
"failed to match conv::Nwc 3-par 2-red");
4107 if (layout({ {n, strideW * w + dilationW * kw, c},
4117 FailureOr<Operation *> generateNcwConv() {
4120 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4122 op,
"failed to match conv::Ncw 3-par 2-red");
4124 if (layout({ {n, c, strideW * w + dilationW * kw},
4134 FailureOr<Operation *> generateNwcPooling() {
4137 if (!iters({Par(), Par(), Par(), Red()}))
4139 "failed to match pooling 3-par 1-red");
4142 if (layout({ {n, strideW * w + dilationW * kw, c},
4152 FailureOr<Operation *> generateNcwPooling() {
4155 if (!iters({Par(), Par(), Par(), Red()}))
4157 "failed to match pooling 3-par 1-red");
4159 if (layout({ {n, c, strideW * w + dilationW * kw},
4169 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4170 bool vecChDimScalableFlag =
false,
4171 bool flatten =
false) {
4174 if (!iters({Par(), Par(), Par(), Red()}))
4176 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4179 if (layout({ {n, strideW * w + dilationW * kw, c},
4182 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4188 ConvOperationKind oper = ConvOperationKind::Conv;
4190 StringAttr poolExtOp;
4191 bool isPoolExt =
false;
4192 int strideW, dilationW;
4193 Value lhsShaped, rhsShaped, resShaped;
4194 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4195 vector::CombiningKind reductionKind;
4198 void setConvOperationKind(
Operation *reduceOp) {
4199 int numBlockArguments =
4200 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4201 if (numBlockArguments == 1) {
4206 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4207 llvm::IsaPred<BlockArgument>);
4208 Operation *feedOp = (*feedValIt).getDefiningOp();
4210 oper = ConvOperationKind::Pool;
4215 oper = ConvOperationKind::Conv;
4219 oper = ConvOperationKind::Pool;
4229 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4230 Conv1DGenerator conv1dGen(rewriter, op);
4231 auto res = conv1dGen.generateNonChanneledConv();
4234 res = conv1dGen.generateNwcConv();
4237 res = conv1dGen.generateNcwConv();
4240 res = conv1dGen.generateNwcPooling();
4243 res = conv1dGen.generateNcwPooling();
4250 uint64_t vecChDimSize = ShapedType::kDynamic;
4251 bool vecChDimScalableFlag =
false;
4252 if (!inputVecSizes.empty()) {
4255 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4256 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4257 "Not a 1D depthwise conv!");
4260 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4261 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4263 vecChDimSize = inputVecSizes[chDimIdx];
4264 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4266 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4267 flatten1DDepthwiseConv);
4276 if (
failed(resultOrFail))
4280 rewriter.
eraseOp(op.getOperation());
4283 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
SmallVector< OpFoldResult > innerTiles
SmallVector< int64_t > innerDimsPos
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.