37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/Sequence.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/DebugLog.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
50 #define DEBUG_TYPE "linalg-vectorization"
53 static FailureOr<Operation *>
57 bool flatten1DDepthwiseConv =
false);
92 template <
typename OpType>
95 block.
walk([&](OpType op) {
110 int64_t nSize, int64_t wSize, int64_t cSize,
111 int64_t kwSize,
int strideW,
int dilationW,
112 int64_t wSizeStep,
bool isSingleChanneled) {
114 if (isSingleChanneled) {
119 for (int64_t kw = 0; kw < kwSize; ++kw) {
120 for (int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
179 for (int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
190 Value res, int64_t wSize, int64_t wSizeStep,
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
226 bool assumeDynamicDimsMatchVecSizes =
false);
241 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
244 if (dimPermutation.has_value()) {
246 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
248 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
275 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
284 std::optional<AffineMap> maybeMaskingMap);
289 bool isValidMaskingMap(
AffineMap maskingMap) {
342 bool assumeDynamicDimsMatchVecSizes =
false;
346 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
349 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
360 unsigned operandDimPos;
361 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
366 linalgOp.hasPureTensorSemantics()
367 ? (
Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 : (
Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 iterSpaceValueSizes.push_back(dynamicDim);
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
389 if (!inputVectorSizes.empty()) {
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(),
false);
404 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407 if (ShapedType::isDynamicShape(canonicalVecShape))
411 initIterSpaceStaticSizes(linalgOp);
416 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
426 Value VectorizationState::getOrCreateMaskFor(
428 std::optional<AffineMap> maybeMaskingMap) {
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.
getContext());
449 LDBG() <<
"Masking map: " << maskingMap;
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() <<
"Reusing mask: " << mask;
467 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
471 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
473 if (permutedStaticSizes == maskShape) {
474 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] =
Value();
479 if (assumeDynamicDimsMatchVecSizes) {
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 return std::get<0>(it) == ShapedType::kDynamic
487 : std::get<0>(it) == std::get<1>(it);
490 <<
"Dynamic + static dimensions match vector sizes, masking is not "
492 activeMaskCache[maskingMap] =
Value();
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() <<
"Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() <<
"Trying to mask: " << *opToMask;
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526 LDBG() <<
"No mask required";
531 assert(opToMask &&
"Expected a valid operation to mask");
532 auto maskOp = cast<vector::MaskOp>(
534 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
540 LDBG() <<
"Masked operation: " << *maskOp;
563 "expected projected permutation");
565 assert(res.getNumDims() ==
566 (res.getNumResults() - res.getNumOfZeroResults()) &&
567 "expected reindexed map with same number of dims and results");
603 std::optional<vector::CombiningKind>
605 using ::mlir::vector::CombiningKind;
610 .Case<arith::AddIOp, arith::AddFOp>(
611 [&](
auto op) {
return CombiningKind::ADD; })
612 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
613 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
614 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
615 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
616 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
617 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
619 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
620 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
621 .Case<arith::MulIOp, arith::MulFOp>(
622 [&](
auto op) {
return CombiningKind::MUL; })
623 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
624 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
625 .Default([&](
auto op) {
return std::nullopt; });
636 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
641 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
642 combinerOps.size() != 1)
646 return combinerOps[0];
652 auto dstVecType = dyn_cast<VectorType>(dstType);
654 if (dstVecType.getRank() == 0)
660 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
672 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
673 return vector::MultiDimReductionOp::create(
674 b, reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
678 return llvm::to_vector(
685 return isa<linalg::ReduceOp>(op) ||
686 (isa<linalg::GenericOp>(op) &&
700 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
701 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
710 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
712 auto vectorType = state.getCanonicalVecType(
716 if (vectorType.getRank() > 0) {
719 linalgOp.getRank(outputOperand),
722 assert(value.
getType() == vectorType &&
"Incorrect type");
723 write = vector::TransferWriteOp::create(
724 rewriter, loc, value, outputOperand->
get(), indices, writeMap);
727 if (!isa<VectorType>(value.
getType()))
728 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
729 assert(value.
getType() == vectorType &&
"Incorrect type");
730 write = vector::TransferWriteOp::create(rewriter, loc, value,
734 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
738 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
739 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
744 LDBG() <<
"vectorized op: " << *write;
754 std::function<LogicalResult(
Operation *,
bool)>;
773 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
782 linalgOp.getDpsInitOperand(output.index()), state);
784 newResults.push_back(newResult);
798 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
801 auto loc = indexOp.getLoc();
804 auto dim = indexOp.getDim();
806 auto indexVectorType =
808 state.getScalableVecDims()[dim]);
809 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
813 if (dim == targetShape.size() - 1)
819 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
820 std::swap(permPattern[dim], permPattern.back());
824 auto broadCastOp = vector::BroadcastOp::create(
826 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
828 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
829 std::swap(transposition.back(), transposition[dim]);
831 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
839 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
843 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
848 if (not extractOp.getIndices().empty()) {
849 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
853 if (!llvm::all_of(extractOp->getResultTypes(),
854 VectorType::isValidElementType)) {
873 tensor::ExtractOp extractOp,
876 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
877 auto loc = extractOp.getLoc();
880 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
882 const size_t numIndices = extractOp.getIndices().size();
883 for (
size_t i = 1; i < numIndices; i++) {
888 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
891 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
894 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
896 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
922 (linalgOp.hasDynamicShape() ||
923 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
924 "For statically shaped Linalg Ops, only one "
925 "non-unit loop dim is expected");
926 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
928 size_t idx = loopRanges.size() - 1;
929 for (; idx != 0; idx--)
930 if (loopRanges[idx] != 1)
938 VectorType resType) {
940 assert(((llvm::count_if(resType.getShape(),
941 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942 "n-D vectors are not yet supported");
948 auto *block = linalgOp.getBlock();
949 if (isa<BlockArgument>(val))
950 return !llvm::is_contained(block->getArguments(), val);
953 assert(defOp &&
"This is neither a block argument nor an operation result");
958 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
962 auto *ancestor = block->findAncestorOpInBlock(*defOp);
969 if (isa<arith::ConstantOp>(ancestor))
973 for (
auto op : ancestor->getOperands())
997 bool &foundIndexOp, VectorType resType) {
999 assert(((llvm::count_if(resType.getShape(),
1000 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1001 "n-D vectors are not yet supported");
1007 auto *block = linalgOp.getBlock();
1008 if (isa<BlockArgument>(val))
1009 return !llvm::is_contained(block->getArguments(), val);
1012 assert(defOp &&
"This is neither a block argument nor an operation result");
1014 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1017 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1021 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1028 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1031 bool result =
false;
1032 for (
auto op : ancestor->getOperands())
1052 LinalgOp &linalgOp, VectorType resType) {
1054 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1057 if (inputShape.getShape().empty())
1062 bool isOutput1DVector =
1063 (llvm::count_if(resType.getShape(),
1064 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1066 if (!isOutput1DVector)
1069 bool leadingIdxsLoopInvariant =
true;
1075 auto indices = extractOp.getIndices();
1076 auto leadIndices = indices.drop_back(1);
1079 if (inputShape.getShape()[i] == 1)
1085 if (!leadingIdxsLoopInvariant) {
1086 LDBG() <<
"Found gather load: " << extractOp;
1094 auto extractOpTrailingIdx = indices.back();
1098 if (leadingIdxsLoopInvariant &&
1100 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1109 bool foundIndexOp =
false;
1111 foundIndexOp, resType);
1114 bool isRowVector = resType.getShape().back() != 1;
1115 isContiguousLoad &= (foundIndexOp && isRowVector);
1117 if (isContiguousLoad) {
1118 LDBG() <<
"Found contigous load: " << extractOp;
1123 LDBG() <<
"Found gather load: " << extractOp;
1134 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1137 auto loc = extractOp.getLoc();
1140 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1141 auto maskConstantOp = arith::ConstantOp::create(
1145 auto passThruConstantOp = arith::ConstantOp::create(
1151 extractOp.getIndices().size(),
1162 Operation *gatherOp = vector::GatherOp::create(
1163 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1164 maskConstantOp, passThruConstantOp);
1165 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1167 LDBG() <<
"Vectorised as gather load: " << extractOp;
1190 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1191 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1193 transferReadIdxs.push_back(idx);
1197 auto indexAs1dVector = vector::ShapeCastOp::create(
1200 resultType.getScalableDims().back()),
1202 transferReadIdxs.push_back(
1203 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1207 auto dstRank = resultType.getRank();
1208 auto srcRank = extractOp.getTensor().getType().getRank();
1217 auto transferReadOp = vector::TransferReadOp::create(
1218 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1219 std::nullopt, permutationMap, inBounds);
1226 auto allTrue = vector::ConstantMaskOp::create(
1228 auto *maskedReadOp =
1231 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1240 int32_t rankDiff = dstRank - srcRank;
1248 while (rankDiff > 0) {
1249 permutationMap = permutationMap.insertResult(
1254 auto transferReadOp = vector::TransferReadOp::create(
1255 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1256 std::nullopt, permutationMap, inBounds);
1258 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1272 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1273 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1277 (outputType && reduceType.getShape() == outputType.getShape()))
1306 LDBG() <<
"vectorize op " << *op;
1309 if (!customVectorizationHooks.empty()) {
1310 for (
auto &customFunc : customVectorizationHooks) {
1320 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1322 rewriter.
clone(*op)};
1331 auto blockArg = dyn_cast<BlockArgument>(operand);
1332 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1333 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1337 linalgOp.getRegionOutputArgs(),
1338 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1341 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1343 if (!reductionOperands.empty()) {
1344 assert(reductionOperands.size() == 1);
1346 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1347 reductionOperands[0].second, bvm);
1354 VectorType firstMaxRankedType;
1356 auto vecOperand = bvm.
lookup(operand);
1357 assert(vecOperand &&
"Vector operand couldn't be found");
1359 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1360 if (vecType && (!firstMaxRankedType ||
1361 firstMaxRankedType.getRank() < vecType.getRank()))
1362 firstMaxRankedType = vecType;
1368 assert(vecOperand &&
"Vector operand couldn't be found");
1370 if (firstMaxRankedType) {
1373 firstMaxRankedType.getScalableDims());
1376 vecOperands.push_back(vecOperand);
1382 resultTypes.push_back(
1385 firstMaxRankedType.getScalableDims())
1417 static LogicalResult
1421 LDBG() <<
"Vectorizing operation as linalg generic/n";
1422 Block *block = linalgOp.getBlock();
1429 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1431 if (linalgOp.getNumDpsInits() == 0)
1437 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1438 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1439 if (linalgOp.isScalar(opOperand)) {
1440 bvm.
map(bbarg, opOperand->get());
1446 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1449 VectorType readType;
1451 if (linalgOp.isDpsInput(opOperand)) {
1454 readType = state.getCanonicalVecType(elemType);
1461 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1466 Operation *read = vector::TransferReadOp::create(
1467 rewriter, loc, readType, opOperand->get(), indices,
1468 std::nullopt, readMap);
1469 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1474 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1476 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1482 if (readType.getRank() == 0)
1486 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1498 hooks.push_back(vectorizeYield);
1505 hooks.push_back(vectorizeIndex);
1512 hooks.push_back(vectorizeExtract);
1519 LDBG() <<
"failed to vectorize: " << op;
1524 state.maskOperation(rewriter, result.
newOp, linalgOp);
1525 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1591 if (ShapedType::isDynamicShape(destShape))
1598 cstMaskSizes.push_back(*intSize);
1603 if (cstMaskSizes.size() != maskShape.size())
1611 cstWriteIdxs.push_back(intVal.getSExtValue());
1616 if (cstWriteIdxs.size() != destShape.size())
1625 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1627 if ( maskShape[i] > destShape[rankDiff + i] ||
1628 destShape[rankDiff + i] <
1629 (
std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1665 bool useInBoundsInsteadOfMasking =
false) {
1667 ShapedType destType = cast<ShapedType>(dest.
getType());
1668 int64_t destRank = destType.getRank();
1669 auto destShape = destType.getShape();
1671 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1672 int64_t vecToStoreRank = vecToStoreType.getRank();
1673 auto vecToStoreShape = vecToStoreType.getShape();
1677 if (useInBoundsInsteadOfMasking) {
1680 for (
unsigned i = 0; i < vecToStoreRank; i++)
1682 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1683 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1687 assert((writeIndices.empty() ||
1688 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1689 "Invalid number of write indices!");
1690 if (writeIndices.empty()) {
1692 writeIndices.assign(destRank, zero);
1696 Operation *write = vector::TransferWriteOp::create(builder, loc,
1703 if (useInBoundsInsteadOfMasking)
1707 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1712 vecToStoreType.getScalableDims());
1715 isa<MemRefType>(dest.
getType())
1725 Value maskForWrite =
1726 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1764 static LogicalResult
1773 auto padValue = packOp.getPaddingValue();
1775 padValue = arith::ConstantOp::create(
1777 rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1780 LogicalResult status =
1781 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1782 .reifyResultShapes(rewriter, reifiedReturnShapes);
1784 assert(succeeded(status) &&
"failed to reify result shapes");
1789 bool useInBoundsInsteadOfMasking =
false;
1790 if (inputVectorSizes.empty()) {
1792 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1793 useInBoundsInsteadOfMasking =
true;
1798 auto innerTiles = packOp.getStaticInnerTiles();
1807 rewriter, loc, packOp.getSource(), inputShape, padValue,
1808 useInBoundsInsteadOfMasking,
1815 packOp.getDestType().getElementType());
1817 vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead);
1820 auto destPermutation =
1822 auto transposeOp = vector::TransposeOp::create(
1823 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1826 Value dest = tensor::EmptyOp::create(
1827 rewriter, loc, reifiedReturnShapes[0],
1828 transposeOp.getResult().getType().getElementType());
1831 newResults.push_back(write->
getResult(0));
1853 assert(type.getNumScalableDims() < 2 &&
1854 "Collapsing more than 1 scalable dim is not supported ATM");
1860 auto shape = type.getShape();
1861 auto scalableFlags = type.getScalableDims();
1865 unsigned currentDim = 0;
1867 unsigned dim = m.getNumResults();
1870 for (
unsigned d = 0; d < dim; ++d) {
1871 size *= shape[currentDim + d];
1872 flag |= scalableFlags[currentDim + d];
1874 newShape.push_back(size);
1875 newScalableFlags.push_back(flag);
1879 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1911 static LogicalResult
1916 if (!inputVectorSizes.empty()) {
1917 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1918 "Invalid number of input vector sizes!");
1919 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1920 "Incompatible number of vector sizes and vector scalable flags!");
1927 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1930 bool useInBoundsInsteadOfMasking =
false;
1939 if (inputVectorSizes.empty()) {
1940 if (ShapedType::isDynamicShape(sourceShape))
1943 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1944 useInBoundsInsteadOfMasking =
true;
1948 auto padValue = arith::ConstantOp::create(
1950 rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1952 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1953 useInBoundsInsteadOfMasking, readScalableVectorFlags);
1956 PackingMetadata packMetadata;
1959 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1960 rewriter, loc, readResult, lastDimToInsertPosPerm);
1964 transposeOp.getType(),
1966 rewriter.
getContext(), packMetadata.reassociations)));
1967 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1968 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1972 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1973 {}, useInBoundsInsteadOfMasking);
1975 newResults.push_back(write->
getResult(0));
1982 static LogicalResult
1986 auto padValue = padOp.getConstantPaddingValue();
1994 LogicalResult status =
1995 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1996 .reifyResultShapes(rewriter, reifiedReturnShapes);
1998 assert(succeeded(status) &&
"failed to reify result shapes");
2000 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2004 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2005 padOp.getResultType().getElementType());
2007 newResults.push_back(write->
getResult(0));
2015 LDBG() <<
"reduction precondition failed: no reduction iterator";
2018 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
2019 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2025 LDBG() <<
"reduction precondition failed: reduction detection failed";
2032 static LogicalResult
2034 bool flatten1DDepthwiseConv) {
2035 if (flatten1DDepthwiseConv) {
2036 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2041 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2042 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2048 Value lhs = conv.getDpsInputOperand(0)->get();
2050 auto shapeWithoutCh = lhsShape.drop_back(1);
2051 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2052 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2053 "channel dim can be dynamic";
2060 static LogicalResult
2062 bool flatten1DDepthwiseConv) {
2063 if (isa<ConvolutionOpInterface>(op.getOperation()))
2072 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2076 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2085 static LogicalResult
2090 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2091 unpackOp.getSourceType().hasStaticShape())
2096 if (!inputVectorSizes.empty() &&
2097 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2098 LDBG() <<
"Incorrect number of input vector sizes";
2104 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2105 LDBG() <<
"Invalid vector sizes for the read operation";
2112 static LogicalResult
2117 auto sourceType = source.getType();
2118 if (!VectorType::isValidElementType(sourceType.getElementType()))
2134 bool isOutOfBoundsRead =
2135 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2137 if (!padValue && isOutOfBoundsRead) {
2138 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2151 static LogicalResult
2161 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2164 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2168 LDBG() <<
"Failed to determine contraction combining kind.";
2175 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2176 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2178 LDBG() <<
"Contractions with broadcasts are not supported.";
2184 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2188 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2192 VectorType readType =
2193 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2196 rewriter, loc, opOperand.get(), readType.getShape(),
2198 false, readType.getScalableDims());
2199 vecOperands.push_back(read);
2204 auto iterators = linalgOp.getIteratorTypesArray();
2205 for (utils::IteratorType iter : iterators) {
2206 auto vecIter = iter == utils::IteratorType::parallel
2207 ? vector::IteratorType::parallel
2208 : vector::IteratorType::reduction;
2213 Operation *contractOp = vector::ContractionOp::create(
2214 rewriter, loc, vecOperands[0],
2215 vecOperands[1], vecOperands[2],
2216 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2217 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2221 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2225 newResults.push_back(write->
getResult(0));
2231 enum class ConvOperationKind { Conv, Pool };
2249 static std::optional<ConvOperationKind>
2251 int numBlockArguments =
2252 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2254 switch (numBlockArguments) {
2260 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2261 llvm::IsaPred<BlockArgument>);
2263 "Expected a non-block argument operand");
2264 Operation *feedOp = (*feedValIt).getDefiningOp();
2266 return ConvOperationKind::Pool;
2269 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2270 (isa<arith::AndIOp>(feedOp) &&
2273 if (isa<BlockArgument>(v))
2275 if (Operation *op = v.getDefiningOp())
2276 return isCastOfBlockArgument(op);
2279 return std::nullopt;
2282 return ConvOperationKind::Conv;
2286 return ConvOperationKind::Pool;
2288 return std::nullopt;
2294 case vector::CombiningKind::ADD:
2295 case vector::CombiningKind::MAXNUMF:
2296 case vector::CombiningKind::MAXIMUMF:
2297 case vector::CombiningKind::MAXSI:
2298 case vector::CombiningKind::MAXUI:
2299 case vector::CombiningKind::MINNUMF:
2300 case vector::CombiningKind::MINIMUMF:
2301 case vector::CombiningKind::MINSI:
2310 auto getOperandType = [&](
auto operand) {
2311 return dyn_cast<ShapedType>((operand->get()).getType());
2313 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2314 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2315 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2319 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2320 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2328 if (!maybeOper.has_value())
2335 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2336 *maybeKind != vector::CombiningKind::OR) &&
2337 (*maybeOper != ConvOperationKind::Pool ||
2342 auto rhsRank = rhsShapedType.getRank();
2343 if (*maybeOper == ConvOperationKind::Pool) {
2347 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2356 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2358 if (llvm::any_of(linalgOp->getOpOperands(), [&](
OpOperand &operand) {
2359 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2363 if (!inputVectorSizes.empty() &&
2369 linalgOp, flatten1DDepthwiseConv))) {
2370 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2383 customPreconditions,
2386 customPrecondition(&innerOp, vectorizeNDExtract));
2390 if (!llvm::all_of(innerOp.getOperandTypes(),
2391 VectorType::isValidElementType)) {
2394 if (!llvm::all_of(innerOp.getResultTypes(),
2395 VectorType::isValidElementType)) {
2405 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2412 LDBG() <<
"precondition failed: not projected permutations";
2416 LDBG() <<
"precondition failed: reduction preconditions";
2422 static LogicalResult
2425 auto padValue = packOp.getPaddingValue();
2428 LDBG() <<
"pad value is not constant: " << packOp;
2433 bool satisfyEmptyCond =
true;
2434 if (inputVectorSizes.empty()) {
2435 if (!packOp.getDestType().hasStaticShape() ||
2436 !packOp.getSourceType().hasStaticShape())
2437 satisfyEmptyCond =
false;
2440 if (!satisfyEmptyCond &&
2442 resultTensorShape.take_front(packOp.getSourceRank()),
2446 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2447 return !getConstantIntValue(v).has_value();
2449 LDBG() <<
"inner_tiles must be constant: " << packOp;
2456 static LogicalResult
2459 auto padValue = padOp.getConstantPaddingValue();
2461 LDBG() <<
"pad value is not constant: " << padOp;
2481 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2482 Value padValue = en.value();
2483 unsigned pos = en.index();
2484 std::optional<int64_t> pad = getConstantIntValue(padValue);
2485 return (!pad.has_value() || pad.value() != 0) &&
2486 resultTensorShape[pos] != 1;
2488 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2501 static LogicalResult
2505 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2506 "Number of input vector sizes and scalable dims doesn't match");
2508 size_t numOfScalableDims =
2509 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2511 if (numOfScalableDims == 0)
2514 auto linalgOp = dyn_cast<LinalgOp>(op);
2519 return success(isa<linalg::UnPackOp>(op));
2523 if (numOfScalableDims > 2)
2543 bool seenNonUnitParallel =
false;
2544 auto iterators = linalgOp.getIteratorTypesArray();
2546 int64_t idx = scalableFlags.size() - 1;
2547 while (!scalableFlags[idx]) {
2548 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2549 seenNonUnitParallel |=
2550 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2552 iterators.pop_back();
2553 scalableFlags.pop_back();
2558 switch (iterators.back()) {
2559 case utils::IteratorType::reduction: {
2561 if (iterators.size() != inputVectorSizes.size()) {
2562 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2566 if (isa<linalg::MatmulOp>(op)) {
2568 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2574 case utils::IteratorType::parallel: {
2576 if (seenNonUnitParallel) {
2577 LDBG() <<
"Inner parallel dim not requested for scalable "
2589 if (numOfScalableDims == 2) {
2593 if (iterators.back() == utils::IteratorType::reduction) {
2594 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2599 scalableFlags.pop_back();
2600 iterators.pop_back();
2602 if (!scalableFlags.back() ||
2603 (iterators.back() != utils::IteratorType::parallel))
2609 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2610 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2611 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2618 bool flatten1DDepthwiseConv) {
2624 inputScalableVecDims)))
2628 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2631 flatten1DDepthwiseConv);
2633 .Case<tensor::PadOp>([&](
auto padOp) {
2636 .Case<linalg::PackOp>([&](
auto packOp) {
2639 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2642 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2645 .Default([](
auto) {
return failure(); });
2651 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2653 for (
auto op : make_early_inc_range(toReplace)) {
2656 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2657 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2658 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2664 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2665 tensor::InsertSliceOp>(op);
2671 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2672 bool createNamedContraction) {
2673 LDBG() <<
"Attempting to vectorize: " << *op;
2674 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2675 LDBG() <<
"Input scalable vector dims: "
2676 << llvm::interleaved(inputScalableVecDims);
2680 flatten1DDepthwiseConv))) {
2681 LDBG() <<
"Vectorization pre-conditions failed";
2687 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2688 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2689 inputScalableVecDims,
2690 assumeDynamicDimsMatchVecSizes))) {
2691 LDBG() <<
"Vectorization state couldn't be initialized";
2697 auto vectorizeResult =
2699 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2703 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2705 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2706 flatten1DDepthwiseConv);
2707 if (succeeded(convOr)) {
2708 llvm::append_range(results, (*convOr)->getResults());
2712 LDBG() <<
"Unsupported convolution can't be vectorized.";
2716 if (createNamedContraction &&
2717 isa<ContractionOpInterface>(linalgOp.getOperation()))
2722 <<
"Vectorize generic by broadcasting to the canonical vector "
2735 .Case<tensor::PadOp>([&](
auto padOp) {
2739 .Case<linalg::PackOp>([&](
auto packOp) {
2743 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2746 inputScalableVecDims, results);
2748 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2752 .Default([](
auto) {
return failure(); });
2754 if (failed(vectorizeResult)) {
2755 LDBG() <<
"Vectorization failed";
2763 memref::CopyOp copyOp) {
2764 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2765 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2766 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2771 if (!VectorType::isValidElementType(srcElementType) ||
2772 !VectorType::isValidElementType(dstElementType))
2783 rewriter, loc, readType, copyOp.getSource(), indices,
2786 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2790 vector::BroadcastOp::create(rewriter, loc, writeType,
readValue);
2792 Operation *writeValue = vector::TransferWriteOp::create(
2793 rewriter, loc,
readValue, copyOp.getTarget(), indices,
2804 template <
typename OpTy>
2812 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2813 if (
auto op = dyn_cast<OpTy>(user))
2814 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2820 tensor::PadOp padOp, OpTy op)
const = 0;
2848 vector::TransferReadOp xferOp)
const override {
2850 if (!padOp.hasZeroLowPad())
2853 auto padValue = padOp.getConstantPaddingValue();
2857 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2862 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2864 xferOp.getBaseMutable().assign(padOp.getSource());
2865 xferOp.getPaddingMutable().assign(padValue);
2910 vector::TransferWriteOp xferOp)
const override {
2912 if (xferOp.getTransferRank() == 0)
2916 if (!padOp.hasZeroLowPad())
2919 auto padValue = padOp.getConstantPaddingValue();
2923 if (!xferOp->hasOneUse())
2925 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2929 if (!trimPadding.hasZeroOffset())
2932 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2940 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2941 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2943 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2959 tensor::ExtractSliceOp afterTrimming)
const {
2962 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2963 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2966 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2967 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2972 if (t1.getRank() != t2.getRank())
2977 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2978 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2980 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2985 if (t1.getNumDynamicDims() == 0)
2993 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2997 assert(
static_cast<size_t>(t1.getRank()) ==
2998 beforeSlice.getMixedSizes().size());
2999 assert(
static_cast<size_t>(t2.getRank()) ==
3000 afterTrimming.getMixedSizes().size());
3002 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3004 if (!t1.isDynamicDim(i))
3006 auto size1 = beforeSlice.getMixedSizes()[i];
3007 auto size2 = afterTrimming.getMixedSizes()[i];
3014 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3015 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3021 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3022 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3023 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3024 minOp1.getOperands() == minOp2.getOperands())
3050 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3051 auto source = bcast.getSource();
3052 if (llvm::dyn_cast<VectorType>(source.getType()))
3060 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3061 return fill.getInputs()[0];
3066 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3073 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3081 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3087 static LogicalResult
3096 auto sourceType = source.getType();
3097 auto resultType = sliceOp.getResultType();
3102 auto elemType = sourceType.getElementType();
3103 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3109 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3110 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3111 if (!inputVectorSizes.empty()) {
3112 vecShape.push_back(inputVectorSizes[i]);
3113 }
else if (!sourceType.isDynamicDim(i)) {
3114 vecShape.push_back(sourceType.getDimSize(i));
3115 }
else if (!resultType.isDynamicDim(i)) {
3121 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3128 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
3131 auto loc = sliceOp.getLoc();
3137 rewriter, loc, source, vecType.getShape(), padValue,
3138 inputVectorSizes.empty(),
3146 writeIndices, inputVectorSizes.empty());
3149 newResults.push_back(write->
getResult(0));
3183 tensor::InsertSliceOp insertOp)
const override {
3185 if (!padOp.hasZeroLowPad())
3188 if (!insertOp.hasUnitStride())
3191 auto padValue = padOp.getConstantPaddingValue();
3195 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3198 if (insertOp.getDest() == padOp.getResult())
3202 padOp.getType().getElementType());
3203 unsigned vecRank = vecType.getRank();
3204 unsigned tensorRank = insertOp.getType().getRank();
3209 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3211 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3212 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3224 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3225 vecType, padOp.getSource(),
3226 readIndices, padValue);
3232 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3235 insertOp, read, insertOp.getDest(), writeIndices,
3261 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3262 <<
", second op: " << *secondOp;
3265 for (
auto v : values) {
3266 for (
auto &u : v.getUses()) {
3268 if (owner == firstOp || owner == secondOp)
3274 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3275 <<
", second op: " << *secondOp;
3285 memref::SubViewOp subViewOp;
3287 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3289 return memref::SubViewOp();
3290 subViewOp = newSubViewOp;
3302 if (xferOp.getMask())
3306 Value viewOrAlloc = xferOp.getBase();
3315 Value subView = subViewOp.getResult();
3318 memref::CopyOp copyOp;
3319 for (
auto &u : subView.
getUses()) {
3320 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3321 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3322 if (newCopyOp.getTarget() != subView)
3336 for (
auto &u : viewOrAlloc.
getUses()) {
3337 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3338 assert(isa<MemRefType>(newFillOp.output().getType()));
3339 if (newFillOp.output() != viewOrAlloc)
3343 maybeFillOp = newFillOp;
3348 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3350 "padding value does not match fill");
3353 Value in = copyOp.getSource();
3359 auto vectorType = xferOp.getVectorType();
3360 Value res = vector::TransferReadOp::create(
3361 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3362 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3367 rewriter.
eraseOp(maybeFillOp);
3379 if (xferOp.getMask())
3383 Value viewOrAlloc = xferOp.getBase();
3392 Value subView = subViewOp.getResult();
3395 memref::CopyOp copyOp;
3396 for (
auto &u : subViewOp.getResult().getUses()) {
3397 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3398 if (newCopyOp.getSource() != subView)
3410 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3411 Value out = copyOp.getTarget();
3418 auto vector = xferOp.getVector();
3419 vector::TransferWriteOp::create(
3420 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3421 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3423 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3438 template <
int N,
typename IntTy,
typename... IntTy2>
3440 val = shapedType.getShape()[N];
3445 template <
typename... IntTy>
3447 bindShapeDims<0>(shapedType, vals...);
3485 struct Conv1DGenerator
3487 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3490 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3491 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3492 resShaped = linalgOp.getDpsInitOperand(0)->get();
3493 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3494 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3495 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3500 setConvOperationKind(reduceOp);
3503 reductionKind = maybeKind.value();
3511 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3512 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3534 int64_t nSize, wSize, cSize, kwSize, fSize;
3537 switch (conv1DOpOrder) {
3540 nSize = fSize = cSize = 0;
3547 (wSize + kwSize - 1)};
3548 rhsShape = {kwSize};
3555 case ConvOperationKind::Conv:
3559 case ConvOperationKind::Pool:
3569 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3573 case ConvOperationKind::Conv:
3574 rhsShape = {kwSize, cSize, fSize};
3576 case ConvOperationKind::Pool:
3577 rhsShape = {kwSize};
3580 resShape = {nSize, wSize, fSize};
3586 case ConvOperationKind::Conv:
3590 case ConvOperationKind::Pool:
3596 lhsShape = {nSize, cSize,
3600 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3603 case ConvOperationKind::Conv:
3604 rhsShape = {fSize, cSize, kwSize};
3606 case ConvOperationKind::Pool:
3607 rhsShape = {kwSize};
3610 resShape = {nSize, fSize, wSize};
3614 vector::TransferWriteOp write;
3620 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3622 Type lhsEltType = lhsShapedType.getElementType();
3623 Type rhsEltType = rhsShapedType.getElementType();
3624 Type resEltType = resShapedType.getElementType();
3634 Value lhs = vector::TransferReadOp::create(
3635 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3638 Value rhs =
nullptr;
3639 if (oper == ConvOperationKind::Conv)
3640 rhs = vector::TransferReadOp::create(
3641 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3643 Value res = vector::TransferReadOp::create(
3644 rewriter, loc, resType, resShaped, resPadding,
3650 switch (conv1DOpOrder) {
3658 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3659 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3661 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3664 if (oper == ConvOperationKind::Conv)
3665 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3667 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3668 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3679 kwSize, strideW, dilationW, wSizeStep,
3682 if (oper == ConvOperationKind::Conv)
3685 wSizeStep, isSingleChanneled);
3687 auto linearIndex = [&](int64_t kw, int64_t w) {
3688 return kw * (wSize / wSizeStep) + w;
3694 for (int64_t kw = 0; kw < kwSize; ++kw) {
3695 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3697 case ConvOperationKind::Conv:
3698 if (isSingleChanneled) {
3699 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3700 lhsVals[linearIndex(kw, w)],
3701 rhsVals[kw], resVals[w]);
3703 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3704 lhsVals[linearIndex(kw, w)],
3705 rhsVals[kw], resVals[w]);
3708 case ConvOperationKind::Pool:
3709 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3725 switch (conv1DOpOrder) {
3732 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3733 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3738 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3747 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3748 if (srcElementType == dstElementType)
3753 const Type dstType =
3754 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3756 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3757 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3760 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3761 srcWidth < dstWidth)
3762 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3764 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3765 srcWidth < dstWidth)
3766 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3768 assert(
false &&
"unhandled promotion case");
3775 vector::IteratorType par = vector::IteratorType::parallel;
3776 vector::IteratorType red = vector::IteratorType::reduction;
3781 auto contrationOp = vector::ContractionOp::create(
3782 rewriter, loc, lhs, rhs, res,
3783 MapList{{n, w, c}, {c, f}, {n, w, f}},
3785 contrationOp.setKind(reductionKind);
3786 return contrationOp;
3793 return vector::OuterProductOp::create(rewriter, loc, res.
getType(), lhs,
3794 rhs, res, vector::CombiningKind::ADD);
3816 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3817 bool channelDimScalableFlag,
3819 bool scalableChDim =
false;
3820 bool useMasking =
false;
3821 int64_t nSize, wSize, cSize, kwSize;
3824 if (ShapedType::isDynamic(cSize)) {
3825 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3826 cSize = channelDimVecSize;
3830 scalableChDim = channelDimScalableFlag;
3834 assert(!(useMasking && flatten) &&
3835 "Unsupported flattened conv with dynamic shapes");
3840 vector::TransferWriteOp write;
3846 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3848 Type lhsEltType = lhsShapedType.getElementType();
3849 Type rhsEltType = rhsShapedType.getElementType();
3850 Type resEltType = resShapedType.getElementType();
3855 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3857 lhsEltType, {
false,
false, scalableChDim});
3858 VectorType rhsType =
3860 {
false, scalableChDim});
3861 VectorType resType =
3863 {
false,
false, scalableChDim});
3876 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3877 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3881 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3884 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3891 Value lhs = vector::TransferReadOp::create(
3892 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3894 auto maybeMaskedLhs = maybeMaskXferOp(
3895 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3898 Value rhs = vector::TransferReadOp::create(
3899 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3901 auto maybeMaskedRhs = maybeMaskXferOp(
3902 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3905 Value res = vector::TransferReadOp::create(
3906 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3908 auto maybeMaskedRes = maybeMaskXferOp(
3909 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3921 for (int64_t kw = 0; kw < kwSize; ++kw) {
3922 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3923 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3924 rewriter, loc, maybeMaskedLhs->getResult(0),
3926 inOutSliceSizes, inOutStrides));
3930 for (int64_t kw = 0; kw < kwSize; ++kw) {
3932 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3936 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3937 resVals.push_back(vector::ExtractStridedSliceOp::create(
3938 rewriter, loc, maybeMaskedRes->getResult(0),
3943 auto linearIndex = [&](int64_t kw, int64_t w) {
3944 return kw * (wSize / wSizeStep) + w;
3950 auto lhsTypeAfterFlattening =
3952 auto resTypeAfterFlattening =
3956 for (int64_t kw = 0; kw < kwSize; ++kw) {
3957 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3958 Value lhsVal = lhsVals[linearIndex(kw, w)];
3959 Value resVal = resVals[w];
3964 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3965 lhsVals[linearIndex(kw, w)]);
3966 resVal = vector::ShapeCastOp::create(
3967 rewriter, loc, resTypeAfterFlattening, resVals[w]);
3969 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3970 rhsVals[kw], resVal, flatten);
3973 resVals[w] = vector::ShapeCastOp::create(
3981 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3983 for (
auto &collection :
3984 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3985 for (
Value v : collection)
3992 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3993 maybeMaskedRes = vector::InsertStridedSliceOp::create(
3994 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4003 Operation *resOut = vector::TransferWriteOp::create(
4004 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4006 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4017 auto rhsTy = cast<ShapedType>(rhs.
getType());
4018 auto resTy = cast<ShapedType>(res.
getType());
4021 lhs =
promote(rewriter, loc, lhs, resTy);
4032 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
4033 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4036 for (
int i = 0; i < resSize / rhsSize; ++i) {
4037 for (
int j = 0;
j < rhsSize; ++
j)
4038 indices.push_back(
j);
4041 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4044 rhs = vector::BroadcastOp::create(rewriter, loc,
4045 resTy.clone(rhsTy.getElementType()), rhs);
4047 rhs =
promote(rewriter, loc, rhs, resTy);
4052 if (isa<FloatType>(resTy.getElementType()))
4053 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4055 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4056 return arith::AddIOp::create(rewriter, loc, mul, res);
4061 FailureOr<Operation *> generateNonChanneledConv() {
4064 if (!iters({Par(), Red()}))
4066 "failed to match conv::W 1-par 1-red");
4069 if (layout({ {w + kw},
4079 FailureOr<Operation *> generateNwcConv() {
4082 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4084 op,
"failed to match conv::Nwc 3-par 2-red");
4087 if (layout({ {n, strideW * w + dilationW * kw, c},
4097 FailureOr<Operation *> generateNcwConv() {
4100 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4102 op,
"failed to match conv::Ncw 3-par 2-red");
4104 if (layout({ {n, c, strideW * w + dilationW * kw},
4114 FailureOr<Operation *> generateNwcPooling() {
4117 if (!iters({Par(), Par(), Par(), Red()}))
4119 "failed to match pooling 3-par 1-red");
4122 if (layout({ {n, strideW * w + dilationW * kw, c},
4132 FailureOr<Operation *> generateNcwPooling() {
4135 if (!iters({Par(), Par(), Par(), Red()}))
4137 "failed to match pooling 3-par 1-red");
4139 if (layout({ {n, c, strideW * w + dilationW * kw},
4149 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4150 bool vecChDimScalableFlag =
false,
4151 bool flatten =
false) {
4154 if (!iters({Par(), Par(), Par(), Red()}))
4156 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4159 if (layout({ {n, strideW * w + dilationW * kw, c},
4162 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4168 ConvOperationKind oper = ConvOperationKind::Conv;
4170 StringAttr poolExtOp;
4171 bool isPoolExt =
false;
4172 int strideW, dilationW;
4173 Value lhsShaped, rhsShaped, resShaped;
4174 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4175 vector::CombiningKind reductionKind;
4178 void setConvOperationKind(
Operation *reduceOp) {
4179 int numBlockArguments =
4180 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4181 if (numBlockArguments == 1) {
4186 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4187 llvm::IsaPred<BlockArgument>);
4188 Operation *feedOp = (*feedValIt).getDefiningOp();
4190 oper = ConvOperationKind::Pool;
4195 oper = ConvOperationKind::Conv;
4199 oper = ConvOperationKind::Pool;
4209 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4210 Conv1DGenerator conv1dGen(rewriter, op);
4211 auto res = conv1dGen.generateNonChanneledConv();
4214 res = conv1dGen.generateNwcConv();
4217 res = conv1dGen.generateNcwConv();
4220 res = conv1dGen.generateNwcPooling();
4223 res = conv1dGen.generateNcwPooling();
4230 uint64_t vecChDimSize = ShapedType::kDynamic;
4231 bool vecChDimScalableFlag =
false;
4232 if (!inputVecSizes.empty()) {
4235 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4236 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4237 "Not a 1D depthwise conv!");
4240 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4241 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4243 vecChDimSize = inputVecSizes[chDimIdx];
4244 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4246 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4247 flatten1DDepthwiseConv);
4256 if (failed(resultOrFail))
4260 rewriter.
eraseOp(op.getOperation());
4263 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
SmallVector< OpFoldResult > innerTiles
SmallVector< int64_t > innerDimsPos
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, SmallVectorImpl< Value > &newResults)
Vectorize linalg.unpack as:
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Vectorize a named linalg contraction op into: vector::TransferReadOp - Reads vectors from the operand...
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
This hook considers two cases: (1) If the input-vector-sizes are empty, then the vector sizes will be...
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.