33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
40 #define DEBUG_TYPE "linalg-transforms"
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
62 .Case<scf::ForOp>([&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
66 return partialIteration->getResults();
67 assert(!partialIteration &&
"expected that loop was not peeled");
68 return forOp->getResults();
77 for (
auto loopOp : loops)
90 if (!e.isFunctionOfDim(dim))
148 static FailureOr<SmallVector<std::optional<int64_t>>>
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
156 indexingMaps.size(), std::nullopt);
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
160 AffineMap map = indexingMaps[operandIdx];
163 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
171 "num results invariant violation");
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(map);
179 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
185 newMaps.push_back(map);
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
190 indexingMaps = newMaps;
192 return packedDimPerIndexingMap;
198 struct PackedOperandsDim {
204 struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(packedOperandsDims);
220 tensor::PackOp packOp,
221 bool lowerPadLikeWithInsertSlice) {
223 auto packedTensorType =
224 cast<RankedTensorType>(packOp->getResultTypes().front());
225 if (llvm::any_of(packOp.getStaticInnerTiles(),
226 [](int64_t size) { return ShapedType::isDynamic(size); })) {
229 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
238 PackingMetadata packingMetadata = computePackingMetadata(
239 packedTensorType.getRank(), packOp.getInnerDimsPos());
253 for (
auto [pos, innerSize] :
254 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
256 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
266 rewriter, loc, map, {outerSize, origSize, innerSize});
268 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
270 packingMetadata.reassociations);
271 Value paddingValue = packOp.getPaddingValue();
273 paddingValue = rewriter.
create<arith::ConstantOp>(
277 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
278 highs, paddingValue,
false);
281 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
282 DBGS() <<
"insertPositions: ");
283 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
284 DBGS() <<
"outerPositions: ");
285 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
286 DBGS() <<
"packedShape: ");
288 llvm::interleaveComma(packedToStripMinedShapePerm,
289 DBGS() <<
"packedToStripMinedShapePerm: ");
290 DBGSNL(); llvm::interleaveComma(
291 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
293 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
296 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
299 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
318 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
319 loc, padOp, packOp.getDest(),
322 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
324 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
332 auto expandShapeResultType =
334 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
335 loc, expandShapeResultType, padOp.getResult(),
336 packingMetadata.reassociations);
341 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
342 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
345 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
346 llvm::interleaveComma(transpPerm,
DBGS() <<
"transpPerm: ");
350 rewriter.
replaceOp(packOp, transposeOp->getResults());
355 FailureOr<LowerUnPackOpResult>
357 bool lowerUnpadLikeWithExtractSlice) {
362 RankedTensorType packedTensorType = unPackOp.getSourceType();
363 int64_t packedRank = packedTensorType.getRank();
366 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
367 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
376 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
377 loc, destTensorType, unPackOp.getSource(),
381 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
384 nullptr, extractSliceOp};
389 PackingMetadata packingMetadata;
399 RankedTensorType stripMinedTensorType =
401 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
402 stripMinedTensorType, packingMetadata.reassociations);
409 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
410 loc, dims, stripMinedTensorType.getElementType());
411 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
412 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
415 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
416 DBGS() <<
"insertPositions: ");
417 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
418 DBGS() <<
"packedShape: ");
420 llvm::interleaveComma(packedToStripMinedShapePerm,
421 DBGS() <<
"packedToStripMinedShapePerm: ");
422 DBGSNL(); llvm::interleaveComma(
423 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
425 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
428 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
432 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
433 loc, collapsedType, transposeOp->getResult(0),
434 packingMetadata.reassociations);
437 int64_t destRank = destTensorType.getRank();
438 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
439 loc, destTensorType, reshapeOp->getResult(0),
445 auto copyOp = rewriter.
create<linalg::CopyOp>(
446 loc, extractSliceOp->getResult(0), unPackOp.getDest());
449 rewriter.
replaceOp(unPackOp, copyOp->getResults());
455 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
457 for (
auto &i : spec) {
458 if (!i.packedDimForEachOperand[operandPos].has_value())
460 res.push_back(i.packedDimForEachOperand[operandPos].value());
466 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
468 for (
auto &i : spec) {
469 if (!i.packedDimForEachOperand[operandPos].has_value())
471 res.push_back(i.packedSize);
480 linalg::LinalgOp linalgOp,
482 if (packedSizes.size() != linalgOp.getNumLoops()) {
484 "incorrect number of pack sizes");
490 linalgOp.getIteratorTypesArray();
491 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n";
492 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
493 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
499 PackedOperandsDimList listOfPackedOperandsDim;
500 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
503 if (maybeConstant.has_value() && maybeConstant.value() == 0)
506 PackedOperandsDim packedOperandsDims;
507 packedOperandsDims.packedSize = packedSizes[i];
508 FailureOr<SmallVector<std::optional<int64_t>>>
509 maybePackedDimForEachOperand =
511 if (failed(maybePackedDimForEachOperand))
513 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
514 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
517 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
519 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
520 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
DBGSNL();
521 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
522 DBGS() <<
"packedDimForEachOperand: ");
529 linalgOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
531 for (
const auto &operandsList : {inputOperands, initOperands}) {
532 for (
OpOperand *opOperand : operandsList) {
533 int64_t pos = opOperand->getOperandNumber();
534 Value operand = opOperand->get();
536 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
538 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
540 DBGS() <<
"operand: " << operand <<
"\n";
541 llvm::interleaveComma(innerPos,
DBGS() <<
"innerPos: ");
DBGSNL();
542 llvm::interleaveComma(innerPackSizes,
DBGS() <<
"innerPackSizes: ");
544 if (innerPackSizes.empty()) {
545 inputsAndInits.push_back(operand);
548 Value dest = tensor::PackOp::createDestinationTensor(
549 rewriter, loc, operand, innerPackSizes, innerPos,
551 ShapedType operandType = cast<ShapedType>(operand.
getType());
552 bool areConstantTiles =
556 if (areConstantTiles && operandType.hasStaticShape() &&
557 !tensor::PackOp::requirePaddingValue(
558 operandType.getShape(), innerPos,
559 cast<ShapedType>(dest.
getType()).getShape(), {},
561 packOps.push_back(rewriter.
create<tensor::PackOp>(
562 loc, operand, dest, innerPos, innerPackSizes));
568 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
569 packOps.push_back(rewriter.
create<tensor::PackOp>(
570 loc, operand, dest, innerPos, innerPackSizes, zero));
572 inputsAndInits.push_back(packOps.back());
578 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
580 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
581 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
582 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
587 for (
OpResult result : packedLinalgOp->getResults()) {
588 int64_t resultNum = result.getResultNumber();
589 tensor::PackOp maybePackedInit =
590 inits[resultNum].getDefiningOp<tensor::PackOp>();
591 if (!maybePackedInit) {
592 results.push_back(result);
596 unPackOps.push_back(rewriter.
create<tensor::UnPackOp>(
597 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
598 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
599 results.push_back(unPackOps.back());
607 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
636 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
640 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
642 assert(tensorType == transposedValue.
getType() &&
643 "expected tensor type mismatch");
648 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
652 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
656 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
662 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
665 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
666 operandsRef.take_front(linalgOp.getNumDpsInputs()),
667 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
669 linalgOp.getIteratorTypesArray());
671 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
673 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
676 FailureOr<PackTransposeResult>
678 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
685 tensor::PackOp transposedPackOp =
686 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
688 if (!packOp.getResult().hasOneUse())
691 OpOperand &packUse = *packOp->getUses().begin();
692 if (packUse.
getOwner() != linalgOp) {
694 linalgOp,
"not a single use by the LinalgOp target");
697 (!linalgOp.isDpsInit(&packUse) ||
698 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
700 "not produced by the LinalgOp target");
706 int64_t numLeadingDims = packOp.getSourceRank();
707 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
711 if (permutation.empty())
712 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
714 if (innerPerm.empty()) {
717 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
719 llvm::append_range(permutation,
720 llvm::map_range(innerPerm, [&](int64_t pos) {
721 return numLeadingDims + pos;
733 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
736 tensor::UnPackOp transposedUnPackOp;
739 transposedLinalgOp->getOpOperand(packUseOperandNumber);
740 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
742 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
743 rewriter, loc, transposedResult, innerPerm, outerPerm);
745 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
749 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
767 FailureOr<PackResult>
772 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
773 assert((mnkPaddedSizesNextMultipleOf.empty() ||
774 mnkPaddedSizesNextMultipleOf.size() == 3) &&
775 "num of packing sizes next multiple should be empty or of size 3");
776 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
779 int64_t numLoops = linalgOp.getNumLoops();
781 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
782 << numLoops <<
"\nin: " << linalgOp <<
"\n");
784 linalgOp,
"need 3+ loops to find a matmul to pack");
788 int64_t numPackedDims = mnkPackedSizes.size();
790 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
791 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
793 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
794 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
796 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
797 paddedSizesNextMultipleOf[mnkOrder[i]] =
798 mnkPaddedSizesNextMultipleOf.empty() ? 0
799 : mnkPaddedSizesNextMultipleOf[i];
803 FailureOr<ContractionDimensions> maybeDimensions =
805 if (failed(maybeDimensions)) {
806 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
809 "couldn't infer matmul iterators");
817 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
818 kPos = maybeDimensions->k.back();
820 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
821 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
825 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
827 FailureOr<GenericOp> generalizeResult =
829 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
830 genericOp = *generalizeResult;
838 LLVM_DEBUG(llvm::interleaveComma(permutation,
DBGS() <<
"perm: ");
DBGSNL(););
841 FailureOr<GenericOp> interchangeResult =
843 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
844 genericOp = *interchangeResult;
845 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
862 cast<LinalgOp>(genericOp.getOperation())
863 .createLoopRanges(rewriter, genericOp.getLoc());
867 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
868 DBGS() <<
"paddedSizesNextMultipleOf: ");
870 LLVM_DEBUG(llvm::interleaveComma(loopRanges,
DBGS() <<
"loopRanges: ",
871 [](
Range r) { llvm::dbgs() << r.size; });
875 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
876 if (paddedSizesNextMultipleOf[i] == 0) {
877 adjustedPackedSizes.push_back(packedSizes[i]);
884 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
885 {loopRanges[adjustedPackedSizes.size()].size,
886 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
888 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
889 DBGS() <<
"adjustedPackedSizes: ");
896 return pack(rewriter, genericOp, adjustedPackedSizes);
905 assert(!tileSizeComputationFunction &&
"tile sizes already set");
910 &op->getParentOfType<func::FuncOp>().getBody().front());
911 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
929 auto padValue = padOp.getConstantPaddingValue();
931 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
934 auto generateOp = rewriter.
create<tensor::GenerateOp>(
935 padOp.getLoc(), padOp.getResultType(), dynSizes);
938 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
947 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
951 padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
955 auto resultType = padOp.getResultType();
959 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
960 if (resultType.isDynamicDim(dim)) {
962 padOp.getSource(), dim));
965 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
967 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
968 dynSizes.push_back(plusHigh);
970 staticSizes.push_back(resultType.getDimSize(dim));
974 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
975 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
979 auto sourceType = padOp.getSourceType();
987 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
995 if (!sliceOp.hasUnitStride())
998 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1002 bool zeroSliceGuard =
true;
1004 if (std::optional<bool> control = controlFn(sliceOp))
1005 zeroSliceGuard = *control;
1010 FailureOr<TilingResult> tilingResult =
1012 sliceOp.getMixedSizes(), zeroSliceGuard);
1013 if (failed(tilingResult))
1017 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1027 tensor::PackOp packOp) {
1028 Value input = packOp.getSource();
1029 if (!packOp.getPaddingValue()) {
1033 assert(llvm::all_of(packOp.getAllOuterDims(),
1034 [](int64_t val) { return val == 1; }) &&
1035 "some outer dims are != 1");
1038 ShapedType inputType = packOp.getSourceType();
1039 int64_t inputRank = inputType.getRank();
1042 packOp.getDimAndTileMapping();
1049 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1052 if (!tileAndPosMapping.count(dimIdx)) {
1053 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1054 assert(inputDimSize == 1 &&
1055 "with all outer dims == 1, this non-tiled input dim should be 1!");
1056 paddedShape.push_back(inputDimSize);
1063 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1067 if (cstTileSize.has_value()) {
1068 paddedShape.push_back(cstTileSize.value());
1073 paddedShape.push_back(ShapedType::kDynamic);
1076 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1081 false, loc, builder,
1091 constexpr int64_t kNonTiledMarker = -1;
1096 vec, [&](int64_t v) {
return v != kNonTiledMarker; });
1111 int64_t unpackedRank = shape.size();
1112 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1113 if (llvm::is_contained(innerDimsPos, i)) {
1114 innerDims.push_back(dim++);
1119 outerDims.push_back(dim++);
1120 if (!outerDimsPerm.empty())
1121 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1127 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1132 rankReducedOuterDimsPerm =
1134 if (!rankReducedOuterDimsPerm.empty())
1135 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1138 perm.append(innerDims);
1147 if (llvm::any_of(packOp.getAllOuterDims(),
1148 [](int64_t dim) { return dim != 1; })) {
1150 packOp,
"not all outer dimensions of the result are 1s");
1159 packOp.getDimAndTileMapping();
1160 int64_t srcRank = packOp.getSourceRank();
1161 int64_t destRank = packOp.getDestRank();
1162 int64_t numTiles = destRank - srcRank;
1164 if (!llvm::all_of(packOp.getInnerDimsPos(),
1165 [&srcRank, &numTiles](int64_t dimPos) {
1166 return dimPos >= (srcRank - numTiles - 1);
1169 packOp,
"Attempting to tile non-trailing source dims!");
1175 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1176 if (dimAndTileMapping.count(i)) {
1180 auto [_, tileSize] =
1182 tileSizes.push_back(tileSize);
1196 for (int64_t i = 0; i < (srcRank - numTiles); i++)
1197 srcPermForTranspose.push_back(i);
1201 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n";
1202 llvm::interleaveComma(srcPermForTranspose,
DBGS() <<
"perm: ");
1208 transShapeForEmptyOp.append(tileSizes);
1210 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1211 srcPermForTranspose);
1213 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1216 auto transposedOp = rewriter.
create<linalg::TransposeOp>(loc, input, empty,
1217 srcPermForTranspose);
1228 for (
auto tileSize : packOp.getMixedTiles()) {
1229 auto [tileSizeStatic, tileSizeOfr] =
1231 writeSizes.push_back(tileSizeOfr);
1232 writeShape.push_back(tileSizeStatic);
1236 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1237 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1238 writeSizes, writeStrides);
1239 rewriter.
replaceOp(packOp, insert.getResult());
1246 int64_t srcRank = unpackOp.getSourceRank();
1247 int64_t destRank = unpackOp.getDestRank();
1250 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1251 [](int64_t dim) { return dim != 1; })) {
1254 "require the tiled outer dimensions of the result are all 1s");
1260 Value source = unpackOp.getSource();
1262 unpackOp.getDimAndTileMapping();
1285 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1294 if (dimAndTileMapping.count(i)) {
1295 extractSliceSizes.push_back(oneIdxAttr);
1301 if (ShapedType::isDynamic(srcShape[i])) {
1303 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1304 extractSliceSizes.push_back(dynamicDim);
1305 shapeForEmptyOp.push_back(dynamicDim);
1307 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1308 if (srcShape[i] != 1)
1309 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1313 if (srcShape[i] != 1) {
1314 readShapeForExtractSlice.push_back(srcShape[i]);
1319 auto mixedTiles = unpackOp.getMixedTiles();
1320 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1321 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1325 auto tileShape = srcShape.drop_front(destRank);
1327 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1328 Type elemType = unpackOp.getSourceType().getElementType();
1330 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1331 loc, readType, unpackOp.getSource(), extractSliceOffsets,
1332 extractSliceSizes, extractSliceStrides);
1336 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1339 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1342 rewriter.
create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1344 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1348 int numLoops = shapeForEmptyOp.size();
1353 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1354 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1355 tileSizes.push_back(
1359 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1360 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1366 for (
int i = 0, idx = 0; i < destRank; ++i) {
1367 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1368 writeSizes.push_back(tileSizes[idx++]);
1370 writeSizes.push_back(oneIdxAttr);
1372 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1373 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1375 rewriter.
replaceOp(unpackOp, insert.getResult());
1388 template <
typename Conv2DOp,
typename Conv1DOp>
1391 if (convOp.hasPureBufferSemantics())
1394 Value input = convOp.getInputs().front();
1395 Value kernel = convOp.getInputs().back();
1396 Value output = convOp.getOutputs().front();
1398 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1399 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1400 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1402 auto kernelShape = kernelType.getShape();
1403 auto outputShape = outputType.getShape();
1406 auto [khIndex, kwIndex, ohIndex, owIndex] =
1409 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1410 return std::make_tuple(0, 1, 1, 2);
1412 .Case([&](linalg::Conv2DNchwFchwOp op) {
1413 return std::make_tuple(2, 3, 2, 3);
1415 .Case([&](linalg::PoolingNhwcSumOp op) {
1416 return std::make_tuple(0, 1, 1, 2);
1418 .Case([&](linalg::PoolingNchwSumOp op) {
1419 return std::make_tuple(0, 1, 2, 3);
1421 .Case([&](linalg::PoolingNhwcMaxOp op) {
1422 return std::make_tuple(0, 1, 1, 2);
1424 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1425 return std::make_tuple(0, 1, 1, 2);
1427 .Case([&](linalg::PoolingNhwcMinOp op) {
1428 return std::make_tuple(0, 1, 1, 2);
1430 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1431 return std::make_tuple(0, 1, 1, 2);
1433 .Case([&](linalg::PoolingNchwMaxOp op) {
1434 return std::make_tuple(0, 1, 2, 3);
1437 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1438 return std::make_tuple(0, 0, 0, 0);
1443 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1444 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1445 bool removeH = (khSize == 1 && ohSize == 1);
1446 bool removeW = (kwSize == 1 && owSize == 1);
1447 if (!removeH && !removeW)
1453 RankedTensorType newInputType =
1454 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1455 RankedTensorType newKernelType =
1456 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1457 RankedTensorType newOutputType =
1458 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1463 rewriter, loc, input, newInputType);
1465 rewriter, loc, kernel, newKernelType);
1467 rewriter, loc, output, newOutputType);
1472 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1473 strides.erase(strides.begin() + (removeH ? 0 : 1));
1477 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1478 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1481 auto conv1DOp = rewriter.
create<Conv1DOp>(
1482 loc, newOutputType,
ValueRange{newInput, newKernel},
1483 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1487 rewriter, loc, conv1DOp.getResult(0), output);
1504 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1508 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1512 FailureOr<DepthwiseConv1DNwcWcOp>
1515 if (convOp.hasPureBufferSemantics())
1518 Value input = convOp.getInputs().front();
1519 Value kernel = convOp.getInputs().back();
1520 Value output = convOp.getOutputs().front();
1522 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1523 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1524 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1526 auto kernelShape = kernelType.getShape();
1527 auto outputShape = outputType.getShape();
1531 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1532 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1533 bool removeH = (khSize == 1 && ohSize == 1);
1534 bool removeW = (kwSize == 1 && owSize == 1);
1535 if (!removeH && !removeW)
1541 RankedTensorType newInputType =
1542 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1543 RankedTensorType newKernelType =
1544 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1545 RankedTensorType newOutputType =
1546 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1551 rewriter, loc, input, newInputType);
1553 rewriter, loc, kernel, newKernelType);
1555 rewriter, loc, output, newOutputType);
1559 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1560 strides.erase(strides.begin() + (removeH ? 0 : 1));
1564 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1565 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1568 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1569 loc, newOutputType,
ValueRange{newInput, newKernel},
1570 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1574 rewriter, loc, conv1DOp.getResult(0), output);
1583 if (convOp.hasPureBufferSemantics())
1586 Value input = convOp.getInputs().front();
1587 Value kernel = convOp.getInputs().back();
1588 Value output = convOp.getOutputs().front();
1590 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1591 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1592 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1594 auto kernelShape = kernelType.getShape();
1595 auto outputShape = outputType.getShape();
1599 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1600 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1601 bool removeH = (khSize == 1 && ohSize == 1);
1602 bool removeW = (kwSize == 1 && owSize == 1);
1603 if (!removeH && !removeW)
1609 RankedTensorType newInputType =
1610 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1611 RankedTensorType newKernelType =
1612 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1613 RankedTensorType newOutputType =
1614 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1619 rewriter, loc, input, newInputType);
1621 rewriter, loc, kernel, newKernelType);
1623 rewriter, loc, output, newOutputType);
1625 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1631 rewriter, loc, conv1DOp.getResult(0), output);
1650 PoolingNwcMaxUnsignedOp>,
1653 PoolingNwcMinUnsignedOp>,
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDim={})
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a tensor::PackOp into a sequence of:
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.