38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/SmallVectorExtras.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/DebugLog.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/MathExtras.h"
46#include "llvm/Support/raw_ostream.h"
52#define DEBUG_TYPE "linalg-vectorization"
55static FailureOr<Operation *>
59 bool flatten1DDepthwiseConv =
false);
94template <
typename OpType>
97 block.
walk([&](OpType op) {
113 int64_t kwSize,
int strideW,
int dilationW,
114 int64_t wSizeStep,
bool isSingleChanneled) {
116 if (isSingleChanneled) {
121 for (
int64_t kw = 0; kw < kwSize; ++kw) {
122 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
123 result.push_back(vector::ExtractStridedSliceOp::create(
133 for (
int64_t kw = 0; kw < kwSize; ++kw) {
134 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(vector::ExtractStridedSliceOp::create(
136 rewriter, loc, input,
153 for (
int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(vector::ExtractOp::create(
165 int64_t wSizeStep,
bool isSingleChanneled) {
167 if (isSingleChanneled) {
171 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(vector::ExtractStridedSliceOp::create(
181 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(vector::ExtractStridedSliceOp::create(
194 bool isSingleChanneled) {
196 if (isSingleChanneled) {
200 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
201 res = vector::InsertStridedSliceOp::create(
209 for (
int64_t w = 0; w < wSize; w += wSizeStep) {
210 res = vector::InsertStridedSliceOp::create(
211 rewriter, loc, resVals[w], res,
225 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
228 bool assumeDynamicDimsMatchVecSizes =
false);
243 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
246 if (dimPermutation.has_value()) {
252 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
253 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
256 return VectorType::get(
vectorShape, elementType, scalableDims);
265 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
270 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
271 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
277 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
284 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
286 std::optional<AffineMap> maybeMaskingMap);
291 bool isValidMaskingMap(AffineMap maskingMap) {
310 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
316 SmallVector<int64_t> iterSpaceStaticSizes;
321 SmallVector<Value> iterSpaceValueSizes;
324 SmallVector<int64_t> canonicalVecShape;
328 SmallVector<bool> scalableVecDims;
336 OpBuilder::InsertionGuard rewriterGuard;
344 bool assumeDynamicDimsMatchVecSizes =
false;
348VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
351 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
352 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
355 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
362 unsigned operandDimPos;
363 if (
failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
368 linalgOp.hasPureTensorSemantics()
369 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
373 iterSpaceValueSizes.push_back(dynamicDim);
386 bool assumeDimsMatchVec) {
387 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
391 if (!inputVectorSizes.empty()) {
395 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
396 scalableVecDims.append(inputScalableVecDims.begin(),
397 inputScalableVecDims.end());
402 canonicalVecShape = linalgOp.getStaticLoopRanges();
403 scalableVecDims.append(linalgOp.getNumLoops(),
false);
406 LDBG() <<
"Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
407 LDBG() <<
"Scalable vector dims: " << llvm::interleaved(scalableVecDims);
409 if (ShapedType::isDynamicShape(canonicalVecShape))
413 initIterSpaceStaticSizes(linalgOp);
418 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
428Value VectorizationState::getOrCreateMaskFor(
430 std::optional<AffineMap> maybeMaskingMap) {
432 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
433 "Ill-formed masking map.");
436 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
440 assert(!maskableOp.isMasked() &&
441 "Masking an operation that is already masked");
444 assert((!maybeMaskingMap || *maybeMaskingMap) &&
445 "Unexpected null mask permutation map");
447 maybeMaskingMap ? *maybeMaskingMap
449 linalgOp.getNumLoops(), rewriter.
getContext());
451 LDBG() <<
"Masking map: " << maskingMap;
455 auto activeMaskIt = activeMaskCache.find(maskingMap);
456 if (activeMaskIt != activeMaskCache.end()) {
457 Value mask = activeMaskIt->second;
458 LDBG() <<
"Reusing mask: " << mask;
468 SmallVector<int64_t> permutedStaticSizes =
470 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
471 auto maskShape = maskType.getShape();
473 LDBG() <<
"Mask shape: " << llvm::interleaved(maskShape);
475 if (permutedStaticSizes == maskShape) {
476 LDBG() <<
"Masking is not needed for masking map: " << maskingMap;
477 activeMaskCache[maskingMap] = Value();
481 if (assumeDynamicDimsMatchVecSizes) {
485 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
487 return std::get<0>(it) == ShapedType::kDynamic
489 : std::get<0>(it) == std::get<1>(it);
492 <<
"Dynamic + static dimensions match vector sizes, masking is not "
494 activeMaskCache[maskingMap] = Value();
500 SmallVector<Value> upperBounds =
502 assert(!maskShape.empty() && !upperBounds.empty() &&
503 "Masked 0-d vectors are not supported yet");
506 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
507 maskType, upperBounds);
508 LDBG() <<
"Creating new mask: " << mask;
509 activeMaskCache[maskingMap] = mask;
516 std::optional<AffineMap> maybeIndexingMap) {
517 LDBG() <<
"Trying to mask: " << *opToMask;
519 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
520 if (maybeIndexingMap)
521 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
525 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
528 LDBG() <<
"No mask required";
529 if (assumeDynamicDimsMatchVecSizes) {
531 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
537 LDBG() <<
"Assuming dynamic dimensions match vector sizes and "
538 "setting their in-bounds to true!";
540 ShapedType xferType = xferOp.getShapedType();
545 for (
unsigned i = 0; i < xferOp.getTransferRank(); i++) {
546 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.
getResult(i));
550 unsigned pos = dimExpr.getPosition();
551 if (xferType.isDynamicDim(pos))
552 inBoundsMap[i] =
true;
555 xferOp.setInBoundsAttr(
567 assert(opToMask &&
"Expected a valid operation to mask");
568 auto maskOp = cast<vector::MaskOp>(
570 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
572 for (
auto [resIdx, resVal] : llvm::enumerate(opToMask->
getResults()))
576 LDBG() <<
"Masked operation: " << *maskOp;
599 "expected projected permutation");
601 assert(res.getNumDims() ==
602 (res.getNumResults() - res.getNumOfZeroResults()) &&
603 "expected reindexed map with same number of dims and results");
639std::optional<vector::CombiningKind>
641 using ::mlir::vector::CombiningKind;
646 .Case<arith::AddIOp, arith::AddFOp>(
647 [&](
auto op) {
return CombiningKind::ADD; })
648 .Case([&](arith::AndIOp op) {
return CombiningKind::AND; })
649 .Case([&](arith::MaxSIOp op) {
return CombiningKind::MAXSI; })
650 .Case([&](arith::MaxUIOp op) {
return CombiningKind::MAXUI; })
651 .Case([&](arith::MaximumFOp op) {
return CombiningKind::MAXIMUMF; })
652 .Case([&](arith::MaxNumFOp op) {
return CombiningKind::MAXNUMF; })
653 .Case([&](arith::MinSIOp op) {
return CombiningKind::MINSI; })
654 .Case([&](arith::MinUIOp op) {
return CombiningKind::MINUI; })
655 .Case([&](arith::MinimumFOp op) {
return CombiningKind::MINIMUMF; })
656 .Case([&](arith::MinNumFOp op) {
return CombiningKind::MINNUMF; })
657 .Case<arith::MulIOp, arith::MulFOp>(
658 [&](
auto op) {
return CombiningKind::MUL; })
659 .Case([&](arith::OrIOp op) {
return CombiningKind::OR; })
660 .Case([&](arith::XOrIOp op) {
return CombiningKind::XOR; })
661 .Default(std::nullopt);
672 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
677 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
678 combinerOps.size() != 1)
682 return combinerOps[0];
688 auto dstVecType = dyn_cast<VectorType>(dstType);
690 if (dstVecType.getRank() == 0)
695 Location loc =
b.getInsertionPoint()->getLoc();
696 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
708 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
709 return vector::MultiDimReductionOp::create(
710 b, reduceOp->
getLoc(), valueToReduce,
acc, dimsToMask, *maybeKind);
714 return llvm::map_to_vector(linalgOp.getIteratorTypesArray(),
721 return isa<linalg::ReduceOp>(op) ||
722 (isa<linalg::GenericOp>(op) &&
734 VectorizationState &state) {
736 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
737 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
746 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
748 auto vectorType = state.getCanonicalVecType(
755 if (vectorType.getRank() > 0) {
758 assert(value.
getType() == vectorType &&
"Incorrect type");
759 write = vector::TransferWriteOp::create(
760 rewriter, loc, value, outputOperand->
get(),
indices, writeMap);
763 if (!isa<VectorType>(value.
getType()))
764 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
765 assert(value.
getType() == vectorType &&
"Incorrect type");
766 write = vector::TransferWriteOp::create(rewriter, loc, value,
770 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
774 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
775 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
780 LDBG() <<
"vectorized op: " << *write;
790 std::function<LogicalResult(
Operation *,
bool)>;
807 const IRMapping &bvm, VectorizationState &state,
809 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
812 for (
const auto &output : llvm::enumerate(yieldOp.getValues())) {
818 linalgOp.getDpsInitOperand(output.index()), state);
820 newResults.push_back(newResult);
831 VectorizationState &state,
834 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
837 auto loc = indexOp.getLoc();
840 auto dim = indexOp.getDim();
842 auto indexVectorType =
843 VectorType::get({targetShape[dim]}, rewriter.
getIndexType(),
844 state.getScalableVecDims()[dim]);
845 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
849 if (dim == targetShape.size() - 1)
855 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
856 std::swap(permPattern[dim], permPattern.back());
860 auto broadCastOp = vector::BroadcastOp::create(
862 state.getCanonicalVecType(rewriter.
getIndexType(), permMap), indexSteps);
864 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
865 std::swap(transposition.back(), transposition[dim]);
867 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
875 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
879 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
884 if (not extractOp.getIndices().empty()) {
885 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
889 if (!llvm::all_of(extractOp->getResultTypes(),
890 VectorType::isValidElementType)) {
908 VectorizationState &state,
909 tensor::ExtractOp extractOp,
912 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
913 auto loc = extractOp.getLoc();
916 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
918 const size_t numIndices = extractOp.getIndices().size();
919 for (
size_t i = 1; i < numIndices; i++) {
924 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
927 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
930 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
932 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
958 (linalgOp.hasDynamicShape() ||
959 llvm::count_if(loopRanges, [](
int64_t dim) { return dim != 1; }) == 1) &&
960 "For statically shaped Linalg Ops, only one "
961 "non-unit loop dim is expected");
962 assert(!loopRanges.empty() &&
"Empty loops, nothing to analyse.");
964 size_t idx = loopRanges.size() - 1;
965 for (; idx != 0; idx--)
966 if (loopRanges[idx] != 1)
974 VectorType resType) {
976 assert(((llvm::count_if(resType.getShape(),
977 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
978 "n-D vectors are not yet supported");
984 auto *block = linalgOp.getBlock();
985 if (isa<BlockArgument>(val))
986 return !llvm::is_contained(block->getArguments(), val);
989 assert(defOp &&
"This is neither a block argument nor an operation result");
994 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
995 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
998 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1005 if (isa<arith::ConstantOp>(ancestor))
1009 for (
auto op : ancestor->getOperands())
1033 bool &foundIndexOp, VectorType resType) {
1035 assert(((llvm::count_if(resType.getShape(),
1036 [](
int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1037 "n-D vectors are not yet supported");
1043 auto *block = linalgOp.getBlock();
1044 if (isa<BlockArgument>(val))
1045 return !llvm::is_contained(block->getArguments(), val);
1048 assert(defOp &&
"This is neither a block argument nor an operation result");
1050 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1053 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1057 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1064 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1068 for (
auto op : ancestor->getOperands())
1088 LinalgOp &linalgOp, VectorType resType) {
1090 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1093 if (inputShape.getShape().empty())
1098 if (resType.getRank() == 0)
1103 bool isOutput1DVector =
1104 (llvm::count_if(resType.getShape(),
1105 [](
int64_t dimSize) { return dimSize > 1; }) == 1);
1107 if (!isOutput1DVector)
1110 bool leadingIdxsLoopInvariant =
true;
1116 auto indices = extractOp.getIndices();
1117 auto leadIndices =
indices.drop_back(1);
1119 for (
auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1120 if (inputShape.getShape()[i] == 1)
1126 if (!leadingIdxsLoopInvariant) {
1127 LDBG() <<
"Found gather load: " << extractOp;
1135 auto extractOpTrailingIdx =
indices.back();
1139 if (leadingIdxsLoopInvariant &&
1141 LDBG() <<
"Found scalar broadcast load: " << extractOp;
1150 bool foundIndexOp =
false;
1152 foundIndexOp, resType);
1155 bool isRowVector = resType.getShape().back() != 1;
1156 isContiguousLoad &= (foundIndexOp && isRowVector);
1158 if (isContiguousLoad) {
1159 LDBG() <<
"Found contigous load: " << extractOp;
1164 LDBG() <<
"Found gather load: " << extractOp;
1172static VectorizationHookResult
1175 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1178 auto loc = extractOp.getLoc();
1181 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1182 auto maskConstantOp = arith::ConstantOp::create(
1186 auto passThruConstantOp = arith::ConstantOp::create(
1192 extractOp.getIndices().size(),
1203 Operation *gatherOp = vector::GatherOp::create(
1204 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1205 maskConstantOp, passThruConstantOp);
1206 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1208 LDBG() <<
"Vectorised as gather load: " << extractOp;
1231 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1232 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1234 transferReadIdxs.push_back(idx);
1238 auto indexAs1dVector = vector::ShapeCastOp::create(
1240 VectorType::get(resultType.getShape().back(), rewriter.
getIndexType(),
1241 resultType.getScalableDims().back()),
1243 transferReadIdxs.push_back(
1244 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1248 auto dstRank = resultType.getRank();
1249 auto srcRank = extractOp.getTensor().getType().getRank();
1258 auto transferReadOp = vector::TransferReadOp::create(
1259 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1260 std::nullopt, permutationMap, inBounds);
1262 Operation *readOrMaskedReadOp = transferReadOp;
1268 auto readMaskType = VectorType::get(readMaskShape, rewriter.
getI1Type());
1269 auto allTrue = vector::ConstantMaskOp::create(
1271 readOrMaskedReadOp =
1275 LDBG() <<
"Vectorised as scalar broadcast load: " << extractOp;
1277 readOrMaskedReadOp};
1282 srcRank, std::min(dstRank, srcRank), rewriter.
getContext());
1284 int32_t rankDiff = dstRank - srcRank;
1292 while (rankDiff > 0) {
1293 permutationMap = permutationMap.insertResult(
1298 auto transferReadOp = vector::TransferReadOp::create(
1299 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1300 std::nullopt, permutationMap, inBounds);
1302 LDBG() <<
"Vectorised as contiguous load: " << extractOp;
1316 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1317 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1321 (outputType && reduceType.getShape() == outputType.getShape()))
1346static VectorizationHookResult
1350 LDBG() <<
"vectorize op " << *op;
1353 if (!customVectorizationHooks.empty()) {
1354 for (
auto &customFunc : customVectorizationHooks) {
1364 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1366 rewriter.
clone(*op)};
1375 auto blockArg = dyn_cast<BlockArgument>(operand);
1376 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1377 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1381 linalgOp.getRegionOutputArgs(),
1382 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1385 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1387 if (!reductionOperands.empty()) {
1388 assert(reductionOperands.size() == 1);
1390 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1391 reductionOperands[0].second, bvm);
1398 VectorType firstMaxRankedType;
1400 auto vecOperand = bvm.
lookup(operand);
1401 assert(vecOperand &&
"Vector operand couldn't be found");
1403 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1404 if (vecType && (!firstMaxRankedType ||
1405 firstMaxRankedType.getRank() < vecType.getRank()))
1406 firstMaxRankedType = vecType;
1412 assert(vecOperand &&
"Vector operand couldn't be found");
1414 if (firstMaxRankedType) {
1415 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1417 firstMaxRankedType.getScalableDims());
1420 vecOperands.push_back(vecOperand);
1426 resultTypes.push_back(
1428 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1429 firstMaxRankedType.getScalableDims())
1465 LDBG() <<
"Vectorizing operation as linalg generic/n";
1466 Block *block = linalgOp.getBlock();
1473 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1475 if (linalgOp.getNumDpsInits() == 0)
1481 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1482 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1483 if (linalgOp.isScalar(opOperand)) {
1484 bvm.
map(bbarg, opOperand->get());
1490 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1493 VectorType readType;
1495 if (linalgOp.isDpsInput(opOperand)) {
1498 readType = state.getCanonicalVecType(elemType);
1505 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1510 Operation *read = vector::TransferReadOp::create(
1511 rewriter, loc, readType, opOperand->get(),
indices,
1512 std::nullopt, readMap);
1513 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1518 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1520 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1526 if (readType.getRank() == 0)
1527 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1530 LDBG() <<
"New vectorized bbarg(" << bbarg.
getArgNumber()
1531 <<
"): " << readValue;
1532 bvm.
map(bbarg, readValue);
1533 bvm.
map(opOperand->get(), readValue);
1542 hooks.push_back(vectorizeYield);
1549 hooks.push_back(vectorizeIndex);
1556 hooks.push_back(vectorizeExtract);
1563 LDBG() <<
"failed to vectorize: " << op;
1568 state.maskOperation(rewriter,
result.newOp, linalgOp);
1569 LDBG() <<
"New vector op: " << *maybeMaskedOp;
1595 assert(type.getNumScalableDims() < 2 &&
1596 "Collapsing more than 1 scalable dim is not supported ATM");
1602 auto shape = type.getShape();
1603 auto scalableFlags = type.getScalableDims();
1607 unsigned currentDim = 0;
1609 unsigned dim = m.getNumResults();
1612 for (
unsigned d = 0; d < dim; ++d) {
1613 size *=
shape[currentDim + d];
1614 flag |= scalableFlags[currentDim + d];
1616 newShape.push_back(size);
1617 newScalableFlags.push_back(flag);
1621 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1654vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1655 ArrayRef<int64_t> inputVectorSizes,
1656 SmallVectorImpl<Value> &newResults) {
1657 if (!inputVectorSizes.empty()) {
1658 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1659 "Invalid number of input vector sizes!");
1663 OpBuilder::InsertionGuard g(rewriter);
1666 Location loc = packOp.getLoc();
1667 std::optional<Value> padValue = packOp.getPaddingValue()
1668 ? std::optional(packOp.getPaddingValue())
1671 SmallVector<int64_t> destShape =
1672 SmallVector<int64_t>(packOp.getDestType().getShape());
1676 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1680 bool useInBoundsInsteadOfMasking =
false;
1681 if (writeVectorSizes.empty()) {
1682 if (ShapedType::isDynamicShape(destShape))
1684 "unable to infer vector sizes");
1686 writeVectorSizes = destShape;
1687 useInBoundsInsteadOfMasking =
true;
1696 PackingMetadata packMetadata;
1697 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1700 auto preTransposeWriteVecType =
1701 VectorType::get(preTransposeWriteVecSizses,
1702 packOp.getResult().getType().getElementType());
1708 preTransposeWriteVecType,
1710 rewriter.
getContext(), packMetadata.reassociations)));
1714 rewriter, loc, packOp.getSource(), readVecType, padValue,
1715 useInBoundsInsteadOfMasking);
1718 auto shapeCastOp = vector::ShapeCastOp::create(
1719 rewriter, loc, preTransposeWriteVecType, maskedRead);
1723 auto transposeOp = vector::TransposeOp::create(
1724 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1728 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1729 newResults.push_back(write->
getResult(0));
1763vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1764 ArrayRef<int64_t> inputVectorSizes,
1765 ArrayRef<bool> inputScalableVecDims,
1766 SmallVectorImpl<Value> &newResults) {
1767 if (!inputVectorSizes.empty()) {
1768 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1769 "Invalid number of input vector sizes!");
1770 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1771 "Incompatible number of vector sizes and vector scalable flags!");
1775 OpBuilder::InsertionGuard g(rewriter);
1778 ShapedType unpackTensorType = unpackOp.getSourceType();
1780 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1781 bool useInBoundsInsteadOfMasking =
false;
1783 Location loc = unpackOp->getLoc();
1786 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1787 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1790 if (inputVectorSizes.empty()) {
1791 if (ShapedType::isDynamicShape(sourceShape))
1793 "Unable to infer vector sizes!");
1795 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1796 useInBoundsInsteadOfMasking =
true;
1800 VectorType readVecType =
1801 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1802 readScalableVectorFlags);
1804 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1805 useInBoundsInsteadOfMasking);
1808 PackingMetadata packMetadata;
1809 SmallVector<int64_t> lastDimToInsertPosPerm =
1811 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1812 rewriter, loc, readResult, lastDimToInsertPosPerm);
1816 transposeOp.getType(),
1818 rewriter.
getContext(), packMetadata.reassociations)));
1819 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1820 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1824 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1825 {}, useInBoundsInsteadOfMasking);
1827 newResults.push_back(write->
getResult(0));
1835vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1836 ArrayRef<int64_t> inputVectorSizes,
1837 SmallVectorImpl<Value> &newResults) {
1838 auto padValue = padOp.getConstantPaddingValue();
1839 Location loc = padOp.getLoc();
1842 OpBuilder::InsertionGuard g(rewriter);
1846 LogicalResult status =
1847 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1848 .reifyResultShapes(rewriter, reifiedReturnShapes);
1850 assert(succeeded(status) &&
"failed to reify result shapes");
1851 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
1853 rewriter, loc, padOp.getSource(), readType, padValue,
1857 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
1858 padOp.getResultType().getElementType());
1861 newResults.push_back(write->
getResult(0));
1867static LogicalResult reductionPreconditions(LinalgOp op) {
1869 LDBG() <<
"reduction precondition failed: no reduction iterator";
1872 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1873 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1879 LDBG() <<
"reduction precondition failed: reduction detection failed";
1887vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1888 bool flatten1DDepthwiseConv) {
1889 if (flatten1DDepthwiseConv) {
1890 LDBG() <<
"Vectorization of flattened convs with dynamic shapes is not "
1896 LDBG() <<
"Not a 1D depth-wise WC conv, dynamic shapes are not supported";
1902 Value
lhs = conv.getDpsInputOperand(0)->get();
1903 ArrayRef<int64_t> lhsShape = cast<ShapedType>(
lhs.getType()).getShape();
1904 auto shapeWithoutCh = lhsShape.drop_back(1);
1905 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1906 LDBG() <<
"Dynamically-shaped op vectorization precondition failed: only "
1907 "channel dim can be dynamic";
1915vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1916 bool flatten1DDepthwiseConv) {
1918 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1921 return reductionPreconditions(op);
1926 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1930 LDBG() <<
"Dynamically-shaped op meets vectorization pre-conditions";
1940vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
1941 ArrayRef<int64_t> inputVectorSizes) {
1943 if (!unpackOp.hasPureTensorSemantics())
1948 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
1949 unpackOp.getSourceType().hasStaticShape())
1954 if (!inputVectorSizes.empty() &&
1955 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
1956 LDBG() <<
"Incorrect number of input vector sizes";
1962 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
1963 LDBG() <<
"Invalid vector sizes for the read operation";
1971vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1972 ArrayRef<int64_t> inputVectorSizes) {
1975 auto sourceType = source.getType();
1976 if (!VectorType::isValidElementType(sourceType.getElementType()))
1992 bool isOutOfBoundsRead =
1993 !sourceType.hasStaticShape() && inputVectorSizes.empty();
1995 if (!padValue && isOutOfBoundsRead) {
1996 LDBG() <<
"Failed to get a pad value for out-of-bounds read access";
2010vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2012 SmallVectorImpl<Value> &newResults) {
2013 Location loc = linalgOp.getLoc();
2014 MLIRContext *ctx = linalgOp.getContext();
2019 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2022 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2026 LDBG() <<
"Failed to determine contraction combining kind.";
2033 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2034 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2036 LDBG() <<
"Contractions with broadcasts are not supported.";
2041 SmallVector<Value> vecOperands;
2042 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2046 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2050 VectorType readType =
2051 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
2054 rewriter, loc, opOperand.get(), readType,
2055 arith::getZeroConstant(rewriter, loc, elemType),
2057 vecOperands.push_back(read);
2061 SmallVector<Attribute> iterAttrs;
2062 auto iterators = linalgOp.getIteratorTypesArray();
2063 for (utils::IteratorType iter : iterators) {
2064 auto vecIter = iter == utils::IteratorType::parallel
2065 ? vector::IteratorType::parallel
2066 : vector::IteratorType::reduction;
2067 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2071 Operation *contractOp = vector::ContractionOp::create(
2072 rewriter, loc, vecOperands[0],
2073 vecOperands[1], vecOperands[2],
2074 linalgOp.getIndexingMaps(), rewriter.
getArrayAttr(iterAttrs), *maybeKind);
2075 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2079 rewriter, loc, contractOp->
getResult(0), outOperand->
get());
2083 newResults.push_back(write->
getResult(0));
2089enum class ConvOperationKind { Conv, Pool };
2092static bool isCastOfBlockArgument(Operation *op) {
2107static std::optional<ConvOperationKind>
2108getConvOperationKind(Operation *reduceOp) {
2109 int numBlockArguments =
2110 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2112 switch (numBlockArguments) {
2118 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2119 llvm::IsaPred<BlockArgument>);
2121 "Expected a non-block argument operand");
2122 Operation *feedOp = (*feedValIt).getDefiningOp();
2123 if (isCastOfBlockArgument(feedOp)) {
2124 return ConvOperationKind::Pool;
2127 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2128 (isa<arith::AndIOp>(feedOp) &&
2131 if (isa<BlockArgument>(v))
2133 if (Operation *op = v.getDefiningOp())
2134 return isCastOfBlockArgument(op);
2137 return std::nullopt;
2140 return ConvOperationKind::Conv;
2144 return ConvOperationKind::Pool;
2146 return std::nullopt;
2150static bool isSupportedPoolKind(vector::CombiningKind kind) {
2152 case vector::CombiningKind::ADD:
2153 case vector::CombiningKind::MAXNUMF:
2154 case vector::CombiningKind::MAXIMUMF:
2155 case vector::CombiningKind::MAXSI:
2156 case vector::CombiningKind::MAXUI:
2157 case vector::CombiningKind::MINNUMF:
2158 case vector::CombiningKind::MINIMUMF:
2159 case vector::CombiningKind::MINSI:
2160 case vector::CombiningKind::MINUI:
2167static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2168 auto getOperandType = [&](
auto operand) {
2169 return dyn_cast<ShapedType>((operand->get()).getType());
2171 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2172 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2173 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2177 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2178 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2185 auto maybeOper = getConvOperationKind(reduceOp);
2186 if (!maybeOper.has_value())
2193 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2194 *maybeKind != vector::CombiningKind::OR) &&
2195 (*maybeOper != ConvOperationKind::Pool ||
2196 !isSupportedPoolKind(*maybeKind)))) {
2200 auto rhsRank = rhsShapedType.getRank();
2201 if (*maybeOper == ConvOperationKind::Pool) {
2205 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2212static LogicalResult vectorizeLinalgOpPrecondition(
2213 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2214 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2216 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2217 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2221 if (!inputVectorSizes.empty() &&
2226 if (linalgOp.hasDynamicShape() &&
failed(vectorizeDynamicLinalgOpPrecondition(
2227 linalgOp, flatten1DDepthwiseConv))) {
2228 LDBG() <<
"Dynamically-shaped op failed vectorization pre-conditions";
2232 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2238 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2241 customPreconditions,
2244 customPrecondition(&innerOp, vectorizeNDExtract));
2248 if (!llvm::all_of(innerOp.getOperandTypes(),
2249 VectorType::isValidElementType)) {
2252 if (!llvm::all_of(innerOp.getResultTypes(),
2253 VectorType::isValidElementType)) {
2262 return vectorizeConvOpPrecondition(linalgOp);
2268 LDBG() <<
"precondition failed: not projected permutations";
2271 if (
failed(reductionPreconditions(linalgOp))) {
2272 LDBG() <<
"precondition failed: reduction preconditions";
2279vectorizePackOpPrecondition(linalg::PackOp packOp,
2280 ArrayRef<int64_t> inputVectorSizes) {
2282 if (!packOp.hasPureTensorSemantics())
2285 auto padValue = packOp.getPaddingValue();
2289 LDBG() <<
"pad value is not constant: " << packOp;
2293 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2294 bool satisfyEmptyCond =
true;
2295 if (inputVectorSizes.empty()) {
2296 if (!packOp.getDestType().hasStaticShape() ||
2297 !packOp.getSourceType().hasStaticShape())
2298 satisfyEmptyCond =
false;
2301 if (!satisfyEmptyCond &&
2303 resultTensorShape.take_front(packOp.getSourceRank()),
2307 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2308 return !getConstantIntValue(v).has_value();
2310 LDBG() <<
"inner_tiles must be constant: " << packOp;
2318vectorizePadOpPrecondition(tensor::PadOp padOp,
2319 ArrayRef<int64_t> inputVectorSizes) {
2320 auto padValue = padOp.getConstantPaddingValue();
2322 LDBG() <<
"pad value is not constant: " << padOp;
2326 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2342 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2343 [&](
const auto &en) {
2344 OpFoldResult padValue = en.value();
2345 unsigned pos = en.index();
2346 std::optional<int64_t> pad = getConstantIntValue(padValue);
2347 return (!pad.has_value() || pad.value() != 0) &&
2348 resultTensorShape[pos] != 1;
2350 LDBG() <<
"low pad must all be zero for all non unit dims: " << padOp;
2364vectorizeScalableVectorPrecondition(Operation *op,
2365 ArrayRef<int64_t> inputVectorSizes,
2366 ArrayRef<bool> inputScalableVecDims) {
2367 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2368 "Number of input vector sizes and scalable dims doesn't match");
2370 size_t numOfScalableDims =
2371 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2373 if (numOfScalableDims == 0)
2376 auto linalgOp = dyn_cast<LinalgOp>(op);
2381 return success(isa<linalg::UnPackOp>(op));
2385 if (numOfScalableDims > 2)
2405 bool seenNonUnitParallel =
false;
2406 auto iterators = linalgOp.getIteratorTypesArray();
2407 SmallVector<bool> scalableFlags(inputScalableVecDims);
2408 int64_t idx = scalableFlags.size() - 1;
2409 while (!scalableFlags[idx]) {
2410 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2411 seenNonUnitParallel |=
2412 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2414 iterators.pop_back();
2415 scalableFlags.pop_back();
2420 switch (iterators.back()) {
2421 case utils::IteratorType::reduction: {
2423 if (iterators.size() != inputVectorSizes.size()) {
2424 LDBG() <<
"Non-trailing reduction dim requested for scalable "
2428 if (isa<linalg::MatmulOp>(op)) {
2430 <<
"Scalable vectorization of the reduction dim in Matmul-like ops "
2436 case utils::IteratorType::parallel: {
2438 if (seenNonUnitParallel) {
2439 LDBG() <<
"Inner parallel dim not requested for scalable "
2451 if (numOfScalableDims == 2) {
2455 if (iterators.back() == utils::IteratorType::reduction) {
2456 LDBG() <<
"Higher dim than the trailing reduction dim requested for "
2461 scalableFlags.pop_back();
2462 iterators.pop_back();
2464 if (!scalableFlags.back() ||
2465 (iterators.back() != utils::IteratorType::parallel))
2473 isa<linalg::BatchMatmulOp>(op) ||
2475 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2480 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2481 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2482 bool flatten1DDepthwiseConv) {
2487 if (
failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2488 inputScalableVecDims)))
2492 .Case([&](linalg::LinalgOp linalgOp) {
2493 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2495 flatten1DDepthwiseConv);
2497 .Case([&](tensor::PadOp padOp) {
2498 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2500 .Case([&](linalg::PackOp packOp) {
2501 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2503 .Case([&](linalg::UnPackOp unpackOp) {
2504 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2506 .Case([&](tensor::InsertSliceOp sliceOp) {
2507 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2509 .Default(failure());
2513static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2514 OpBuilder::InsertionGuard g(rewriter);
2515 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2517 for (
auto op : make_early_inc_range(toReplace)) {
2519 auto expanded = affine::expandAffineExpr(
2521 op.
getOperands().take_front(op.getAffineMap().getNumDims()),
2522 op.
getOperands().take_back(op.getAffineMap().getNumSymbols()));
2528 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2529 tensor::InsertSliceOp>(op);
2533 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2534 ArrayRef<bool> inputScalableVecDims,
bool vectorizeNDExtract,
2535 bool flatten1DDepthwiseConv,
bool assumeDynamicDimsMatchVecSizes,
2536 bool createNamedContraction) {
2537 LDBG() <<
"Attempting to vectorize: " << *op;
2538 LDBG() <<
"Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2539 LDBG() <<
"Input scalable vector dims: "
2540 << llvm::interleaved(inputScalableVecDims);
2544 flatten1DDepthwiseConv))) {
2545 LDBG() <<
"Vectorization pre-conditions failed";
2550 VectorizationState state(rewriter);
2551 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2552 if (
failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2553 inputScalableVecDims,
2554 assumeDynamicDimsMatchVecSizes))) {
2555 LDBG() <<
"Vectorization state couldn't be initialized";
2560 SmallVector<Value> results;
2561 auto vectorizeResult =
2563 .Case([&](linalg::LinalgOp linalgOp) {
2567 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2568 flatten1DDepthwiseConv);
2569 if (succeeded(convOr)) {
2570 llvm::append_range(results, (*convOr)->getResults());
2574 LDBG() <<
"Unsupported convolution can't be vectorized.";
2578 if (createNamedContraction &&
2579 isa<ContractionOpInterface>(linalgOp.getOperation()))
2580 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2584 <<
"Vectorize generic by broadcasting to the canonical vector "
2588 convertAffineApply(rewriter, linalgOp);
2597 .Case([&](tensor::PadOp padOp) {
2598 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2601 .Case([&](linalg::PackOp packOp) {
2602 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2605 .Case([&](linalg::UnPackOp unpackOp) {
2606 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2608 inputScalableVecDims, results);
2610 .Case([&](tensor::InsertSliceOp sliceOp) {
2614 .Default(failure());
2616 if (
failed(vectorizeResult)) {
2617 LDBG() <<
"Vectorization failed";
2621 return VectorizationResult{results};
2625 memref::CopyOp copyOp) {
2626 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2627 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2628 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2633 if (!VectorType::isValidElementType(srcElementType) ||
2634 !VectorType::isValidElementType(dstElementType))
2637 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2638 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2640 Location loc = copyOp->getLoc();
2642 SmallVector<Value>
indices(srcType.getRank(), zero);
2644 Value
readValue = vector::TransferReadOp::create(
2645 rewriter, loc, readType, copyOp.getSource(),
indices,
2648 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2649 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2650 ArrayRef<int64_t>());
2652 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2654 Operation *writeValue = vector::TransferWriteOp::create(
2655 rewriter, loc, readValue, copyOp.getTarget(),
indices,
2666template <
typename OpTy>
2667struct VectorizePadOpUserPattern :
public OpRewritePattern<tensor::PadOp> {
2668 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2670 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2671 PatternRewriter &rewriter)
const final {
2672 bool changed =
false;
2674 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2675 if (
auto op = dyn_cast<OpTy>(user))
2676 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2681 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2682 tensor::PadOp padOp, OpTy op)
const = 0;
2704struct PadOpVectorizationWithTransferReadPattern
2705 :
public VectorizePadOpUserPattern<vector::TransferReadOp> {
2706 using VectorizePadOpUserPattern<
2707 vector::TransferReadOp>::VectorizePadOpUserPattern;
2709 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2710 vector::TransferReadOp xferOp)
const override {
2712 if (!padOp.hasZeroLowPad())
2715 auto padValue = padOp.getConstantPaddingValue();
2719 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2723 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2724 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2726 xferOp.getBaseMutable().assign(padOp.getSource());
2727 xferOp.getPaddingMutable().assign(padValue);
2766struct PadOpVectorizationWithTransferWritePattern
2767 :
public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2768 using VectorizePadOpUserPattern<
2769 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2771 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2772 vector::TransferWriteOp xferOp)
const override {
2774 if (xferOp.getTransferRank() == 0)
2778 if (!padOp.hasZeroLowPad())
2781 auto padValue = padOp.getConstantPaddingValue();
2785 if (!xferOp->hasOneUse())
2787 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2791 if (!trimPadding.hasZeroOffset())
2794 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2800 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(),
false);
2802 xferOp, padOp.getSource().
getType(), xferOp.getVector(),
2803 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2805 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2820 bool hasSameTensorSize(Value beforePadding,
2821 tensor::ExtractSliceOp afterTrimming)
const {
2824 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2825 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2828 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2829 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2834 if (t1.getRank() != t2.getRank())
2839 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2840 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2842 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2847 if (t1.getNumDynamicDims() == 0)
2855 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2859 assert(
static_cast<size_t>(t1.getRank()) ==
2860 beforeSlice.getMixedSizes().size());
2861 assert(
static_cast<size_t>(t2.getRank()) ==
2862 afterTrimming.getMixedSizes().size());
2864 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2866 if (!t1.isDynamicDim(i))
2868 auto size1 = beforeSlice.getMixedSizes()[i];
2869 auto size2 = afterTrimming.getMixedSizes()[i];
2876 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2877 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2883 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2884 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2885 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2886 minOp1.getOperands() == minOp2.getOperands())
2912 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2913 auto source = bcast.getSource();
2914 if (llvm::dyn_cast<VectorType>(source.getType()))
2922 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2923 return fill.getInputs()[0];
2928 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2935 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2943 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2951 ArrayRef<int64_t> inputVectorSizes,
2952 SmallVectorImpl<Value> &newResults) {
2954 OpBuilder::InsertionGuard g(rewriter);
2958 auto sourceType = source.getType();
2959 auto resultType = sliceOp.getResultType();
2964 auto elemType = sourceType.getElementType();
2965 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
2970 SmallVector<int64_t> vecShape;
2971 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2972 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2973 if (!inputVectorSizes.empty()) {
2974 vecShape.push_back(inputVectorSizes[i]);
2975 }
else if (!sourceType.isDynamicDim(i)) {
2976 vecShape.push_back(sourceType.getDimSize(i));
2977 }
else if (!resultType.isDynamicDim(i)) {
2983 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2990 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2993 auto loc = sliceOp.getLoc();
2996 SmallVector<Value> readIndices(
2999 rewriter, loc, source, vecType, padValue,
3000 inputVectorSizes.empty());
3007 writeIndices, inputVectorSizes.empty());
3010 newResults.push_back(write->
getResult(0));
3038struct PadOpVectorizationWithInsertSlicePattern
3039 :
public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3040 using VectorizePadOpUserPattern<
3041 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3043 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3044 tensor::InsertSliceOp insertOp)
const override {
3046 if (!padOp.hasZeroLowPad())
3049 if (!insertOp.hasUnitStride())
3052 auto padValue = padOp.getConstantPaddingValue();
3056 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3059 if (insertOp.getDest() == padOp.getResult())
3062 auto vecType = VectorType::get(padOp.getType().getShape(),
3063 padOp.getType().getElementType());
3064 unsigned vecRank = vecType.getRank();
3065 unsigned tensorRank = insertOp.getType().getRank();
3069 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3070 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3072 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3073 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3083 SmallVector<Value> readIndices(
3085 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3086 vecType, padOp.getSource(),
3087 readIndices, padValue);
3093 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3094 SmallVector<bool> inBounds(vecRank,
true);
3096 insertOp, read, insertOp.getDest(), writeIndices,
3097 ArrayRef<bool>{inBounds});
3104 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3105 patterns.
add<PadOpVectorizationWithTransferReadPattern,
3106 PadOpVectorizationWithTransferWritePattern,
3107 PadOpVectorizationWithInsertSlicePattern>(
3118static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3122 LDBG() <<
"interleavedUses precondition failed, firstOp: " << *firstOp
3123 <<
", second op: " << *secondOp;
3126 for (
auto v : values) {
3127 for (
auto &u : v.getUses()) {
3128 Operation *owner = u.getOwner();
3129 if (owner == firstOp || owner == secondOp)
3135 LDBG() <<
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3136 <<
", second op: " << *secondOp;
3145static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3146 memref::SubViewOp subViewOp;
3148 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3150 return memref::SubViewOp();
3151 subViewOp = newSubViewOp;
3160 vector::TransferReadOp xferOp, PatternRewriter &rewriter)
const {
3163 if (xferOp.getMask())
3167 Value viewOrAlloc = xferOp.getBase();
3173 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3176 Value subView = subViewOp.getResult();
3179 memref::CopyOp copyOp;
3180 for (
auto &u : subView.
getUses()) {
3181 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3182 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3183 if (newCopyOp.getTarget() != subView)
3185 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3197 for (
auto &u : viewOrAlloc.
getUses()) {
3198 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3199 assert(isa<MemRefType>(newFillOp.output().getType()));
3200 if (newFillOp.output() != viewOrAlloc)
3202 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3204 maybeFillOp = newFillOp;
3209 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3211 "padding value does not match fill");
3214 Value in = copyOp.getSource();
3220 auto vectorType = xferOp.getVectorType();
3221 Value res = vector::TransferReadOp::create(
3222 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3223 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3225 SmallVector<bool>(vectorType.getRank(),
false)));
3228 rewriter.
eraseOp(maybeFillOp);
3238 vector::TransferWriteOp xferOp, PatternRewriter &rewriter)
const {
3240 if (xferOp.getMask())
3244 Value viewOrAlloc = xferOp.getBase();
3250 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3253 Value subView = subViewOp.getResult();
3256 memref::CopyOp copyOp;
3257 for (
auto &u : subViewOp.getResult().getUses()) {
3258 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3259 if (newCopyOp.getSource() != subView)
3261 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3271 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3272 Value out = copyOp.getTarget();
3279 auto vector = xferOp.getVector();
3280 vector::TransferWriteOp::create(
3281 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3282 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3284 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3297static void bindShapeDims(ShapedType shapedType) {}
3299template <
int N,
typename IntTy,
typename... IntTy2>
3300static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3301 val = shapedType.getShape()[N];
3302 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3306template <
typename... IntTy>
3307static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3308 bindShapeDims<0>(shapedType, vals...);
3313static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3314#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3315 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3337#undef MATCH_1D_CONV_POOL_OP
3339 return std::nullopt;
3377struct Conv1DGenerator
3378 :
public StructuredGenerator<LinalgOp, utils::IteratorType> {
3381 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3382 LinalgOp linalgOp) {
3385 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3389 int strideW =
static_cast<int>(convParams->strides.front());
3390 int dilationW =
static_cast<int>(convParams->dilations.front());
3391 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3395 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3397 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3398 strideW(strideW), dilationW(dilationW) {
3400 lhsShaped = linalgOp.getDpsInputOperand(0)->
get();
3401 rhsShaped = linalgOp.getDpsInputOperand(1)->
get();
3402 resShaped = linalgOp.getDpsInitOperand(0)->
get();
3403 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3404 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3405 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3410 setConvOperationKind(reduceOp);
3413 reductionKind = maybeKind.value();
3436 int64_t nSize, wSize, cSize, kwSize, fSize;
3437 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3439 switch (conv1DOpOrder) {
3442 nSize = fSize = cSize = 0;
3444 bindShapeDims(resShapedType, wSize);
3446 bindShapeDims(rhsShapedType, kwSize);
3449 (wSize + kwSize - 1)};
3450 rhsShape = {kwSize};
3455 bindShapeDims(resShapedType, nSize, wSize, fSize);
3457 case ConvOperationKind::Conv:
3459 bindShapeDims(rhsShapedType, kwSize, cSize);
3461 case ConvOperationKind::Pool:
3463 bindShapeDims(rhsShapedType, kwSize);
3471 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3475 case ConvOperationKind::Conv:
3476 rhsShape = {kwSize, cSize, fSize};
3478 case ConvOperationKind::Pool:
3479 rhsShape = {kwSize};
3482 resShape = {nSize, wSize, fSize};
3486 bindShapeDims(resShapedType, nSize, fSize, wSize);
3488 case ConvOperationKind::Conv:
3490 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3492 case ConvOperationKind::Pool:
3494 bindShapeDims(rhsShapedType, kwSize);
3498 lhsShape = {nSize, cSize,
3502 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3505 case ConvOperationKind::Conv:
3506 rhsShape = {fSize, cSize, kwSize};
3508 case ConvOperationKind::Pool:
3509 rhsShape = {kwSize};
3512 resShape = {nSize, fSize, wSize};
3516 vector::TransferWriteOp write;
3522 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3524 Type lhsEltType = lhsShapedType.getElementType();
3525 Type rhsEltType = rhsShapedType.getElementType();
3526 Type resEltType = resShapedType.getElementType();
3527 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3528 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3529 auto resType = VectorType::get(resShape, resEltType);
3531 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3532 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3533 SmallVector<Value> resPadding(resShape.size(), zero);
3536 Value
lhs = vector::TransferReadOp::create(
3537 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3538 arith::getZeroConstant(rewriter, loc, lhsEltType));
3540 Value
rhs =
nullptr;
3541 if (oper == ConvOperationKind::Conv)
3542 rhs = vector::TransferReadOp::create(
3543 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3544 arith::getZeroConstant(rewriter, loc, rhsEltType));
3545 Value res = vector::TransferReadOp::create(
3546 rewriter, loc, resType, resShaped, resPadding,
3547 arith::getZeroConstant(rewriter, loc, resEltType));
3552 switch (conv1DOpOrder) {
3560 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3561 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, permLhs);
3563 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3566 if (oper == ConvOperationKind::Conv)
3567 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, permRhs);
3569 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3570 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3579 SmallVector<Value> lhsVals, rhsVals, resVals;
3581 kwSize, strideW, dilationW, wSizeStep,
3584 if (oper == ConvOperationKind::Conv)
3587 wSizeStep, isSingleChanneled);
3589 auto linearIndex = [&](int64_t kw, int64_t w) {
3590 return kw * (wSize / wSizeStep) + w;
3596 for (int64_t kw = 0; kw < kwSize; ++kw) {
3597 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3599 case ConvOperationKind::Conv:
3600 if (isSingleChanneled) {
3601 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3602 lhsVals[linearIndex(kw, w)],
3603 rhsVals[kw], resVals[w]);
3605 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3606 lhsVals[linearIndex(kw, w)],
3607 rhsVals[kw], resVals[w]);
3610 case ConvOperationKind::Pool:
3611 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3627 switch (conv1DOpOrder) {
3634 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3635 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3640 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3646 Value
promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3649 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3650 if (srcElementType == dstElementType)
3657 if (
auto shapedType = dyn_cast<ShapedType>(val.
getType()))
3658 dstType = shapedType.cloneWith(std::nullopt, dstElementType);
3660 dstType = dstElementType;
3662 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3663 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3666 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3667 srcWidth < dstWidth)
3668 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3670 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3671 srcWidth < dstWidth)
3672 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3674 assert(
false &&
"unhandled promotion case");
3679 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3680 Value
lhs, Value
rhs, Value res) {
3681 vector::IteratorType par = vector::IteratorType::parallel;
3682 vector::IteratorType red = vector::IteratorType::reduction;
3683 AffineExpr n, w, f, c;
3687 auto contrationOp = vector::ContractionOp::create(
3688 rewriter, loc,
lhs,
rhs, res,
3689 MapList{{n, w, c}, {c, f}, {n, w, f}},
3690 ArrayRef<vector::IteratorType>{par, par, par, red});
3691 contrationOp.setKind(reductionKind);
3692 return contrationOp;
3697 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3698 Value
lhs, Value
rhs, Value res) {
3701 return vector::OuterProductOp::create(rewriter, loc, res.
getType(),
lhs,
3702 rhs, res, vector::CombiningKind::ADD);
3706 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value
lhs,
3724 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3725 bool channelDimScalableFlag,
3727 bool scalableChDim =
false;
3728 bool useMasking =
false;
3729 int64_t nSize, wSize, cSize, kwSize;
3731 bindShapeDims(rhsShapedType, kwSize, cSize);
3732 if (ShapedType::isDynamic(cSize)) {
3733 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3734 cSize = channelDimVecSize;
3738 scalableChDim = channelDimScalableFlag;
3742 assert(!(useMasking && flatten) &&
3743 "Unsupported flattened conv with dynamic shapes");
3746 bindShapeDims(resShapedType, nSize, wSize);
3748 vector::TransferWriteOp write;
3754 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3756 Type lhsEltType = lhsShapedType.getElementType();
3757 Type rhsEltType = rhsShapedType.getElementType();
3758 Type resEltType = resShapedType.getElementType();
3759 VectorType lhsType = VectorType::get(
3763 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3765 lhsEltType, {
false,
false, scalableChDim});
3766 VectorType rhsType =
3767 VectorType::get({kwSize, cSize}, rhsEltType,
3768 {
false, scalableChDim});
3769 VectorType resType =
3770 VectorType::get({nSize, wSize, cSize}, resEltType,
3771 {
false,
false, scalableChDim});
3775 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3776 ArrayRef<bool> scalableDims,
3777 Operation *opToMask) {
3781 VectorType::get(maskShape, rewriter.
getI1Type(), scalableDims);
3783 SmallVector<bool> inBounds(maskShape.size(),
true);
3784 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3785 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3789 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3792 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3799 Value
lhs = vector::TransferReadOp::create(
3800 rewriter, loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3801 arith::getZeroConstant(rewriter, loc, lhsEltType));
3802 auto *maybeMaskedLhs = maybeMaskXferOp(
3803 lhsType.getShape(), lhsType.getScalableDims(),
lhs.getDefiningOp());
3806 Value
rhs = vector::TransferReadOp::create(
3807 rewriter, loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3808 arith::getZeroConstant(rewriter, loc, rhsEltType));
3809 auto *maybeMaskedRhs = maybeMaskXferOp(
3810 rhsType.getShape(), rhsType.getScalableDims(),
rhs.getDefiningOp());
3813 Value res = vector::TransferReadOp::create(
3814 rewriter, loc, resType, resShaped,
ValueRange{zero, zero, zero},
3815 arith::getZeroConstant(rewriter, loc, resEltType));
3816 auto *maybeMaskedRes = maybeMaskXferOp(
3817 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3823 SmallVector<Value> lhsVals, rhsVals, resVals;
3824 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3825 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3829 for (int64_t kw = 0; kw < kwSize; ++kw) {
3830 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3831 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3832 rewriter, loc, maybeMaskedLhs->getResult(0),
3833 ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3834 inOutSliceSizes, inOutStrides));
3838 for (int64_t kw = 0; kw < kwSize; ++kw) {
3840 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3841 ArrayRef<int64_t>{kw}));
3844 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3845 resVals.push_back(vector::ExtractStridedSliceOp::create(
3846 rewriter, loc, maybeMaskedRes->getResult(0),
3847 ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3851 auto linearIndex = [&](int64_t kw, int64_t w) {
3852 return kw * (wSize / wSizeStep) + w;
3857 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3858 auto lhsTypeAfterFlattening =
3859 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3860 auto resTypeAfterFlattening =
3861 VectorType::get(inOutFlattenSliceSizes, resEltType);
3864 for (int64_t kw = 0; kw < kwSize; ++kw) {
3865 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3866 Value lhsVal = lhsVals[linearIndex(kw, w)];
3867 Value resVal = resVals[w];
3872 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3873 lhsVals[linearIndex(kw, w)]);
3874 resVal = vector::ShapeCastOp::create(
3875 rewriter, loc, resTypeAfterFlattening, resVals[w]);
3877 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3878 rhsVals[kw], resVal, flatten);
3881 resVals[w] = vector::ShapeCastOp::create(
3882 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
3889 if (!llvm::all_of(resVals, [](Value v) {
return v; })) {
3891 for (
auto &collection :
3892 {resVals, rhsVals, lhsVals, {res,
rhs,
lhs, zero}})
3893 for (Value v : collection)
3900 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3901 maybeMaskedRes = vector::InsertStridedSliceOp::create(
3902 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
3903 ArrayRef<int64_t>{0, w, 0},
3904 ArrayRef<int64_t>{1, 1, 1});
3911 Operation *resOut = vector::TransferWriteOp::create(
3912 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
3914 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3922 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3923 Value
lhs, Value
rhs, Value res,
3925 auto rhsTy = cast<ShapedType>(
rhs.getType());
3926 auto resTy = cast<ShapedType>(res.
getType());
3940 auto rhsSize = cast<VectorType>(
rhs.getType()).getShape()[0];
3941 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3943 SmallVector<int64_t, 16>
indices;
3944 for (
int i = 0; i < resSize / rhsSize; ++i) {
3945 for (
int j = 0; j < rhsSize; ++j)
3952 rhs = vector::BroadcastOp::create(rewriter, loc,
3953 resTy.clone(rhsTy.getElementType()),
rhs);
3960 if (isa<FloatType>(resTy.getElementType()))
3961 return vector::FMAOp::create(rewriter, loc,
lhs,
rhs, res);
3963 auto mul = arith::MulIOp::create(rewriter, loc,
lhs,
rhs);
3964 return arith::AddIOp::create(rewriter, loc,
mul, res);
3969 FailureOr<Operation *> generateNonChanneledConv() {
3972 if (!iters({Par(), Red()}))
3974 "failed to match conv::W 1-par 1-red");
3977 if (layout({ {w + kw},
3987 FailureOr<Operation *> generateNwcConv() {
3988 AffineExpr n, w, f, kw, c;
3990 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3992 op,
"failed to match conv::Nwc 3-par 2-red");
3995 if (layout({ {n, strideW * w + dilationW * kw, c},
4005 FailureOr<Operation *> generateNcwConv() {
4006 AffineExpr n, w, f, kw, c;
4008 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4010 op,
"failed to match conv::Ncw 3-par 2-red");
4012 if (layout({ {n, c, strideW * w + dilationW * kw},
4022 FailureOr<Operation *> generateNwcPooling() {
4023 AffineExpr n, w, c, kw;
4025 if (!iters({Par(), Par(), Par(), Red()}))
4027 "failed to match pooling 3-par 1-red");
4030 if (layout({ {n, strideW * w + dilationW * kw, c},
4040 FailureOr<Operation *> generateNcwPooling() {
4041 AffineExpr n, w, c, kw;
4043 if (!iters({Par(), Par(), Par(), Red()}))
4045 "failed to match pooling 3-par 1-red");
4047 if (layout({ {n, c, strideW * w + dilationW * kw},
4057 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4058 bool vecChDimScalableFlag =
false,
4059 bool flatten =
false) {
4060 AffineExpr n, w, c, kw;
4062 if (!iters({Par(), Par(), Par(), Red()}))
4064 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4067 if (layout({ {n, strideW * w + dilationW * kw, c},
4070 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4076 ConvOperationKind oper = ConvOperationKind::Conv;
4078 StringAttr poolExtOp;
4079 bool isPoolExt =
false;
4080 int strideW, dilationW;
4081 Value lhsShaped, rhsShaped, resShaped;
4082 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4083 vector::CombiningKind reductionKind;
4086 void setConvOperationKind(Operation *reduceOp) {
4087 int numBlockArguments =
4088 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4089 if (numBlockArguments == 1) {
4094 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4095 llvm::IsaPred<BlockArgument>);
4096 Operation *feedOp = (*feedValIt).getDefiningOp();
4097 if (isCastOfBlockArgument(feedOp)) {
4098 oper = ConvOperationKind::Pool;
4103 oper = ConvOperationKind::Conv;
4107 oper = ConvOperationKind::Pool;
4116 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4117 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4118 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4121 auto res = conv1dGen->generateNonChanneledConv();
4124 res = conv1dGen->generateNwcConv();
4127 res = conv1dGen->generateNcwConv();
4130 res = conv1dGen->generateNwcPooling();
4133 res = conv1dGen->generateNcwPooling();
4140 uint64_t vecChDimSize = ShapedType::kDynamic;
4141 bool vecChDimScalableFlag =
false;
4142 if (!inputVecSizes.empty()) {
4147 "Not a 1D depthwise conv!");
4148 size_t chDimIdx = 0;
4154 vecChDimSize = inputVecSizes[chDimIdx];
4155 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4157 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4158 flatten1DDepthwiseConv);
4161struct VectorizeConvolution :
public OpInterfaceRewritePattern<LinalgOp> {
4164 LogicalResult matchAndRewrite(LinalgOp op,
4165 PatternRewriter &rewriter)
const override {
4167 if (
failed(resultOrFail))
4169 Operation *newOp = *resultOrFail;
4171 rewriter.
eraseOp(op.getOperation());
4174 assert(newOp->
getNumResults() == 1 &&
"expected single result");
4181 RewritePatternSet &patterns, PatternBenefit benefit) {
4182 patterns.
add<VectorizeConvolution>(patterns.
getContext(), benefit);
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
A dimensional identifier appearing in an affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType & getOperations()
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Create a TransferWriteOp of vecToStore into dest.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override