36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/Sequence.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/ADT/iterator_range.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
45 #include <type_traits>
50 #define DEBUG_TYPE "linalg-vectorization"
52 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
53 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
56 static FailureOr<Operation *>
60 bool flatten1DDepthwiseConv =
false);
64 template <
typename OpType>
67 block.
walk([&](OpType op) {
82 int64_t nSize, int64_t wSize, int64_t cSize,
83 int64_t kwSize,
int strideW,
int dilationW,
84 int64_t wSizeStep,
bool isSingleChanneled) {
86 if (isSingleChanneled) {
91 for (int64_t kw = 0; kw < kwSize; ++kw) {
92 for (int64_t w = 0; w < wSize; w += wSizeStep) {
93 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
102 for (int64_t kw = 0; kw < kwSize; ++kw) {
103 for (int64_t w = 0; w < wSize; w += wSizeStep) {
104 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
122 for (int64_t kw = 0; kw < kwSize; ++kw) {
123 result.push_back(rewriter.
create<vector::ExtractOp>(
133 int64_t nSize, int64_t wSize, int64_t fSize,
134 int64_t wSizeStep,
bool isSingleChanneled) {
136 if (isSingleChanneled) {
140 for (int64_t w = 0; w < wSize; w += wSizeStep) {
141 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
149 for (int64_t w = 0; w < wSize; w += wSizeStep) {
150 result.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
159 Value res, int64_t wSize, int64_t wSizeStep,
161 bool isSingleChanneled) {
163 if (isSingleChanneled) {
167 for (int64_t w = 0; w < wSize; w += wSizeStep) {
168 res = rewriter.
create<vector::InsertStridedSliceOp>(
175 for (int64_t w = 0; w < wSize; w += wSizeStep) {
176 res = rewriter.
create<vector::InsertStridedSliceOp>(
191 LogicalResult initState(
RewriterBase &rewriter, LinalgOp linalgOp,
208 std::optional<AffineMap> dimPermutation = std::nullopt)
const {
211 if (dimPermutation.has_value()) {
213 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
215 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
217 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
218 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
230 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
235 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
236 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
242 LogicalResult precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
251 std::optional<AffineMap> maybeMaskingMap);
256 bool isValidMaskingMap(
AffineMap maskingMap) {
305 VectorizationState::precomputeIterSpaceValueSizes(
RewriterBase &rewriter,
308 for (
int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
309 if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
311 iterSpaceValueSizes.push_back(rewriter.
create<arith::ConstantIndexOp>(
312 linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
319 unsigned operandDimPos;
320 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
324 Value dynamicDim = linalgOp.hasPureTensorSemantics()
326 linalgOp.getLoc(), operand, operandDimPos)
328 linalgOp.getLoc(), operand, operandDimPos);
329 iterSpaceValueSizes.push_back(dynamicDim);
345 if (!inputVectorSizes.empty()) {
349 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
350 scalableVecDims.append(inputScalableVecDims.begin(),
351 inputScalableVecDims.end());
356 canonicalVecShape = linalgOp.getStaticLoopRanges();
357 scalableVecDims.append(linalgOp.getNumLoops(),
false);
360 LDBG(
"Canonical vector shape: ");
361 LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
362 LLVM_DEBUG(llvm::dbgs() <<
"\n");
363 LDBG(
"Scalable vector dims: ");
364 LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs()));
365 LLVM_DEBUG(llvm::dbgs() <<
"\n");
367 if (ShapedType::isDynamicShape(canonicalVecShape))
371 initIterSpaceStaticSizes(linalgOp);
376 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
386 Value VectorizationState::getOrCreateMaskFor(
388 std::optional<AffineMap> maybeMaskingMap) {
390 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
391 "Ill-formed masking map.");
394 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
398 assert(!maskableOp.isMasked() &&
399 "Masking an operation that is already masked");
402 assert((!maybeMaskingMap || *maybeMaskingMap) &&
403 "Unexpected null mask permutation map");
405 maybeMaskingMap ? *maybeMaskingMap
407 linalgOp.getNumLoops(), rewriter.
getContext());
409 LDBG(
"Masking map: " << maskingMap <<
"\n");
413 auto activeMaskIt = activeMaskCache.find(maskingMap);
414 if (activeMaskIt != activeMaskCache.end()) {
415 Value mask = activeMaskIt->second;
416 LDBG(
"Reusing mask: " << mask <<
"\n");
427 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
428 auto maskType = getCanonicalVecType(rewriter.
getI1Type(), maskingMap);
429 auto maskShape = maskType.getShape();
431 LDBG(
"Mask shape: ");
432 LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
433 LLVM_DEBUG(llvm::dbgs() <<
"\n");
435 if (permutedStaticSizes == maskShape) {
436 LDBG(
"Masking is not needed for masking map: " << maskingMap <<
"\n");
437 activeMaskCache[maskingMap] =
Value();
444 assert(!maskShape.empty() && !upperBounds.empty() &&
445 "Masked 0-d vectors are not supported yet");
448 Value mask = rewriter.
create<vector::CreateMaskOp>(linalgOp.getLoc(),
449 maskType, upperBounds);
450 LDBG(
"Creating new mask: " << mask <<
"\n");
451 activeMaskCache[maskingMap] = mask;
458 std::optional<AffineMap> maybeIndexingMap) {
459 LDBG(
"Trying to mask: " << *opToMask <<
"\n");
461 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
462 if (maybeIndexingMap)
463 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
467 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
470 LDBG(
"No mask required\n");
475 assert(opToMask &&
"Expected a valid operation to mask");
476 auto maskOp = cast<vector::MaskOp>(
478 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
484 LDBG(
"Masked operation: " << *maskOp <<
"\n");
507 "expected projected permutation");
509 assert(res.getNumDims() ==
510 (res.getNumResults() - res.getNumOfZeroResults()) &&
511 "expected reindexed map with same number of dims and results");
543 std::optional<vector::CombiningKind>
545 using ::mlir::vector::CombiningKind;
550 .Case<arith::AddIOp, arith::AddFOp>(
551 [&](
auto op) {
return CombiningKind::ADD; })
552 .Case<arith::AndIOp>([&](
auto op) {
return CombiningKind::AND; })
553 .Case<arith::MaxSIOp>([&](
auto op) {
return CombiningKind::MAXSI; })
554 .Case<arith::MaxUIOp>([&](
auto op) {
return CombiningKind::MAXUI; })
555 .Case<arith::MaximumFOp>([&](
auto op) {
return CombiningKind::MAXIMUMF; })
556 .Case<arith::MaxNumFOp>([&](
auto op) {
return CombiningKind::MAXNUMF; })
557 .Case<arith::MinSIOp>([&](
auto op) {
return CombiningKind::MINSI; })
559 .Case<arith::MinimumFOp>([&](
auto op) {
return CombiningKind::MINIMUMF; })
560 .Case<arith::MinNumFOp>([&](
auto op) {
return CombiningKind::MINNUMF; })
561 .Case<arith::MulIOp, arith::MulFOp>(
562 [&](
auto op) {
return CombiningKind::MUL; })
563 .Case<arith::OrIOp>([&](
auto op) {
return CombiningKind::OR; })
564 .Case<arith::XOrIOp>([&](
auto op) {
return CombiningKind::XOR; })
565 .Default([&](
auto op) {
return std::nullopt; });
576 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
581 if (!
matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
582 combinerOps.size() != 1)
586 return combinerOps[0];
592 auto dstVecType = dyn_cast<VectorType>(dstType);
594 if (dstVecType.getRank() == 0)
600 return b.
createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
612 assert(maybeKind &&
"Failed precondition: could not get reduction kind");
613 return b.
create<vector::MultiDimReductionOp>(
614 reduceOp->
getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
618 return llvm::to_vector(
625 return isa<linalg::ReduceOp>(op) ||
626 (isa<linalg::GenericOp>(op) &&
640 auto linalgOp = cast<LinalgOp>(outputOperand->
getOwner());
641 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
650 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
652 auto vectorType = state.getCanonicalVecType(
656 if (vectorType.getRank() > 0) {
659 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
661 assert(value.
getType() == vectorType &&
"Incorrect type");
662 write = rewriter.
create<vector::TransferWriteOp>(
663 loc, value, outputOperand->
get(), indices, writeMap);
666 if (!isa<VectorType>(value.
getType()))
667 value = rewriter.
create<vector::BroadcastOp>(loc, vectorType, value);
668 assert(value.
getType() == vectorType &&
"Incorrect type");
669 write = rewriter.
create<vector::TransferWriteOp>(
673 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
677 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
678 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
683 LDBG(
"vectorized op: " << *write <<
"\n");
693 std::function<LogicalResult(
Operation *,
bool)>;
712 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
721 linalgOp.getDpsInitOperand(output.index()), state);
723 newResults.push_back(newResult);
737 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
740 auto loc = indexOp.getLoc();
743 auto dim = indexOp.getDim();
745 auto indexVectorType =
747 state.getScalableVecDims()[dim]);
748 auto indexSteps = rewriter.
create<vector::StepOp>(loc, indexVectorType);
752 if (dim == targetShape.size() - 1)
758 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
759 std::swap(permPattern[dim], permPattern.back());
763 auto broadCastOp = rewriter.
create<vector::BroadcastOp>(
764 loc, state.getCanonicalVecType(rewriter.
getIndexType(), permMap),
767 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
768 std::swap(transposition.back(), transposition[dim]);
770 rewriter.
create<vector::TransposeOp>(loc, broadCastOp, transposition);
778 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
782 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
787 if (not extractOp.getIndices().empty()) {
788 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
792 if (llvm::any_of(extractOp->getResultTypes(), [](
Type type) {
793 return !VectorType::isValidElementType(type);
813 tensor::ExtractOp extractOp,
816 auto indexVecType = state.getCanonicalVecType(rewriter.
getIndexType());
817 auto loc = extractOp.getLoc();
820 rewriter, bvm.
lookup(extractOp.getIndices()[0]), indexVecType);
822 const size_t numIndices = extractOp.getIndices().size();
823 for (
size_t i = 1; i < numIndices; i++) {
824 Value dimIdx = rewriter.
create<arith::ConstantIndexOp>(loc, i);
828 rewriter.
create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
831 offset = rewriter.
create<arith::MulIOp>(loc, offset, dimSize);
834 rewriter, bvm.
lookup(extractOp.getIndices()[i]), indexVecType);
836 offset = rewriter.
create<arith::AddIOp>(loc, extractOpIndex, offset);
862 (linalgOp.hasDynamicShape() ||
863 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
864 "For statically shaped Linalg Ops, only one "
865 "non-unit loop dim is expected");
866 assert(loopRanges.size() != 0 &&
"Empty loops, nothing to analyse.");
868 size_t idx = loopRanges.size() - 1;
869 for (; idx != 0; idx--)
870 if (loopRanges[idx] != 1)
878 VectorType resType) {
880 assert(((llvm::count_if(resType.getShape(),
881 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
882 "n-D vectors are not yet supported");
888 auto *block = linalgOp.getBlock();
889 if (isa<BlockArgument>(val))
890 return llvm::all_of(block->getArguments(),
891 [&val](
Value v) { return (v != val); });
894 assert(defOp &&
"This is neither a block argument nor an operation result");
899 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
900 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
903 auto *ancestor = block->findAncestorOpInBlock(*defOp);
910 if (isa<arith::ConstantOp>(ancestor))
914 for (
auto op : ancestor->getOperands())
938 bool &foundIndexOp, VectorType resType) {
940 assert(((llvm::count_if(resType.getShape(),
941 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
942 "n-D vectors are not yet supported");
948 auto *block = linalgOp.getBlock();
949 if (isa<BlockArgument>(val))
950 return llvm::all_of(block->getArguments(),
951 [&val](
Value v) { return (v != val); });
954 assert(defOp &&
"This is neither a block argument nor an operation result");
956 if (
auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
959 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
963 auto *ancestor = block->findAncestorOpInBlock(*defOp);
970 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
974 for (
auto op : ancestor->getOperands())
994 LinalgOp &linalgOp, VectorType resType) {
996 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
999 if (inputShape.getShape().empty())
1004 bool isOutput1DVector =
1005 (llvm::count_if(resType.getShape(),
1006 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1008 if (!isOutput1DVector)
1011 bool leadingIdxsLoopInvariant =
true;
1017 auto indices = extractOp.getIndices();
1018 auto leadIndices = indices.drop_back(1);
1021 if (inputShape.getShape()[i] == 1)
1027 if (!leadingIdxsLoopInvariant) {
1028 LDBG(
"Found gather load: " << extractOp);
1036 auto extractOpTrailingIdx = indices.back();
1040 if (leadingIdxsLoopInvariant &&
1042 LDBG(
"Found scalar broadcast load: " << extractOp);
1051 bool foundIndexOp =
false;
1053 foundIndexOp, resType);
1056 bool isRowVector = resType.getShape().back() != 1;
1057 isContiguousLoad &= (foundIndexOp && isRowVector);
1059 if (isContiguousLoad) {
1060 LDBG(
"Found contigous load: " << extractOp);
1065 LDBG(
"Found gather load: " << extractOp);
1076 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1079 auto loc = extractOp.getLoc();
1082 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1083 auto maskConstantOp = rewriter.
create<arith::ConstantOp>(
1087 auto passThruConstantOp =
1093 extractOp.getIndices().size(),
1094 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1105 loc, resultType, extractOp.getTensor(), baseIndices, offset,
1106 maskConstantOp, passThruConstantOp);
1107 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1109 LDBG(
"Vectorised as gather load: " << extractOp <<
"\n");
1118 assert(llvm::count_if(resultType.getShape(),
1119 [](uint64_t dim) { return dim != 1; }) &&
1120 "Contiguous loads and scalar loads + broadcast only support 1-D "
1137 for (
size_t i = 0; i < extractOp.getIndices().size(); i++) {
1138 Value idx = bvm.
lookup(extractOp.getIndices()[i]);
1140 transferReadIdxs.push_back(idx);
1144 auto indexAs1dVector = rewriter.
create<vector::ShapeCastOp>(
1147 resultType.getScalableDims().back()),
1149 transferReadIdxs.push_back(
1150 rewriter.
create<vector::ExtractOp>(loc, indexAs1dVector, 0));
1154 auto dstRank = resultType.getRank();
1155 auto srcRank = extractOp.getTensor().getType().getRank();
1164 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1165 loc, resultType, extractOp.getTensor(), transferReadIdxs,
1166 permutationMap, inBounds);
1173 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
1175 auto *maskedReadOp =
1178 LDBG(
"Vectorised as scalar broadcast load: " << extractOp <<
"\n");
1186 int32_t rankDiff = dstRank - srcRank;
1194 while (rankDiff > 0) {
1195 permutationMap = permutationMap.insertResult(
1200 auto transferReadOp = rewriter.
create<vector::TransferReadOp>(
1201 loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
1204 LDBG(
"Vectorised as contiguous load: " << extractOp);
1217 auto reduceType = dyn_cast<VectorType>(reduceVec.
getType());
1218 auto outputType = dyn_cast<VectorType>(outputVec.
getType());
1222 (outputType && reduceType.getShape() == outputType.getShape()))
1251 LDBG(
"vectorize op " << *op <<
"\n");
1254 if (!customVectorizationHooks.empty()) {
1255 for (
auto &customFunc : customVectorizationHooks) {
1265 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1275 auto blockArg = dyn_cast<BlockArgument>(operand);
1276 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1277 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1281 linalgOp.getRegionOutputArgs(),
1282 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1285 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1287 if (!reductionOperands.empty()) {
1288 assert(reductionOperands.size() == 1);
1290 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1291 reductionOperands[0].second, bvm);
1298 VectorType firstMaxRankedType;
1300 auto vecOperand = bvm.
lookup(operand);
1301 assert(vecOperand &&
"Vector operand couldn't be found");
1303 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1304 if (vecType && (!firstMaxRankedType ||
1305 firstMaxRankedType.getRank() < vecType.getRank()))
1306 firstMaxRankedType = vecType;
1312 assert(vecOperand &&
"Vector operand couldn't be found");
1314 if (firstMaxRankedType) {
1317 firstMaxRankedType.getScalableDims());
1320 vecOperands.push_back(vecOperand);
1326 resultTypes.push_back(
1329 firstMaxRankedType.getScalableDims())
1361 static LogicalResult
1365 LDBG(
"Vectorizing operation as linalg generic\n");
1366 Block *block = linalgOp.getBlock();
1373 bvm.
map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1375 if (linalgOp.getNumDpsInits() == 0)
1380 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1381 for (
OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1382 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1383 if (linalgOp.isScalar(opOperand)) {
1384 bvm.
map(bbarg, opOperand->get());
1390 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1393 VectorType readType;
1395 if (linalgOp.isDpsInput(opOperand)) {
1398 readType = state.getCanonicalVecType(elemType);
1405 state.getCanonicalVecType(elemType, readMap.
compose(indexingMap));
1411 loc, readType, opOperand->get(), indices, readMap);
1412 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1417 if (
auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1419 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1425 if (readType.getRank() == 0)
1441 hooks.push_back(vectorizeYield);
1448 hooks.push_back(vectorizeIndex);
1455 hooks.push_back(vectorizeExtract);
1462 LDBG(
"failed to vectorize: " << op <<
"\n");
1467 state.maskOperation(rewriter, result.
newOp, linalgOp);
1468 LDBG(
"New vector op: " << *maybeMaskedOp <<
"\n");
1493 bool useInBoundsInsteadOfMasking) {
1495 auto inputType = cast<VectorType>(input.
getType());
1496 Value dest = builder.
create<tensor::EmptyOp>(loc, destSizes,
1497 inputType.getElementType());
1498 int64_t rank = cast<ShapedType>(dest.
getType()).getRank();
1499 auto zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
1500 auto destShape = cast<ShapedType>(dest.
getType()).getShape();
1502 if (useInBoundsInsteadOfMasking) {
1504 for (
unsigned i = 0; i < rank; i++)
1505 inBoundsVal[i] = (destShape[i] == inputVectorSizes[i]) &&
1506 !ShapedType::isDynamic(destShape[i]);
1514 assert(llvm::none_of(
1515 destShape.drop_front(inputVectorSizes.size()),
1516 [](int64_t size) { return size == ShapedType::kDynamic; }) &&
1517 "Only dims aligned with inputVectorSizes may be dynamic");
1518 if (useInBoundsInsteadOfMasking)
1520 bool needMaskForWrite = !llvm::equal(
1521 inputVectorSizes, destShape.take_front(inputVectorSizes.size()));
1522 if (needMaskForWrite) {
1524 writeMaskShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
1525 writeMaskShape.append(destShape.begin() + inputVectorSizes.size(),
1528 Value maskForWrite =
1529 builder.
create<vector::CreateMaskOp>(loc, writeMaskType, destSizes);
1561 static LogicalResult
1569 auto padValue = packOp.getPaddingValue();
1571 padValue = rewriter.
create<arith::ConstantOp>(
1572 loc, rewriter.
getZeroAttr(packOp.getSourceType().getElementType()));
1575 LogicalResult status =
1576 cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
1577 .reifyResultShapes(rewriter, reifiedReturnShapes);
1579 assert(succeeded(status) &&
"failed to reify result shapes");
1584 bool useInBoundsInsteadOfMasking =
false;
1585 if (inputVectorSizes.empty()) {
1587 inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
1588 useInBoundsInsteadOfMasking =
true;
1593 auto innerTiles = packOp.getStaticInnerTiles();
1594 auto innerDimsPos = packOp.getInnerDimsPos();
1595 auto outerDimsPerm = packOp.getOuterDimsPerm();
1596 if (!outerDimsPerm.empty())
1599 for (
auto [idx, size] :
enumerate(innerTiles))
1600 inputShape[innerDimsPos[idx]] *= size;
1602 rewriter, loc, packOp.getSource(), inputShape, padValue,
1603 useInBoundsInsteadOfMasking);
1607 destShape.append(innerTiles.begin(), innerTiles.end());
1609 packOp.getDestType().getElementType());
1611 rewriter.
create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
1614 auto destPermutation =
1616 auto transposeOp = rewriter.
create<vector::TransposeOp>(
1617 loc, shapeCastOp.getResult(), destPermutation);
1621 rewriter, loc, transposeOp.getResult(), reifiedReturnShapes[0],
1622 inputVectorSizes,
false);
1623 newResults.push_back(write->getResult(0));
1636 static LogicalResult
1644 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1649 bool useInBoundsInsteadOfMasking =
false;
1652 auto destSize = unpackOp.getDestRank();
1654 if (!inputVectorSizes.empty())
1655 assert(inputVectorSizes.size() == destSize &&
1656 "Incorrect number of input vector sizes");
1667 if (vectorSizes.empty()) {
1668 llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1669 if (!outerDimsPerm.empty())
1672 vectorSizes[pos] *= innerTiles[i];
1674 useInBoundsInsteadOfMasking =
true;
1698 for (
auto [index, size] :
enumerate(innerTiles)) {
1699 readVectorSizes[innerDimPos[index]] =
1702 if (!outerDimsPerm.empty()) {
1705 readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1709 LogicalResult status =
1710 cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
1711 .reifyResultShapes(rewriter, reifiedRetShapes);
1712 if (status.failed()) {
1713 LDBG(
"Unable to reify result shapes of " << unpackOp);
1718 auto padValue = rewriter.
create<arith::ConstantOp>(
1719 loc, rewriter.
getZeroAttr(unpackOp.getSourceType().getElementType()));
1724 rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1727 PackingMetadata packMetadata;
1730 ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1732 mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1734 RankedTensorType stripMineTensorType =
1737 vector::TransposeOp transposeOp = rewriter.
create<vector::TransposeOp>(
1738 loc, readResult, lastDimToInsertPosPerm);
1741 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1742 stripMineTensorType, packMetadata.reassociations);
1743 mlir::VectorType vecCollapsedType =
1744 VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1745 vector::ShapeCastOp shapeCastOp = rewriter.
create<vector::ShapeCastOp>(
1746 loc, vecCollapsedType, transposeOp->getResult(0));
1751 unpackOp.getDestType().hasStaticShape()
1753 : shapeCastOp.getResultVectorType().getShape());
1755 rewriter, loc, shapeCastOp.getResult(), reifiedRetShapes[0],
1756 writeVectorSizes, useInBoundsInsteadOfMasking);
1757 newResults.push_back(write->
getResult(0));
1764 static LogicalResult
1768 auto padValue = padOp.getConstantPaddingValue();
1776 LogicalResult status =
1777 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1778 .reifyResultShapes(rewriter, reifiedReturnShapes);
1780 assert(succeeded(status) &&
"failed to reify result shapes");
1782 rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1785 rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes,
1787 newResults.push_back(write->
getResult(0));
1795 LDBG(
"reduction precondition failed: no reduction iterator\n");
1798 for (
OpOperand &opOperand : op.getDpsInitsMutable()) {
1799 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1805 LDBG(
"reduction precondition failed: reduction detection failed\n");
1812 static LogicalResult
1814 bool flatten1DDepthwiseConv) {
1815 if (flatten1DDepthwiseConv) {
1816 LDBG(
"Vectorization of flattened convs with dynamic shapes is not "
1821 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1822 LDBG(
"Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
1828 Value lhs = conv.getDpsInputOperand(0)->get();
1830 auto shapeWithoutCh = lhsShape.drop_back(1);
1831 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1832 LDBG(
"Dynamically-shaped op vectorization precondition failed: only "
1833 "channel dim can be dynamic\n");
1840 static LogicalResult
1842 bool flatten1DDepthwiseConv) {
1843 if (isa<ConvolutionOpInterface>(op.getOperation()))
1852 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1856 LDBG(
"Dynamically-shaped op meets vectorization pre-conditions\n");
1861 static LogicalResult
1865 if (llvm::any_of(unpackOp.getInnerTiles(), [](
OpFoldResult res) {
1866 return !getConstantIntValue(res).has_value();
1868 LDBG(
"Inner-tiles must be constant: " << unpackOp <<
"\n");
1872 bool satisfyEmptyCond = inputVectorSizes.empty() &&
1873 unpackOp.getDestType().hasStaticShape() &&
1874 unpackOp.getSourceType().hasStaticShape();
1875 if (!satisfyEmptyCond &&
1884 bool vectorizeNDExtract,
bool flatten1DDepthwiseConv) {
1886 if (llvm::is_contained(linalgOp.getStaticShape(), 0))
1889 if (!inputVectorSizes.empty() &&
1895 linalgOp, flatten1DDepthwiseConv))) {
1896 LDBG(
"Dynamically-shaped op failed vectorization pre-conditions\n");
1909 customPreconditions,
1912 customPrecondition(&innerOp, vectorizeNDExtract));
1916 if (llvm::any_of(innerOp.getOperandTypes(), [](
Type type) {
1917 return !VectorType::isValidElementType(type);
1921 if (llvm::any_of(innerOp.getResultTypes(), [](
Type type) {
1922 return !VectorType::isValidElementType(type);
1933 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
1939 LDBG(
"precondition failed: not projected permutations\n");
1943 LDBG(
"precondition failed: reduction preconditions\n");
1949 static LogicalResult
1952 auto padValue = packOp.getPaddingValue();
1955 LDBG(
"pad value is not constant: " << packOp <<
"\n");
1959 bool satisfyEmptyCond =
true;
1960 if (inputVectorSizes.empty()) {
1961 if (!packOp.getDestType().hasStaticShape() ||
1962 !packOp.getSourceType().hasStaticShape())
1963 satisfyEmptyCond =
false;
1966 if (!satisfyEmptyCond &&
1968 resultTensorShape.take_front(packOp.getSourceRank()),
1972 if (llvm::any_of(packOp.getInnerTiles(), [](
OpFoldResult v) {
1973 return !getConstantIntValue(v).has_value();
1975 LDBG(
"inner_tiles must be constant: " << packOp <<
"\n");
1982 static LogicalResult
1985 auto padValue = padOp.getConstantPaddingValue();
1987 LDBG(
"pad value is not constant: " << padOp <<
"\n");
1996 if (llvm::any_of(padOp.getLow(), [](
Value v) {
1997 std::optional<int64_t> res = getConstantIntValue(v);
1998 return !res.has_value() || res.value() != 0;
2000 LDBG(
"low pad must all be zero: " << padOp <<
"\n");
2009 static LogicalResult
2013 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2014 "Number of input vector sizes and scalable dims doesn't match");
2016 size_t numOfScalableDims =
2017 llvm::count_if(inputScalableVecDims, [](
bool flag) {
return flag; });
2019 if (numOfScalableDims == 0)
2022 auto linalgOp = dyn_cast<LinalgOp>(op);
2030 if (numOfScalableDims > 2)
2050 bool seenNonUnitParallel =
false;
2051 auto iterators = linalgOp.getIteratorTypesArray();
2053 int64_t idx = scalableFlags.size() - 1;
2054 while (!scalableFlags[idx]) {
2055 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2056 seenNonUnitParallel |=
2057 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2059 iterators.pop_back();
2060 scalableFlags.pop_back();
2065 switch (iterators.back()) {
2066 case utils::IteratorType::reduction: {
2068 if (iterators.size() != inputVectorSizes.size()) {
2069 LDBG(
"Non-trailing reduction dim requested for scalable "
2073 if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2074 LDBG(
"Scalable vectorization of the reduction dim in Matmul-like ops "
2075 "is not supported\n");
2080 case utils::IteratorType::parallel: {
2082 if (seenNonUnitParallel) {
2083 LDBG(
"Inner parallel dim not requested for scalable "
2095 if (numOfScalableDims == 2) {
2099 if (iterators.back() == utils::IteratorType::reduction) {
2100 LDBG(
"Higher dim than the trailing reduction dim requested for scalable "
2104 scalableFlags.pop_back();
2105 iterators.pop_back();
2107 if (!scalableFlags.back() ||
2108 (iterators.back() != utils::IteratorType::parallel))
2114 if (linalgOp.hasUserDefinedMaps())
2119 return success(
isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2120 isa<linalg::MatmulTransposeAOp>(op) ||
2121 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2128 bool flatten1DDepthwiseConv) {
2134 inputScalableVecDims)))
2138 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2141 flatten1DDepthwiseConv);
2143 .Case<tensor::PadOp>([&](
auto padOp) {
2146 .Case<tensor::PackOp>([&](
auto packOp) {
2149 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2152 .Default([](
auto) {
return failure(); });
2158 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2160 for (
auto op : make_early_inc_range(toReplace)) {
2163 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2164 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2165 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2171 return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2184 bool vectorizeNDExtract,
2185 bool flatten1DDepthwiseConv) {
2186 LDBG(
"Attempting to vectorize:\n" << *op <<
"\n");
2187 LDBG(
"Input vector sizes: ");
2188 LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
2189 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2190 LDBG(
"Input scalable vector dims: ");
2191 LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs()));
2192 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2196 flatten1DDepthwiseConv))) {
2197 LDBG(
"Vectorization pre-conditions failed\n");
2203 if (
auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2204 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2205 inputScalableVecDims))) {
2206 LDBG(
"Vectorization state couldn't be initialized\n");
2212 auto vectorizeResult =
2214 .Case<linalg::LinalgOp>([&](
auto linalgOp) {
2218 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2220 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2221 flatten1DDepthwiseConv);
2222 if (succeeded(convOr)) {
2223 llvm::append_range(results, (*convOr)->getResults());
2227 LDBG(
"Unsupported convolution can't be vectorized.\n");
2231 LDBG(
"Vectorize generic by broadcasting to the canonical vector "
2244 .Case<tensor::PadOp>([&](
auto padOp) {
2248 .Case<tensor::PackOp>([&](
auto packOp) {
2252 .Case<tensor::UnPackOp>([&](
auto unpackOp) {
2254 inputVectorSizes, results);
2256 .Default([](
auto) {
return failure(); });
2258 if (failed(vectorizeResult)) {
2259 LDBG(
"Vectorization failed\n");
2263 if (!results.empty())
2272 memref::CopyOp copyOp) {
2273 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2274 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2275 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2280 if (!VectorType::isValidElementType(srcElementType) ||
2281 !VectorType::isValidElementType(dstElementType))
2292 loc, readType, copyOp.getSource(), indices,
2294 if (cast<VectorType>(
readValue.getType()).getRank() == 0) {
2300 loc,
readValue, copyOp.getTarget(), indices,
2311 template <
typename OpTy>
2317 bool changed =
false;
2319 for (
auto *user : llvm::to_vector<4>(padOp->getUsers()))
2320 if (
auto op = dyn_cast<OpTy>(user))
2321 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2322 return success(changed);
2327 tensor::PadOp padOp, OpTy op)
const = 0;
2355 vector::TransferReadOp xferOp)
const override {
2357 if (!padOp.hasZeroLowPad())
2360 auto padValue = padOp.getConstantPaddingValue();
2364 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2369 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2371 xferOp.getSourceMutable().assign(padOp.getSource());
2372 xferOp.getPaddingMutable().assign(padValue);
2417 vector::TransferWriteOp xferOp)
const override {
2419 if (xferOp.getTransferRank() == 0)
2423 if (!padOp.hasZeroLowPad())
2426 auto padValue = padOp.getConstantPaddingValue();
2430 if (!xferOp->hasOneUse())
2432 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2436 if (!trimPadding.hasZeroOffset())
2439 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2447 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2448 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2450 rewriter.
replaceOp(trimPadding, newXferOp->getResult(0));
2466 tensor::ExtractSliceOp afterTrimming)
const {
2469 if (
auto castOp = beforePadding.
getDefiningOp<tensor::CastOp>())
2470 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2473 auto t1 = dyn_cast<RankedTensorType>(beforePadding.
getType());
2474 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2479 if (t1.getRank() != t2.getRank())
2484 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2485 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2487 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2492 if (t1.getNumDynamicDims() == 0)
2500 auto beforeSlice = beforePadding.
getDefiningOp<tensor::ExtractSliceOp>();
2504 assert(
static_cast<size_t>(t1.getRank()) ==
2505 beforeSlice.getMixedSizes().size());
2506 assert(
static_cast<size_t>(t2.getRank()) ==
2507 afterTrimming.getMixedSizes().size());
2509 for (
unsigned i = 0; i < t1.getRank(); ++i) {
2511 if (!t1.isDynamicDim(i))
2513 auto size1 = beforeSlice.getMixedSizes()[i];
2514 auto size2 = afterTrimming.getMixedSizes()[i];
2521 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2522 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2528 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2529 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2530 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2531 minOp1.getOperands() == minOp2.getOperands())
2554 if (
auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2555 auto source = bcast.getSource();
2556 if (llvm::dyn_cast<VectorType>(source.getType()))
2564 if (
auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2565 return fill.getInputs()[0];
2570 if (
auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2577 if (
auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2585 if (
auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2613 auto sourceType = sliceOp.getSource().getType();
2614 if (!VectorType::isValidElementType(sourceType.getElementType()))
2617 auto resultType = sliceOp.getResultType();
2631 bool isOutOfBoundsRead = !sourceType.hasStaticShape();
2633 if (!padValue && isOutOfBoundsRead) {
2634 LDBG(
"Failed to get a pad value for out-of-bounds read access\n");
2639 auto elemType = sourceType.getElementType();
2640 padValue = rewriter.
create<arith::ConstantOp>(
2641 sliceOp.getLoc(), elemType, rewriter.
getZeroAttr(elemType));
2648 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2649 for (
unsigned i = 0; i < sourceType.getRank(); ++i) {
2650 if (!sourceType.isDynamicDim(i)) {
2651 vecShape.push_back(sourceType.getDimSize(i));
2654 readInBounds.push_back(
true);
2655 writeInBounds.push_back(
true);
2656 }
else if (!resultType.isDynamicDim(i)) {
2662 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2665 readInBounds.push_back(
false);
2668 writeInBounds.push_back(
false);
2676 auto vecType =
VectorType::get(vecShape, sourceType.getElementType());
2681 rewriter.
create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2682 auto read = rewriter.
create<vector::TransferReadOp>(
2683 sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
2688 rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2692 sliceOp, read, sliceOp.getDest(), writeIndices,
2728 tensor::InsertSliceOp insertOp)
const override {
2730 if (!padOp.hasZeroLowPad())
2733 if (!insertOp.hasUnitStride())
2736 auto padValue = padOp.getConstantPaddingValue();
2740 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
2743 if (insertOp.getDest() == padOp.getResult())
2747 padOp.getType().getElementType());
2748 unsigned vecRank = vecType.getRank();
2749 unsigned tensorRank = insertOp.getType().getRank();
2754 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
2756 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](
auto it) {
2757 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
2768 vecRank, rewriter.
create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
2769 auto read = rewriter.
create<vector::TransferReadOp>(
2770 padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
2776 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
2779 insertOp, read, insertOp.getDest(), writeIndices,
2810 LDBG(
"interleavedUses precondition failed, firstOp: "
2811 << *firstOp <<
", second op: " << *secondOp <<
"\n");
2814 for (
auto v : values) {
2815 for (
auto &u : v.getUses()) {
2817 if (owner == firstOp || owner == secondOp)
2823 LDBG(
" found interleaved op " << *owner <<
", firstOp: " << *firstOp
2824 <<
", second op: " << *secondOp <<
"\n");
2834 memref::SubViewOp subViewOp;
2836 if (
auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
2838 return memref::SubViewOp();
2839 subViewOp = newSubViewOp;
2851 if (xferOp.getMask())
2855 Value viewOrAlloc = xferOp.getSource();
2864 Value subView = subViewOp.getResult();
2867 memref::CopyOp copyOp;
2868 for (
auto &u : subView.
getUses()) {
2869 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2870 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
2871 if (newCopyOp.getTarget() != subView)
2885 for (
auto &u : viewOrAlloc.
getUses()) {
2886 if (
auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
2887 assert(isa<MemRefType>(newFillOp.output().getType()));
2888 if (newFillOp.output() != viewOrAlloc)
2892 maybeFillOp = newFillOp;
2897 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
2899 "padding value does not match fill");
2902 Value in = copyOp.getSource();
2908 auto vectorType = xferOp.getVectorType();
2909 Value res = rewriter.
create<vector::TransferReadOp>(
2910 xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
2911 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
2916 rewriter.
eraseOp(maybeFillOp);
2928 if (xferOp.getMask())
2932 Value viewOrAlloc = xferOp.getSource();
2941 Value subView = subViewOp.getResult();
2944 memref::CopyOp copyOp;
2945 for (
auto &u : subViewOp.getResult().getUses()) {
2946 if (
auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
2947 if (newCopyOp.getSource() != subView)
2959 assert(isa<MemRefType>(copyOp.getTarget().getType()));
2960 Value out = copyOp.getTarget();
2967 auto vector = xferOp.getVector();
2968 rewriter.
create<vector::TransferWriteOp>(
2969 xferOp.getLoc(), vector, out, xferOp.getIndices(),
2970 xferOp.getPermutationMapAttr(), xferOp.getMask(),
2987 template <
int N,
typename IntTy,
typename... IntTy2>
2989 val = shapedType.getShape()[N];
2994 template <
typename... IntTy>
2996 bindShapeDims<0>(shapedType, vals...);
3000 bool isCastOfBlockArgument(
Operation *op) {
3005 bool isSupportedPoolKind(vector::CombiningKind kind) {
3007 case vector::CombiningKind::ADD:
3008 case vector::CombiningKind::MAXNUMF:
3009 case vector::CombiningKind::MAXIMUMF:
3010 case vector::CombiningKind::MAXSI:
3011 case vector::CombiningKind::MAXUI:
3012 case vector::CombiningKind::MINNUMF:
3013 case vector::CombiningKind::MINIMUMF:
3014 case vector::CombiningKind::MINSI:
3056 struct Conv1DGenerator
3058 Conv1DGenerator(
RewriterBase &rewriter, LinalgOp linalgOp,
int strideW,
3061 strideW(strideW), dilationW(dilationW) {
3063 if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
3065 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3066 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3067 resShaped = linalgOp.getDpsInitOperand(0)->get();
3068 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3069 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3070 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3071 if (!lhsShapedType || !rhsShapedType || !resShapedType)
3075 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
3076 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
3084 if (!setOperKind(reduceOp))
3090 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
3091 *maybeKind != vector::CombiningKind::OR) &&
3092 (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
3095 reductionKind = maybeKind.value();
3097 auto rhsRank = rhsShapedType.getRank();
3100 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
3134 int64_t nSize, wSize, cSize, kwSize, fSize;
3137 switch (conv1DOpOrder) {
3140 nSize = fSize = cSize = 0;
3147 (wSize + kwSize - 1)};
3148 rhsShape = {kwSize};
3169 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3174 rhsShape = {kwSize, cSize, fSize};
3177 rhsShape = {kwSize};
3180 resShape = {nSize, wSize, fSize};
3196 lhsShape = {nSize, cSize,
3200 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3204 rhsShape = {fSize, cSize, kwSize};
3207 rhsShape = {kwSize};
3210 resShape = {nSize, fSize, wSize};
3214 vector::TransferWriteOp write;
3215 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3220 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3222 Type lhsEltType = lhsShapedType.getElementType();
3223 Type rhsEltType = rhsShapedType.getElementType();
3224 Type resEltType = resShapedType.getElementType();
3234 Value lhs = rewriter.
create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
3237 Value rhs =
nullptr;
3239 rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3241 Value res = rewriter.
create<vector::TransferReadOp>(loc, resType, resShaped,
3247 switch (conv1DOpOrder) {
3255 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3256 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, permLhs);
3258 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3262 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, permRhs);
3264 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3265 res = rewriter.
create<vector::TransposeOp>(loc, res, permRes);
3276 kwSize, strideW, dilationW, wSizeStep,
3282 wSizeStep, isSingleChanneled);
3284 auto linearIndex = [&](int64_t kw, int64_t w) {
3285 return kw * (wSize / wSizeStep) + w;
3291 for (int64_t kw = 0; kw < kwSize; ++kw) {
3292 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3295 if (isSingleChanneled) {
3296 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3297 lhsVals[linearIndex(kw, w)],
3298 rhsVals[kw], resVals[w]);
3300 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3301 lhsVals[linearIndex(kw, w)],
3302 rhsVals[kw], resVals[w]);
3306 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3322 switch (conv1DOpOrder) {
3329 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3330 res = rewriter.
create<vector::TransposeOp>(loc, res, perm);
3336 .
create<vector::TransferWriteOp>(loc, res, resShaped, resPadding)
3344 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3345 if (srcElementType == dstElementType)
3350 const Type dstType =
3351 cast<ShapedType>(val.
getType()).cloneWith(std::nullopt, dstElementType);
3353 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3354 return rewriter.
create<arith::SIToFPOp>(loc, dstType, val);
3357 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3358 srcWidth < dstWidth)
3359 return rewriter.
create<arith::ExtFOp>(loc, dstType, val);
3361 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3362 srcWidth < dstWidth)
3363 return rewriter.
create<arith::ExtSIOp>(loc, dstType, val);
3365 assert(
false &&
"unhandled promotion case");
3372 vector::IteratorType par = vector::IteratorType::parallel;
3373 vector::IteratorType red = vector::IteratorType::reduction;
3378 auto contrationOp = rewriter.
create<vector::ContractionOp>(
3380 MapList{{n, w, c}, {c, f}, {n, w, f}},
3382 contrationOp.setKind(reductionKind);
3383 return contrationOp;
3390 return rewriter.
create<vector::OuterProductOp>(
3391 loc, res.
getType(), lhs, rhs, res, vector::CombiningKind::ADD);
3413 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3414 bool channelDimScalableFlag,
3419 bool scalableChDim =
false;
3420 bool useMasking =
false;
3421 int64_t nSize, wSize, cSize, kwSize;
3424 if (ShapedType::isDynamic(cSize)) {
3425 assert(channelDimVecSize != 0 &&
"Channel dim vec size must be > 0");
3426 cSize = channelDimVecSize;
3430 scalableChDim = channelDimScalableFlag;
3434 assert(!(useMasking && flatten) &&
3435 "Unsupported flattened conv with dynamic shapes");
3440 vector::TransferWriteOp write;
3441 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
3446 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3448 Type lhsEltType = lhsShapedType.getElementType();
3449 Type rhsEltType = rhsShapedType.getElementType();
3450 Type resEltType = resShapedType.getElementType();
3455 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3457 lhsEltType, {
false,
false, scalableChDim});
3458 VectorType rhsType =
3460 {
false, scalableChDim});
3461 VectorType resType =
3463 {
false,
false, scalableChDim});
3476 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3477 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3481 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3484 rewriter.
create<vector::CreateMaskOp>(loc, maskType, mixedDims);
3491 Value lhs = rewriter.
create<vector::TransferReadOp>(
3492 loc, lhsType, lhsShaped,
ValueRange{zero, zero, zero});
3493 auto maybeMaskedLhs = maybeMaskXferOp(
3494 lhsType.getShape(), lhsType.getScalableDims(), lhs.
getDefiningOp());
3497 Value rhs = rewriter.
create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
3499 auto maybeMaskedRhs = maybeMaskXferOp(
3500 rhsType.getShape(), rhsType.getScalableDims(), rhs.
getDefiningOp());
3503 Value res = rewriter.
create<vector::TransferReadOp>(
3504 loc, resType, resShaped,
ValueRange{zero, zero, zero});
3505 auto maybeMaskedRes = maybeMaskXferOp(
3506 resType.getShape(), resType.getScalableDims(), res.
getDefiningOp());
3518 for (int64_t kw = 0; kw < kwSize; ++kw) {
3519 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3520 lhsVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3521 loc, maybeMaskedLhs->getResult(0),
3523 inOutSliceSizes, inOutStrides));
3527 for (int64_t kw = 0; kw < kwSize; ++kw) {
3528 rhsVals.push_back(rewriter.
create<vector::ExtractOp>(
3529 loc, maybeMaskedRhs->getResult(0),
3533 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3534 resVals.push_back(rewriter.
create<vector::ExtractStridedSliceOp>(
3535 loc, maybeMaskedRes->getResult(0),
3540 auto linearIndex = [&](int64_t kw, int64_t w) {
3541 return kw * (wSize / wSizeStep) + w;
3546 auto inOutFlattenSliceSizes =
3548 auto lhsTypeAfterFlattening =
3550 auto resTypeAfterFlattening =
3554 for (int64_t kw = 0; kw < kwSize; ++kw) {
3555 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3556 Value lhsVal = lhsVals[linearIndex(kw, w)];
3557 Value resVal = resVals[w];
3561 lhsVal = rewriter.
create<vector::ShapeCastOp>(
3562 loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3563 resVal = rewriter.
create<vector::ShapeCastOp>(
3564 loc, resTypeAfterFlattening, resVals[w]);
3566 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3567 rhsVals[kw], resVal, flatten);
3570 resVals[w] = rewriter.
create<vector::ShapeCastOp>(
3577 if (!llvm::all_of(resVals, [](
Value v) {
return v; })) {
3579 for (
auto &collection :
3580 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3581 for (
Value v : collection)
3588 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3589 maybeMaskedRes = rewriter.
create<vector::InsertStridedSliceOp>(
3590 loc, resVals[w], maybeMaskedRes->getResult(0),
3600 loc, maybeMaskedRes->getResult(0), resShaped,
3602 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3613 auto rhsTy = cast<ShapedType>(rhs.
getType());
3614 auto resTy = cast<ShapedType>(res.
getType());
3617 lhs =
promote(rewriter, loc, lhs, resTy);
3628 auto rhsSize = cast<VectorType>(rhs.
getType()).getShape()[0];
3629 auto resSize = cast<VectorType>(res.
getType()).getShape()[1];
3632 for (
int i = 0; i < resSize / rhsSize; ++i) {
3633 for (
int j = 0;
j < rhsSize; ++
j)
3634 indices.push_back(
j);
3637 rhs = rewriter.
create<vector::ShuffleOp>(loc, rhs, rhs, indices);
3640 rhs = rewriter.
create<vector::BroadcastOp>(
3641 loc, resTy.clone(rhsTy.getElementType()), rhs);
3643 rhs =
promote(rewriter, loc, rhs, resTy);
3648 if (isa<FloatType>(resTy.getElementType()))
3649 return rewriter.
create<vector::FMAOp>(loc, lhs, rhs, res);
3651 auto mul = rewriter.
create<arith::MulIOp>(loc, lhs, rhs);
3652 return rewriter.
create<arith::AddIOp>(loc, mul, res);
3657 FailureOr<Operation *> generateNonChanneledConv() {
3660 if (!iters({Par(), Red()}))
3662 "failed to match conv::W 1-par 1-red");
3665 if (layout({ {w + kw},
3675 FailureOr<Operation *> generateNwcConv() {
3678 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3680 op,
"failed to match conv::Nwc 3-par 2-red");
3683 if (layout({ {n, strideW * w + dilationW * kw, c},
3693 FailureOr<Operation *> generateNcwConv() {
3696 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3698 op,
"failed to match conv::Ncw 3-par 2-red");
3700 if (layout({ {n, c, strideW * w + dilationW * kw},
3710 FailureOr<Operation *> generateNwcPooling() {
3713 if (!iters({Par(), Par(), Par(), Red()}))
3715 "failed to match pooling 3-par 1-red");
3718 if (layout({ {n, strideW * w + dilationW * kw, c},
3728 FailureOr<Operation *> generateNcwPooling() {
3731 if (!iters({Par(), Par(), Par(), Red()}))
3733 "failed to match pooling 3-par 1-red");
3735 if (layout({ {n, c, strideW * w + dilationW * kw},
3745 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3746 bool vecChDimScalableFlag =
false,
3747 bool flatten =
false) {
3750 if (!iters({Par(), Par(), Par(), Red()}))
3752 op,
"failed to match depthwise::Nwc conv 3-par 1-red");
3755 if (layout({ {n, strideW * w + dilationW * kw, c},
3758 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
3764 enum OperKind { Conv, Pool };
3766 OperKind oper = Conv;
3768 StringAttr poolExtOp;
3769 bool isPoolExt =
false;
3770 int strideW, dilationW;
3771 Value lhsShaped, rhsShaped, resShaped;
3772 ShapedType lhsShapedType, rhsShapedType, resShapedType;
3773 vector::CombiningKind reductionKind;
3784 int numBlockArguments =
3785 llvm::count_if(reduceOp->
getOperands(), llvm::IsaPred<BlockArgument>);
3786 switch (numBlockArguments) {
3792 auto feedValIt = llvm::find_if_not(reduceOp->
getOperands(),
3793 llvm::IsaPred<BlockArgument>);
3794 Operation *feedOp = (*feedValIt).getDefiningOp();
3795 if (isCastOfBlockArgument(feedOp)) {
3799 }
else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
3800 (isa<arith::AndIOp>(feedOp) &&
3803 if (isa<BlockArgument>(v))
3805 if (Operation *op = v.getDefiningOp())
3806 return isCastOfBlockArgument(op);
3829 ArrayRef<bool> inputScalableVecDims,
bool flatten1DDepthwiseConv) {
3836 auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
3837 auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3838 Conv1DGenerator e(rewriter, op, stride, dilation);
3839 auto res = e.generateNonChanneledConv();
3842 res = e.generateNwcConv();
3845 res = e.generateNcwConv();
3848 res = e.generateNwcPooling();
3851 res = e.generateNcwPooling();
3858 uint64_t vecChDimSize = ShapedType::kDynamic;
3859 bool vecChDimScalableFlag =
false;
3860 if (!inputVecSizes.empty()) {
3863 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
3864 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
3865 "Not a 1D depthwise conv!");
3868 .Case<linalg::DepthwiseConv1DNwcWcOp>([](
auto conv) {
return 2; })
3869 .Case<linalg::DepthwiseConv1DNcwCwOp>([](
auto conv) {
return 1; });
3871 vecChDimSize = inputVecSizes[chDimIdx];
3872 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
3874 return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3875 flatten1DDepthwiseConv);
3884 if (failed(resultOrFail))
3888 rewriter.
eraseOp(op.getOperation());
3891 assert(newOp->
getNumResults() == 1 &&
"expected single result");
static std::optional< VectorShape > vectorShape(Type type)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static memref::SubViewOp getSubViewUseIfUnique(Value v)
Return the unique subview use of v if it is indeed unique, null otherwise.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static LogicalResult reductionPreconditions(LinalgOp op)
static VectorizationResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value input, SmallVector< OpFoldResult > destSizes, ArrayRef< int64_t > inputVectorSizes, bool useInBoundsInsteadOfMasking)
Given an input, the mixed destSizes, and the vector sizes for vectorization, create an empty destinat...
static LogicalResult vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes)
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
static SmallVector< int64_t > getTiledPackShape(tensor::PackOp packOp, ArrayRef< int64_t > destShape)
Given a tensor::PackOp, return the dest shape before any packing permutations.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes)
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static void bindShapeDims(ShapedType shapedType)
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values)
Check whether there is any interleaved use of any values between firstOp and secondOp.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static VectorizationResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
std::function< VectorizationResult(Operation *, const IRMapping &)> CustomVectorizationHook
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv)
static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a tensor::UnPackOp to these 4 Ops: Vector::TransferReadOp - Reads a vector from the source ...
static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp)
Converts affine.apply Ops to arithmetic operations.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::PackOp with (1) static innerTiles (2) constant padding value and (3) input vector s...
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, bool vectorizeNDExtract, bool flatten1DDepthwiseConv)
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
VectorizationStatus
Helper data structure to represent the result of vectorization.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv)
static LogicalResult vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp, ArrayRef< int64_t > inputVectorSizes)
Need to check if the inner-tiles are static/constant.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Preconditions for scalable vectors.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static LogicalResult vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize a padOp with (1) static result type, (2) constant padding value and (3) all-zero lowPad to ...
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumInputs() const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
OpListType & getOperations()
AffineMap getMultiDimIdentityMap(unsigned rank)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite tensor.insert.slice as a vector.transfer_read + vector.transfer_write pair.
LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp, PatternRewriter &rewriter) const final
Rewrite use of tensor::PadOp result in InsertSliceOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override
Rewrite use of tensor::PadOp result in TransferReadOp.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override
Rewrite use of tensor::PadOp result in TransferWriteOp.
bool hasSameTensorSize(Value beforePadding, tensor::ExtractSliceOp afterTrimming) const
Check if beforePadding and afterTrimming have the same tensor size, i.e., same dimensions.
LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override
Operation * newOp
New vectorized operation to replace the current op.
enum VectorizationStatus status
Return status from vectorizing the current op.
Contains the vectorization state and related methods used across the vectorization process of a given...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
VectorizationState(RewriterBase &rewriter)
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Base pattern for rewriting tensor::PadOps whose result is consumed by a given operation type OpTy.
virtual LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, OpTy op) const =0
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const final
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override
TODO: use interfaces, side-effects and aliasing analysis as appropriate, when available.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.