37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/Sequence.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/InterleavedRange.h"
43#include "llvm/Support/MathExtras.h"
44#include "llvm/Support/raw_ostream.h"
50#define DEBUG_TYPE "linalg-vectorization"
53static FailureOr<Operation *>
57 bool flatten1DDepthwiseConv =
false);
92template <
typename OpType>
95 block.
walk([&](OpType op) {
111 int64_t kwSize,
int strideW,
int dilationW,
112 int64_t wSizeStep,
bool isSingleChanneled) {
114 if (isSingleChanneled) {
119 for (
int64_t kw = 0; kw < kwSize; ++kw) {
120 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
131 for (
int64_t kw = 0; kw < kwSize; ++kw) {
132 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
151 for (
int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
179 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
192 bool isSingleChanneled) {
194 if (isSingleChanneled) {
198 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
207 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
223 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
226 bool assumeDynamicDimsMatchVecSizes =
false);
241 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
244 if (dimPermutation.has_value()) {
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
254 return VectorType::get(
vectorShape, elementType, scalableDims);
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
275 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
282 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
284 std::optional<AffineMap> maybeMaskingMap);
289 bool isValidMaskingMap(AffineMap maskingMap) {
308 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
314 SmallVector<int64_t> iterSpaceStaticSizes;
319 SmallVector<Value> iterSpaceValueSizes;
322 SmallVector<int64_t> canonicalVecShape;
326 SmallVector<bool> scalableVecDims;
334 OpBuilder::InsertionGuard rewriterGuard;
342 bool assumeDynamicDimsMatchVecSizes =
false;
346VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
349 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
360 unsigned operandDimPos;
361 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
366 linalgOp.hasPureTensorSemantics()
367 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 iterSpaceValueSizes.push_back(dynamicDim);
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
389 if (!inputVectorSizes.empty()) {
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(),
false);
404 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407 if (ShapedType::isDynamicShape(canonicalVecShape))
411 initIterSpaceStaticSizes(linalgOp);
416 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
426Value VectorizationState::getOrCreateMaskFor(
428 std::optional<AffineMap> maybeMaskingMap) {
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.
getContext());
449 LDBG() <<
"Masking map: " << maskingMap;
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() <<
"Reusing mask: " << mask;
466 SmallVector<int64_t> permutedStaticSizes =
468 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
471 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
473 if (permutedStaticSizes == maskShape) {
474 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] = Value();
479 if (assumeDynamicDimsMatchVecSizes) {
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 return std::get<0>(it) == ShapedType::kDynamic
487 : std::get<0>(it) == std::get<1>(it);
490 <<
"Dynamic + static dimensions match vector sizes, masking is not "
492 activeMaskCache[maskingMap] = Value();
498 SmallVector<Value> upperBounds =
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() <<
"Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() <<
"Trying to mask: " << *opToMask;
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526 LDBG() <<
"No mask required";
527 if (assumeDynamicDimsMatchVecSizes) {
529 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
535 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
536 "setting their in-bounds to true!";
538 ShapedType xferType = xferOp.getShapedType();
543 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
548 unsigned pos = dimExpr.getPosition();
549 if (xferType.isDynamicDim(pos))
550 inBoundsMap[i] =
true;
553 xferOp.setInBoundsAttr(
565 assert(opToMask &&
"Expected a valid operation to mask");
566 auto maskOp = cast<vector::MaskOp>(
568 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
570 for (
auto [resIdx, resVal] : llvm::enumerate(opToMask->
getResults()))
574 LDBG() <<
"Masked operation: " << *maskOp;
597 "expected projected permutation");
599 assert(res.getNumDims() ==
600 (res.getNumResults() - res.getNumOfZeroResults()) &&
601 "expected reindexed map with same number of dims and results");
637std::optional<vector::CombiningKind>
639 using ::mlir::vector::CombiningKind;
644 .Case<arith::AddIOp, arith::AddFOp>(
645 [&](
auto op) {
return CombiningKind::ADD; })
646 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
647 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
648 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
649 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
650 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
651 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
652 .Case<arith::MinUIOp>([&](
auto op) {
return CombiningKind::MINUI; })
653 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
654 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
655 .Case<arith::MulIOp, arith::MulFOp>(
656 [&](
auto op) {
return CombiningKind::MUL; })
657 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
658 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
659 .Default(std::nullopt);
670 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
675 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
676 combinerOps.size() != 1)
680 return combinerOps[0];
686 auto dstVecType = dyn_cast<VectorType>(dstType);
688 if (dstVecType.getRank() == 0)
693 Location loc =
b.getInsertionPoint()->getLoc();
694 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
706 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
707 return vector::MultiDimReductionOp::create(
708 b, reduceOp->
getLoc(), valueToReduce,
acc, dimsToMask, *maybeKind);
712 return llvm::to_vector(
719 return isa<linalg::ReduceOp>(op) ||
720 (isa<linalg::GenericOp>(op) &&
732 VectorizationState &state) {
734 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
735 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
744 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
746 auto vectorType = state.getCanonicalVecType(
753 if (vectorType.getRank() > 0) {
756 assert(value.
getType() == vectorType &&
"Incorrect type");
757 write = vector::TransferWriteOp::create(
758 rewriter, loc, value, outputOperand->
get(),
indices, writeMap);
761 if (!isa<VectorType>(value.
getType()))
762 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763 assert(value.
getType() == vectorType &&
"Incorrect type");
764 write = vector::TransferWriteOp::create(rewriter, loc, value,
768 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
772 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
773 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
778 LDBG() <<
"vectorized op: " << *write;
788 std::function<LogicalResult(
Operation *,
bool)>;
805 const IRMapping &bvm, VectorizationState &state,
807 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
810 for (
const auto &output : llvm::enumerate(yieldOp.getValues())) {
816 linalgOp.getDpsInitOperand(output.index()), state);
818 newResults.push_back(newResult);
829 VectorizationState &state,
832 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
835 auto loc = indexOp.getLoc();
838 auto dim = indexOp.getDim();
840 auto indexVectorType =
841 VectorType::get({targetShape[dim]}, rewriter.
getIndexType(),
842 state.getScalableVecDims()[dim]);
843 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
847 if (dim == targetShape.size() - 1)
853 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
854 std::swap(permPattern[dim], permPattern.back());
858 auto broadCastOp = vector::BroadcastOp::create(
860 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
862 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
863 std::swap(transposition.back(), transposition[dim]);
865 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
873 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
877 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
882 if (not extractOp.getIndices().empty()) {
883 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
887 if (!llvm::all_of(extractOp->getResultTypes(),
888 VectorType::isValidElementType)) {
906 VectorizationState &state,
907 tensor::ExtractOp extractOp,
910 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
911 auto loc = extractOp.getLoc();
914 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
916 const size_t numIndices = extractOp.getIndices().size();
917 for (
size_t i = 1; i < numIndices; i++) {
922 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
925 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
928 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
930 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
956 (linalgOp.hasDynamicShape() ||
957 llvm::count_if(loopRanges, [](
int64_t dim) { return dim != 1; }) == 1) &&
958 "For statically shaped Linalg Ops, only one "
959 "non-unit loop dim is expected");
960 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
962 size_t idx = loopRanges.size() - 1;
963 for (; idx != 0; idx--)
964 if (loopRanges[idx] != 1)
972 VectorType resType) {
974 assert(((llvm::count_if(resType.getShape(),
975 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976 "n-D vectors are not yet supported");
982 auto *block = linalgOp.getBlock();
983 if (isa<BlockArgument>(val))
984 return !llvm::is_contained(block->getArguments(), val);
987 assert(defOp &&
"This is neither a block argument nor an operation result");
992 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
993 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
996 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1003 if (isa<arith::ConstantOp>(ancestor))
1007 for (
auto op : ancestor->getOperands())
1031 bool &foundIndexOp, VectorType resType) {
1033 assert(((llvm::count_if(resType.getShape(),
1034 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1035 "n-D vectors are not yet supported");
1041 auto *block = linalgOp.getBlock();
1042 if (isa<BlockArgument>(val))
1043 return !llvm::is_contained(block->getArguments(), val);
1046 assert(defOp &&
"This is neither a block argument nor an operation result");
1048 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1051 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1055 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1062 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1066 for (
auto op : ancestor->getOperands())
1086 LinalgOp &linalgOp, VectorType resType) {
1088 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1091 if (inputShape.getShape().empty())
1096 bool isOutput1DVector =
1097 (llvm::count_if(resType.getShape(),
1098 [](
int64_t dimSize) { return dimSize > 1; }) == 1);
1100 if (!isOutput1DVector)
1103 bool leadingIdxsLoopInvariant =
true;
1109 auto indices = extractOp.getIndices();
1110 auto leadIndices =
indices.drop_back(1);
1112 for (
auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1113 if (inputShape.getShape()[i] == 1)
1119 if (!leadingIdxsLoopInvariant) {
1120 LDBG() <<
"Found gather load: " << extractOp;
1128 auto extractOpTrailingIdx =
indices.back();
1132 if (leadingIdxsLoopInvariant &&
1134 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1143 bool foundIndexOp =
false;
1145 foundIndexOp, resType);
1148 bool isRowVector = resType.getShape().back() != 1;
1149 isContiguousLoad &= (foundIndexOp && isRowVector);
1151 if (isContiguousLoad) {
1152 LDBG() <<
"Found contigous load: " << extractOp;
1157 LDBG() <<
"Found gather load: " << extractOp;
1165static VectorizationHookResult
1168 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1171 auto loc = extractOp.getLoc();
1174 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1175 auto maskConstantOp = arith::ConstantOp::create(
1179 auto passThruConstantOp = arith::ConstantOp::create(
1185 extractOp.getIndices().size(),
1196 Operation *gatherOp = vector::GatherOp::create(
1197 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1198 maskConstantOp, passThruConstantOp);
1199 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1201 LDBG() <<
"Vectorised as gather load: " << extractOp;
1224 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1225 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1227 transferReadIdxs.push_back(idx);
1231 auto indexAs1dVector = vector::ShapeCastOp::create(
1233 VectorType::get(resultType.getShape().back(), rewriter.
getIndexType(),
1234 resultType.getScalableDims().back()),
1236 transferReadIdxs.push_back(
1237 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1241 auto dstRank = resultType.getRank();
1242 auto srcRank = extractOp.getTensor().getType().getRank();
1251 auto transferReadOp = vector::TransferReadOp::create(
1252 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1253 std::nullopt, permutationMap, inBounds);
1259 auto readMaskType = VectorType::get(readMaskShape, rewriter.
getI1Type());
1260 auto allTrue = vector::ConstantMaskOp::create(
1262 auto *maskedReadOp =
1265 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1272 srcRank, std::min(dstRank, srcRank), rewriter.
getContext());
1274 int32_t rankDiff = dstRank - srcRank;
1282 while (rankDiff > 0) {
1283 permutationMap = permutationMap.insertResult(
1288 auto transferReadOp = vector::TransferReadOp::create(
1289 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1290 std::nullopt, permutationMap, inBounds);
1292 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1306 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1307 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1311 (outputType && reduceType.getShape() == outputType.getShape()))
1336static VectorizationHookResult
1340 LDBG() <<
"vectorize op " << *op;
1343 if (!customVectorizationHooks.empty()) {
1344 for (
auto &customFunc : customVectorizationHooks) {
1354 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1356 rewriter.
clone(*op)};
1365 auto blockArg = dyn_cast<BlockArgument>(operand);
1366 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1367 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1371 linalgOp.getRegionOutputArgs(),
1372 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1375 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1377 if (!reductionOperands.empty()) {
1378 assert(reductionOperands.size() == 1);
1380 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1381 reductionOperands[0].second, bvm);
1388 VectorType firstMaxRankedType;
1390 auto vecOperand = bvm.
lookup(operand);
1391 assert(vecOperand &&
"Vector operand couldn't be found");
1393 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1394 if (vecType && (!firstMaxRankedType ||
1395 firstMaxRankedType.getRank() < vecType.getRank()))
1396 firstMaxRankedType = vecType;
1402 assert(vecOperand &&
"Vector operand couldn't be found");
1404 if (firstMaxRankedType) {
1405 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1407 firstMaxRankedType.getScalableDims());
1410 vecOperands.push_back(vecOperand);
1416 resultTypes.push_back(
1418 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1419 firstMaxRankedType.getScalableDims())
1455 LDBG() <<
"Vectorizing operation as linalg generic/n";
1456 Block *block = linalgOp.getBlock();
1463 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1465 if (linalgOp.getNumDpsInits() == 0)
1471 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1472 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1473 if (linalgOp.isScalar(opOperand)) {
1474 bvm.
map(bbarg, opOperand->get());
1480 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1483 VectorType readType;
1485 if (linalgOp.isDpsInput(opOperand)) {
1488 readType = state.getCanonicalVecType(elemType);
1495 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1500 Operation *read = vector::TransferReadOp::create(
1501 rewriter, loc, readType, opOperand->get(),
indices,
1502 std::nullopt, readMap);
1503 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1508 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1510 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1516 if (readType.getRank() == 0)
1517 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1520 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1521 <<
"): " << readValue;
1522 bvm.
map(bbarg, readValue);
1523 bvm.
map(opOperand->get(), readValue);
1532 hooks.push_back(vectorizeYield);
1539 hooks.push_back(vectorizeIndex);
1546 hooks.push_back(vectorizeExtract);
1553 LDBG() <<
"failed to vectorize: " << op;
1558 state.maskOperation(rewriter,
result.newOp, linalgOp);
1559 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1618 if (ShapedType::isDynamicShape(destShape))
1623 for (
auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1625 cstMaskSizes.push_back(*intSize);
1630 if (cstMaskSizes.size() != maskShape.size())
1635 for (
auto [i, idx] : llvm::enumerate(writeIdxs)) {
1638 cstWriteIdxs.push_back(intVal.getSExtValue());
1643 if (cstWriteIdxs.size() != destShape.size())
1652 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1653 for (
auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1654 if ( maskShape[i] > destShape[rankDiff + i] ||
1655 destShape[rankDiff + i] <
1656 (std::clamp(cstMaskSizes[i],
int64_t(0), maskShape[i]) +
1692 bool useInBoundsInsteadOfMasking =
false) {
1694 ShapedType destType = cast<ShapedType>(dest.
getType());
1695 int64_t destRank = destType.getRank();
1696 auto destShape = destType.getShape();
1698 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1699 int64_t vecToStoreRank = vecToStoreType.getRank();
1700 auto vecToStoreShape = vecToStoreType.getShape();
1703 SmallVector<bool> inBoundsVal(vecToStoreRank,
true);
1704 if (useInBoundsInsteadOfMasking) {
1707 for (
unsigned i = 0; i < vecToStoreRank; i++)
1709 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1710 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1714 assert((writeIndices.empty() ||
1715 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1716 "Invalid number of write indices!");
1717 if (writeIndices.empty()) {
1719 writeIndices.assign(destRank, zero);
1723 Operation *write = vector::TransferWriteOp::create(builder, loc,
1730 if (useInBoundsInsteadOfMasking)
1734 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1738 auto writeMaskType = VectorType::get(vecToStoreShape, builder.
getI1Type(),
1739 vecToStoreType.getScalableDims());
1741 SmallVector<OpFoldResult> destSizes =
1742 isa<MemRefType>(dest.
getType())
1745 SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1752 Value maskForWrite =
1753 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1775 assert(type.getNumScalableDims() < 2 &&
1776 "Collapsing more than 1 scalable dim is not supported ATM");
1782 auto shape = type.getShape();
1783 auto scalableFlags = type.getScalableDims();
1787 unsigned currentDim = 0;
1789 unsigned dim = m.getNumResults();
1792 for (
unsigned d = 0; d < dim; ++d) {
1793 size *=
shape[currentDim + d];
1794 flag |= scalableFlags[currentDim + d];
1796 newShape.push_back(size);
1797 newScalableFlags.push_back(flag);
1801 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1834vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1835 ArrayRef<int64_t> inputVectorSizes,
1836 SmallVectorImpl<Value> &newResults) {
1837 if (!inputVectorSizes.empty()) {
1838 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1839 "Invalid number of input vector sizes!");
1843 OpBuilder::InsertionGuard g(rewriter);
1846 Location loc = packOp.getLoc();
1847 std::optional<Value> padValue = packOp.getPaddingValue()
1848 ? std::optional(packOp.getPaddingValue())
1851 SmallVector<int64_t> destShape =
1852 SmallVector<int64_t>(packOp.getDestType().getShape());
1856 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1860 bool useInBoundsInsteadOfMasking =
false;
1861 if (writeVectorSizes.empty()) {
1862 if (ShapedType::isDynamicShape(destShape))
1864 "unable to infer vector sizes");
1866 writeVectorSizes = destShape;
1867 useInBoundsInsteadOfMasking =
true;
1876 PackingMetadata packMetadata;
1877 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1880 auto preTransposeWriteVecType = VectorType::get(
1881 preTransposeWriteVecSizses, packOp.getType().getElementType());
1887 preTransposeWriteVecType,
1889 rewriter.
getContext(), packMetadata.reassociations)));
1893 rewriter, loc, packOp.getSource(), readVecType, padValue,
1894 useInBoundsInsteadOfMasking);
1897 auto shapeCastOp = vector::ShapeCastOp::create(
1898 rewriter, loc, preTransposeWriteVecType, maskedRead);
1902 auto transposeOp = vector::TransposeOp::create(
1903 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1907 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1908 newResults.push_back(write->
getResult(0));
1942vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1943 ArrayRef<int64_t> inputVectorSizes,
1944 ArrayRef<bool> inputScalableVecDims,
1945 SmallVectorImpl<Value> &newResults) {
1946 if (!inputVectorSizes.empty()) {
1947 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1948 "Invalid number of input vector sizes!");
1949 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1950 "Incompatible number of vector sizes and vector scalable flags!");
1954 OpBuilder::InsertionGuard g(rewriter);
1957 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1959 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1960 bool useInBoundsInsteadOfMasking =
false;
1962 Location loc = unpackOp->getLoc();
1965 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1966 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1969 if (inputVectorSizes.empty()) {
1970 if (ShapedType::isDynamicShape(sourceShape))
1972 "Unable to infer vector sizes!");
1974 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1975 useInBoundsInsteadOfMasking =
true;
1979 VectorType readVecType =
1980 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1981 readScalableVectorFlags);
1983 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1984 useInBoundsInsteadOfMasking);
1987 PackingMetadata packMetadata;
1988 SmallVector<int64_t> lastDimToInsertPosPerm =
1990 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1991 rewriter, loc, readResult, lastDimToInsertPosPerm);
1995 transposeOp.getType(),
1997 rewriter.
getContext(), packMetadata.reassociations)));
1998 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1999 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2003 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2004 {}, useInBoundsInsteadOfMasking);
2006 newResults.push_back(write->
getResult(0));
2014vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2015 ArrayRef<int64_t> inputVectorSizes,
2016 SmallVectorImpl<Value> &newResults) {
2017 auto padValue = padOp.getConstantPaddingValue();
2018 Location loc = padOp.getLoc();
2021 OpBuilder::InsertionGuard g(rewriter);
2025 LogicalResult status =
2026 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2027 .reifyResultShapes(rewriter, reifiedReturnShapes);
2029 assert(succeeded(status) &&
"failed to reify result shapes");
2030 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2032 rewriter, loc, padOp.getSource(), readType, padValue,
2036 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2037 padOp.getResultType().getElementType());
2039 newResults.push_back(write->
getResult(0));
2045static LogicalResult reductionPreconditions(LinalgOp op) {
2047 LDBG() <<
"reduction precondition failed: no reduction iterator";
2050 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2051 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2057 LDBG() <<
"reduction precondition failed: reduction detection failed";
2065vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2066 bool flatten1DDepthwiseConv) {
2067 if (flatten1DDepthwiseConv) {
2068 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2073 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2074 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2080 Value
lhs = conv.getDpsInputOperand(0)->get();
2081 ArrayRef<int64_t> lhsShape = cast<ShapedType>(
lhs.getType()).getShape();
2082 auto shapeWithoutCh = lhsShape.drop_back(1);
2083 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2084 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2085 "channel dim can be dynamic";
2093vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2094 bool flatten1DDepthwiseConv) {
2095 if (isa<ConvolutionOpInterface>(op.getOperation()))
2096 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2099 return reductionPreconditions(op);
2104 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2108 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2118vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2119 ArrayRef<int64_t> inputVectorSizes) {
2122 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2123 unpackOp.getSourceType().hasStaticShape())
2128 if (!inputVectorSizes.empty() &&
2129 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2130 LDBG() <<
"Incorrect number of input vector sizes";
2136 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2137 LDBG() <<
"Invalid vector sizes for the read operation";
2145vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2146 ArrayRef<int64_t> inputVectorSizes) {
2149 auto sourceType = source.getType();
2150 if (!VectorType::isValidElementType(sourceType.getElementType()))
2166 bool isOutOfBoundsRead =
2167 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2169 if (!padValue && isOutOfBoundsRead) {
2170 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2184vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2186 SmallVectorImpl<Value> &newResults) {
2187 Location loc = linalgOp.getLoc();
2188 MLIRContext *ctx = linalgOp.getContext();
2193 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2196 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2200 LDBG() <<
"Failed to determine contraction combining kind.";
2207 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2208 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2210 LDBG() <<
"Contractions with broadcasts are not supported.";
2215 SmallVector<Value> vecOperands;
2216 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2220 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2224 VectorType readType =
2225 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2228 rewriter, loc, opOperand.get(), readType,
2229 arith::getZeroConstant(rewriter, loc, elemType),
2231 vecOperands.push_back(read);
2235 SmallVector<Attribute> iterAttrs;
2236 auto iterators = linalgOp.getIteratorTypesArray();
2237 for (utils::IteratorType iter : iterators) {
2238 auto vecIter = iter == utils::IteratorType::parallel
2239 ? vector::IteratorType::parallel
2240 : vector::IteratorType::reduction;
2241 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2245 Operation *contractOp = vector::ContractionOp::create(
2246 rewriter, loc, vecOperands[0],
2247 vecOperands[1], vecOperands[2],
2248 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2249 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2253 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2257 newResults.push_back(write->
getResult(0));
2263enum class ConvOperationKind { Conv, Pool };
2266static bool isCastOfBlockArgument(Operation *op) {
2281static std::optional<ConvOperationKind>
2282getConvOperationKind(Operation *reduceOp) {
2283 int numBlockArguments =
2284 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2286 switch (numBlockArguments) {
2292 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2293 llvm::IsaPred<BlockArgument>);
2295 "Expected a non-block argument operand");
2296 Operation *feedOp = (*feedValIt).getDefiningOp();
2297 if (isCastOfBlockArgument(feedOp)) {
2298 return ConvOperationKind::Pool;
2301 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2302 (isa<arith::AndIOp>(feedOp) &&
2305 if (isa<BlockArgument>(v))
2307 if (Operation *op = v.getDefiningOp())
2308 return isCastOfBlockArgument(op);
2311 return std::nullopt;
2314 return ConvOperationKind::Conv;
2318 return ConvOperationKind::Pool;
2320 return std::nullopt;
2324static bool isSupportedPoolKind(vector::CombiningKind kind) {
2326 case vector::CombiningKind::ADD:
2327 case vector::CombiningKind::MAXNUMF:
2328 case vector::CombiningKind::MAXIMUMF:
2329 case vector::CombiningKind::MAXSI:
2330 case vector::CombiningKind::MAXUI:
2331 case vector::CombiningKind::MINNUMF:
2332 case vector::CombiningKind::MINIMUMF:
2333 case vector::CombiningKind::MINSI:
2334 case vector::CombiningKind::MINUI:
2341static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2342 auto getOperandType = [&](
auto operand) {
2343 return dyn_cast<ShapedType>((operand->get()).getType());
2345 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2346 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2347 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2351 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2352 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2359 auto maybeOper = getConvOperationKind(reduceOp);
2360 if (!maybeOper.has_value())
2367 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2368 *maybeKind != vector::CombiningKind::OR) &&
2369 (*maybeOper != ConvOperationKind::Pool ||
2370 !isSupportedPoolKind(*maybeKind)))) {
2374 auto rhsRank = rhsShapedType.getRank();
2375 if (*maybeOper == ConvOperationKind::Pool) {
2379 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2386static LogicalResult vectorizeLinalgOpPrecondition(
2387 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2388 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2390 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2391 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2395 if (!inputVectorSizes.empty() &&
2400 if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(
2401 linalgOp, flatten1DDepthwiseConv))) {
2402 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2406 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2412 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2415 customPreconditions,
2418 customPrecondition(&innerOp, vectorizeNDExtract));
2422 if (!llvm::all_of(innerOp.getOperandTypes(),
2423 VectorType::isValidElementType)) {
2426 if (!llvm::all_of(innerOp.getResultTypes(),
2427 VectorType::isValidElementType)) {
2437 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2438 return vectorizeConvOpPrecondition(linalgOp);
2444 LDBG() <<
"precondition failed: not projected permutations";
2447 if (
failed(reductionPreconditions(linalgOp))) {
2448 LDBG() <<
"precondition failed: reduction preconditions";
2455vectorizePackOpPrecondition(linalg::PackOp packOp,
2456 ArrayRef<int64_t> inputVectorSizes) {
2457 auto padValue = packOp.getPaddingValue();
2461 LDBG() <<
"pad value is not constant: " << packOp;
2465 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2466 bool satisfyEmptyCond =
true;
2467 if (inputVectorSizes.empty()) {
2468 if (!packOp.getDestType().hasStaticShape() ||
2469 !packOp.getSourceType().hasStaticShape())
2470 satisfyEmptyCond =
false;
2473 if (!satisfyEmptyCond &&
2475 resultTensorShape.take_front(packOp.getSourceRank()),
2479 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2480 return !getConstantIntValue(v).has_value();
2482 LDBG() <<
"inner_tiles must be constant: " << packOp;
2490vectorizePadOpPrecondition(tensor::PadOp padOp,
2491 ArrayRef<int64_t> inputVectorSizes) {
2492 auto padValue = padOp.getConstantPaddingValue();
2494 LDBG() <<
"pad value is not constant: " << padOp;
2498 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2514 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2515 Value padValue = en.value();
2516 unsigned pos = en.index();
2517 std::optional<int64_t> pad = getConstantIntValue(padValue);
2518 return (!pad.has_value() || pad.value() != 0) &&
2519 resultTensorShape[pos] != 1;
2521 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2535vectorizeScalableVectorPrecondition(Operation *op,
2536 ArrayRef<int64_t> inputVectorSizes,
2537 ArrayRef<bool> inputScalableVecDims) {
2538 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2539 "Number of input vector sizes and scalable dims doesn't match");
2541 size_t numOfScalableDims =
2542 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2544 if (numOfScalableDims == 0)
2547 auto linalgOp = dyn_cast<LinalgOp>(op);
2552 return success(isa<linalg::UnPackOp>(op));
2556 if (numOfScalableDims > 2)
2576 bool seenNonUnitParallel =
false;
2577 auto iterators = linalgOp.getIteratorTypesArray();
2578 SmallVector<bool> scalableFlags(inputScalableVecDims);
2579 int64_t idx = scalableFlags.size() - 1;
2580 while (!scalableFlags[idx]) {
2581 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2582 seenNonUnitParallel |=
2583 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2585 iterators.pop_back();
2586 scalableFlags.pop_back();
2591 switch (iterators.back()) {
2592 case utils::IteratorType::reduction: {
2594 if (iterators.size() != inputVectorSizes.size()) {
2595 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2599 if (isa<linalg::MatmulOp>(op)) {
2601 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2607 case utils::IteratorType::parallel: {
2609 if (seenNonUnitParallel) {
2610 LDBG() <<
"Inner parallel dim not requested for scalable "
2622 if (numOfScalableDims == 2) {
2626 if (iterators.back() == utils::IteratorType::reduction) {
2627 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2632 scalableFlags.pop_back();
2633 iterators.pop_back();
2635 if (!scalableFlags.back() ||
2636 (iterators.back() != utils::IteratorType::parallel))
2643 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2644 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2645 isa<linalg::BatchMmt4DOp>(op) ||
2650 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2651 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2652 bool flatten1DDepthwiseConv) {
2657 if (
failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2658 inputScalableVecDims)))
2662 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2663 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2665 flatten1DDepthwiseConv);
2667 .Case<tensor::PadOp>([&](
auto padOp) {
2668 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2670 .Case<linalg::PackOp>([&](
auto packOp) {
2671 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2673 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2674 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2676 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2677 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2679 .Default([](
auto) {
return failure(); });
2683static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2684 OpBuilder::InsertionGuard g(rewriter);
2685 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2687 for (
auto op : make_early_inc_range(toReplace)) {
2689 auto expanded = affine::expandAffineExpr(
2691 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2692 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2698 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2699 tensor::InsertSliceOp>(op);
2703 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2704 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2705 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2706 bool createNamedContraction) {
2707 LDBG() <<
"Attempting to vectorize: " << *op;
2708 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2709 LDBG() <<
"Input scalable vector dims: "
2710 << llvm::interleaved(inputScalableVecDims);
2714 flatten1DDepthwiseConv))) {
2715 LDBG() <<
"Vectorization pre-conditions failed";
2720 VectorizationState state(rewriter);
2721 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2722 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2723 inputScalableVecDims,
2724 assumeDynamicDimsMatchVecSizes))) {
2725 LDBG() <<
"Vectorization state couldn't be initialized";
2730 SmallVector<Value> results;
2731 auto vectorizeResult =
2733 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2737 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2739 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2740 flatten1DDepthwiseConv);
2741 if (succeeded(convOr)) {
2742 llvm::append_range(results, (*convOr)->getResults());
2746 LDBG() <<
"Unsupported convolution can't be vectorized.";
2750 if (createNamedContraction &&
2751 isa<ContractionOpInterface>(linalgOp.getOperation()))
2752 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2756 <<
"Vectorize generic by broadcasting to the canonical vector "
2760 convertAffineApply(rewriter, linalgOp);
2769 .Case<tensor::PadOp>([&](
auto padOp) {
2770 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2773 .Case<linalg::PackOp>([&](
auto packOp) {
2774 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2777 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2778 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2780 inputScalableVecDims, results);
2782 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2786 .Default([](
auto) {
return failure(); });
2788 if (
failed(vectorizeResult)) {
2789 LDBG() <<
"Vectorization failed";
2793 return VectorizationResult{results};
2797 memref::CopyOp copyOp) {
2798 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2799 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2800 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2805 if (!VectorType::isValidElementType(srcElementType) ||
2806 !VectorType::isValidElementType(dstElementType))
2809 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2810 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2812 Location loc = copyOp->getLoc();
2814 SmallVector<Value>
indices(srcType.getRank(), zero);
2816 Value
readValue = vector::TransferReadOp::create(
2817 rewriter, loc, readType, copyOp.getSource(),
indices,
2820 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2821 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2822 ArrayRef<int64_t>());
2824 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2826 Operation *writeValue = vector::TransferWriteOp::create(
2827 rewriter, loc, readValue, copyOp.getTarget(),
indices,
2838template <
typename OpTy>
2839struct VectorizePadOpUserPattern :
public OpRewritePattern<tensor::PadOp> {
2840 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2842 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2843 PatternRewriter &rewriter)
const final {
2846 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2847 if (
auto op = dyn_cast<OpTy>(user))
2848 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2853 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2854 tensor::PadOp padOp, OpTy op)
const = 0;
2876struct PadOpVectorizationWithTransferReadPattern
2877 :
public VectorizePadOpUserPattern<vector::TransferReadOp> {
2878 using VectorizePadOpUserPattern<
2879 vector::TransferReadOp>::VectorizePadOpUserPattern;
2881 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2882 vector::TransferReadOp xferOp)
const override {
2884 if (!padOp.hasZeroLowPad())
2887 auto padValue = padOp.getConstantPaddingValue();
2891 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2895 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2896 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2898 xferOp.getBaseMutable().assign(padOp.getSource());
2899 xferOp.getPaddingMutable().assign(padValue);
2938struct PadOpVectorizationWithTransferWritePattern
2939 :
public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2940 using VectorizePadOpUserPattern<
2941 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2943 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2944 vector::TransferWriteOp xferOp)
const override {
2946 if (xferOp.getTransferRank() == 0)
2950 if (!padOp.hasZeroLowPad())
2953 auto padValue = padOp.getConstantPaddingValue();
2957 if (!xferOp->hasOneUse())
2959 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2963 if (!trimPadding.hasZeroOffset())
2966 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2972 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2974 xferOp, padOp.getSource().
getType(), xferOp.getVector(),
2975 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2977 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2992 bool hasSameTensorSize(Value beforePadding,
2993 tensor::ExtractSliceOp afterTrimming)
const {
2996 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2997 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3000 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
3001 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3006 if (t1.getRank() != t2.getRank())
3011 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3012 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3014 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3019 if (t1.getNumDynamicDims() == 0)
3027 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
3031 assert(
static_cast<size_t>(t1.getRank()) ==
3032 beforeSlice.getMixedSizes().size());
3033 assert(
static_cast<size_t>(t2.getRank()) ==
3034 afterTrimming.getMixedSizes().size());
3036 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3038 if (!t1.isDynamicDim(i))
3040 auto size1 = beforeSlice.getMixedSizes()[i];
3041 auto size2 = afterTrimming.getMixedSizes()[i];
3048 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3049 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3055 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3056 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3057 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3058 minOp1.getOperands() == minOp2.getOperands())
3084 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3085 auto source = bcast.getSource();
3086 if (llvm::dyn_cast<VectorType>(source.getType()))
3094 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3095 return fill.getInputs()[0];
3100 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3107 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3115 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3123 ArrayRef<int64_t> inputVectorSizes,
3124 SmallVectorImpl<Value> &newResults) {
3126 OpBuilder::InsertionGuard g(rewriter);
3130 auto sourceType = source.getType();
3131 auto resultType = sliceOp.getResultType();
3136 auto elemType = sourceType.getElementType();
3137 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3142 SmallVector<int64_t> vecShape;
3143 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3144 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3145 if (!inputVectorSizes.empty()) {
3146 vecShape.push_back(inputVectorSizes[i]);
3147 }
else if (!sourceType.isDynamicDim(i)) {
3148 vecShape.push_back(sourceType.getDimSize(i));
3149 }
else if (!resultType.isDynamicDim(i)) {
3155 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3162 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3165 auto loc = sliceOp.getLoc();
3168 SmallVector<Value> readIndices(
3171 rewriter, loc, source, vecType, padValue,
3172 inputVectorSizes.empty());
3179 writeIndices, inputVectorSizes.empty());
3182 newResults.push_back(write->
getResult(0));
3210struct PadOpVectorizationWithInsertSlicePattern
3211 :
public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3212 using VectorizePadOpUserPattern<
3213 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3215 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3216 tensor::InsertSliceOp insertOp)
const override {
3218 if (!padOp.hasZeroLowPad())
3221 if (!insertOp.hasUnitStride())
3224 auto padValue = padOp.getConstantPaddingValue();
3228 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3231 if (insertOp.getDest() == padOp.getResult())
3234 auto vecType = VectorType::get(padOp.getType().getShape(),
3235 padOp.getType().getElementType());
3236 unsigned vecRank = vecType.getRank();
3237 unsigned tensorRank = insertOp.getType().getRank();
3241 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3242 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3244 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3245 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3255 SmallVector<Value> readIndices(
3257 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3258 vecType, padOp.getSource(),
3259 readIndices, padValue);
3265 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3266 SmallVector<bool> inBounds(vecRank,
true);
3268 insertOp, read, insertOp.getDest(), writeIndices,
3269 ArrayRef<bool>{inBounds});
3276 RewritePatternSet &
patterns, PatternBenefit baseBenefit) {
3277 patterns.add<PadOpVectorizationWithTransferReadPattern,
3278 PadOpVectorizationWithTransferWritePattern,
3279 PadOpVectorizationWithInsertSlicePattern>(
3290static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3294 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3295 <<
", second op: " << *secondOp;
3298 for (
auto v : values) {
3299 for (
auto &u : v.getUses()) {
3300 Operation *owner = u.getOwner();
3301 if (owner == firstOp || owner == secondOp)
3307 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3308 <<
", second op: " << *secondOp;
3317static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3318 memref::SubViewOp subViewOp;
3320 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3322 return memref::SubViewOp();
3323 subViewOp = newSubViewOp;
3332 vector::TransferReadOp xferOp, PatternRewriter &rewriter)
const {
3335 if (xferOp.getMask())
3339 Value viewOrAlloc = xferOp.getBase();
3345 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3348 Value subView = subViewOp.getResult();
3351 memref::CopyOp copyOp;
3352 for (
auto &u : subView.
getUses()) {
3353 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3354 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3355 if (newCopyOp.getTarget() != subView)
3357 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3369 for (
auto &u : viewOrAlloc.
getUses()) {
3370 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3371 assert(isa<MemRefType>(newFillOp.output().getType()));
3372 if (newFillOp.output() != viewOrAlloc)
3374 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3376 maybeFillOp = newFillOp;
3381 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3383 "padding value does not match fill");
3386 Value in = copyOp.getSource();
3392 auto vectorType = xferOp.getVectorType();
3393 Value res = vector::TransferReadOp::create(
3394 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3395 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3397 SmallVector<bool>(vectorType.getRank(),
false)));
3400 rewriter.
eraseOp(maybeFillOp);
3410 vector::TransferWriteOp xferOp, PatternRewriter &rewriter)
const {
3412 if (xferOp.getMask())
3416 Value viewOrAlloc = xferOp.getBase();
3422 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3425 Value subView = subViewOp.getResult();
3428 memref::CopyOp copyOp;
3429 for (
auto &u : subViewOp.getResult().getUses()) {
3430 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3431 if (newCopyOp.getSource() != subView)
3433 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3443 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3444 Value out = copyOp.getTarget();
3451 auto vector = xferOp.getVector();
3452 vector::TransferWriteOp::create(
3453 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3454 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3456 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3469static void bindShapeDims(ShapedType shapedType) {}
3471template <
int N,
typename IntTy,
typename... IntTy2>
3472static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3473 val = shapedType.getShape()[N];
3474 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3478template <
typename... IntTy>
3479static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3480 bindShapeDims<0>(shapedType, vals...);
3518struct Conv1DGenerator
3519 :
public StructuredGenerator<LinalgOp, utils::IteratorType> {
3520 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3521 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3523 lhsShaped = linalgOp.getDpsInputOperand(0)->
get();
3524 rhsShaped = linalgOp.getDpsInputOperand(1)->
get();
3525 resShaped = linalgOp.getDpsInitOperand(0)->
get();
3526 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3527 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3528 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3533 setConvOperationKind(reduceOp);
3536 reductionKind = maybeKind.value();
3542 auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>(
"strides");
3543 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>(
"dilations");
3544 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3545 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3567 int64_t nSize, wSize, cSize, kwSize, fSize;
3568 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3570 switch (conv1DOpOrder) {
3573 nSize = fSize = cSize = 0;
3575 bindShapeDims(resShapedType, wSize);
3577 bindShapeDims(rhsShapedType, kwSize);
3580 (wSize + kwSize - 1)};
3581 rhsShape = {kwSize};
3586 bindShapeDims(resShapedType, nSize, wSize, fSize);
3588 case ConvOperationKind::Conv:
3590 bindShapeDims(rhsShapedType, kwSize, cSize);
3592 case ConvOperationKind::Pool:
3594 bindShapeDims(rhsShapedType, kwSize);
3602 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3606 case ConvOperationKind::Conv:
3607 rhsShape = {kwSize, cSize, fSize};
3609 case ConvOperationKind::Pool:
3610 rhsShape = {kwSize};
3613 resShape = {nSize, wSize, fSize};
3617 bindShapeDims(resShapedType, nSize, fSize, wSize);
3619 case ConvOperationKind::Conv:
3621 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3623 case ConvOperationKind::Pool:
3625 bindShapeDims(rhsShapedType, kwSize);
3629 lhsShape = {nSize, cSize,
3633 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3636 case ConvOperationKind::Conv:
3637 rhsShape = {fSize, cSize, kwSize};
3639 case ConvOperationKind::Pool:
3640 rhsShape = {kwSize};
3643 resShape = {nSize, fSize, wSize};
3647 vector::TransferWriteOp write;
3653 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3655 Type lhsEltType = lhsShapedType.getElementType();
3656 Type rhsEltType = rhsShapedType.getElementType();
3657 Type resEltType = resShapedType.getElementType();
3658 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3659 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3660 auto resType = VectorType::get(resShape, resEltType);
3662 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3663 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3664 SmallVector<Value> resPadding(resShape.size(), zero);
3667 Value
lhs = vector::TransferReadOp::create(
3668 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3669 arith::getZeroConstant(rewriter, loc, lhsEltType));
3671 Value
rhs =
nullptr;
3672 if (oper == ConvOperationKind::Conv)
3673 rhs = vector::TransferReadOp::create(
3674 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3675 arith::getZeroConstant(rewriter, loc, rhsEltType));
3676 Value res = vector::TransferReadOp::create(
3677 rewriter, loc, resType, resShaped, resPadding,
3678 arith::getZeroConstant(rewriter, loc, resEltType));
3683 switch (conv1DOpOrder) {
3691 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3692 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, permLhs);
3694 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3697 if (oper == ConvOperationKind::Conv)
3698 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, permRhs);
3700 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3701 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3710 SmallVector<Value> lhsVals, rhsVals, resVals;
3712 kwSize, strideW, dilationW, wSizeStep,
3715 if (oper == ConvOperationKind::Conv)
3718 wSizeStep, isSingleChanneled);
3720 auto linearIndex = [&](int64_t kw, int64_t w) {
3721 return kw * (wSize / wSizeStep) + w;
3727 for (int64_t kw = 0; kw < kwSize; ++kw) {
3728 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3730 case ConvOperationKind::Conv:
3731 if (isSingleChanneled) {
3732 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3733 lhsVals[linearIndex(kw, w)],
3734 rhsVals[kw], resVals[w]);
3736 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3737 lhsVals[linearIndex(kw, w)],
3738 rhsVals[kw], resVals[w]);
3741 case ConvOperationKind::Pool:
3742 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3758 switch (conv1DOpOrder) {
3765 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3766 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3771 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3777 Value
promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3780 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3781 if (srcElementType == dstElementType)
3786 const Type dstType =
3787 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3789 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3790 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3793 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3794 srcWidth < dstWidth)
3795 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3797 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3798 srcWidth < dstWidth)
3799 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3801 assert(
false &&
"unhandled promotion case");
3806 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3807 Value
lhs, Value
rhs, Value res) {
3808 vector::IteratorType par = vector::IteratorType::parallel;
3809 vector::IteratorType red = vector::IteratorType::reduction;
3810 AffineExpr n, w, f, c;
3814 auto contrationOp = vector::ContractionOp::create(
3815 rewriter, loc,
lhs,
rhs, res,
3816 MapList{{n, w, c}, {c, f}, {n, w, f}},
3817 ArrayRef<vector::IteratorType>{par, par, par, red});
3818 contrationOp.setKind(reductionKind);
3819 return contrationOp;
3824 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3825 Value
lhs, Value
rhs, Value res) {
3826 return vector::OuterProductOp::create(rewriter, loc, res.
getType(),
lhs,
3827 rhs, res, vector::CombiningKind::ADD);
3831 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value
lhs,
3849 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3850 bool channelDimScalableFlag,
3852 bool scalableChDim =
false;
3853 bool useMasking =
false;
3854 int64_t nSize, wSize, cSize, kwSize;
3856 bindShapeDims(rhsShapedType, kwSize, cSize);
3857 if (ShapedType::isDynamic(cSize)) {
3858 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3859 cSize = channelDimVecSize;
3863 scalableChDim = channelDimScalableFlag;
3867 assert(!(useMasking && flatten) &&
3868 "Unsupported flattened conv with dynamic shapes");
3871 bindShapeDims(resShapedType, nSize, wSize);
3873 vector::TransferWriteOp write;
3879 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3881 Type lhsEltType = lhsShapedType.getElementType();
3882 Type rhsEltType = rhsShapedType.getElementType();
3883 Type resEltType = resShapedType.getElementType();
3884 VectorType lhsType = VectorType::get(
3888 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3890 lhsEltType, {
false,
false, scalableChDim});
3891 VectorType rhsType =
3892 VectorType::get({kwSize, cSize}, rhsEltType,
3893 {
false, scalableChDim});
3894 VectorType resType =
3895 VectorType::get({nSize, wSize, cSize}, resEltType,
3896 {
false,
false, scalableChDim});
3900 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3901 ArrayRef<bool> scalableDims,
3902 Operation *opToMask) {
3906 VectorType::get(maskShape, rewriter.
getI1Type(), scalableDims);
3908 SmallVector<bool> inBounds(maskShape.size(),
true);
3909 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3910 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3914 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3917 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3924 Value
lhs = vector::TransferReadOp::create(
3925 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3926 arith::getZeroConstant(rewriter, loc, lhsEltType));
3927 auto *maybeMaskedLhs = maybeMaskXferOp(
3928 lhsType.getShape(), lhsType.getScalableDims(),
lhs.getDefiningOp());
3931 Value
rhs = vector::TransferReadOp::create(
3932 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3933 arith::getZeroConstant(rewriter, loc, rhsEltType));
3934 auto *maybeMaskedRhs = maybeMaskXferOp(
3935 rhsType.getShape(), rhsType.getScalableDims(),
rhs.getDefiningOp());
3938 Value res = vector::TransferReadOp::create(
3939 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3940 arith::getZeroConstant(rewriter, loc, resEltType));
3941 auto *maybeMaskedRes = maybeMaskXferOp(
3942 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3948 SmallVector<Value> lhsVals, rhsVals, resVals;
3949 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3950 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3954 for (int64_t kw = 0; kw < kwSize; ++kw) {
3955 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3956 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3957 rewriter, loc, maybeMaskedLhs->getResult(0),
3958 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3959 inOutSliceSizes, inOutStrides));
3963 for (int64_t kw = 0; kw < kwSize; ++kw) {
3965 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3966 ArrayRef<int64_t>{kw}));
3969 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3970 resVals.push_back(vector::ExtractStridedSliceOp::create(
3971 rewriter, loc, maybeMaskedRes->getResult(0),
3972 ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3976 auto linearIndex = [&](int64_t kw, int64_t w) {
3977 return kw * (wSize / wSizeStep) + w;
3982 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3983 auto lhsTypeAfterFlattening =
3984 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3985 auto resTypeAfterFlattening =
3986 VectorType::get(inOutFlattenSliceSizes, resEltType);
3989 for (int64_t kw = 0; kw < kwSize; ++kw) {
3990 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3991 Value lhsVal = lhsVals[linearIndex(kw, w)];
3992 Value resVal = resVals[w];
3997 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3998 lhsVals[linearIndex(kw, w)]);
3999 resVal = vector::ShapeCastOp::create(
4000 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4002 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4003 rhsVals[kw], resVal, flatten);
4006 resVals[w] = vector::ShapeCastOp::create(
4007 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4014 if (!llvm::all_of(resVals, [](Value v) {
return v; })) {
4016 for (
auto &collection :
4017 {resVals, rhsVals, lhsVals, {res,
rhs,
lhs, zero}})
4018 for (Value v : collection)
4025 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4026 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4027 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4028 ArrayRef<int64_t>{0, w, 0},
4029 ArrayRef<int64_t>{1, 1, 1});
4036 Operation *resOut = vector::TransferWriteOp::create(
4037 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4039 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4047 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4048 Value
lhs, Value
rhs, Value res,
4050 auto rhsTy = cast<ShapedType>(
rhs.getType());
4051 auto resTy = cast<ShapedType>(res.
getType());
4065 auto rhsSize = cast<VectorType>(
rhs.getType()).getShape()[0];
4066 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4068 SmallVector<int64_t, 16>
indices;
4069 for (
int i = 0; i < resSize / rhsSize; ++i) {
4070 for (
int j = 0; j < rhsSize; ++j)
4077 rhs = vector::BroadcastOp::create(rewriter, loc,
4078 resTy.clone(rhsTy.getElementType()),
rhs);
4085 if (isa<FloatType>(resTy.getElementType()))
4086 return vector::FMAOp::create(rewriter, loc,
lhs,
rhs, res);
4088 auto mul = arith::MulIOp::create(rewriter, loc,
lhs,
rhs);
4089 return arith::AddIOp::create(rewriter, loc,
mul, res);
4094 FailureOr<Operation *> generateNonChanneledConv() {
4097 if (!iters({Par(), Red()}))
4099 "failed to match conv::W 1-par 1-red");
4102 if (layout({ {w + kw},
4112 FailureOr<Operation *> generateNwcConv() {
4113 AffineExpr n, w, f, kw, c;
4115 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4117 op,
"failed to match conv::Nwc 3-par 2-red");
4120 if (layout({ {n, strideW * w + dilationW * kw, c},
4130 FailureOr<Operation *> generateNcwConv() {
4131 AffineExpr n, w, f, kw, c;
4133 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4135 op,
"failed to match conv::Ncw 3-par 2-red");
4137 if (layout({ {n, c, strideW * w + dilationW * kw},
4147 FailureOr<Operation *> generateNwcPooling() {
4148 AffineExpr n, w, c, kw;
4150 if (!iters({Par(), Par(), Par(), Red()}))
4152 "failed to match pooling 3-par 1-red");
4155 if (layout({ {n, strideW * w + dilationW * kw, c},
4165 FailureOr<Operation *> generateNcwPooling() {
4166 AffineExpr n, w, c, kw;
4168 if (!iters({Par(), Par(), Par(), Red()}))
4170 "failed to match pooling 3-par 1-red");
4172 if (layout({ {n, c, strideW * w + dilationW * kw},
4182 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4183 bool vecChDimScalableFlag =
false,
4184 bool flatten =
false) {
4185 AffineExpr n, w, c, kw;
4187 if (!iters({Par(), Par(), Par(), Red()}))
4189 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4192 if (layout({ {n, strideW * w + dilationW * kw, c},
4195 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4201 ConvOperationKind oper = ConvOperationKind::Conv;
4203 StringAttr poolExtOp;
4204 bool isPoolExt =
false;
4205 int strideW, dilationW;
4206 Value lhsShaped, rhsShaped, resShaped;
4207 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4208 vector::CombiningKind reductionKind;
4211 void setConvOperationKind(Operation *reduceOp) {
4212 int numBlockArguments =
4213 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4214 if (numBlockArguments == 1) {
4219 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4220 llvm::IsaPred<BlockArgument>);
4221 Operation *feedOp = (*feedValIt).getDefiningOp();
4222 if (isCastOfBlockArgument(feedOp)) {
4223 oper = ConvOperationKind::Pool;
4228 oper = ConvOperationKind::Conv;
4232 oper = ConvOperationKind::Pool;
4241 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4242 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4243 Conv1DGenerator conv1dGen(rewriter, op);
4244 auto res = conv1dGen.generateNonChanneledConv();
4247 res = conv1dGen.generateNwcConv();
4250 res = conv1dGen.generateNcwConv();
4253 res = conv1dGen.generateNwcPooling();
4256 res = conv1dGen.generateNcwPooling();
4263 uint64_t vecChDimSize = ShapedType::kDynamic;
4264 bool vecChDimScalableFlag =
false;
4265 if (!inputVecSizes.empty()) {
4268 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4269 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4270 "Not a 1D depthwise conv!");
4273 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4274 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4276 vecChDimSize = inputVecSizes[chDimIdx];
4277 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4279 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4280 flatten1DDepthwiseConv);
4283struct VectorizeConvolution :
public OpInterfaceRewritePattern<LinalgOp> {
4286 LogicalResult matchAndRewrite(LinalgOp op,
4287 PatternRewriter &rewriter)
const override {
4289 if (
failed(resultOrFail))
4291 Operation *newOp = *resultOrFail;
4293 rewriter.
eraseOp(op.getOperation());
4296 assert(newOp->
getNumResults() == 1 &&
"expected single result");
4303 RewritePatternSet &
patterns, PatternBenefit benefit) {
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override