38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Support/DebugLog.h"
43#include "llvm/Support/InterleavedRange.h"
44#include "llvm/Support/MathExtras.h"
45#include "llvm/Support/raw_ostream.h"
51#define DEBUG_TYPE "linalg-vectorization"
54static FailureOr<Operation *>
58 bool flatten1DDepthwiseConv =
false);
93template <
typename OpType>
96 block.
walk([&](OpType op) {
112 int64_t kwSize,
int strideW,
int dilationW,
113 int64_t wSizeStep,
bool isSingleChanneled) {
115 if (isSingleChanneled) {
120 for (
int64_t kw = 0; kw < kwSize; ++kw) {
121 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
122 result.push_back(vector::ExtractStridedSliceOp::create(
132 for (
int64_t kw = 0; kw < kwSize; ++kw) {
133 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
134 result.push_back(vector::ExtractStridedSliceOp::create(
135 rewriter, loc, input,
152 for (
int64_t kw = 0; kw < kwSize; ++kw) {
153 result.push_back(vector::ExtractOp::create(
164 int64_t wSizeStep,
bool isSingleChanneled) {
166 if (isSingleChanneled) {
170 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
171 result.push_back(vector::ExtractStridedSliceOp::create(
180 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
181 result.push_back(vector::ExtractStridedSliceOp::create(
193 bool isSingleChanneled) {
195 if (isSingleChanneled) {
199 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
200 res = vector::InsertStridedSliceOp::create(
208 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
209 res = vector::InsertStridedSliceOp::create(
210 rewriter, loc, resVals[w], res,
224 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
227 bool assumeDynamicDimsMatchVecSizes =
false);
242 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
245 if (dimPermutation.has_value()) {
251 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
252 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
255 return VectorType::get(
vectorShape, elementType, scalableDims);
264 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
269 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
270 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
276 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
283 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
285 std::optional<AffineMap> maybeMaskingMap);
290 bool isValidMaskingMap(AffineMap maskingMap) {
309 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
315 SmallVector<int64_t> iterSpaceStaticSizes;
320 SmallVector<Value> iterSpaceValueSizes;
323 SmallVector<int64_t> canonicalVecShape;
327 SmallVector<bool> scalableVecDims;
335 OpBuilder::InsertionGuard rewriterGuard;
343 bool assumeDynamicDimsMatchVecSizes =
false;
347VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
350 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
351 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
354 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
361 unsigned operandDimPos;
362 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
367 linalgOp.hasPureTensorSemantics()
368 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
372 iterSpaceValueSizes.push_back(dynamicDim);
385 bool assumeDimsMatchVec) {
386 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
390 if (!inputVectorSizes.empty()) {
394 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
395 scalableVecDims.append(inputScalableVecDims.begin(),
396 inputScalableVecDims.end());
401 canonicalVecShape = linalgOp.getStaticLoopRanges();
402 scalableVecDims.append(linalgOp.getNumLoops(),
false);
405 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
406 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
408 if (ShapedType::isDynamicShape(canonicalVecShape))
412 initIterSpaceStaticSizes(linalgOp);
417 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
427Value VectorizationState::getOrCreateMaskFor(
429 std::optional<AffineMap> maybeMaskingMap) {
431 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
432 "Ill-formed masking map.");
435 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
439 assert(!maskableOp.isMasked() &&
440 "Masking an operation that is already masked");
443 assert((!maybeMaskingMap || *maybeMaskingMap) &&
444 "Unexpected null mask permutation map");
446 maybeMaskingMap ? *maybeMaskingMap
448 linalgOp.getNumLoops(), rewriter.
getContext());
450 LDBG() <<
"Masking map: " << maskingMap;
454 auto activeMaskIt = activeMaskCache.find(maskingMap);
455 if (activeMaskIt != activeMaskCache.end()) {
456 Value mask = activeMaskIt->second;
457 LDBG() <<
"Reusing mask: " << mask;
467 SmallVector<int64_t> permutedStaticSizes =
469 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
470 auto maskShape = maskType.getShape();
472 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
474 if (permutedStaticSizes == maskShape) {
475 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
476 activeMaskCache[maskingMap] = Value();
480 if (assumeDynamicDimsMatchVecSizes) {
484 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
486 return std::get<0>(it) == ShapedType::kDynamic
488 : std::get<0>(it) == std::get<1>(it);
491 <<
"Dynamic + static dimensions match vector sizes, masking is not "
493 activeMaskCache[maskingMap] = Value();
499 SmallVector<Value> upperBounds =
501 assert(!maskShape.empty() && !upperBounds.empty() &&
502 "Masked 0-d vectors are not supported yet");
505 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
506 maskType, upperBounds);
507 LDBG() <<
"Creating new mask: " << mask;
508 activeMaskCache[maskingMap] = mask;
515 std::optional<AffineMap> maybeIndexingMap) {
516 LDBG() <<
"Trying to mask: " << *opToMask;
518 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
519 if (maybeIndexingMap)
520 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
524 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
527 LDBG() <<
"No mask required";
528 if (assumeDynamicDimsMatchVecSizes) {
530 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
536 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
537 "setting their in-bounds to true!";
539 ShapedType xferType = xferOp.getShapedType();
544 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
545 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
549 unsigned pos = dimExpr.getPosition();
550 if (xferType.isDynamicDim(pos))
551 inBoundsMap[i] =
true;
554 xferOp.setInBoundsAttr(
566 assert(opToMask &&
"Expected a valid operation to mask");
567 auto maskOp = cast<vector::MaskOp>(
569 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
571 for (
auto [resIdx, resVal] : llvm::enumerate(opToMask->
getResults()))
575 LDBG() <<
"Masked operation: " << *maskOp;
598 "expected projected permutation");
600 assert(res.getNumDims() ==
601 (res.getNumResults() - res.getNumOfZeroResults()) &&
602 "expected reindexed map with same number of dims and results");
638std::optional<vector::CombiningKind>
640 using ::mlir::vector::CombiningKind;
645 .Case<arith::AddIOp, arith::AddFOp>(
646 [&](
auto op) {
return CombiningKind::ADD; })
647 .Case([&](arith::AndIOp op) {
return CombiningKind::AND; })
648 .Case([&](arith::MaxSIOp op) {
return CombiningKind::MAXSI; })
649 .Case([&](arith::MaxUIOp op) {
return CombiningKind::MAXUI; })
650 .Case([&](arith::MaximumFOp op) {
return CombiningKind::MAXIMUMF; })
651 .Case([&](arith::MaxNumFOp op) {
return CombiningKind::MAXNUMF; })
652 .Case([&](arith::MinSIOp op) {
return CombiningKind::MINSI; })
653 .Case([&](arith::MinUIOp op) {
return CombiningKind::MINUI; })
654 .Case([&](arith::MinimumFOp op) {
return CombiningKind::MINIMUMF; })
655 .Case([&](arith::MinNumFOp op) {
return CombiningKind::MINNUMF; })
656 .Case<arith::MulIOp, arith::MulFOp>(
657 [&](
auto op) {
return CombiningKind::MUL; })
658 .Case([&](arith::OrIOp op) {
return CombiningKind::OR; })
659 .Case([&](arith::XOrIOp op) {
return CombiningKind::XOR; })
660 .Default(std::nullopt);
671 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
676 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
677 combinerOps.size() != 1)
681 return combinerOps[0];
687 auto dstVecType = dyn_cast<VectorType>(dstType);
689 if (dstVecType.getRank() == 0)
694 Location loc =
b.getInsertionPoint()->getLoc();
695 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
707 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
708 return vector::MultiDimReductionOp::create(
709 b, reduceOp->
getLoc(), valueToReduce,
acc, dimsToMask, *maybeKind);
713 return llvm::to_vector(
720 return isa<linalg::ReduceOp>(op) ||
721 (isa<linalg::GenericOp>(op) &&
733 VectorizationState &state) {
735 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
736 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
745 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
747 auto vectorType = state.getCanonicalVecType(
754 if (vectorType.getRank() > 0) {
757 assert(value.
getType() == vectorType &&
"Incorrect type");
758 write = vector::TransferWriteOp::create(
759 rewriter, loc, value, outputOperand->
get(),
indices, writeMap);
762 if (!isa<VectorType>(value.
getType()))
763 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
764 assert(value.
getType() == vectorType &&
"Incorrect type");
765 write = vector::TransferWriteOp::create(rewriter, loc, value,
769 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
773 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
774 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
779 LDBG() <<
"vectorized op: " << *write;
789 std::function<LogicalResult(
Operation *,
bool)>;
806 const IRMapping &bvm, VectorizationState &state,
808 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
811 for (
const auto &output : llvm::enumerate(yieldOp.getValues())) {
817 linalgOp.getDpsInitOperand(output.index()), state);
819 newResults.push_back(newResult);
830 VectorizationState &state,
833 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
836 auto loc = indexOp.getLoc();
839 auto dim = indexOp.getDim();
841 auto indexVectorType =
842 VectorType::get({targetShape[dim]}, rewriter.
getIndexType(),
843 state.getScalableVecDims()[dim]);
844 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
848 if (dim == targetShape.size() - 1)
854 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
855 std::swap(permPattern[dim], permPattern.back());
859 auto broadCastOp = vector::BroadcastOp::create(
861 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
863 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
864 std::swap(transposition.back(), transposition[dim]);
866 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
874 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
878 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
883 if (not extractOp.getIndices().empty()) {
884 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
888 if (!llvm::all_of(extractOp->getResultTypes(),
889 VectorType::isValidElementType)) {
907 VectorizationState &state,
908 tensor::ExtractOp extractOp,
911 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
912 auto loc = extractOp.getLoc();
915 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
917 const size_t numIndices = extractOp.getIndices().size();
918 for (
size_t i = 1; i < numIndices; i++) {
923 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
926 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
929 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
931 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
957 (linalgOp.hasDynamicShape() ||
958 llvm::count_if(loopRanges, [](
int64_t dim) { return dim != 1; }) == 1) &&
959 "For statically shaped Linalg Ops, only one "
960 "non-unit loop dim is expected");
961 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
963 size_t idx = loopRanges.size() - 1;
964 for (; idx != 0; idx--)
965 if (loopRanges[idx] != 1)
973 VectorType resType) {
975 assert(((llvm::count_if(resType.getShape(),
976 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
977 "n-D vectors are not yet supported");
983 auto *block = linalgOp.getBlock();
984 if (isa<BlockArgument>(val))
985 return !llvm::is_contained(block->getArguments(), val);
988 assert(defOp &&
"This is neither a block argument nor an operation result");
993 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
994 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
997 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1004 if (isa<arith::ConstantOp>(ancestor))
1008 for (
auto op : ancestor->getOperands())
1032 bool &foundIndexOp, VectorType resType) {
1034 assert(((llvm::count_if(resType.getShape(),
1035 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1036 "n-D vectors are not yet supported");
1042 auto *block = linalgOp.getBlock();
1043 if (isa<BlockArgument>(val))
1044 return !llvm::is_contained(block->getArguments(), val);
1047 assert(defOp &&
"This is neither a block argument nor an operation result");
1049 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1052 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1056 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1063 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1067 for (
auto op : ancestor->getOperands())
1087 LinalgOp &linalgOp, VectorType resType) {
1089 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1092 if (inputShape.getShape().empty())
1097 bool isOutput1DVector =
1098 (llvm::count_if(resType.getShape(),
1099 [](
int64_t dimSize) { return dimSize > 1; }) == 1);
1101 if (!isOutput1DVector)
1104 bool leadingIdxsLoopInvariant =
true;
1110 auto indices = extractOp.getIndices();
1111 auto leadIndices =
indices.drop_back(1);
1113 for (
auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1114 if (inputShape.getShape()[i] == 1)
1120 if (!leadingIdxsLoopInvariant) {
1121 LDBG() <<
"Found gather load: " << extractOp;
1129 auto extractOpTrailingIdx =
indices.back();
1133 if (leadingIdxsLoopInvariant &&
1135 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1144 bool foundIndexOp =
false;
1146 foundIndexOp, resType);
1149 bool isRowVector = resType.getShape().back() != 1;
1150 isContiguousLoad &= (foundIndexOp && isRowVector);
1152 if (isContiguousLoad) {
1153 LDBG() <<
"Found contigous load: " << extractOp;
1158 LDBG() <<
"Found gather load: " << extractOp;
1166static VectorizationHookResult
1169 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1172 auto loc = extractOp.getLoc();
1175 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1176 auto maskConstantOp = arith::ConstantOp::create(
1180 auto passThruConstantOp = arith::ConstantOp::create(
1186 extractOp.getIndices().size(),
1197 Operation *gatherOp = vector::GatherOp::create(
1198 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1199 maskConstantOp, passThruConstantOp);
1200 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1202 LDBG() <<
"Vectorised as gather load: " << extractOp;
1225 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1226 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1228 transferReadIdxs.push_back(idx);
1232 auto indexAs1dVector = vector::ShapeCastOp::create(
1234 VectorType::get(resultType.getShape().back(), rewriter.
getIndexType(),
1235 resultType.getScalableDims().back()),
1237 transferReadIdxs.push_back(
1238 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1242 auto dstRank = resultType.getRank();
1243 auto srcRank = extractOp.getTensor().getType().getRank();
1252 auto transferReadOp = vector::TransferReadOp::create(
1253 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1254 std::nullopt, permutationMap, inBounds);
1260 auto readMaskType = VectorType::get(readMaskShape, rewriter.
getI1Type());
1261 auto allTrue = vector::ConstantMaskOp::create(
1263 auto *maskedReadOp =
1266 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1273 srcRank, std::min(dstRank, srcRank), rewriter.
getContext());
1275 int32_t rankDiff = dstRank - srcRank;
1283 while (rankDiff > 0) {
1284 permutationMap = permutationMap.insertResult(
1289 auto transferReadOp = vector::TransferReadOp::create(
1290 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1291 std::nullopt, permutationMap, inBounds);
1293 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1307 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1308 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1312 (outputType && reduceType.getShape() == outputType.getShape()))
1337static VectorizationHookResult
1341 LDBG() <<
"vectorize op " << *op;
1344 if (!customVectorizationHooks.empty()) {
1345 for (
auto &customFunc : customVectorizationHooks) {
1355 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1357 rewriter.
clone(*op)};
1366 auto blockArg = dyn_cast<BlockArgument>(operand);
1367 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1368 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1372 linalgOp.getRegionOutputArgs(),
1373 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1376 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1378 if (!reductionOperands.empty()) {
1379 assert(reductionOperands.size() == 1);
1381 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1382 reductionOperands[0].second, bvm);
1389 VectorType firstMaxRankedType;
1391 auto vecOperand = bvm.
lookup(operand);
1392 assert(vecOperand &&
"Vector operand couldn't be found");
1394 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1395 if (vecType && (!firstMaxRankedType ||
1396 firstMaxRankedType.getRank() < vecType.getRank()))
1397 firstMaxRankedType = vecType;
1403 assert(vecOperand &&
"Vector operand couldn't be found");
1405 if (firstMaxRankedType) {
1406 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1408 firstMaxRankedType.getScalableDims());
1411 vecOperands.push_back(vecOperand);
1417 resultTypes.push_back(
1419 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1420 firstMaxRankedType.getScalableDims())
1456 LDBG() <<
"Vectorizing operation as linalg generic/n";
1457 Block *block = linalgOp.getBlock();
1464 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1466 if (linalgOp.getNumDpsInits() == 0)
1472 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1473 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1474 if (linalgOp.isScalar(opOperand)) {
1475 bvm.
map(bbarg, opOperand->get());
1481 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1484 VectorType readType;
1486 if (linalgOp.isDpsInput(opOperand)) {
1489 readType = state.getCanonicalVecType(elemType);
1496 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1501 Operation *read = vector::TransferReadOp::create(
1502 rewriter, loc, readType, opOperand->get(),
indices,
1503 std::nullopt, readMap);
1504 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1509 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1511 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1517 if (readType.getRank() == 0)
1518 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1521 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1522 <<
"): " << readValue;
1523 bvm.
map(bbarg, readValue);
1524 bvm.
map(opOperand->get(), readValue);
1533 hooks.push_back(vectorizeYield);
1540 hooks.push_back(vectorizeIndex);
1547 hooks.push_back(vectorizeExtract);
1554 LDBG() <<
"failed to vectorize: " << op;
1559 state.maskOperation(rewriter,
result.newOp, linalgOp);
1560 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1619 if (ShapedType::isDynamicShape(destShape))
1624 for (
auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1626 cstMaskSizes.push_back(*intSize);
1631 if (cstMaskSizes.size() != maskShape.size())
1636 for (
auto [i, idx] : llvm::enumerate(writeIdxs)) {
1639 cstWriteIdxs.push_back(intVal.getSExtValue());
1644 if (cstWriteIdxs.size() != destShape.size())
1653 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1654 for (
auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1655 if ( maskShape[i] > destShape[rankDiff + i] ||
1656 destShape[rankDiff + i] <
1657 (std::clamp(cstMaskSizes[i],
int64_t(0), maskShape[i]) +
1693 bool useInBoundsInsteadOfMasking =
false) {
1695 ShapedType destType = cast<ShapedType>(dest.
getType());
1696 int64_t destRank = destType.getRank();
1697 auto destShape = destType.getShape();
1699 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1700 int64_t vecToStoreRank = vecToStoreType.getRank();
1701 auto vecToStoreShape = vecToStoreType.getShape();
1704 SmallVector<bool> inBoundsVal(vecToStoreRank,
true);
1705 if (useInBoundsInsteadOfMasking) {
1708 for (
unsigned i = 0; i < vecToStoreRank; i++)
1710 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1711 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1715 bool useDefaultWriteIdxs = writeIndices.empty();
1716 assert((useDefaultWriteIdxs ||
1717 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1718 "Invalid number of write indices!");
1719 if (writeIndices.empty()) {
1721 writeIndices.assign(destRank, zero);
1725 Operation *write = vector::TransferWriteOp::create(builder, loc,
1732 if (useInBoundsInsteadOfMasking)
1736 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1740 auto writeMaskType = VectorType::get(vecToStoreShape, builder.
getI1Type(),
1741 vecToStoreType.getScalableDims());
1743 SmallVector<OpFoldResult> destSizes =
1744 isa<MemRefType>(dest.
getType())
1749 SmallVector<OpFoldResult> maskSizes;
1750 if (useDefaultWriteIdxs) {
1751 maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
1754 size_t diff = destShape.size() - vecToStoreRank;
1755 for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
1759 builder.
createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
1760 maskSizes.push_back(OpFoldResult(neg));
1768 Value maskForWrite =
1769 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1791 assert(type.getNumScalableDims() < 2 &&
1792 "Collapsing more than 1 scalable dim is not supported ATM");
1798 auto shape = type.getShape();
1799 auto scalableFlags = type.getScalableDims();
1803 unsigned currentDim = 0;
1805 unsigned dim = m.getNumResults();
1808 for (
unsigned d = 0; d < dim; ++d) {
1809 size *=
shape[currentDim + d];
1810 flag |= scalableFlags[currentDim + d];
1812 newShape.push_back(size);
1813 newScalableFlags.push_back(flag);
1817 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1850vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1851 ArrayRef<int64_t> inputVectorSizes,
1852 SmallVectorImpl<Value> &newResults) {
1853 if (!inputVectorSizes.empty()) {
1854 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1855 "Invalid number of input vector sizes!");
1859 OpBuilder::InsertionGuard g(rewriter);
1862 Location loc = packOp.getLoc();
1863 std::optional<Value> padValue = packOp.getPaddingValue()
1864 ? std::optional(packOp.getPaddingValue())
1867 SmallVector<int64_t> destShape =
1868 SmallVector<int64_t>(packOp.getDestType().getShape());
1872 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1876 bool useInBoundsInsteadOfMasking =
false;
1877 if (writeVectorSizes.empty()) {
1878 if (ShapedType::isDynamicShape(destShape))
1880 "unable to infer vector sizes");
1882 writeVectorSizes = destShape;
1883 useInBoundsInsteadOfMasking =
true;
1892 PackingMetadata packMetadata;
1893 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1896 auto preTransposeWriteVecType =
1897 VectorType::get(preTransposeWriteVecSizses,
1898 packOp.getResult().getType().getElementType());
1904 preTransposeWriteVecType,
1906 rewriter.
getContext(), packMetadata.reassociations)));
1910 rewriter, loc, packOp.getSource(), readVecType, padValue,
1911 useInBoundsInsteadOfMasking);
1914 auto shapeCastOp = vector::ShapeCastOp::create(
1915 rewriter, loc, preTransposeWriteVecType, maskedRead);
1919 auto transposeOp = vector::TransposeOp::create(
1920 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1924 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1925 newResults.push_back(write->
getResult(0));
1959vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1960 ArrayRef<int64_t> inputVectorSizes,
1961 ArrayRef<bool> inputScalableVecDims,
1962 SmallVectorImpl<Value> &newResults) {
1963 if (!inputVectorSizes.empty()) {
1964 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1965 "Invalid number of input vector sizes!");
1966 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1967 "Incompatible number of vector sizes and vector scalable flags!");
1971 OpBuilder::InsertionGuard g(rewriter);
1974 ShapedType unpackTensorType = unpackOp.getSourceType();
1976 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1977 bool useInBoundsInsteadOfMasking =
false;
1979 Location loc = unpackOp->getLoc();
1982 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1983 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1986 if (inputVectorSizes.empty()) {
1987 if (ShapedType::isDynamicShape(sourceShape))
1989 "Unable to infer vector sizes!");
1991 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1992 useInBoundsInsteadOfMasking =
true;
1996 VectorType readVecType =
1997 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1998 readScalableVectorFlags);
2000 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
2001 useInBoundsInsteadOfMasking);
2004 PackingMetadata packMetadata;
2005 SmallVector<int64_t> lastDimToInsertPosPerm =
2007 vector::TransposeOp transposeOp = vector::TransposeOp::create(
2008 rewriter, loc, readResult, lastDimToInsertPosPerm);
2012 transposeOp.getType(),
2014 rewriter.
getContext(), packMetadata.reassociations)));
2015 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
2016 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2020 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2021 {}, useInBoundsInsteadOfMasking);
2023 newResults.push_back(write->
getResult(0));
2031vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2032 ArrayRef<int64_t> inputVectorSizes,
2033 SmallVectorImpl<Value> &newResults) {
2034 auto padValue = padOp.getConstantPaddingValue();
2035 Location loc = padOp.getLoc();
2038 OpBuilder::InsertionGuard g(rewriter);
2042 LogicalResult status =
2043 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2044 .reifyResultShapes(rewriter, reifiedReturnShapes);
2046 assert(succeeded(status) &&
"failed to reify result shapes");
2047 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2049 rewriter, loc, padOp.getSource(), readType, padValue,
2053 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2054 padOp.getResultType().getElementType());
2056 newResults.push_back(write->
getResult(0));
2062static LogicalResult reductionPreconditions(LinalgOp op) {
2064 LDBG() <<
"reduction precondition failed: no reduction iterator";
2067 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2068 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2074 LDBG() <<
"reduction precondition failed: reduction detection failed";
2082vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2083 bool flatten1DDepthwiseConv) {
2084 if (flatten1DDepthwiseConv) {
2085 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2091 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2097 Value
lhs = conv.getDpsInputOperand(0)->get();
2098 ArrayRef<int64_t> lhsShape = cast<ShapedType>(
lhs.getType()).getShape();
2099 auto shapeWithoutCh = lhsShape.drop_back(1);
2100 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2101 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2102 "channel dim can be dynamic";
2110vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2111 bool flatten1DDepthwiseConv) {
2112 if (isa<ConvolutionOpInterface>(op.getOperation()))
2113 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2116 return reductionPreconditions(op);
2121 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2125 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2135vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2136 ArrayRef<int64_t> inputVectorSizes) {
2138 if (!unpackOp.hasPureTensorSemantics())
2143 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2144 unpackOp.getSourceType().hasStaticShape())
2149 if (!inputVectorSizes.empty() &&
2150 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2151 LDBG() <<
"Incorrect number of input vector sizes";
2157 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2158 LDBG() <<
"Invalid vector sizes for the read operation";
2166vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2167 ArrayRef<int64_t> inputVectorSizes) {
2170 auto sourceType = source.getType();
2171 if (!VectorType::isValidElementType(sourceType.getElementType()))
2187 bool isOutOfBoundsRead =
2188 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2190 if (!padValue && isOutOfBoundsRead) {
2191 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2205vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2207 SmallVectorImpl<Value> &newResults) {
2208 Location loc = linalgOp.getLoc();
2209 MLIRContext *ctx = linalgOp.getContext();
2214 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2217 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2221 LDBG() <<
"Failed to determine contraction combining kind.";
2228 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2229 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2231 LDBG() <<
"Contractions with broadcasts are not supported.";
2236 SmallVector<Value> vecOperands;
2237 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2241 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2245 VectorType readType =
2246 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2249 rewriter, loc, opOperand.get(), readType,
2250 arith::getZeroConstant(rewriter, loc, elemType),
2252 vecOperands.push_back(read);
2256 SmallVector<Attribute> iterAttrs;
2257 auto iterators = linalgOp.getIteratorTypesArray();
2258 for (utils::IteratorType iter : iterators) {
2259 auto vecIter = iter == utils::IteratorType::parallel
2260 ? vector::IteratorType::parallel
2261 : vector::IteratorType::reduction;
2262 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2266 Operation *contractOp = vector::ContractionOp::create(
2267 rewriter, loc, vecOperands[0],
2268 vecOperands[1], vecOperands[2],
2269 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2270 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2274 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2278 newResults.push_back(write->
getResult(0));
2284enum class ConvOperationKind { Conv, Pool };
2287static bool isCastOfBlockArgument(Operation *op) {
2302static std::optional<ConvOperationKind>
2303getConvOperationKind(Operation *reduceOp) {
2304 int numBlockArguments =
2305 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2307 switch (numBlockArguments) {
2313 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2314 llvm::IsaPred<BlockArgument>);
2316 "Expected a non-block argument operand");
2317 Operation *feedOp = (*feedValIt).getDefiningOp();
2318 if (isCastOfBlockArgument(feedOp)) {
2319 return ConvOperationKind::Pool;
2322 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2323 (isa<arith::AndIOp>(feedOp) &&
2326 if (isa<BlockArgument>(v))
2328 if (Operation *op = v.getDefiningOp())
2329 return isCastOfBlockArgument(op);
2332 return std::nullopt;
2335 return ConvOperationKind::Conv;
2339 return ConvOperationKind::Pool;
2341 return std::nullopt;
2345static bool isSupportedPoolKind(vector::CombiningKind kind) {
2347 case vector::CombiningKind::ADD:
2348 case vector::CombiningKind::MAXNUMF:
2349 case vector::CombiningKind::MAXIMUMF:
2350 case vector::CombiningKind::MAXSI:
2351 case vector::CombiningKind::MAXUI:
2352 case vector::CombiningKind::MINNUMF:
2353 case vector::CombiningKind::MINIMUMF:
2354 case vector::CombiningKind::MINSI:
2355 case vector::CombiningKind::MINUI:
2362static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2363 auto getOperandType = [&](
auto operand) {
2364 return dyn_cast<ShapedType>((operand->get()).getType());
2366 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2367 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2368 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2372 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2373 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2380 auto maybeOper = getConvOperationKind(reduceOp);
2381 if (!maybeOper.has_value())
2388 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2389 *maybeKind != vector::CombiningKind::OR) &&
2390 (*maybeOper != ConvOperationKind::Pool ||
2391 !isSupportedPoolKind(*maybeKind)))) {
2395 auto rhsRank = rhsShapedType.getRank();
2396 if (*maybeOper == ConvOperationKind::Pool) {
2400 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2407static LogicalResult vectorizeLinalgOpPrecondition(
2408 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2409 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2411 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2412 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2416 if (!inputVectorSizes.empty() &&
2421 if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(
2422 linalgOp, flatten1DDepthwiseConv))) {
2423 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2427 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2433 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2436 customPreconditions,
2439 customPrecondition(&innerOp, vectorizeNDExtract));
2443 if (!llvm::all_of(innerOp.getOperandTypes(),
2444 VectorType::isValidElementType)) {
2447 if (!llvm::all_of(innerOp.getResultTypes(),
2448 VectorType::isValidElementType)) {
2457 return vectorizeConvOpPrecondition(linalgOp);
2463 LDBG() <<
"precondition failed: not projected permutations";
2466 if (
failed(reductionPreconditions(linalgOp))) {
2467 LDBG() <<
"precondition failed: reduction preconditions";
2474vectorizePackOpPrecondition(linalg::PackOp packOp,
2475 ArrayRef<int64_t> inputVectorSizes) {
2477 if (!packOp.hasPureTensorSemantics())
2480 auto padValue = packOp.getPaddingValue();
2484 LDBG() <<
"pad value is not constant: " << packOp;
2488 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2489 bool satisfyEmptyCond =
true;
2490 if (inputVectorSizes.empty()) {
2491 if (!packOp.getDestType().hasStaticShape() ||
2492 !packOp.getSourceType().hasStaticShape())
2493 satisfyEmptyCond =
false;
2496 if (!satisfyEmptyCond &&
2498 resultTensorShape.take_front(packOp.getSourceRank()),
2502 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2503 return !getConstantIntValue(v).has_value();
2505 LDBG() <<
"inner_tiles must be constant: " << packOp;
2513vectorizePadOpPrecondition(tensor::PadOp padOp,
2514 ArrayRef<int64_t> inputVectorSizes) {
2515 auto padValue = padOp.getConstantPaddingValue();
2517 LDBG() <<
"pad value is not constant: " << padOp;
2521 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2537 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2538 [&](
const auto &en) {
2539 OpFoldResult padValue = en.value();
2540 unsigned pos = en.index();
2541 std::optional<int64_t> pad = getConstantIntValue(padValue);
2542 return (!pad.has_value() || pad.value() != 0) &&
2543 resultTensorShape[pos] != 1;
2545 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2559vectorizeScalableVectorPrecondition(Operation *op,
2560 ArrayRef<int64_t> inputVectorSizes,
2561 ArrayRef<bool> inputScalableVecDims) {
2562 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2563 "Number of input vector sizes and scalable dims doesn't match");
2565 size_t numOfScalableDims =
2566 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2568 if (numOfScalableDims == 0)
2571 auto linalgOp = dyn_cast<LinalgOp>(op);
2576 return success(isa<linalg::UnPackOp>(op));
2580 if (numOfScalableDims > 2)
2600 bool seenNonUnitParallel =
false;
2601 auto iterators = linalgOp.getIteratorTypesArray();
2602 SmallVector<bool> scalableFlags(inputScalableVecDims);
2603 int64_t idx = scalableFlags.size() - 1;
2604 while (!scalableFlags[idx]) {
2605 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2606 seenNonUnitParallel |=
2607 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2609 iterators.pop_back();
2610 scalableFlags.pop_back();
2615 switch (iterators.back()) {
2616 case utils::IteratorType::reduction: {
2618 if (iterators.size() != inputVectorSizes.size()) {
2619 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2623 if (isa<linalg::MatmulOp>(op)) {
2625 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2631 case utils::IteratorType::parallel: {
2633 if (seenNonUnitParallel) {
2634 LDBG() <<
"Inner parallel dim not requested for scalable "
2646 if (numOfScalableDims == 2) {
2650 if (iterators.back() == utils::IteratorType::reduction) {
2651 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2656 scalableFlags.pop_back();
2657 iterators.pop_back();
2659 if (!scalableFlags.back() ||
2660 (iterators.back() != utils::IteratorType::parallel))
2668 isa<linalg::BatchMatmulOp>(op) ||
2670 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2675 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2676 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2677 bool flatten1DDepthwiseConv) {
2682 if (
failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2683 inputScalableVecDims)))
2687 .Case([&](linalg::LinalgOp linalgOp) {
2688 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2690 flatten1DDepthwiseConv);
2692 .Case([&](tensor::PadOp padOp) {
2693 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2695 .Case([&](linalg::PackOp packOp) {
2696 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2698 .Case([&](linalg::UnPackOp unpackOp) {
2699 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2701 .Case([&](tensor::InsertSliceOp sliceOp) {
2702 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2704 .Default(failure());
2708static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2709 OpBuilder::InsertionGuard g(rewriter);
2710 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2712 for (
auto op : make_early_inc_range(toReplace)) {
2714 auto expanded = affine::expandAffineExpr(
2716 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2717 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2723 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2724 tensor::InsertSliceOp>(op);
2728 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2729 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2730 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2731 bool createNamedContraction) {
2732 LDBG() <<
"Attempting to vectorize: " << *op;
2733 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2734 LDBG() <<
"Input scalable vector dims: "
2735 << llvm::interleaved(inputScalableVecDims);
2739 flatten1DDepthwiseConv))) {
2740 LDBG() <<
"Vectorization pre-conditions failed";
2745 VectorizationState state(rewriter);
2746 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2747 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2748 inputScalableVecDims,
2749 assumeDynamicDimsMatchVecSizes))) {
2750 LDBG() <<
"Vectorization state couldn't be initialized";
2755 SmallVector<Value> results;
2756 auto vectorizeResult =
2758 .Case([&](linalg::LinalgOp linalgOp) {
2762 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2763 flatten1DDepthwiseConv);
2764 if (succeeded(convOr)) {
2765 llvm::append_range(results, (*convOr)->getResults());
2769 LDBG() <<
"Unsupported convolution can't be vectorized.";
2773 if (createNamedContraction &&
2774 isa<ContractionOpInterface>(linalgOp.getOperation()))
2775 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2779 <<
"Vectorize generic by broadcasting to the canonical vector "
2783 convertAffineApply(rewriter, linalgOp);
2792 .Case([&](tensor::PadOp padOp) {
2793 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2796 .Case([&](linalg::PackOp packOp) {
2797 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2800 .Case([&](linalg::UnPackOp unpackOp) {
2801 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2803 inputScalableVecDims, results);
2805 .Case([&](tensor::InsertSliceOp sliceOp) {
2809 .Default(failure());
2811 if (
failed(vectorizeResult)) {
2812 LDBG() <<
"Vectorization failed";
2816 return VectorizationResult{results};
2820 memref::CopyOp copyOp) {
2821 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2822 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2823 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2828 if (!VectorType::isValidElementType(srcElementType) ||
2829 !VectorType::isValidElementType(dstElementType))
2832 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2833 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2835 Location loc = copyOp->getLoc();
2837 SmallVector<Value>
indices(srcType.getRank(), zero);
2839 Value
readValue = vector::TransferReadOp::create(
2840 rewriter, loc, readType, copyOp.getSource(),
indices,
2843 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2844 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2845 ArrayRef<int64_t>());
2847 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2849 Operation *writeValue = vector::TransferWriteOp::create(
2850 rewriter, loc, readValue, copyOp.getTarget(),
indices,
2861template <
typename OpTy>
2862struct VectorizePadOpUserPattern :
public OpRewritePattern<tensor::PadOp> {
2863 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2865 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2866 PatternRewriter &rewriter)
const final {
2869 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2870 if (
auto op = dyn_cast<OpTy>(user))
2871 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2876 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2877 tensor::PadOp padOp, OpTy op)
const = 0;
2899struct PadOpVectorizationWithTransferReadPattern
2900 :
public VectorizePadOpUserPattern<vector::TransferReadOp> {
2901 using VectorizePadOpUserPattern<
2902 vector::TransferReadOp>::VectorizePadOpUserPattern;
2904 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2905 vector::TransferReadOp xferOp)
const override {
2907 if (!padOp.hasZeroLowPad())
2910 auto padValue = padOp.getConstantPaddingValue();
2914 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2918 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2919 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2921 xferOp.getBaseMutable().assign(padOp.getSource());
2922 xferOp.getPaddingMutable().assign(padValue);
2961struct PadOpVectorizationWithTransferWritePattern
2962 :
public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2963 using VectorizePadOpUserPattern<
2964 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2966 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2967 vector::TransferWriteOp xferOp)
const override {
2969 if (xferOp.getTransferRank() == 0)
2973 if (!padOp.hasZeroLowPad())
2976 auto padValue = padOp.getConstantPaddingValue();
2980 if (!xferOp->hasOneUse())
2982 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2986 if (!trimPadding.hasZeroOffset())
2989 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2995 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2997 xferOp, padOp.getSource().
getType(), xferOp.getVector(),
2998 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
3000 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
3015 bool hasSameTensorSize(Value beforePadding,
3016 tensor::ExtractSliceOp afterTrimming)
const {
3019 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
3020 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3023 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
3024 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3029 if (t1.getRank() != t2.getRank())
3034 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3035 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3037 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3042 if (t1.getNumDynamicDims() == 0)
3050 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
3054 assert(
static_cast<size_t>(t1.getRank()) ==
3055 beforeSlice.getMixedSizes().size());
3056 assert(
static_cast<size_t>(t2.getRank()) ==
3057 afterTrimming.getMixedSizes().size());
3059 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3061 if (!t1.isDynamicDim(i))
3063 auto size1 = beforeSlice.getMixedSizes()[i];
3064 auto size2 = afterTrimming.getMixedSizes()[i];
3071 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3072 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3078 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3079 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3080 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3081 minOp1.getOperands() == minOp2.getOperands())
3107 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3108 auto source = bcast.getSource();
3109 if (llvm::dyn_cast<VectorType>(source.getType()))
3117 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3118 return fill.getInputs()[0];
3123 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3130 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3138 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3146 ArrayRef<int64_t> inputVectorSizes,
3147 SmallVectorImpl<Value> &newResults) {
3149 OpBuilder::InsertionGuard g(rewriter);
3153 auto sourceType = source.getType();
3154 auto resultType = sliceOp.getResultType();
3159 auto elemType = sourceType.getElementType();
3160 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3165 SmallVector<int64_t> vecShape;
3166 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3167 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3168 if (!inputVectorSizes.empty()) {
3169 vecShape.push_back(inputVectorSizes[i]);
3170 }
else if (!sourceType.isDynamicDim(i)) {
3171 vecShape.push_back(sourceType.getDimSize(i));
3172 }
else if (!resultType.isDynamicDim(i)) {
3178 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3185 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3188 auto loc = sliceOp.getLoc();
3191 SmallVector<Value> readIndices(
3194 rewriter, loc, source, vecType, padValue,
3195 inputVectorSizes.empty());
3202 writeIndices, inputVectorSizes.empty());
3205 newResults.push_back(write->
getResult(0));
3233struct PadOpVectorizationWithInsertSlicePattern
3234 :
public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3235 using VectorizePadOpUserPattern<
3236 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3238 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3239 tensor::InsertSliceOp insertOp)
const override {
3241 if (!padOp.hasZeroLowPad())
3244 if (!insertOp.hasUnitStride())
3247 auto padValue = padOp.getConstantPaddingValue();
3251 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3254 if (insertOp.getDest() == padOp.getResult())
3257 auto vecType = VectorType::get(padOp.getType().getShape(),
3258 padOp.getType().getElementType());
3259 unsigned vecRank = vecType.getRank();
3260 unsigned tensorRank = insertOp.getType().getRank();
3264 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3265 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3267 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3268 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3278 SmallVector<Value> readIndices(
3280 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3281 vecType, padOp.getSource(),
3282 readIndices, padValue);
3288 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3289 SmallVector<bool> inBounds(vecRank,
true);
3291 insertOp, read, insertOp.getDest(), writeIndices,
3292 ArrayRef<bool>{inBounds});
3299 RewritePatternSet &
patterns, PatternBenefit baseBenefit) {
3300 patterns.add<PadOpVectorizationWithTransferReadPattern,
3301 PadOpVectorizationWithTransferWritePattern,
3302 PadOpVectorizationWithInsertSlicePattern>(
3313static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3317 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3318 <<
", second op: " << *secondOp;
3321 for (
auto v : values) {
3322 for (
auto &u : v.getUses()) {
3323 Operation *owner = u.getOwner();
3324 if (owner == firstOp || owner == secondOp)
3330 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3331 <<
", second op: " << *secondOp;
3340static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3341 memref::SubViewOp subViewOp;
3343 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3345 return memref::SubViewOp();
3346 subViewOp = newSubViewOp;
3355 vector::TransferReadOp xferOp, PatternRewriter &rewriter)
const {
3358 if (xferOp.getMask())
3362 Value viewOrAlloc = xferOp.getBase();
3368 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3371 Value subView = subViewOp.getResult();
3374 memref::CopyOp copyOp;
3375 for (
auto &u : subView.
getUses()) {
3376 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3377 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3378 if (newCopyOp.getTarget() != subView)
3380 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3392 for (
auto &u : viewOrAlloc.
getUses()) {
3393 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3394 assert(isa<MemRefType>(newFillOp.output().getType()));
3395 if (newFillOp.output() != viewOrAlloc)
3397 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3399 maybeFillOp = newFillOp;
3404 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3406 "padding value does not match fill");
3409 Value in = copyOp.getSource();
3415 auto vectorType = xferOp.getVectorType();
3416 Value res = vector::TransferReadOp::create(
3417 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3418 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3420 SmallVector<bool>(vectorType.getRank(),
false)));
3423 rewriter.
eraseOp(maybeFillOp);
3433 vector::TransferWriteOp xferOp, PatternRewriter &rewriter)
const {
3435 if (xferOp.getMask())
3439 Value viewOrAlloc = xferOp.getBase();
3445 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3448 Value subView = subViewOp.getResult();
3451 memref::CopyOp copyOp;
3452 for (
auto &u : subViewOp.getResult().getUses()) {
3453 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3454 if (newCopyOp.getSource() != subView)
3456 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3466 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3467 Value out = copyOp.getTarget();
3474 auto vector = xferOp.getVector();
3475 vector::TransferWriteOp::create(
3476 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3477 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3479 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3492static void bindShapeDims(ShapedType shapedType) {}
3494template <
int N,
typename IntTy,
typename... IntTy2>
3495static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3496 val = shapedType.getShape()[N];
3497 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3501template <
typename... IntTy>
3502static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3503 bindShapeDims<0>(shapedType, vals...);
3508static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3509#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3510 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3532#undef MATCH_1D_CONV_POOL_OP
3534 return std::nullopt;
3572struct Conv1DGenerator
3573 :
public StructuredGenerator<LinalgOp, utils::IteratorType> {
3576 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3577 LinalgOp linalgOp) {
3580 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3584 int strideW =
static_cast<int>(convParams->strides.front());
3585 int dilationW =
static_cast<int>(convParams->dilations.front());
3586 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3590 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3592 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3593 strideW(strideW), dilationW(dilationW) {
3595 lhsShaped = linalgOp.getDpsInputOperand(0)->
get();
3596 rhsShaped = linalgOp.getDpsInputOperand(1)->
get();
3597 resShaped = linalgOp.getDpsInitOperand(0)->
get();
3598 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3599 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3600 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3605 setConvOperationKind(reduceOp);
3608 reductionKind = maybeKind.value();
3631 int64_t nSize, wSize, cSize, kwSize, fSize;
3632 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3634 switch (conv1DOpOrder) {
3637 nSize = fSize = cSize = 0;
3639 bindShapeDims(resShapedType, wSize);
3641 bindShapeDims(rhsShapedType, kwSize);
3644 (wSize + kwSize - 1)};
3645 rhsShape = {kwSize};
3650 bindShapeDims(resShapedType, nSize, wSize, fSize);
3652 case ConvOperationKind::Conv:
3654 bindShapeDims(rhsShapedType, kwSize, cSize);
3656 case ConvOperationKind::Pool:
3658 bindShapeDims(rhsShapedType, kwSize);
3666 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3670 case ConvOperationKind::Conv:
3671 rhsShape = {kwSize, cSize, fSize};
3673 case ConvOperationKind::Pool:
3674 rhsShape = {kwSize};
3677 resShape = {nSize, wSize, fSize};
3681 bindShapeDims(resShapedType, nSize, fSize, wSize);
3683 case ConvOperationKind::Conv:
3685 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3687 case ConvOperationKind::Pool:
3689 bindShapeDims(rhsShapedType, kwSize);
3693 lhsShape = {nSize, cSize,
3697 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3700 case ConvOperationKind::Conv:
3701 rhsShape = {fSize, cSize, kwSize};
3703 case ConvOperationKind::Pool:
3704 rhsShape = {kwSize};
3707 resShape = {nSize, fSize, wSize};
3711 vector::TransferWriteOp write;
3717 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3719 Type lhsEltType = lhsShapedType.getElementType();
3720 Type rhsEltType = rhsShapedType.getElementType();
3721 Type resEltType = resShapedType.getElementType();
3722 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3723 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3724 auto resType = VectorType::get(resShape, resEltType);
3726 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3727 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3728 SmallVector<Value> resPadding(resShape.size(), zero);
3731 Value
lhs = vector::TransferReadOp::create(
3732 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3733 arith::getZeroConstant(rewriter, loc, lhsEltType));
3735 Value
rhs =
nullptr;
3736 if (oper == ConvOperationKind::Conv)
3737 rhs = vector::TransferReadOp::create(
3738 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3739 arith::getZeroConstant(rewriter, loc, rhsEltType));
3740 Value res = vector::TransferReadOp::create(
3741 rewriter, loc, resType, resShaped, resPadding,
3742 arith::getZeroConstant(rewriter, loc, resEltType));
3747 switch (conv1DOpOrder) {
3755 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3756 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, permLhs);
3758 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3761 if (oper == ConvOperationKind::Conv)
3762 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, permRhs);
3764 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3765 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3774 SmallVector<Value> lhsVals, rhsVals, resVals;
3776 kwSize, strideW, dilationW, wSizeStep,
3779 if (oper == ConvOperationKind::Conv)
3782 wSizeStep, isSingleChanneled);
3784 auto linearIndex = [&](int64_t kw, int64_t w) {
3785 return kw * (wSize / wSizeStep) + w;
3791 for (int64_t kw = 0; kw < kwSize; ++kw) {
3792 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3794 case ConvOperationKind::Conv:
3795 if (isSingleChanneled) {
3796 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3797 lhsVals[linearIndex(kw, w)],
3798 rhsVals[kw], resVals[w]);
3800 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3801 lhsVals[linearIndex(kw, w)],
3802 rhsVals[kw], resVals[w]);
3805 case ConvOperationKind::Pool:
3806 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3822 switch (conv1DOpOrder) {
3829 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3830 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3835 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3841 Value
promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3844 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3845 if (srcElementType == dstElementType)
3850 const Type dstType =
3851 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3853 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3854 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3857 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3858 srcWidth < dstWidth)
3859 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3861 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3862 srcWidth < dstWidth)
3863 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3865 assert(
false &&
"unhandled promotion case");
3870 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3871 Value
lhs, Value
rhs, Value res) {
3872 vector::IteratorType par = vector::IteratorType::parallel;
3873 vector::IteratorType red = vector::IteratorType::reduction;
3874 AffineExpr n, w, f, c;
3878 auto contrationOp = vector::ContractionOp::create(
3879 rewriter, loc,
lhs,
rhs, res,
3880 MapList{{n, w, c}, {c, f}, {n, w, f}},
3881 ArrayRef<vector::IteratorType>{par, par, par, red});
3882 contrationOp.setKind(reductionKind);
3883 return contrationOp;
3888 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3889 Value
lhs, Value
rhs, Value res) {
3890 return vector::OuterProductOp::create(rewriter, loc, res.
getType(),
lhs,
3891 rhs, res, vector::CombiningKind::ADD);
3895 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value
lhs,
3913 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3914 bool channelDimScalableFlag,
3916 bool scalableChDim =
false;
3917 bool useMasking =
false;
3918 int64_t nSize, wSize, cSize, kwSize;
3920 bindShapeDims(rhsShapedType, kwSize, cSize);
3921 if (ShapedType::isDynamic(cSize)) {
3922 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3923 cSize = channelDimVecSize;
3927 scalableChDim = channelDimScalableFlag;
3931 assert(!(useMasking && flatten) &&
3932 "Unsupported flattened conv with dynamic shapes");
3935 bindShapeDims(resShapedType, nSize, wSize);
3937 vector::TransferWriteOp write;
3943 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3945 Type lhsEltType = lhsShapedType.getElementType();
3946 Type rhsEltType = rhsShapedType.getElementType();
3947 Type resEltType = resShapedType.getElementType();
3948 VectorType lhsType = VectorType::get(
3952 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3954 lhsEltType, {
false,
false, scalableChDim});
3955 VectorType rhsType =
3956 VectorType::get({kwSize, cSize}, rhsEltType,
3957 {
false, scalableChDim});
3958 VectorType resType =
3959 VectorType::get({nSize, wSize, cSize}, resEltType,
3960 {
false,
false, scalableChDim});
3964 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3965 ArrayRef<bool> scalableDims,
3966 Operation *opToMask) {
3970 VectorType::get(maskShape, rewriter.
getI1Type(), scalableDims);
3972 SmallVector<bool> inBounds(maskShape.size(),
true);
3973 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3974 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3978 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3981 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3988 Value
lhs = vector::TransferReadOp::create(
3989 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3990 arith::getZeroConstant(rewriter, loc, lhsEltType));
3991 auto *maybeMaskedLhs = maybeMaskXferOp(
3992 lhsType.getShape(), lhsType.getScalableDims(),
lhs.getDefiningOp());
3995 Value
rhs = vector::TransferReadOp::create(
3996 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3997 arith::getZeroConstant(rewriter, loc, rhsEltType));
3998 auto *maybeMaskedRhs = maybeMaskXferOp(
3999 rhsType.getShape(), rhsType.getScalableDims(),
rhs.getDefiningOp());
4002 Value res = vector::TransferReadOp::create(
4003 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
4004 arith::getZeroConstant(rewriter, loc, resEltType));
4005 auto *maybeMaskedRes = maybeMaskXferOp(
4006 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
4012 SmallVector<Value> lhsVals, rhsVals, resVals;
4013 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
4014 SmallVector<int64_t> inOutStrides = {1, 1, 1};
4018 for (int64_t kw = 0; kw < kwSize; ++kw) {
4019 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4020 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
4021 rewriter, loc, maybeMaskedLhs->getResult(0),
4022 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
4023 inOutSliceSizes, inOutStrides));
4027 for (int64_t kw = 0; kw < kwSize; ++kw) {
4029 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
4030 ArrayRef<int64_t>{kw}));
4033 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4034 resVals.push_back(vector::ExtractStridedSliceOp::create(
4035 rewriter, loc, maybeMaskedRes->getResult(0),
4036 ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
4040 auto linearIndex = [&](int64_t kw, int64_t w) {
4041 return kw * (wSize / wSizeStep) + w;
4046 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
4047 auto lhsTypeAfterFlattening =
4048 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
4049 auto resTypeAfterFlattening =
4050 VectorType::get(inOutFlattenSliceSizes, resEltType);
4053 for (int64_t kw = 0; kw < kwSize; ++kw) {
4054 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4055 Value lhsVal = lhsVals[linearIndex(kw, w)];
4056 Value resVal = resVals[w];
4061 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
4062 lhsVals[linearIndex(kw, w)]);
4063 resVal = vector::ShapeCastOp::create(
4064 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4066 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4067 rhsVals[kw], resVal, flatten);
4070 resVals[w] = vector::ShapeCastOp::create(
4071 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4078 if (!llvm::all_of(resVals, [](Value v) {
return v; })) {
4080 for (
auto &collection :
4081 {resVals, rhsVals, lhsVals, {res,
rhs,
lhs, zero}})
4082 for (Value v : collection)
4089 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4090 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4091 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4092 ArrayRef<int64_t>{0, w, 0},
4093 ArrayRef<int64_t>{1, 1, 1});
4100 Operation *resOut = vector::TransferWriteOp::create(
4101 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4103 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4111 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4112 Value
lhs, Value
rhs, Value res,
4114 auto rhsTy = cast<ShapedType>(
rhs.getType());
4115 auto resTy = cast<ShapedType>(res.
getType());
4129 auto rhsSize = cast<VectorType>(
rhs.getType()).getShape()[0];
4130 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4132 SmallVector<int64_t, 16>
indices;
4133 for (
int i = 0; i < resSize / rhsSize; ++i) {
4134 for (
int j = 0; j < rhsSize; ++j)
4141 rhs = vector::BroadcastOp::create(rewriter, loc,
4142 resTy.clone(rhsTy.getElementType()),
rhs);
4149 if (isa<FloatType>(resTy.getElementType()))
4150 return vector::FMAOp::create(rewriter, loc,
lhs,
rhs, res);
4152 auto mul = arith::MulIOp::create(rewriter, loc,
lhs,
rhs);
4153 return arith::AddIOp::create(rewriter, loc,
mul, res);
4158 FailureOr<Operation *> generateNonChanneledConv() {
4161 if (!iters({Par(), Red()}))
4163 "failed to match conv::W 1-par 1-red");
4166 if (layout({ {w + kw},
4176 FailureOr<Operation *> generateNwcConv() {
4177 AffineExpr n, w, f, kw, c;
4179 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4181 op,
"failed to match conv::Nwc 3-par 2-red");
4184 if (layout({ {n, strideW * w + dilationW * kw, c},
4194 FailureOr<Operation *> generateNcwConv() {
4195 AffineExpr n, w, f, kw, c;
4197 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4199 op,
"failed to match conv::Ncw 3-par 2-red");
4201 if (layout({ {n, c, strideW * w + dilationW * kw},
4211 FailureOr<Operation *> generateNwcPooling() {
4212 AffineExpr n, w, c, kw;
4214 if (!iters({Par(), Par(), Par(), Red()}))
4216 "failed to match pooling 3-par 1-red");
4219 if (layout({ {n, strideW * w + dilationW * kw, c},
4229 FailureOr<Operation *> generateNcwPooling() {
4230 AffineExpr n, w, c, kw;
4232 if (!iters({Par(), Par(), Par(), Red()}))
4234 "failed to match pooling 3-par 1-red");
4236 if (layout({ {n, c, strideW * w + dilationW * kw},
4246 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4247 bool vecChDimScalableFlag =
false,
4248 bool flatten =
false) {
4249 AffineExpr n, w, c, kw;
4251 if (!iters({Par(), Par(), Par(), Red()}))
4253 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4256 if (layout({ {n, strideW * w + dilationW * kw, c},
4259 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4265 ConvOperationKind oper = ConvOperationKind::Conv;
4267 StringAttr poolExtOp;
4268 bool isPoolExt =
false;
4269 int strideW, dilationW;
4270 Value lhsShaped, rhsShaped, resShaped;
4271 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4272 vector::CombiningKind reductionKind;
4275 void setConvOperationKind(Operation *reduceOp) {
4276 int numBlockArguments =
4277 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4278 if (numBlockArguments == 1) {
4283 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4284 llvm::IsaPred<BlockArgument>);
4285 Operation *feedOp = (*feedValIt).getDefiningOp();
4286 if (isCastOfBlockArgument(feedOp)) {
4287 oper = ConvOperationKind::Pool;
4292 oper = ConvOperationKind::Conv;
4296 oper = ConvOperationKind::Pool;
4305 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4306 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4307 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4310 auto res = conv1dGen->generateNonChanneledConv();
4313 res = conv1dGen->generateNwcConv();
4316 res = conv1dGen->generateNcwConv();
4319 res = conv1dGen->generateNwcPooling();
4322 res = conv1dGen->generateNcwPooling();
4329 uint64_t vecChDimSize = ShapedType::kDynamic;
4330 bool vecChDimScalableFlag =
false;
4331 if (!inputVecSizes.empty()) {
4336 "Not a 1D depthwise conv!");
4337 size_t chDimIdx = 0;
4343 vecChDimSize = inputVecSizes[chDimIdx];
4344 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4346 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4347 flatten1DDepthwiseConv);
4350struct VectorizeConvolution :
public OpInterfaceRewritePattern<LinalgOp> {
4353 LogicalResult matchAndRewrite(LinalgOp op,
4354 PatternRewriter &rewriter)
const override {
4356 if (
failed(resultOrFail))
4358 Operation *newOp = *resultOrFail;
4360 rewriter.
eraseOp(op.getOperation());
4363 assert(newOp->
getNumResults() == 1 &&
"expected single result");
4370 RewritePatternSet &
patterns, PatternBenefit benefit) {
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override