33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
40 #define DEBUG_TYPE "linalg-transforms"
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
62 .Case<scf::ForOp>([&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
66 return partialIteration->getResults();
67 assert(!partialIteration &&
"expected that loop was not peeled");
68 return forOp->getResults();
77 for (
auto loopOp : loops)
90 if (!e.isFunctionOfDim(dim))
148 static FailureOr<SmallVector<std::optional<int64_t>>>
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
156 indexingMaps.size(), std::nullopt);
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
160 AffineMap map = indexingMaps[operandIdx];
163 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
171 "num results invariant violation");
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(map);
179 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
185 newMaps.push_back(map);
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
190 indexingMaps = newMaps;
192 return packedDimPerIndexingMap;
198 struct PackedOperandsDim {
204 struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(packedOperandsDims);
220 tensor::PackOp packOp) {
222 auto packedTensorType =
223 cast<RankedTensorType>(packOp->getResultTypes().front());
224 if (llvm::any_of(packOp.getStaticInnerTiles(),
225 [](int64_t size) { return ShapedType::isDynamic(size); })) {
228 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
237 PackingMetadata packingMetadata = computePackingMetadata(
238 packedTensorType.getRank(), packOp.getInnerDimsPos());
252 for (
auto [pos, innerSize] :
253 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
255 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
265 rewriter, loc, map, {outerSize, origSize, innerSize});
267 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269 packingMetadata.reassociations);
270 Value paddingValue = packOp.getPaddingValue();
272 paddingValue = rewriter.
create<arith::ConstantOp>(
276 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
277 highs, paddingValue,
false);
280 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
281 DBGS() <<
"insertPositions: ");
282 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
283 DBGS() <<
"outerPositions: ");
284 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
285 DBGS() <<
"packedShape: ");
287 llvm::interleaveComma(packedToStripMinedShapePerm,
288 DBGS() <<
"packedToStripMinedShapePerm: ");
289 DBGSNL(); llvm::interleaveComma(
290 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
292 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
295 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
298 if (packOp.isLikePad()) {
317 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
318 loc, padOp, packOp.getDest(),
321 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
323 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
331 auto expandShapeResultType =
333 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
334 loc, expandShapeResultType, padOp.getResult(),
335 packingMetadata.reassociations);
340 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
341 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
344 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
345 llvm::interleaveComma(transpPerm,
DBGS() <<
"transpPerm: ");
349 rewriter.
replaceOp(packOp, transposeOp->getResults());
355 tensor::UnPackOp unPackOp) {
360 RankedTensorType packedTensorType = unPackOp.getSourceType();
361 int64_t packedRank = packedTensorType.getRank();
364 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
365 if (unPackOp.isLikeUnPad()) {
374 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
375 loc, destTensorType, unPackOp.getSource(),
379 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
382 nullptr, extractSliceOp};
387 PackingMetadata packingMetadata;
397 RankedTensorType stripMinedTensorType =
399 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
400 stripMinedTensorType, packingMetadata.reassociations);
407 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
408 loc, dims, stripMinedTensorType.getElementType());
409 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
410 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
414 DBGS() <<
"insertPositions: ");
415 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
416 DBGS() <<
"packedShape: ");
418 llvm::interleaveComma(packedToStripMinedShapePerm,
419 DBGS() <<
"packedToStripMinedShapePerm: ");
420 DBGSNL(); llvm::interleaveComma(
421 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
423 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
426 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
430 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
431 loc, collapsedType, transposeOp->getResult(0),
432 packingMetadata.reassociations);
435 int64_t destRank = destTensorType.getRank();
436 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
437 loc, destTensorType, reshapeOp->getResult(0),
443 auto copyOp = rewriter.
create<linalg::CopyOp>(
444 loc, extractSliceOp->getResult(0), unPackOp.getDest());
447 rewriter.
replaceOp(unPackOp, copyOp->getResults());
453 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
455 for (
auto &i : spec) {
456 if (!i.packedDimForEachOperand[operandPos].has_value())
458 res.push_back(i.packedDimForEachOperand[operandPos].value());
464 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
466 for (
auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
469 res.push_back(i.packedSize);
478 linalg::LinalgOp linalgOp,
480 if (packedSizes.size() != linalgOp.getNumLoops()) {
482 "incorrect number of pack sizes");
488 linalgOp.getIteratorTypesArray();
489 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n";
490 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
491 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
497 PackedOperandsDimList listOfPackedOperandsDim;
498 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
501 if (maybeConstant.has_value() && maybeConstant.value() == 0)
504 PackedOperandsDim packedOperandsDims;
505 packedOperandsDims.packedSize = packedSizes[i];
506 FailureOr<SmallVector<std::optional<int64_t>>>
507 maybePackedDimForEachOperand =
509 if (failed(maybePackedDimForEachOperand))
511 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
512 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
515 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
517 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
518 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
DBGSNL();
519 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
520 DBGS() <<
"packedDimForEachOperand: ");
527 linalgOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
529 for (
const auto &operandsList : {inputOperands, initOperands}) {
530 for (
OpOperand *opOperand : operandsList) {
531 int64_t pos = opOperand->getOperandNumber();
532 Value operand = opOperand->get();
534 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
536 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
538 DBGS() <<
"operand: " << operand <<
"\n";
539 llvm::interleaveComma(innerPos,
DBGS() <<
"innerPos: ");
DBGSNL();
540 llvm::interleaveComma(innerPackSizes,
DBGS() <<
"innerPackSizes: ");
542 if (innerPackSizes.empty()) {
543 inputsAndInits.push_back(operand);
546 Value dest = tensor::PackOp::createDestinationTensor(
547 rewriter, loc, operand, innerPackSizes, innerPos,
549 ShapedType operandType = cast<ShapedType>(operand.
getType());
550 bool areConstantTiles =
554 if (areConstantTiles && operandType.hasStaticShape() &&
555 !tensor::PackOp::requirePaddingValue(
556 operandType.getShape(), innerPos,
557 cast<ShapedType>(dest.
getType()).getShape(), {},
559 packOps.push_back(rewriter.
create<tensor::PackOp>(
560 loc, operand, dest, innerPos, innerPackSizes));
566 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
567 packOps.push_back(rewriter.
create<tensor::PackOp>(
568 loc, operand, dest, innerPos, innerPackSizes, zero));
570 inputsAndInits.push_back(packOps.back());
576 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
578 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
579 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
580 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
585 for (
OpResult result : packedLinalgOp->getResults()) {
586 int64_t resultNum = result.getResultNumber();
587 tensor::PackOp maybePackedInit =
588 inits[resultNum].getDefiningOp<tensor::PackOp>();
589 if (!maybePackedInit) {
590 results.push_back(result);
594 unPackOps.push_back(rewriter.
create<tensor::UnPackOp>(
595 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
596 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
597 results.push_back(unPackOps.back());
605 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
634 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
638 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
640 assert(tensorType == transposedValue.
getType() &&
641 "expected tensor type mismatch");
646 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
650 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
654 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
660 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
663 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
664 operandsRef.take_front(linalgOp.getNumDpsInputs()),
665 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
667 linalgOp.getIteratorTypesArray());
669 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
671 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
674 FailureOr<PackTransposeResult>
676 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
683 tensor::PackOp transposedPackOp =
684 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
686 if (!packOp.getResult().hasOneUse())
689 OpOperand &packUse = *packOp->getUses().begin();
690 if (packUse.
getOwner() != linalgOp) {
692 linalgOp,
"not a single use by the LinalgOp target");
695 (!linalgOp.isDpsInit(&packUse) ||
696 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
698 "not produced by the LinalgOp target");
704 int64_t numLeadingDims = packOp.getSourceRank();
705 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
709 if (permutation.empty())
710 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
712 if (innerPerm.empty()) {
715 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
717 llvm::append_range(permutation,
718 llvm::map_range(innerPerm, [&](int64_t pos) {
719 return numLeadingDims + pos;
731 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
734 tensor::UnPackOp transposedUnPackOp;
737 transposedLinalgOp->getOpOperand(packUseOperandNumber);
738 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
740 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
741 rewriter, loc, transposedResult, innerPerm, outerPerm);
743 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
747 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
765 FailureOr<PackResult>
770 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
771 assert((mnkPaddedSizesNextMultipleOf.empty() ||
772 mnkPaddedSizesNextMultipleOf.size() == 3) &&
773 "num of packing sizes next multiple should be empty or of size 3");
774 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
777 int64_t numLoops = linalgOp.getNumLoops();
779 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
780 << numLoops <<
"\nin: " << linalgOp <<
"\n");
782 linalgOp,
"need 3+ loops to find a matmul to pack");
786 int64_t numPackedDims = mnkPackedSizes.size();
788 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
789 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
791 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
792 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
794 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
795 paddedSizesNextMultipleOf[mnkOrder[i]] =
796 mnkPaddedSizesNextMultipleOf.empty() ? 0
797 : mnkPaddedSizesNextMultipleOf[i];
801 FailureOr<ContractionDimensions> maybeDimensions =
803 if (failed(maybeDimensions)) {
804 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
807 "couldn't infer matmul iterators");
815 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
816 kPos = maybeDimensions->k.back();
818 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
819 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
823 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
825 FailureOr<GenericOp> generalizeResult =
827 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
828 genericOp = *generalizeResult;
836 LLVM_DEBUG(llvm::interleaveComma(permutation,
DBGS() <<
"perm: ");
DBGSNL(););
839 FailureOr<GenericOp> interchangeResult =
841 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
842 genericOp = *interchangeResult;
843 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
860 cast<LinalgOp>(genericOp.getOperation())
861 .createLoopRanges(rewriter, genericOp.getLoc());
865 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
866 DBGS() <<
"paddedSizesNextMultipleOf: ");
868 LLVM_DEBUG(llvm::interleaveComma(loopRanges,
DBGS() <<
"loopRanges: ",
869 [](
Range r) { llvm::dbgs() << r.size; });
873 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
874 if (paddedSizesNextMultipleOf[i] == 0) {
875 adjustedPackedSizes.push_back(packedSizes[i]);
882 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
883 {loopRanges[adjustedPackedSizes.size()].size,
884 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
886 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
887 DBGS() <<
"adjustedPackedSizes: ");
894 return pack(rewriter, genericOp, adjustedPackedSizes);
903 assert(!tileSizeComputationFunction &&
"tile sizes already set");
909 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
927 auto padValue = padOp.getConstantPaddingValue();
929 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
932 auto generateOp = rewriter.
create<tensor::GenerateOp>(
933 padOp.getLoc(), padOp.getResultType(), dynSizes);
936 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
945 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
949 padOp.getLoc(), cast<IntegerAttr>(ofr.get<
Attribute>()).getInt())
953 auto resultType = padOp.getResultType();
957 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
958 if (resultType.isDynamicDim(dim)) {
960 padOp.getSource(), dim));
963 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
965 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
966 dynSizes.push_back(plusHigh);
968 staticSizes.push_back(resultType.getDimSize(dim));
972 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
973 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
982 auto sourceType = padOp.getSourceType();
990 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
998 if (!sliceOp.hasUnitStride())
1001 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1005 bool zeroSliceGuard =
true;
1007 if (std::optional<bool> control = controlFn(sliceOp))
1008 zeroSliceGuard = *control;
1013 FailureOr<TilingResult> tilingResult =
1015 sliceOp.getMixedSizes(), zeroSliceGuard);
1016 if (failed(tilingResult))
1020 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1030 tensor::PackOp packOp) {
1031 Value input = packOp.getSource();
1032 if (!packOp.getPaddingValue()) {
1036 assert(llvm::all_of(packOp.getAllOuterDims(),
1037 [](int64_t val) { return val == 1; }) &&
1038 "some outer dims are != 1");
1041 ShapedType inputType = packOp.getSourceType();
1042 int64_t inputRank = inputType.getRank();
1045 packOp.getDimAndTileMapping();
1052 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1055 if (!tileAndPosMapping.count(dimIdx)) {
1056 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1057 assert(inputDimSize == 1 &&
1058 "with all outer dims == 1, this non-tiled input dim should be 1!");
1059 paddedShape.push_back(inputDimSize);
1066 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1070 if (cstTileSize.has_value()) {
1071 paddedShape.push_back(cstTileSize.value());
1076 paddedShape.push_back(ShapedType::kDynamic);
1079 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1084 false, loc, builder,
1094 constexpr int64_t kNonTiledMarker = -1;
1099 vec, [&](int64_t v) {
return v != kNonTiledMarker; }));
1114 int64_t unpackedRank = shape.size();
1115 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1116 if (llvm::is_contained(innerDimsPos, i)) {
1117 innerDims.push_back(dim++);
1122 outerDims.push_back(dim++);
1123 if (!outerDimsPerm.empty())
1124 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1130 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1135 rankReducedOuterDimsPerm =
1137 if (!rankReducedOuterDimsPerm.empty())
1138 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1141 perm.append(innerDims);
1150 if (llvm::any_of(packOp.getTiledOuterDims(),
1151 [](int64_t dim) { return dim != 1; })) {
1153 packOp,
"require the tiled outer dimensions of the result are all 1s");
1160 auto inputShape = packOp.getSourceType().getShape();
1162 packOp.getDimAndTileMapping();
1165 int64_t srcRank = packOp.getSourceRank();
1171 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1172 if (dimAndTileMapping.count(i)) {
1173 readShapeForExtractSlice.push_back(
1175 .value_or(ShapedType::kDynamic));
1176 readSizes.push_back(dimAndTileMapping[i]);
1177 transShapeForEmpty.push_back(dimAndTileMapping[i]);
1180 if (ShapedType::isDynamic(inputShape[i])) {
1181 readSizes.push_back(
1184 readSizes.push_back(rewriter.
getIndexAttr(inputShape[i]));
1186 if (inputShape[i] != 1) {
1187 readShapeForExtractSlice.push_back(inputShape[i]);
1188 transShapeForEmpty.push_back(rewriter.
getIndexAttr(inputShape[i]));
1192 Type elemType = packOp.getSourceType().getElementType();
1196 loc, readType, input, readOffsets, readSizes, readStrides);
1200 inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
1202 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n";
1203 llvm::interleaveComma(perm,
DBGS() <<
"perm: ");
DBGSNL(););
1205 applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
1208 rewriter.
create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
1210 rewriter.
create<linalg::TransposeOp>(loc,
tile, empty, perm);
1213 int64_t destRank = packOp.getDestRank();
1219 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1220 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1221 writeSizes, writeStrides);
1222 rewriter.
replaceOp(packOp, insert.getResult());
1229 int64_t srcRank = unpackOp.getSourceRank();
1230 int64_t destRank = unpackOp.getDestRank();
1233 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1234 [](int64_t dim) { return dim != 1; })) {
1237 "require the tiled outer dimensions of the result are all 1s");
1242 Value source = unpackOp.getSource();
1244 unpackOp.getDimAndTileMapping();
1252 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1253 if (dimAndTileMapping.count(i)) {
1254 readSizes.push_back(oneIdxAttr);
1258 if (ShapedType::isDynamic(srcShape[i])) {
1260 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1261 readSizes.push_back(dynamicDim);
1262 dynamicDims.push_back(dynamicDim);
1264 readSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1266 if (srcShape[i] != 1)
1267 readShape.push_back(srcShape[i]);
1269 auto mixedTiles = unpackOp.getMixedTiles();
1270 readSizes.append(mixedTiles.begin(), mixedTiles.end());
1274 auto tileShape = srcShape.drop_front(destRank);
1276 readShape.append(tileShape.begin(), tileShape.end());
1277 Type elemType = unpackOp.getSourceType().getElementType();
1279 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1280 loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1284 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1288 applyPermutationToVector<int64_t>(transpShape, perm);
1291 rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1293 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1297 int numLoops = transpShape.size();
1302 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1303 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1304 tileSizes.push_back(
1308 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1309 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1315 for (
int i = 0, idx = 0; i < destRank; ++i) {
1316 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1317 writeSizes.push_back(tileSizes[idx++]);
1319 writeSizes.push_back(oneIdxAttr);
1321 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1322 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1324 rewriter.
replaceOp(unpackOp, insert.getResult());
1337 template <
typename Conv2DOp,
typename Conv1DOp>
1340 if (convOp.hasPureBufferSemantics())
1343 Value input = convOp.getInputs().front();
1344 Value kernel = convOp.getInputs().back();
1345 Value output = convOp.getOutputs().front();
1347 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1348 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1349 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1351 auto kernelShape = kernelType.getShape();
1352 auto outputShape = outputType.getShape();
1355 auto [khIndex, kwIndex, ohIndex, owIndex] =
1358 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1359 return std::make_tuple(0, 1, 1, 2);
1361 .Case([&](linalg::Conv2DNchwFchwOp op) {
1362 return std::make_tuple(2, 3, 2, 3);
1364 .Case([&](linalg::PoolingNhwcSumOp op) {
1365 return std::make_tuple(0, 1, 1, 2);
1367 .Case([&](linalg::PoolingNchwSumOp op) {
1368 return std::make_tuple(0, 1, 2, 3);
1370 .Case([&](linalg::PoolingNhwcMaxOp op) {
1371 return std::make_tuple(0, 1, 1, 2);
1373 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1374 return std::make_tuple(0, 1, 1, 2);
1376 .Case([&](linalg::PoolingNhwcMinOp op) {
1377 return std::make_tuple(0, 1, 1, 2);
1379 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1380 return std::make_tuple(0, 1, 1, 2);
1382 .Case([&](linalg::PoolingNchwMaxOp op) {
1383 return std::make_tuple(0, 1, 2, 3);
1386 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1387 return std::make_tuple(0, 0, 0, 0);
1392 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1393 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1394 bool removeH = (khSize == 1 && ohSize == 1);
1395 bool removeW = (kwSize == 1 && owSize == 1);
1396 if (!removeH && !removeW)
1402 RankedTensorType newInputType =
1403 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1404 RankedTensorType newKernelType =
1405 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1406 RankedTensorType newOutputType =
1407 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1412 rewriter, loc, input, newInputType);
1414 rewriter, loc, kernel, newKernelType);
1416 rewriter, loc, output, newOutputType);
1421 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1422 strides.erase(strides.begin() + (removeH ? 0 : 1));
1426 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1427 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1430 auto conv1DOp = rewriter.
create<Conv1DOp>(
1431 loc, newOutputType,
ValueRange{newInput, newKernel},
1432 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1436 rewriter, loc, conv1DOp.getResult(0), output);
1453 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1457 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1461 FailureOr<DepthwiseConv1DNwcWcOp>
1464 if (convOp.hasPureBufferSemantics())
1467 Value input = convOp.getInputs().front();
1468 Value kernel = convOp.getInputs().back();
1469 Value output = convOp.getOutputs().front();
1471 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1472 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1473 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1475 auto kernelShape = kernelType.getShape();
1476 auto outputShape = outputType.getShape();
1480 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1481 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1482 bool removeH = (khSize == 1 && ohSize == 1);
1483 bool removeW = (kwSize == 1 && owSize == 1);
1484 if (!removeH && !removeW)
1490 RankedTensorType newInputType =
1491 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1492 RankedTensorType newKernelType =
1493 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1494 RankedTensorType newOutputType =
1495 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1500 rewriter, loc, input, newInputType);
1502 rewriter, loc, kernel, newKernelType);
1504 rewriter, loc, output, newOutputType);
1508 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1509 strides.erase(strides.begin() + (removeH ? 0 : 1));
1513 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1514 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1517 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1518 loc, newOutputType,
ValueRange{newInput, newKernel},
1519 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1523 rewriter, loc, conv1DOp.getResult(0), output);
1532 if (convOp.hasPureBufferSemantics())
1535 Value input = convOp.getInputs().front();
1536 Value kernel = convOp.getInputs().back();
1537 Value output = convOp.getOutputs().front();
1539 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1540 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1541 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1543 auto kernelShape = kernelType.getShape();
1544 auto outputShape = outputType.getShape();
1548 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1549 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1550 bool removeH = (khSize == 1 && ohSize == 1);
1551 bool removeW = (kwSize == 1 && owSize == 1);
1552 if (!removeH && !removeW)
1558 RankedTensorType newInputType =
1559 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1560 RankedTensorType newKernelType =
1561 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1562 RankedTensorType newOutputType =
1563 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1568 rewriter, loc, input, newInputType);
1570 rewriter, loc, kernel, newKernelType);
1572 rewriter, loc, output, newOutputType);
1574 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1580 rewriter, loc, conv1DOp.getResult(0), output);
1599 PoolingNwcMaxUnsignedOp>,
1602 PoolingNwcMinUnsignedOp>,
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDim={})
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
OptimizeCopyFn optimizeCopyFn
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.