33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <type_traits>
41 #define DEBUG_TYPE "linalg-transforms"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #define DBGSNL() (llvm::dbgs() << "\n")
63 .Case<scf::ForOp>([&](scf::ForOp forOp) {
64 scf::ForOp partialIteration;
67 return partialIteration->getResults();
68 assert(!partialIteration &&
"expected that loop was not peeled");
69 return forOp->getResults();
78 for (
auto loopOp : loops)
91 if (!e.isFunctionOfDim(dim))
101 return llvm::interleaved(ri,
", ",
"|",
"");
153 static FailureOr<SmallVector<std::optional<int64_t>>>
157 int64_t newDim = iteratorTypes.size();
158 iteratorTypes.push_back(iteratorTypes[dim]);
161 indexingMaps.size(), std::nullopt);
163 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
165 AffineMap map = indexingMaps[operandIdx];
168 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
176 "num results invariant violation");
178 if (!maybeOperandDimensionToPack.has_value()) {
179 newMaps.push_back(map);
184 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
190 newMaps.push_back(map);
193 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
195 indexingMaps = newMaps;
197 return packedDimPerIndexingMap;
203 struct PackedOperandsDim {
209 struct PackedOperandsDimList {
210 void pushBack(PackedOperandsDim &&packedOperandsDims) {
211 spec.emplace_back(packedOperandsDims);
225 linalg::PackOp packOp,
226 bool lowerPadLikeWithInsertSlice) {
228 auto packedTensorType =
229 cast<RankedTensorType>(packOp->getResultTypes().front());
230 if (llvm::any_of(packOp.getStaticInnerTiles(),
231 [](int64_t size) { return ShapedType::isDynamic(size); })) {
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
243 PackingMetadata packingMetadata = computePackingMetadata(
244 packedTensorType.getRank(), packOp.getInnerDimsPos());
258 for (
auto [pos, innerSize] :
259 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
261 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
271 rewriter, loc, map, {outerSize, origSize, innerSize});
273 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
275 packingMetadata.reassociations);
276 Value paddingValue = packOp.getPaddingValue();
278 paddingValue = rewriter.
create<arith::ConstantOp>(
282 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
283 highs, paddingValue,
false);
287 DBGS() <<
"insertPositions: "
288 << llvm::interleaved(packingMetadata.insertPositions);
290 << llvm::interleaved(packingMetadata.outerPositions);
292 << llvm::interleaved(packedTensorType.getShape());
293 DBGSNL();
DBGS() <<
"packedToStripMinedShapePerm: "
294 << llvm::interleaved(packedToStripMinedShapePerm);
296 DBGS() <<
"reassociations: "
297 << llvm::interleaved(llvm::map_range(
300 DBGS() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
303 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
322 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
323 loc, padOp, packOp.getDest(),
326 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
328 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
336 auto expandShapeResultType =
338 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
339 loc, expandShapeResultType, padOp.getResult(),
340 packingMetadata.reassociations);
345 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
346 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
349 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
350 DBGS() <<
"transpPerm: " << llvm::interleaved(transpPerm);
354 rewriter.
replaceOp(packOp, transposeOp->getResults());
359 FailureOr<LowerUnPackOpResult>
361 bool lowerUnpadLikeWithExtractSlice) {
366 RankedTensorType packedTensorType = unPackOp.getSourceType();
367 int64_t packedRank = packedTensorType.getRank();
370 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
371 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
380 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
381 loc, destTensorType, unPackOp.getSource(),
385 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
388 nullptr, extractSliceOp};
393 PackingMetadata packingMetadata;
403 RankedTensorType stripMinedTensorType =
405 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
406 stripMinedTensorType, packingMetadata.reassociations);
413 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
414 loc, dims, stripMinedTensorType.getElementType());
415 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
416 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
420 DBGS() <<
"insertPositions: "
421 << llvm::interleaved(packingMetadata.insertPositions);
423 << llvm::interleaved(packedTensorType.getShape());
424 DBGSNL();
DBGS() <<
"packedToStripMinedShapePerm: "
425 << llvm::interleaved(packedToStripMinedShapePerm);
427 DBGS() <<
"reassociations: "
428 << llvm::interleaved(llvm::map_range(
431 DBGS() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
435 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
436 loc, collapsedType, transposeOp->getResult(0),
437 packingMetadata.reassociations);
440 int64_t destRank = destTensorType.getRank();
441 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
442 loc, destTensorType, reshapeOp->getResult(0),
448 auto copyOp = rewriter.
create<linalg::CopyOp>(
449 loc, extractSliceOp->getResult(0), unPackOp.getDest());
452 rewriter.
replaceOp(unPackOp, copyOp->getResults());
458 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
460 for (
auto &i : spec) {
461 if (!i.packedDimForEachOperand[operandPos].has_value())
463 res.push_back(i.packedDimForEachOperand[operandPos].value());
469 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
471 for (
auto &i : spec) {
472 if (!i.packedDimForEachOperand[operandPos].has_value())
474 res.push_back(i.packedSize);
483 linalg::LinalgOp linalgOp,
485 if (packedSizes.size() != linalgOp.getNumLoops()) {
487 "incorrect number of pack sizes");
493 linalgOp.getIteratorTypesArray();
494 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n"
495 <<
"maps: " << llvm::interleaved(indexingMaps) <<
"\n"
496 <<
"iterators: " << llvm::interleaved(iteratorTypes)
502 PackedOperandsDimList listOfPackedOperandsDim;
503 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
506 if (maybeConstant.has_value() && maybeConstant.value() == 0)
509 PackedOperandsDim packedOperandsDims;
510 packedOperandsDims.packedSize = packedSizes[i];
511 FailureOr<SmallVector<std::optional<int64_t>>>
512 maybePackedDimForEachOperand =
514 if (failed(maybePackedDimForEachOperand))
516 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
517 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
520 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
522 <<
"maps: " << llvm::interleaved(indexingMaps) <<
"\n"
523 <<
"iterators: " << llvm::interleaved(iteratorTypes) <<
"\n"
524 <<
"packedDimForEachOperand: "
525 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
532 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
534 for (
const auto &operandsList : {inputOperands, initOperands}) {
535 for (
OpOperand *opOperand : operandsList) {
536 int64_t pos = opOperand->getOperandNumber();
537 Value operand = opOperand->get();
539 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
541 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
542 LLVM_DEBUG(
DBGS() <<
"operand: " << operand <<
"\n"
543 <<
"innerPos: " << llvm::interleaved(innerPos) <<
"\n"
544 <<
"innerPackSizes: "
545 << llvm::interleaved(innerPackSizes) <<
"\n");
546 if (innerPackSizes.empty()) {
547 inputsAndInits.push_back(operand);
550 Value dest = linalg::PackOp::createDestinationTensor(
551 rewriter, loc, operand, innerPackSizes, innerPos,
553 ShapedType operandType = cast<ShapedType>(operand.
getType());
554 bool areConstantTiles =
558 if (areConstantTiles && operandType.hasStaticShape() &&
559 !linalg::PackOp::requirePaddingValue(
560 operandType.getShape(), innerPos,
561 cast<ShapedType>(dest.
getType()).getShape(), {},
563 packOps.push_back(rewriter.
create<linalg::PackOp>(
564 loc, operand, dest, innerPos, innerPackSizes));
570 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
571 packOps.push_back(rewriter.
create<linalg::PackOp>(
572 loc, operand, dest, innerPos, innerPackSizes, zero));
574 inputsAndInits.push_back(packOps.back());
580 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
582 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
583 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
584 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
589 for (
OpResult result : packedLinalgOp->getResults()) {
590 int64_t resultNum = result.getResultNumber();
591 linalg::PackOp maybePackedInit =
592 inits[resultNum].getDefiningOp<linalg::PackOp>();
593 if (!maybePackedInit) {
594 results.push_back(result);
598 unPackOps.push_back(rewriter.
create<linalg::UnPackOp>(
599 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
600 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
601 results.push_back(unPackOps.back());
609 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
638 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
642 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
644 assert(tensorType == transposedValue.
getType() &&
645 "expected tensor type mismatch");
650 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
654 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
658 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
664 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
667 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
668 operandsRef.take_front(linalgOp.getNumDpsInputs()),
669 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
671 linalgOp.getIteratorTypesArray());
673 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
675 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
678 FailureOr<PackTransposeResult>
680 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
687 linalg::PackOp transposedPackOp =
688 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
690 if (!packOp.getResult().hasOneUse())
693 OpOperand &packUse = *packOp->getUses().begin();
694 if (packUse.
getOwner() != linalgOp) {
696 linalgOp,
"not a single use by the LinalgOp target");
699 (!linalgOp.isDpsInit(&packUse) ||
700 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
702 "not produced by the LinalgOp target");
708 int64_t numLeadingDims = packOp.getSourceRank();
709 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
713 if (permutation.empty())
714 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
716 if (innerPerm.empty()) {
719 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
721 llvm::append_range(permutation,
722 llvm::map_range(innerPerm, [&](int64_t pos) {
723 return numLeadingDims + pos;
735 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
738 linalg::UnPackOp transposedUnPackOp;
741 transposedLinalgOp->getOpOperand(packUseOperandNumber);
742 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
744 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
745 rewriter, loc, transposedResult, innerPerm, outerPerm);
747 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
751 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
769 FailureOr<PackResult>
774 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
775 assert((mnkPaddedSizesNextMultipleOf.empty() ||
776 mnkPaddedSizesNextMultipleOf.size() == 3) &&
777 "num of packing sizes next multiple should be empty or of size 3");
778 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
781 int64_t numLoops = linalgOp.getNumLoops();
783 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
784 << numLoops <<
"\nin: " << linalgOp <<
"\n");
786 linalgOp,
"need 3+ loops to find a matmul to pack");
790 int64_t numPackedDims = mnkPackedSizes.size();
792 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
793 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
795 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
796 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
798 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
799 paddedSizesNextMultipleOf[mnkOrder[i]] =
800 mnkPaddedSizesNextMultipleOf.empty() ? 0
801 : mnkPaddedSizesNextMultipleOf[i];
805 FailureOr<ContractionDimensions> maybeDimensions =
807 if (failed(maybeDimensions)) {
808 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
811 "couldn't infer matmul iterators");
819 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
820 kPos = maybeDimensions->k.back();
822 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
823 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
827 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
829 FailureOr<GenericOp> generalizeResult =
831 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
832 genericOp = *generalizeResult;
840 LLVM_DEBUG(
DBGS() <<
"perm: " << llvm::interleaved(permutation) <<
"\n");
843 FailureOr<GenericOp> interchangeResult =
845 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
846 genericOp = *interchangeResult;
847 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
864 cast<LinalgOp>(genericOp.getOperation())
865 .createLoopRanges(rewriter, genericOp.getLoc());
869 LLVM_DEBUG(
DBGS() <<
"paddedSizesNextMultipleOf: "
870 << llvm::interleaved(paddedSizesNextMultipleOf) <<
"\n"
872 << llvm::interleaved(llvm::map_range(
873 loopRanges, [](
Range r) {
return r.size; }))
877 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
878 if (paddedSizesNextMultipleOf[i] == 0) {
879 adjustedPackedSizes.push_back(packedSizes[i]);
886 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
887 {loopRanges[adjustedPackedSizes.size()].size,
888 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
890 LLVM_DEBUG(
DBGS() <<
"adjustedPackedSizes: "
891 << llvm::interleaved(adjustedPackedSizes) <<
"\n");
897 return pack(rewriter, genericOp, adjustedPackedSizes);
906 assert(!tileSizeComputationFunction &&
"tile sizes already set");
911 &op->getParentOfType<func::FuncOp>().getBody().front());
912 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
930 auto padValue = padOp.getConstantPaddingValue();
933 if (padValue.getParentBlock() == &padOp.getRegion().front())
935 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
939 auto generateOp = rewriter.
create<tensor::GenerateOp>(
940 padOp.getLoc(), padOp.getResultType(), dynSizes);
943 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
952 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
956 padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
960 auto resultType = padOp.getResultType();
964 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
965 if (resultType.isDynamicDim(dim)) {
967 padOp.getSource(), dim));
970 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
972 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
973 dynSizes.push_back(plusHigh);
975 staticSizes.push_back(resultType.getDimSize(dim));
979 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
980 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
984 auto sourceType = padOp.getSourceType();
992 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1000 if (!sliceOp.hasUnitStride())
1003 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1007 bool zeroSliceGuard =
true;
1009 if (std::optional<bool> control = controlFn(sliceOp))
1010 zeroSliceGuard = *control;
1015 FailureOr<TilingResult> tilingResult =
1017 sliceOp.getMixedSizes(), zeroSliceGuard);
1018 if (failed(tilingResult))
1022 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1032 linalg::PackOp packOp) {
1033 Value input = packOp.getSource();
1034 if (!packOp.getPaddingValue()) {
1038 assert(llvm::all_of(packOp.getAllOuterDims(),
1039 [](int64_t val) { return val == 1; }) &&
1040 "some outer dims are != 1");
1043 ShapedType inputType = packOp.getSourceType();
1044 int64_t inputRank = inputType.getRank();
1047 packOp.getDimAndTileMapping();
1054 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1057 if (!tileAndPosMapping.count(dimIdx)) {
1058 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1059 assert(inputDimSize == 1 &&
1060 "with all outer dims == 1, this non-tiled input dim should be 1!");
1061 paddedShape.push_back(inputDimSize);
1068 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1072 if (cstTileSize.has_value()) {
1073 paddedShape.push_back(cstTileSize.value());
1078 paddedShape.push_back(ShapedType::kDynamic);
1081 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1086 false, loc, builder,
1096 constexpr int64_t kNonTiledMarker = -1;
1101 vec, [&](int64_t v) {
return v != kNonTiledMarker; });
1116 int64_t unpackedRank = shape.size();
1117 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1119 innerDims.push_back(dim++);
1124 outerDims.push_back(dim++);
1132 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1137 rankReducedOuterDimsPerm =
1139 if (!rankReducedOuterDimsPerm.empty())
1140 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1143 perm.append(innerDims);
1152 if (llvm::any_of(packOp.getAllOuterDims(),
1153 [](int64_t dim) { return dim != 1; })) {
1155 packOp,
"not all outer dimensions of the result are 1s");
1164 packOp.getDimAndTileMapping();
1165 int64_t srcRank = packOp.getSourceRank();
1166 int64_t destRank = packOp.getDestRank();
1167 int64_t numTiles = destRank - srcRank;
1169 if (!llvm::all_of(packOp.getInnerDimsPos(),
1170 [&srcRank, &numTiles](int64_t dimPos) {
1171 return dimPos >= (srcRank - numTiles - 1);
1174 packOp,
"Attempting to tile non-trailing source dims!");
1180 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1181 if (dimAndTileMapping.count(i)) {
1185 auto [_, tileSize] =
1187 tileSizes.push_back(tileSize);
1201 for (int64_t i = 0; i < (srcRank - numTiles); i++)
1202 srcPermForTranspose.push_back(i);
1206 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n"
1207 <<
"perm: " << llvm::interleaved(srcPermForTranspose)
1213 transShapeForEmptyOp.append(tileSizes);
1215 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1216 srcPermForTranspose);
1218 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1221 auto transposedOp = rewriter.
create<linalg::TransposeOp>(loc, input, empty,
1222 srcPermForTranspose);
1233 for (
auto tileSize : packOp.getMixedTiles()) {
1234 auto [tileSizeStatic, tileSizeOfr] =
1236 writeSizes.push_back(tileSizeOfr);
1237 writeShape.push_back(tileSizeStatic);
1241 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1242 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1243 writeSizes, writeStrides);
1244 rewriter.
replaceOp(packOp, insert.getResult());
1251 int64_t srcRank = unpackOp.getSourceRank();
1252 int64_t destRank = unpackOp.getDestRank();
1255 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1256 [](int64_t dim) { return dim != 1; })) {
1259 "require the tiled outer dimensions of the result are all 1s");
1265 Value source = unpackOp.getSource();
1267 unpackOp.getDimAndTileMapping();
1290 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1299 if (dimAndTileMapping.count(i)) {
1300 extractSliceSizes.push_back(oneIdxAttr);
1306 if (ShapedType::isDynamic(srcShape[i])) {
1308 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1309 extractSliceSizes.push_back(dynamicDim);
1310 shapeForEmptyOp.push_back(dynamicDim);
1312 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1313 if (srcShape[i] != 1)
1314 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1318 if (srcShape[i] != 1) {
1319 readShapeForExtractSlice.push_back(srcShape[i]);
1324 auto mixedTiles = unpackOp.getMixedTiles();
1325 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1326 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1330 auto tileShape = srcShape.drop_front(destRank);
1332 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1333 Type elemType = unpackOp.getSourceType().getElementType();
1335 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1336 loc, readType, unpackOp.getSource(), extractSliceOffsets,
1337 extractSliceSizes, extractSliceStrides);
1341 srcShape.take_front(destRank),
innerDimsPos, unpackOp.getOuterDimsPerm());
1344 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1347 rewriter.
create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1349 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1353 int numLoops = shapeForEmptyOp.size();
1358 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1359 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1360 tileSizes.push_back(
1364 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1365 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1371 for (
int i = 0, idx = 0; i < destRank; ++i) {
1372 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1373 writeSizes.push_back(tileSizes[idx++]);
1375 writeSizes.push_back(oneIdxAttr);
1377 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1378 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1380 rewriter.
replaceOp(unpackOp, insert.getResult());
1393 template <
typename Conv2DOp,
typename Conv1DOp>
1396 if (convOp.hasPureBufferSemantics())
1399 Value input = convOp.getInputs().front();
1400 Value kernel = convOp.getInputs().back();
1401 Value output = convOp.getOutputs().front();
1403 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1404 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1405 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1407 auto kernelShape = kernelType.getShape();
1408 auto outputShape = outputType.getShape();
1411 auto [khIndex, kwIndex, ohIndex, owIndex] =
1414 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1415 return std::make_tuple(0, 1, 1, 2);
1417 .Case([&](linalg::Conv2DNchwFchwOp op) {
1418 return std::make_tuple(2, 3, 2, 3);
1420 .Case([&](linalg::PoolingNhwcSumOp op) {
1421 return std::make_tuple(0, 1, 1, 2);
1423 .Case([&](linalg::PoolingNchwSumOp op) {
1424 return std::make_tuple(0, 1, 2, 3);
1426 .Case([&](linalg::PoolingNhwcMaxOp op) {
1427 return std::make_tuple(0, 1, 1, 2);
1429 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1430 return std::make_tuple(0, 1, 1, 2);
1432 .Case([&](linalg::PoolingNhwcMinOp op) {
1433 return std::make_tuple(0, 1, 1, 2);
1435 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1436 return std::make_tuple(0, 1, 1, 2);
1438 .Case([&](linalg::PoolingNchwMaxOp op) {
1439 return std::make_tuple(0, 1, 2, 3);
1442 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1443 return std::make_tuple(0, 0, 0, 0);
1448 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1449 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1450 bool removeH = (khSize == 1 && ohSize == 1);
1451 bool removeW = (kwSize == 1 && owSize == 1);
1452 if (!removeH && !removeW)
1458 RankedTensorType newInputType =
1459 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1460 RankedTensorType newKernelType =
1461 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1462 RankedTensorType newOutputType =
1463 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1468 rewriter, loc, input, newInputType);
1470 rewriter, loc, kernel, newKernelType);
1472 rewriter, loc, output, newOutputType);
1477 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1478 strides.erase(strides.begin() + (removeH ? 0 : 1));
1482 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1483 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1486 auto conv1DOp = rewriter.
create<Conv1DOp>(
1487 loc, newOutputType,
ValueRange{newInput, newKernel},
1488 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1492 rewriter, loc, conv1DOp.getResult(0), output);
1509 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1513 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1517 FailureOr<DepthwiseConv1DNwcWcOp>
1520 if (convOp.hasPureBufferSemantics())
1523 Value input = convOp.getInputs().front();
1524 Value kernel = convOp.getInputs().back();
1525 Value output = convOp.getOutputs().front();
1527 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1528 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1529 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1531 auto kernelShape = kernelType.getShape();
1532 auto outputShape = outputType.getShape();
1536 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1537 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1538 bool removeH = (khSize == 1 && ohSize == 1);
1539 bool removeW = (kwSize == 1 && owSize == 1);
1540 if (!removeH && !removeW)
1546 RankedTensorType newInputType =
1547 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1548 RankedTensorType newKernelType =
1549 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1550 RankedTensorType newOutputType =
1551 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1556 rewriter, loc, input, newInputType);
1558 rewriter, loc, kernel, newKernelType);
1560 rewriter, loc, output, newOutputType);
1564 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1565 strides.erase(strides.begin() + (removeH ? 0 : 1));
1569 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1570 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1573 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1574 loc, newOutputType,
ValueRange{newInput, newKernel},
1575 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1579 rewriter, loc, conv1DOp.getResult(0), output);
1588 if (convOp.hasPureBufferSemantics())
1591 Value input = convOp.getInputs().front();
1592 Value kernel = convOp.getInputs().back();
1593 Value output = convOp.getOutputs().front();
1595 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1596 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1597 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1599 auto kernelShape = kernelType.getShape();
1600 auto outputShape = outputType.getShape();
1604 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1605 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1606 bool removeH = (khSize == 1 && ohSize == 1);
1607 bool removeW = (kwSize == 1 && owSize == 1);
1608 if (!removeH && !removeW)
1614 RankedTensorType newInputType =
1615 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1616 RankedTensorType newKernelType =
1617 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1618 RankedTensorType newOutputType =
1619 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1624 rewriter, loc, input, newInputType);
1626 rewriter, loc, kernel, newKernelType);
1628 rewriter, loc, output, newOutputType);
1630 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1636 rewriter, loc, conv1DOp.getResult(0), output);
1655 PoolingNwcMaxUnsignedOp>,
1658 PoolingNwcMinUnsignedOp>,
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of rank-reduced.
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.