36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
64 template <
typename OpType>
67 block.
walk([&](OpType op) {
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize,
int strideW,
int dilationW,
84 int64_t wSizeStep,
bool isSingleChanneled) {
86 if (isSingleChanneled) {
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.
create<vector::ExtractOp>(
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep,
bool isSingleChanneled) {
136 if (isSingleChanneled) {
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
159 Value res, int64_t wSize, int64_t wSizeStep,
161 bool isSingleChanneled) {
163 if (isSingleChanneled) {
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.
create<vector::InsertStridedSliceOp>(
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.
create<vector::InsertStridedSliceOp>(
191 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
208 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
211 if (dimPermutation.has_value()) {
213 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
215 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
217 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
230 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
235 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
242 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
251 std::optional<AffineMap> maybeMaskingMap);
256 bool isValidMaskingMap(
AffineMap maskingMap) {
305 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
308 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
311 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
312 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
319 unsigned operandDimPos;
320 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
324 Value dynamicDim = linalgOp.hasPureTensorSemantics()
326 linalgOp.getLoc(), operand, operandDimPos)
328 linalgOp.getLoc(), operand, operandDimPos);
329 iterSpaceValueSizes.push_back(dynamicDim);
345 if (!inputVectorSizes.empty()) {
349 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350 scalableVecDims.append(inputScalableVecDims.begin(),
351 inputScalableVecDims.end());
356 canonicalVecShape = linalgOp.getStaticLoopRanges();
357 scalableVecDims.append(linalgOp.getNumLoops(),
false);
360 LDBG(
"Canonical vector shape: ");
361 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362 LLVM_DEBUG(llvm::dbgs() <<
"\n");
363 LDBG(
"Scalable vector dims: ");
364 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365 LLVM_DEBUG(llvm::dbgs() <<
"\n");
367 if (ShapedType::isDynamicShape(canonicalVecShape))
371 initIterSpaceStaticSizes(linalgOp);
376 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
386 Value VectorizationState::getOrCreateMaskFor(
388 std::optional<AffineMap> maybeMaskingMap) {
390 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391 "Ill-formed masking map.");
394 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
398 assert(!maskableOp.isMasked() &&
399 "Masking an operation that is already masked");
402 assert((!maybeMaskingMap || *maybeMaskingMap) &&
403 "Unexpected null mask permutation map");
405 maybeMaskingMap ? *maybeMaskingMap
407 linalgOp.getNumLoops(), rewriter.
getContext());
409 LDBG(
"Masking map: " << maskingMap <<
"\n");
413 auto activeMaskIt = activeMaskCache.find(maskingMap);
414 if (activeMaskIt != activeMaskCache.end()) {
415 Value mask = activeMaskIt->second;
416 LDBG(
"Reusing mask: " << mask <<
"\n");
427 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
429 auto maskShape = maskType.getShape();
431 LDBG(
"Mask shape: ");
432 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433 LLVM_DEBUG(llvm::dbgs() <<
"\n");
435 if (permutedStaticSizes == maskShape) {
436 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
437 activeMaskCache[maskingMap] =
Value();
444 assert(!maskShape.empty() && !upperBounds.empty() &&
445 "Masked 0-d vectors are not supported yet");
448 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
449 maskType, upperBounds);
450 LDBG(
"Creating new mask: " << mask <<
"\n");
451 activeMaskCache[maskingMap] = mask;
458 std::optional<AffineMap> maybeIndexingMap) {
459 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
461 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462 if (maybeIndexingMap)
463 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
467 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
470 LDBG(
"No mask required\n");
475 assert(opToMask &&
"Expected a valid operation to mask");
476 auto maskOp = cast<vector::MaskOp>(
478 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
484 LDBG(
"Masked operation: " << *maskOp <<
"\n");
507 "expected projected permutation");
509 assert(res.getNumDims() ==
510 (res.getNumResults() - res.getNumOfZeroResults()) &&
511 "expected reindexed map with same number of dims and results");
543 std::optional<vector::CombiningKind>
545 using ::mlir::vector::CombiningKind;
550 .Case<arith::AddIOp, arith::AddFOp>(
551 [&](
auto op) {
return CombiningKind::ADD; })
552 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
553 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
554 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
555 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
556 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
557 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
559 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
560 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
561 .Case<arith::MulIOp, arith::MulFOp>(
562 [&](
auto op) {
return CombiningKind::MUL; })
563 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
564 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
565 .Default([&](
auto op) {
return std::nullopt; });
576 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
581 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582 combinerOps.size() != 1)
586 return combinerOps[0];
592 auto dstVecType = dyn_cast<VectorType>(dstType);
594 if (dstVecType.getRank() == 0)
600 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
612 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
613 return b.
create<vector::MultiDimReductionOp>(
614 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
618 return llvm::to_vector(
625 return isa<linalg::ReduceOp>(op) ||
626 (isa<linalg::GenericOp>(op) &&
640 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
641 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
650 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
652 auto vectorType = state.getCanonicalVecType(
656 if (vectorType.getRank() > 0) {
659 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
661 assert(value.
getType() == vectorType &&
"Incorrect type");
662 write = rewriter.
create<vector::TransferWriteOp>(
663 loc, value, outputOperand->
get(), indices, writeMap);
666 if (!isa<VectorType>(value.
getType()))
667 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
668 assert(value.
getType() == vectorType &&
"Incorrect type");
669 write = rewriter.
create<vector::TransferWriteOp>(
673 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
677 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
683 LDBG(
"vectorized op: " << *write <<
"\n");
693 std::function<LogicalResult(
Operation *,
bool)>;
712 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
721 linalgOp.getDpsInitOperand(output.index()), state);
723 newResults.push_back(newResult);
737 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
740 auto loc = indexOp.getLoc();
743 auto dim = indexOp.getDim();
745 auto indexVectorType =
747 state.getScalableVecDims()[dim]);
748 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
752 if (dim == targetShape.size() - 1)
758 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759 std::swap(permPattern[dim], permPattern.back());
763 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
764 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
767 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768 std::swap(transposition.back(), transposition[dim]);
770 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
778 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
782 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
787 if (not extractOp.getIndices().empty()) {
788 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
792 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
793 return !VectorType::isValidElementType(type);
813 tensor::ExtractOp extractOp,
816 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
817 auto loc = extractOp.getLoc();
820 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
822 const size_t numIndices = extractOp.getIndices().size();
823 for (
size_t i = 1; i < numIndices; i++) {
824 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
828 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
831 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
834 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
836 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
862 (linalgOp.hasDynamicShape() ||
863 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864 "For statically shaped Linalg Ops, only one "
865 "non-unit loop dim is expected");
866 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
868 size_t idx = loopRanges.size() - 1;
869 for (; idx != 0; idx--)
870 if (loopRanges[idx] != 1)
878 VectorType resType) {
880 assert(((llvm::count_if(resType.getShape(),
881 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882 "n-D vectors are not yet supported");
888 auto *block = linalgOp.getBlock();
889 if (isa<BlockArgument>(val))
890 return llvm::all_of(block->getArguments(),
891 [&val](
Value v) { return (v != val); });
894 assert(defOp &&
"This is neither a block argument nor an operation result");
899 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
903 auto *ancestor = block->findAncestorOpInBlock(*defOp);
910 if (isa<arith::ConstantOp>(ancestor))
914 for (
auto op : ancestor->getOperands())
938 bool &foundIndexOp, VectorType resType) {
940 assert(((llvm::count_if(resType.getShape(),
941 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942 "n-D vectors are not yet supported");
948 auto *block = linalgOp.getBlock();
949 if (isa<BlockArgument>(val))
950 return llvm::all_of(block->getArguments(),
951 [&val](
Value v) { return (v != val); });
954 assert(defOp &&
"This is neither a block argument nor an operation result");
956 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
963 auto *ancestor = block->findAncestorOpInBlock(*defOp);
970 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
974 for (
auto op : ancestor->getOperands())
994 LinalgOp &linalgOp, VectorType resType) {
996 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
999 if (inputShape.getShape().empty())
1004 bool isOutput1DVector =
1005 (llvm::count_if(resType.getShape(),
1006 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1008 if (!isOutput1DVector)
1011 bool leadingIdxsLoopInvariant =
true;
1017 auto indices = extractOp.getIndices();
1018 auto leadIndices = indices.drop_back(1);
1021 if (inputShape.getShape()[i] == 1)
1027 if (!leadingIdxsLoopInvariant) {
1028 LDBG(
"Found gather load: " << extractOp);
1036 auto extractOpTrailingIdx = indices.back();
1040 if (leadingIdxsLoopInvariant &&
1042 LDBG(
"Found scalar broadcast load: " << extractOp);
1051 bool foundIndexOp =
false;
1053 foundIndexOp, resType);
1056 bool isRowVector = resType.getShape().back() != 1;
1057 isContiguousLoad &= (foundIndexOp && isRowVector);
1059 if (isContiguousLoad) {
1060 LDBG(
"Found contigous load: " << extractOp);
1065 LDBG(
"Found gather load: " << extractOp);
1076 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1079 auto loc = extractOp.getLoc();
1082 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1087 auto passThruConstantOp =
1093 extractOp.getIndices().size(),
1094 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1105 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106 maskConstantOp, passThruConstantOp);
1107 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1109 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1132 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1133 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1135 transferReadIdxs.push_back(idx);
1139 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1142 resultType.getScalableDims().back()),
1144 transferReadIdxs.push_back(
1145 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1149 auto dstRank = resultType.getRank();
1150 auto srcRank = extractOp.getTensor().getType().getRank();
1159 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1160 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1161 permutationMap, inBounds);
1168 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
1170 auto *maskedReadOp =
1173 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1181 int32_t rankDiff = dstRank - srcRank;
1189 while (rankDiff > 0) {
1190 permutationMap = permutationMap.insertResult(
1195 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1196 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1199 LDBG(
"Vectorised as contiguous load: " << extractOp);
1212 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1213 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1217 (outputType && reduceType.getShape() == outputType.getShape()))
1246 LDBG(
"vectorize op " << *op <<
"\n");
1249 if (!customVectorizationHooks.empty()) {
1250 for (
auto &customFunc : customVectorizationHooks) {
1260 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1270 auto blockArg = dyn_cast<BlockArgument>(operand);
1271 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1272 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1276 linalgOp.getRegionOutputArgs(),
1277 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1280 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1282 if (!reductionOperands.empty()) {
1283 assert(reductionOperands.size() == 1);
1285 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1286 reductionOperands[0].second, bvm);
1293 VectorType firstMaxRankedType;
1295 auto vecOperand = bvm.
lookup(operand);
1296 assert(vecOperand &&
"Vector operand couldn't be found");
1298 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1299 if (vecType && (!firstMaxRankedType ||
1300 firstMaxRankedType.getRank() < vecType.getRank()))
1301 firstMaxRankedType = vecType;
1307 assert(vecOperand &&
"Vector operand couldn't be found");
1309 if (firstMaxRankedType) {
1312 firstMaxRankedType.getScalableDims());
1315 vecOperands.push_back(vecOperand);
1321 resultTypes.push_back(
1324 firstMaxRankedType.getScalableDims())
1356 static LogicalResult
1360 LDBG(
"Vectorizing operation as linalg generic\n");
1361 Block *block = linalgOp.getBlock();
1368 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1370 if (linalgOp.getNumDpsInits() == 0)
1375 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1376 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1377 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1378 if (linalgOp.isScalar(opOperand)) {
1379 bvm.
map(bbarg, opOperand->get());
1385 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1388 VectorType readType;
1390 if (linalgOp.isDpsInput(opOperand)) {
1393 readType = state.getCanonicalVecType(elemType);
1400 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1406 loc, readType, opOperand->get(), indices, readMap);
1407 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1412 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1414 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1420 if (readType.getRank() == 0)
1436 hooks.push_back(vectorizeYield);
1443 hooks.push_back(vectorizeIndex);
1450 hooks.push_back(vectorizeExtract);
1457 LDBG(
"failed to vectorize: " << op <<
"\n");
1462 state.maskOperation(rewriter, result.
newOp, linalgOp);
1463 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1488 bool useInBoundsInsteadOfMasking) {
1490 auto inputType = cast<VectorType>(input.
getType());
1491 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1492 inputType.getElementType());
1493 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1494 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1495 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1497 if (useInBoundsInsteadOfMasking) {
1499 for (
unsigned i = 0; i < rank; i++)
1500 inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1501 !ShapedType::isDynamic(destShape[i]);
1509 assert(llvm::none_of(
1510 destShape.drop_front(inputVectorSizes.size()),
1511 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1512 "Only dims aligned with inputVectorSizes may be dynamic");
1513 if (useInBoundsInsteadOfMasking)
1515 bool needMaskForWrite = !llvm::equal(
1516 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1517 if (needMaskForWrite) {
1519 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1520 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1523 Value maskForWrite =
1524 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1556 static LogicalResult
1564 auto padValue = packOp.getPaddingValue();
1566 padValue = rewriter.
create<arith::ConstantOp>(
1567 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1570 LogicalResult status =
1571 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1572 .reifyResultShapes(rewriter, reifiedReturnShapes);
1574 assert(succeeded(status) &&
"failed to reify result shapes");
1579 bool useInBoundsInsteadOfMasking =
false;
1580 if (inputVectorSizes.empty()) {
1582 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1583 useInBoundsInsteadOfMasking =
true;
1588 auto innerTiles = packOp.getStaticInnerTiles();
1589 auto innerDimsPos = packOp.getInnerDimsPos();
1590 auto outerDimsPerm = packOp.getOuterDimsPerm();
1591 if (!outerDimsPerm.empty())
1594 for (
auto [idx, size] :
enumerate(innerTiles))
1595 inputShape[innerDimsPos[idx]] *= size;
1597 rewriter, loc, packOp.getSource(), inputShape, padValue,
1598 useInBoundsInsteadOfMasking);
1602 destShape.append(innerTiles.begin(), innerTiles.end());
1604 packOp.getDestType().getElementType());
1606 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1609 auto destPermutation =
1611 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1612 loc, shapeCastOp.getResult(), destPermutation);
1616 rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1617 inputVectorSizes,
false);
1618 newResults.push_back(write->getResult(0));
1631 static LogicalResult
1639 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1644 bool useInBoundsInsteadOfMasking =
false;
1647 auto destSize = unpackOp.getDestRank();
1649 if (!inputVectorSizes.empty())
1650 assert(inputVectorSizes.size() == destSize &&
1651 "Incorrect number of input vector sizes");
1662 if (vectorSizes.empty()) {
1663 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1664 if (!outerDimsPerm.empty())
1667 vectorSizes[pos] *= innerTiles[i];
1669 useInBoundsInsteadOfMasking =
true;
1693 for (
auto [index, size] :
enumerate(innerTiles)) {
1694 readVectorSizes[innerDimPos[index]] =
1697 if (!outerDimsPerm.empty()) {
1700 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1704 LogicalResult status =
1705 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1706 .reifyResultShapes(rewriter, reifiedRetShapes);
1707 if (status.failed()) {
1708 LDBG(
"Unable to reify result shapes of " << unpackOp);
1713 auto padValue = rewriter.
create<arith::ConstantOp>(
1714 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1719 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1722 PackingMetadata packMetadata;
1725 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1727 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1729 RankedTensorType stripMineTensorType =
1732 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1733 loc, readResult, lastDimToInsertPosPerm);
1736 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1737 stripMineTensorType, packMetadata.reassociations);
1738 mlir::VectorType vecCollapsedType =
1739 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1740 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1741 loc, vecCollapsedType, transposeOp->getResult(0));
1746 unpackOp.getDestType().hasStaticShape()
1748 : shapeCastOp.getResultVectorType().getShape());
1750 rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1751 writeVectorSizes, useInBoundsInsteadOfMasking);
1752 newResults.push_back(write->
getResult(0));
1759 static LogicalResult
1763 auto padValue = padOp.getConstantPaddingValue();
1771 LogicalResult status =
1772 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1773 .reifyResultShapes(rewriter, reifiedReturnShapes);
1775 assert(succeeded(status) &&
"failed to reify result shapes");
1777 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1780 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1782 newResults.push_back(write->
getResult(0));
1790 LDBG(
"reduction precondition failed: no reduction iterator\n");
1793 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1794 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1800 LDBG(
"reduction precondition failed: reduction detection failed\n");
1807 static LogicalResult
1809 bool flatten1DDepthwiseConv) {
1810 if (flatten1DDepthwiseConv) {
1811 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1816 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1817 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1823 Value lhs = conv.getDpsInputOperand(0)->get();
1825 auto shapeWithoutCh = lhsShape.drop_back(1);
1826 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1827 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1828 "channel dim can be dynamic\n");
1835 static LogicalResult
1837 bool flatten1DDepthwiseConv) {
1838 if (isa<ConvolutionOpInterface>(op.getOperation()))
1847 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1851 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1856 static LogicalResult
1860 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1861 return !getConstantIntValue(res).has_value();
1863 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1867 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1868 unpackOp.getDestType().hasStaticShape() &&
1869 unpackOp.getSourceType().hasStaticShape();
1870 if (!satisfyEmptyCond &&
1879 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
1881 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1884 if (!inputVectorSizes.empty() &&
1890 linalgOp, flatten1DDepthwiseConv))) {
1891 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1904 customPreconditions,
1907 customPrecondition(&innerOp, vectorizeNDExtract));
1911 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1912 return !VectorType::isValidElementType(type);
1916 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1917 return !VectorType::isValidElementType(type);
1928 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1934 LDBG(
"precondition failed: not projected permutations\n");
1938 LDBG(
"precondition failed: reduction preconditions\n");
1944 static LogicalResult
1947 auto padValue = packOp.getPaddingValue();
1950 LDBG(
"pad value is not constant: " << packOp <<
"\n");
1954 bool satisfyEmptyCond =
true;
1955 if (inputVectorSizes.empty()) {
1956 if (!packOp.getDestType().hasStaticShape() ||
1957 !packOp.getSourceType().hasStaticShape())
1958 satisfyEmptyCond =
false;
1961 if (!satisfyEmptyCond &&
1963 resultTensorShape.take_front(packOp.getSourceRank()),
1967 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
1968 return !getConstantIntValue(v).has_value();
1970 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
1977 static LogicalResult
1980 auto padValue = padOp.getConstantPaddingValue();
1982 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1991 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1992 std::optional<int64_t> res = getConstantIntValue(v);
1993 return !res.has_value() || res.value() != 0;
1995 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
2004 static LogicalResult
2008 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2009 "Number of input vector sizes and scalable dims doesn't match");
2011 size_t numOfScalableDims =
2012 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2014 if (numOfScalableDims == 0)
2017 auto linalgOp = dyn_cast<LinalgOp>(op);
2025 if (numOfScalableDims > 2)
2045 bool seenNonUnitParallel =
false;
2046 auto iterators = linalgOp.getIteratorTypesArray();
2048 int64_t idx = scalableFlags.size() - 1;
2049 while (!scalableFlags[idx]) {
2050 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2051 seenNonUnitParallel |=
2052 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2054 iterators.pop_back();
2055 scalableFlags.pop_back();
2060 switch (iterators.back()) {
2061 case utils::IteratorType::reduction: {
2063 if (iterators.size() != inputVectorSizes.size()) {
2064 LDBG(
"Non-trailing reduction dim requested for scalable "
2068 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2069 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2070 "is not supported\n");
2075 case utils::IteratorType::parallel: {
2077 if (seenNonUnitParallel) {
2078 LDBG(
"Inner parallel dim not requested for scalable "
2090 if (numOfScalableDims == 2) {
2094 if (iterators.back() == utils::IteratorType::reduction) {
2095 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2099 scalableFlags.pop_back();
2100 iterators.pop_back();
2102 if (!scalableFlags.back() ||
2103 (iterators.back() != utils::IteratorType::parallel))
2109 if (linalgOp.hasUserDefinedMaps())
2114 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2115 isa<linalg::MatmulTransposeAOp>(op) ||
2116 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2123 bool flatten1DDepthwiseConv) {
2129 inputScalableVecDims)))
2133 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2136 flatten1DDepthwiseConv);
2138 .Case<tensor::PadOp>([&](
auto padOp) {
2141 .Case<tensor::PackOp>([&](
auto packOp) {
2144 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2147 .Default([](
auto) {
return failure(); });
2153 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2155 for (
auto op : make_early_inc_range(toReplace)) {
2158 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2159 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2160 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2166 return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2179 bool vectorizeNDExtract,
2180 bool flatten1DDepthwiseConv) {
2181 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2182 LDBG(
"Input vector sizes: ");
2183 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2184 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2185 LDBG(
"Input scalable vector dims: ");
2186 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2187 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2191 flatten1DDepthwiseConv))) {
2192 LDBG(
"Vectorization pre-conditions failed\n");
2198 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2199 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2200 inputScalableVecDims))) {
2201 LDBG(
"Vectorization state couldn't be initialized\n");
2207 auto vectorizeResult =
2209 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2213 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2215 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2216 flatten1DDepthwiseConv);
2217 if (succeeded(convOr)) {
2218 llvm::append_range(results, (*convOr)->getResults());
2222 LDBG(
"Unsupported convolution can't be vectorized.\n");
2226 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2239 .Case<tensor::PadOp>([&](
auto padOp) {
2243 .Case<tensor::PackOp>([&](
auto packOp) {
2247 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2249 inputVectorSizes, results);
2251 .Default([](
auto) {
return failure(); });
2253 if (failed(vectorizeResult)) {
2254 LDBG(
"Vectorization failed\n");
2258 if (!results.empty())
2267 memref::CopyOp copyOp) {
2268 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2269 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2270 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2275 if (!VectorType::isValidElementType(srcElementType) ||
2276 !VectorType::isValidElementType(dstElementType))
2287 loc, readType, copyOp.getSource(), indices,
2289 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2295 loc,
readValue, copyOp.getTarget(), indices,
2306 template <
typename OpTy>
2314 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2315 if (
auto op = dyn_cast<OpTy>(user))
2316 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2322 tensor::PadOp padOp, OpTy op)
const = 0;
2350 vector::TransferReadOp xferOp)
const override {
2352 if (!padOp.hasZeroLowPad())
2355 auto padValue = padOp.getConstantPaddingValue();
2359 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2364 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2366 xferOp.getSourceMutable().assign(padOp.getSource());
2367 xferOp.getPaddingMutable().assign(padValue);
2412 vector::TransferWriteOp xferOp)
const override {
2414 if (xferOp.getTransferRank() == 0)
2418 if (!padOp.hasZeroLowPad())
2421 auto padValue = padOp.getConstantPaddingValue();
2425 if (!xferOp->hasOneUse())
2427 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2431 if (!trimPadding.hasZeroOffset())
2434 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2442 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2443 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2445 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2461 tensor::ExtractSliceOp afterTrimming)
const {
2464 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2465 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2468 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2469 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2474 if (t1.getRank() != t2.getRank())
2479 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2480 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2482 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2487 if (t1.getNumDynamicDims() == 0)
2495 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2499 assert(
static_cast<size_t>(t1.getRank()) ==
2500 beforeSlice.getMixedSizes().size());
2501 assert(
static_cast<size_t>(t2.getRank()) ==
2502 afterTrimming.getMixedSizes().size());
2504 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2506 if (!t1.isDynamicDim(i))
2508 auto size1 = beforeSlice.getMixedSizes()[i];
2509 auto size2 = afterTrimming.getMixedSizes()[i];
2516 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2517 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2523 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2524 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2525 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2526 minOp1.getOperands() == minOp2.getOperands())
2549 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2550 auto source = bcast.getSource();
2551 if (llvm::dyn_cast<VectorType>(source.getType()))
2559 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2560 return fill.getInputs()[0];
2565 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2572 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2580 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2608 auto sourceType = sliceOp.getSource().getType();
2609 if (!VectorType::isValidElementType(sourceType.getElementType()))
2612 auto resultType = sliceOp.getResultType();
2626 bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2628 if (!padValue && isOutOfBoundsRead) {
2629 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
2634 auto elemType = sourceType.getElementType();
2635 padValue = rewriter.
create<arith::ConstantOp>(
2636 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2643 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2644 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
2645 if (!sourceType.isDynamicDim(i)) {
2646 vecShape.push_back(sourceType.getDimSize(i));
2649 readInBounds.push_back(
true);
2650 writeInBounds.push_back(
true);
2651 }
else if (!resultType.isDynamicDim(i)) {
2657 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2660 readInBounds.push_back(
false);
2663 writeInBounds.push_back(
false);
2671 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2676 rewriter.
create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2677 auto read = rewriter.
create<vector::TransferReadOp>(
2678 sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2683 rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2687 sliceOp, read, sliceOp.getDest(), writeIndices,
2723 tensor::InsertSliceOp insertOp)
const override {
2725 if (!padOp.hasZeroLowPad())
2728 if (!insertOp.hasUnitStride())
2731 auto padValue = padOp.getConstantPaddingValue();
2735 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2738 if (insertOp.getDest() == padOp.getResult())
2742 padOp.getType().getElementType());
2743 unsigned vecRank = vecType.getRank();
2744 unsigned tensorRank = insertOp.getType().getRank();
2749 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2751 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2752 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2763 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2764 auto read = rewriter.
create<vector::TransferReadOp>(
2765 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2771 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2774 insertOp, read, insertOp.getDest(), writeIndices,
2805 LDBG(
"interleavedUses precondition failed, firstOp: "
2806 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2809 for (
auto v : values) {
2810 for (
auto &u : v.getUses()) {
2812 if (owner == firstOp || owner == secondOp)
2818 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2819 <<
", second op: " << *secondOp <<
"\n");
2829 memref::SubViewOp subViewOp;
2831 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2833 return memref::SubViewOp();
2834 subViewOp = newSubViewOp;
2846 if (xferOp.getMask())
2850 Value viewOrAlloc = xferOp.getSource();
2859 Value subView = subViewOp.getResult();
2862 memref::CopyOp copyOp;
2863 for (
auto &u : subView.
getUses()) {
2864 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2865 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2866 if (newCopyOp.getTarget() != subView)
2880 for (
auto &u : viewOrAlloc.
getUses()) {
2881 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2882 assert(isa<MemRefType>(newFillOp.output().getType()));
2883 if (newFillOp.output() != viewOrAlloc)
2887 maybeFillOp = newFillOp;
2892 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2894 "padding value does not match fill");
2897 Value in = copyOp.getSource();
2903 auto vectorType = xferOp.getVectorType();
2904 Value res = rewriter.
create<vector::TransferReadOp>(
2905 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2906 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2911 rewriter.
eraseOp(maybeFillOp);
2923 if (xferOp.getMask())
2927 Value viewOrAlloc = xferOp.getSource();
2936 Value subView = subViewOp.getResult();
2939 memref::CopyOp copyOp;
2940 for (
auto &u : subViewOp.getResult().getUses()) {
2941 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2942 if (newCopyOp.getSource() != subView)
2954 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2955 Value out = copyOp.getTarget();
2962 auto vector = xferOp.getVector();
2963 rewriter.
create<vector::TransferWriteOp>(
2964 xferOp.getLoc(), vector, out, xferOp.getIndices(),
2965 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2982 template <
int N,
typename IntTy,
typename... IntTy2>
2984 val = shapedType.getShape()[N];
2989 template <
typename... IntTy>
2991 bindShapeDims<0>(shapedType, vals...);
2995 bool isCastOfBlockArgument(
Operation *op) {
3000 bool isSupportedPoolKind(vector::CombiningKind kind) {
3002 case vector::CombiningKind::ADD:
3003 case vector::CombiningKind::MAXNUMF:
3004 case vector::CombiningKind::MAXIMUMF:
3005 case vector::CombiningKind::MAXSI:
3006 case vector::CombiningKind::MAXUI:
3007 case vector::CombiningKind::MINNUMF:
3008 case vector::CombiningKind::MINIMUMF:
3009 case vector::CombiningKind::MINSI:
3051 struct Conv1DGenerator
3053 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3056 strideW(strideW), dilationW(dilationW) {
3058 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3060 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3061 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3062 resShaped = linalgOp.getDpsInitOperand(0)->get();
3063 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3064 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3065 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3066 if (!lhsShapedType || !rhsShapedType || !resShapedType)
3070 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3071 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3079 if (!setOperKind(reduceOp))
3085 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3086 *maybeKind != vector::CombiningKind::OR) &&
3087 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3090 reductionKind = maybeKind.value();
3092 auto rhsRank = rhsShapedType.getRank();
3095 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3129 int64_t nSize, wSize, cSize, kwSize, fSize;
3132 switch (conv1DOpOrder) {
3135 nSize = fSize = cSize = 0;
3142 (wSize + kwSize - 1)};
3143 rhsShape = {kwSize};
3164 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3169 rhsShape = {kwSize, cSize, fSize};
3172 rhsShape = {kwSize};
3175 resShape = {nSize, wSize, fSize};
3191 lhsShape = {nSize, cSize,
3195 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3199 rhsShape = {fSize, cSize, kwSize};
3202 rhsShape = {kwSize};
3205 resShape = {nSize, fSize, wSize};
3209 vector::TransferWriteOp write;
3210 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3215 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3217 Type lhsEltType = lhsShapedType.getElementType();
3218 Type rhsEltType = rhsShapedType.getElementType();
3219 Type resEltType = resShapedType.getElementType();
3229 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3232 Value rhs =
nullptr;
3234 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3236 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3242 switch (conv1DOpOrder) {
3250 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3251 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3253 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3257 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3259 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3260 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3271 kwSize, strideW, dilationW, wSizeStep,
3277 wSizeStep, isSingleChanneled);
3279 auto linearIndex = [&](int64_t kw, int64_t w) {
3280 return kw * (wSize / wSizeStep) + w;
3286 for (int64_t kw = 0; kw < kwSize; ++kw) {
3287 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3290 if (isSingleChanneled) {
3291 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3292 lhsVals[linearIndex(kw, w)],
3293 rhsVals[kw], resVals[w]);
3295 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3296 lhsVals[linearIndex(kw, w)],
3297 rhsVals[kw], resVals[w]);
3301 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3317 switch (conv1DOpOrder) {
3324 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3325 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3331 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3339 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3340 if (srcElementType == dstElementType)
3345 const Type dstType =
3346 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3348 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3349 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3352 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3353 srcWidth < dstWidth)
3354 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3356 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3357 srcWidth < dstWidth)
3358 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3360 assert(
false &&
"unhandled promotion case");
3367 vector::IteratorType par = vector::IteratorType::parallel;
3368 vector::IteratorType red = vector::IteratorType::reduction;
3373 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3375 MapList{{n, w, c}, {c, f}, {n, w, f}},
3377 contrationOp.setKind(reductionKind);
3378 return contrationOp;
3385 return rewriter.
create<vector::OuterProductOp>(
3386 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3408 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3409 bool channelDimScalableFlag,
3414 bool scalableChDim =
false;
3415 bool useMasking =
false;
3416 int64_t nSize, wSize, cSize, kwSize;
3419 if (ShapedType::isDynamic(cSize)) {
3420 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3421 cSize = channelDimVecSize;
3425 scalableChDim = channelDimScalableFlag;
3429 assert(!(useMasking && flatten) &&
3430 "Unsupported flattened conv with dynamic shapes");
3435 vector::TransferWriteOp write;
3436 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3441 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3443 Type lhsEltType = lhsShapedType.getElementType();
3444 Type rhsEltType = rhsShapedType.getElementType();
3445 Type resEltType = resShapedType.getElementType();
3450 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3452 lhsEltType, {
false,
false, scalableChDim});
3453 VectorType rhsType =
3455 {
false, scalableChDim});
3456 VectorType resType =
3458 {
false,
false, scalableChDim});
3471 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3472 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3476 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3479 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3486 Value lhs = rewriter.
create<vector::TransferReadOp>(
3487 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3488 auto maybeMaskedLhs = maybeMaskXferOp(
3489 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3492 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3494 auto maybeMaskedRhs = maybeMaskXferOp(
3495 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3498 Value res = rewriter.
create<vector::TransferReadOp>(
3499 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3500 auto maybeMaskedRes = maybeMaskXferOp(
3501 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3513 for (int64_t kw = 0; kw < kwSize; ++kw) {
3514 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3515 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3516 loc, maybeMaskedLhs->getResult(0),
3518 inOutSliceSizes, inOutStrides));
3522 for (int64_t kw = 0; kw < kwSize; ++kw) {
3523 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3524 loc, maybeMaskedRhs->getResult(0),
3528 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3529 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3530 loc, maybeMaskedRes->getResult(0),
3535 auto linearIndex = [&](int64_t kw, int64_t w) {
3536 return kw * (wSize / wSizeStep) + w;
3541 auto inOutFlattenSliceSizes =
3543 auto lhsTypeAfterFlattening =
3545 auto resTypeAfterFlattening =
3549 for (int64_t kw = 0; kw < kwSize; ++kw) {
3550 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3551 Value lhsVal = lhsVals[linearIndex(kw, w)];
3552 Value resVal = resVals[w];
3556 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3557 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3558 resVal = rewriter.
create<vector::ShapeCastOp>(
3559 loc, resTypeAfterFlattening, resVals[w]);
3561 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3562 rhsVals[kw], resVal, flatten);
3565 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3572 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3574 for (
auto &collection :
3575 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3576 for (
Value v : collection)
3583 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3584 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3585 loc, resVals[w], maybeMaskedRes->getResult(0),
3595 loc, maybeMaskedRes->getResult(0), resShaped,
3597 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3608 auto rhsTy = cast<ShapedType>(rhs.
getType());
3609 auto resTy = cast<ShapedType>(res.
getType());
3612 lhs =
promote(rewriter, loc, lhs, resTy);
3623 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3624 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3627 for (
int i = 0; i < resSize / rhsSize; ++i) {
3628 for (
int j = 0;
j < rhsSize; ++
j)
3629 indices.push_back(
j);
3632 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3635 rhs = rewriter.
create<vector::BroadcastOp>(
3636 loc, resTy.clone(rhsTy.getElementType()), rhs);
3638 rhs =
promote(rewriter, loc, rhs, resTy);
3643 if (isa<FloatType>(resTy.getElementType()))
3644 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3646 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3647 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3652 FailureOr<Operation *> generateNonChanneledConv() {
3655 if (!iters({Par(), Red()}))
3657 "failed to match conv::W 1-par 1-red");
3660 if (layout({ {w + kw},
3670 FailureOr<Operation *> generateNwcConv() {
3673 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3675 op,
"failed to match conv::Nwc 3-par 2-red");
3678 if (layout({ {n, strideW * w + dilationW * kw, c},
3688 FailureOr<Operation *> generateNcwConv() {
3691 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3693 op,
"failed to match conv::Ncw 3-par 2-red");
3695 if (layout({ {n, c, strideW * w + dilationW * kw},
3705 FailureOr<Operation *> generateNwcPooling() {
3708 if (!iters({Par(), Par(), Par(), Red()}))
3710 "failed to match pooling 3-par 1-red");
3713 if (layout({ {n, strideW * w + dilationW * kw, c},
3723 FailureOr<Operation *> generateNcwPooling() {
3726 if (!iters({Par(), Par(), Par(), Red()}))
3728 "failed to match pooling 3-par 1-red");
3730 if (layout({ {n, c, strideW * w + dilationW * kw},
3740 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3741 bool vecChDimScalableFlag =
false,
3742 bool flatten =
false) {
3745 if (!iters({Par(), Par(), Par(), Red()}))
3747 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3750 if (layout({ {n, strideW * w + dilationW * kw, c},
3753 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3759 enum OperKind { Conv, Pool };
3761 OperKind oper = Conv;
3763 StringAttr poolExtOp;
3764 bool isPoolExt =
false;
3765 int strideW, dilationW;
3766 Value lhsShaped, rhsShaped, resShaped;
3767 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3768 vector::CombiningKind reductionKind;
3779 int numBlockArguments =
3780 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3781 switch (numBlockArguments) {
3787 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3788 llvm::IsaPred<BlockArgument>);
3789 Operation *feedOp = (*feedValIt).getDefiningOp();
3790 if (isCastOfBlockArgument(feedOp)) {
3794 }
else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3795 (isa<arith::AndIOp>(feedOp) &&
3798 if (isa<BlockArgument>(v))
3800 if (Operation *op = v.getDefiningOp())
3801 return isCastOfBlockArgument(op);
3824 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3831 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3832 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3833 Conv1DGenerator e(rewriter, op, stride, dilation);
3834 auto res = e.generateNonChanneledConv();
3837 res = e.generateNwcConv();
3840 res = e.generateNcwConv();
3843 res = e.generateNwcPooling();
3846 res = e.generateNcwPooling();
3853 uint64_t vecChDimSize = ShapedType::kDynamic;
3854 bool vecChDimScalableFlag =
false;
3855 if (!inputVecSizes.empty()) {
3858 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3859 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3860 "Not a 1D depthwise conv!");
3863 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3864 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3866 vecChDimSize = inputVecSizes[chDimIdx];
3867 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3869 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3870 flatten1DDepthwiseConv);
3879 if (failed(resultOrFail))
3883 rewriter.
eraseOp(op.getOperation());
3886 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite tensor.insert.slice as a vector.transfer_read + vector.transfer_write pair.
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, PatternRewriter &rewriter) const final
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.