36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Support/MathExtras.h"
42 #include "llvm/Support/raw_ostream.h"
48 #define DEBUG_TYPE "linalg-vectorization"
50 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
51 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
54 static FailureOr<Operation *>
58 bool flatten1DDepthwiseConv =
false);
93 template <
typename OpType>
96 block.
walk([&](OpType op) {
111 int64_t nSize, int64_t wSize, int64_t cSize,
112 int64_t kwSize,
int strideW,
int dilationW,
113 int64_t wSizeStep,
bool isSingleChanneled) {
115 if (isSingleChanneled) {
120 for (int64_t kw = 0; kw < kwSize; ++kw) {
121 for (int64_t w = 0; w < wSize; w += wSizeStep) {
122 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(rewriter.
create<vector::ExtractOp>(
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep,
bool isSingleChanneled) {
165 if (isSingleChanneled) {
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
178 for (int64_t w = 0; w < wSize; w += wSizeStep) {
179 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
188 Value res, int64_t wSize, int64_t wSizeStep,
190 bool isSingleChanneled) {
192 if (isSingleChanneled) {
196 for (int64_t w = 0; w < wSize; w += wSizeStep) {
197 res = rewriter.
create<vector::InsertStridedSliceOp>(
204 for (int64_t w = 0; w < wSize; w += wSizeStep) {
205 res = rewriter.
create<vector::InsertStridedSliceOp>(
220 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
237 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
240 if (dimPermutation.has_value()) {
242 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
244 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
246 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
247 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
259 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
265 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
271 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
280 std::optional<AffineMap> maybeMaskingMap);
285 bool isValidMaskingMap(
AffineMap maskingMap) {
334 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
337 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
338 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
340 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
341 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
348 unsigned operandDimPos;
349 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
353 Value dynamicDim = linalgOp.hasPureTensorSemantics()
355 linalgOp.getLoc(), operand, operandDimPos)
357 linalgOp.getLoc(), operand, operandDimPos);
358 iterSpaceValueSizes.push_back(dynamicDim);
374 if (!inputVectorSizes.empty()) {
378 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
379 scalableVecDims.append(inputScalableVecDims.begin(),
380 inputScalableVecDims.end());
385 canonicalVecShape = linalgOp.getStaticLoopRanges();
386 scalableVecDims.append(linalgOp.getNumLoops(),
false);
389 LDBG(
"Canonical vector shape: ");
390 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
391 LLVM_DEBUG(llvm::dbgs() <<
"\n");
392 LDBG(
"Scalable vector dims: ");
393 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
394 LLVM_DEBUG(llvm::dbgs() <<
"\n");
396 if (ShapedType::isDynamicShape(canonicalVecShape))
400 initIterSpaceStaticSizes(linalgOp);
405 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
415 Value VectorizationState::getOrCreateMaskFor(
417 std::optional<AffineMap> maybeMaskingMap) {
419 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
420 "Ill-formed masking map.");
423 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
427 assert(!maskableOp.isMasked() &&
428 "Masking an operation that is already masked");
431 assert((!maybeMaskingMap || *maybeMaskingMap) &&
432 "Unexpected null mask permutation map");
434 maybeMaskingMap ? *maybeMaskingMap
436 linalgOp.getNumLoops(), rewriter.
getContext());
438 LDBG(
"Masking map: " << maskingMap <<
"\n");
442 auto activeMaskIt = activeMaskCache.find(maskingMap);
443 if (activeMaskIt != activeMaskCache.end()) {
444 Value mask = activeMaskIt->second;
445 LDBG(
"Reusing mask: " << mask <<
"\n");
456 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
457 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
458 auto maskShape = maskType.getShape();
460 LDBG(
"Mask shape: ");
461 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
462 LLVM_DEBUG(llvm::dbgs() <<
"\n");
464 if (permutedStaticSizes == maskShape) {
465 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
466 activeMaskCache[maskingMap] =
Value();
473 assert(!maskShape.empty() && !upperBounds.empty() &&
474 "Masked 0-d vectors are not supported yet");
477 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
478 maskType, upperBounds);
479 LDBG(
"Creating new mask: " << mask <<
"\n");
480 activeMaskCache[maskingMap] = mask;
487 std::optional<AffineMap> maybeIndexingMap) {
488 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
490 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
491 if (maybeIndexingMap)
492 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
496 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
499 LDBG(
"No mask required\n");
504 assert(opToMask &&
"Expected a valid operation to mask");
505 auto maskOp = cast<vector::MaskOp>(
507 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
513 LDBG(
"Masked operation: " << *maskOp <<
"\n");
536 "expected projected permutation");
538 assert(res.getNumDims() ==
539 (res.getNumResults() - res.getNumOfZeroResults()) &&
540 "expected reindexed map with same number of dims and results");
576 std::optional<vector::CombiningKind>
578 using ::mlir::vector::CombiningKind;
583 .Case<arith::AddIOp, arith::AddFOp>(
584 [&](
auto op) {
return CombiningKind::ADD; })
585 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
586 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
587 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
588 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
589 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
590 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
592 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
593 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
594 .Case<arith::MulIOp, arith::MulFOp>(
595 [&](
auto op) {
return CombiningKind::MUL; })
596 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
597 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
598 .Default([&](
auto op) {
return std::nullopt; });
609 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
614 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
615 combinerOps.size() != 1)
619 return combinerOps[0];
625 auto dstVecType = dyn_cast<VectorType>(dstType);
627 if (dstVecType.getRank() == 0)
633 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
645 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
646 return b.
create<vector::MultiDimReductionOp>(
647 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
651 return llvm::to_vector(
658 return isa<linalg::ReduceOp>(op) ||
659 (isa<linalg::GenericOp>(op) &&
673 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
674 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
683 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
685 auto vectorType = state.getCanonicalVecType(
689 if (vectorType.getRank() > 0) {
692 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
694 assert(value.
getType() == vectorType &&
"Incorrect type");
695 write = rewriter.
create<vector::TransferWriteOp>(
696 loc, value, outputOperand->
get(), indices, writeMap);
699 if (!isa<VectorType>(value.
getType()))
700 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
701 assert(value.
getType() == vectorType &&
"Incorrect type");
702 write = rewriter.
create<vector::TransferWriteOp>(
706 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
710 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
711 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
716 LDBG(
"vectorized op: " << *write <<
"\n");
726 std::function<LogicalResult(
Operation *,
bool)>;
745 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
754 linalgOp.getDpsInitOperand(output.index()), state);
756 newResults.push_back(newResult);
770 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
773 auto loc = indexOp.getLoc();
776 auto dim = indexOp.getDim();
778 auto indexVectorType =
780 state.getScalableVecDims()[dim]);
781 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
785 if (dim == targetShape.size() - 1)
791 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
792 std::swap(permPattern[dim], permPattern.back());
796 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
797 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
800 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
801 std::swap(transposition.back(), transposition[dim]);
803 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
811 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
815 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
820 if (not extractOp.getIndices().empty()) {
821 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
825 if (!llvm::all_of(extractOp->getResultTypes(),
826 VectorType::isValidElementType)) {
845 tensor::ExtractOp extractOp,
848 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
849 auto loc = extractOp.getLoc();
852 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
854 const size_t numIndices = extractOp.getIndices().size();
855 for (
size_t i = 1; i < numIndices; i++) {
856 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
860 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
863 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
866 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
868 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
894 (linalgOp.hasDynamicShape() ||
895 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
896 "For statically shaped Linalg Ops, only one "
897 "non-unit loop dim is expected");
898 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
900 size_t idx = loopRanges.size() - 1;
901 for (; idx != 0; idx--)
902 if (loopRanges[idx] != 1)
910 VectorType resType) {
912 assert(((llvm::count_if(resType.getShape(),
913 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
914 "n-D vectors are not yet supported");
920 auto *block = linalgOp.getBlock();
921 if (isa<BlockArgument>(val))
922 return !llvm::is_contained(block->getArguments(), val);
925 assert(defOp &&
"This is neither a block argument nor an operation result");
930 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
931 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
934 auto *ancestor = block->findAncestorOpInBlock(*defOp);
941 if (isa<arith::ConstantOp>(ancestor))
945 for (
auto op : ancestor->getOperands())
969 bool &foundIndexOp, VectorType resType) {
971 assert(((llvm::count_if(resType.getShape(),
972 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
973 "n-D vectors are not yet supported");
979 auto *block = linalgOp.getBlock();
980 if (isa<BlockArgument>(val))
981 return !llvm::is_contained(block->getArguments(), val);
984 assert(defOp &&
"This is neither a block argument nor an operation result");
986 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
989 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
993 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1000 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1003 bool result =
false;
1004 for (
auto op : ancestor->getOperands())
1024 LinalgOp &linalgOp, VectorType resType) {
1026 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1029 if (inputShape.getShape().empty())
1034 bool isOutput1DVector =
1035 (llvm::count_if(resType.getShape(),
1036 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1038 if (!isOutput1DVector)
1041 bool leadingIdxsLoopInvariant =
true;
1047 auto indices = extractOp.getIndices();
1048 auto leadIndices = indices.drop_back(1);
1051 if (inputShape.getShape()[i] == 1)
1057 if (!leadingIdxsLoopInvariant) {
1058 LDBG(
"Found gather load: " << extractOp);
1066 auto extractOpTrailingIdx = indices.back();
1070 if (leadingIdxsLoopInvariant &&
1072 LDBG(
"Found scalar broadcast load: " << extractOp);
1081 bool foundIndexOp =
false;
1083 foundIndexOp, resType);
1086 bool isRowVector = resType.getShape().back() != 1;
1087 isContiguousLoad &= (foundIndexOp && isRowVector);
1089 if (isContiguousLoad) {
1090 LDBG(
"Found contigous load: " << extractOp);
1095 LDBG(
"Found gather load: " << extractOp);
1106 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1109 auto loc = extractOp.getLoc();
1112 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1113 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1117 auto passThruConstantOp =
1123 extractOp.getIndices().size(),
1124 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1135 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1136 maskConstantOp, passThruConstantOp);
1137 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1139 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1162 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1163 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1165 transferReadIdxs.push_back(idx);
1169 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1172 resultType.getScalableDims().back()),
1174 transferReadIdxs.push_back(
1175 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1179 auto dstRank = resultType.getRank();
1180 auto srcRank = extractOp.getTensor().getType().getRank();
1189 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1190 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1191 std::nullopt, permutationMap, inBounds);
1198 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
1200 auto *maskedReadOp =
1203 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1212 int32_t rankDiff = dstRank - srcRank;
1220 while (rankDiff > 0) {
1221 permutationMap = permutationMap.insertResult(
1226 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1227 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1228 std::nullopt, permutationMap, inBounds);
1230 LDBG(
"Vectorised as contiguous load: " << extractOp);
1244 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1245 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1249 (outputType && reduceType.getShape() == outputType.getShape()))
1278 LDBG(
"vectorize op " << *op <<
"\n");
1281 if (!customVectorizationHooks.empty()) {
1282 for (
auto &customFunc : customVectorizationHooks) {
1292 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1294 rewriter.
clone(*op)};
1303 auto blockArg = dyn_cast<BlockArgument>(operand);
1304 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1305 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1309 linalgOp.getRegionOutputArgs(),
1310 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1313 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1315 if (!reductionOperands.empty()) {
1316 assert(reductionOperands.size() == 1);
1318 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1319 reductionOperands[0].second, bvm);
1326 VectorType firstMaxRankedType;
1328 auto vecOperand = bvm.
lookup(operand);
1329 assert(vecOperand &&
"Vector operand couldn't be found");
1331 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1332 if (vecType && (!firstMaxRankedType ||
1333 firstMaxRankedType.getRank() < vecType.getRank()))
1334 firstMaxRankedType = vecType;
1340 assert(vecOperand &&
"Vector operand couldn't be found");
1342 if (firstMaxRankedType) {
1345 firstMaxRankedType.getScalableDims());
1348 vecOperands.push_back(vecOperand);
1354 resultTypes.push_back(
1357 firstMaxRankedType.getScalableDims())
1389 static LogicalResult
1393 LDBG(
"Vectorizing operation as linalg generic\n");
1394 Block *block = linalgOp.getBlock();
1401 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1403 if (linalgOp.getNumDpsInits() == 0)
1408 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1409 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1410 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1411 if (linalgOp.isScalar(opOperand)) {
1412 bvm.
map(bbarg, opOperand->get());
1418 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1421 VectorType readType;
1423 if (linalgOp.isDpsInput(opOperand)) {
1426 readType = state.getCanonicalVecType(elemType);
1433 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1439 loc, readType, opOperand->get(), indices,
1440 std::nullopt, readMap);
1441 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1446 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1448 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1454 if (readType.getRank() == 0)
1470 hooks.push_back(vectorizeYield);
1477 hooks.push_back(vectorizeIndex);
1484 hooks.push_back(vectorizeExtract);
1491 LDBG(
"failed to vectorize: " << op <<
"\n");
1496 state.maskOperation(rewriter, result.
newOp, linalgOp);
1497 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1563 if (ShapedType::isDynamicShape(destShape))
1570 cstMaskSizes.push_back(*intSize);
1575 if (cstMaskSizes.size() != maskShape.size())
1583 cstWriteIdxs.push_back(intVal.getSExtValue());
1588 if (cstWriteIdxs.size() != destShape.size())
1597 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1599 if ( maskShape[i] > destShape[rankDiff + i] ||
1600 destShape[rankDiff + i] <
1601 (
std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1637 bool useInBoundsInsteadOfMasking =
false) {
1639 ShapedType destType = cast<ShapedType>(dest.
getType());
1640 int64_t destRank = destType.getRank();
1641 auto destShape = destType.getShape();
1643 VectorType vecToStoreType = cast<VectorType>(vecToStore.
getType());
1644 int64_t vecToStoreRank = vecToStoreType.getRank();
1645 auto vecToStoreShape = vecToStoreType.getShape();
1649 if (useInBoundsInsteadOfMasking) {
1652 for (
unsigned i = 0; i < vecToStoreRank; i++)
1654 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1655 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1659 assert((writeIndices.empty() ||
1660 writeIndices.size() ==
static_cast<size_t>(destRank)) &&
1661 "Invalid number of write indices!");
1662 if (writeIndices.empty()) {
1663 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1664 writeIndices.assign(destRank, zero);
1669 builder.
create<vector::TransferWriteOp>(loc,
1676 if (useInBoundsInsteadOfMasking)
1680 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1695 Value maskForWrite =
1696 builder.
createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1734 static LogicalResult
1743 auto padValue = packOp.getPaddingValue();
1745 padValue = rewriter.
create<arith::ConstantOp>(
1746 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1749 LogicalResult status =
1750 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1751 .reifyResultShapes(rewriter, reifiedReturnShapes);
1753 assert(succeeded(status) &&
"failed to reify result shapes");
1758 bool useInBoundsInsteadOfMasking =
false;
1759 if (inputVectorSizes.empty()) {
1761 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1762 useInBoundsInsteadOfMasking =
true;
1767 auto innerTiles = packOp.getStaticInnerTiles();
1776 rewriter, loc, packOp.getSource(), inputShape, padValue,
1777 useInBoundsInsteadOfMasking);
1783 packOp.getDestType().getElementType());
1785 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1788 auto destPermutation =
1790 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1791 loc, shapeCastOp.getResult(), destPermutation);
1795 loc, reifiedReturnShapes[0],
1796 transposeOp.getResult().getType().getElementType());
1799 newResults.push_back(write->getResult(0));
1812 static LogicalResult
1821 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1826 bool useInBoundsInsteadOfMasking =
false;
1829 auto destSize = unpackOp.getDestRank();
1831 if (!inputVectorSizes.empty())
1832 assert(inputVectorSizes.size() == destSize &&
1833 "Incorrect number of input vector sizes");
1844 if (vectorSizes.empty()) {
1845 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1851 useInBoundsInsteadOfMasking =
true;
1876 readVectorSizes[innerDimPos[index]] =
1882 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1886 LogicalResult status =
1887 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1888 .reifyResultShapes(rewriter, reifiedRetShapes);
1889 if (status.failed()) {
1890 LDBG(
"Unable to reify result shapes of " << unpackOp);
1895 auto padValue = rewriter.
create<arith::ConstantOp>(
1896 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1901 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1904 PackingMetadata packMetadata;
1907 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1909 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1911 RankedTensorType stripMineTensorType =
1914 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1915 loc, readResult, lastDimToInsertPosPerm);
1918 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1919 stripMineTensorType, packMetadata.reassociations);
1920 mlir::VectorType vecCollapsedType =
1921 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1922 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1923 loc, vecCollapsedType, transposeOp->getResult(0));
1928 unpackOp.getDestType().hasStaticShape()
1930 : shapeCastOp.getResultVectorType().getShape());
1932 loc, reifiedRetShapes[0],
1933 shapeCastOp.getResult().getType().getElementType());
1935 rewriter, loc, shapeCastOp.getResult(), dest,
1936 {}, useInBoundsInsteadOfMasking);
1937 newResults.push_back(write->getResult(0));
1944 static LogicalResult
1948 auto padValue = padOp.getConstantPaddingValue();
1956 LogicalResult status =
1957 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1958 .reifyResultShapes(rewriter, reifiedReturnShapes);
1960 assert(succeeded(status) &&
"failed to reify result shapes");
1962 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1967 loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1969 newResults.push_back(write->getResult(0));
1977 LDBG(
"reduction precondition failed: no reduction iterator\n");
1980 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1981 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1987 LDBG(
"reduction precondition failed: reduction detection failed\n");
1994 static LogicalResult
1996 bool flatten1DDepthwiseConv) {
1997 if (flatten1DDepthwiseConv) {
1998 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
2003 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2004 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
2010 Value lhs = conv.getDpsInputOperand(0)->get();
2012 auto shapeWithoutCh = lhsShape.drop_back(1);
2013 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2014 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
2015 "channel dim can be dynamic\n");
2022 static LogicalResult
2024 bool flatten1DDepthwiseConv) {
2025 if (isa<ConvolutionOpInterface>(op.getOperation()))
2034 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2038 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
2043 static LogicalResult
2047 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
2048 return !getConstantIntValue(res).has_value();
2050 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
2054 bool satisfyEmptyCond = inputVectorSizes.empty() &&
2055 unpackOp.getDestType().hasStaticShape() &&
2056 unpackOp.getSourceType().hasStaticShape();
2057 if (!satisfyEmptyCond &&
2064 static LogicalResult
2069 auto sourceType = source.getType();
2070 if (!VectorType::isValidElementType(sourceType.getElementType()))
2086 bool isOutOfBoundsRead =
2087 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2089 if (!padValue && isOutOfBoundsRead) {
2090 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
2097 enum class ConvOperationKind { Conv, Pool };
2115 static std::optional<ConvOperationKind>
2117 int numBlockArguments =
2118 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
2120 switch (numBlockArguments) {
2126 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
2127 llvm::IsaPred<BlockArgument>);
2129 "Expected a non-block argument operand");
2130 Operation *feedOp = (*feedValIt).getDefiningOp();
2132 return ConvOperationKind::Pool;
2135 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2136 (isa<arith::AndIOp>(feedOp) &&
2139 if (isa<BlockArgument>(v))
2141 if (Operation *op = v.getDefiningOp())
2142 return isCastOfBlockArgument(op);
2145 return std::nullopt;
2148 return ConvOperationKind::Conv;
2152 return ConvOperationKind::Pool;
2154 return std::nullopt;
2160 case vector::CombiningKind::ADD:
2161 case vector::CombiningKind::MAXNUMF:
2162 case vector::CombiningKind::MAXIMUMF:
2163 case vector::CombiningKind::MAXSI:
2164 case vector::CombiningKind::MAXUI:
2165 case vector::CombiningKind::MINNUMF:
2166 case vector::CombiningKind::MINIMUMF:
2167 case vector::CombiningKind::MINSI:
2176 auto getOperandType = [&](
auto operand) {
2177 return dyn_cast<ShapedType>((operand->get()).getType());
2179 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2180 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2181 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2185 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2186 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2194 if (!maybeOper.has_value())
2201 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2202 *maybeKind != vector::CombiningKind::OR) &&
2203 (*maybeOper != ConvOperationKind::Pool ||
2208 auto rhsRank = rhsShapedType.getRank();
2209 if (*maybeOper == ConvOperationKind::Pool) {
2213 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2222 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2224 if (llvm::any_of(linalgOp->getOpOperands(), [&](
OpOperand &operand) {
2225 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2229 if (!inputVectorSizes.empty() &&
2235 linalgOp, flatten1DDepthwiseConv))) {
2236 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
2249 customPreconditions,
2252 customPrecondition(&innerOp, vectorizeNDExtract));
2256 if (!llvm::all_of(innerOp.getOperandTypes(),
2257 VectorType::isValidElementType)) {
2260 if (!llvm::all_of(innerOp.getResultTypes(),
2261 VectorType::isValidElementType)) {
2271 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2278 LDBG(
"precondition failed: not projected permutations\n");
2282 LDBG(
"precondition failed: reduction preconditions\n");
2288 static LogicalResult
2291 auto padValue = packOp.getPaddingValue();
2294 LDBG(
"pad value is not constant: " << packOp <<
"\n");
2298 bool satisfyEmptyCond =
true;
2299 if (inputVectorSizes.empty()) {
2300 if (!packOp.getDestType().hasStaticShape() ||
2301 !packOp.getSourceType().hasStaticShape())
2302 satisfyEmptyCond =
false;
2305 if (!satisfyEmptyCond &&
2307 resultTensorShape.take_front(packOp.getSourceRank()),
2311 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
2312 return !getConstantIntValue(v).has_value();
2314 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
2321 static LogicalResult
2324 auto padValue = padOp.getConstantPaddingValue();
2326 LDBG(
"pad value is not constant: " << padOp <<
"\n");
2346 if (llvm::any_of(
llvm::enumerate(padOp.getLow()), [&](
const auto &en) {
2347 Value padValue = en.value();
2348 unsigned pos = en.index();
2349 std::optional<int64_t> pad = getConstantIntValue(padValue);
2350 return (!pad.has_value() || pad.value() != 0) &&
2351 resultTensorShape[pos] != 1;
2353 LDBG(
"low pad must all be zero for all non unit dims: " << padOp <<
"\n");
2362 static LogicalResult
2366 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2367 "Number of input vector sizes and scalable dims doesn't match");
2369 size_t numOfScalableDims =
2370 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2372 if (numOfScalableDims == 0)
2375 auto linalgOp = dyn_cast<LinalgOp>(op);
2383 if (numOfScalableDims > 2)
2403 bool seenNonUnitParallel =
false;
2404 auto iterators = linalgOp.getIteratorTypesArray();
2406 int64_t idx = scalableFlags.size() - 1;
2407 while (!scalableFlags[idx]) {
2408 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2409 seenNonUnitParallel |=
2410 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2412 iterators.pop_back();
2413 scalableFlags.pop_back();
2418 switch (iterators.back()) {
2419 case utils::IteratorType::reduction: {
2421 if (iterators.size() != inputVectorSizes.size()) {
2422 LDBG(
"Non-trailing reduction dim requested for scalable "
2426 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2427 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2428 "is not supported\n");
2433 case utils::IteratorType::parallel: {
2435 if (seenNonUnitParallel) {
2436 LDBG(
"Inner parallel dim not requested for scalable "
2448 if (numOfScalableDims == 2) {
2452 if (iterators.back() == utils::IteratorType::reduction) {
2453 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2457 scalableFlags.pop_back();
2458 iterators.pop_back();
2460 if (!scalableFlags.back() ||
2461 (iterators.back() != utils::IteratorType::parallel))
2467 if (linalgOp.hasUserDefinedMaps())
2472 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2473 isa<linalg::MatmulTransposeAOp>(op) ||
2474 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2481 bool flatten1DDepthwiseConv) {
2487 inputScalableVecDims)))
2491 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2494 flatten1DDepthwiseConv);
2496 .Case<tensor::PadOp>([&](
auto padOp) {
2499 .Case<linalg::PackOp>([&](
auto packOp) {
2502 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2505 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2508 .Default([](
auto) {
return failure(); });
2514 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2516 for (
auto op : make_early_inc_range(toReplace)) {
2519 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2520 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2521 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2527 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2528 tensor::InsertSliceOp>(op);
2531 FailureOr<VectorizationResult>
2535 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
2536 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2537 LDBG(
"Input vector sizes: ");
2538 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2539 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2540 LDBG(
"Input scalable vector dims: ");
2541 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2542 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2546 flatten1DDepthwiseConv))) {
2547 LDBG(
"Vectorization pre-conditions failed\n");
2553 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2554 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2555 inputScalableVecDims))) {
2556 LDBG(
"Vectorization state couldn't be initialized\n");
2562 auto vectorizeResult =
2564 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2568 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2570 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2571 flatten1DDepthwiseConv);
2572 if (succeeded(convOr)) {
2573 llvm::append_range(results, (*convOr)->getResults());
2577 LDBG(
"Unsupported convolution can't be vectorized.\n");
2581 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2594 .Case<tensor::PadOp>([&](
auto padOp) {
2598 .Case<linalg::PackOp>([&](
auto packOp) {
2602 .Case<linalg::UnPackOp>([&](
auto unpackOp) {
2604 inputVectorSizes, results);
2606 .Case<tensor::InsertSliceOp>([&](
auto sliceOp) {
2610 .Default([](
auto) {
return failure(); });
2612 if (failed(vectorizeResult)) {
2613 LDBG(
"Vectorization failed\n");
2621 memref::CopyOp copyOp) {
2622 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2623 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2624 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2629 if (!VectorType::isValidElementType(srcElementType) ||
2630 !VectorType::isValidElementType(dstElementType))
2641 loc, readType, copyOp.getSource(), indices,
2644 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2650 loc,
readValue, copyOp.getTarget(), indices,
2661 template <
typename OpTy>
2669 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2670 if (
auto op = dyn_cast<OpTy>(user))
2671 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2677 tensor::PadOp padOp, OpTy op)
const = 0;
2705 vector::TransferReadOp xferOp)
const override {
2707 if (!padOp.hasZeroLowPad())
2710 auto padValue = padOp.getConstantPaddingValue();
2714 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2719 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2721 xferOp.getBaseMutable().assign(padOp.getSource());
2722 xferOp.getPaddingMutable().assign(padValue);
2767 vector::TransferWriteOp xferOp)
const override {
2769 if (xferOp.getTransferRank() == 0)
2773 if (!padOp.hasZeroLowPad())
2776 auto padValue = padOp.getConstantPaddingValue();
2780 if (!xferOp->hasOneUse())
2782 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2786 if (!trimPadding.hasZeroOffset())
2789 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2797 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2798 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2800 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2816 tensor::ExtractSliceOp afterTrimming)
const {
2819 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2820 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2823 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2824 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2829 if (t1.getRank() != t2.getRank())
2834 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2835 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2837 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2842 if (t1.getNumDynamicDims() == 0)
2850 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2854 assert(
static_cast<size_t>(t1.getRank()) ==
2855 beforeSlice.getMixedSizes().size());
2856 assert(
static_cast<size_t>(t2.getRank()) ==
2857 afterTrimming.getMixedSizes().size());
2859 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2861 if (!t1.isDynamicDim(i))
2863 auto size1 = beforeSlice.getMixedSizes()[i];
2864 auto size2 = afterTrimming.getMixedSizes()[i];
2871 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2872 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2878 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2879 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2880 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2881 minOp1.getOperands() == minOp2.getOperands())
2907 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2908 auto source = bcast.getSource();
2909 if (llvm::dyn_cast<VectorType>(source.getType()))
2917 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2918 return fill.getInputs()[0];
2923 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2930 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2938 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2944 static LogicalResult
2953 auto sourceType = source.getType();
2954 auto resultType = sliceOp.getResultType();
2959 auto elemType = sourceType.getElementType();
2960 padValue = rewriter.
create<arith::ConstantOp>(
2961 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2966 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2967 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2968 if (!inputVectorSizes.empty()) {
2969 vecShape.push_back(inputVectorSizes[i]);
2970 }
else if (!sourceType.isDynamicDim(i)) {
2971 vecShape.push_back(sourceType.getDimSize(i));
2972 }
else if (!resultType.isDynamicDim(i)) {
2978 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2985 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2988 auto loc = sliceOp.getLoc();
2992 vecType.getRank(), rewriter.
create<arith::ConstantIndexOp>(loc, 0));
2994 rewriter, loc, source, vecType.getShape(), padValue,
2995 inputVectorSizes.empty());
3002 writeIndices, inputVectorSizes.empty());
3005 newResults.push_back(write->
getResult(0));
3039 tensor::InsertSliceOp insertOp)
const override {
3041 if (!padOp.hasZeroLowPad())
3044 if (!insertOp.hasUnitStride())
3047 auto padValue = padOp.getConstantPaddingValue();
3051 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3054 if (insertOp.getDest() == padOp.getResult())
3058 padOp.getType().getElementType());
3059 unsigned vecRank = vecType.getRank();
3060 unsigned tensorRank = insertOp.getType().getRank();
3065 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3067 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
3068 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3079 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
3080 auto read = rewriter.
create<vector::TransferReadOp>(
3081 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
3087 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3090 insertOp, read, insertOp.getDest(), writeIndices,
3116 LDBG(
"interleavedUses precondition failed, firstOp: "
3117 << *firstOp <<
", second op: " << *secondOp <<
"\n");
3120 for (
auto v : values) {
3121 for (
auto &u : v.getUses()) {
3123 if (owner == firstOp || owner == secondOp)
3129 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
3130 <<
", second op: " << *secondOp <<
"\n");
3140 memref::SubViewOp subViewOp;
3142 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3144 return memref::SubViewOp();
3145 subViewOp = newSubViewOp;
3157 if (xferOp.getMask())
3161 Value viewOrAlloc = xferOp.getBase();
3170 Value subView = subViewOp.getResult();
3173 memref::CopyOp copyOp;
3174 for (
auto &u : subView.
getUses()) {
3175 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3176 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3177 if (newCopyOp.getTarget() != subView)
3191 for (
auto &u : viewOrAlloc.
getUses()) {
3192 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3193 assert(isa<MemRefType>(newFillOp.output().getType()));
3194 if (newFillOp.output() != viewOrAlloc)
3198 maybeFillOp = newFillOp;
3203 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3205 "padding value does not match fill");
3208 Value in = copyOp.getSource();
3214 auto vectorType = xferOp.getVectorType();
3215 Value res = rewriter.
create<vector::TransferReadOp>(
3216 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3217 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3222 rewriter.
eraseOp(maybeFillOp);
3234 if (xferOp.getMask())
3238 Value viewOrAlloc = xferOp.getBase();
3247 Value subView = subViewOp.getResult();
3250 memref::CopyOp copyOp;
3251 for (
auto &u : subViewOp.getResult().getUses()) {
3252 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3253 if (newCopyOp.getSource() != subView)
3265 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3266 Value out = copyOp.getTarget();
3273 auto vector = xferOp.getVector();
3274 rewriter.
create<vector::TransferWriteOp>(
3275 xferOp.getLoc(), vector, out, xferOp.getIndices(),
3276 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3278 dyn_cast<VectorType>(vector.getType()).getRank(),
false)));
3293 template <
int N,
typename IntTy,
typename... IntTy2>
3295 val = shapedType.getShape()[N];
3300 template <
typename... IntTy>
3302 bindShapeDims<0>(shapedType, vals...);
3340 struct Conv1DGenerator
3342 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp)
3345 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3346 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3347 resShaped = linalgOp.getDpsInitOperand(0)->get();
3348 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3349 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3350 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3355 setConvOperationKind(reduceOp);
3358 reductionKind = maybeKind.value();
3366 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3367 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3389 int64_t nSize, wSize, cSize, kwSize, fSize;
3392 switch (conv1DOpOrder) {
3395 nSize = fSize = cSize = 0;
3402 (wSize + kwSize - 1)};
3403 rhsShape = {kwSize};
3410 case ConvOperationKind::Conv:
3414 case ConvOperationKind::Pool:
3424 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3428 case ConvOperationKind::Conv:
3429 rhsShape = {kwSize, cSize, fSize};
3431 case ConvOperationKind::Pool:
3432 rhsShape = {kwSize};
3435 resShape = {nSize, wSize, fSize};
3441 case ConvOperationKind::Conv:
3445 case ConvOperationKind::Pool:
3451 lhsShape = {nSize, cSize,
3455 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3458 case ConvOperationKind::Conv:
3459 rhsShape = {fSize, cSize, kwSize};
3461 case ConvOperationKind::Pool:
3462 rhsShape = {kwSize};
3465 resShape = {nSize, fSize, wSize};
3469 vector::TransferWriteOp write;
3470 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3475 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3477 Type lhsEltType = lhsShapedType.getElementType();
3478 Type rhsEltType = rhsShapedType.getElementType();
3479 Type resEltType = resShapedType.getElementType();
3489 Value lhs = rewriter.
create<vector::TransferReadOp>(
3490 loc, lhsType, lhsShaped, lhsPadding,
3493 Value rhs =
nullptr;
3494 if (oper == ConvOperationKind::Conv)
3495 rhs = rewriter.
create<vector::TransferReadOp>(
3496 loc, rhsType, rhsShaped, rhsPadding,
3498 Value res = rewriter.
create<vector::TransferReadOp>(
3499 loc, resType, resShaped, resPadding,
3505 switch (conv1DOpOrder) {
3513 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3514 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3516 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3519 if (oper == ConvOperationKind::Conv)
3520 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3522 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3523 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3534 kwSize, strideW, dilationW, wSizeStep,
3537 if (oper == ConvOperationKind::Conv)
3540 wSizeStep, isSingleChanneled);
3542 auto linearIndex = [&](int64_t kw, int64_t w) {
3543 return kw * (wSize / wSizeStep) + w;
3549 for (int64_t kw = 0; kw < kwSize; ++kw) {
3550 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3552 case ConvOperationKind::Conv:
3553 if (isSingleChanneled) {
3554 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3555 lhsVals[linearIndex(kw, w)],
3556 rhsVals[kw], resVals[w]);
3558 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3559 lhsVals[linearIndex(kw, w)],
3560 rhsVals[kw], resVals[w]);
3563 case ConvOperationKind::Pool:
3564 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3580 switch (conv1DOpOrder) {
3587 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3588 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3594 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3602 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3603 if (srcElementType == dstElementType)
3608 const Type dstType =
3609 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3611 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3612 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3615 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3616 srcWidth < dstWidth)
3617 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3619 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3620 srcWidth < dstWidth)
3621 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3623 assert(
false &&
"unhandled promotion case");
3630 vector::IteratorType par = vector::IteratorType::parallel;
3631 vector::IteratorType red = vector::IteratorType::reduction;
3636 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3638 MapList{{n, w, c}, {c, f}, {n, w, f}},
3640 contrationOp.setKind(reductionKind);
3641 return contrationOp;
3648 return rewriter.
create<vector::OuterProductOp>(
3649 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3671 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3672 bool channelDimScalableFlag,
3674 bool scalableChDim =
false;
3675 bool useMasking =
false;
3676 int64_t nSize, wSize, cSize, kwSize;
3679 if (ShapedType::isDynamic(cSize)) {
3680 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3681 cSize = channelDimVecSize;
3685 scalableChDim = channelDimScalableFlag;
3689 assert(!(useMasking && flatten) &&
3690 "Unsupported flattened conv with dynamic shapes");
3695 vector::TransferWriteOp write;
3696 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3701 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3703 Type lhsEltType = lhsShapedType.getElementType();
3704 Type rhsEltType = rhsShapedType.getElementType();
3705 Type resEltType = resShapedType.getElementType();
3710 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3712 lhsEltType, {
false,
false, scalableChDim});
3713 VectorType rhsType =
3715 {
false, scalableChDim});
3716 VectorType resType =
3718 {
false,
false, scalableChDim});
3731 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3732 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3736 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3739 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3746 Value lhs = rewriter.
create<vector::TransferReadOp>(
3747 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero},
3749 auto maybeMaskedLhs = maybeMaskXferOp(
3750 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3753 Value rhs = rewriter.
create<vector::TransferReadOp>(
3754 loc, rhsType, rhsShaped,
ValueRange{zero, zero},
3756 auto maybeMaskedRhs = maybeMaskXferOp(
3757 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3760 Value res = rewriter.
create<vector::TransferReadOp>(
3761 loc, resType, resShaped,
ValueRange{zero, zero, zero},
3763 auto maybeMaskedRes = maybeMaskXferOp(
3764 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3776 for (int64_t kw = 0; kw < kwSize; ++kw) {
3777 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3778 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3779 loc, maybeMaskedLhs->getResult(0),
3781 inOutSliceSizes, inOutStrides));
3785 for (int64_t kw = 0; kw < kwSize; ++kw) {
3786 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3787 loc, maybeMaskedRhs->getResult(0),
3791 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3792 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3793 loc, maybeMaskedRes->getResult(0),
3798 auto linearIndex = [&](int64_t kw, int64_t w) {
3799 return kw * (wSize / wSizeStep) + w;
3805 auto lhsTypeAfterFlattening =
3807 auto resTypeAfterFlattening =
3811 for (int64_t kw = 0; kw < kwSize; ++kw) {
3812 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3813 Value lhsVal = lhsVals[linearIndex(kw, w)];
3814 Value resVal = resVals[w];
3818 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3819 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3820 resVal = rewriter.
create<vector::ShapeCastOp>(
3821 loc, resTypeAfterFlattening, resVals[w]);
3823 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3824 rhsVals[kw], resVal, flatten);
3827 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3834 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3836 for (
auto &collection :
3837 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3838 for (
Value v : collection)
3845 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3846 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3847 loc, resVals[w], maybeMaskedRes->getResult(0),
3857 loc, maybeMaskedRes->getResult(0), resShaped,
3859 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3870 auto rhsTy = cast<ShapedType>(rhs.
getType());
3871 auto resTy = cast<ShapedType>(res.
getType());
3874 lhs =
promote(rewriter, loc, lhs, resTy);
3885 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3886 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3889 for (
int i = 0; i < resSize / rhsSize; ++i) {
3890 for (
int j = 0;
j < rhsSize; ++
j)
3891 indices.push_back(
j);
3894 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3897 rhs = rewriter.
create<vector::BroadcastOp>(
3898 loc, resTy.clone(rhsTy.getElementType()), rhs);
3900 rhs =
promote(rewriter, loc, rhs, resTy);
3905 if (isa<FloatType>(resTy.getElementType()))
3906 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3908 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3909 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3914 FailureOr<Operation *> generateNonChanneledConv() {
3917 if (!iters({Par(), Red()}))
3919 "failed to match conv::W 1-par 1-red");
3922 if (layout({ {w + kw},
3932 FailureOr<Operation *> generateNwcConv() {
3935 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3937 op,
"failed to match conv::Nwc 3-par 2-red");
3940 if (layout({ {n, strideW * w + dilationW * kw, c},
3950 FailureOr<Operation *> generateNcwConv() {
3953 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3955 op,
"failed to match conv::Ncw 3-par 2-red");
3957 if (layout({ {n, c, strideW * w + dilationW * kw},
3967 FailureOr<Operation *> generateNwcPooling() {
3970 if (!iters({Par(), Par(), Par(), Red()}))
3972 "failed to match pooling 3-par 1-red");
3975 if (layout({ {n, strideW * w + dilationW * kw, c},
3985 FailureOr<Operation *> generateNcwPooling() {
3988 if (!iters({Par(), Par(), Par(), Red()}))
3990 "failed to match pooling 3-par 1-red");
3992 if (layout({ {n, c, strideW * w + dilationW * kw},
4002 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4003 bool vecChDimScalableFlag =
false,
4004 bool flatten =
false) {
4007 if (!iters({Par(), Par(), Par(), Red()}))
4009 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
4012 if (layout({ {n, strideW * w + dilationW * kw, c},
4015 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4021 ConvOperationKind oper = ConvOperationKind::Conv;
4023 StringAttr poolExtOp;
4024 bool isPoolExt =
false;
4025 int strideW, dilationW;
4026 Value lhsShaped, rhsShaped, resShaped;
4027 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4028 vector::CombiningKind reductionKind;
4031 void setConvOperationKind(
Operation *reduceOp) {
4032 int numBlockArguments =
4033 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
4034 if (numBlockArguments == 1) {
4039 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
4040 llvm::IsaPred<BlockArgument>);
4041 Operation *feedOp = (*feedValIt).getDefiningOp();
4043 oper = ConvOperationKind::Pool;
4048 oper = ConvOperationKind::Conv;
4052 oper = ConvOperationKind::Pool;
4062 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
4063 Conv1DGenerator conv1dGen(rewriter, op);
4064 auto res = conv1dGen.generateNonChanneledConv();
4067 res = conv1dGen.generateNwcConv();
4070 res = conv1dGen.generateNcwConv();
4073 res = conv1dGen.generateNwcPooling();
4076 res = conv1dGen.generateNcwPooling();
4083 uint64_t vecChDimSize = ShapedType::kDynamic;
4084 bool vecChDimScalableFlag =
false;
4085 if (!inputVecSizes.empty()) {
4088 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4089 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4090 "Not a 1D depthwise conv!");
4093 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
4094 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
4096 vecChDimSize = inputVecSizes[chDimIdx];
4097 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4099 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4100 flatten1DDepthwiseConv);
4109 if (failed(resultOrFail))
4113 rewriter.
eraseOp(op.getOperation());
4116 assert(newOp->
getNumResults() == 1 &&
"expected single result");
SmallVector< int64_t > outerDimsPerm
union mlir::linalg::@1221::ArityGroupAndKind::Kind kind
SmallVector< OpFoldResult > innerTiles
SmallVector< int64_t > innerDimsPos
static std::optional< VectorShape > vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< ConvOperationKind > getConvOperationKind(Operation *reduceOp)
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a linalg::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize linalg::PackOp with (1) static inner_tiles (2) constant padding value and (3) input vector ...
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static bool isCastOfBlockArgument(Operation *op)
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp)
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static bool isSupportedPoolKind(vector::CombiningKind kind)
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(linalg::PackOp packOp, ArrayRef< int64_t > destShape)
Given a linalg::PackOp, return the dest shape before any packing permutations.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
operand_iterator operand_end()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
Value getZeroConstant(OpBuilder &builder, Location loc, Type type)
Creates an arith.constant operation with a zero value of type type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
enum WinogradConv2DFmr uint32_t std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Transformation information returned after vectorizing.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.