33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
40 #define DEBUG_TYPE "linalg-transforms"
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
62 .Case<scf::ForOp>([&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
66 return partialIteration->getResults();
67 assert(!partialIteration &&
"expected that loop was not peeled");
68 return forOp->getResults();
77 for (
auto loopOp : loops)
90 if (!e.isFunctionOfDim(dim))
148 static FailureOr<SmallVector<std::optional<int64_t>>>
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
156 indexingMaps.size(), std::nullopt);
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
160 AffineMap map = indexingMaps[operandIdx];
163 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
171 "num results invariant violation");
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(map);
179 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
185 newMaps.push_back(map);
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
190 indexingMaps = newMaps;
192 return packedDimPerIndexingMap;
198 struct PackedOperandsDim {
204 struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(packedOperandsDims);
220 tensor::PackOp packOp) {
222 auto packedTensorType =
223 cast<RankedTensorType>(packOp->getResultTypes().front());
224 if (llvm::any_of(packOp.getStaticInnerTiles(),
225 [](int64_t size) { return ShapedType::isDynamic(size); })) {
228 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
237 PackingMetadata packingMetadata = computePackingMetadata(
238 packedTensorType.getRank(), packOp.getInnerDimsPos());
252 for (
auto [pos, innerSize] :
253 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
255 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
265 rewriter, loc, map, {outerSize, origSize, innerSize});
267 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269 packingMetadata.reassociations);
270 Value paddingValue = packOp.getPaddingValue();
272 paddingValue = rewriter.
create<arith::ConstantOp>(
276 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
277 highs, paddingValue,
false);
280 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
281 DBGS() <<
"insertPositions: ");
282 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
283 DBGS() <<
"outerPositions: ");
284 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
285 DBGS() <<
"packedShape: ");
287 llvm::interleaveComma(packedToStripMinedShapePerm,
288 DBGS() <<
"packedToStripMinedShapePerm: ");
289 DBGSNL(); llvm::interleaveComma(
290 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
292 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
295 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
298 if (packOp.isLikePad()) {
317 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
318 loc, padOp, packOp.getDest(),
321 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
323 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
331 auto expandShapeResultType =
333 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
334 loc, expandShapeResultType, padOp.getResult(),
335 packingMetadata.reassociations);
340 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
341 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
344 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
345 llvm::interleaveComma(transpPerm,
DBGS() <<
"transpPerm: ");
349 rewriter.
replaceOp(packOp, transposeOp->getResults());
355 tensor::UnPackOp unPackOp) {
360 RankedTensorType packedTensorType = unPackOp.getSourceType();
361 int64_t packedRank = packedTensorType.getRank();
364 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
365 if (unPackOp.isLikeUnPad()) {
374 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
375 loc, destTensorType, unPackOp.getSource(),
379 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
382 nullptr, extractSliceOp};
387 PackingMetadata packingMetadata;
397 RankedTensorType stripMinedTensorType =
399 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
400 stripMinedTensorType, packingMetadata.reassociations);
407 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
408 loc, dims, stripMinedTensorType.getElementType());
409 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
410 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
414 DBGS() <<
"insertPositions: ");
415 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
416 DBGS() <<
"packedShape: ");
418 llvm::interleaveComma(packedToStripMinedShapePerm,
419 DBGS() <<
"packedToStripMinedShapePerm: ");
420 DBGSNL(); llvm::interleaveComma(
421 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
423 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
426 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
430 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
431 loc, collapsedType, transposeOp->getResult(0),
432 packingMetadata.reassociations);
435 int64_t destRank = destTensorType.getRank();
436 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
437 loc, destTensorType, reshapeOp->getResult(0),
443 auto copyOp = rewriter.
create<linalg::CopyOp>(
444 loc, extractSliceOp->getResult(0), unPackOp.getDest());
447 rewriter.
replaceOp(unPackOp, copyOp->getResults());
453 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
455 for (
auto &i : spec) {
456 if (!i.packedDimForEachOperand[operandPos].has_value())
458 res.push_back(i.packedDimForEachOperand[operandPos].value());
464 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
466 for (
auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
469 res.push_back(i.packedSize);
478 linalg::LinalgOp linalgOp,
480 if (packedSizes.size() != linalgOp.getNumLoops()) {
482 "incorrect number of pack sizes");
488 linalgOp.getIteratorTypesArray();
489 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n";
490 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
491 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
497 PackedOperandsDimList listOfPackedOperandsDim;
498 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
501 if (maybeConstant.has_value() && maybeConstant.value() == 0)
504 PackedOperandsDim packedOperandsDims;
505 packedOperandsDims.packedSize = packedSizes[i];
506 FailureOr<SmallVector<std::optional<int64_t>>>
507 maybePackedDimForEachOperand =
509 if (failed(maybePackedDimForEachOperand))
511 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
512 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
515 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
517 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
518 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
DBGSNL();
519 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
520 DBGS() <<
"packedDimForEachOperand: ");
527 linalgOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
529 for (
const auto &operandsList : {inputOperands, initOperands}) {
530 for (
OpOperand *opOperand : operandsList) {
531 int64_t pos = opOperand->getOperandNumber();
532 Value operand = opOperand->get();
534 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
536 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
538 DBGS() <<
"operand: " << operand <<
"\n";
539 llvm::interleaveComma(innerPos,
DBGS() <<
"innerPos: ");
DBGSNL();
540 llvm::interleaveComma(innerPackSizes,
DBGS() <<
"innerPackSizes: ");
542 if (innerPackSizes.empty()) {
543 inputsAndInits.push_back(operand);
546 Value dest = tensor::PackOp::createDestinationTensor(
547 rewriter, loc, operand, innerPackSizes, innerPos,
549 ShapedType operandType = cast<ShapedType>(operand.
getType());
550 bool areConstantTiles =
554 if (areConstantTiles && operandType.hasStaticShape() &&
555 !tensor::PackOp::requirePaddingValue(
556 operandType.getShape(), innerPos,
557 cast<ShapedType>(dest.
getType()).getShape(), {},
559 packOps.push_back(rewriter.
create<tensor::PackOp>(
560 loc, operand, dest, innerPos, innerPackSizes));
566 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
567 packOps.push_back(rewriter.
create<tensor::PackOp>(
568 loc, operand, dest, innerPos, innerPackSizes, zero));
570 inputsAndInits.push_back(packOps.back());
576 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
578 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
579 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
580 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
585 for (
OpResult result : packedLinalgOp->getResults()) {
586 int64_t resultNum = result.getResultNumber();
587 tensor::PackOp maybePackedInit =
588 inits[resultNum].getDefiningOp<tensor::PackOp>();
589 if (!maybePackedInit) {
590 results.push_back(result);
594 unPackOps.push_back(rewriter.
create<tensor::UnPackOp>(
595 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
596 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
597 results.push_back(unPackOps.back());
605 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
634 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
638 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
640 assert(tensorType == transposedValue.
getType() &&
641 "expected tensor type mismatch");
646 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
650 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
654 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
660 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
663 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
664 operandsRef.take_front(linalgOp.getNumDpsInputs()),
665 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
667 linalgOp.getIteratorTypesArray());
669 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
671 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
674 FailureOr<PackTransposeResult>
676 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
683 tensor::PackOp transposedPackOp =
684 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
686 if (!packOp.getResult().hasOneUse())
689 OpOperand &packUse = *packOp->getUses().begin();
690 if (packUse.
getOwner() != linalgOp) {
692 linalgOp,
"not a single use by the LinalgOp target");
695 (!linalgOp.isDpsInit(&packUse) ||
696 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
698 "not produced by the LinalgOp target");
704 int64_t numLeadingDims = packOp.getSourceRank();
705 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
709 if (permutation.empty())
710 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
712 if (innerPerm.empty()) {
715 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
717 llvm::append_range(permutation,
718 llvm::map_range(innerPerm, [&](int64_t pos) {
719 return numLeadingDims + pos;
731 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
734 tensor::UnPackOp transposedUnPackOp;
737 transposedLinalgOp->getOpOperand(packUseOperandNumber);
738 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
740 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
741 rewriter, loc, transposedResult, innerPerm, outerPerm);
743 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
747 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
765 FailureOr<PackResult>
770 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
771 assert((mnkPaddedSizesNextMultipleOf.empty() ||
772 mnkPaddedSizesNextMultipleOf.size() == 3) &&
773 "num of packing sizes next multiple should be empty or of size 3");
774 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
777 int64_t numLoops = linalgOp.getNumLoops();
779 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
780 << numLoops <<
"\nin: " << linalgOp <<
"\n");
782 linalgOp,
"need 3+ loops to find a matmul to pack");
786 int64_t numPackedDims = mnkPackedSizes.size();
788 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
789 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
791 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
792 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
794 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
795 paddedSizesNextMultipleOf[mnkOrder[i]] =
796 mnkPaddedSizesNextMultipleOf.empty() ? 0
797 : mnkPaddedSizesNextMultipleOf[i];
801 FailureOr<ContractionDimensions> maybeDimensions =
803 if (failed(maybeDimensions)) {
804 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
807 "couldn't infer matmul iterators");
815 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
816 kPos = maybeDimensions->k.back();
818 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
819 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
823 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
825 FailureOr<GenericOp> generalizeResult =
827 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
828 genericOp = *generalizeResult;
836 LLVM_DEBUG(llvm::interleaveComma(permutation,
DBGS() <<
"perm: ");
DBGSNL(););
839 FailureOr<GenericOp> interchangeResult =
841 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
842 genericOp = *interchangeResult;
843 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
860 cast<LinalgOp>(genericOp.getOperation())
861 .createLoopRanges(rewriter, genericOp.getLoc());
865 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
866 DBGS() <<
"paddedSizesNextMultipleOf: ");
868 LLVM_DEBUG(llvm::interleaveComma(loopRanges,
DBGS() <<
"loopRanges: ",
869 [](
Range r) { llvm::dbgs() << r.size; });
873 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
874 if (paddedSizesNextMultipleOf[i] == 0) {
875 adjustedPackedSizes.push_back(packedSizes[i]);
882 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
883 {loopRanges[adjustedPackedSizes.size()].size,
884 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
886 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
887 DBGS() <<
"adjustedPackedSizes: ");
894 return pack(rewriter, genericOp, adjustedPackedSizes);
903 assert(!tileSizeComputationFunction &&
"tile sizes already set");
908 &op->getParentOfType<func::FuncOp>().getBody().front());
909 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
927 auto padValue = padOp.getConstantPaddingValue();
929 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
932 auto generateOp = rewriter.
create<tensor::GenerateOp>(
933 padOp.getLoc(), padOp.getResultType(), dynSizes);
936 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
945 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
949 padOp.getLoc(), cast<IntegerAttr>(ofr.get<
Attribute>()).getInt())
953 auto resultType = padOp.getResultType();
957 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
958 if (resultType.isDynamicDim(dim)) {
960 padOp.getSource(), dim));
963 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
965 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
966 dynSizes.push_back(plusHigh);
968 staticSizes.push_back(resultType.getDimSize(dim));
972 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
973 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
977 auto sourceType = padOp.getSourceType();
985 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
993 if (!sliceOp.hasUnitStride())
996 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1000 bool zeroSliceGuard =
true;
1002 if (std::optional<bool> control = controlFn(sliceOp))
1003 zeroSliceGuard = *control;
1008 FailureOr<TilingResult> tilingResult =
1010 sliceOp.getMixedSizes(), zeroSliceGuard);
1011 if (failed(tilingResult))
1015 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1025 tensor::PackOp packOp) {
1026 Value input = packOp.getSource();
1027 if (!packOp.getPaddingValue()) {
1031 assert(llvm::all_of(packOp.getAllOuterDims(),
1032 [](int64_t val) { return val == 1; }) &&
1033 "some outer dims are != 1");
1036 ShapedType inputType = packOp.getSourceType();
1037 int64_t inputRank = inputType.getRank();
1040 packOp.getDimAndTileMapping();
1047 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1050 if (!tileAndPosMapping.count(dimIdx)) {
1051 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1052 assert(inputDimSize == 1 &&
1053 "with all outer dims == 1, this non-tiled input dim should be 1!");
1054 paddedShape.push_back(inputDimSize);
1061 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1065 if (cstTileSize.has_value()) {
1066 paddedShape.push_back(cstTileSize.value());
1071 paddedShape.push_back(ShapedType::kDynamic);
1074 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1079 false, loc, builder,
1089 constexpr int64_t kNonTiledMarker = -1;
1094 vec, [&](int64_t v) {
return v != kNonTiledMarker; }));
1109 int64_t unpackedRank = shape.size();
1110 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1111 if (llvm::is_contained(innerDimsPos, i)) {
1112 innerDims.push_back(dim++);
1117 outerDims.push_back(dim++);
1118 if (!outerDimsPerm.empty())
1119 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1125 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1130 rankReducedOuterDimsPerm =
1132 if (!rankReducedOuterDimsPerm.empty())
1133 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1136 perm.append(innerDims);
1145 if (llvm::any_of(packOp.getAllOuterDims(),
1146 [](int64_t dim) { return dim != 1; })) {
1148 packOp,
"not all outer dimensions of the result are 1s");
1157 packOp.getDimAndTileMapping();
1158 int64_t srcRank = packOp.getSourceRank();
1159 int64_t destRank = packOp.getDestRank();
1160 int64_t numTiles = destRank - srcRank;
1162 if (!llvm::all_of(packOp.getInnerDimsPos(),
1163 [&srcRank, &numTiles](int64_t dimPos) {
1164 return dimPos >= (srcRank - numTiles - 1);
1167 packOp,
"Attempting to tile non-trailing source dims!");
1173 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1174 if (dimAndTileMapping.count(i)) {
1178 auto [_, tileSize] =
1180 tileSizes.push_back(tileSize);
1194 for (int64_t i = 0; i < (srcRank - numTiles); i++)
1195 srcPermForTranspose.push_back(i);
1199 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n";
1200 llvm::interleaveComma(srcPermForTranspose,
DBGS() <<
"perm: ");
1206 transShapeForEmptyOp.append(tileSizes);
1208 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1209 srcPermForTranspose);
1211 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1214 auto transposedOp = rewriter.
create<linalg::TransposeOp>(loc, input, empty,
1215 srcPermForTranspose);
1226 for (
auto tileSize : packOp.getMixedTiles()) {
1227 auto [tileSizeStatic, tileSizeOfr] =
1229 writeSizes.push_back(tileSizeOfr);
1230 writeShape.push_back(tileSizeStatic);
1234 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1235 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1236 writeSizes, writeStrides);
1237 rewriter.
replaceOp(packOp, insert.getResult());
1244 int64_t srcRank = unpackOp.getSourceRank();
1245 int64_t destRank = unpackOp.getDestRank();
1248 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1249 [](int64_t dim) { return dim != 1; })) {
1252 "require the tiled outer dimensions of the result are all 1s");
1257 Value source = unpackOp.getSource();
1259 unpackOp.getDimAndTileMapping();
1267 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1268 if (dimAndTileMapping.count(i)) {
1269 readSizes.push_back(oneIdxAttr);
1273 if (ShapedType::isDynamic(srcShape[i])) {
1275 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1276 readSizes.push_back(dynamicDim);
1277 dynamicDims.push_back(dynamicDim);
1279 readSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1281 if (srcShape[i] != 1)
1282 readShape.push_back(srcShape[i]);
1284 auto mixedTiles = unpackOp.getMixedTiles();
1285 readSizes.append(mixedTiles.begin(), mixedTiles.end());
1289 auto tileShape = srcShape.drop_front(destRank);
1291 readShape.append(tileShape.begin(), tileShape.end());
1292 Type elemType = unpackOp.getSourceType().getElementType();
1294 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1295 loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1299 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1303 applyPermutationToVector<int64_t>(transpShape, perm);
1306 rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1308 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1312 int numLoops = transpShape.size();
1317 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1318 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1319 tileSizes.push_back(
1323 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1324 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1330 for (
int i = 0, idx = 0; i < destRank; ++i) {
1331 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1332 writeSizes.push_back(tileSizes[idx++]);
1334 writeSizes.push_back(oneIdxAttr);
1336 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1337 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1339 rewriter.
replaceOp(unpackOp, insert.getResult());
1352 template <
typename Conv2DOp,
typename Conv1DOp>
1355 if (convOp.hasPureBufferSemantics())
1358 Value input = convOp.getInputs().front();
1359 Value kernel = convOp.getInputs().back();
1360 Value output = convOp.getOutputs().front();
1362 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1363 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1364 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1366 auto kernelShape = kernelType.getShape();
1367 auto outputShape = outputType.getShape();
1370 auto [khIndex, kwIndex, ohIndex, owIndex] =
1373 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1374 return std::make_tuple(0, 1, 1, 2);
1376 .Case([&](linalg::Conv2DNchwFchwOp op) {
1377 return std::make_tuple(2, 3, 2, 3);
1379 .Case([&](linalg::PoolingNhwcSumOp op) {
1380 return std::make_tuple(0, 1, 1, 2);
1382 .Case([&](linalg::PoolingNchwSumOp op) {
1383 return std::make_tuple(0, 1, 2, 3);
1385 .Case([&](linalg::PoolingNhwcMaxOp op) {
1386 return std::make_tuple(0, 1, 1, 2);
1388 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1389 return std::make_tuple(0, 1, 1, 2);
1391 .Case([&](linalg::PoolingNhwcMinOp op) {
1392 return std::make_tuple(0, 1, 1, 2);
1394 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1395 return std::make_tuple(0, 1, 1, 2);
1397 .Case([&](linalg::PoolingNchwMaxOp op) {
1398 return std::make_tuple(0, 1, 2, 3);
1401 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1402 return std::make_tuple(0, 0, 0, 0);
1407 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1408 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1409 bool removeH = (khSize == 1 && ohSize == 1);
1410 bool removeW = (kwSize == 1 && owSize == 1);
1411 if (!removeH && !removeW)
1417 RankedTensorType newInputType =
1418 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1419 RankedTensorType newKernelType =
1420 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1421 RankedTensorType newOutputType =
1422 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1427 rewriter, loc, input, newInputType);
1429 rewriter, loc, kernel, newKernelType);
1431 rewriter, loc, output, newOutputType);
1436 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1437 strides.erase(strides.begin() + (removeH ? 0 : 1));
1441 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1442 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1445 auto conv1DOp = rewriter.
create<Conv1DOp>(
1446 loc, newOutputType,
ValueRange{newInput, newKernel},
1447 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1451 rewriter, loc, conv1DOp.getResult(0), output);
1468 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1472 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1476 FailureOr<DepthwiseConv1DNwcWcOp>
1479 if (convOp.hasPureBufferSemantics())
1482 Value input = convOp.getInputs().front();
1483 Value kernel = convOp.getInputs().back();
1484 Value output = convOp.getOutputs().front();
1486 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1487 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1488 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1490 auto kernelShape = kernelType.getShape();
1491 auto outputShape = outputType.getShape();
1495 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1496 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1497 bool removeH = (khSize == 1 && ohSize == 1);
1498 bool removeW = (kwSize == 1 && owSize == 1);
1499 if (!removeH && !removeW)
1505 RankedTensorType newInputType =
1506 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1507 RankedTensorType newKernelType =
1508 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1509 RankedTensorType newOutputType =
1510 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1515 rewriter, loc, input, newInputType);
1517 rewriter, loc, kernel, newKernelType);
1519 rewriter, loc, output, newOutputType);
1523 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1524 strides.erase(strides.begin() + (removeH ? 0 : 1));
1528 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1529 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1532 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1533 loc, newOutputType,
ValueRange{newInput, newKernel},
1534 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1538 rewriter, loc, conv1DOp.getResult(0), output);
1547 if (convOp.hasPureBufferSemantics())
1550 Value input = convOp.getInputs().front();
1551 Value kernel = convOp.getInputs().back();
1552 Value output = convOp.getOutputs().front();
1554 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1555 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1556 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1558 auto kernelShape = kernelType.getShape();
1559 auto outputShape = outputType.getShape();
1563 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1564 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1565 bool removeH = (khSize == 1 && ohSize == 1);
1566 bool removeW = (kwSize == 1 && owSize == 1);
1567 if (!removeH && !removeW)
1573 RankedTensorType newInputType =
1574 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1575 RankedTensorType newKernelType =
1576 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1577 RankedTensorType newOutputType =
1578 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1583 rewriter, loc, input, newInputType);
1585 rewriter, loc, kernel, newKernelType);
1587 rewriter, loc, output, newOutputType);
1589 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1595 rewriter, loc, conv1DOp.getResult(0), output);
1614 PoolingNwcMaxUnsignedOp>,
1617 PoolingNwcMinUnsignedOp>,
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDim={})
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a tensor::PackOp into a sequence of:
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.