37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/Sequence.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/InterleavedRange.h"
43#include "llvm/Support/MathExtras.h"
44#include "llvm/Support/raw_ostream.h"
50#define DEBUG_TYPE "linalg-vectorization"
53static FailureOr<Operation *>
57 bool flatten1DDepthwiseConv =
false);
92template <
typename OpType>
95 block.
walk([&](OpType op) {
111 int64_t kwSize,
int strideW,
int dilationW,
112 int64_t wSizeStep,
bool isSingleChanneled) {
114 if (isSingleChanneled) {
119 for (
int64_t kw = 0; kw < kwSize; ++kw) {
120 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
131 for (
int64_t kw = 0; kw < kwSize; ++kw) {
132 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
151 for (
int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
179 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
207 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
226 bool assumeDynamicDimsMatchVecSizes =
false);
241 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
244 if (dimPermutation.has_value()) {
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
254 return VectorType::get(
vectorShape, elementType, scalableDims);
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
275 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
282 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
284 std::optional<AffineMap> maybeMaskingMap);
289 bool isValidMaskingMap(AffineMap maskingMap) {
308 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
314 SmallVector<int64_t> iterSpaceStaticSizes;
319 SmallVector<Value> iterSpaceValueSizes;
322 SmallVector<int64_t> canonicalVecShape;
326 SmallVector<bool> scalableVecDims;
334 OpBuilder::InsertionGuard rewriterGuard;
342 bool assumeDynamicDimsMatchVecSizes =
false;
346VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
349 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
360 unsigned operandDimPos;
361 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
366 linalgOp.hasPureTensorSemantics()
367 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 iterSpaceValueSizes.push_back(dynamicDim);
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
389 if (!inputVectorSizes.empty()) {
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(),
false);
404 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407 if (ShapedType::isDynamicShape(canonicalVecShape))
411 initIterSpaceStaticSizes(linalgOp);
416 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
426Value VectorizationState::getOrCreateMaskFor(
428 std::optional<AffineMap> maybeMaskingMap) {
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.
getContext());
449 LDBG() <<
"Masking map: " << maskingMap;
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() <<
"Reusing mask: " << mask;
466 SmallVector<int64_t> permutedStaticSizes =
468 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
471 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
473 if (permutedStaticSizes == maskShape) {
474 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] = Value();
479 if (assumeDynamicDimsMatchVecSizes) {
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 return std::get<0>(it) == ShapedType::kDynamic
487 : std::get<0>(it) == std::get<1>(it);
490 <<
"Dynamic + static dimensions match vector sizes, masking is not "
492 activeMaskCache[maskingMap] = Value();
498 SmallVector<Value> upperBounds =
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() <<
"Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() <<
"Trying to mask: " << *opToMask;
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526 LDBG() <<
"No mask required";
527 if (assumeDynamicDimsMatchVecSizes) {
529 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
535 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
536 "setting their in-bounds to true!";
538 ShapedType xferType = xferOp.getShapedType();
543 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
548 unsigned pos = dimExpr.getPosition();
549 if (xferType.isDynamicDim(pos))
550 inBoundsMap[i] =
true;
553 xferOp.setInBoundsAttr(
565 assert(opToMask &&
"Expected a valid operation to mask");
566 auto maskOp = cast<vector::MaskOp>(
568 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
570 for (
auto [resIdx, resVal] : llvm::enumerate(opToMask->
getResults()))
574 LDBG() <<
"Masked operation: " << *maskOp;
597 "expected projected permutation");
599 assert(res.getNumDims() ==
600 (res.getNumResults() - res.getNumOfZeroResults()) &&
601 "expected reindexed map with same number of dims and results");
637std::optional<vector::CombiningKind>
639 using ::mlir::vector::CombiningKind;
644 .Case<arith::AddIOp, arith::AddFOp>(
645 [&](
auto op) {
return CombiningKind::ADD; })
646 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
647 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
648 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
649 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
650 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
651 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
652 .Case<arith::MinUIOp>([&](
auto op) {
return CombiningKind::MINUI; })
653 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
654 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
655 .Case<arith::MulIOp, arith::MulFOp>(
656 [&](
auto op) {
return CombiningKind::MUL; })
657 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
658 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
659 .Default(std::nullopt);
670 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
675 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
676 combinerOps.size() != 1)
680 return combinerOps[0];
686 auto dstVecType = dyn_cast<VectorType>(dstType);
688 if (dstVecType.getRank() == 0)
693 Location loc =
b.getInsertionPoint()->getLoc();
694 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
706 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
707 return vector::MultiDimReductionOp::create(
708 b, reduceOp->
getLoc(), valueToReduce,
acc, dimsToMask, *maybeKind);
712 return llvm::to_vector(
719 return isa<linalg::ReduceOp>(op) ||
720 (isa<linalg::GenericOp>(op) &&
732 VectorizationState &state) {
734 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
735 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
744 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
746 auto vectorType = state.getCanonicalVecType(
753 if (vectorType.getRank() > 0) {
756 assert(value.
getType() == vectorType &&
"Incorrect type");
757 write = vector::TransferWriteOp::create(
758 rewriter, loc, value, outputOperand->
get(),
indices, writeMap);
761 if (!isa<VectorType>(value.
getType()))
762 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763 assert(value.
getType() == vectorType &&
"Incorrect type");
764 write = vector::TransferWriteOp::create(rewriter, loc, value,
768 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
772 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
773 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
778 LDBG() <<
"vectorized op: " << *write;
788 std::function<LogicalResult(
Operation *,
bool)>;
805 const IRMapping &bvm, VectorizationState &state,
807 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
810 for (
const auto &output : llvm::enumerate(yieldOp.getValues())) {
816 linalgOp.getDpsInitOperand(output.index()), state);
818 newResults.push_back(newResult);
829 VectorizationState &state,
832 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
835 auto loc = indexOp.getLoc();
838 auto dim = indexOp.getDim();
840 auto indexVectorType =
841 VectorType::get({targetShape[dim]}, rewriter.
getIndexType(),
842 state.getScalableVecDims()[dim]);
843 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
847 if (dim == targetShape.size() - 1)
853 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
854 std::swap(permPattern[dim], permPattern.back());
858 auto broadCastOp = vector::BroadcastOp::create(
860 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
862 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
863 std::swap(transposition.back(), transposition[dim]);
865 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
873 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
877 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
882 if (not extractOp.getIndices().empty()) {
883 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
887 if (!llvm::all_of(extractOp->getResultTypes(),
888 VectorType::isValidElementType)) {
906 VectorizationState &state,
907 tensor::ExtractOp extractOp,
910 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
911 auto loc = extractOp.getLoc();
914 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
916 const size_t numIndices = extractOp.getIndices().size();
917 for (
size_t i = 1; i < numIndices; i++) {
922 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
925 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
928 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
930 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
956 (linalgOp.hasDynamicShape() ||
957 llvm::count_if(loopRanges, [](
int64_t dim) { return dim != 1; }) == 1) &&
958 "For statically shaped Linalg Ops, only one "
959 "non-unit loop dim is expected");
960 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
962 size_t idx = loopRanges.size() - 1;
963 for (; idx != 0; idx--)
964 if (loopRanges[idx] != 1)
972 VectorType resType) {
974 assert(((llvm::count_if(resType.getShape(),
975 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976 "n-D vectors are not yet supported");
982 auto *block = linalgOp.getBlock();
983 if (isa<BlockArgument>(val))
984 return !llvm::is_contained(block->getArguments(), val);
987 assert(defOp &&
"This is neither a block argument nor an operation result");
992 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
993 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
996 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1003 if (isa<arith::ConstantOp>(ancestor))
1007 for (
auto op : ancestor->getOperands())
1031 bool &foundIndexOp, VectorType resType) {
1033 assert(((llvm::count_if(resType.getShape(),
1034 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1035 "n-D vectors are not yet supported");
1041 auto *block = linalgOp.getBlock();
1042 if (isa<BlockArgument>(val))
1043 return !llvm::is_contained(block->getArguments(), val);
1046 assert(defOp &&
"This is neither a block argument nor an operation result");
1048 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1051 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1055 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1062 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1066 for (
auto op : ancestor->getOperands())
1086 LinalgOp &linalgOp, VectorType resType) {
1088 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1091 if (inputShape.getShape().empty())
1096 bool isOutput1DVector =
1097 (llvm::count_if(resType.getShape(),
1098 [](
int64_t dimSize) { return dimSize > 1; }) == 1);
1100 if (!isOutput1DVector)
1103 bool leadingIdxsLoopInvariant =
true;
1109 auto indices = extractOp.getIndices();
1110 auto leadIndices =
indices.drop_back(1);
1112 for (
auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1113 if (inputShape.getShape()[i] == 1)
1119 if (!leadingIdxsLoopInvariant) {
1120 LDBG() <<
"Found gather load: " << extractOp;
1128 auto extractOpTrailingIdx =
indices.back();
1132 if (leadingIdxsLoopInvariant &&
1134 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1143 bool foundIndexOp =
false;
1145 foundIndexOp, resType);
1148 bool isRowVector = resType.getShape().back() != 1;
1149 isContiguousLoad &= (foundIndexOp && isRowVector);
1151 if (isContiguousLoad) {
1152 LDBG() <<
"Found contigous load: " << extractOp;
1157 LDBG() <<
"Found gather load: " << extractOp;
1165static VectorizationHookResult
1168 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1171 auto loc = extractOp.getLoc();
1174 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1175 auto maskConstantOp = arith::ConstantOp::create(
1179 auto passThruConstantOp = arith::ConstantOp::create(
1185 extractOp.getIndices().size(),
1196 Operation *gatherOp = vector::GatherOp::create(
1197 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1198 maskConstantOp, passThruConstantOp);
1199 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1201 LDBG() <<
"Vectorised as gather load: " << extractOp;
1224 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1225 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1227 transferReadIdxs.push_back(idx);
1231 auto indexAs1dVector = vector::ShapeCastOp::create(
1233 VectorType::get(resultType.getShape().back(), rewriter.
getIndexType(),
1234 resultType.getScalableDims().back()),
1236 transferReadIdxs.push_back(
1237 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1241 auto dstRank = resultType.getRank();
1242 auto srcRank = extractOp.getTensor().getType().getRank();
1251 auto transferReadOp = vector::TransferReadOp::create(
1252 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1253 std::nullopt, permutationMap, inBounds);
1259 auto readMaskType = VectorType::get(readMaskShape, rewriter.
getI1Type());
1260 auto allTrue = vector::ConstantMaskOp::create(
1262 auto *maskedReadOp =
1265 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1272 srcRank, std::min(dstRank, srcRank), rewriter.
getContext());
1274 int32_t rankDiff = dstRank - srcRank;
1282 while (rankDiff > 0) {
1283 permutationMap = permutationMap.insertResult(
1288 auto transferReadOp = vector::TransferReadOp::create(
1289 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1290 std::nullopt, permutationMap, inBounds);
1292 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1306 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1307 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1311 (outputType && reduceType.getShape() == outputType.getShape()))
1336static VectorizationHookResult
1340 LDBG() <<
"vectorize op " << *op;
1343 if (!customVectorizationHooks.empty()) {
1344 for (
auto &customFunc : customVectorizationHooks) {
1354 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1356 rewriter.
clone(*op)};
1365 auto blockArg = dyn_cast<BlockArgument>(operand);
1366 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1367 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1371 linalgOp.getRegionOutputArgs(),
1372 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1375 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1377 if (!reductionOperands.empty()) {
1378 assert(reductionOperands.size() == 1);
1380 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1381 reductionOperands[0].second, bvm);
1388 VectorType firstMaxRankedType;
1390 auto vecOperand = bvm.
lookup(operand);
1391 assert(vecOperand &&
"Vector operand couldn't be found");
1393 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1394 if (vecType && (!firstMaxRankedType ||
1395 firstMaxRankedType.getRank() < vecType.getRank()))
1396 firstMaxRankedType = vecType;
1402 assert(vecOperand &&
"Vector operand couldn't be found");
1404 if (firstMaxRankedType) {
1405 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1407 firstMaxRankedType.getScalableDims());
1410 vecOperands.push_back(vecOperand);
1416 resultTypes.push_back(
1418 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1419 firstMaxRankedType.getScalableDims())
1455 LDBG() <<
"Vectorizing operation as linalg generic/n";
1456 Block *block = linalgOp.getBlock();
1463 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1465 if (linalgOp.getNumDpsInits() == 0)
1471 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1472 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1473 if (linalgOp.isScalar(opOperand)) {
1474 bvm.
map(bbarg, opOperand->get());
1480 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1483 VectorType readType;
1485 if (linalgOp.isDpsInput(opOperand)) {
1488 readType = state.getCanonicalVecType(elemType);
1495 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1500 Operation *read = vector::TransferReadOp::create(
1501 rewriter, loc, readType, opOperand->get(),
indices,
1502 std::nullopt, readMap);
1503 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1508 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1510 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1516 if (readType.getRank() == 0)
1517 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1520 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1521 <<
"): " << readValue;
1522 bvm.
map(bbarg, readValue);
1523 bvm.
map(opOperand->get(), readValue);
1532 hooks.push_back(vectorizeYield);
1539 hooks.push_back(vectorizeIndex);
1546 hooks.push_back(vectorizeExtract);
1553 LDBG() <<
"failed to vectorize: " << op;
1558 state.maskOperation(rewriter,
result.newOp, linalgOp);
1559 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1618 if (ShapedType::isDynamicShape(destShape))
1623 for (
auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1625 cstMaskSizes.push_back(*intSize);
1630 if (cstMaskSizes.size() != maskShape.size())
1635 for (
auto [i, idx] : llvm::enumerate(writeIdxs)) {
1638 cstWriteIdxs.push_back(intVal.getSExtValue());
1643 if (cstWriteIdxs.size() != destShape.size())
1652 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1653 for (
auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1654 if ( maskShape[i] > destShape[rankDiff + i] ||
1655 destShape[rankDiff + i] <
1656 (std::clamp(cstMaskSizes[i],
int64_t(0), maskShape[i]) +
1692 bool useInBoundsInsteadOfMasking =
false) {
1694 ShapedType destType = cast<ShapedType>(dest.
getType());
1695 int64_t destRank = destType.getRank();
1696 auto destShape = destType.getShape();
1698 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1699 int64_t vecToStoreRank = vecToStoreType.getRank();
1700 auto vecToStoreShape = vecToStoreType.getShape();
1703 SmallVector<bool> inBoundsVal(vecToStoreRank,
true);
1704 if (useInBoundsInsteadOfMasking) {
1707 for (
unsigned i = 0; i < vecToStoreRank; i++)
1709 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1710 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1714 assert((writeIndices.empty() ||
1715 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1716 "Invalid number of write indices!");
1717 if (writeIndices.empty()) {
1719 writeIndices.assign(destRank, zero);
1723 Operation *write = vector::TransferWriteOp::create(builder, loc,
1730 if (useInBoundsInsteadOfMasking)
1734 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1738 auto writeMaskType = VectorType::get(vecToStoreShape, builder.
getI1Type(),
1739 vecToStoreType.getScalableDims());
1741 SmallVector<OpFoldResult> destSizes =
1742 isa<MemRefType>(dest.
getType())
1745 SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1752 Value maskForWrite =
1753 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1775 assert(type.getNumScalableDims() < 2 &&
1776 "Collapsing more than 1 scalable dim is not supported ATM");
1782 auto shape = type.getShape();
1783 auto scalableFlags = type.getScalableDims();
1787 unsigned currentDim = 0;
1789 unsigned dim = m.getNumResults();
1792 for (
unsigned d = 0; d < dim; ++d) {
1793 size *=
shape[currentDim + d];
1794 flag |= scalableFlags[currentDim + d];
1796 newShape.push_back(size);
1797 newScalableFlags.push_back(flag);
1801 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1834vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1835 ArrayRef<int64_t> inputVectorSizes,
1836 SmallVectorImpl<Value> &newResults) {
1837 if (!inputVectorSizes.empty()) {
1838 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1839 "Invalid number of input vector sizes!");
1843 OpBuilder::InsertionGuard g(rewriter);
1846 Location loc = packOp.getLoc();
1847 std::optional<Value> padValue = packOp.getPaddingValue()
1848 ? std::optional(packOp.getPaddingValue())
1851 SmallVector<int64_t> destShape =
1852 SmallVector<int64_t>(packOp.getDestType().getShape());
1856 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1860 bool useInBoundsInsteadOfMasking =
false;
1861 if (writeVectorSizes.empty()) {
1862 if (ShapedType::isDynamicShape(destShape))
1864 "unable to infer vector sizes");
1866 writeVectorSizes = destShape;
1867 useInBoundsInsteadOfMasking =
true;
1876 PackingMetadata packMetadata;
1877 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1880 auto preTransposeWriteVecType = VectorType::get(
1881 preTransposeWriteVecSizses, packOp.getType().getElementType());
1887 preTransposeWriteVecType,
1889 rewriter.
getContext(), packMetadata.reassociations)));
1893 rewriter, loc, packOp.getSource(), readVecType, padValue,
1894 useInBoundsInsteadOfMasking);
1897 auto shapeCastOp = vector::ShapeCastOp::create(
1898 rewriter, loc, preTransposeWriteVecType, maskedRead);
1902 auto transposeOp = vector::TransposeOp::create(
1903 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1907 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1908 newResults.push_back(write->
getResult(0));
1942vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1943 ArrayRef<int64_t> inputVectorSizes,
1944 ArrayRef<bool> inputScalableVecDims,
1945 SmallVectorImpl<Value> &newResults) {
1946 if (!inputVectorSizes.empty()) {
1947 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1948 "Invalid number of input vector sizes!");
1949 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1950 "Incompatible number of vector sizes and vector scalable flags!");
1954 OpBuilder::InsertionGuard g(rewriter);
1957 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1959 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1960 bool useInBoundsInsteadOfMasking =
false;
1962 Location loc = unpackOp->getLoc();
1965 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1966 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1969 if (inputVectorSizes.empty()) {
1970 if (ShapedType::isDynamicShape(sourceShape))
1972 "Unable to infer vector sizes!");
1974 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1975 useInBoundsInsteadOfMasking =
true;
1979 VectorType readVecType =
1980 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1981 readScalableVectorFlags);
1983 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1984 useInBoundsInsteadOfMasking);
1987 PackingMetadata packMetadata;
1988 SmallVector<int64_t> lastDimToInsertPosPerm =
1990 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1991 rewriter, loc, readResult, lastDimToInsertPosPerm);
1995 transposeOp.getType(),
1997 rewriter.
getContext(), packMetadata.reassociations)));
1998 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1999 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2003 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2004 {}, useInBoundsInsteadOfMasking);
2006 newResults.push_back(write->
getResult(0));
2014vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2015 ArrayRef<int64_t> inputVectorSizes,
2016 SmallVectorImpl<Value> &newResults) {
2017 auto padValue = padOp.getConstantPaddingValue();
2018 Location loc = padOp.getLoc();
2021 OpBuilder::InsertionGuard g(rewriter);
2025 LogicalResult status =
2026 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2027 .reifyResultShapes(rewriter, reifiedReturnShapes);
2029 assert(succeeded(status) &&
"failed to reify result shapes");
2030 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2032 rewriter, loc, padOp.getSource(), readType, padValue,
2036 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2037 padOp.getResultType().getElementType());
2039 newResults.push_back(write->
getResult(0));
2045static LogicalResult reductionPreconditions(LinalgOp op) {
2047 LDBG() <<
"reduction precondition failed: no reduction iterator";
2050 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2051 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2057 LDBG() <<
"reduction precondition failed: reduction detection failed";
2065vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2066 bool flatten1DDepthwiseConv) {
2067 if (flatten1DDepthwiseConv) {
2068 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2073 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2074 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2080 Value
lhs = conv.getDpsInputOperand(0)->get();
2081 ArrayRef<int64_t> lhsShape = cast<ShapedType>(
lhs.getType()).getShape();
2082 auto shapeWithoutCh = lhsShape.drop_back(1);
2083 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2084 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2085 "channel dim can be dynamic";
2093vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2094 bool flatten1DDepthwiseConv) {
2095 if (isa<ConvolutionOpInterface>(op.getOperation()))
2096 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2099 return reductionPreconditions(op);
2104 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2108 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2118vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2119 ArrayRef<int64_t> inputVectorSizes) {
2122 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2123 unpackOp.getSourceType().hasStaticShape())
2128 if (!inputVectorSizes.empty() &&
2129 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2130 LDBG() <<
"Incorrect number of input vector sizes";
2136 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2137 LDBG() <<
"Invalid vector sizes for the read operation";
2145vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2146 ArrayRef<int64_t> inputVectorSizes) {
2149 auto sourceType = source.getType();
2150 if (!VectorType::isValidElementType(sourceType.getElementType()))
2166 bool isOutOfBoundsRead =
2167 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2169 if (!padValue && isOutOfBoundsRead) {
2170 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2184vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2186 SmallVectorImpl<Value> &newResults) {
2187 Location loc = linalgOp.getLoc();
2188 MLIRContext *ctx = linalgOp.getContext();
2193 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2196 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2200 LDBG() <<
"Failed to determine contraction combining kind.";
2207 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2208 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2210 LDBG() <<
"Contractions with broadcasts are not supported.";
2215 SmallVector<Value> vecOperands;
2216 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2220 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2224 VectorType readType =
2225 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2228 rewriter, loc, opOperand.get(), readType,
2229 arith::getZeroConstant(rewriter, loc, elemType),
2231 vecOperands.push_back(read);
2235 SmallVector<Attribute> iterAttrs;
2236 auto iterators = linalgOp.getIteratorTypesArray();
2237 for (utils::IteratorType iter : iterators) {
2238 auto vecIter = iter == utils::IteratorType::parallel
2239 ? vector::IteratorType::parallel
2240 : vector::IteratorType::reduction;
2241 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2245 Operation *contractOp = vector::ContractionOp::create(
2246 rewriter, loc, vecOperands[0],
2247 vecOperands[1], vecOperands[2],
2248 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2249 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2253 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2257 newResults.push_back(write->
getResult(0));
2263enum class ConvOperationKind { Conv, Pool };
2266static bool isCastOfBlockArgument(Operation *op) {
2281static std::optional<ConvOperationKind>
2282getConvOperationKind(Operation *reduceOp) {
2283 int numBlockArguments =
2284 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2286 switch (numBlockArguments) {
2292 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2293 llvm::IsaPred<BlockArgument>);
2295 "Expected a non-block argument operand");
2296 Operation *feedOp = (*feedValIt).getDefiningOp();
2297 if (isCastOfBlockArgument(feedOp)) {
2298 return ConvOperationKind::Pool;
2301 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2302 (isa<arith::AndIOp>(feedOp) &&
2305 if (isa<BlockArgument>(v))
2307 if (Operation *op = v.getDefiningOp())
2308 return isCastOfBlockArgument(op);
2311 return std::nullopt;
2314 return ConvOperationKind::Conv;
2318 return ConvOperationKind::Pool;
2320 return std::nullopt;
2324static bool isSupportedPoolKind(vector::CombiningKind kind) {
2326 case vector::CombiningKind::ADD:
2327 case vector::CombiningKind::MAXNUMF:
2328 case vector::CombiningKind::MAXIMUMF:
2329 case vector::CombiningKind::MAXSI:
2330 case vector::CombiningKind::MAXUI:
2331 case vector::CombiningKind::MINNUMF:
2332 case vector::CombiningKind::MINIMUMF:
2333 case vector::CombiningKind::MINSI:
2334 case vector::CombiningKind::MINUI:
2341static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2342 auto getOperandType = [&](
auto operand) {
2343 return dyn_cast<ShapedType>((operand->get()).getType());
2345 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2346 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2347 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2351 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2352 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2359 auto maybeOper = getConvOperationKind(reduceOp);
2360 if (!maybeOper.has_value())
2367 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2368 *maybeKind != vector::CombiningKind::OR) &&
2369 (*maybeOper != ConvOperationKind::Pool ||
2370 !isSupportedPoolKind(*maybeKind)))) {
2374 auto rhsRank = rhsShapedType.getRank();
2375 if (*maybeOper == ConvOperationKind::Pool) {
2379 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2386static LogicalResult vectorizeLinalgOpPrecondition(
2387 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2388 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2390 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2391 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2395 if (!inputVectorSizes.empty() &&
2400 if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(
2401 linalgOp, flatten1DDepthwiseConv))) {
2402 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2406 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2412 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2415 customPreconditions,
2418 customPrecondition(&innerOp, vectorizeNDExtract));
2422 if (!llvm::all_of(innerOp.getOperandTypes(),
2423 VectorType::isValidElementType)) {
2426 if (!llvm::all_of(innerOp.getResultTypes(),
2427 VectorType::isValidElementType)) {
2437 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2438 return vectorizeConvOpPrecondition(linalgOp);
2444 LDBG() <<
"precondition failed: not projected permutations";
2447 if (
failed(reductionPreconditions(linalgOp))) {
2448 LDBG() <<
"precondition failed: reduction preconditions";
2455vectorizePackOpPrecondition(linalg::PackOp packOp,
2456 ArrayRef<int64_t> inputVectorSizes) {
2457 auto padValue = packOp.getPaddingValue();
2461 LDBG() <<
"pad value is not constant: " << packOp;
2465 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2466 bool satisfyEmptyCond =
true;
2467 if (inputVectorSizes.empty()) {
2468 if (!packOp.getDestType().hasStaticShape() ||
2469 !packOp.getSourceType().hasStaticShape())
2470 satisfyEmptyCond =
false;
2473 if (!satisfyEmptyCond &&
2475 resultTensorShape.take_front(packOp.getSourceRank()),
2479 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2480 return !getConstantIntValue(v).has_value();
2482 LDBG() <<
"inner_tiles must be constant: " << packOp;
2490vectorizePadOpPrecondition(tensor::PadOp padOp,
2491 ArrayRef<int64_t> inputVectorSizes) {
2492 auto padValue = padOp.getConstantPaddingValue();
2494 LDBG() <<
"pad value is not constant: " << padOp;
2498 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2514 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2515 Value padValue = en.value();
2516 unsigned pos = en.index();
2517 std::optional<int64_t> pad = getConstantIntValue(padValue);
2518 return (!pad.has_value() || pad.value() != 0) &&
2519 resultTensorShape[pos] != 1;
2521 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2535vectorizeScalableVectorPrecondition(Operation *op,
2536 ArrayRef<int64_t> inputVectorSizes,
2537 ArrayRef<bool> inputScalableVecDims) {
2538 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2539 "Number of input vector sizes and scalable dims doesn't match");
2541 size_t numOfScalableDims =
2542 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2544 if (numOfScalableDims == 0)
2547 auto linalgOp = dyn_cast<LinalgOp>(op);
2552 return success(isa<linalg::UnPackOp>(op));
2556 if (numOfScalableDims > 2)
2576 bool seenNonUnitParallel =
false;
2577 auto iterators = linalgOp.getIteratorTypesArray();
2578 SmallVector<bool> scalableFlags(inputScalableVecDims);
2579 int64_t idx = scalableFlags.size() - 1;
2580 while (!scalableFlags[idx]) {
2581 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2582 seenNonUnitParallel |=
2583 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2585 iterators.pop_back();
2586 scalableFlags.pop_back();
2591 switch (iterators.back()) {
2592 case utils::IteratorType::reduction: {
2594 if (iterators.size() != inputVectorSizes.size()) {
2595 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2599 if (isa<linalg::MatmulOp>(op)) {
2601 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2607 case utils::IteratorType::parallel: {
2609 if (seenNonUnitParallel) {
2610 LDBG() <<
"Inner parallel dim not requested for scalable "
2622 if (numOfScalableDims == 2) {
2626 if (iterators.back() == utils::IteratorType::reduction) {
2627 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2632 scalableFlags.pop_back();
2633 iterators.pop_back();
2635 if (!scalableFlags.back() ||
2636 (iterators.back() != utils::IteratorType::parallel))
2643 isa<linalg::BatchMatmulOp>(op) ||
2644 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2645 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2646 isa<linalg::BatchMmt4DOp>(op) ||
2651 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2652 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2653 bool flatten1DDepthwiseConv) {
2658 if (
failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2659 inputScalableVecDims)))
2663 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2664 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2666 flatten1DDepthwiseConv);
2668 .Case<tensor::PadOp>([&](
auto padOp) {
2669 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2671 .Case<linalg::PackOp>([&](
auto packOp) {
2672 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2674 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2675 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2677 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2678 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2680 .Default(failure());
2684static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2685 OpBuilder::InsertionGuard g(rewriter);
2686 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2688 for (
auto op : make_early_inc_range(toReplace)) {
2690 auto expanded = affine::expandAffineExpr(
2692 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2693 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2699 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2700 tensor::InsertSliceOp>(op);
2704 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2705 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2706 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2707 bool createNamedContraction) {
2708 LDBG() <<
"Attempting to vectorize: " << *op;
2709 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2710 LDBG() <<
"Input scalable vector dims: "
2711 << llvm::interleaved(inputScalableVecDims);
2715 flatten1DDepthwiseConv))) {
2716 LDBG() <<
"Vectorization pre-conditions failed";
2721 VectorizationState state(rewriter);
2722 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2723 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2724 inputScalableVecDims,
2725 assumeDynamicDimsMatchVecSizes))) {
2726 LDBG() <<
"Vectorization state couldn't be initialized";
2731 SmallVector<Value> results;
2732 auto vectorizeResult =
2734 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2738 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2740 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2741 flatten1DDepthwiseConv);
2742 if (succeeded(convOr)) {
2743 llvm::append_range(results, (*convOr)->getResults());
2747 LDBG() <<
"Unsupported convolution can't be vectorized.";
2751 if (createNamedContraction &&
2752 isa<ContractionOpInterface>(linalgOp.getOperation()))
2753 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2757 <<
"Vectorize generic by broadcasting to the canonical vector "
2761 convertAffineApply(rewriter, linalgOp);
2770 .Case<tensor::PadOp>([&](
auto padOp) {
2771 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2774 .Case<linalg::PackOp>([&](
auto packOp) {
2775 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2778 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2779 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2781 inputScalableVecDims, results);
2783 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2787 .Default(failure());
2789 if (
failed(vectorizeResult)) {
2790 LDBG() <<
"Vectorization failed";
2794 return VectorizationResult{results};
2798 memref::CopyOp copyOp) {
2799 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2800 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2801 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2806 if (!VectorType::isValidElementType(srcElementType) ||
2807 !VectorType::isValidElementType(dstElementType))
2810 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2811 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2813 Location loc = copyOp->getLoc();
2815 SmallVector<Value>
indices(srcType.getRank(), zero);
2817 Value
readValue = vector::TransferReadOp::create(
2818 rewriter, loc, readType, copyOp.getSource(),
indices,
2821 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2822 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2823 ArrayRef<int64_t>());
2825 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2827 Operation *writeValue = vector::TransferWriteOp::create(
2828 rewriter, loc, readValue, copyOp.getTarget(),
indices,
2839template <
typename OpTy>
2840struct VectorizePadOpUserPattern :
public OpRewritePattern<tensor::PadOp> {
2841 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2843 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2844 PatternRewriter &rewriter)
const final {
2847 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2848 if (
auto op = dyn_cast<OpTy>(user))
2849 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2854 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2855 tensor::PadOp padOp, OpTy op)
const = 0;
2877struct PadOpVectorizationWithTransferReadPattern
2878 :
public VectorizePadOpUserPattern<vector::TransferReadOp> {
2879 using VectorizePadOpUserPattern<
2880 vector::TransferReadOp>::VectorizePadOpUserPattern;
2882 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2883 vector::TransferReadOp xferOp)
const override {
2885 if (!padOp.hasZeroLowPad())
2888 auto padValue = padOp.getConstantPaddingValue();
2892 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2896 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2897 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2899 xferOp.getBaseMutable().assign(padOp.getSource());
2900 xferOp.getPaddingMutable().assign(padValue);
2939struct PadOpVectorizationWithTransferWritePattern
2940 :
public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2941 using VectorizePadOpUserPattern<
2942 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2944 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2945 vector::TransferWriteOp xferOp)
const override {
2947 if (xferOp.getTransferRank() == 0)
2951 if (!padOp.hasZeroLowPad())
2954 auto padValue = padOp.getConstantPaddingValue();
2958 if (!xferOp->hasOneUse())
2960 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2964 if (!trimPadding.hasZeroOffset())
2967 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2973 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2975 xferOp, padOp.getSource().
getType(), xferOp.getVector(),
2976 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2978 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2993 bool hasSameTensorSize(Value beforePadding,
2994 tensor::ExtractSliceOp afterTrimming)
const {
2997 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2998 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3001 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
3002 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3007 if (t1.getRank() != t2.getRank())
3012 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3013 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3015 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3020 if (t1.getNumDynamicDims() == 0)
3028 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
3032 assert(
static_cast<size_t>(t1.getRank()) ==
3033 beforeSlice.getMixedSizes().size());
3034 assert(
static_cast<size_t>(t2.getRank()) ==
3035 afterTrimming.getMixedSizes().size());
3037 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3039 if (!t1.isDynamicDim(i))
3041 auto size1 = beforeSlice.getMixedSizes()[i];
3042 auto size2 = afterTrimming.getMixedSizes()[i];
3049 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3050 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3056 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3057 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3058 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3059 minOp1.getOperands() == minOp2.getOperands())
3085 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3086 auto source = bcast.getSource();
3087 if (llvm::dyn_cast<VectorType>(source.getType()))
3095 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3096 return fill.getInputs()[0];
3101 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3108 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3116 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3124 ArrayRef<int64_t> inputVectorSizes,
3125 SmallVectorImpl<Value> &newResults) {
3127 OpBuilder::InsertionGuard g(rewriter);
3131 auto sourceType = source.getType();
3132 auto resultType = sliceOp.getResultType();
3137 auto elemType = sourceType.getElementType();
3138 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3143 SmallVector<int64_t> vecShape;
3144 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3145 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3146 if (!inputVectorSizes.empty()) {
3147 vecShape.push_back(inputVectorSizes[i]);
3148 }
else if (!sourceType.isDynamicDim(i)) {
3149 vecShape.push_back(sourceType.getDimSize(i));
3150 }
else if (!resultType.isDynamicDim(i)) {
3156 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3163 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3166 auto loc = sliceOp.getLoc();
3169 SmallVector<Value> readIndices(
3172 rewriter, loc, source, vecType, padValue,
3173 inputVectorSizes.empty());
3180 writeIndices, inputVectorSizes.empty());
3183 newResults.push_back(write->
getResult(0));
3211struct PadOpVectorizationWithInsertSlicePattern
3212 :
public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3213 using VectorizePadOpUserPattern<
3214 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3216 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3217 tensor::InsertSliceOp insertOp)
const override {
3219 if (!padOp.hasZeroLowPad())
3222 if (!insertOp.hasUnitStride())
3225 auto padValue = padOp.getConstantPaddingValue();
3229 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3232 if (insertOp.getDest() == padOp.getResult())
3235 auto vecType = VectorType::get(padOp.getType().getShape(),
3236 padOp.getType().getElementType());
3237 unsigned vecRank = vecType.getRank();
3238 unsigned tensorRank = insertOp.getType().getRank();
3242 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3243 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3245 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3246 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3256 SmallVector<Value> readIndices(
3258 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3259 vecType, padOp.getSource(),
3260 readIndices, padValue);
3266 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3267 SmallVector<bool> inBounds(vecRank,
true);
3269 insertOp, read, insertOp.getDest(), writeIndices,
3270 ArrayRef<bool>{inBounds});
3277 RewritePatternSet &
patterns, PatternBenefit baseBenefit) {
3278 patterns.add<PadOpVectorizationWithTransferReadPattern,
3279 PadOpVectorizationWithTransferWritePattern,
3280 PadOpVectorizationWithInsertSlicePattern>(
3291static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3295 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3296 <<
", second op: " << *secondOp;
3299 for (
auto v : values) {
3300 for (
auto &u : v.getUses()) {
3301 Operation *owner = u.getOwner();
3302 if (owner == firstOp || owner == secondOp)
3308 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3309 <<
", second op: " << *secondOp;
3318static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3319 memref::SubViewOp subViewOp;
3321 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3323 return memref::SubViewOp();
3324 subViewOp = newSubViewOp;
3333 vector::TransferReadOp xferOp, PatternRewriter &rewriter)
const {
3336 if (xferOp.getMask())
3340 Value viewOrAlloc = xferOp.getBase();
3346 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3349 Value subView = subViewOp.getResult();
3352 memref::CopyOp copyOp;
3353 for (
auto &u : subView.
getUses()) {
3354 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3355 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3356 if (newCopyOp.getTarget() != subView)
3358 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3370 for (
auto &u : viewOrAlloc.
getUses()) {
3371 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3372 assert(isa<MemRefType>(newFillOp.output().getType()));
3373 if (newFillOp.output() != viewOrAlloc)
3375 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3377 maybeFillOp = newFillOp;
3382 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3384 "padding value does not match fill");
3387 Value in = copyOp.getSource();
3393 auto vectorType = xferOp.getVectorType();
3394 Value res = vector::TransferReadOp::create(
3395 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3396 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3398 SmallVector<bool>(vectorType.getRank(),
false)));
3401 rewriter.
eraseOp(maybeFillOp);
3411 vector::TransferWriteOp xferOp, PatternRewriter &rewriter)
const {
3413 if (xferOp.getMask())
3417 Value viewOrAlloc = xferOp.getBase();
3423 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3426 Value subView = subViewOp.getResult();
3429 memref::CopyOp copyOp;
3430 for (
auto &u : subViewOp.getResult().getUses()) {
3431 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3432 if (newCopyOp.getSource() != subView)
3434 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3444 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3445 Value out = copyOp.getTarget();
3452 auto vector = xferOp.getVector();
3453 vector::TransferWriteOp::create(
3454 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3455 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3457 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3470static void bindShapeDims(ShapedType shapedType) {}
3472template <
int N,
typename IntTy,
typename... IntTy2>
3473static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3474 val = shapedType.getShape()[N];
3475 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3479template <
typename... IntTy>
3480static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3481 bindShapeDims<0>(shapedType, vals...);
3519struct Conv1DGenerator
3520 :
public StructuredGenerator<LinalgOp, utils::IteratorType> {
3521 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3522 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3524 lhsShaped = linalgOp.getDpsInputOperand(0)->
get();
3525 rhsShaped = linalgOp.getDpsInputOperand(1)->
get();
3526 resShaped = linalgOp.getDpsInitOperand(0)->
get();
3527 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3528 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3529 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3534 setConvOperationKind(reduceOp);
3537 reductionKind = maybeKind.value();
3543 auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>(
"strides");
3544 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>(
"dilations");
3545 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3546 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3568 int64_t nSize, wSize, cSize, kwSize, fSize;
3569 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3571 switch (conv1DOpOrder) {
3574 nSize = fSize = cSize = 0;
3576 bindShapeDims(resShapedType, wSize);
3578 bindShapeDims(rhsShapedType, kwSize);
3581 (wSize + kwSize - 1)};
3582 rhsShape = {kwSize};
3587 bindShapeDims(resShapedType, nSize, wSize, fSize);
3589 case ConvOperationKind::Conv:
3591 bindShapeDims(rhsShapedType, kwSize, cSize);
3593 case ConvOperationKind::Pool:
3595 bindShapeDims(rhsShapedType, kwSize);
3603 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3607 case ConvOperationKind::Conv:
3608 rhsShape = {kwSize, cSize, fSize};
3610 case ConvOperationKind::Pool:
3611 rhsShape = {kwSize};
3614 resShape = {nSize, wSize, fSize};
3618 bindShapeDims(resShapedType, nSize, fSize, wSize);
3620 case ConvOperationKind::Conv:
3622 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3624 case ConvOperationKind::Pool:
3626 bindShapeDims(rhsShapedType, kwSize);
3630 lhsShape = {nSize, cSize,
3634 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3637 case ConvOperationKind::Conv:
3638 rhsShape = {fSize, cSize, kwSize};
3640 case ConvOperationKind::Pool:
3641 rhsShape = {kwSize};
3644 resShape = {nSize, fSize, wSize};
3648 vector::TransferWriteOp write;
3654 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3656 Type lhsEltType = lhsShapedType.getElementType();
3657 Type rhsEltType = rhsShapedType.getElementType();
3658 Type resEltType = resShapedType.getElementType();
3659 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3660 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3661 auto resType = VectorType::get(resShape, resEltType);
3663 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3664 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3665 SmallVector<Value> resPadding(resShape.size(), zero);
3668 Value
lhs = vector::TransferReadOp::create(
3669 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3670 arith::getZeroConstant(rewriter, loc, lhsEltType));
3672 Value
rhs =
nullptr;
3673 if (oper == ConvOperationKind::Conv)
3674 rhs = vector::TransferReadOp::create(
3675 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3676 arith::getZeroConstant(rewriter, loc, rhsEltType));
3677 Value res = vector::TransferReadOp::create(
3678 rewriter, loc, resType, resShaped, resPadding,
3679 arith::getZeroConstant(rewriter, loc, resEltType));
3684 switch (conv1DOpOrder) {
3692 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3693 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, permLhs);
3695 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3698 if (oper == ConvOperationKind::Conv)
3699 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, permRhs);
3701 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3702 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3711 SmallVector<Value> lhsVals, rhsVals, resVals;
3713 kwSize, strideW, dilationW, wSizeStep,
3716 if (oper == ConvOperationKind::Conv)
3719 wSizeStep, isSingleChanneled);
3721 auto linearIndex = [&](int64_t kw, int64_t w) {
3722 return kw * (wSize / wSizeStep) + w;
3728 for (int64_t kw = 0; kw < kwSize; ++kw) {
3729 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3731 case ConvOperationKind::Conv:
3732 if (isSingleChanneled) {
3733 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3734 lhsVals[linearIndex(kw, w)],
3735 rhsVals[kw], resVals[w]);
3737 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3738 lhsVals[linearIndex(kw, w)],
3739 rhsVals[kw], resVals[w]);
3742 case ConvOperationKind::Pool:
3743 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3759 switch (conv1DOpOrder) {
3766 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3767 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3772 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3778 Value
promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3781 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3782 if (srcElementType == dstElementType)
3787 const Type dstType =
3788 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3790 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3791 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3794 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3795 srcWidth < dstWidth)
3796 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3798 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3799 srcWidth < dstWidth)
3800 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3802 assert(
false &&
"unhandled promotion case");
3807 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3808 Value
lhs, Value
rhs, Value res) {
3809 vector::IteratorType par = vector::IteratorType::parallel;
3810 vector::IteratorType red = vector::IteratorType::reduction;
3811 AffineExpr n, w, f, c;
3815 auto contrationOp = vector::ContractionOp::create(
3816 rewriter, loc,
lhs,
rhs, res,
3817 MapList{{n, w, c}, {c, f}, {n, w, f}},
3818 ArrayRef<vector::IteratorType>{par, par, par, red});
3819 contrationOp.setKind(reductionKind);
3820 return contrationOp;
3825 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3826 Value
lhs, Value
rhs, Value res) {
3827 return vector::OuterProductOp::create(rewriter, loc, res.
getType(),
lhs,
3828 rhs, res, vector::CombiningKind::ADD);
3832 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value
lhs,
3850 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3851 bool channelDimScalableFlag,
3853 bool scalableChDim =
false;
3854 bool useMasking =
false;
3855 int64_t nSize, wSize, cSize, kwSize;
3857 bindShapeDims(rhsShapedType, kwSize, cSize);
3858 if (ShapedType::isDynamic(cSize)) {
3859 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3860 cSize = channelDimVecSize;
3864 scalableChDim = channelDimScalableFlag;
3868 assert(!(useMasking && flatten) &&
3869 "Unsupported flattened conv with dynamic shapes");
3872 bindShapeDims(resShapedType, nSize, wSize);
3874 vector::TransferWriteOp write;
3880 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3882 Type lhsEltType = lhsShapedType.getElementType();
3883 Type rhsEltType = rhsShapedType.getElementType();
3884 Type resEltType = resShapedType.getElementType();
3885 VectorType lhsType = VectorType::get(
3889 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3891 lhsEltType, {
false,
false, scalableChDim});
3892 VectorType rhsType =
3893 VectorType::get({kwSize, cSize}, rhsEltType,
3894 {
false, scalableChDim});
3895 VectorType resType =
3896 VectorType::get({nSize, wSize, cSize}, resEltType,
3897 {
false,
false, scalableChDim});
3901 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3902 ArrayRef<bool> scalableDims,
3903 Operation *opToMask) {
3907 VectorType::get(maskShape, rewriter.
getI1Type(), scalableDims);
3909 SmallVector<bool> inBounds(maskShape.size(),
true);
3910 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3911 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3915 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3918 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3925 Value
lhs = vector::TransferReadOp::create(
3926 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3927 arith::getZeroConstant(rewriter, loc, lhsEltType));
3928 auto *maybeMaskedLhs = maybeMaskXferOp(
3929 lhsType.getShape(), lhsType.getScalableDims(),
lhs.getDefiningOp());
3932 Value
rhs = vector::TransferReadOp::create(
3933 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3934 arith::getZeroConstant(rewriter, loc, rhsEltType));
3935 auto *maybeMaskedRhs = maybeMaskXferOp(
3936 rhsType.getShape(), rhsType.getScalableDims(),
rhs.getDefiningOp());
3939 Value res = vector::TransferReadOp::create(
3940 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3941 arith::getZeroConstant(rewriter, loc, resEltType));
3942 auto *maybeMaskedRes = maybeMaskXferOp(
3943 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3949 SmallVector<Value> lhsVals, rhsVals, resVals;
3950 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3951 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3955 for (int64_t kw = 0; kw < kwSize; ++kw) {
3956 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3957 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3958 rewriter, loc, maybeMaskedLhs->getResult(0),
3959 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3960 inOutSliceSizes, inOutStrides));
3964 for (int64_t kw = 0; kw < kwSize; ++kw) {
3966 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3967 ArrayRef<int64_t>{kw}));
3970 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3971 resVals.push_back(vector::ExtractStridedSliceOp::create(
3972 rewriter, loc, maybeMaskedRes->getResult(0),
3973 ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3977 auto linearIndex = [&](int64_t kw, int64_t w) {
3978 return kw * (wSize / wSizeStep) + w;
3983 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3984 auto lhsTypeAfterFlattening =
3985 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3986 auto resTypeAfterFlattening =
3987 VectorType::get(inOutFlattenSliceSizes, resEltType);
3990 for (int64_t kw = 0; kw < kwSize; ++kw) {
3991 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3992 Value lhsVal = lhsVals[linearIndex(kw, w)];
3993 Value resVal = resVals[w];
3998 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3999 lhsVals[linearIndex(kw, w)]);
4000 resVal = vector::ShapeCastOp::create(
4001 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4003 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4004 rhsVals[kw], resVal, flatten);
4007 resVals[w] = vector::ShapeCastOp::create(
4008 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4015 if (!llvm::all_of(resVals, [](Value v) {
return v; })) {
4017 for (
auto &collection :
4018 {resVals, rhsVals, lhsVals, {res,
rhs,
lhs, zero}})
4019 for (Value v : collection)
4026 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4027 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4028 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4029 ArrayRef<int64_t>{0, w, 0},
4030 ArrayRef<int64_t>{1, 1, 1});
4037 Operation *resOut = vector::TransferWriteOp::create(
4038 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4040 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4048 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4049 Value
lhs, Value
rhs, Value res,
4051 auto rhsTy = cast<ShapedType>(
rhs.getType());
4052 auto resTy = cast<ShapedType>(res.
getType());
4066 auto rhsSize = cast<VectorType>(
rhs.getType()).getShape()[0];
4067 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4069 SmallVector<int64_t, 16>
indices;
4070 for (
int i = 0; i < resSize / rhsSize; ++i) {
4071 for (
int j = 0; j < rhsSize; ++j)
4078 rhs = vector::BroadcastOp::create(rewriter, loc,
4079 resTy.clone(rhsTy.getElementType()),
rhs);
4086 if (isa<FloatType>(resTy.getElementType()))
4087 return vector::FMAOp::create(rewriter, loc,
lhs,
rhs, res);
4089 auto mul = arith::MulIOp::create(rewriter, loc,
lhs,
rhs);
4090 return arith::AddIOp::create(rewriter, loc,
mul, res);
4095 FailureOr<Operation *> generateNonChanneledConv() {
4098 if (!iters({Par(), Red()}))
4100 "failed to match conv::W 1-par 1-red");
4103 if (layout({ {w + kw},
4113 FailureOr<Operation *> generateNwcConv() {
4114 AffineExpr n, w, f, kw, c;
4116 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4118 op,
"failed to match conv::Nwc 3-par 2-red");
4121 if (layout({ {n, strideW * w + dilationW * kw, c},
4131 FailureOr<Operation *> generateNcwConv() {
4132 AffineExpr n, w, f, kw, c;
4134 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4136 op,
"failed to match conv::Ncw 3-par 2-red");
4138 if (layout({ {n, c, strideW * w + dilationW * kw},
4148 FailureOr<Operation *> generateNwcPooling() {
4149 AffineExpr n, w, c, kw;
4151 if (!iters({Par(), Par(), Par(), Red()}))
4153 "failed to match pooling 3-par 1-red");
4156 if (layout({ {n, strideW * w + dilationW * kw, c},
4166 FailureOr<Operation *> generateNcwPooling() {
4167 AffineExpr n, w, c, kw;
4169 if (!iters({Par(), Par(), Par(), Red()}))
4171 "failed to match pooling 3-par 1-red");
4173 if (layout({ {n, c, strideW * w + dilationW * kw},
4183 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4184 bool vecChDimScalableFlag =
false,
4185 bool flatten =
false) {
4186 AffineExpr n, w, c, kw;
4188 if (!iters({Par(), Par(), Par(), Red()}))
4190 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4193 if (layout({ {n, strideW * w + dilationW * kw, c},
4196 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4202 ConvOperationKind oper = ConvOperationKind::Conv;
4204 StringAttr poolExtOp;
4205 bool isPoolExt =
false;
4206 int strideW, dilationW;
4207 Value lhsShaped, rhsShaped, resShaped;
4208 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4209 vector::CombiningKind reductionKind;
4212 void setConvOperationKind(Operation *reduceOp) {
4213 int numBlockArguments =
4214 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4215 if (numBlockArguments == 1) {
4220 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4221 llvm::IsaPred<BlockArgument>);
4222 Operation *feedOp = (*feedValIt).getDefiningOp();
4223 if (isCastOfBlockArgument(feedOp)) {
4224 oper = ConvOperationKind::Pool;
4229 oper = ConvOperationKind::Conv;
4233 oper = ConvOperationKind::Pool;
4242 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4243 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4244 Conv1DGenerator conv1dGen(rewriter, op);
4245 auto res = conv1dGen.generateNonChanneledConv();
4248 res = conv1dGen.generateNwcConv();
4251 res = conv1dGen.generateNcwConv();
4254 res = conv1dGen.generateNwcPooling();
4257 res = conv1dGen.generateNcwPooling();
4264 uint64_t vecChDimSize = ShapedType::kDynamic;
4265 bool vecChDimScalableFlag =
false;
4266 if (!inputVecSizes.empty()) {
4269 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4270 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4271 "Not a 1D depthwise conv!");
4274 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4275 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4277 vecChDimSize = inputVecSizes[chDimIdx];
4278 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4280 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4281 flatten1DDepthwiseConv);
4284struct VectorizeConvolution :
public OpInterfaceRewritePattern<LinalgOp> {
4287 LogicalResult matchAndRewrite(LinalgOp op,
4288 PatternRewriter &rewriter)
const override {
4290 if (
failed(resultOrFail))
4292 Operation *newOp = *resultOrFail;
4294 rewriter.
eraseOp(op.getOperation());
4297 assert(newOp->
getNumResults() == 1 &&
"expected single result");
4304 RewritePatternSet &
patterns, PatternBenefit benefit) {
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override