38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/SmallVectorExtras.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/DebugLog.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/MathExtras.h"
46#include "llvm/Support/raw_ostream.h"
52#define DEBUG_TYPE "linalg-vectorization"
55static FailureOr<Operation *>
59 bool flatten1DDepthwiseConv =
false);
94template <
typename OpType>
97 block.
walk([&](OpType op) {
113 int64_t kwSize,
int strideW,
int dilationW,
114 int64_t wSizeStep,
bool isSingleChanneled) {
116 if (isSingleChanneled) {
121 for (
int64_t kw = 0; kw < kwSize; ++kw) {
122 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
123 result.push_back(vector::ExtractStridedSliceOp::create(
133 for (
int64_t kw = 0; kw < kwSize; ++kw) {
134 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(vector::ExtractStridedSliceOp::create(
136 rewriter, loc, input,
153 for (
int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(vector::ExtractOp::create(
165 int64_t wSizeStep,
bool isSingleChanneled) {
167 if (isSingleChanneled) {
171 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(vector::ExtractStridedSliceOp::create(
181 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(vector::ExtractStridedSliceOp::create(
194 bool isSingleChanneled) {
196 if (isSingleChanneled) {
200 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
201 res = vector::InsertStridedSliceOp::create(
209 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
210 res = vector::InsertStridedSliceOp::create(
211 rewriter, loc, resVals[w], res,
225 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
228 bool assumeDynamicDimsMatchVecSizes =
false);
243 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
246 if (dimPermutation.has_value()) {
252 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
253 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
256 return VectorType::get(
vectorShape, elementType, scalableDims);
265 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
270 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
271 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
277 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
284 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
286 std::optional<AffineMap> maybeMaskingMap);
291 bool isValidMaskingMap(AffineMap maskingMap) {
310 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
316 SmallVector<int64_t> iterSpaceStaticSizes;
321 SmallVector<Value> iterSpaceValueSizes;
324 SmallVector<int64_t> canonicalVecShape;
328 SmallVector<bool> scalableVecDims;
336 OpBuilder::InsertionGuard rewriterGuard;
344 bool assumeDynamicDimsMatchVecSizes =
false;
348VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
351 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
352 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
355 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
362 unsigned operandDimPos;
363 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
368 linalgOp.hasPureTensorSemantics()
369 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
373 iterSpaceValueSizes.push_back(dynamicDim);
386 bool assumeDimsMatchVec) {
387 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
391 if (!inputVectorSizes.empty()) {
395 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
396 scalableVecDims.append(inputScalableVecDims.begin(),
397 inputScalableVecDims.end());
402 canonicalVecShape = linalgOp.getStaticLoopRanges();
403 scalableVecDims.append(linalgOp.getNumLoops(),
false);
406 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
407 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
409 if (ShapedType::isDynamicShape(canonicalVecShape))
413 initIterSpaceStaticSizes(linalgOp);
418 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
428Value VectorizationState::getOrCreateMaskFor(
430 std::optional<AffineMap> maybeMaskingMap) {
432 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
433 "Ill-formed masking map.");
436 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
440 assert(!maskableOp.isMasked() &&
441 "Masking an operation that is already masked");
444 assert((!maybeMaskingMap || *maybeMaskingMap) &&
445 "Unexpected null mask permutation map");
447 maybeMaskingMap ? *maybeMaskingMap
449 linalgOp.getNumLoops(), rewriter.
getContext());
451 LDBG() <<
"Masking map: " << maskingMap;
455 auto activeMaskIt = activeMaskCache.find(maskingMap);
456 if (activeMaskIt != activeMaskCache.end()) {
457 Value mask = activeMaskIt->second;
458 LDBG() <<
"Reusing mask: " << mask;
468 SmallVector<int64_t> permutedStaticSizes =
470 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
471 auto maskShape = maskType.getShape();
473 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
475 if (permutedStaticSizes == maskShape) {
476 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
477 activeMaskCache[maskingMap] = Value();
481 if (assumeDynamicDimsMatchVecSizes) {
485 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
487 return std::get<0>(it) == ShapedType::kDynamic
489 : std::get<0>(it) == std::get<1>(it);
492 <<
"Dynamic + static dimensions match vector sizes, masking is not "
494 activeMaskCache[maskingMap] = Value();
500 SmallVector<Value> upperBounds =
502 assert(!maskShape.empty() && !upperBounds.empty() &&
503 "Masked 0-d vectors are not supported yet");
506 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
507 maskType, upperBounds);
508 LDBG() <<
"Creating new mask: " << mask;
509 activeMaskCache[maskingMap] = mask;
516 std::optional<AffineMap> maybeIndexingMap) {
517 LDBG() <<
"Trying to mask: " << *opToMask;
519 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
520 if (maybeIndexingMap)
521 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
525 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
528 LDBG() <<
"No mask required";
529 if (assumeDynamicDimsMatchVecSizes) {
531 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
537 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
538 "setting their in-bounds to true!";
540 ShapedType xferType = xferOp.getShapedType();
545 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
546 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
550 unsigned pos = dimExpr.getPosition();
551 if (xferType.isDynamicDim(pos))
552 inBoundsMap[i] =
true;
555 xferOp.setInBoundsAttr(
567 assert(opToMask &&
"Expected a valid operation to mask");
568 auto maskOp = cast<vector::MaskOp>(
570 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
572 for (
auto [resIdx, resVal] : llvm::enumerate(opToMask->
getResults()))
576 LDBG() <<
"Masked operation: " << *maskOp;
599 "expected projected permutation");
601 assert(res.getNumDims() ==
602 (res.getNumResults() - res.getNumOfZeroResults()) &&
603 "expected reindexed map with same number of dims and results");
639std::optional<vector::CombiningKind>
641 using ::mlir::vector::CombiningKind;
646 .Case<arith::AddIOp, arith::AddFOp>(
647 [&](
auto op) {
return CombiningKind::ADD; })
648 .Case([&](arith::AndIOp op) {
return CombiningKind::AND; })
649 .Case([&](arith::MaxSIOp op) {
return CombiningKind::MAXSI; })
650 .Case([&](arith::MaxUIOp op) {
return CombiningKind::MAXUI; })
651 .Case([&](arith::MaximumFOp op) {
return CombiningKind::MAXIMUMF; })
652 .Case([&](arith::MaxNumFOp op) {
return CombiningKind::MAXNUMF; })
653 .Case([&](arith::MinSIOp op) {
return CombiningKind::MINSI; })
654 .Case([&](arith::MinUIOp op) {
return CombiningKind::MINUI; })
655 .Case([&](arith::MinimumFOp op) {
return CombiningKind::MINIMUMF; })
656 .Case([&](arith::MinNumFOp op) {
return CombiningKind::MINNUMF; })
657 .Case<arith::MulIOp, arith::MulFOp>(
658 [&](
auto op) {
return CombiningKind::MUL; })
659 .Case([&](arith::OrIOp op) {
return CombiningKind::OR; })
660 .Case([&](arith::XOrIOp op) {
return CombiningKind::XOR; })
661 .Default(std::nullopt);
672 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
677 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
678 combinerOps.size() != 1)
682 return combinerOps[0];
688 auto dstVecType = dyn_cast<VectorType>(dstType);
690 if (dstVecType.getRank() == 0)
695 Location loc =
b.getInsertionPoint()->getLoc();
696 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
708 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
709 return vector::MultiDimReductionOp::create(
710 b, reduceOp->
getLoc(), valueToReduce,
acc, dimsToMask, *maybeKind);
714 return llvm::map_to_vector(linalgOp.getIteratorTypesArray(),
721 return isa<linalg::ReduceOp>(op) ||
722 (isa<linalg::GenericOp>(op) &&
734 VectorizationState &state) {
736 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
737 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
746 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
748 auto vectorType = state.getCanonicalVecType(
755 if (vectorType.getRank() > 0) {
758 assert(value.
getType() == vectorType &&
"Incorrect type");
759 write = vector::TransferWriteOp::create(
760 rewriter, loc, value, outputOperand->
get(),
indices, writeMap);
763 if (!isa<VectorType>(value.
getType()))
764 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
765 assert(value.
getType() == vectorType &&
"Incorrect type");
766 write = vector::TransferWriteOp::create(rewriter, loc, value,
770 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
774 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
775 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
780 LDBG() <<
"vectorized op: " << *write;
790 std::function<LogicalResult(
Operation *,
bool)>;
807 const IRMapping &bvm, VectorizationState &state,
809 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
812 for (
const auto &output : llvm::enumerate(yieldOp.getValues())) {
818 linalgOp.getDpsInitOperand(output.index()), state);
820 newResults.push_back(newResult);
831 VectorizationState &state,
834 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
837 auto loc = indexOp.getLoc();
840 auto dim = indexOp.getDim();
842 auto indexVectorType =
843 VectorType::get({targetShape[dim]}, rewriter.
getIndexType(),
844 state.getScalableVecDims()[dim]);
845 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
849 if (dim == targetShape.size() - 1)
855 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
856 std::swap(permPattern[dim], permPattern.back());
860 auto broadCastOp = vector::BroadcastOp::create(
862 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
864 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
865 std::swap(transposition.back(), transposition[dim]);
867 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
875 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
879 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
884 if (not extractOp.getIndices().empty()) {
885 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
889 if (!llvm::all_of(extractOp->getResultTypes(),
890 VectorType::isValidElementType)) {
908 VectorizationState &state,
909 tensor::ExtractOp extractOp,
912 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
913 auto loc = extractOp.getLoc();
916 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
918 const size_t numIndices = extractOp.getIndices().size();
919 for (
size_t i = 1; i < numIndices; i++) {
924 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
927 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
930 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
932 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
958 (linalgOp.hasDynamicShape() ||
959 llvm::count_if(loopRanges, [](
int64_t dim) { return dim != 1; }) == 1) &&
960 "For statically shaped Linalg Ops, only one "
961 "non-unit loop dim is expected");
962 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
964 size_t idx = loopRanges.size() - 1;
965 for (; idx != 0; idx--)
966 if (loopRanges[idx] != 1)
974 VectorType resType) {
976 assert(((llvm::count_if(resType.getShape(),
977 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
978 "n-D vectors are not yet supported");
984 auto *block = linalgOp.getBlock();
985 if (isa<BlockArgument>(val))
986 return !llvm::is_contained(block->getArguments(), val);
989 assert(defOp &&
"This is neither a block argument nor an operation result");
994 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
995 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
998 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1005 if (isa<arith::ConstantOp>(ancestor))
1009 for (
auto op : ancestor->getOperands())
1033 bool &foundIndexOp, VectorType resType) {
1035 assert(((llvm::count_if(resType.getShape(),
1036 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1037 "n-D vectors are not yet supported");
1043 auto *block = linalgOp.getBlock();
1044 if (isa<BlockArgument>(val))
1045 return !llvm::is_contained(block->getArguments(), val);
1048 assert(defOp &&
"This is neither a block argument nor an operation result");
1050 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1053 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1057 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1064 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1068 for (
auto op : ancestor->getOperands())
1088 LinalgOp &linalgOp, VectorType resType) {
1090 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1093 if (inputShape.getShape().empty())
1098 bool isOutput1DVector =
1099 (llvm::count_if(resType.getShape(),
1100 [](
int64_t dimSize) { return dimSize > 1; }) == 1);
1102 if (!isOutput1DVector)
1105 bool leadingIdxsLoopInvariant =
true;
1111 auto indices = extractOp.getIndices();
1112 auto leadIndices =
indices.drop_back(1);
1114 for (
auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1115 if (inputShape.getShape()[i] == 1)
1121 if (!leadingIdxsLoopInvariant) {
1122 LDBG() <<
"Found gather load: " << extractOp;
1130 auto extractOpTrailingIdx =
indices.back();
1134 if (leadingIdxsLoopInvariant &&
1136 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1145 bool foundIndexOp =
false;
1147 foundIndexOp, resType);
1150 bool isRowVector = resType.getShape().back() != 1;
1151 isContiguousLoad &= (foundIndexOp && isRowVector);
1153 if (isContiguousLoad) {
1154 LDBG() <<
"Found contigous load: " << extractOp;
1159 LDBG() <<
"Found gather load: " << extractOp;
1167static VectorizationHookResult
1170 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1173 auto loc = extractOp.getLoc();
1176 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1177 auto maskConstantOp = arith::ConstantOp::create(
1181 auto passThruConstantOp = arith::ConstantOp::create(
1187 extractOp.getIndices().size(),
1198 Operation *gatherOp = vector::GatherOp::create(
1199 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1200 maskConstantOp, passThruConstantOp);
1201 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1203 LDBG() <<
"Vectorised as gather load: " << extractOp;
1226 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1227 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1229 transferReadIdxs.push_back(idx);
1233 auto indexAs1dVector = vector::ShapeCastOp::create(
1235 VectorType::get(resultType.getShape().back(), rewriter.
getIndexType(),
1236 resultType.getScalableDims().back()),
1238 transferReadIdxs.push_back(
1239 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1243 auto dstRank = resultType.getRank();
1244 auto srcRank = extractOp.getTensor().getType().getRank();
1253 auto transferReadOp = vector::TransferReadOp::create(
1254 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1255 std::nullopt, permutationMap, inBounds);
1261 auto readMaskType = VectorType::get(readMaskShape, rewriter.
getI1Type());
1262 auto allTrue = vector::ConstantMaskOp::create(
1264 auto *maskedReadOp =
1267 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1274 srcRank, std::min(dstRank, srcRank), rewriter.
getContext());
1276 int32_t rankDiff = dstRank - srcRank;
1284 while (rankDiff > 0) {
1285 permutationMap = permutationMap.insertResult(
1290 auto transferReadOp = vector::TransferReadOp::create(
1291 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1292 std::nullopt, permutationMap, inBounds);
1294 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1308 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1309 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1313 (outputType && reduceType.getShape() == outputType.getShape()))
1338static VectorizationHookResult
1342 LDBG() <<
"vectorize op " << *op;
1345 if (!customVectorizationHooks.empty()) {
1346 for (
auto &customFunc : customVectorizationHooks) {
1356 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1358 rewriter.
clone(*op)};
1367 auto blockArg = dyn_cast<BlockArgument>(operand);
1368 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1369 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1373 linalgOp.getRegionOutputArgs(),
1374 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1377 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1379 if (!reductionOperands.empty()) {
1380 assert(reductionOperands.size() == 1);
1382 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1383 reductionOperands[0].second, bvm);
1390 VectorType firstMaxRankedType;
1392 auto vecOperand = bvm.
lookup(operand);
1393 assert(vecOperand &&
"Vector operand couldn't be found");
1395 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1396 if (vecType && (!firstMaxRankedType ||
1397 firstMaxRankedType.getRank() < vecType.getRank()))
1398 firstMaxRankedType = vecType;
1404 assert(vecOperand &&
"Vector operand couldn't be found");
1406 if (firstMaxRankedType) {
1407 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1409 firstMaxRankedType.getScalableDims());
1412 vecOperands.push_back(vecOperand);
1418 resultTypes.push_back(
1420 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1421 firstMaxRankedType.getScalableDims())
1457 LDBG() <<
"Vectorizing operation as linalg generic/n";
1458 Block *block = linalgOp.getBlock();
1465 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1467 if (linalgOp.getNumDpsInits() == 0)
1473 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1474 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1475 if (linalgOp.isScalar(opOperand)) {
1476 bvm.
map(bbarg, opOperand->get());
1482 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1485 VectorType readType;
1487 if (linalgOp.isDpsInput(opOperand)) {
1490 readType = state.getCanonicalVecType(elemType);
1497 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1502 Operation *read = vector::TransferReadOp::create(
1503 rewriter, loc, readType, opOperand->get(),
indices,
1504 std::nullopt, readMap);
1505 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1510 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1512 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1518 if (readType.getRank() == 0)
1519 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1522 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1523 <<
"): " << readValue;
1524 bvm.
map(bbarg, readValue);
1525 bvm.
map(opOperand->get(), readValue);
1534 hooks.push_back(vectorizeYield);
1541 hooks.push_back(vectorizeIndex);
1548 hooks.push_back(vectorizeExtract);
1555 LDBG() <<
"failed to vectorize: " << op;
1560 state.maskOperation(rewriter,
result.newOp, linalgOp);
1561 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1620 if (ShapedType::isDynamicShape(destShape))
1625 for (
auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1627 cstMaskSizes.push_back(*intSize);
1632 if (cstMaskSizes.size() != maskShape.size())
1637 for (
auto [i, idx] : llvm::enumerate(writeIdxs)) {
1640 cstWriteIdxs.push_back(intVal.getSExtValue());
1645 if (cstWriteIdxs.size() != destShape.size())
1654 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1655 for (
auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1656 if ( maskShape[i] > destShape[rankDiff + i] ||
1657 destShape[rankDiff + i] <
1658 (std::clamp(cstMaskSizes[i],
int64_t(0), maskShape[i]) +
1694 bool useInBoundsInsteadOfMasking =
false) {
1696 ShapedType destType = cast<ShapedType>(dest.
getType());
1697 int64_t destRank = destType.getRank();
1698 auto destShape = destType.getShape();
1700 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1701 int64_t vecToStoreRank = vecToStoreType.getRank();
1702 auto vecToStoreShape = vecToStoreType.getShape();
1705 SmallVector<bool> inBoundsVal(vecToStoreRank,
true);
1706 if (useInBoundsInsteadOfMasking) {
1709 for (
unsigned i = 0; i < vecToStoreRank; i++)
1711 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1712 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1716 bool useDefaultWriteIdxs = writeIndices.empty();
1717 assert((useDefaultWriteIdxs ||
1718 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1719 "Invalid number of write indices!");
1720 if (writeIndices.empty()) {
1722 writeIndices.assign(destRank, zero);
1726 Operation *write = vector::TransferWriteOp::create(builder, loc,
1733 if (useInBoundsInsteadOfMasking)
1737 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1741 auto writeMaskType = VectorType::get(vecToStoreShape, builder.
getI1Type(),
1742 vecToStoreType.getScalableDims());
1744 SmallVector<OpFoldResult> destSizes =
1745 isa<MemRefType>(dest.
getType())
1750 SmallVector<OpFoldResult> maskSizes;
1751 if (useDefaultWriteIdxs) {
1752 maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
1755 size_t diff = destShape.size() - vecToStoreRank;
1756 for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
1760 builder.
createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
1761 maskSizes.push_back(OpFoldResult(neg));
1769 Value maskForWrite =
1770 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1792 assert(type.getNumScalableDims() < 2 &&
1793 "Collapsing more than 1 scalable dim is not supported ATM");
1799 auto shape = type.getShape();
1800 auto scalableFlags = type.getScalableDims();
1804 unsigned currentDim = 0;
1806 unsigned dim = m.getNumResults();
1809 for (
unsigned d = 0; d < dim; ++d) {
1810 size *=
shape[currentDim + d];
1811 flag |= scalableFlags[currentDim + d];
1813 newShape.push_back(size);
1814 newScalableFlags.push_back(flag);
1818 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1851vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1852 ArrayRef<int64_t> inputVectorSizes,
1853 SmallVectorImpl<Value> &newResults) {
1854 if (!inputVectorSizes.empty()) {
1855 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1856 "Invalid number of input vector sizes!");
1860 OpBuilder::InsertionGuard g(rewriter);
1863 Location loc = packOp.getLoc();
1864 std::optional<Value> padValue = packOp.getPaddingValue()
1865 ? std::optional(packOp.getPaddingValue())
1868 SmallVector<int64_t> destShape =
1869 SmallVector<int64_t>(packOp.getDestType().getShape());
1873 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1877 bool useInBoundsInsteadOfMasking =
false;
1878 if (writeVectorSizes.empty()) {
1879 if (ShapedType::isDynamicShape(destShape))
1881 "unable to infer vector sizes");
1883 writeVectorSizes = destShape;
1884 useInBoundsInsteadOfMasking =
true;
1893 PackingMetadata packMetadata;
1894 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1897 auto preTransposeWriteVecType =
1898 VectorType::get(preTransposeWriteVecSizses,
1899 packOp.getResult().getType().getElementType());
1905 preTransposeWriteVecType,
1907 rewriter.
getContext(), packMetadata.reassociations)));
1911 rewriter, loc, packOp.getSource(), readVecType, padValue,
1912 useInBoundsInsteadOfMasking);
1915 auto shapeCastOp = vector::ShapeCastOp::create(
1916 rewriter, loc, preTransposeWriteVecType, maskedRead);
1920 auto transposeOp = vector::TransposeOp::create(
1921 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1925 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1926 newResults.push_back(write->
getResult(0));
1960vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1961 ArrayRef<int64_t> inputVectorSizes,
1962 ArrayRef<bool> inputScalableVecDims,
1963 SmallVectorImpl<Value> &newResults) {
1964 if (!inputVectorSizes.empty()) {
1965 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1966 "Invalid number of input vector sizes!");
1967 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1968 "Incompatible number of vector sizes and vector scalable flags!");
1972 OpBuilder::InsertionGuard g(rewriter);
1975 ShapedType unpackTensorType = unpackOp.getSourceType();
1977 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1978 bool useInBoundsInsteadOfMasking =
false;
1980 Location loc = unpackOp->getLoc();
1983 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1984 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1987 if (inputVectorSizes.empty()) {
1988 if (ShapedType::isDynamicShape(sourceShape))
1990 "Unable to infer vector sizes!");
1992 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1993 useInBoundsInsteadOfMasking =
true;
1997 VectorType readVecType =
1998 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1999 readScalableVectorFlags);
2001 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
2002 useInBoundsInsteadOfMasking);
2005 PackingMetadata packMetadata;
2006 SmallVector<int64_t> lastDimToInsertPosPerm =
2008 vector::TransposeOp transposeOp = vector::TransposeOp::create(
2009 rewriter, loc, readResult, lastDimToInsertPosPerm);
2013 transposeOp.getType(),
2015 rewriter.
getContext(), packMetadata.reassociations)));
2016 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
2017 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2021 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2022 {}, useInBoundsInsteadOfMasking);
2024 newResults.push_back(write->
getResult(0));
2032vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2033 ArrayRef<int64_t> inputVectorSizes,
2034 SmallVectorImpl<Value> &newResults) {
2035 auto padValue = padOp.getConstantPaddingValue();
2036 Location loc = padOp.getLoc();
2039 OpBuilder::InsertionGuard g(rewriter);
2043 LogicalResult status =
2044 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2045 .reifyResultShapes(rewriter, reifiedReturnShapes);
2047 assert(succeeded(status) &&
"failed to reify result shapes");
2048 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2050 rewriter, loc, padOp.getSource(), readType, padValue,
2054 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2055 padOp.getResultType().getElementType());
2057 newResults.push_back(write->
getResult(0));
2063static LogicalResult reductionPreconditions(LinalgOp op) {
2065 LDBG() <<
"reduction precondition failed: no reduction iterator";
2068 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2069 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2075 LDBG() <<
"reduction precondition failed: reduction detection failed";
2083vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2084 bool flatten1DDepthwiseConv) {
2085 if (flatten1DDepthwiseConv) {
2086 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2092 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2098 Value
lhs = conv.getDpsInputOperand(0)->get();
2099 ArrayRef<int64_t> lhsShape = cast<ShapedType>(
lhs.getType()).getShape();
2100 auto shapeWithoutCh = lhsShape.drop_back(1);
2101 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2102 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2103 "channel dim can be dynamic";
2111vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2112 bool flatten1DDepthwiseConv) {
2113 if (isa<ConvolutionOpInterface>(op.getOperation()))
2114 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2117 return reductionPreconditions(op);
2122 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2126 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2136vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2137 ArrayRef<int64_t> inputVectorSizes) {
2139 if (!unpackOp.hasPureTensorSemantics())
2144 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2145 unpackOp.getSourceType().hasStaticShape())
2150 if (!inputVectorSizes.empty() &&
2151 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2152 LDBG() <<
"Incorrect number of input vector sizes";
2158 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2159 LDBG() <<
"Invalid vector sizes for the read operation";
2167vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2168 ArrayRef<int64_t> inputVectorSizes) {
2171 auto sourceType = source.getType();
2172 if (!VectorType::isValidElementType(sourceType.getElementType()))
2188 bool isOutOfBoundsRead =
2189 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2191 if (!padValue && isOutOfBoundsRead) {
2192 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2206vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2208 SmallVectorImpl<Value> &newResults) {
2209 Location loc = linalgOp.getLoc();
2210 MLIRContext *ctx = linalgOp.getContext();
2215 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2218 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2222 LDBG() <<
"Failed to determine contraction combining kind.";
2229 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2230 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2232 LDBG() <<
"Contractions with broadcasts are not supported.";
2237 SmallVector<Value> vecOperands;
2238 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2242 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2246 VectorType readType =
2247 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2250 rewriter, loc, opOperand.get(), readType,
2251 arith::getZeroConstant(rewriter, loc, elemType),
2253 vecOperands.push_back(read);
2257 SmallVector<Attribute> iterAttrs;
2258 auto iterators = linalgOp.getIteratorTypesArray();
2259 for (utils::IteratorType iter : iterators) {
2260 auto vecIter = iter == utils::IteratorType::parallel
2261 ? vector::IteratorType::parallel
2262 : vector::IteratorType::reduction;
2263 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2267 Operation *contractOp = vector::ContractionOp::create(
2268 rewriter, loc, vecOperands[0],
2269 vecOperands[1], vecOperands[2],
2270 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2271 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2275 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2279 newResults.push_back(write->
getResult(0));
2285enum class ConvOperationKind { Conv, Pool };
2288static bool isCastOfBlockArgument(Operation *op) {
2303static std::optional<ConvOperationKind>
2304getConvOperationKind(Operation *reduceOp) {
2305 int numBlockArguments =
2306 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2308 switch (numBlockArguments) {
2314 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2315 llvm::IsaPred<BlockArgument>);
2317 "Expected a non-block argument operand");
2318 Operation *feedOp = (*feedValIt).getDefiningOp();
2319 if (isCastOfBlockArgument(feedOp)) {
2320 return ConvOperationKind::Pool;
2323 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2324 (isa<arith::AndIOp>(feedOp) &&
2327 if (isa<BlockArgument>(v))
2329 if (Operation *op = v.getDefiningOp())
2330 return isCastOfBlockArgument(op);
2333 return std::nullopt;
2336 return ConvOperationKind::Conv;
2340 return ConvOperationKind::Pool;
2342 return std::nullopt;
2346static bool isSupportedPoolKind(vector::CombiningKind kind) {
2348 case vector::CombiningKind::ADD:
2349 case vector::CombiningKind::MAXNUMF:
2350 case vector::CombiningKind::MAXIMUMF:
2351 case vector::CombiningKind::MAXSI:
2352 case vector::CombiningKind::MAXUI:
2353 case vector::CombiningKind::MINNUMF:
2354 case vector::CombiningKind::MINIMUMF:
2355 case vector::CombiningKind::MINSI:
2356 case vector::CombiningKind::MINUI:
2363static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2364 auto getOperandType = [&](
auto operand) {
2365 return dyn_cast<ShapedType>((operand->get()).getType());
2367 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2368 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2369 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2373 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2374 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2381 auto maybeOper = getConvOperationKind(reduceOp);
2382 if (!maybeOper.has_value())
2389 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2390 *maybeKind != vector::CombiningKind::OR) &&
2391 (*maybeOper != ConvOperationKind::Pool ||
2392 !isSupportedPoolKind(*maybeKind)))) {
2396 auto rhsRank = rhsShapedType.getRank();
2397 if (*maybeOper == ConvOperationKind::Pool) {
2401 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2408static LogicalResult vectorizeLinalgOpPrecondition(
2409 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2410 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2412 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2413 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2417 if (!inputVectorSizes.empty() &&
2422 if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(
2423 linalgOp, flatten1DDepthwiseConv))) {
2424 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2428 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2434 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2437 customPreconditions,
2440 customPrecondition(&innerOp, vectorizeNDExtract));
2444 if (!llvm::all_of(innerOp.getOperandTypes(),
2445 VectorType::isValidElementType)) {
2448 if (!llvm::all_of(innerOp.getResultTypes(),
2449 VectorType::isValidElementType)) {
2458 return vectorizeConvOpPrecondition(linalgOp);
2464 LDBG() <<
"precondition failed: not projected permutations";
2467 if (
failed(reductionPreconditions(linalgOp))) {
2468 LDBG() <<
"precondition failed: reduction preconditions";
2475vectorizePackOpPrecondition(linalg::PackOp packOp,
2476 ArrayRef<int64_t> inputVectorSizes) {
2478 if (!packOp.hasPureTensorSemantics())
2481 auto padValue = packOp.getPaddingValue();
2485 LDBG() <<
"pad value is not constant: " << packOp;
2489 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2490 bool satisfyEmptyCond =
true;
2491 if (inputVectorSizes.empty()) {
2492 if (!packOp.getDestType().hasStaticShape() ||
2493 !packOp.getSourceType().hasStaticShape())
2494 satisfyEmptyCond =
false;
2497 if (!satisfyEmptyCond &&
2499 resultTensorShape.take_front(packOp.getSourceRank()),
2503 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2504 return !getConstantIntValue(v).has_value();
2506 LDBG() <<
"inner_tiles must be constant: " << packOp;
2514vectorizePadOpPrecondition(tensor::PadOp padOp,
2515 ArrayRef<int64_t> inputVectorSizes) {
2516 auto padValue = padOp.getConstantPaddingValue();
2518 LDBG() <<
"pad value is not constant: " << padOp;
2522 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2538 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2539 [&](
const auto &en) {
2540 OpFoldResult padValue = en.value();
2541 unsigned pos = en.index();
2542 std::optional<int64_t> pad = getConstantIntValue(padValue);
2543 return (!pad.has_value() || pad.value() != 0) &&
2544 resultTensorShape[pos] != 1;
2546 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2560vectorizeScalableVectorPrecondition(Operation *op,
2561 ArrayRef<int64_t> inputVectorSizes,
2562 ArrayRef<bool> inputScalableVecDims) {
2563 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2564 "Number of input vector sizes and scalable dims doesn't match");
2566 size_t numOfScalableDims =
2567 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2569 if (numOfScalableDims == 0)
2572 auto linalgOp = dyn_cast<LinalgOp>(op);
2577 return success(isa<linalg::UnPackOp>(op));
2581 if (numOfScalableDims > 2)
2601 bool seenNonUnitParallel =
false;
2602 auto iterators = linalgOp.getIteratorTypesArray();
2603 SmallVector<bool> scalableFlags(inputScalableVecDims);
2604 int64_t idx = scalableFlags.size() - 1;
2605 while (!scalableFlags[idx]) {
2606 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2607 seenNonUnitParallel |=
2608 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2610 iterators.pop_back();
2611 scalableFlags.pop_back();
2616 switch (iterators.back()) {
2617 case utils::IteratorType::reduction: {
2619 if (iterators.size() != inputVectorSizes.size()) {
2620 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2624 if (isa<linalg::MatmulOp>(op)) {
2626 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2632 case utils::IteratorType::parallel: {
2634 if (seenNonUnitParallel) {
2635 LDBG() <<
"Inner parallel dim not requested for scalable "
2647 if (numOfScalableDims == 2) {
2651 if (iterators.back() == utils::IteratorType::reduction) {
2652 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2657 scalableFlags.pop_back();
2658 iterators.pop_back();
2660 if (!scalableFlags.back() ||
2661 (iterators.back() != utils::IteratorType::parallel))
2669 isa<linalg::BatchMatmulOp>(op) ||
2671 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2676 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2677 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2678 bool flatten1DDepthwiseConv) {
2683 if (
failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2684 inputScalableVecDims)))
2688 .Case([&](linalg::LinalgOp linalgOp) {
2689 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2691 flatten1DDepthwiseConv);
2693 .Case([&](tensor::PadOp padOp) {
2694 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2696 .Case([&](linalg::PackOp packOp) {
2697 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2699 .Case([&](linalg::UnPackOp unpackOp) {
2700 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2702 .Case([&](tensor::InsertSliceOp sliceOp) {
2703 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2705 .Default(failure());
2709static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2710 OpBuilder::InsertionGuard g(rewriter);
2711 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2713 for (
auto op : make_early_inc_range(toReplace)) {
2715 auto expanded = affine::expandAffineExpr(
2717 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2718 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2724 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2725 tensor::InsertSliceOp>(op);
2729 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2730 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2731 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2732 bool createNamedContraction) {
2733 LDBG() <<
"Attempting to vectorize: " << *op;
2734 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2735 LDBG() <<
"Input scalable vector dims: "
2736 << llvm::interleaved(inputScalableVecDims);
2740 flatten1DDepthwiseConv))) {
2741 LDBG() <<
"Vectorization pre-conditions failed";
2746 VectorizationState state(rewriter);
2747 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2748 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2749 inputScalableVecDims,
2750 assumeDynamicDimsMatchVecSizes))) {
2751 LDBG() <<
"Vectorization state couldn't be initialized";
2756 SmallVector<Value> results;
2757 auto vectorizeResult =
2759 .Case([&](linalg::LinalgOp linalgOp) {
2763 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2764 flatten1DDepthwiseConv);
2765 if (succeeded(convOr)) {
2766 llvm::append_range(results, (*convOr)->getResults());
2770 LDBG() <<
"Unsupported convolution can't be vectorized.";
2774 if (createNamedContraction &&
2775 isa<ContractionOpInterface>(linalgOp.getOperation()))
2776 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2780 <<
"Vectorize generic by broadcasting to the canonical vector "
2784 convertAffineApply(rewriter, linalgOp);
2793 .Case([&](tensor::PadOp padOp) {
2794 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2797 .Case([&](linalg::PackOp packOp) {
2798 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2801 .Case([&](linalg::UnPackOp unpackOp) {
2802 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2804 inputScalableVecDims, results);
2806 .Case([&](tensor::InsertSliceOp sliceOp) {
2810 .Default(failure());
2812 if (
failed(vectorizeResult)) {
2813 LDBG() <<
"Vectorization failed";
2817 return VectorizationResult{results};
2821 memref::CopyOp copyOp) {
2822 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2823 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2824 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2829 if (!VectorType::isValidElementType(srcElementType) ||
2830 !VectorType::isValidElementType(dstElementType))
2833 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2834 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2836 Location loc = copyOp->getLoc();
2838 SmallVector<Value>
indices(srcType.getRank(), zero);
2840 Value
readValue = vector::TransferReadOp::create(
2841 rewriter, loc, readType, copyOp.getSource(),
indices,
2844 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2845 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2846 ArrayRef<int64_t>());
2848 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2850 Operation *writeValue = vector::TransferWriteOp::create(
2851 rewriter, loc, readValue, copyOp.getTarget(),
indices,
2862template <
typename OpTy>
2863struct VectorizePadOpUserPattern :
public OpRewritePattern<tensor::PadOp> {
2864 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2866 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2867 PatternRewriter &rewriter)
const final {
2870 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2871 if (
auto op = dyn_cast<OpTy>(user))
2872 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2877 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2878 tensor::PadOp padOp, OpTy op)
const = 0;
2900struct PadOpVectorizationWithTransferReadPattern
2901 :
public VectorizePadOpUserPattern<vector::TransferReadOp> {
2902 using VectorizePadOpUserPattern<
2903 vector::TransferReadOp>::VectorizePadOpUserPattern;
2905 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2906 vector::TransferReadOp xferOp)
const override {
2908 if (!padOp.hasZeroLowPad())
2911 auto padValue = padOp.getConstantPaddingValue();
2915 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2919 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2920 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2922 xferOp.getBaseMutable().assign(padOp.getSource());
2923 xferOp.getPaddingMutable().assign(padValue);
2962struct PadOpVectorizationWithTransferWritePattern
2963 :
public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2964 using VectorizePadOpUserPattern<
2965 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2967 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2968 vector::TransferWriteOp xferOp)
const override {
2970 if (xferOp.getTransferRank() == 0)
2974 if (!padOp.hasZeroLowPad())
2977 auto padValue = padOp.getConstantPaddingValue();
2981 if (!xferOp->hasOneUse())
2983 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2987 if (!trimPadding.hasZeroOffset())
2990 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2996 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2998 xferOp, padOp.getSource().
getType(), xferOp.getVector(),
2999 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
3001 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
3016 bool hasSameTensorSize(Value beforePadding,
3017 tensor::ExtractSliceOp afterTrimming)
const {
3020 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
3021 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3024 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
3025 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3030 if (t1.getRank() != t2.getRank())
3035 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3036 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3038 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3043 if (t1.getNumDynamicDims() == 0)
3051 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
3055 assert(
static_cast<size_t>(t1.getRank()) ==
3056 beforeSlice.getMixedSizes().size());
3057 assert(
static_cast<size_t>(t2.getRank()) ==
3058 afterTrimming.getMixedSizes().size());
3060 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3062 if (!t1.isDynamicDim(i))
3064 auto size1 = beforeSlice.getMixedSizes()[i];
3065 auto size2 = afterTrimming.getMixedSizes()[i];
3072 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3073 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3079 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3080 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3081 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3082 minOp1.getOperands() == minOp2.getOperands())
3108 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3109 auto source = bcast.getSource();
3110 if (llvm::dyn_cast<VectorType>(source.getType()))
3118 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3119 return fill.getInputs()[0];
3124 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3131 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3139 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3147 ArrayRef<int64_t> inputVectorSizes,
3148 SmallVectorImpl<Value> &newResults) {
3150 OpBuilder::InsertionGuard g(rewriter);
3154 auto sourceType = source.getType();
3155 auto resultType = sliceOp.getResultType();
3160 auto elemType = sourceType.getElementType();
3161 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3166 SmallVector<int64_t> vecShape;
3167 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3168 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3169 if (!inputVectorSizes.empty()) {
3170 vecShape.push_back(inputVectorSizes[i]);
3171 }
else if (!sourceType.isDynamicDim(i)) {
3172 vecShape.push_back(sourceType.getDimSize(i));
3173 }
else if (!resultType.isDynamicDim(i)) {
3179 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3186 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3189 auto loc = sliceOp.getLoc();
3192 SmallVector<Value> readIndices(
3195 rewriter, loc, source, vecType, padValue,
3196 inputVectorSizes.empty());
3203 writeIndices, inputVectorSizes.empty());
3206 newResults.push_back(write->
getResult(0));
3234struct PadOpVectorizationWithInsertSlicePattern
3235 :
public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3236 using VectorizePadOpUserPattern<
3237 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3239 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3240 tensor::InsertSliceOp insertOp)
const override {
3242 if (!padOp.hasZeroLowPad())
3245 if (!insertOp.hasUnitStride())
3248 auto padValue = padOp.getConstantPaddingValue();
3252 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3255 if (insertOp.getDest() == padOp.getResult())
3258 auto vecType = VectorType::get(padOp.getType().getShape(),
3259 padOp.getType().getElementType());
3260 unsigned vecRank = vecType.getRank();
3261 unsigned tensorRank = insertOp.getType().getRank();
3265 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3266 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3268 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3269 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3279 SmallVector<Value> readIndices(
3281 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3282 vecType, padOp.getSource(),
3283 readIndices, padValue);
3289 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3290 SmallVector<bool> inBounds(vecRank,
true);
3292 insertOp, read, insertOp.getDest(), writeIndices,
3293 ArrayRef<bool>{inBounds});
3300 RewritePatternSet &
patterns, PatternBenefit baseBenefit) {
3301 patterns.add<PadOpVectorizationWithTransferReadPattern,
3302 PadOpVectorizationWithTransferWritePattern,
3303 PadOpVectorizationWithInsertSlicePattern>(
3314static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3318 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3319 <<
", second op: " << *secondOp;
3322 for (
auto v : values) {
3323 for (
auto &u : v.getUses()) {
3324 Operation *owner = u.getOwner();
3325 if (owner == firstOp || owner == secondOp)
3331 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3332 <<
", second op: " << *secondOp;
3341static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3342 memref::SubViewOp subViewOp;
3344 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3346 return memref::SubViewOp();
3347 subViewOp = newSubViewOp;
3356 vector::TransferReadOp xferOp, PatternRewriter &rewriter)
const {
3359 if (xferOp.getMask())
3363 Value viewOrAlloc = xferOp.getBase();
3369 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3372 Value subView = subViewOp.getResult();
3375 memref::CopyOp copyOp;
3376 for (
auto &u : subView.
getUses()) {
3377 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3378 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3379 if (newCopyOp.getTarget() != subView)
3381 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3393 for (
auto &u : viewOrAlloc.
getUses()) {
3394 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3395 assert(isa<MemRefType>(newFillOp.output().getType()));
3396 if (newFillOp.output() != viewOrAlloc)
3398 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3400 maybeFillOp = newFillOp;
3405 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3407 "padding value does not match fill");
3410 Value in = copyOp.getSource();
3416 auto vectorType = xferOp.getVectorType();
3417 Value res = vector::TransferReadOp::create(
3418 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3419 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3421 SmallVector<bool>(vectorType.getRank(),
false)));
3424 rewriter.
eraseOp(maybeFillOp);
3434 vector::TransferWriteOp xferOp, PatternRewriter &rewriter)
const {
3436 if (xferOp.getMask())
3440 Value viewOrAlloc = xferOp.getBase();
3446 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3449 Value subView = subViewOp.getResult();
3452 memref::CopyOp copyOp;
3453 for (
auto &u : subViewOp.getResult().getUses()) {
3454 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3455 if (newCopyOp.getSource() != subView)
3457 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3467 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3468 Value out = copyOp.getTarget();
3475 auto vector = xferOp.getVector();
3476 vector::TransferWriteOp::create(
3477 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3478 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3480 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3493static void bindShapeDims(ShapedType shapedType) {}
3495template <
int N,
typename IntTy,
typename... IntTy2>
3496static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3497 val = shapedType.getShape()[N];
3498 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3502template <
typename... IntTy>
3503static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3504 bindShapeDims<0>(shapedType, vals...);
3509static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3510#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3511 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3533#undef MATCH_1D_CONV_POOL_OP
3535 return std::nullopt;
3573struct Conv1DGenerator
3574 :
public StructuredGenerator<LinalgOp, utils::IteratorType> {
3577 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3578 LinalgOp linalgOp) {
3581 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3585 int strideW =
static_cast<int>(convParams->strides.front());
3586 int dilationW =
static_cast<int>(convParams->dilations.front());
3587 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3591 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3593 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3594 strideW(strideW), dilationW(dilationW) {
3596 lhsShaped = linalgOp.getDpsInputOperand(0)->
get();
3597 rhsShaped = linalgOp.getDpsInputOperand(1)->
get();
3598 resShaped = linalgOp.getDpsInitOperand(0)->
get();
3599 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3600 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3601 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3606 setConvOperationKind(reduceOp);
3609 reductionKind = maybeKind.value();
3632 int64_t nSize, wSize, cSize, kwSize, fSize;
3633 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3635 switch (conv1DOpOrder) {
3638 nSize = fSize = cSize = 0;
3640 bindShapeDims(resShapedType, wSize);
3642 bindShapeDims(rhsShapedType, kwSize);
3645 (wSize + kwSize - 1)};
3646 rhsShape = {kwSize};
3651 bindShapeDims(resShapedType, nSize, wSize, fSize);
3653 case ConvOperationKind::Conv:
3655 bindShapeDims(rhsShapedType, kwSize, cSize);
3657 case ConvOperationKind::Pool:
3659 bindShapeDims(rhsShapedType, kwSize);
3667 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3671 case ConvOperationKind::Conv:
3672 rhsShape = {kwSize, cSize, fSize};
3674 case ConvOperationKind::Pool:
3675 rhsShape = {kwSize};
3678 resShape = {nSize, wSize, fSize};
3682 bindShapeDims(resShapedType, nSize, fSize, wSize);
3684 case ConvOperationKind::Conv:
3686 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3688 case ConvOperationKind::Pool:
3690 bindShapeDims(rhsShapedType, kwSize);
3694 lhsShape = {nSize, cSize,
3698 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3701 case ConvOperationKind::Conv:
3702 rhsShape = {fSize, cSize, kwSize};
3704 case ConvOperationKind::Pool:
3705 rhsShape = {kwSize};
3708 resShape = {nSize, fSize, wSize};
3712 vector::TransferWriteOp write;
3718 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3720 Type lhsEltType = lhsShapedType.getElementType();
3721 Type rhsEltType = rhsShapedType.getElementType();
3722 Type resEltType = resShapedType.getElementType();
3723 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3724 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3725 auto resType = VectorType::get(resShape, resEltType);
3727 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3728 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3729 SmallVector<Value> resPadding(resShape.size(), zero);
3732 Value
lhs = vector::TransferReadOp::create(
3733 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3734 arith::getZeroConstant(rewriter, loc, lhsEltType));
3736 Value
rhs =
nullptr;
3737 if (oper == ConvOperationKind::Conv)
3738 rhs = vector::TransferReadOp::create(
3739 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3740 arith::getZeroConstant(rewriter, loc, rhsEltType));
3741 Value res = vector::TransferReadOp::create(
3742 rewriter, loc, resType, resShaped, resPadding,
3743 arith::getZeroConstant(rewriter, loc, resEltType));
3748 switch (conv1DOpOrder) {
3756 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3757 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, permLhs);
3759 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3762 if (oper == ConvOperationKind::Conv)
3763 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, permRhs);
3765 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3766 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3775 SmallVector<Value> lhsVals, rhsVals, resVals;
3777 kwSize, strideW, dilationW, wSizeStep,
3780 if (oper == ConvOperationKind::Conv)
3783 wSizeStep, isSingleChanneled);
3785 auto linearIndex = [&](int64_t kw, int64_t w) {
3786 return kw * (wSize / wSizeStep) + w;
3792 for (int64_t kw = 0; kw < kwSize; ++kw) {
3793 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3795 case ConvOperationKind::Conv:
3796 if (isSingleChanneled) {
3797 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3798 lhsVals[linearIndex(kw, w)],
3799 rhsVals[kw], resVals[w]);
3801 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3802 lhsVals[linearIndex(kw, w)],
3803 rhsVals[kw], resVals[w]);
3806 case ConvOperationKind::Pool:
3807 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3823 switch (conv1DOpOrder) {
3830 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3831 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3836 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3842 Value
promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3845 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3846 if (srcElementType == dstElementType)
3853 if (
auto shapedType = dyn_cast<ShapedType>(val.
getType()))
3854 dstType = shapedType.cloneWith(std::nullopt, dstElementType);
3856 dstType = dstElementType;
3858 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3859 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3862 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3863 srcWidth < dstWidth)
3864 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3866 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3867 srcWidth < dstWidth)
3868 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3870 assert(
false &&
"unhandled promotion case");
3875 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3876 Value
lhs, Value
rhs, Value res) {
3877 vector::IteratorType par = vector::IteratorType::parallel;
3878 vector::IteratorType red = vector::IteratorType::reduction;
3879 AffineExpr n, w, f, c;
3883 auto contrationOp = vector::ContractionOp::create(
3884 rewriter, loc,
lhs,
rhs, res,
3885 MapList{{n, w, c}, {c, f}, {n, w, f}},
3886 ArrayRef<vector::IteratorType>{par, par, par, red});
3887 contrationOp.setKind(reductionKind);
3888 return contrationOp;
3893 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3894 Value
lhs, Value
rhs, Value res) {
3897 return vector::OuterProductOp::create(rewriter, loc, res.
getType(),
lhs,
3898 rhs, res, vector::CombiningKind::ADD);
3902 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value
lhs,
3920 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3921 bool channelDimScalableFlag,
3923 bool scalableChDim =
false;
3924 bool useMasking =
false;
3925 int64_t nSize, wSize, cSize, kwSize;
3927 bindShapeDims(rhsShapedType, kwSize, cSize);
3928 if (ShapedType::isDynamic(cSize)) {
3929 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3930 cSize = channelDimVecSize;
3934 scalableChDim = channelDimScalableFlag;
3938 assert(!(useMasking && flatten) &&
3939 "Unsupported flattened conv with dynamic shapes");
3942 bindShapeDims(resShapedType, nSize, wSize);
3944 vector::TransferWriteOp write;
3950 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3952 Type lhsEltType = lhsShapedType.getElementType();
3953 Type rhsEltType = rhsShapedType.getElementType();
3954 Type resEltType = resShapedType.getElementType();
3955 VectorType lhsType = VectorType::get(
3959 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3961 lhsEltType, {
false,
false, scalableChDim});
3962 VectorType rhsType =
3963 VectorType::get({kwSize, cSize}, rhsEltType,
3964 {
false, scalableChDim});
3965 VectorType resType =
3966 VectorType::get({nSize, wSize, cSize}, resEltType,
3967 {
false,
false, scalableChDim});
3971 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3972 ArrayRef<bool> scalableDims,
3973 Operation *opToMask) {
3977 VectorType::get(maskShape, rewriter.
getI1Type(), scalableDims);
3979 SmallVector<bool> inBounds(maskShape.size(),
true);
3980 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3981 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3985 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3988 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3995 Value
lhs = vector::TransferReadOp::create(
3996 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3997 arith::getZeroConstant(rewriter, loc, lhsEltType));
3998 auto *maybeMaskedLhs = maybeMaskXferOp(
3999 lhsType.getShape(), lhsType.getScalableDims(),
lhs.getDefiningOp());
4002 Value
rhs = vector::TransferReadOp::create(
4003 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
4004 arith::getZeroConstant(rewriter, loc, rhsEltType));
4005 auto *maybeMaskedRhs = maybeMaskXferOp(
4006 rhsType.getShape(), rhsType.getScalableDims(),
rhs.getDefiningOp());
4009 Value res = vector::TransferReadOp::create(
4010 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
4011 arith::getZeroConstant(rewriter, loc, resEltType));
4012 auto *maybeMaskedRes = maybeMaskXferOp(
4013 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
4019 SmallVector<Value> lhsVals, rhsVals, resVals;
4020 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
4021 SmallVector<int64_t> inOutStrides = {1, 1, 1};
4025 for (int64_t kw = 0; kw < kwSize; ++kw) {
4026 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4027 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
4028 rewriter, loc, maybeMaskedLhs->getResult(0),
4029 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
4030 inOutSliceSizes, inOutStrides));
4034 for (int64_t kw = 0; kw < kwSize; ++kw) {
4036 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
4037 ArrayRef<int64_t>{kw}));
4040 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4041 resVals.push_back(vector::ExtractStridedSliceOp::create(
4042 rewriter, loc, maybeMaskedRes->getResult(0),
4043 ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
4047 auto linearIndex = [&](int64_t kw, int64_t w) {
4048 return kw * (wSize / wSizeStep) + w;
4053 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
4054 auto lhsTypeAfterFlattening =
4055 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
4056 auto resTypeAfterFlattening =
4057 VectorType::get(inOutFlattenSliceSizes, resEltType);
4060 for (int64_t kw = 0; kw < kwSize; ++kw) {
4061 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4062 Value lhsVal = lhsVals[linearIndex(kw, w)];
4063 Value resVal = resVals[w];
4068 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
4069 lhsVals[linearIndex(kw, w)]);
4070 resVal = vector::ShapeCastOp::create(
4071 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4073 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4074 rhsVals[kw], resVal, flatten);
4077 resVals[w] = vector::ShapeCastOp::create(
4078 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4085 if (!llvm::all_of(resVals, [](Value v) {
return v; })) {
4087 for (
auto &collection :
4088 {resVals, rhsVals, lhsVals, {res,
rhs,
lhs, zero}})
4089 for (Value v : collection)
4096 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4097 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4098 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4099 ArrayRef<int64_t>{0, w, 0},
4100 ArrayRef<int64_t>{1, 1, 1});
4107 Operation *resOut = vector::TransferWriteOp::create(
4108 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4110 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4118 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4119 Value
lhs, Value
rhs, Value res,
4121 auto rhsTy = cast<ShapedType>(
rhs.getType());
4122 auto resTy = cast<ShapedType>(res.
getType());
4136 auto rhsSize = cast<VectorType>(
rhs.getType()).getShape()[0];
4137 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4139 SmallVector<int64_t, 16>
indices;
4140 for (
int i = 0; i < resSize / rhsSize; ++i) {
4141 for (
int j = 0; j < rhsSize; ++j)
4148 rhs = vector::BroadcastOp::create(rewriter, loc,
4149 resTy.clone(rhsTy.getElementType()),
rhs);
4156 if (isa<FloatType>(resTy.getElementType()))
4157 return vector::FMAOp::create(rewriter, loc,
lhs,
rhs, res);
4159 auto mul = arith::MulIOp::create(rewriter, loc,
lhs,
rhs);
4160 return arith::AddIOp::create(rewriter, loc,
mul, res);
4165 FailureOr<Operation *> generateNonChanneledConv() {
4168 if (!iters({Par(), Red()}))
4170 "failed to match conv::W 1-par 1-red");
4173 if (layout({ {w + kw},
4183 FailureOr<Operation *> generateNwcConv() {
4184 AffineExpr n, w, f, kw, c;
4186 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4188 op,
"failed to match conv::Nwc 3-par 2-red");
4191 if (layout({ {n, strideW * w + dilationW * kw, c},
4201 FailureOr<Operation *> generateNcwConv() {
4202 AffineExpr n, w, f, kw, c;
4204 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4206 op,
"failed to match conv::Ncw 3-par 2-red");
4208 if (layout({ {n, c, strideW * w + dilationW * kw},
4218 FailureOr<Operation *> generateNwcPooling() {
4219 AffineExpr n, w, c, kw;
4221 if (!iters({Par(), Par(), Par(), Red()}))
4223 "failed to match pooling 3-par 1-red");
4226 if (layout({ {n, strideW * w + dilationW * kw, c},
4236 FailureOr<Operation *> generateNcwPooling() {
4237 AffineExpr n, w, c, kw;
4239 if (!iters({Par(), Par(), Par(), Red()}))
4241 "failed to match pooling 3-par 1-red");
4243 if (layout({ {n, c, strideW * w + dilationW * kw},
4253 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4254 bool vecChDimScalableFlag =
false,
4255 bool flatten =
false) {
4256 AffineExpr n, w, c, kw;
4258 if (!iters({Par(), Par(), Par(), Red()}))
4260 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4263 if (layout({ {n, strideW * w + dilationW * kw, c},
4266 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4272 ConvOperationKind oper = ConvOperationKind::Conv;
4274 StringAttr poolExtOp;
4275 bool isPoolExt =
false;
4276 int strideW, dilationW;
4277 Value lhsShaped, rhsShaped, resShaped;
4278 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4279 vector::CombiningKind reductionKind;
4282 void setConvOperationKind(Operation *reduceOp) {
4283 int numBlockArguments =
4284 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4285 if (numBlockArguments == 1) {
4290 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4291 llvm::IsaPred<BlockArgument>);
4292 Operation *feedOp = (*feedValIt).getDefiningOp();
4293 if (isCastOfBlockArgument(feedOp)) {
4294 oper = ConvOperationKind::Pool;
4299 oper = ConvOperationKind::Conv;
4303 oper = ConvOperationKind::Pool;
4312 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4313 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4314 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4317 auto res = conv1dGen->generateNonChanneledConv();
4320 res = conv1dGen->generateNwcConv();
4323 res = conv1dGen->generateNcwConv();
4326 res = conv1dGen->generateNwcPooling();
4329 res = conv1dGen->generateNcwPooling();
4336 uint64_t vecChDimSize = ShapedType::kDynamic;
4337 bool vecChDimScalableFlag =
false;
4338 if (!inputVecSizes.empty()) {
4343 "Not a 1D depthwise conv!");
4344 size_t chDimIdx = 0;
4350 vecChDimSize = inputVecSizes[chDimIdx];
4351 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4353 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4354 flatten1DDepthwiseConv);
4357struct VectorizeConvolution :
public OpInterfaceRewritePattern<LinalgOp> {
4360 LogicalResult matchAndRewrite(LinalgOp op,
4361 PatternRewriter &rewriter)
const override {
4363 if (
failed(resultOrFail))
4365 Operation *newOp = *resultOrFail;
4367 rewriter.
eraseOp(op.getOperation());
4370 assert(newOp->
getNumResults() == 1 &&
"expected single result");
4377 RewritePatternSet &
patterns, PatternBenefit benefit) {
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override