38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/SmallVectorExtras.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/DebugLog.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/MathExtras.h"
46#include "llvm/Support/raw_ostream.h"
52#define DEBUG_TYPE "linalg-vectorization"
55static FailureOr<Operation *>
59 bool flatten1DDepthwiseConv =
false);
94template <
typename OpType>
97 block.
walk([&](OpType op) {
113 int64_t kwSize,
int strideW,
int dilationW,
114 int64_t wSizeStep,
bool isSingleChanneled) {
116 if (isSingleChanneled) {
121 for (
int64_t kw = 0; kw < kwSize; ++kw) {
122 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
123 result.push_back(vector::ExtractStridedSliceOp::create(
133 for (
int64_t kw = 0; kw < kwSize; ++kw) {
134 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(vector::ExtractStridedSliceOp::create(
136 rewriter, loc, input,
153 for (
int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(vector::ExtractOp::create(
165 int64_t wSizeStep,
bool isSingleChanneled) {
167 if (isSingleChanneled) {
171 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(vector::ExtractStridedSliceOp::create(
181 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(vector::ExtractStridedSliceOp::create(
194 bool isSingleChanneled) {
196 if (isSingleChanneled) {
200 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
201 res = vector::InsertStridedSliceOp::create(
209 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
210 res = vector::InsertStridedSliceOp::create(
211 rewriter, loc, resVals[w], res,
225 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
228 bool assumeDynamicDimsMatchVecSizes =
false);
243 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
246 if (dimPermutation.has_value()) {
252 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
253 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
256 return VectorType::get(
vectorShape, elementType, scalableDims);
265 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
270 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
271 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
277 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
284 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
286 std::optional<AffineMap> maybeMaskingMap);
291 bool isValidMaskingMap(AffineMap maskingMap) {
310 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
316 SmallVector<int64_t> iterSpaceStaticSizes;
321 SmallVector<Value> iterSpaceValueSizes;
324 SmallVector<int64_t> canonicalVecShape;
328 SmallVector<bool> scalableVecDims;
336 OpBuilder::InsertionGuard rewriterGuard;
344 bool assumeDynamicDimsMatchVecSizes =
false;
348VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
351 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
352 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
355 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
362 unsigned operandDimPos;
363 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
368 linalgOp.hasPureTensorSemantics()
369 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
373 iterSpaceValueSizes.push_back(dynamicDim);
386 bool assumeDimsMatchVec) {
387 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
391 if (!inputVectorSizes.empty()) {
395 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
396 scalableVecDims.append(inputScalableVecDims.begin(),
397 inputScalableVecDims.end());
402 canonicalVecShape = linalgOp.getStaticLoopRanges();
403 scalableVecDims.append(linalgOp.getNumLoops(),
false);
406 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
407 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
409 if (ShapedType::isDynamicShape(canonicalVecShape))
413 initIterSpaceStaticSizes(linalgOp);
418 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
428Value VectorizationState::getOrCreateMaskFor(
430 std::optional<AffineMap> maybeMaskingMap) {
432 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
433 "Ill-formed masking map.");
436 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
440 assert(!maskableOp.isMasked() &&
441 "Masking an operation that is already masked");
444 assert((!maybeMaskingMap || *maybeMaskingMap) &&
445 "Unexpected null mask permutation map");
447 maybeMaskingMap ? *maybeMaskingMap
449 linalgOp.getNumLoops(), rewriter.
getContext());
451 LDBG() <<
"Masking map: " << maskingMap;
455 auto activeMaskIt = activeMaskCache.find(maskingMap);
456 if (activeMaskIt != activeMaskCache.end()) {
457 Value mask = activeMaskIt->second;
458 LDBG() <<
"Reusing mask: " << mask;
468 SmallVector<int64_t> permutedStaticSizes =
470 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
471 auto maskShape = maskType.getShape();
473 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
475 if (permutedStaticSizes == maskShape) {
476 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
477 activeMaskCache[maskingMap] = Value();
481 if (assumeDynamicDimsMatchVecSizes) {
485 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
487 return std::get<0>(it) == ShapedType::kDynamic
489 : std::get<0>(it) == std::get<1>(it);
492 <<
"Dynamic + static dimensions match vector sizes, masking is not "
494 activeMaskCache[maskingMap] = Value();
500 SmallVector<Value> upperBounds =
502 assert(!maskShape.empty() && !upperBounds.empty() &&
503 "Masked 0-d vectors are not supported yet");
506 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
507 maskType, upperBounds);
508 LDBG() <<
"Creating new mask: " << mask;
509 activeMaskCache[maskingMap] = mask;
516 std::optional<AffineMap> maybeIndexingMap) {
517 LDBG() <<
"Trying to mask: " << *opToMask;
519 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
520 if (maybeIndexingMap)
521 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
525 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
528 LDBG() <<
"No mask required";
529 if (assumeDynamicDimsMatchVecSizes) {
531 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
537 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
538 "setting their in-bounds to true!";
540 ShapedType xferType = xferOp.getShapedType();
545 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
546 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
550 unsigned pos = dimExpr.getPosition();
551 if (xferType.isDynamicDim(pos))
552 inBoundsMap[i] =
true;
555 xferOp.setInBoundsAttr(
567 assert(opToMask &&
"Expected a valid operation to mask");
568 auto maskOp = cast<vector::MaskOp>(
570 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
572 for (
auto [resIdx, resVal] : llvm::enumerate(opToMask->
getResults()))
576 LDBG() <<
"Masked operation: " << *maskOp;
599 "expected projected permutation");
601 assert(res.getNumDims() ==
602 (res.getNumResults() - res.getNumOfZeroResults()) &&
603 "expected reindexed map with same number of dims and results");
639std::optional<vector::CombiningKind>
641 using ::mlir::vector::CombiningKind;
646 .Case<arith::AddIOp, arith::AddFOp>(
647 [&](
auto op) {
return CombiningKind::ADD; })
648 .Case([&](arith::AndIOp op) {
return CombiningKind::AND; })
649 .Case([&](arith::MaxSIOp op) {
return CombiningKind::MAXSI; })
650 .Case([&](arith::MaxUIOp op) {
return CombiningKind::MAXUI; })
651 .Case([&](arith::MaximumFOp op) {
return CombiningKind::MAXIMUMF; })
652 .Case([&](arith::MaxNumFOp op) {
return CombiningKind::MAXNUMF; })
653 .Case([&](arith::MinSIOp op) {
return CombiningKind::MINSI; })
654 .Case([&](arith::MinUIOp op) {
return CombiningKind::MINUI; })
655 .Case([&](arith::MinimumFOp op) {
return CombiningKind::MINIMUMF; })
656 .Case([&](arith::MinNumFOp op) {
return CombiningKind::MINNUMF; })
657 .Case<arith::MulIOp, arith::MulFOp>(
658 [&](
auto op) {
return CombiningKind::MUL; })
659 .Case([&](arith::OrIOp op) {
return CombiningKind::OR; })
660 .Case([&](arith::XOrIOp op) {
return CombiningKind::XOR; })
661 .Default(std::nullopt);
672 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
677 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
678 combinerOps.size() != 1)
682 return combinerOps[0];
688 auto dstVecType = dyn_cast<VectorType>(dstType);
690 if (dstVecType.getRank() == 0)
695 Location loc =
b.getInsertionPoint()->getLoc();
696 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
708 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
709 return vector::MultiDimReductionOp::create(
710 b, reduceOp->
getLoc(), valueToReduce,
acc, dimsToMask, *maybeKind);
714 return llvm::map_to_vector(linalgOp.getIteratorTypesArray(),
721 return isa<linalg::ReduceOp>(op) ||
722 (isa<linalg::GenericOp>(op) &&
734 VectorizationState &state) {
736 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
737 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
746 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
748 auto vectorType = state.getCanonicalVecType(
755 if (vectorType.getRank() > 0) {
758 assert(value.
getType() == vectorType &&
"Incorrect type");
759 write = vector::TransferWriteOp::create(
760 rewriter, loc, value, outputOperand->
get(),
indices, writeMap);
763 if (!isa<VectorType>(value.
getType()))
764 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
765 assert(value.
getType() == vectorType &&
"Incorrect type");
766 write = vector::TransferWriteOp::create(rewriter, loc, value,
770 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
774 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
775 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
780 LDBG() <<
"vectorized op: " << *write;
790 std::function<LogicalResult(
Operation *,
bool)>;
807 const IRMapping &bvm, VectorizationState &state,
809 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
812 for (
const auto &output : llvm::enumerate(yieldOp.getValues())) {
818 linalgOp.getDpsInitOperand(output.index()), state);
820 newResults.push_back(newResult);
831 VectorizationState &state,
834 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
837 auto loc = indexOp.getLoc();
840 auto dim = indexOp.getDim();
842 auto indexVectorType =
843 VectorType::get({targetShape[dim]}, rewriter.
getIndexType(),
844 state.getScalableVecDims()[dim]);
845 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
849 if (dim == targetShape.size() - 1)
855 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
856 std::swap(permPattern[dim], permPattern.back());
860 auto broadCastOp = vector::BroadcastOp::create(
862 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
864 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
865 std::swap(transposition.back(), transposition[dim]);
867 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
875 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
879 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
884 if (not extractOp.getIndices().empty()) {
885 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
889 if (!llvm::all_of(extractOp->getResultTypes(),
890 VectorType::isValidElementType)) {
908 VectorizationState &state,
909 tensor::ExtractOp extractOp,
912 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
913 auto loc = extractOp.getLoc();
916 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
918 const size_t numIndices = extractOp.getIndices().size();
919 for (
size_t i = 1; i < numIndices; i++) {
924 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
927 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
930 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
932 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
958 (linalgOp.hasDynamicShape() ||
959 llvm::count_if(loopRanges, [](
int64_t dim) { return dim != 1; }) == 1) &&
960 "For statically shaped Linalg Ops, only one "
961 "non-unit loop dim is expected");
962 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
964 size_t idx = loopRanges.size() - 1;
965 for (; idx != 0; idx--)
966 if (loopRanges[idx] != 1)
974 VectorType resType) {
976 assert(((llvm::count_if(resType.getShape(),
977 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
978 "n-D vectors are not yet supported");
984 auto *block = linalgOp.getBlock();
985 if (isa<BlockArgument>(val))
986 return !llvm::is_contained(block->getArguments(), val);
989 assert(defOp &&
"This is neither a block argument nor an operation result");
994 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
995 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
998 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1005 if (isa<arith::ConstantOp>(ancestor))
1009 for (
auto op : ancestor->getOperands())
1033 bool &foundIndexOp, VectorType resType) {
1035 assert(((llvm::count_if(resType.getShape(),
1036 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1037 "n-D vectors are not yet supported");
1043 auto *block = linalgOp.getBlock();
1044 if (isa<BlockArgument>(val))
1045 return !llvm::is_contained(block->getArguments(), val);
1048 assert(defOp &&
"This is neither a block argument nor an operation result");
1050 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1053 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1057 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1064 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1068 for (
auto op : ancestor->getOperands())
1088 LinalgOp &linalgOp, VectorType resType) {
1090 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1093 if (inputShape.getShape().empty())
1098 if (resType.getRank() == 0)
1103 bool isOutput1DVector =
1104 (llvm::count_if(resType.getShape(),
1105 [](
int64_t dimSize) { return dimSize > 1; }) == 1);
1107 if (!isOutput1DVector)
1110 bool leadingIdxsLoopInvariant =
true;
1116 auto indices = extractOp.getIndices();
1117 auto leadIndices =
indices.drop_back(1);
1119 for (
auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1120 if (inputShape.getShape()[i] == 1)
1126 if (!leadingIdxsLoopInvariant) {
1127 LDBG() <<
"Found gather load: " << extractOp;
1135 auto extractOpTrailingIdx =
indices.back();
1139 if (leadingIdxsLoopInvariant &&
1141 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1150 bool foundIndexOp =
false;
1152 foundIndexOp, resType);
1155 bool isRowVector = resType.getShape().back() != 1;
1156 isContiguousLoad &= (foundIndexOp && isRowVector);
1158 if (isContiguousLoad) {
1159 LDBG() <<
"Found contigous load: " << extractOp;
1164 LDBG() <<
"Found gather load: " << extractOp;
1172static VectorizationHookResult
1175 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1178 auto loc = extractOp.getLoc();
1181 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1182 auto maskConstantOp = arith::ConstantOp::create(
1186 auto passThruConstantOp = arith::ConstantOp::create(
1192 extractOp.getIndices().size(),
1203 Operation *gatherOp = vector::GatherOp::create(
1204 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1205 maskConstantOp, passThruConstantOp);
1206 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1208 LDBG() <<
"Vectorised as gather load: " << extractOp;
1231 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1232 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1234 transferReadIdxs.push_back(idx);
1238 auto indexAs1dVector = vector::ShapeCastOp::create(
1240 VectorType::get(resultType.getShape().back(), rewriter.
getIndexType(),
1241 resultType.getScalableDims().back()),
1243 transferReadIdxs.push_back(
1244 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1248 auto dstRank = resultType.getRank();
1249 auto srcRank = extractOp.getTensor().getType().getRank();
1258 auto transferReadOp = vector::TransferReadOp::create(
1259 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1260 std::nullopt, permutationMap, inBounds);
1262 Operation *readOrMaskedReadOp = transferReadOp;
1268 auto readMaskType = VectorType::get(readMaskShape, rewriter.
getI1Type());
1269 auto allTrue = vector::ConstantMaskOp::create(
1271 readOrMaskedReadOp =
1275 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1277 readOrMaskedReadOp};
1282 srcRank, std::min(dstRank, srcRank), rewriter.
getContext());
1284 int32_t rankDiff = dstRank - srcRank;
1292 while (rankDiff > 0) {
1293 permutationMap = permutationMap.insertResult(
1298 auto transferReadOp = vector::TransferReadOp::create(
1299 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1300 std::nullopt, permutationMap, inBounds);
1302 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1316 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1317 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1321 (outputType && reduceType.getShape() == outputType.getShape()))
1346static VectorizationHookResult
1350 LDBG() <<
"vectorize op " << *op;
1353 if (!customVectorizationHooks.empty()) {
1354 for (
auto &customFunc : customVectorizationHooks) {
1364 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1366 rewriter.
clone(*op)};
1375 auto blockArg = dyn_cast<BlockArgument>(operand);
1376 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1377 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1381 linalgOp.getRegionOutputArgs(),
1382 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1385 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1387 if (!reductionOperands.empty()) {
1388 assert(reductionOperands.size() == 1);
1390 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1391 reductionOperands[0].second, bvm);
1398 VectorType firstMaxRankedType;
1400 auto vecOperand = bvm.
lookup(operand);
1401 assert(vecOperand &&
"Vector operand couldn't be found");
1403 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1404 if (vecType && (!firstMaxRankedType ||
1405 firstMaxRankedType.getRank() < vecType.getRank()))
1406 firstMaxRankedType = vecType;
1412 assert(vecOperand &&
"Vector operand couldn't be found");
1414 if (firstMaxRankedType) {
1415 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1417 firstMaxRankedType.getScalableDims());
1420 vecOperands.push_back(vecOperand);
1426 resultTypes.push_back(
1428 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1429 firstMaxRankedType.getScalableDims())
1465 LDBG() <<
"Vectorizing operation as linalg generic/n";
1466 Block *block = linalgOp.getBlock();
1473 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1475 if (linalgOp.getNumDpsInits() == 0)
1481 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1482 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1483 if (linalgOp.isScalar(opOperand)) {
1484 bvm.
map(bbarg, opOperand->get());
1490 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1493 VectorType readType;
1495 if (linalgOp.isDpsInput(opOperand)) {
1498 readType = state.getCanonicalVecType(elemType);
1505 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1510 Operation *read = vector::TransferReadOp::create(
1511 rewriter, loc, readType, opOperand->get(),
indices,
1512 std::nullopt, readMap);
1513 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1518 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1520 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1526 if (readType.getRank() == 0)
1527 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1530 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1531 <<
"): " << readValue;
1532 bvm.
map(bbarg, readValue);
1533 bvm.
map(opOperand->get(), readValue);
1542 hooks.push_back(vectorizeYield);
1549 hooks.push_back(vectorizeIndex);
1556 hooks.push_back(vectorizeExtract);
1563 LDBG() <<
"failed to vectorize: " << op;
1568 state.maskOperation(rewriter,
result.newOp, linalgOp);
1569 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1628 if (ShapedType::isDynamicShape(destShape))
1633 for (
auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1635 cstMaskSizes.push_back(*intSize);
1640 if (cstMaskSizes.size() != maskShape.size())
1645 for (
auto [i, idx] : llvm::enumerate(writeIdxs)) {
1648 cstWriteIdxs.push_back(intVal.getSExtValue());
1653 if (cstWriteIdxs.size() != destShape.size())
1662 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1663 for (
auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1664 if ( maskShape[i] > destShape[rankDiff + i] ||
1665 destShape[rankDiff + i] <
1666 (std::clamp(cstMaskSizes[i],
int64_t(0), maskShape[i]) +
1702 bool useInBoundsInsteadOfMasking =
false) {
1704 ShapedType destType = cast<ShapedType>(dest.
getType());
1705 int64_t destRank = destType.getRank();
1706 auto destShape = destType.getShape();
1708 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1709 int64_t vecToStoreRank = vecToStoreType.getRank();
1710 auto vecToStoreShape = vecToStoreType.getShape();
1713 SmallVector<bool> inBoundsVal(vecToStoreRank,
true);
1714 if (useInBoundsInsteadOfMasking) {
1717 for (
unsigned i = 0; i < vecToStoreRank; i++)
1719 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1720 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1724 bool useDefaultWriteIdxs = writeIndices.empty();
1725 assert((useDefaultWriteIdxs ||
1726 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1727 "Invalid number of write indices!");
1728 if (writeIndices.empty()) {
1730 writeIndices.assign(destRank, zero);
1734 Operation *write = vector::TransferWriteOp::create(builder, loc,
1741 if (useInBoundsInsteadOfMasking)
1745 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1749 auto writeMaskType = VectorType::get(vecToStoreShape, builder.
getI1Type(),
1750 vecToStoreType.getScalableDims());
1752 SmallVector<OpFoldResult> destSizes =
1753 isa<MemRefType>(dest.
getType())
1758 SmallVector<OpFoldResult> maskSizes;
1759 if (useDefaultWriteIdxs) {
1760 maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
1763 size_t diff = destShape.size() - vecToStoreRank;
1764 for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
1768 builder.
createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
1769 maskSizes.push_back(OpFoldResult(neg));
1777 Value maskForWrite =
1778 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1800 assert(type.getNumScalableDims() < 2 &&
1801 "Collapsing more than 1 scalable dim is not supported ATM");
1807 auto shape = type.getShape();
1808 auto scalableFlags = type.getScalableDims();
1812 unsigned currentDim = 0;
1814 unsigned dim = m.getNumResults();
1817 for (
unsigned d = 0; d < dim; ++d) {
1818 size *=
shape[currentDim + d];
1819 flag |= scalableFlags[currentDim + d];
1821 newShape.push_back(size);
1822 newScalableFlags.push_back(flag);
1826 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1859vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1860 ArrayRef<int64_t> inputVectorSizes,
1861 SmallVectorImpl<Value> &newResults) {
1862 if (!inputVectorSizes.empty()) {
1863 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1864 "Invalid number of input vector sizes!");
1868 OpBuilder::InsertionGuard g(rewriter);
1871 Location loc = packOp.getLoc();
1872 std::optional<Value> padValue = packOp.getPaddingValue()
1873 ? std::optional(packOp.getPaddingValue())
1876 SmallVector<int64_t> destShape =
1877 SmallVector<int64_t>(packOp.getDestType().getShape());
1881 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1885 bool useInBoundsInsteadOfMasking =
false;
1886 if (writeVectorSizes.empty()) {
1887 if (ShapedType::isDynamicShape(destShape))
1889 "unable to infer vector sizes");
1891 writeVectorSizes = destShape;
1892 useInBoundsInsteadOfMasking =
true;
1901 PackingMetadata packMetadata;
1902 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1905 auto preTransposeWriteVecType =
1906 VectorType::get(preTransposeWriteVecSizses,
1907 packOp.getResult().getType().getElementType());
1913 preTransposeWriteVecType,
1915 rewriter.
getContext(), packMetadata.reassociations)));
1919 rewriter, loc, packOp.getSource(), readVecType, padValue,
1920 useInBoundsInsteadOfMasking);
1923 auto shapeCastOp = vector::ShapeCastOp::create(
1924 rewriter, loc, preTransposeWriteVecType, maskedRead);
1928 auto transposeOp = vector::TransposeOp::create(
1929 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1933 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1934 newResults.push_back(write->
getResult(0));
1968vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1969 ArrayRef<int64_t> inputVectorSizes,
1970 ArrayRef<bool> inputScalableVecDims,
1971 SmallVectorImpl<Value> &newResults) {
1972 if (!inputVectorSizes.empty()) {
1973 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1974 "Invalid number of input vector sizes!");
1975 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1976 "Incompatible number of vector sizes and vector scalable flags!");
1980 OpBuilder::InsertionGuard g(rewriter);
1983 ShapedType unpackTensorType = unpackOp.getSourceType();
1985 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1986 bool useInBoundsInsteadOfMasking =
false;
1988 Location loc = unpackOp->getLoc();
1991 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1992 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1995 if (inputVectorSizes.empty()) {
1996 if (ShapedType::isDynamicShape(sourceShape))
1998 "Unable to infer vector sizes!");
2000 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
2001 useInBoundsInsteadOfMasking =
true;
2005 VectorType readVecType =
2006 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
2007 readScalableVectorFlags);
2009 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
2010 useInBoundsInsteadOfMasking);
2013 PackingMetadata packMetadata;
2014 SmallVector<int64_t> lastDimToInsertPosPerm =
2016 vector::TransposeOp transposeOp = vector::TransposeOp::create(
2017 rewriter, loc, readResult, lastDimToInsertPosPerm);
2021 transposeOp.getType(),
2023 rewriter.
getContext(), packMetadata.reassociations)));
2024 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
2025 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2029 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2030 {}, useInBoundsInsteadOfMasking);
2032 newResults.push_back(write->
getResult(0));
2040vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2041 ArrayRef<int64_t> inputVectorSizes,
2042 SmallVectorImpl<Value> &newResults) {
2043 auto padValue = padOp.getConstantPaddingValue();
2044 Location loc = padOp.getLoc();
2047 OpBuilder::InsertionGuard g(rewriter);
2051 LogicalResult status =
2052 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2053 .reifyResultShapes(rewriter, reifiedReturnShapes);
2055 assert(succeeded(status) &&
"failed to reify result shapes");
2056 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2058 rewriter, loc, padOp.getSource(), readType, padValue,
2062 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2063 padOp.getResultType().getElementType());
2065 newResults.push_back(write->
getResult(0));
2071static LogicalResult reductionPreconditions(LinalgOp op) {
2073 LDBG() <<
"reduction precondition failed: no reduction iterator";
2076 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2077 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2083 LDBG() <<
"reduction precondition failed: reduction detection failed";
2091vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2092 bool flatten1DDepthwiseConv) {
2093 if (flatten1DDepthwiseConv) {
2094 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
2100 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2106 Value
lhs = conv.getDpsInputOperand(0)->get();
2107 ArrayRef<int64_t> lhsShape = cast<ShapedType>(
lhs.getType()).getShape();
2108 auto shapeWithoutCh = lhsShape.drop_back(1);
2109 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2110 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
2111 "channel dim can be dynamic";
2119vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2120 bool flatten1DDepthwiseConv) {
2122 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2125 return reductionPreconditions(op);
2130 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2134 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
2144vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2145 ArrayRef<int64_t> inputVectorSizes) {
2147 if (!unpackOp.hasPureTensorSemantics())
2152 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2153 unpackOp.getSourceType().hasStaticShape())
2158 if (!inputVectorSizes.empty() &&
2159 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2160 LDBG() <<
"Incorrect number of input vector sizes";
2166 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2167 LDBG() <<
"Invalid vector sizes for the read operation";
2175vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2176 ArrayRef<int64_t> inputVectorSizes) {
2179 auto sourceType = source.getType();
2180 if (!VectorType::isValidElementType(sourceType.getElementType()))
2196 bool isOutOfBoundsRead =
2197 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2199 if (!padValue && isOutOfBoundsRead) {
2200 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2214vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2216 SmallVectorImpl<Value> &newResults) {
2217 Location loc = linalgOp.getLoc();
2218 MLIRContext *ctx = linalgOp.getContext();
2223 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2226 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2230 LDBG() <<
"Failed to determine contraction combining kind.";
2237 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2238 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2240 LDBG() <<
"Contractions with broadcasts are not supported.";
2245 SmallVector<Value> vecOperands;
2246 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2250 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2254 VectorType readType =
2255 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2258 rewriter, loc, opOperand.get(), readType,
2259 arith::getZeroConstant(rewriter, loc, elemType),
2261 vecOperands.push_back(read);
2265 SmallVector<Attribute> iterAttrs;
2266 auto iterators = linalgOp.getIteratorTypesArray();
2267 for (utils::IteratorType iter : iterators) {
2268 auto vecIter = iter == utils::IteratorType::parallel
2269 ? vector::IteratorType::parallel
2270 : vector::IteratorType::reduction;
2271 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2275 Operation *contractOp = vector::ContractionOp::create(
2276 rewriter, loc, vecOperands[0],
2277 vecOperands[1], vecOperands[2],
2278 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2279 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2283 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2287 newResults.push_back(write->
getResult(0));
2293enum class ConvOperationKind { Conv, Pool };
2296static bool isCastOfBlockArgument(Operation *op) {
2311static std::optional<ConvOperationKind>
2312getConvOperationKind(Operation *reduceOp) {
2313 int numBlockArguments =
2314 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2316 switch (numBlockArguments) {
2322 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2323 llvm::IsaPred<BlockArgument>);
2325 "Expected a non-block argument operand");
2326 Operation *feedOp = (*feedValIt).getDefiningOp();
2327 if (isCastOfBlockArgument(feedOp)) {
2328 return ConvOperationKind::Pool;
2331 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2332 (isa<arith::AndIOp>(feedOp) &&
2335 if (isa<BlockArgument>(v))
2337 if (Operation *op = v.getDefiningOp())
2338 return isCastOfBlockArgument(op);
2341 return std::nullopt;
2344 return ConvOperationKind::Conv;
2348 return ConvOperationKind::Pool;
2350 return std::nullopt;
2354static bool isSupportedPoolKind(vector::CombiningKind kind) {
2356 case vector::CombiningKind::ADD:
2357 case vector::CombiningKind::MAXNUMF:
2358 case vector::CombiningKind::MAXIMUMF:
2359 case vector::CombiningKind::MAXSI:
2360 case vector::CombiningKind::MAXUI:
2361 case vector::CombiningKind::MINNUMF:
2362 case vector::CombiningKind::MINIMUMF:
2363 case vector::CombiningKind::MINSI:
2364 case vector::CombiningKind::MINUI:
2371static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2372 auto getOperandType = [&](
auto operand) {
2373 return dyn_cast<ShapedType>((operand->get()).getType());
2375 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2376 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2377 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2381 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2382 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2389 auto maybeOper = getConvOperationKind(reduceOp);
2390 if (!maybeOper.has_value())
2397 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2398 *maybeKind != vector::CombiningKind::OR) &&
2399 (*maybeOper != ConvOperationKind::Pool ||
2400 !isSupportedPoolKind(*maybeKind)))) {
2404 auto rhsRank = rhsShapedType.getRank();
2405 if (*maybeOper == ConvOperationKind::Pool) {
2409 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2416static LogicalResult vectorizeLinalgOpPrecondition(
2417 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2418 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2420 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2421 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2425 if (!inputVectorSizes.empty() &&
2430 if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(
2431 linalgOp, flatten1DDepthwiseConv))) {
2432 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2436 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2442 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2445 customPreconditions,
2448 customPrecondition(&innerOp, vectorizeNDExtract));
2452 if (!llvm::all_of(innerOp.getOperandTypes(),
2453 VectorType::isValidElementType)) {
2456 if (!llvm::all_of(innerOp.getResultTypes(),
2457 VectorType::isValidElementType)) {
2466 return vectorizeConvOpPrecondition(linalgOp);
2472 LDBG() <<
"precondition failed: not projected permutations";
2475 if (
failed(reductionPreconditions(linalgOp))) {
2476 LDBG() <<
"precondition failed: reduction preconditions";
2483vectorizePackOpPrecondition(linalg::PackOp packOp,
2484 ArrayRef<int64_t> inputVectorSizes) {
2486 if (!packOp.hasPureTensorSemantics())
2489 auto padValue = packOp.getPaddingValue();
2493 LDBG() <<
"pad value is not constant: " << packOp;
2497 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2498 bool satisfyEmptyCond =
true;
2499 if (inputVectorSizes.empty()) {
2500 if (!packOp.getDestType().hasStaticShape() ||
2501 !packOp.getSourceType().hasStaticShape())
2502 satisfyEmptyCond =
false;
2505 if (!satisfyEmptyCond &&
2507 resultTensorShape.take_front(packOp.getSourceRank()),
2511 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2512 return !getConstantIntValue(v).has_value();
2514 LDBG() <<
"inner_tiles must be constant: " << packOp;
2522vectorizePadOpPrecondition(tensor::PadOp padOp,
2523 ArrayRef<int64_t> inputVectorSizes) {
2524 auto padValue = padOp.getConstantPaddingValue();
2526 LDBG() <<
"pad value is not constant: " << padOp;
2530 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2546 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2547 [&](
const auto &en) {
2548 OpFoldResult padValue = en.value();
2549 unsigned pos = en.index();
2550 std::optional<int64_t> pad = getConstantIntValue(padValue);
2551 return (!pad.has_value() || pad.value() != 0) &&
2552 resultTensorShape[pos] != 1;
2554 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2568vectorizeScalableVectorPrecondition(Operation *op,
2569 ArrayRef<int64_t> inputVectorSizes,
2570 ArrayRef<bool> inputScalableVecDims) {
2571 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2572 "Number of input vector sizes and scalable dims doesn't match");
2574 size_t numOfScalableDims =
2575 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2577 if (numOfScalableDims == 0)
2580 auto linalgOp = dyn_cast<LinalgOp>(op);
2585 return success(isa<linalg::UnPackOp>(op));
2589 if (numOfScalableDims > 2)
2609 bool seenNonUnitParallel =
false;
2610 auto iterators = linalgOp.getIteratorTypesArray();
2611 SmallVector<bool> scalableFlags(inputScalableVecDims);
2612 int64_t idx = scalableFlags.size() - 1;
2613 while (!scalableFlags[idx]) {
2614 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2615 seenNonUnitParallel |=
2616 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2618 iterators.pop_back();
2619 scalableFlags.pop_back();
2624 switch (iterators.back()) {
2625 case utils::IteratorType::reduction: {
2627 if (iterators.size() != inputVectorSizes.size()) {
2628 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2632 if (isa<linalg::MatmulOp>(op)) {
2634 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2640 case utils::IteratorType::parallel: {
2642 if (seenNonUnitParallel) {
2643 LDBG() <<
"Inner parallel dim not requested for scalable "
2655 if (numOfScalableDims == 2) {
2659 if (iterators.back() == utils::IteratorType::reduction) {
2660 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2665 scalableFlags.pop_back();
2666 iterators.pop_back();
2668 if (!scalableFlags.back() ||
2669 (iterators.back() != utils::IteratorType::parallel))
2677 isa<linalg::BatchMatmulOp>(op) ||
2679 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2684 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2685 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2686 bool flatten1DDepthwiseConv) {
2691 if (
failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2692 inputScalableVecDims)))
2696 .Case([&](linalg::LinalgOp linalgOp) {
2697 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2699 flatten1DDepthwiseConv);
2701 .Case([&](tensor::PadOp padOp) {
2702 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2704 .Case([&](linalg::PackOp packOp) {
2705 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2707 .Case([&](linalg::UnPackOp unpackOp) {
2708 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2710 .Case([&](tensor::InsertSliceOp sliceOp) {
2711 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2713 .Default(failure());
2717static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2718 OpBuilder::InsertionGuard g(rewriter);
2719 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2721 for (
auto op : make_early_inc_range(toReplace)) {
2723 auto expanded = affine::expandAffineExpr(
2725 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2726 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2732 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2733 tensor::InsertSliceOp>(op);
2737 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2738 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2739 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2740 bool createNamedContraction) {
2741 LDBG() <<
"Attempting to vectorize: " << *op;
2742 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2743 LDBG() <<
"Input scalable vector dims: "
2744 << llvm::interleaved(inputScalableVecDims);
2748 flatten1DDepthwiseConv))) {
2749 LDBG() <<
"Vectorization pre-conditions failed";
2754 VectorizationState state(rewriter);
2755 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2756 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2757 inputScalableVecDims,
2758 assumeDynamicDimsMatchVecSizes))) {
2759 LDBG() <<
"Vectorization state couldn't be initialized";
2764 SmallVector<Value> results;
2765 auto vectorizeResult =
2767 .Case([&](linalg::LinalgOp linalgOp) {
2771 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2772 flatten1DDepthwiseConv);
2773 if (succeeded(convOr)) {
2774 llvm::append_range(results, (*convOr)->getResults());
2778 LDBG() <<
"Unsupported convolution can't be vectorized.";
2782 if (createNamedContraction &&
2783 isa<ContractionOpInterface>(linalgOp.getOperation()))
2784 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2788 <<
"Vectorize generic by broadcasting to the canonical vector "
2792 convertAffineApply(rewriter, linalgOp);
2801 .Case([&](tensor::PadOp padOp) {
2802 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2805 .Case([&](linalg::PackOp packOp) {
2806 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2809 .Case([&](linalg::UnPackOp unpackOp) {
2810 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2812 inputScalableVecDims, results);
2814 .Case([&](tensor::InsertSliceOp sliceOp) {
2818 .Default(failure());
2820 if (
failed(vectorizeResult)) {
2821 LDBG() <<
"Vectorization failed";
2825 return VectorizationResult{results};
2829 memref::CopyOp copyOp) {
2830 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2831 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2832 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2837 if (!VectorType::isValidElementType(srcElementType) ||
2838 !VectorType::isValidElementType(dstElementType))
2841 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2842 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2844 Location loc = copyOp->getLoc();
2846 SmallVector<Value>
indices(srcType.getRank(), zero);
2848 Value
readValue = vector::TransferReadOp::create(
2849 rewriter, loc, readType, copyOp.getSource(),
indices,
2852 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2853 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2854 ArrayRef<int64_t>());
2856 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2858 Operation *writeValue = vector::TransferWriteOp::create(
2859 rewriter, loc, readValue, copyOp.getTarget(),
indices,
2870template <
typename OpTy>
2871struct VectorizePadOpUserPattern :
public OpRewritePattern<tensor::PadOp> {
2872 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2874 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2875 PatternRewriter &rewriter)
const final {
2876 bool changed =
false;
2878 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2879 if (
auto op = dyn_cast<OpTy>(user))
2880 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2885 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2886 tensor::PadOp padOp, OpTy op)
const = 0;
2908struct PadOpVectorizationWithTransferReadPattern
2909 :
public VectorizePadOpUserPattern<vector::TransferReadOp> {
2910 using VectorizePadOpUserPattern<
2911 vector::TransferReadOp>::VectorizePadOpUserPattern;
2913 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2914 vector::TransferReadOp xferOp)
const override {
2916 if (!padOp.hasZeroLowPad())
2919 auto padValue = padOp.getConstantPaddingValue();
2923 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2927 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2928 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2930 xferOp.getBaseMutable().assign(padOp.getSource());
2931 xferOp.getPaddingMutable().assign(padValue);
2970struct PadOpVectorizationWithTransferWritePattern
2971 :
public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2972 using VectorizePadOpUserPattern<
2973 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2975 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2976 vector::TransferWriteOp xferOp)
const override {
2978 if (xferOp.getTransferRank() == 0)
2982 if (!padOp.hasZeroLowPad())
2985 auto padValue = padOp.getConstantPaddingValue();
2989 if (!xferOp->hasOneUse())
2991 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2995 if (!trimPadding.hasZeroOffset())
2998 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
3004 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
3006 xferOp, padOp.getSource().
getType(), xferOp.getVector(),
3007 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
3009 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
3024 bool hasSameTensorSize(Value beforePadding,
3025 tensor::ExtractSliceOp afterTrimming)
const {
3028 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
3029 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3032 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
3033 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3038 if (t1.getRank() != t2.getRank())
3043 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3044 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3046 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3051 if (t1.getNumDynamicDims() == 0)
3059 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
3063 assert(
static_cast<size_t>(t1.getRank()) ==
3064 beforeSlice.getMixedSizes().size());
3065 assert(
static_cast<size_t>(t2.getRank()) ==
3066 afterTrimming.getMixedSizes().size());
3068 for (
unsigned i = 0; i < t1.getRank(); ++i) {
3070 if (!t1.isDynamicDim(i))
3072 auto size1 = beforeSlice.getMixedSizes()[i];
3073 auto size2 = afterTrimming.getMixedSizes()[i];
3080 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3081 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3087 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3088 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3089 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3090 minOp1.getOperands() == minOp2.getOperands())
3116 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3117 auto source = bcast.getSource();
3118 if (llvm::dyn_cast<VectorType>(source.getType()))
3126 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3127 return fill.getInputs()[0];
3132 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3139 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3147 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3155 ArrayRef<int64_t> inputVectorSizes,
3156 SmallVectorImpl<Value> &newResults) {
3158 OpBuilder::InsertionGuard g(rewriter);
3162 auto sourceType = source.getType();
3163 auto resultType = sliceOp.getResultType();
3168 auto elemType = sourceType.getElementType();
3169 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3174 SmallVector<int64_t> vecShape;
3175 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3176 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3177 if (!inputVectorSizes.empty()) {
3178 vecShape.push_back(inputVectorSizes[i]);
3179 }
else if (!sourceType.isDynamicDim(i)) {
3180 vecShape.push_back(sourceType.getDimSize(i));
3181 }
else if (!resultType.isDynamicDim(i)) {
3187 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3194 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3197 auto loc = sliceOp.getLoc();
3200 SmallVector<Value> readIndices(
3203 rewriter, loc, source, vecType, padValue,
3204 inputVectorSizes.empty());
3211 writeIndices, inputVectorSizes.empty());
3214 newResults.push_back(write->
getResult(0));
3242struct PadOpVectorizationWithInsertSlicePattern
3243 :
public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3244 using VectorizePadOpUserPattern<
3245 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3247 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3248 tensor::InsertSliceOp insertOp)
const override {
3250 if (!padOp.hasZeroLowPad())
3253 if (!insertOp.hasUnitStride())
3256 auto padValue = padOp.getConstantPaddingValue();
3260 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3263 if (insertOp.getDest() == padOp.getResult())
3266 auto vecType = VectorType::get(padOp.getType().getShape(),
3267 padOp.getType().getElementType());
3268 unsigned vecRank = vecType.getRank();
3269 unsigned tensorRank = insertOp.getType().getRank();
3273 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3274 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3276 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3277 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3287 SmallVector<Value> readIndices(
3289 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3290 vecType, padOp.getSource(),
3291 readIndices, padValue);
3297 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3298 SmallVector<bool> inBounds(vecRank,
true);
3300 insertOp, read, insertOp.getDest(), writeIndices,
3301 ArrayRef<bool>{inBounds});
3308 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3309 patterns.
add<PadOpVectorizationWithTransferReadPattern,
3310 PadOpVectorizationWithTransferWritePattern,
3311 PadOpVectorizationWithInsertSlicePattern>(
3322static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3326 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3327 <<
", second op: " << *secondOp;
3330 for (
auto v : values) {
3331 for (
auto &u : v.getUses()) {
3332 Operation *owner = u.getOwner();
3333 if (owner == firstOp || owner == secondOp)
3339 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3340 <<
", second op: " << *secondOp;
3349static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3350 memref::SubViewOp subViewOp;
3352 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3354 return memref::SubViewOp();
3355 subViewOp = newSubViewOp;
3364 vector::TransferReadOp xferOp, PatternRewriter &rewriter)
const {
3367 if (xferOp.getMask())
3371 Value viewOrAlloc = xferOp.getBase();
3377 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3380 Value subView = subViewOp.getResult();
3383 memref::CopyOp copyOp;
3384 for (
auto &u : subView.
getUses()) {
3385 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3386 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3387 if (newCopyOp.getTarget() != subView)
3389 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3401 for (
auto &u : viewOrAlloc.
getUses()) {
3402 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3403 assert(isa<MemRefType>(newFillOp.output().getType()));
3404 if (newFillOp.output() != viewOrAlloc)
3406 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3408 maybeFillOp = newFillOp;
3413 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3415 "padding value does not match fill");
3418 Value in = copyOp.getSource();
3424 auto vectorType = xferOp.getVectorType();
3425 Value res = vector::TransferReadOp::create(
3426 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3427 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3429 SmallVector<bool>(vectorType.getRank(),
false)));
3432 rewriter.
eraseOp(maybeFillOp);
3442 vector::TransferWriteOp xferOp, PatternRewriter &rewriter)
const {
3444 if (xferOp.getMask())
3448 Value viewOrAlloc = xferOp.getBase();
3454 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3457 Value subView = subViewOp.getResult();
3460 memref::CopyOp copyOp;
3461 for (
auto &u : subViewOp.getResult().getUses()) {
3462 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3463 if (newCopyOp.getSource() != subView)
3465 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3475 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3476 Value out = copyOp.getTarget();
3483 auto vector = xferOp.getVector();
3484 vector::TransferWriteOp::create(
3485 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3486 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3488 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3501static void bindShapeDims(ShapedType shapedType) {}
3503template <
int N,
typename IntTy,
typename... IntTy2>
3504static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3505 val = shapedType.getShape()[N];
3506 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3510template <
typename... IntTy>
3511static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3512 bindShapeDims<0>(shapedType, vals...);
3517static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3518#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3519 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3541#undef MATCH_1D_CONV_POOL_OP
3543 return std::nullopt;
3581struct Conv1DGenerator
3582 :
public StructuredGenerator<LinalgOp, utils::IteratorType> {
3585 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3586 LinalgOp linalgOp) {
3589 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3593 int strideW =
static_cast<int>(convParams->strides.front());
3594 int dilationW =
static_cast<int>(convParams->dilations.front());
3595 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3599 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3601 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3602 strideW(strideW), dilationW(dilationW) {
3604 lhsShaped = linalgOp.getDpsInputOperand(0)->
get();
3605 rhsShaped = linalgOp.getDpsInputOperand(1)->
get();
3606 resShaped = linalgOp.getDpsInitOperand(0)->
get();
3607 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3608 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3609 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3614 setConvOperationKind(reduceOp);
3617 reductionKind = maybeKind.value();
3640 int64_t nSize, wSize, cSize, kwSize, fSize;
3641 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3643 switch (conv1DOpOrder) {
3646 nSize = fSize = cSize = 0;
3648 bindShapeDims(resShapedType, wSize);
3650 bindShapeDims(rhsShapedType, kwSize);
3653 (wSize + kwSize - 1)};
3654 rhsShape = {kwSize};
3659 bindShapeDims(resShapedType, nSize, wSize, fSize);
3661 case ConvOperationKind::Conv:
3663 bindShapeDims(rhsShapedType, kwSize, cSize);
3665 case ConvOperationKind::Pool:
3667 bindShapeDims(rhsShapedType, kwSize);
3675 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3679 case ConvOperationKind::Conv:
3680 rhsShape = {kwSize, cSize, fSize};
3682 case ConvOperationKind::Pool:
3683 rhsShape = {kwSize};
3686 resShape = {nSize, wSize, fSize};
3690 bindShapeDims(resShapedType, nSize, fSize, wSize);
3692 case ConvOperationKind::Conv:
3694 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3696 case ConvOperationKind::Pool:
3698 bindShapeDims(rhsShapedType, kwSize);
3702 lhsShape = {nSize, cSize,
3706 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3709 case ConvOperationKind::Conv:
3710 rhsShape = {fSize, cSize, kwSize};
3712 case ConvOperationKind::Pool:
3713 rhsShape = {kwSize};
3716 resShape = {nSize, fSize, wSize};
3720 vector::TransferWriteOp write;
3726 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3728 Type lhsEltType = lhsShapedType.getElementType();
3729 Type rhsEltType = rhsShapedType.getElementType();
3730 Type resEltType = resShapedType.getElementType();
3731 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3732 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3733 auto resType = VectorType::get(resShape, resEltType);
3735 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3736 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3737 SmallVector<Value> resPadding(resShape.size(), zero);
3740 Value
lhs = vector::TransferReadOp::create(
3741 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3742 arith::getZeroConstant(rewriter, loc, lhsEltType));
3744 Value
rhs =
nullptr;
3745 if (oper == ConvOperationKind::Conv)
3746 rhs = vector::TransferReadOp::create(
3747 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3748 arith::getZeroConstant(rewriter, loc, rhsEltType));
3749 Value res = vector::TransferReadOp::create(
3750 rewriter, loc, resType, resShaped, resPadding,
3751 arith::getZeroConstant(rewriter, loc, resEltType));
3756 switch (conv1DOpOrder) {
3764 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3765 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, permLhs);
3767 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3770 if (oper == ConvOperationKind::Conv)
3771 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, permRhs);
3773 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3774 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3783 SmallVector<Value> lhsVals, rhsVals, resVals;
3785 kwSize, strideW, dilationW, wSizeStep,
3788 if (oper == ConvOperationKind::Conv)
3791 wSizeStep, isSingleChanneled);
3793 auto linearIndex = [&](int64_t kw, int64_t w) {
3794 return kw * (wSize / wSizeStep) + w;
3800 for (int64_t kw = 0; kw < kwSize; ++kw) {
3801 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3803 case ConvOperationKind::Conv:
3804 if (isSingleChanneled) {
3805 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3806 lhsVals[linearIndex(kw, w)],
3807 rhsVals[kw], resVals[w]);
3809 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3810 lhsVals[linearIndex(kw, w)],
3811 rhsVals[kw], resVals[w]);
3814 case ConvOperationKind::Pool:
3815 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3831 switch (conv1DOpOrder) {
3838 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3839 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3844 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3850 Value
promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3853 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3854 if (srcElementType == dstElementType)
3861 if (
auto shapedType = dyn_cast<ShapedType>(val.
getType()))
3862 dstType = shapedType.cloneWith(std::nullopt, dstElementType);
3864 dstType = dstElementType;
3866 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3867 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3870 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3871 srcWidth < dstWidth)
3872 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3874 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3875 srcWidth < dstWidth)
3876 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3878 assert(
false &&
"unhandled promotion case");
3883 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3884 Value
lhs, Value
rhs, Value res) {
3885 vector::IteratorType par = vector::IteratorType::parallel;
3886 vector::IteratorType red = vector::IteratorType::reduction;
3887 AffineExpr n, w, f, c;
3891 auto contrationOp = vector::ContractionOp::create(
3892 rewriter, loc,
lhs,
rhs, res,
3893 MapList{{n, w, c}, {c, f}, {n, w, f}},
3894 ArrayRef<vector::IteratorType>{par, par, par, red});
3895 contrationOp.setKind(reductionKind);
3896 return contrationOp;
3901 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3902 Value
lhs, Value
rhs, Value res) {
3905 return vector::OuterProductOp::create(rewriter, loc, res.
getType(),
lhs,
3906 rhs, res, vector::CombiningKind::ADD);
3910 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value
lhs,
3928 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3929 bool channelDimScalableFlag,
3931 bool scalableChDim =
false;
3932 bool useMasking =
false;
3933 int64_t nSize, wSize, cSize, kwSize;
3935 bindShapeDims(rhsShapedType, kwSize, cSize);
3936 if (ShapedType::isDynamic(cSize)) {
3937 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3938 cSize = channelDimVecSize;
3942 scalableChDim = channelDimScalableFlag;
3946 assert(!(useMasking && flatten) &&
3947 "Unsupported flattened conv with dynamic shapes");
3950 bindShapeDims(resShapedType, nSize, wSize);
3952 vector::TransferWriteOp write;
3958 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3960 Type lhsEltType = lhsShapedType.getElementType();
3961 Type rhsEltType = rhsShapedType.getElementType();
3962 Type resEltType = resShapedType.getElementType();
3963 VectorType lhsType = VectorType::get(
3967 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3969 lhsEltType, {
false,
false, scalableChDim});
3970 VectorType rhsType =
3971 VectorType::get({kwSize, cSize}, rhsEltType,
3972 {
false, scalableChDim});
3973 VectorType resType =
3974 VectorType::get({nSize, wSize, cSize}, resEltType,
3975 {
false,
false, scalableChDim});
3979 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3980 ArrayRef<bool> scalableDims,
3981 Operation *opToMask) {
3985 VectorType::get(maskShape, rewriter.
getI1Type(), scalableDims);
3987 SmallVector<bool> inBounds(maskShape.size(),
true);
3988 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3989 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3993 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3996 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
4003 Value
lhs = vector::TransferReadOp::create(
4004 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
4005 arith::getZeroConstant(rewriter, loc, lhsEltType));
4006 auto *maybeMaskedLhs = maybeMaskXferOp(
4007 lhsType.getShape(), lhsType.getScalableDims(),
lhs.getDefiningOp());
4010 Value
rhs = vector::TransferReadOp::create(
4011 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
4012 arith::getZeroConstant(rewriter, loc, rhsEltType));
4013 auto *maybeMaskedRhs = maybeMaskXferOp(
4014 rhsType.getShape(), rhsType.getScalableDims(),
rhs.getDefiningOp());
4017 Value res = vector::TransferReadOp::create(
4018 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
4019 arith::getZeroConstant(rewriter, loc, resEltType));
4020 auto *maybeMaskedRes = maybeMaskXferOp(
4021 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
4027 SmallVector<Value> lhsVals, rhsVals, resVals;
4028 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
4029 SmallVector<int64_t> inOutStrides = {1, 1, 1};
4033 for (int64_t kw = 0; kw < kwSize; ++kw) {
4034 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4035 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
4036 rewriter, loc, maybeMaskedLhs->getResult(0),
4037 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
4038 inOutSliceSizes, inOutStrides));
4042 for (int64_t kw = 0; kw < kwSize; ++kw) {
4044 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
4045 ArrayRef<int64_t>{kw}));
4048 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4049 resVals.push_back(vector::ExtractStridedSliceOp::create(
4050 rewriter, loc, maybeMaskedRes->getResult(0),
4051 ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
4055 auto linearIndex = [&](int64_t kw, int64_t w) {
4056 return kw * (wSize / wSizeStep) + w;
4061 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
4062 auto lhsTypeAfterFlattening =
4063 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
4064 auto resTypeAfterFlattening =
4065 VectorType::get(inOutFlattenSliceSizes, resEltType);
4068 for (int64_t kw = 0; kw < kwSize; ++kw) {
4069 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4070 Value lhsVal = lhsVals[linearIndex(kw, w)];
4071 Value resVal = resVals[w];
4076 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
4077 lhsVals[linearIndex(kw, w)]);
4078 resVal = vector::ShapeCastOp::create(
4079 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4081 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4082 rhsVals[kw], resVal, flatten);
4085 resVals[w] = vector::ShapeCastOp::create(
4086 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4093 if (!llvm::all_of(resVals, [](Value v) {
return v; })) {
4095 for (
auto &collection :
4096 {resVals, rhsVals, lhsVals, {res,
rhs,
lhs, zero}})
4097 for (Value v : collection)
4104 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4105 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4106 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4107 ArrayRef<int64_t>{0, w, 0},
4108 ArrayRef<int64_t>{1, 1, 1});
4115 Operation *resOut = vector::TransferWriteOp::create(
4116 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4118 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4126 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4127 Value
lhs, Value
rhs, Value res,
4129 auto rhsTy = cast<ShapedType>(
rhs.getType());
4130 auto resTy = cast<ShapedType>(res.
getType());
4144 auto rhsSize = cast<VectorType>(
rhs.getType()).getShape()[0];
4145 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
4147 SmallVector<int64_t, 16>
indices;
4148 for (
int i = 0; i < resSize / rhsSize; ++i) {
4149 for (
int j = 0; j < rhsSize; ++j)
4156 rhs = vector::BroadcastOp::create(rewriter, loc,
4157 resTy.clone(rhsTy.getElementType()),
rhs);
4164 if (isa<FloatType>(resTy.getElementType()))
4165 return vector::FMAOp::create(rewriter, loc,
lhs,
rhs, res);
4167 auto mul = arith::MulIOp::create(rewriter, loc,
lhs,
rhs);
4168 return arith::AddIOp::create(rewriter, loc,
mul, res);
4173 FailureOr<Operation *> generateNonChanneledConv() {
4176 if (!iters({Par(), Red()}))
4178 "failed to match conv::W 1-par 1-red");
4181 if (layout({ {w + kw},
4191 FailureOr<Operation *> generateNwcConv() {
4192 AffineExpr n, w, f, kw, c;
4194 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4196 op,
"failed to match conv::Nwc 3-par 2-red");
4199 if (layout({ {n, strideW * w + dilationW * kw, c},
4209 FailureOr<Operation *> generateNcwConv() {
4210 AffineExpr n, w, f, kw, c;
4212 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4214 op,
"failed to match conv::Ncw 3-par 2-red");
4216 if (layout({ {n, c, strideW * w + dilationW * kw},
4226 FailureOr<Operation *> generateNwcPooling() {
4227 AffineExpr n, w, c, kw;
4229 if (!iters({Par(), Par(), Par(), Red()}))
4231 "failed to match pooling 3-par 1-red");
4234 if (layout({ {n, strideW * w + dilationW * kw, c},
4244 FailureOr<Operation *> generateNcwPooling() {
4245 AffineExpr n, w, c, kw;
4247 if (!iters({Par(), Par(), Par(), Red()}))
4249 "failed to match pooling 3-par 1-red");
4251 if (layout({ {n, c, strideW * w + dilationW * kw},
4261 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4262 bool vecChDimScalableFlag =
false,
4263 bool flatten =
false) {
4264 AffineExpr n, w, c, kw;
4266 if (!iters({Par(), Par(), Par(), Red()}))
4268 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4271 if (layout({ {n, strideW * w + dilationW * kw, c},
4274 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4280 ConvOperationKind oper = ConvOperationKind::Conv;
4282 StringAttr poolExtOp;
4283 bool isPoolExt =
false;
4284 int strideW, dilationW;
4285 Value lhsShaped, rhsShaped, resShaped;
4286 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4287 vector::CombiningKind reductionKind;
4290 void setConvOperationKind(Operation *reduceOp) {
4291 int numBlockArguments =
4292 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4293 if (numBlockArguments == 1) {
4298 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4299 llvm::IsaPred<BlockArgument>);
4300 Operation *feedOp = (*feedValIt).getDefiningOp();
4301 if (isCastOfBlockArgument(feedOp)) {
4302 oper = ConvOperationKind::Pool;
4307 oper = ConvOperationKind::Conv;
4311 oper = ConvOperationKind::Pool;
4320 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4321 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4322 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4325 auto res = conv1dGen->generateNonChanneledConv();
4328 res = conv1dGen->generateNwcConv();
4331 res = conv1dGen->generateNcwConv();
4334 res = conv1dGen->generateNwcPooling();
4337 res = conv1dGen->generateNcwPooling();
4344 uint64_t vecChDimSize = ShapedType::kDynamic;
4345 bool vecChDimScalableFlag =
false;
4346 if (!inputVecSizes.empty()) {
4351 "Not a 1D depthwise conv!");
4352 size_t chDimIdx = 0;
4358 vecChDimSize = inputVecSizes[chDimIdx];
4359 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4361 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4362 flatten1DDepthwiseConv);
4365struct VectorizeConvolution :
public OpInterfaceRewritePattern<LinalgOp> {
4368 LogicalResult matchAndRewrite(LinalgOp op,
4369 PatternRewriter &rewriter)
const override {
4371 if (
failed(resultOrFail))
4373 Operation *newOp = *resultOrFail;
4375 rewriter.
eraseOp(op.getOperation());
4378 assert(newOp->
getNumResults() == 1 &&
"expected single result");
4385 RewritePatternSet &patterns, PatternBenefit benefit) {
4386 patterns.
add<VectorizeConvolution>(patterns.
getContext(), benefit);
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override