33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
40 #define DEBUG_TYPE "linalg-transforms"
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
62 .Case<scf::ForOp>([&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
66 return partialIteration->getResults();
67 assert(!partialIteration &&
"expected that loop was not peeled");
68 return forOp->getResults();
77 for (
auto loopOp : loops)
90 if (!e.isFunctionOfDim(dim))
148 static FailureOr<SmallVector<std::optional<int64_t>>>
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
156 indexingMaps.size(), std::nullopt);
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
160 AffineMap map = indexingMaps[operandIdx];
163 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
171 "num results invariant violation");
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(map);
179 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
185 newMaps.push_back(map);
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
190 indexingMaps = newMaps;
192 return packedDimPerIndexingMap;
198 struct PackedOperandsDim {
204 struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(packedOperandsDims);
220 tensor::PackOp packOp) {
222 auto packedTensorType =
223 cast<RankedTensorType>(packOp->getResultTypes().front());
224 if (llvm::any_of(packOp.getStaticInnerTiles(),
225 [](int64_t size) { return ShapedType::isDynamic(size); })) {
228 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
237 PackingMetadata packingMetadata = computePackingMetadata(
238 packedTensorType.getRank(), packOp.getInnerDimsPos());
252 for (
auto [pos, innerSize] :
253 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
255 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
265 rewriter, loc, map, {outerSize, origSize, innerSize});
267 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269 packingMetadata.reassociations);
270 Value paddingValue = packOp.getPaddingValue();
272 paddingValue = rewriter.
create<arith::ConstantOp>(
276 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
277 highs, paddingValue,
false);
280 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
281 DBGS() <<
"insertPositions: ");
282 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
283 DBGS() <<
"outerPositions: ");
284 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
285 DBGS() <<
"packedShape: ");
287 llvm::interleaveComma(packedToStripMinedShapePerm,
288 DBGS() <<
"packedToStripMinedShapePerm: ");
289 DBGSNL(); llvm::interleaveComma(
290 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
292 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
295 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
298 if (packOp.isLikePad()) {
319 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
324 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
326 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
334 auto expandShapeResultType =
336 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
337 loc, expandShapeResultType, padOp.getResult(),
338 packingMetadata.reassociations);
343 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
344 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
347 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
348 llvm::interleaveComma(transpPerm,
DBGS() <<
"transpPerm: ");
352 rewriter.
replaceOp(packOp, transposeOp->getResults());
358 tensor::UnPackOp unPackOp) {
363 RankedTensorType packedTensorType = unPackOp.getSourceType();
364 int64_t packedRank = packedTensorType.getRank();
367 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
368 if (unPackOp.isLikeUnPad()) {
377 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
378 loc, destTensorType, unPackOp.getSource(),
382 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
385 nullptr, extractSliceOp};
390 PackingMetadata packingMetadata;
400 RankedTensorType stripMinedTensorType =
402 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
403 stripMinedTensorType, packingMetadata.reassociations);
410 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
411 loc, dims, stripMinedTensorType.getElementType());
412 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
413 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
416 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
417 DBGS() <<
"insertPositions: ");
418 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
419 DBGS() <<
"packedShape: ");
421 llvm::interleaveComma(packedToStripMinedShapePerm,
422 DBGS() <<
"packedToStripMinedShapePerm: ");
423 DBGSNL(); llvm::interleaveComma(
424 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
426 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
429 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
433 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
434 loc, collapsedType, transposeOp->getResult(0),
435 packingMetadata.reassociations);
438 int64_t destRank = destTensorType.getRank();
439 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
440 loc, destTensorType, reshapeOp->getResult(0),
446 auto copyOp = rewriter.
create<linalg::CopyOp>(
447 loc, extractSliceOp->getResult(0), unPackOp.getDest());
450 rewriter.
replaceOp(unPackOp, copyOp->getResults());
456 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
458 for (
auto &i : spec) {
459 if (!i.packedDimForEachOperand[operandPos].has_value())
461 res.push_back(i.packedDimForEachOperand[operandPos].value());
467 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
469 for (
auto &i : spec) {
470 if (!i.packedDimForEachOperand[operandPos].has_value())
472 res.push_back(i.packedSize);
481 linalg::LinalgOp linalgOp,
483 if (packedSizes.size() != linalgOp.getNumLoops()) {
485 "incorrect number of pack sizes");
491 linalgOp.getIteratorTypesArray();
492 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n";
493 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
494 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
500 PackedOperandsDimList listOfPackedOperandsDim;
501 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
504 if (maybeConstant.has_value() && maybeConstant.value() == 0)
507 PackedOperandsDim packedOperandsDims;
508 packedOperandsDims.packedSize = packedSizes[i];
509 FailureOr<SmallVector<std::optional<int64_t>>>
510 maybePackedDimForEachOperand =
512 if (failed(maybePackedDimForEachOperand))
514 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
515 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
518 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
520 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
521 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
DBGSNL();
522 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
523 DBGS() <<
"packedDimForEachOperand: ");
530 linalgOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
532 for (
const auto &operandsList : {inputOperands, initOperands}) {
533 for (
OpOperand *opOperand : operandsList) {
534 int64_t pos = opOperand->getOperandNumber();
535 Value operand = opOperand->get();
537 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
539 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
541 DBGS() <<
"operand: " << operand <<
"\n";
542 llvm::interleaveComma(innerPos,
DBGS() <<
"innerPos: ");
DBGSNL();
543 llvm::interleaveComma(innerPackSizes,
DBGS() <<
"innerPackSizes: ");
545 if (innerPackSizes.empty()) {
546 inputsAndInits.push_back(operand);
549 Value dest = tensor::PackOp::createDestinationTensor(
550 rewriter, loc, operand, innerPackSizes, innerPos,
552 ShapedType operandType = cast<ShapedType>(operand.
getType());
553 bool areConstantTiles =
557 if (areConstantTiles && operandType.hasStaticShape() &&
558 !tensor::PackOp::requirePaddingValue(
559 operandType.getShape(), innerPos,
560 cast<ShapedType>(dest.
getType()).getShape(), {},
562 packOps.push_back(rewriter.
create<tensor::PackOp>(
563 loc, operand, dest, innerPos, innerPackSizes));
569 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
570 packOps.push_back(rewriter.
create<tensor::PackOp>(
571 loc, operand, dest, innerPos, innerPackSizes, zero));
573 inputsAndInits.push_back(packOps.back());
579 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
581 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
582 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
583 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
588 for (
OpResult result : packedLinalgOp->getResults()) {
589 int64_t resultNum = result.getResultNumber();
590 tensor::PackOp maybePackedInit =
591 inits[resultNum].getDefiningOp<tensor::PackOp>();
592 if (!maybePackedInit) {
593 results.push_back(result);
597 unPackOps.push_back(rewriter.
create<tensor::UnPackOp>(
598 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
599 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
600 results.push_back(unPackOps.back());
608 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
637 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
641 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
643 assert(tensorType == transposedValue.
getType() &&
644 "expected tensor type mismatch");
649 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
653 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
657 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
663 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
666 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
667 operandsRef.take_front(linalgOp.getNumDpsInputs()),
668 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
670 linalgOp.getIteratorTypesArray());
672 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
674 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
677 FailureOr<PackTransposeResult>
679 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
686 tensor::PackOp transposedPackOp =
687 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
689 if (!packOp.getResult().hasOneUse())
692 OpOperand &packUse = *packOp->getUses().begin();
693 if (packUse.
getOwner() != linalgOp) {
695 linalgOp,
"not a single use by the LinalgOp target");
698 (!linalgOp.isDpsInit(&packUse) ||
699 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
701 "not produced by the LinalgOp target");
707 int64_t numLeadingDims = packOp.getSourceRank();
708 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
712 if (permutation.empty())
713 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
715 if (innerPerm.empty()) {
718 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
720 llvm::append_range(permutation,
721 llvm::map_range(innerPerm, [&](int64_t pos) {
722 return numLeadingDims + pos;
734 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
737 tensor::UnPackOp transposedUnPackOp;
740 transposedLinalgOp->getOpOperand(packUseOperandNumber);
741 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
743 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
744 rewriter, loc, transposedResult, innerPerm, outerPerm);
746 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
750 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
768 FailureOr<PackResult>
773 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
774 assert((mnkPaddedSizesNextMultipleOf.empty() ||
775 mnkPaddedSizesNextMultipleOf.size() == 3) &&
776 "num of packing sizes next multiple should be empty or of size 3");
777 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
780 int64_t numLoops = linalgOp.getNumLoops();
782 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
783 << numLoops <<
"\nin: " << linalgOp <<
"\n");
785 linalgOp,
"need 3+ loops to find a matmul to pack");
789 int64_t numPackedDims = mnkPackedSizes.size();
791 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
792 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
794 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
795 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
797 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
798 paddedSizesNextMultipleOf[mnkOrder[i]] =
799 mnkPaddedSizesNextMultipleOf.empty() ? 0
800 : mnkPaddedSizesNextMultipleOf[i];
804 FailureOr<ContractionDimensions> maybeDimensions =
806 if (failed(maybeDimensions)) {
807 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
810 "couldn't infer matmul iterators");
818 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
819 kPos = maybeDimensions->k.back();
821 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
822 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
826 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
828 FailureOr<GenericOp> generalizeResult =
830 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
831 genericOp = *generalizeResult;
839 LLVM_DEBUG(llvm::interleaveComma(permutation,
DBGS() <<
"perm: ");
DBGSNL(););
842 FailureOr<GenericOp> interchangeResult =
844 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
845 genericOp = *interchangeResult;
846 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
863 cast<LinalgOp>(genericOp.getOperation())
864 .createLoopRanges(rewriter, genericOp.getLoc());
868 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
869 DBGS() <<
"paddedSizesNextMultipleOf: ");
871 LLVM_DEBUG(llvm::interleaveComma(loopRanges,
DBGS() <<
"loopRanges: ",
872 [](
Range r) { llvm::dbgs() << r.size; });
876 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
877 if (paddedSizesNextMultipleOf[i] == 0) {
878 adjustedPackedSizes.push_back(packedSizes[i]);
885 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
886 {loopRanges[adjustedPackedSizes.size()].size,
887 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
889 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
890 DBGS() <<
"adjustedPackedSizes: ");
897 return pack(rewriter, genericOp, adjustedPackedSizes);
906 assert(!tileSizeComputationFunction &&
"tile sizes already set");
912 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
930 auto padValue = padOp.getConstantPaddingValue();
932 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
935 auto generateOp = rewriter.
create<tensor::GenerateOp>(
936 padOp.getLoc(), padOp.getResultType(), dynSizes);
939 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
948 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
952 padOp.getLoc(), cast<IntegerAttr>(ofr.get<
Attribute>()).getInt())
956 auto resultType = padOp.getResultType();
960 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
961 if (resultType.isDynamicDim(dim)) {
963 padOp.getSource(), dim));
966 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
968 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
969 dynSizes.push_back(plusHigh);
971 staticSizes.push_back(resultType.getDimSize(dim));
975 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
976 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
985 auto sourceType = padOp.getSourceType();
993 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1001 if (!sliceOp.hasUnitStride())
1004 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1008 bool zeroSliceGuard =
true;
1010 if (std::optional<bool> control = controlFn(sliceOp))
1011 zeroSliceGuard = *control;
1016 FailureOr<TilingResult> tilingResult =
1018 sliceOp.getMixedSizes(), zeroSliceGuard);
1019 if (failed(tilingResult))
1023 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1030 tensor::PackOp packOp) {
1031 Value input = packOp.getSource();
1032 if (!packOp.getPaddingValue()) {
1037 ShapedType inputType = packOp.getSourceType();
1038 int64_t inputRank = inputType.getRank();
1039 assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
1040 [](int64_t val) { return val == 1; }));
1044 packOp.getDimAndTileMapping();
1045 for (int64_t dim = 0; dim < inputRank; ++dim) {
1046 int64_t size = inputType.getDimSize(dim);
1047 if (!tileAndPosMapping.count(dim)) {
1048 paddedShape.push_back(size);
1053 std::optional<int64_t> tileSize =
1055 assert(tileSize.has_value() &&
"dynamic inner tile size is not supported");
1056 paddedShape.push_back(tileSize.value());
1061 false, loc, builder);
1070 constexpr int64_t kNonTiledMarker = -1;
1075 vec, [&](int64_t v) {
return v != kNonTiledMarker; }));
1090 int64_t unpackedRank = shape.size();
1091 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1092 if (llvm::is_contained(innerDimsPos, i)) {
1093 innerDims.push_back(dim++);
1098 outerDims.push_back(dim++);
1099 if (!outerDimsPerm.empty())
1100 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1106 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1111 rankReducedOuterDimsPerm =
1113 if (!rankReducedOuterDimsPerm.empty())
1114 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1117 perm.append(innerDims);
1124 if (llvm::any_of(packOp.getMixedTiles(),
1127 "require inner tile sizes being static");
1132 auto innerDimsPos = packOp.getInnerDimsPos();
1133 int64_t srcRank = packOp.getSourceRank();
1134 auto destShape = packOp.getDestType().getShape();
1135 if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
1136 return destShape[index] != 1;
1139 packOp,
"require the tiled outer dimensions of the result are all 1s");
1146 auto inputShape = packOp.getSourceType().getShape();
1148 packOp.getDimAndTileMapping();
1155 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1156 if (dimAndTileMapping.count(i)) {
1158 .value_or(ShapedType::kDynamic));
1159 readSizes.push_back(dimAndTileMapping[i]);
1162 if (ShapedType::isDynamic(inputShape[i])) {
1163 readSizes.push_back(
1166 readSizes.push_back(rewriter.
getIndexAttr(inputShape[i]));
1168 if (inputShape[i] != 1)
1169 readShape.push_back(inputShape[i]);
1172 Type elemType = packOp.getSourceType().getElementType();
1176 loc, readType, input, readOffsets, readSizes, readStrides);
1181 inputShape, innerDimsPos, packOp.getOuterDimsPerm());
1183 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n";
1184 llvm::interleaveComma(perm,
DBGS() <<
"perm: ");
DBGSNL(););
1187 applyPermutationToVector<int64_t>(transpShape, perm);
1189 Value empty = rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType);
1191 rewriter.
create<linalg::TransposeOp>(loc,
tile, empty, perm);
1194 int64_t destRank = packOp.getDestRank();
1200 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1201 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1202 writeSizes, writeStrides);
1203 rewriter.
replaceOp(packOp, insert.getResult());
1210 int64_t srcRank = unpackOp.getSourceRank();
1211 int64_t destRank = unpackOp.getDestRank();
1214 if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
1215 return srcShape[index] != 1;
1219 "require the tiled outer dimensions of the result are all 1s");
1224 Value source = unpackOp.getSource();
1226 unpackOp.getDimAndTileMapping();
1234 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1235 if (dimAndTileMapping.count(i)) {
1236 readSizes.push_back(oneIdxAttr);
1240 if (ShapedType::isDynamic(srcShape[i])) {
1242 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1243 readSizes.push_back(dynamicDim);
1244 dynamicDims.push_back(dynamicDim);
1246 readSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1248 if (srcShape[i] != 1)
1249 readShape.push_back(srcShape[i]);
1251 auto mixedTiles = unpackOp.getMixedTiles();
1252 readSizes.append(mixedTiles.begin(), mixedTiles.end());
1256 auto tileShape = srcShape.drop_front(destRank);
1258 readShape.append(tileShape.begin(), tileShape.end());
1259 Type elemType = unpackOp.getSourceType().getElementType();
1261 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1262 loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1266 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1270 applyPermutationToVector<int64_t>(transpShape, perm);
1273 rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1275 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1279 int numLoops = transpShape.size();
1284 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1285 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1286 tileSizes.push_back(
1290 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1291 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1297 for (
int i = 0, idx = 0; i < destRank; ++i) {
1298 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1299 writeSizes.push_back(tileSizes[idx++]);
1301 writeSizes.push_back(oneIdxAttr);
1303 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1304 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1306 rewriter.
replaceOp(unpackOp, insert.getResult());
1319 template <
typename Conv2DOp,
typename Conv1DOp>
1322 if (convOp.hasPureBufferSemantics())
1325 Value input = convOp.getInputs().front();
1326 Value kernel = convOp.getInputs().back();
1327 Value output = convOp.getOutputs().front();
1329 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1330 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1331 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1333 auto kernelShape = kernelType.getShape();
1334 auto outputShape = outputType.getShape();
1337 auto [khIndex, kwIndex, ohIndex, owIndex] =
1340 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1341 return std::make_tuple(0, 1, 1, 2);
1343 .Case([&](linalg::Conv2DNchwFchwOp op) {
1344 return std::make_tuple(2, 3, 2, 3);
1346 .Case([&](linalg::PoolingNhwcSumOp op) {
1347 return std::make_tuple(0, 1, 1, 2);
1349 .Case([&](linalg::PoolingNchwSumOp op) {
1350 return std::make_tuple(0, 1, 2, 3);
1352 .Case([&](linalg::PoolingNhwcMaxOp op) {
1353 return std::make_tuple(0, 1, 1, 2);
1355 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1356 return std::make_tuple(0, 1, 1, 2);
1358 .Case([&](linalg::PoolingNhwcMinOp op) {
1359 return std::make_tuple(0, 1, 1, 2);
1361 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1362 return std::make_tuple(0, 1, 1, 2);
1364 .Case([&](linalg::PoolingNchwMaxOp op) {
1365 return std::make_tuple(0, 1, 2, 3);
1368 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1369 return std::make_tuple(0, 0, 0, 0);
1374 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1375 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1376 bool removeH = (khSize == 1 && ohSize == 1);
1377 bool removeW = (kwSize == 1 && owSize == 1);
1378 if (!removeH && !removeW)
1384 RankedTensorType newInputType =
1385 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1386 RankedTensorType newKernelType =
1387 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1388 RankedTensorType newOutputType =
1389 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1394 rewriter, loc, input, newInputType);
1396 rewriter, loc, kernel, newKernelType);
1398 rewriter, loc, output, newOutputType);
1403 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1404 strides.erase(strides.begin() + (removeH ? 0 : 1));
1408 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1409 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1412 auto conv1DOp = rewriter.
create<Conv1DOp>(
1413 loc, newOutputType,
ValueRange{newInput, newKernel},
1414 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1418 rewriter, loc, conv1DOp.getResult(0), output);
1435 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1439 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1443 FailureOr<DepthwiseConv1DNwcWcOp>
1446 if (convOp.hasPureBufferSemantics())
1449 Value input = convOp.getInputs().front();
1450 Value kernel = convOp.getInputs().back();
1451 Value output = convOp.getOutputs().front();
1453 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1454 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1455 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1457 auto kernelShape = kernelType.getShape();
1458 auto outputShape = outputType.getShape();
1462 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1463 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1464 bool removeH = (khSize == 1 && ohSize == 1);
1465 bool removeW = (kwSize == 1 && owSize == 1);
1466 if (!removeH && !removeW)
1472 RankedTensorType newInputType =
1473 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1474 RankedTensorType newKernelType =
1475 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1476 RankedTensorType newOutputType =
1477 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1482 rewriter, loc, input, newInputType);
1484 rewriter, loc, kernel, newKernelType);
1486 rewriter, loc, output, newOutputType);
1490 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1491 strides.erase(strides.begin() + (removeH ? 0 : 1));
1495 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1496 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1499 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1500 loc, newOutputType,
ValueRange{newInput, newKernel},
1501 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1505 rewriter, loc, conv1DOp.getResult(0), output);
1514 if (convOp.hasPureBufferSemantics())
1517 Value input = convOp.getInputs().front();
1518 Value kernel = convOp.getInputs().back();
1519 Value output = convOp.getOutputs().front();
1521 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1522 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1523 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1525 auto kernelShape = kernelType.getShape();
1526 auto outputShape = outputType.getShape();
1530 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1531 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1532 bool removeH = (khSize == 1 && ohSize == 1);
1533 bool removeW = (kwSize == 1 && owSize == 1);
1534 if (!removeH && !removeW)
1540 RankedTensorType newInputType =
1541 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1542 RankedTensorType newKernelType =
1543 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1544 RankedTensorType newOutputType =
1545 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1550 rewriter, loc, input, newInputType);
1552 rewriter, loc, kernel, newKernelType);
1554 rewriter, loc, output, newOutputType);
1556 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1562 rewriter, loc, conv1DOp.getResult(0), output);
1581 PoolingNwcMaxUnsignedOp>,
1584 PoolingNwcMinUnsignedOp>,
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder)
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
OptimizeCopyFn optimizeCopyFn
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.