36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
64 template <
typename OpType>
67 block.
walk([&](OpType op) {
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize,
int strideW,
int dilationW,
84 int64_t wSizeStep,
bool isSingleChanneled) {
86 if (isSingleChanneled) {
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.
create<vector::ExtractOp>(
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep,
bool isSingleChanneled) {
136 if (isSingleChanneled) {
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
159 Value res, int64_t wSize, int64_t wSizeStep,
161 bool isSingleChanneled) {
163 if (isSingleChanneled) {
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.
create<vector::InsertStridedSliceOp>(
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.
create<vector::InsertStridedSliceOp>(
191 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
208 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
211 if (dimPermutation.has_value()) {
213 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
215 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
217 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
230 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
235 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
242 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
251 std::optional<AffineMap> maybeMaskingMap);
256 bool isValidMaskingMap(
AffineMap maskingMap) {
305 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
308 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
311 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
312 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
319 unsigned operandDimPos;
320 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
324 Value dynamicDim = linalgOp.hasPureTensorSemantics()
326 linalgOp.getLoc(), operand, operandDimPos)
328 linalgOp.getLoc(), operand, operandDimPos);
329 iterSpaceValueSizes.push_back(dynamicDim);
345 if (!inputVectorSizes.empty()) {
349 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350 scalableVecDims.append(inputScalableVecDims.begin(),
351 inputScalableVecDims.end());
356 canonicalVecShape = linalgOp.getStaticLoopRanges();
357 scalableVecDims.append(linalgOp.getNumLoops(),
false);
360 LDBG(
"Canonical vector shape: ");
361 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362 LLVM_DEBUG(llvm::dbgs() <<
"\n");
363 LDBG(
"Scalable vector dims: ");
364 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365 LLVM_DEBUG(llvm::dbgs() <<
"\n");
367 if (ShapedType::isDynamicShape(canonicalVecShape))
371 initIterSpaceStaticSizes(linalgOp);
376 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
386 Value VectorizationState::getOrCreateMaskFor(
388 std::optional<AffineMap> maybeMaskingMap) {
390 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391 "Ill-formed masking map.");
394 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
398 assert(!maskableOp.isMasked() &&
399 "Masking an operation that is already masked");
402 assert((!maybeMaskingMap || *maybeMaskingMap) &&
403 "Unexpected null mask permutation map");
405 maybeMaskingMap ? *maybeMaskingMap
407 linalgOp.getNumLoops(), rewriter.
getContext());
409 LDBG(
"Masking map: " << maskingMap <<
"\n");
413 auto activeMaskIt = activeMaskCache.find(maskingMap);
414 if (activeMaskIt != activeMaskCache.end()) {
415 Value mask = activeMaskIt->second;
416 LDBG(
"Reusing mask: " << mask <<
"\n");
427 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
429 auto maskShape = maskType.getShape();
431 LDBG(
"Mask shape: ");
432 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433 LLVM_DEBUG(llvm::dbgs() <<
"\n");
435 if (permutedStaticSizes == maskShape) {
436 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
437 activeMaskCache[maskingMap] =
Value();
444 assert(!maskShape.empty() && !upperBounds.empty() &&
445 "Masked 0-d vectors are not supported yet");
448 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
449 maskType, upperBounds);
450 LDBG(
"Creating new mask: " << mask <<
"\n");
451 activeMaskCache[maskingMap] = mask;
458 std::optional<AffineMap> maybeIndexingMap) {
459 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
461 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462 if (maybeIndexingMap)
463 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
467 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
470 LDBG(
"No mask required\n");
475 assert(opToMask &&
"Expected a valid operation to mask");
476 auto maskOp = cast<vector::MaskOp>(
478 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
484 LDBG(
"Masked operation: " << *maskOp <<
"\n");
507 "expected projected permutation");
509 assert(res.getNumDims() ==
510 (res.getNumResults() - res.getNumOfZeroResults()) &&
511 "expected reindexed map with same number of dims and results");
543 std::optional<vector::CombiningKind>
545 using ::mlir::vector::CombiningKind;
550 .Case<arith::AddIOp, arith::AddFOp>(
551 [&](
auto op) {
return CombiningKind::ADD; })
552 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
553 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
554 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
555 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
556 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
557 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
559 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
560 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
561 .Case<arith::MulIOp, arith::MulFOp>(
562 [&](
auto op) {
return CombiningKind::MUL; })
563 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
564 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
565 .Default([&](
auto op) {
return std::nullopt; });
576 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
581 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582 combinerOps.size() != 1)
586 return combinerOps[0];
592 auto dstVecType = dyn_cast<VectorType>(dstType);
594 if (dstVecType.getRank() == 0)
600 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
612 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
613 return b.
create<vector::MultiDimReductionOp>(
614 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
618 return llvm::to_vector(
625 return isa<linalg::ReduceOp>(op) ||
626 (isa<linalg::GenericOp>(op) &&
640 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
641 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
650 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
652 auto vectorType = state.getCanonicalVecType(
656 if (vectorType.getRank() > 0) {
659 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
661 assert(value.
getType() == vectorType &&
"Incorrect type");
662 write = rewriter.
create<vector::TransferWriteOp>(
663 loc, value, outputOperand->
get(), indices, writeMap);
666 if (!isa<VectorType>(value.
getType()))
667 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
668 assert(value.
getType() == vectorType &&
"Incorrect type");
669 write = rewriter.
create<vector::TransferWriteOp>(
673 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
677 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
683 LDBG(
"vectorized op: " << *write <<
"\n");
693 std::function<LogicalResult(
Operation *,
bool)>;
712 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
721 linalgOp.getDpsInitOperand(output.index()), state);
723 newResults.push_back(newResult);
737 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
740 auto loc = indexOp.getLoc();
743 auto dim = indexOp.getDim();
745 auto indexVectorType =
747 state.getScalableVecDims()[dim]);
748 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
752 if (dim == targetShape.size() - 1)
758 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759 std::swap(permPattern[dim], permPattern.back());
763 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
764 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
767 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768 std::swap(transposition.back(), transposition[dim]);
770 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
778 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
782 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
787 if (not extractOp.getIndices().empty()) {
788 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
792 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
793 return !VectorType::isValidElementType(type);
813 tensor::ExtractOp extractOp,
816 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
817 auto loc = extractOp.getLoc();
820 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
822 const size_t numIndices = extractOp.getIndices().size();
823 for (
size_t i = 1; i < numIndices; i++) {
824 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
828 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
831 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
834 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
836 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
862 (linalgOp.hasDynamicShape() ||
863 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864 "For statically shaped Linalg Ops, only one "
865 "non-unit loop dim is expected");
866 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
868 size_t idx = loopRanges.size() - 1;
869 for (; idx != 0; idx--)
870 if (loopRanges[idx] != 1)
878 VectorType resType) {
880 assert(((llvm::count_if(resType.getShape(),
881 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882 "n-D vectors are not yet supported");
888 auto *block = linalgOp.getBlock();
889 if (isa<BlockArgument>(val))
890 return llvm::all_of(block->getArguments(),
891 [&val](
Value v) { return (v != val); });
894 assert(defOp &&
"This is neither a block argument nor an operation result");
899 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
903 auto *ancestor = block->findAncestorOpInBlock(*defOp);
910 if (isa<arith::ConstantOp>(ancestor))
914 for (
auto op : ancestor->getOperands())
938 bool &foundIndexOp, VectorType resType) {
940 assert(((llvm::count_if(resType.getShape(),
941 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942 "n-D vectors are not yet supported");
948 auto *block = linalgOp.getBlock();
949 if (isa<BlockArgument>(val))
950 return llvm::all_of(block->getArguments(),
951 [&val](
Value v) { return (v != val); });
954 assert(defOp &&
"This is neither a block argument nor an operation result");
956 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
963 auto *ancestor = block->findAncestorOpInBlock(*defOp);
970 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
974 for (
auto op : ancestor->getOperands())
994 LinalgOp &linalgOp, VectorType resType) {
996 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
999 if (inputShape.getShape().empty())
1004 bool isOutput1DVector =
1005 (llvm::count_if(resType.getShape(),
1006 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1008 if (!isOutput1DVector)
1011 bool leadingIdxsLoopInvariant =
true;
1017 auto indices = extractOp.getIndices();
1018 auto leadIndices = indices.drop_back(1);
1021 if (inputShape.getShape()[i] == 1)
1027 if (!leadingIdxsLoopInvariant) {
1028 LDBG(
"Found gather load: " << extractOp);
1036 auto extractOpTrailingIdx = indices.back();
1040 if (leadingIdxsLoopInvariant &&
1042 LDBG(
"Found scalar broadcast load: " << extractOp);
1051 bool foundIndexOp =
false;
1053 foundIndexOp, resType);
1056 bool isRowVector = resType.getShape().back() != 1;
1057 isContiguousLoad &= (foundIndexOp && isRowVector);
1059 if (isContiguousLoad) {
1060 LDBG(
"Found contigous load: " << extractOp);
1065 LDBG(
"Found gather load: " << extractOp);
1076 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1079 auto loc = extractOp.getLoc();
1082 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1087 auto passThruConstantOp =
1093 extractOp.getIndices().size(),
1094 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1105 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106 maskConstantOp, passThruConstantOp);
1107 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1109 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1118 assert(llvm::count_if(resultType.getShape(),
1119 [](uint64_t dim) { return dim != 1; }) &&
1120 "Contiguous loads and scalar loads + broadcast only support 1-D "
1137 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1138 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1140 transferReadIdxs.push_back(idx);
1144 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1147 resultType.getScalableDims().back()),
1149 transferReadIdxs.push_back(
1150 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1154 auto dstRank = resultType.getRank();
1155 auto srcRank = extractOp.getTensor().getType().getRank();
1164 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1165 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1166 permutationMap, inBounds);
1168 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1176 int32_t rankDiff = dstRank - srcRank;
1184 while (rankDiff > 0) {
1185 permutationMap = permutationMap.insertResult(
1190 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1191 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1194 LDBG(
"Vectorised as contiguous load: " << extractOp);
1207 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1208 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1212 (outputType && reduceType.getShape() == outputType.getShape()))
1241 LDBG(
"vectorize op " << *op <<
"\n");
1244 if (!customVectorizationHooks.empty()) {
1245 for (
auto &customFunc : customVectorizationHooks) {
1255 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1265 auto blockArg = dyn_cast<BlockArgument>(operand);
1266 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1267 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1271 linalgOp.getRegionOutputArgs(),
1272 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1275 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1277 if (!reductionOperands.empty()) {
1278 assert(reductionOperands.size() == 1);
1280 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1281 reductionOperands[0].second, bvm);
1288 VectorType firstMaxRankedType;
1290 auto vecOperand = bvm.
lookup(operand);
1291 assert(vecOperand &&
"Vector operand couldn't be found");
1293 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1294 if (vecType && (!firstMaxRankedType ||
1295 firstMaxRankedType.getRank() < vecType.getRank()))
1296 firstMaxRankedType = vecType;
1302 assert(vecOperand &&
"Vector operand couldn't be found");
1304 if (firstMaxRankedType) {
1307 firstMaxRankedType.getScalableDims());
1310 vecOperands.push_back(vecOperand);
1316 resultTypes.push_back(
1319 firstMaxRankedType.getScalableDims())
1351 static LogicalResult
1355 LDBG(
"Vectorizing operation as linalg generic\n");
1356 Block *block = linalgOp.getBlock();
1363 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1365 if (linalgOp.getNumDpsInits() == 0)
1370 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1371 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1372 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1373 if (linalgOp.isScalar(opOperand)) {
1374 bvm.
map(bbarg, opOperand->get());
1380 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1383 VectorType readType;
1385 if (linalgOp.isDpsInput(opOperand)) {
1388 readType = state.getCanonicalVecType(elemType);
1395 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1401 loc, readType, opOperand->get(), indices, readMap);
1402 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1407 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1409 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1415 if (readType.getRank() == 0)
1431 hooks.push_back(vectorizeYield);
1438 hooks.push_back(vectorizeIndex);
1445 hooks.push_back(vectorizeExtract);
1452 LDBG(
"failed to vectorize: " << op <<
"\n");
1457 state.maskOperation(rewriter, result.
newOp, linalgOp);
1458 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1483 bool useInBoundsInsteadOfMasking) {
1485 auto inputType = cast<VectorType>(input.
getType());
1486 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1487 inputType.getElementType());
1488 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1489 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1490 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1492 if (useInBoundsInsteadOfMasking) {
1494 for (
unsigned i = 0; i < rank; i++)
1495 inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1496 !ShapedType::isDynamic(destShape[i]);
1504 assert(llvm::none_of(
1505 destShape.drop_front(inputVectorSizes.size()),
1506 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1507 "Only dims aligned with inputVectorSizes may be dynamic");
1508 if (useInBoundsInsteadOfMasking)
1510 bool needMaskForWrite = !llvm::equal(
1511 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1512 if (needMaskForWrite) {
1514 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1515 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1518 Value maskForWrite =
1519 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1551 static LogicalResult
1559 auto padValue = packOp.getPaddingValue();
1561 padValue = rewriter.
create<arith::ConstantOp>(
1562 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1565 LogicalResult status =
1566 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1567 .reifyResultShapes(rewriter, reifiedReturnShapes);
1569 assert(succeeded(status) &&
"failed to reify result shapes");
1574 bool useInBoundsInsteadOfMasking =
false;
1575 if (inputVectorSizes.empty()) {
1577 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1578 useInBoundsInsteadOfMasking =
true;
1583 auto innerTiles = packOp.getStaticInnerTiles();
1584 auto innerDimsPos = packOp.getInnerDimsPos();
1585 auto outerDimsPerm = packOp.getOuterDimsPerm();
1586 if (!outerDimsPerm.empty())
1589 for (
auto [idx, size] :
enumerate(innerTiles))
1590 inputShape[innerDimsPos[idx]] *= size;
1592 rewriter, loc, packOp.getSource(), inputShape, padValue,
1593 useInBoundsInsteadOfMasking);
1597 destShape.append(innerTiles.begin(), innerTiles.end());
1599 packOp.getDestType().getElementType());
1601 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1604 auto destPermutation =
1606 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1607 loc, shapeCastOp.getResult(), destPermutation);
1611 rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1612 inputVectorSizes,
false);
1613 newResults.push_back(write->getResult(0));
1626 static LogicalResult
1634 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1639 bool useInBoundsInsteadOfMasking =
false;
1642 auto destSize = unpackOp.getDestRank();
1644 if (!inputVectorSizes.empty())
1645 assert(inputVectorSizes.size() == destSize &&
1646 "Incorrect number of input vector sizes");
1657 if (vectorSizes.empty()) {
1658 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1659 if (!outerDimsPerm.empty())
1662 vectorSizes[pos] *= innerTiles[i];
1664 useInBoundsInsteadOfMasking =
true;
1688 for (
auto [index, size] :
enumerate(innerTiles)) {
1689 readVectorSizes[innerDimPos[index]] =
1692 if (!outerDimsPerm.empty()) {
1695 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1699 LogicalResult status =
1700 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1701 .reifyResultShapes(rewriter, reifiedRetShapes);
1702 if (status.failed()) {
1703 LDBG(
"Unable to reify result shapes of " << unpackOp);
1708 auto padValue = rewriter.
create<arith::ConstantOp>(
1709 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1714 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1717 PackingMetadata packMetadata;
1720 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1722 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1724 RankedTensorType stripMineTensorType =
1727 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1728 loc, readResult, lastDimToInsertPosPerm);
1731 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1732 stripMineTensorType, packMetadata.reassociations);
1733 mlir::VectorType vecCollapsedType =
1734 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1735 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1736 loc, vecCollapsedType, transposeOp->getResult(0));
1741 unpackOp.getDestType().hasStaticShape()
1743 : shapeCastOp.getResultVectorType().getShape());
1745 rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1746 writeVectorSizes, useInBoundsInsteadOfMasking);
1747 newResults.push_back(write->
getResult(0));
1754 static LogicalResult
1758 auto padValue = padOp.getConstantPaddingValue();
1766 LogicalResult status =
1767 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1768 .reifyResultShapes(rewriter, reifiedReturnShapes);
1770 assert(succeeded(status) &&
"failed to reify result shapes");
1772 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1775 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1777 newResults.push_back(write->
getResult(0));
1785 LDBG(
"reduction precondition failed: no reduction iterator\n");
1788 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1789 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1795 LDBG(
"reduction precondition failed: reduction detection failed\n");
1802 static LogicalResult
1804 bool flatten1DDepthwiseConv) {
1805 if (flatten1DDepthwiseConv) {
1806 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1811 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1812 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1818 Value lhs = conv.getDpsInputOperand(0)->get();
1820 auto shapeWithoutCh = lhsShape.drop_back(1);
1821 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1822 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1823 "channel dim can be dynamic\n");
1830 static LogicalResult
1832 bool flatten1DDepthwiseConv) {
1833 if (isa<ConvolutionOpInterface>(op.getOperation()))
1842 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1846 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1851 static LogicalResult
1855 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1856 return !getConstantIntValue(res).has_value();
1858 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1862 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1863 unpackOp.getDestType().hasStaticShape() &&
1864 unpackOp.getSourceType().hasStaticShape();
1865 if (!satisfyEmptyCond &&
1874 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
1876 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1879 if (!inputVectorSizes.empty() &&
1885 linalgOp, flatten1DDepthwiseConv))) {
1886 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1899 customPreconditions,
1902 customPrecondition(&innerOp, vectorizeNDExtract));
1906 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1907 return !VectorType::isValidElementType(type);
1911 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1912 return !VectorType::isValidElementType(type);
1923 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1929 LDBG(
"precondition failed: not projected permutations\n");
1933 LDBG(
"precondition failed: reduction preconditions\n");
1939 static LogicalResult
1942 auto padValue = packOp.getPaddingValue();
1945 LDBG(
"pad value is not constant: " << packOp <<
"\n");
1949 bool satisfyEmptyCond =
true;
1950 if (inputVectorSizes.empty()) {
1951 if (!packOp.getDestType().hasStaticShape() ||
1952 !packOp.getSourceType().hasStaticShape())
1953 satisfyEmptyCond =
false;
1956 if (!satisfyEmptyCond &&
1958 resultTensorShape.take_front(packOp.getSourceRank()),
1962 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
1963 return !getConstantIntValue(v).has_value();
1965 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
1972 static LogicalResult
1975 auto padValue = padOp.getConstantPaddingValue();
1977 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1986 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1987 std::optional<int64_t> res = getConstantIntValue(v);
1988 return !res.has_value() || res.value() != 0;
1990 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
1999 static LogicalResult
2003 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2004 "Number of input vector sizes and scalable dims doesn't match");
2006 size_t numOfScalableDims =
2007 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2009 if (numOfScalableDims == 0)
2012 auto linalgOp = dyn_cast<LinalgOp>(op);
2020 if (numOfScalableDims > 2)
2035 bool seenParalell =
false;
2036 auto iterators = linalgOp.getIteratorTypesArray();
2038 while (!scalableFlags.back()) {
2039 seenParalell |= (iterators.back() == utils::IteratorType::parallel);
2041 iterators.pop_back();
2042 scalableFlags.pop_back();
2045 switch (iterators.back()) {
2046 case utils::IteratorType::reduction: {
2048 if (iterators.size() != inputVectorSizes.size()) {
2049 LDBG(
"Non-trailing reduction dim requested for scalable "
2053 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2054 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2055 "is not supported\n");
2060 case utils::IteratorType::parallel: {
2063 LDBG(
"Inner parallel dim not requested for scalable "
2075 if (numOfScalableDims == 2) {
2079 if (iterators.back() == utils::IteratorType::reduction) {
2080 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2084 scalableFlags.pop_back();
2085 iterators.pop_back();
2087 if (!scalableFlags.back() ||
2088 (iterators.back() != utils::IteratorType::parallel))
2094 if (linalgOp.hasUserDefinedMaps())
2099 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2100 isa<linalg::MatmulTransposeAOp>(op) ||
2101 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2108 bool flatten1DDepthwiseConv) {
2114 inputScalableVecDims)))
2118 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2121 flatten1DDepthwiseConv);
2123 .Case<tensor::PadOp>([&](
auto padOp) {
2126 .Case<tensor::PackOp>([&](
auto packOp) {
2129 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2132 .Default([](
auto) {
return failure(); });
2138 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2140 for (
auto op : make_early_inc_range(toReplace)) {
2143 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2144 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2145 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2151 return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2164 bool vectorizeNDExtract,
2165 bool flatten1DDepthwiseConv) {
2166 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2167 LDBG(
"Input vector sizes: ");
2168 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2169 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2170 LDBG(
"Input scalable vector dims: ");
2171 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2172 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2176 flatten1DDepthwiseConv))) {
2177 LDBG(
"Vectorization pre-conditions failed\n");
2183 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2184 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2185 inputScalableVecDims))) {
2186 LDBG(
"Vectorization state couldn't be initialized\n");
2192 auto vectorizeResult =
2194 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2198 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2200 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2201 flatten1DDepthwiseConv);
2202 if (succeeded(convOr)) {
2203 llvm::append_range(results, (*convOr)->getResults());
2207 LDBG(
"Unsupported convolution can't be vectorized.\n");
2211 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2224 .Case<tensor::PadOp>([&](
auto padOp) {
2228 .Case<tensor::PackOp>([&](
auto packOp) {
2232 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2234 inputVectorSizes, results);
2236 .Default([](
auto) {
return failure(); });
2238 if (failed(vectorizeResult)) {
2239 LDBG(
"Vectorization failed\n");
2243 if (!results.empty())
2252 memref::CopyOp copyOp) {
2253 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2254 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2255 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2260 if (!VectorType::isValidElementType(srcElementType) ||
2261 !VectorType::isValidElementType(dstElementType))
2272 loc, readType, copyOp.getSource(), indices,
2274 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2280 loc,
readValue, copyOp.getTarget(), indices,
2291 template <
typename OpTy>
2297 bool changed =
false;
2299 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2300 if (
auto op = dyn_cast<OpTy>(user))
2301 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2302 return success(changed);
2307 tensor::PadOp padOp, OpTy op)
const = 0;
2335 vector::TransferReadOp xferOp)
const override {
2337 if (!padOp.hasZeroLowPad())
2340 auto padValue = padOp.getConstantPaddingValue();
2344 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2349 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2351 xferOp.getSourceMutable().assign(padOp.getSource());
2352 xferOp.getPaddingMutable().assign(padValue);
2397 vector::TransferWriteOp xferOp)
const override {
2399 if (xferOp.getTransferRank() == 0)
2403 if (!padOp.hasZeroLowPad())
2406 auto padValue = padOp.getConstantPaddingValue();
2410 if (!xferOp->hasOneUse())
2412 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2416 if (!trimPadding.hasZeroOffset())
2419 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2427 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2428 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2430 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2446 tensor::ExtractSliceOp afterTrimming)
const {
2449 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2450 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2453 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2454 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2459 if (t1.getRank() != t2.getRank())
2464 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2465 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2467 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2472 if (t1.getNumDynamicDims() == 0)
2480 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2484 assert(
static_cast<size_t>(t1.getRank()) ==
2485 beforeSlice.getMixedSizes().size());
2486 assert(
static_cast<size_t>(t2.getRank()) ==
2487 afterTrimming.getMixedSizes().size());
2489 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2491 if (!t1.isDynamicDim(i))
2493 auto size1 = beforeSlice.getMixedSizes()[i];
2494 auto size2 = afterTrimming.getMixedSizes()[i];
2501 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2502 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2508 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2509 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2510 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2511 minOp1.getOperands() == minOp2.getOperands())
2534 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2535 auto source = bcast.getSource();
2536 if (llvm::dyn_cast<VectorType>(source.getType()))
2544 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2545 return fill.getInputs()[0];
2550 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2557 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2565 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2593 auto sourceType = sliceOp.getSource().getType();
2594 if (!VectorType::isValidElementType(sourceType.getElementType()))
2597 auto resultType = sliceOp.getResultType();
2611 bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2613 if (!padValue && isOutOfBoundsRead) {
2614 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
2619 auto elemType = sourceType.getElementType();
2620 padValue = rewriter.
create<arith::ConstantOp>(
2621 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2628 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2629 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
2630 if (!sourceType.isDynamicDim(i)) {
2631 vecShape.push_back(sourceType.getDimSize(i));
2634 readInBounds.push_back(
true);
2635 writeInBounds.push_back(
true);
2636 }
else if (!resultType.isDynamicDim(i)) {
2642 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2645 readInBounds.push_back(
false);
2648 writeInBounds.push_back(
false);
2656 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2661 rewriter.
create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2662 auto read = rewriter.
create<vector::TransferReadOp>(
2663 sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2668 rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2672 sliceOp, read, sliceOp.getDest(), writeIndices,
2708 tensor::InsertSliceOp insertOp)
const override {
2710 if (!padOp.hasZeroLowPad())
2713 if (!insertOp.hasUnitStride())
2716 auto padValue = padOp.getConstantPaddingValue();
2720 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2723 if (insertOp.getDest() == padOp.getResult())
2727 padOp.getType().getElementType());
2728 unsigned vecRank = vecType.getRank();
2729 unsigned tensorRank = insertOp.getType().getRank();
2734 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2736 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2737 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2748 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2749 auto read = rewriter.
create<vector::TransferReadOp>(
2750 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2756 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2759 insertOp, read, insertOp.getDest(), writeIndices,
2796 LDBG(
"interleavedUses precondition failed, firstOp: "
2797 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2800 for (
auto v : values) {
2801 for (
auto &u : v.getUses()) {
2803 if (owner == firstOp || owner == secondOp)
2809 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2810 <<
", second op: " << *secondOp <<
"\n");
2820 memref::SubViewOp subViewOp;
2822 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2824 return memref::SubViewOp();
2825 subViewOp = newSubViewOp;
2837 if (xferOp.getMask())
2841 Value viewOrAlloc = xferOp.getSource();
2850 Value subView = subViewOp.getResult();
2853 memref::CopyOp copyOp;
2854 for (
auto &u : subView.
getUses()) {
2855 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2856 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2857 if (newCopyOp.getTarget() != subView)
2871 for (
auto &u : viewOrAlloc.
getUses()) {
2872 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2873 assert(isa<MemRefType>(newFillOp.output().getType()));
2874 if (newFillOp.output() != viewOrAlloc)
2878 maybeFillOp = newFillOp;
2883 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2885 "padding value does not match fill");
2888 Value in = copyOp.getSource();
2894 auto vectorType = xferOp.getVectorType();
2895 Value res = rewriter.
create<vector::TransferReadOp>(
2896 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2897 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2902 rewriter.
eraseOp(maybeFillOp);
2914 if (xferOp.getMask())
2918 Value viewOrAlloc = xferOp.getSource();
2927 Value subView = subViewOp.getResult();
2930 memref::CopyOp copyOp;
2931 for (
auto &u : subViewOp.getResult().getUses()) {
2932 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2933 if (newCopyOp.getSource() != subView)
2945 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2946 Value out = copyOp.getTarget();
2953 auto vector = xferOp.getVector();
2954 rewriter.
create<vector::TransferWriteOp>(
2955 xferOp.getLoc(), vector, out, xferOp.getIndices(),
2956 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2973 template <
int N,
typename IntTy,
typename... IntTy2>
2975 val = shapedType.getShape()[N];
2980 template <
typename... IntTy>
2982 bindShapeDims<0>(shapedType, vals...);
2986 bool isCastOfBlockArgument(
Operation *op) {
2991 bool isSupportedPoolKind(vector::CombiningKind kind) {
2993 case vector::CombiningKind::ADD:
2994 case vector::CombiningKind::MAXNUMF:
2995 case vector::CombiningKind::MAXIMUMF:
2996 case vector::CombiningKind::MAXSI:
2997 case vector::CombiningKind::MAXUI:
2998 case vector::CombiningKind::MINNUMF:
2999 case vector::CombiningKind::MINIMUMF:
3000 case vector::CombiningKind::MINSI:
3042 struct Conv1DGenerator
3044 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3047 strideW(strideW), dilationW(dilationW) {
3049 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3051 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3052 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3053 resShaped = linalgOp.getDpsInitOperand(0)->get();
3054 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3055 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3056 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3057 if (!lhsShapedType || !rhsShapedType || !resShapedType)
3061 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3062 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3070 if (!setOperKind(reduceOp))
3076 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3077 *maybeKind != vector::CombiningKind::OR) &&
3078 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3081 reductionKind = maybeKind.value();
3083 auto rhsRank = rhsShapedType.getRank();
3086 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3120 int64_t nSize, wSize, cSize, kwSize, fSize;
3123 switch (conv1DOpOrder) {
3126 nSize = fSize = cSize = 0;
3133 (wSize + kwSize - 1)};
3134 rhsShape = {kwSize};
3155 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3160 rhsShape = {kwSize, cSize, fSize};
3163 rhsShape = {kwSize};
3166 resShape = {nSize, wSize, fSize};
3182 lhsShape = {nSize, cSize,
3186 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3190 rhsShape = {fSize, cSize, kwSize};
3193 rhsShape = {kwSize};
3196 resShape = {nSize, fSize, wSize};
3200 vector::TransferWriteOp write;
3201 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3206 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3208 Type lhsEltType = lhsShapedType.getElementType();
3209 Type rhsEltType = rhsShapedType.getElementType();
3210 Type resEltType = resShapedType.getElementType();
3220 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3223 Value rhs =
nullptr;
3225 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3227 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3233 switch (conv1DOpOrder) {
3241 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3242 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3244 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3248 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3250 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3251 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3262 kwSize, strideW, dilationW, wSizeStep,
3268 wSizeStep, isSingleChanneled);
3270 auto linearIndex = [&](int64_t kw, int64_t w) {
3271 return kw * (wSize / wSizeStep) + w;
3277 for (int64_t kw = 0; kw < kwSize; ++kw) {
3278 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3281 if (isSingleChanneled) {
3282 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3283 lhsVals[linearIndex(kw, w)],
3284 rhsVals[kw], resVals[w]);
3286 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3287 lhsVals[linearIndex(kw, w)],
3288 rhsVals[kw], resVals[w]);
3292 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3308 switch (conv1DOpOrder) {
3315 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3316 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3322 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3330 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3331 if (srcElementType == dstElementType)
3336 const Type dstType =
3337 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3339 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3340 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3343 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3344 srcWidth < dstWidth)
3345 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3347 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3348 srcWidth < dstWidth)
3349 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3351 assert(
false &&
"unhandled promotion case");
3358 vector::IteratorType par = vector::IteratorType::parallel;
3359 vector::IteratorType red = vector::IteratorType::reduction;
3364 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3366 MapList{{n, w, c}, {c, f}, {n, w, f}},
3368 contrationOp.setKind(reductionKind);
3369 return contrationOp;
3376 return rewriter.
create<vector::OuterProductOp>(
3377 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3399 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3400 bool channelDimScalableFlag,
3405 bool scalableChDim =
false;
3406 bool useMasking =
false;
3407 int64_t nSize, wSize, cSize, kwSize;
3410 if (ShapedType::isDynamic(cSize)) {
3411 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3412 cSize = channelDimVecSize;
3416 scalableChDim = channelDimScalableFlag;
3420 assert(!(useMasking && flatten) &&
3421 "Unsupported flattened conv with dynamic shapes");
3426 vector::TransferWriteOp write;
3427 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3432 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3434 Type lhsEltType = lhsShapedType.getElementType();
3435 Type rhsEltType = rhsShapedType.getElementType();
3436 Type resEltType = resShapedType.getElementType();
3441 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3443 lhsEltType, {
false,
false, scalableChDim});
3444 VectorType rhsType =
3446 {
false, scalableChDim});
3447 VectorType resType =
3449 {
false,
false, scalableChDim});
3462 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3463 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3467 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3470 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3477 Value lhs = rewriter.
create<vector::TransferReadOp>(
3478 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3479 auto maybeMaskedLhs = maybeMaskXferOp(
3480 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3483 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3485 auto maybeMaskedRhs = maybeMaskXferOp(
3486 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3489 Value res = rewriter.
create<vector::TransferReadOp>(
3490 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3491 auto maybeMaskedRes = maybeMaskXferOp(
3492 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3504 for (int64_t kw = 0; kw < kwSize; ++kw) {
3505 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3506 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3507 loc, maybeMaskedLhs->getResult(0),
3509 inOutSliceSizes, inOutStrides));
3513 for (int64_t kw = 0; kw < kwSize; ++kw) {
3514 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3515 loc, maybeMaskedRhs->getResult(0),
3519 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3520 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3521 loc, maybeMaskedRes->getResult(0),
3526 auto linearIndex = [&](int64_t kw, int64_t w) {
3527 return kw * (wSize / wSizeStep) + w;
3532 auto inOutFlattenSliceSizes =
3534 auto lhsTypeAfterFlattening =
3536 auto resTypeAfterFlattening =
3540 for (int64_t kw = 0; kw < kwSize; ++kw) {
3541 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3542 Value lhsVal = lhsVals[linearIndex(kw, w)];
3543 Value resVal = resVals[w];
3547 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3548 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3549 resVal = rewriter.
create<vector::ShapeCastOp>(
3550 loc, resTypeAfterFlattening, resVals[w]);
3552 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3553 rhsVals[kw], resVal, flatten);
3556 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3563 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3565 for (
auto &collection :
3566 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3567 for (
Value v : collection)
3574 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3575 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3576 loc, resVals[w], maybeMaskedRes->getResult(0),
3586 loc, maybeMaskedRes->getResult(0), resShaped,
3588 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3599 auto rhsTy = cast<ShapedType>(rhs.
getType());
3600 auto resTy = cast<ShapedType>(res.
getType());
3603 lhs =
promote(rewriter, loc, lhs, resTy);
3614 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3615 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3618 for (
int i = 0; i < resSize / rhsSize; ++i) {
3619 for (
int j = 0;
j < rhsSize; ++
j)
3620 indices.push_back(
j);
3623 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3626 rhs = rewriter.
create<vector::BroadcastOp>(
3627 loc, resTy.clone(rhsTy.getElementType()), rhs);
3629 rhs =
promote(rewriter, loc, rhs, resTy);
3634 if (isa<FloatType>(resTy.getElementType()))
3635 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3637 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3638 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3643 FailureOr<Operation *> generateNonChanneledConv() {
3646 if (!iters({Par(), Red()}))
3648 "failed to match conv::W 1-par 1-red");
3651 if (layout({ {w + kw},
3661 FailureOr<Operation *> generateNwcConv() {
3664 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3666 op,
"failed to match conv::Nwc 3-par 2-red");
3669 if (layout({ {n, strideW * w + dilationW * kw, c},
3679 FailureOr<Operation *> generateNcwConv() {
3682 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3684 op,
"failed to match conv::Ncw 3-par 2-red");
3686 if (layout({ {n, c, strideW * w + dilationW * kw},
3696 FailureOr<Operation *> generateNwcPooling() {
3699 if (!iters({Par(), Par(), Par(), Red()}))
3701 "failed to match pooling 3-par 1-red");
3704 if (layout({ {n, strideW * w + dilationW * kw, c},
3714 FailureOr<Operation *> generateNcwPooling() {
3717 if (!iters({Par(), Par(), Par(), Red()}))
3719 "failed to match pooling 3-par 1-red");
3721 if (layout({ {n, c, strideW * w + dilationW * kw},
3731 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3732 bool vecChDimScalableFlag =
false,
3733 bool flatten =
false) {
3736 if (!iters({Par(), Par(), Par(), Red()}))
3738 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3741 if (layout({ {n, strideW * w + dilationW * kw, c},
3744 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3750 enum OperKind { Conv, Pool };
3752 OperKind oper = Conv;
3754 StringAttr poolExtOp;
3755 bool isPoolExt =
false;
3756 int strideW, dilationW;
3757 Value lhsShaped, rhsShaped, resShaped;
3758 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3759 vector::CombiningKind reductionKind;
3770 int numBlockArguments =
3771 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3772 switch (numBlockArguments) {
3778 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3779 llvm::IsaPred<BlockArgument>);
3780 Operation *feedOp = (*feedValIt).getDefiningOp();
3781 if (isCastOfBlockArgument(feedOp)) {
3785 }
else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3786 (isa<arith::AndIOp>(feedOp) &&
3789 if (isa<BlockArgument>(v))
3791 if (Operation *op = v.getDefiningOp())
3792 return isCastOfBlockArgument(op);
3815 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3822 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3823 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3824 Conv1DGenerator e(rewriter, op, stride, dilation);
3825 auto res = e.generateNonChanneledConv();
3828 res = e.generateNwcConv();
3831 res = e.generateNcwConv();
3834 res = e.generateNwcPooling();
3837 res = e.generateNcwPooling();
3844 uint64_t vecChDimSize = ShapedType::kDynamic;
3845 bool vecChDimScalableFlag =
false;
3846 if (!inputVecSizes.empty()) {
3849 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3850 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3851 "Not a 1D depthwise conv!");
3854 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3855 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3857 vecChDimSize = inputVecSizes[chDimIdx];
3858 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3860 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3861 flatten1DDepthwiseConv);
3870 if (failed(resultOrFail))
3874 rewriter.
eraseOp(op.getOperation());
3877 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite tensor.insert.slice as a vector.transfer_read + vector.transfer_write pair.
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, PatternRewriter &rewriter) const final
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.