29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
35 #include <type_traits>
40 #define DEBUG_TYPE "linalg-vectorization"
42 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
51 template <
typename OpType>
54 block.
walk([&](OpType op) {
69 int64_t nSize, int64_t wSize, int64_t cSize,
70 int64_t kwSize,
int strideW,
int dilationW,
71 int64_t wSizeStep,
bool isSingleChanneled) {
73 if (isSingleChanneled) {
78 for (int64_t kw = 0; kw < kwSize; ++kw) {
79 for (int64_t w = 0; w < wSize; w += wSizeStep) {
80 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
89 for (int64_t kw = 0; kw < kwSize; ++kw) {
90 for (int64_t w = 0; w < wSize; w += wSizeStep) {
91 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
109 for (int64_t kw = 0; kw < kwSize; ++kw) {
110 result.push_back(rewriter.
create<vector::ExtractOp>(
120 int64_t nSize, int64_t wSize, int64_t fSize,
121 int64_t wSizeStep,
bool isSingleChanneled) {
123 if (isSingleChanneled) {
127 for (int64_t w = 0; w < wSize; w += wSizeStep) {
128 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
136 for (int64_t w = 0; w < wSize; w += wSizeStep) {
137 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
146 Value res, int64_t wSize, int64_t wSizeStep,
148 bool isSingleChanneled) {
150 if (isSingleChanneled) {
154 for (int64_t w = 0; w < wSize; w += wSizeStep) {
155 res = rewriter.
create<vector::InsertStridedSliceOp>(
162 for (int64_t w = 0; w < wSize; w += wSizeStep) {
163 res = rewriter.
create<vector::InsertStridedSliceOp>(
190 std::optional<AffineMap> maybeMaskingMap = std::nullopt);
195 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
196 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
211 std::optional<AffineMap> maybeMaskingMap);
237 VectorizationState::precomputeIterSpaceDynamicSizes(
RewriterBase &rewriter,
240 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
241 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
243 iterSpaceDynamicSizes.push_back(
Value());
250 unsigned operandDimPos;
251 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
255 Value dynamicDim = linalgOp.hasTensorSemantics()
257 linalgOp.getLoc(), operand, operandDimPos)
259 linalgOp.getLoc(), operand, operandDimPos);
260 iterSpaceDynamicSizes.push_back(dynamicDim);
275 if (!inputVectorSizes.empty()) {
279 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
283 canonicalVecShape = linalgOp.getStaticLoopRanges();
286 LDBG(
"Canonical vector shape: ");
287 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
288 LLVM_DEBUG(llvm::dbgs() <<
"\n");
290 if (ShapedType::isDynamicShape(canonicalVecShape))
294 initIterSpaceStaticSizes(linalgOp);
298 if (
failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
308 Value VectorizationState::getOrCreateMaskFor(
310 std::optional<AffineMap> maybeMaskingMap) {
312 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
316 assert(!maskableOp.isMasked() &&
317 "Masking an operation that is already masked");
320 assert((!maybeMaskingMap || *maybeMaskingMap) &&
321 "Unexpected null mask permutation map");
323 maybeMaskingMap ? *maybeMaskingMap
325 linalgOp.getNumLoops(), rewriter.
getContext());
327 LDBG(
"Masking map: " << maskingMap <<
"\n");
331 auto activeMaskIt = activeMaskCache.find(maskingMap);
332 if (activeMaskIt != activeMaskCache.end()) {
333 Value mask = activeMaskIt->second;
334 LDBG(
"Reusing mask: " << mask <<
"\n");
348 LDBG(
"Mask shape: ");
349 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
350 LLVM_DEBUG(llvm::dbgs() <<
"\n");
352 if (permutedStaticSizes == maskShape) {
353 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
354 activeMaskCache[maskingMap] =
Value();
363 for (
auto [staticBound, dynBound] :
364 llvm::zip(permutedStaticSizes, permutedDynamicSizes))
365 upperBounds.push_back(ShapedType::isDynamic(staticBound)
367 : rewriter.
create<arith::ConstantIndexOp>(
368 linalgOp.getLoc(), staticBound));
370 assert(!maskShape.empty() && !upperBounds.empty() &&
371 "Masked 0-d vectors are not supported yet");
374 auto maskType = VectorType::get(maskShape, rewriter.
getI1Type());
375 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
376 maskType, upperBounds);
377 LDBG(
"Creating new mask: " << mask <<
"\n");
378 activeMaskCache[maskingMap] = mask;
389 std::optional<AffineMap> maybeMaskingMap) {
390 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
394 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
397 LDBG(
"No mask required\n");
402 assert(opToMask &&
"Expected a valid operation to mask");
403 auto maskOp = cast<vector::MaskOp>(
405 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
411 LDBG(
"Masked operation: " << *maskOp <<
"\n");
434 "expected projected permutation");
436 assert(res.getNumDims() == res.getNumResults() &&
437 "expected reindexed map with same number of dims and results");
469 std::optional<vector::CombiningKind>
471 using ::mlir::vector::CombiningKind;
476 .Case<arith::AddIOp, arith::AddFOp>(
477 [&](
auto op) {
return CombiningKind::ADD; })
478 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
479 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
480 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
481 .Case<arith::MaxFOp>([&](
auto op) {
return CombiningKind::MAXF; })
482 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
483 .Case<arith::MinUIOp>([&](
auto op) {
return CombiningKind::MINUI; })
484 .Case<arith::MinFOp>([&](
auto op) {
return CombiningKind::MINF; })
485 .Case<arith::MulIOp, arith::MulFOp>(
486 [&](
auto op) {
return CombiningKind::MUL; })
487 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
488 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
489 .Default([&](
auto op) {
return std::nullopt; });
500 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
505 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
506 combinerOps.size() != 1)
510 return combinerOps[0];
520 VectorType targetVectorType =
526 return b.
createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
538 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
539 return b.
create<vector::MultiDimReductionOp>(
540 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
544 return llvm::to_vector(
558 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
559 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
565 if (vectorType.getRank() > 0) {
568 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
570 write = rewriter.
create<vector::TransferWriteOp>(
571 loc, value, outputOperand->
get(), indices, writeMap);
575 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
576 assert(value.
getType() == vectorType &&
"incorrect type");
577 write = rewriter.
create<vector::TransferWriteOp>(
581 write = state.
maskOperation(rewriter, write, linalgOp, opOperandMap);
585 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
586 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
591 LDBG(
"vectorized op: " << *write <<
"\n");
620 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
629 linalgOp.getDpsInitOperand(output.index()), state);
631 newResults.push_back(newResult);
645 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
648 auto loc = indexOp.getLoc();
653 llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
654 auto constantOp = rewriter.
create<arith::ConstantOp>(
659 if (indexOp.getDim() == targetShape.size() - 1)
664 std::swap(targetShape[indexOp.getDim()], targetShape.back());
665 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
666 loc, VectorType::get(targetShape, rewriter.
getIndexType()), constantOp);
668 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
669 std::swap(transposition.back(), transposition[indexOp.getDim()]);
671 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
679 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
683 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
686 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
689 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
690 return !VectorType::isValidElementType(type);
709 tensor::ExtractOp extractOp,
713 auto indexVecType = VectorType::get(targetShape, rewriter.
getIndexType());
714 auto loc = extractOp.getLoc();
717 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType.getShape());
719 const size_t numIndices = extractOp.getIndices().size();
720 for (
size_t i = 1; i < numIndices; i++) {
721 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
725 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
726 indexVecType.getShape());
728 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
730 auto extractOpIndex =
732 indexVecType.getShape());
734 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
749 auto targetShape = linalgOp.getStaticLoopRanges();
750 assert(((llvm::count_if(targetShape,
751 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
752 "n-D vectors are not yet supported");
753 assert(targetShape.back() != 1 &&
754 "1-D vectors with the trailing dim eqaual 1 are not yet supported");
760 auto *block = linalgOp.getBlock();
761 if (isa<BlockArgument>(val))
762 return llvm::all_of(block->getArguments(),
763 [&val](
Value v) { return (v != val); });
766 assert(defOp &&
"This is neither a block argument nor an operation result");
771 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
772 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
773 return (indexOp.getDim() != trailingLoopDim);
775 auto *ancestor = block->findAncestorOpInBlock(*defOp);
782 if (isa<arith::ConstantOp>(ancestor))
786 for (
auto op : ancestor->getOperands())
809 bool &foundIndexOp) {
811 auto targetShape = linalgOp.getStaticLoopRanges();
812 assert(((llvm::count_if(targetShape,
813 [](int64_t dimSize) {
return dimSize > 1; }) == 1)) &&
814 "n-D vectors are not yet supported");
815 assert(targetShape.back() != 1 &&
816 "1-D vectors with the trailing dim 1 are not yet supported");
822 auto *block = linalgOp.getBlock();
823 if (isa<BlockArgument>(val))
824 return llvm::all_of(block->getArguments(),
825 [&val](
Value v) { return (v != val); });
828 assert(defOp &&
"This is neither a block argument nor an operation result");
832 auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
833 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
834 foundIndexOp = (indexOp.getDim() == trailingLoopDim);
838 auto *ancestor = block->findAncestorOpInBlock(*defOp);
845 if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
850 for (
auto op : ancestor->getOperands())
862 LinalgOp &linalgOp) {
864 auto targetShape = linalgOp.getStaticLoopRanges();
870 if ((llvm::count_if(targetShape,
871 [](int64_t dimSize) {
return dimSize > 1; }) != 1) ||
872 targetShape.back() == 1)
875 auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
880 if (inputShape.getShape().back() == 1)
883 bool isContiguous =
true;
888 auto indices = extractOp.getIndices();
889 auto leadIndices =
ValueRange(indices.drop_back(1));
892 if (inputShape.getShape()[i] == 1)
899 auto extractOpTrailingIdx = indices.back();
903 bool foundIndexOp =
false;
906 isContiguous &= foundIndexOp;
909 LDBG(
"Found contigous load: " << extractOp);
923 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
926 auto loc = extractOp.getLoc();
932 VectorType::get(targetShape, extractOp.getResult().getType());
933 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
935 VectorType::get(targetShape, rewriter.
getI1Type()),
937 auto passThruConstantOp =
943 extractOp.getIndices().size(),
944 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
955 loc, resultType, extractOp.getTensor(), baseIndices, offset,
956 maskConstantOp, passThruConstantOp);
957 gatherOp = state.
maskOperation(rewriter, gatherOp, linalgOp);
959 LDBG(
"Vectorised as gather load: " << extractOp);
965 auto resTrailingDim = resultType.getShape().back();
966 auto zero = rewriter.
create<arith::ConstantOp>(
982 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
983 auto idx = bvm.
lookup(extractOp.getIndices()[i]);
984 if (idx.getType().isIndex()) {
985 transferReadIdxs.push_back(idx);
989 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
990 loc, VectorType::get({resTrailingDim}, rewriter.
getIndexType()),
991 bvm.
lookup(extractOp.getIndices()[i]));
992 transferReadIdxs.push_back(
993 rewriter.
create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
999 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1000 loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
1002 LDBG(
"Vectorised as contiguous load: " << extractOp);
1020 (outputType && reduceType.getShape() == outputType.getShape()))
1049 LDBG(
"vectorize op " << *op <<
"\n");
1052 if (!customVectorizationHooks.empty()) {
1053 for (
auto &customFunc : customVectorizationHooks) {
1063 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1074 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1075 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1079 linalgOp.getRegionOutputArgs(),
1080 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1083 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1085 if (!reductionOperands.empty()) {
1086 assert(reductionOperands.size() == 1);
1088 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1089 reductionOperands[0].second, bvm);
1098 auto vt = bvm.
lookup(operand).getType().dyn_cast<VectorType>();
1099 if (vt && firstMaxRankedShape.size() < vt.getShape().size())
1100 firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
1103 auto vectorizedOperands = llvm::map_range(op->
getOperands(), [&](
Value v) {
1104 return firstMaxRankedShape.empty()
1106 : broadcastIfNeeded(rewriter, bvm.lookup(v),
1107 firstMaxRankedShape);
1111 return firstMaxRankedShape.empty()
1113 : VectorType::get(firstMaxRankedShape, t);
1120 llvm::to_vector<4>(vectorizedOperands),
1121 llvm::to_vector<4>(returnTypes), op->
getAttrs())};
1150 LDBG(
"Vectorizing operation as linalg generic\n");
1151 Block *block = linalgOp.getBlock();
1158 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1160 if (linalgOp.getNumDpsInits() == 0)
1165 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1166 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1167 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1168 if (linalgOp.isScalar(opOperand)) {
1169 bvm.
map(bbarg, opOperand->get());
1175 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1182 zeroPos.push_back(result.index());
1189 if (linalgOp.isDpsInput(opOperand)) {
1207 loc, readType, opOperand->get(), indices, readMap);
1208 read = state.
maskOperation(rewriter, read, linalgOp, maskingMap);
1213 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1215 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1221 if (
readValue.getType().cast<VectorType>().getRank() == 0)
1236 hooks.push_back(vectorizeYield);
1243 hooks.push_back(vectorizeIndex);
1250 hooks.push_back(vectorizeExtract);
1257 LDBG(
"failed to vectorize: " << op <<
"\n");
1263 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1275 LDBG(
"reduction precondition failed: no reduction iterator\n");
1278 for (
OpOperand *opOperand : op.getDpsInitOperands()) {
1279 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
1285 LDBG(
"reduction precondition failed: reduction detection failed\n");
1294 if (!isa<linalg::GenericOp, linalg::FillOp>(op))
1298 if (op.hasIndexSemantics())
1301 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1308 bool vectorizeNDExtract) {
1310 if (llvm::any_of(linalgOp.getStaticShape(),
1311 [](int64_t dim) { return dim == 0; }))
1314 if (!inputVectorSizes.empty()) {
1315 assert(inputVectorSizes.size() == linalgOp.getNumLoops() &&
1316 "Input vector sizes don't match the number of loops");
1317 assert(!ShapedType::isDynamicShape(inputVectorSizes) &&
1318 "Input vector sizes can't have dynamic dimensions");
1321 llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes),
1322 [](std::tuple<int64_t, int64_t> sizePair) {
1323 int64_t staticSize = std::get<0>(sizePair);
1324 int64_t inputSize = std::get<1>(sizePair);
1325 return ShapedType::isDynamic(staticSize) ||
1326 staticSize <= inputSize;
1328 "Input vector sizes must be greater than or equal to iteration space "
1332 if (linalgOp.hasDynamicShape() &&
1334 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1347 customPreconditions,
1350 customPrecondition(&innerOp, vectorizeNDExtract));
1354 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1355 return !VectorType::isValidElementType(type);
1359 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1360 return !VectorType::isValidElementType(type);
1370 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1376 LDBG(
"precondition failed: not projected permutations\n");
1380 LDBG(
"precondition failed: reduction preconditions\n");
1389 auto toReplace = linalgOp.getBlock()->getOps<AffineApplyOp>();
1391 for (
auto op : make_early_inc_range(toReplace)) {
1394 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
1395 op.getOperands().take_front(op.getAffineMap().getNumDims()),
1396 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
1409 bool vectorizeNDExtract) {
1410 LDBG(
"Attempting to vectorize:\n" << linalgOp <<
"\n");
1411 LDBG(
"Input vector sizes: ");
1412 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
1413 LLVM_DEBUG(llvm::dbgs() <<
"\n");
1416 vectorizeNDExtract))) {
1417 LDBG(
"Vectorization pre-conditions failed\n");
1424 LDBG(
"Vectorization state couldn't be initialized\n");
1433 llvm::append_range(results, (*convOr)->getResults());
1436 vectorizeNDExtract)))
1438 LDBG(
"Vectorize generic by broadcasting to the canonical vector shape\n");
1452 if (!results.empty())
1461 memref::CopyOp copyOp) {
1463 auto srcType = copyOp.getSource().getType().cast<MemRefType>();
1464 auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
1465 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
1470 if (!VectorType::isValidElementType(srcElementType) ||
1471 !VectorType::isValidElementType(dstElementType))
1474 auto readType = VectorType::get(srcType.getShape(), srcElementType);
1475 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
1482 loc, readType, copyOp.getSource(), indices,
1484 if (
readValue.getType().cast<VectorType>().getRank() == 0) {
1489 loc,
readValue, copyOp.getTarget(), indices,
1501 return attr.
cast<IntegerAttr>().getInt();
1510 for (
auto o : ofrs) {
1511 if (
auto val = o.template dyn_cast<Value>()) {
1512 result.push_back(val);
1514 result.push_back(rewriter.
create<arith::ConstantIndexOp>(
1533 tensor::PadOp padOp,
Value dest) {
1534 auto sourceType = padOp.getSourceType();
1535 auto resultType = padOp.getResultType();
1536 if (!VectorType::isValidElementType(sourceType.getElementType()))
1542 auto padValue = padOp.getConstantPaddingValue();
1544 if (!sourceType.hasStaticShape())
1547 auto elemType = sourceType.getElementType();
1548 padValue = rewriter.
create<arith::ConstantOp>(
1549 padOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
1555 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
1556 if (!sourceType.isDynamicDim(i)) {
1557 vecShape.push_back(sourceType.getDimSize(i));
1560 readInBounds.push_back(
true);
1561 writeInBounds.push_back(
true);
1562 }
else if (!resultType.isDynamicDim(i)) {
1566 vecShape.push_back(resultType.getDimSize(i));
1569 readInBounds.push_back(
false);
1571 writeInBounds.push_back(
1573 static_cast<int64_t
>(0));
1580 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
1585 rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1586 auto read = rewriter.
create<vector::TransferReadOp>(
1587 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
1592 if (llvm::equal(vecShape, resultType.getShape()) &&
1593 llvm::all_of(writeInBounds, [](
bool b) {
return b; }))
1595 dest = fill.output();
1609 template <
typename OpTy>
1615 bool changed =
false;
1617 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
1618 if (
auto op = dyn_cast<OpTy>(user))
1619 changed |= rewriteUser(rewriter, padOp, op).
succeeded();
1625 tensor::PadOp padOp, OpTy op)
const = 0;
1653 vector::TransferReadOp xferOp)
const override {
1655 if (!padOp.hasZeroLowPad())
1658 auto padValue = padOp.getConstantPaddingValue();
1662 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
1667 xferOp->setAttr(xferOp.getInBoundsAttrName(),
1669 xferOp.getSourceMutable().assign(padOp.getSource());
1670 xferOp.getPaddingMutable().assign(padValue);
1715 vector::TransferWriteOp xferOp)
const override {
1717 if (xferOp.getTransferRank() == 0)
1721 if (!padOp.hasZeroLowPad())
1724 auto padValue = padOp.getConstantPaddingValue();
1728 if (!xferOp->hasOneUse())
1730 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
1734 if (!trimPadding.hasZeroOffset())
1737 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
1745 xferOp, padOp.getSource().getType(), xferOp.getVector(),
1746 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
1748 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
1764 tensor::ExtractSliceOp afterTrimming)
const {
1767 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
1768 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
1772 auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
1777 if (t1.getRank() != t2.getRank())
1782 for (
unsigned i = 0; i < t1.getRank(); ++i) {
1783 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
1785 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
1790 if (t1.getNumDynamicDims() == 0)
1798 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
1802 assert(
static_cast<size_t>(t1.getRank()) ==
1803 beforeSlice.getMixedSizes().size());
1804 assert(
static_cast<size_t>(t2.getRank()) ==
1805 afterTrimming.getMixedSizes().size());
1807 for (
unsigned i = 0; i < t1.getRank(); ++i) {
1809 if (!t1.isDynamicDim(i))
1811 auto size1 = beforeSlice.getMixedSizes()[i];
1812 auto size2 = afterTrimming.getMixedSizes()[i];
1819 auto v1 = size1.dyn_cast<
Value>();
1828 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
1829 minOp1.getOperands() == minOp2.getOperands())
1869 tensor::InsertSliceOp insertOp)
const override {
1871 if (!padOp.hasZeroLowPad())
1874 if (!insertOp.hasUnitStride())
1877 auto padValue = padOp.getConstantPaddingValue();
1881 if (!padOp.getResult().getType().cast<ShapedType>().hasStaticShape())
1884 if (insertOp.getDest() == padOp.getResult())
1887 auto vecType = VectorType::get(padOp.getType().getShape(),
1888 padOp.getType().getElementType());
1889 unsigned vecRank = vecType.getRank();
1890 unsigned tensorRank = insertOp.getType().getRank();
1895 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
1897 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
1898 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1909 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
1910 auto read = rewriter.
create<vector::TransferReadOp>(
1911 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
1920 insertOp, read, insertOp.getDest(), writeIndices,
1949 LDBG(
"interleavedUses precondition failed, firstOp: "
1950 << *firstOp <<
", second op: " << *secondOp <<
"\n");
1953 for (
auto v : values) {
1954 for (
auto &u : v.getUses()) {
1956 if (owner == firstOp || owner == secondOp)
1962 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
1963 <<
", second op: " << *secondOp <<
"\n");
1973 memref::SubViewOp subViewOp;
1975 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
1977 return memref::SubViewOp();
1978 subViewOp = newSubViewOp;
1990 if (xferOp.getMask())
1994 Value viewOrAlloc = xferOp.getSource();
2003 Value subView = subViewOp.getResult();
2006 memref::CopyOp copyOp;
2007 for (
auto &u : subView.
getUses()) {
2008 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2009 assert(newCopyOp.getTarget().getType().isa<MemRefType>());
2010 if (newCopyOp.getTarget() != subView)
2024 for (
auto &u : viewOrAlloc.
getUses()) {
2025 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2026 assert(newFillOp.output().getType().isa<MemRefType>());
2027 if (newFillOp.output() != viewOrAlloc)
2031 maybeFillOp = newFillOp;
2036 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2038 "padding value does not match fill");
2041 Value in = copyOp.getSource();
2047 Value res = rewriter.
create<vector::TransferReadOp>(
2048 xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(),
2049 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2054 rewriter.
eraseOp(maybeFillOp);
2066 if (xferOp.getMask())
2070 Value viewOrAlloc = xferOp.getSource();
2079 Value subView = subViewOp.getResult();
2082 memref::CopyOp copyOp;
2083 for (
auto &u : subViewOp.getResult().getUses()) {
2084 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2085 if (newCopyOp.getSource() != subView)
2097 assert(copyOp.getTarget().getType().isa<MemRefType>());
2098 Value out = copyOp.getTarget();
2105 rewriter.
create<vector::TransferWriteOp>(
2106 xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(),
2107 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2124 template <
int N,
typename IntTy,
typename... IntTy2>
2126 val = shapedType.getShape()[N];
2131 template <
typename... IntTy>
2133 bindShapeDims<0>(shapedType, vals...);
2137 bool isCastOfBlockArgument(
Operation *op) {
2142 bool isSupportedPoolKind(vector::CombiningKind kind) {
2144 case vector::CombiningKind::ADD:
2145 case vector::CombiningKind::MAXF:
2146 case vector::CombiningKind::MAXSI:
2147 case vector::CombiningKind::MAXUI:
2148 case vector::CombiningKind::MINF:
2149 case vector::CombiningKind::MINSI:
2150 case vector::CombiningKind::MINUI:
2191 struct Conv1DGenerator
2193 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
2196 strideW(strideW), dilationW(dilationW) {
2198 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
2200 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
2201 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
2202 resShaped = linalgOp.getDpsInitOperand(0)->get();
2203 lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
2204 rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
2205 resShapedType = resShaped.getType().dyn_cast<ShapedType>();
2206 if (!lhsShapedType || !rhsShapedType || !resShapedType)
2210 if (!((lhsShapedType.getRank() == 3 && resShapedType.getRank() == 3) ||
2211 (lhsShapedType.getRank() == 1 && resShapedType.getRank() == 1)))
2219 if (!setOperKind(reduceOp))
2222 if (!maybeKind || (*maybeKind != vector::CombiningKind::ADD &&
2223 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
2227 auto rhsRank = rhsShapedType.getRank();
2230 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2264 int64_t nSize, wSize, cSize, kwSize, fSize;
2267 switch (conv1DOpOrder) {
2270 nSize = fSize = cSize = 0;
2277 (wSize + kwSize - 1)};
2278 rhsShape = {kwSize};
2299 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2304 rhsShape = {kwSize, cSize, fSize};
2307 rhsShape = {kwSize};
2310 resShape = {nSize, wSize, fSize};
2326 lhsShape = {nSize, cSize,
2330 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
2334 rhsShape = {fSize, cSize, kwSize};
2337 rhsShape = {kwSize};
2340 resShape = {nSize, fSize, wSize};
2344 vector::TransferWriteOp write;
2345 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
2350 int64_t wSizeStep = strideW == 1 ? wSize : 1;
2352 Type lhsEltType = lhsShapedType.getElementType();
2353 Type rhsEltType = rhsShapedType.getElementType();
2354 Type resEltType = resShapedType.getElementType();
2355 auto lhsType = VectorType::get(lhsShape, lhsEltType);
2356 auto rhsType = VectorType::get(rhsShape, rhsEltType);
2357 auto resType = VectorType::get(resShape, resEltType);
2364 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
2367 Value rhs =
nullptr;
2369 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2371 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
2377 switch (conv1DOpOrder) {
2385 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
2386 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
2388 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
2392 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
2394 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
2395 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
2406 kwSize, strideW, dilationW, wSizeStep,
2412 wSizeStep, isSingleChanneled);
2414 auto linearIndex = [&](int64_t kw, int64_t w) {
2415 return kw * (wSize / wSizeStep) + w;
2421 for (int64_t kw = 0; kw < kwSize; ++kw) {
2422 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2425 if (isSingleChanneled) {
2426 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
2427 lhsVals[linearIndex(kw, w)],
2428 rhsVals[kw], resVals[w]);
2430 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
2431 lhsVals[linearIndex(kw, w)],
2432 rhsVals[kw], resVals[w]);
2436 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
2452 switch (conv1DOpOrder) {
2459 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
2460 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
2466 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
2473 vector::IteratorType par = vector::IteratorType::parallel;
2474 vector::IteratorType red = vector::IteratorType::reduction;
2477 return rewriter.
create<vector::ContractionOp>(
2479 MapList{{n, w, c}, {c, f}, {n, w, f}},
2487 return rewriter.
create<vector::OuterProductOp>(
2488 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
2514 int64_t nSize, wSize, cSize, kwSize;
2520 vector::TransferWriteOp write;
2521 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
2526 int64_t wSizeStep = strideW == 1 ? wSize : 1;
2528 Type lhsEltType = lhsShapedType.getElementType();
2529 Type rhsEltType = rhsShapedType.getElementType();
2530 Type resEltType = resShapedType.getElementType();
2531 VectorType lhsType = VectorType::get(
2535 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
2538 VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
2539 VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
2543 Value lhs = rewriter.
create<vector::TransferReadOp>(
2544 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
2546 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
2549 Value res = rewriter.
create<vector::TransferReadOp>(
2550 loc, resType, resShaped,
ValueRange{zero, zero, zero});
2559 for (int64_t kw = 0; kw < kwSize; ++kw) {
2560 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2561 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
2569 for (int64_t kw = 0; kw < kwSize; ++kw) {
2570 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
2574 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2575 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
2582 auto linearIndex = [&](int64_t kw, int64_t w) {
2583 return kw * (wSize / wSizeStep) + w;
2587 for (int64_t kw = 0; kw < kwSize; ++kw) {
2588 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2589 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
2590 lhsVals[linearIndex(kw, w)],
2591 rhsVals[kw], resVals[w]);
2596 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
2598 for (
auto &collection :
2599 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
2600 for (
Value v : collection)
2607 for (int64_t w = 0; w < wSize; w += wSizeStep) {
2608 res = rewriter.
create<vector::InsertStridedSliceOp>(
2609 loc, resVals[w], res,
2619 .
create<vector::TransferWriteOp>(loc, res, resShaped,
2629 const int64_t srcWidth =
2634 return rewriter.
create<arith::ExtFOp>(loc, ty, val);
2637 return rewriter.
create<arith::ExtSIOp>(loc, ty, val);
2649 lhs =
promote(rewriter, loc, lhs, resTy);
2651 rhs = rewriter.
create<vector::BroadcastOp>(
2652 loc, resTy.clone(rhsTy.getElementType()), rhs);
2653 rhs =
promote(rewriter, loc, rhs, resTy);
2658 if (resTy.getElementType().isa<
FloatType>())
2659 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
2661 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
2662 return rewriter.
create<arith::AddIOp>(loc, mul, res);
2670 if (!iters({Par(), Red()}))
2672 "failed to match conv::W 1-par 1-red");
2675 if (layout({ {w + kw},
2688 if (!iters({Par(), Par(), Par(), Red(), Red()}))
2690 op,
"failed to match conv::Nwc 3-par 2-red");
2693 if (layout({ {n, strideW * w + dilationW * kw, c},
2706 if (!iters({Par(), Par(), Par(), Red(), Red()}))
2708 op,
"failed to match conv::Ncw 3-par 2-red");
2710 if (layout({ {n, c, strideW * w + dilationW * kw},
2723 if (!iters({Par(), Par(), Par(), Red()}))
2725 "failed to match pooling 3-par 1-red");
2728 if (layout({ {n, strideW * w + dilationW * kw, c},
2741 if (!iters({Par(), Par(), Par(), Red()}))
2743 "failed to match pooling 3-par 1-red");
2745 if (layout({ {n, c, strideW * w + dilationW * kw},
2758 if (!iters({Par(), Par(), Par(), Red()}))
2760 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
2763 if (layout({ {n, strideW * w + dilationW * kw, c},
2766 return depthwiseConv();
2772 enum OperKind { Conv, Pool };
2774 OperKind oper = Conv;
2776 StringAttr poolExtOp;
2777 bool isPoolExt =
false;
2778 int strideW, dilationW;
2779 Value lhsShaped, rhsShaped, resShaped;
2780 ShapedType lhsShapedType, rhsShapedType, resShapedType;
2791 int numBlockArguments =
2793 [](
Value v) { return v.isa<BlockArgument>(); });
2794 switch (numBlockArguments) {
2799 return !v.isa<BlockArgument>();
2801 Operation *feedOp = (*feedValIt).getDefiningOp();
2802 if (isCastOfBlockArgument(feedOp)) {
2806 }
else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
2808 if (v.isa<BlockArgument>())
2810 if (Operation *op = v.getDefiningOp())
2811 return isCastOfBlockArgument(op);
2840 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
2841 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
2842 Conv1DGenerator e(rewriter, op, stride, dilation);
2843 auto res = e.generateNonChanneledConv();
2846 res = e.generateNwcConv();
2849 res = e.generateNcwConv();
2852 res = e.generateNwcPooling();
2855 res = e.generateNcwPooling();
2858 return e.generateDilatedConv();
2862 using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
2867 if (
failed(resultOrFail))
2871 rewriter.
eraseOp(op.getOperation());
2874 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static Value calculateGatherOffset(RewriterBase &rewriter, tensor::ExtractOp extractOp, const IRMapping &bvm, const ArrayRef< int64_t > targetShape)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static LogicalResult reductionPreconditions(LinalgOp op)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp)
Try to vectorize convOp as a convolution.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp)
Check whether extractOp would be a gather or a contiguous load Op after vectorising linalgOp.
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val)
Checks whether /p val can be used for calculating a loop invariant index.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< Value > ofrToIndexValues(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > ofrs)
Given an ArrayRef of OpFoldResults, return a vector of Values.
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static Value broadcastIfNeeded(OpBuilder &b, Value value, ArrayRef< int64_t > shape)
Broadcast value to a vector of shape if possible.
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static int64_t getIntFromAttr(Attribute attr)
Helper function that retrieves the value of an IntegerAttr.
An integer constant appearing in affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk the operations in this block.
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
MLIRContext * getContext() const
Attribute getZeroAttr(Type type)
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
DenseIntElementsAttr getIndexVectorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes={}, bool vectorizeNDExtract=false)
Emit a suitable vector form for a Linalg op.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes={}, bool vectorizeNDExtract=false)
Return success if the operation can be vectorized.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(PatternRewriter &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
static LogicalResult tryVectorizeCopy(RewriterBase &rewriter, tensor::PadOp padOp, Value dest)
Vectorize the copying of a tensor::PadOp's source.
GenericPadOpVectorizationPattern(MLIRContext *context, PatternBenefit benefit=1)
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeMaskingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.