33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <type_traits>
41 #define DEBUG_TYPE "linalg-transforms"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #define DBGSNL() (llvm::dbgs() << "\n")
63 .Case<scf::ForOp>([&](scf::ForOp forOp) {
64 scf::ForOp partialIteration;
67 return partialIteration->getResults();
68 assert(!partialIteration &&
"expected that loop was not peeled");
69 return forOp->getResults();
78 for (
auto loopOp : loops)
91 if (!e.isFunctionOfDim(dim))
101 return llvm::interleaved(ri,
", ",
"|",
"");
153 static FailureOr<SmallVector<std::optional<int64_t>>>
157 int64_t newDim = iteratorTypes.size();
158 iteratorTypes.push_back(iteratorTypes[dim]);
161 indexingMaps.size(), std::nullopt);
163 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
165 AffineMap map = indexingMaps[operandIdx];
168 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
176 "num results invariant violation");
178 if (!maybeOperandDimensionToPack.has_value()) {
179 newMaps.push_back(map);
184 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
190 newMaps.push_back(map);
193 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
195 indexingMaps = newMaps;
197 return packedDimPerIndexingMap;
203 struct PackedOperandsDim {
209 struct PackedOperandsDimList {
210 void pushBack(PackedOperandsDim &&packedOperandsDims) {
211 spec.emplace_back(packedOperandsDims);
225 linalg::PackOp packOp,
226 bool lowerPadLikeWithInsertSlice) {
228 auto packedTensorType =
229 cast<RankedTensorType>(packOp->getResultTypes().front());
230 if (llvm::any_of(packOp.getStaticInnerTiles(),
231 [](int64_t size) { return ShapedType::isDynamic(size); })) {
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
243 PackingMetadata packingMetadata = computePackingMetadata(
244 packedTensorType.getRank(), packOp.getInnerDimsPos());
258 for (
auto [pos, innerSize] :
259 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
261 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
271 rewriter, loc, map, {outerSize, origSize, innerSize});
273 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
275 packingMetadata.reassociations);
276 Value paddingValue = packOp.getPaddingValue();
278 paddingValue = rewriter.
create<arith::ConstantOp>(
282 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
283 highs, paddingValue,
false);
287 DBGS() <<
"insertPositions: "
288 << llvm::interleaved(packingMetadata.insertPositions);
290 << llvm::interleaved(packingMetadata.outerPositions);
292 << llvm::interleaved(packedTensorType.getShape());
293 DBGSNL();
DBGS() <<
"packedToStripMinedShapePerm: "
294 << llvm::interleaved(packedToStripMinedShapePerm);
296 DBGS() <<
"reassociations: "
297 << llvm::interleaved(llvm::map_range(
300 DBGS() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
303 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
322 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
323 loc, padOp, packOp.getDest(),
326 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
328 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
336 auto expandShapeResultType =
338 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
339 loc, expandShapeResultType, padOp.getResult(),
340 packingMetadata.reassociations);
345 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
346 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
349 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
350 DBGS() <<
"transpPerm: " << llvm::interleaved(transpPerm);
354 rewriter.
replaceOp(packOp, transposeOp->getResults());
359 FailureOr<LowerUnPackOpResult>
361 bool lowerUnpadLikeWithExtractSlice) {
366 RankedTensorType packedTensorType = unPackOp.getSourceType();
367 int64_t packedRank = packedTensorType.getRank();
370 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
371 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
380 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
381 loc, destTensorType, unPackOp.getSource(),
385 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
388 nullptr, extractSliceOp};
393 PackingMetadata packingMetadata;
403 RankedTensorType stripMinedTensorType =
405 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
406 stripMinedTensorType, packingMetadata.reassociations);
413 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
414 loc, dims, stripMinedTensorType.getElementType());
415 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
416 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
420 DBGS() <<
"insertPositions: "
421 << llvm::interleaved(packingMetadata.insertPositions);
423 << llvm::interleaved(packedTensorType.getShape());
424 DBGSNL();
DBGS() <<
"packedToStripMinedShapePerm: "
425 << llvm::interleaved(packedToStripMinedShapePerm);
427 DBGS() <<
"reassociations: "
428 << llvm::interleaved(llvm::map_range(
431 DBGS() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
435 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
436 loc, collapsedType, transposeOp->getResult(0),
437 packingMetadata.reassociations);
440 int64_t destRank = destTensorType.getRank();
441 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
442 loc, destTensorType, reshapeOp->getResult(0),
448 auto copyOp = rewriter.
create<linalg::CopyOp>(
449 loc, extractSliceOp->getResult(0), unPackOp.getDest());
452 rewriter.
replaceOp(unPackOp, copyOp->getResults());
458 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
460 for (
auto &i : spec) {
461 if (!i.packedDimForEachOperand[operandPos].has_value())
463 res.push_back(i.packedDimForEachOperand[operandPos].value());
469 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
471 for (
auto &i : spec) {
472 if (!i.packedDimForEachOperand[operandPos].has_value())
474 res.push_back(i.packedSize);
483 linalg::LinalgOp linalgOp,
485 if (packedSizes.size() != linalgOp.getNumLoops()) {
487 "incorrect number of pack sizes");
493 linalgOp.getIteratorTypesArray();
494 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n"
495 <<
"maps: " << llvm::interleaved(indexingMaps) <<
"\n"
496 <<
"iterators: " << llvm::interleaved(iteratorTypes)
502 PackedOperandsDimList listOfPackedOperandsDim;
503 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
506 if (maybeConstant.has_value() && maybeConstant.value() == 0)
509 PackedOperandsDim packedOperandsDims;
510 packedOperandsDims.packedSize = packedSizes[i];
511 FailureOr<SmallVector<std::optional<int64_t>>>
512 maybePackedDimForEachOperand =
514 if (failed(maybePackedDimForEachOperand))
516 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
517 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
520 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
522 <<
"maps: " << llvm::interleaved(indexingMaps) <<
"\n"
523 <<
"iterators: " << llvm::interleaved(iteratorTypes) <<
"\n"
524 <<
"packedDimForEachOperand: "
525 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
532 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
534 for (
const auto &operandsList : {inputOperands, initOperands}) {
535 for (
OpOperand *opOperand : operandsList) {
536 int64_t pos = opOperand->getOperandNumber();
537 Value operand = opOperand->get();
539 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
541 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
542 LLVM_DEBUG(
DBGS() <<
"operand: " << operand <<
"\n"
543 <<
"innerPos: " << llvm::interleaved(innerPos) <<
"\n"
544 <<
"innerPackSizes: "
545 << llvm::interleaved(innerPackSizes) <<
"\n");
546 if (innerPackSizes.empty()) {
547 inputsAndInits.push_back(operand);
550 Value dest = linalg::PackOp::createDestinationTensor(
551 rewriter, loc, operand, innerPackSizes, innerPos,
553 ShapedType operandType = cast<ShapedType>(operand.
getType());
554 bool areConstantTiles =
558 if (areConstantTiles && operandType.hasStaticShape() &&
559 !linalg::PackOp::requirePaddingValue(
560 operandType.getShape(), innerPos,
561 cast<ShapedType>(dest.
getType()).getShape(), {},
563 packOps.push_back(rewriter.
create<linalg::PackOp>(
564 loc, operand, dest, innerPos, innerPackSizes));
570 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
571 packOps.push_back(rewriter.
create<linalg::PackOp>(
572 loc, operand, dest, innerPos, innerPackSizes, zero));
574 inputsAndInits.push_back(packOps.back());
580 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
582 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
583 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
584 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
589 for (
OpResult result : packedLinalgOp->getResults()) {
590 int64_t resultNum = result.getResultNumber();
591 linalg::PackOp maybePackedInit =
592 inits[resultNum].getDefiningOp<linalg::PackOp>();
593 if (!maybePackedInit) {
594 results.push_back(result);
598 unPackOps.push_back(rewriter.
create<linalg::UnPackOp>(
599 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
600 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
601 results.push_back(unPackOps.back());
609 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
638 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
642 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
644 assert(tensorType == transposedValue.
getType() &&
645 "expected tensor type mismatch");
650 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
654 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
658 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
664 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
667 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
668 operandsRef.take_front(linalgOp.getNumDpsInputs()),
669 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
671 linalgOp.getIteratorTypesArray());
673 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
675 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
678 FailureOr<PackTransposeResult>
680 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
687 linalg::PackOp transposedPackOp =
688 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
690 if (!packOp.getResult().hasOneUse())
693 OpOperand &packUse = *packOp->getUses().begin();
694 if (packUse.
getOwner() != linalgOp) {
696 linalgOp,
"not a single use by the LinalgOp target");
699 (!linalgOp.isDpsInit(&packUse) ||
700 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
702 "not produced by the LinalgOp target");
708 int64_t numLeadingDims = packOp.getSourceRank();
709 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
713 if (permutation.empty())
714 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
716 if (innerPerm.empty()) {
719 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
721 llvm::append_range(permutation,
722 llvm::map_range(innerPerm, [&](int64_t pos) {
723 return numLeadingDims + pos;
735 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
738 linalg::UnPackOp transposedUnPackOp;
741 transposedLinalgOp->getOpOperand(packUseOperandNumber);
742 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
744 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
745 rewriter, loc, transposedResult, innerPerm, outerPerm);
747 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
751 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
769 FailureOr<PackResult>
774 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
775 assert((mnkPaddedSizesNextMultipleOf.empty() ||
776 mnkPaddedSizesNextMultipleOf.size() == 3) &&
777 "num of packing sizes next multiple should be empty or of size 3");
778 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
781 int64_t numLoops = linalgOp.getNumLoops();
783 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
784 << numLoops <<
"\nin: " << linalgOp <<
"\n");
786 linalgOp,
"need 3+ loops to find a matmul to pack");
790 int64_t numPackedDims = mnkPackedSizes.size();
792 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
793 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
795 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
796 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
798 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
799 paddedSizesNextMultipleOf[mnkOrder[i]] =
800 mnkPaddedSizesNextMultipleOf.empty() ? 0
801 : mnkPaddedSizesNextMultipleOf[i];
805 FailureOr<ContractionDimensions> maybeDimensions =
807 if (failed(maybeDimensions)) {
808 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
811 "couldn't infer matmul iterators");
819 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
820 kPos = maybeDimensions->k.back();
822 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
823 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
827 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
829 FailureOr<GenericOp> generalizeResult =
831 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
832 genericOp = *generalizeResult;
840 LLVM_DEBUG(
DBGS() <<
"perm: " << llvm::interleaved(permutation) <<
"\n");
843 FailureOr<GenericOp> interchangeResult =
845 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
846 genericOp = *interchangeResult;
847 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
864 cast<LinalgOp>(genericOp.getOperation())
865 .createLoopRanges(rewriter, genericOp.getLoc());
869 LLVM_DEBUG(
DBGS() <<
"paddedSizesNextMultipleOf: "
870 << llvm::interleaved(paddedSizesNextMultipleOf) <<
"\n"
872 << llvm::interleaved(llvm::map_range(
873 loopRanges, [](
Range r) {
return r.size; }))
877 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
878 if (paddedSizesNextMultipleOf[i] == 0) {
879 adjustedPackedSizes.push_back(packedSizes[i]);
886 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
887 {loopRanges[adjustedPackedSizes.size()].size,
888 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
890 LLVM_DEBUG(
DBGS() <<
"adjustedPackedSizes: "
891 << llvm::interleaved(adjustedPackedSizes) <<
"\n");
897 return pack(rewriter, genericOp, adjustedPackedSizes);
906 assert(!tileSizeComputationFunction &&
"tile sizes already set");
911 &op->getParentOfType<func::FuncOp>().getBody().front());
912 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
930 auto padValue = padOp.getConstantPaddingValue();
933 if (padValue.getParentBlock() == &padOp.getRegion().front())
935 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
939 auto generateOp = rewriter.
create<tensor::GenerateOp>(
940 padOp.getLoc(), padOp.getResultType(), dynSizes);
943 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
952 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
956 padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
960 auto resultType = padOp.getResultType();
964 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
965 if (resultType.isDynamicDim(dim)) {
967 padOp.getSource(), dim));
970 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
972 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
973 dynSizes.push_back(plusHigh);
975 staticSizes.push_back(resultType.getDimSize(dim));
979 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
980 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
984 auto sourceType = padOp.getSourceType();
992 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1000 if (!sliceOp.hasUnitStride())
1003 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1007 bool zeroSliceGuard =
true;
1009 if (std::optional<bool> control = controlFn(sliceOp))
1010 zeroSliceGuard = *control;
1015 FailureOr<TilingResult> tilingResult =
1017 sliceOp.getMixedSizes(), zeroSliceGuard);
1018 if (failed(tilingResult))
1021 RankedTensorType sourceType = sliceOp.getSourceType();
1022 RankedTensorType resultType = sliceOp.getResultType();
1026 if (sourceType.getRank() == resultType.getRank()) {
1027 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1033 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1035 rewriter.
replaceOp(sliceOp, rankReduced);
1045 linalg::PackOp packOp) {
1046 Value input = packOp.getSource();
1047 if (!packOp.getPaddingValue()) {
1051 assert(llvm::all_of(packOp.getAllOuterDims(),
1052 [](int64_t val) { return val == 1; }) &&
1053 "some outer dims are != 1");
1056 ShapedType inputType = packOp.getSourceType();
1057 int64_t inputRank = inputType.getRank();
1060 packOp.getDimAndTileMapping();
1067 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1070 if (!tileAndPosMapping.count(dimIdx)) {
1071 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1072 assert(inputDimSize == 1 &&
1073 "with all outer dims == 1, this non-tiled input dim should be 1!");
1074 paddedShape.push_back(inputDimSize);
1081 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1085 if (cstTileSize.has_value()) {
1086 paddedShape.push_back(cstTileSize.value());
1091 paddedShape.push_back(ShapedType::kDynamic);
1094 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1099 false, loc, builder,
1109 constexpr int64_t kNonTiledMarker = -1;
1114 vec, [&](int64_t v) {
return v != kNonTiledMarker; });
1129 int64_t unpackedRank = shape.size();
1130 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1132 innerDims.push_back(dim++);
1137 outerDims.push_back(dim++);
1145 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1150 rankReducedOuterDimsPerm =
1152 if (!rankReducedOuterDimsPerm.empty())
1153 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1156 perm.append(innerDims);
1165 if (llvm::any_of(packOp.getAllOuterDims(),
1166 [](int64_t dim) { return dim != 1; })) {
1168 packOp,
"not all outer dimensions of the result are 1s");
1177 packOp.getDimAndTileMapping();
1178 int64_t srcRank = packOp.getSourceRank();
1179 int64_t destRank = packOp.getDestRank();
1180 int64_t numTiles = destRank - srcRank;
1182 if (!llvm::all_of(packOp.getInnerDimsPos(),
1183 [&srcRank, &numTiles](int64_t dimPos) {
1184 return dimPos >= (srcRank - numTiles - 1);
1187 packOp,
"Attempting to tile non-trailing source dims!");
1193 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1194 if (dimAndTileMapping.count(i)) {
1198 auto [_, tileSize] =
1200 tileSizes.push_back(tileSize);
1214 for (int64_t i = 0; i < (srcRank - numTiles); i++)
1215 srcPermForTranspose.push_back(i);
1219 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n"
1220 <<
"perm: " << llvm::interleaved(srcPermForTranspose)
1226 transShapeForEmptyOp.append(tileSizes);
1228 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1229 srcPermForTranspose);
1231 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1234 auto transposedOp = rewriter.
create<linalg::TransposeOp>(loc, input, empty,
1235 srcPermForTranspose);
1246 for (
auto tileSize : packOp.getMixedTiles()) {
1247 auto [tileSizeStatic, tileSizeOfr] =
1249 writeSizes.push_back(tileSizeOfr);
1250 writeShape.push_back(tileSizeStatic);
1254 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1255 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1256 writeSizes, writeStrides);
1257 rewriter.
replaceOp(packOp, insert.getResult());
1264 int64_t srcRank = unpackOp.getSourceRank();
1265 int64_t destRank = unpackOp.getDestRank();
1268 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1269 [](int64_t dim) { return dim != 1; })) {
1272 "require the tiled outer dimensions of the result are all 1s");
1278 Value source = unpackOp.getSource();
1280 unpackOp.getDimAndTileMapping();
1303 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1312 if (dimAndTileMapping.count(i)) {
1313 extractSliceSizes.push_back(oneIdxAttr);
1319 if (ShapedType::isDynamic(srcShape[i])) {
1321 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1322 extractSliceSizes.push_back(dynamicDim);
1323 shapeForEmptyOp.push_back(dynamicDim);
1325 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1326 if (srcShape[i] != 1)
1327 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1331 if (srcShape[i] != 1) {
1332 readShapeForExtractSlice.push_back(srcShape[i]);
1337 auto mixedTiles = unpackOp.getMixedTiles();
1338 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1339 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1343 auto tileShape = srcShape.drop_front(destRank);
1345 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1346 Type elemType = unpackOp.getSourceType().getElementType();
1348 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1349 loc, readType, unpackOp.getSource(), extractSliceOffsets,
1350 extractSliceSizes, extractSliceStrides);
1354 srcShape.take_front(destRank),
innerDimsPos, unpackOp.getOuterDimsPerm());
1357 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1360 rewriter.
create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1362 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1366 int numLoops = shapeForEmptyOp.size();
1371 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1372 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1373 tileSizes.push_back(
1377 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1378 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1384 for (
int i = 0, idx = 0; i < destRank; ++i) {
1385 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1386 writeSizes.push_back(tileSizes[idx++]);
1388 writeSizes.push_back(oneIdxAttr);
1390 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1391 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1393 rewriter.
replaceOp(unpackOp, insert.getResult());
1406 template <
typename Conv2DOp,
typename Conv1DOp>
1409 if (convOp.hasPureBufferSemantics())
1412 Value input = convOp.getInputs().front();
1413 Value kernel = convOp.getInputs().back();
1414 Value output = convOp.getOutputs().front();
1416 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1417 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1418 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1420 auto kernelShape = kernelType.getShape();
1421 auto outputShape = outputType.getShape();
1424 auto [khIndex, kwIndex, ohIndex, owIndex] =
1427 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1428 return std::make_tuple(0, 1, 1, 2);
1430 .Case([&](linalg::Conv2DNchwFchwOp op) {
1431 return std::make_tuple(2, 3, 2, 3);
1433 .Case([&](linalg::PoolingNhwcSumOp op) {
1434 return std::make_tuple(0, 1, 1, 2);
1436 .Case([&](linalg::PoolingNchwSumOp op) {
1437 return std::make_tuple(0, 1, 2, 3);
1439 .Case([&](linalg::PoolingNhwcMaxOp op) {
1440 return std::make_tuple(0, 1, 1, 2);
1442 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1443 return std::make_tuple(0, 1, 1, 2);
1445 .Case([&](linalg::PoolingNhwcMinOp op) {
1446 return std::make_tuple(0, 1, 1, 2);
1448 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1449 return std::make_tuple(0, 1, 1, 2);
1451 .Case([&](linalg::PoolingNchwMaxOp op) {
1452 return std::make_tuple(0, 1, 2, 3);
1455 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1456 return std::make_tuple(0, 0, 0, 0);
1461 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1462 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1463 bool removeH = (khSize == 1 && ohSize == 1);
1464 bool removeW = (kwSize == 1 && owSize == 1);
1465 if (!removeH && !removeW)
1471 RankedTensorType newInputType =
1472 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1473 RankedTensorType newKernelType =
1474 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1475 RankedTensorType newOutputType =
1476 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1481 rewriter, loc, input, newInputType);
1483 rewriter, loc, kernel, newKernelType);
1485 rewriter, loc, output, newOutputType);
1490 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1491 strides.erase(strides.begin() + (removeH ? 0 : 1));
1495 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1496 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1499 auto conv1DOp = rewriter.
create<Conv1DOp>(
1500 loc, newOutputType,
ValueRange{newInput, newKernel},
1501 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1505 rewriter, loc, conv1DOp.getResult(0), output);
1522 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1526 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1530 FailureOr<DepthwiseConv1DNwcWcOp>
1533 if (convOp.hasPureBufferSemantics())
1536 Value input = convOp.getInputs().front();
1537 Value kernel = convOp.getInputs().back();
1538 Value output = convOp.getOutputs().front();
1540 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1541 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1542 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1544 auto kernelShape = kernelType.getShape();
1545 auto outputShape = outputType.getShape();
1549 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1550 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1551 bool removeH = (khSize == 1 && ohSize == 1);
1552 bool removeW = (kwSize == 1 && owSize == 1);
1553 if (!removeH && !removeW)
1559 RankedTensorType newInputType =
1560 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1561 RankedTensorType newKernelType =
1562 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1563 RankedTensorType newOutputType =
1564 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1569 rewriter, loc, input, newInputType);
1571 rewriter, loc, kernel, newKernelType);
1573 rewriter, loc, output, newOutputType);
1577 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1578 strides.erase(strides.begin() + (removeH ? 0 : 1));
1582 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1583 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1586 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1587 loc, newOutputType,
ValueRange{newInput, newKernel},
1588 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1592 rewriter, loc, conv1DOp.getResult(0), output);
1601 if (convOp.hasPureBufferSemantics())
1604 Value input = convOp.getInputs().front();
1605 Value kernel = convOp.getInputs().back();
1606 Value output = convOp.getOutputs().front();
1608 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1609 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1610 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1612 auto kernelShape = kernelType.getShape();
1613 auto outputShape = outputType.getShape();
1617 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1618 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1619 bool removeH = (khSize == 1 && ohSize == 1);
1620 bool removeW = (kwSize == 1 && owSize == 1);
1621 if (!removeH && !removeW)
1627 RankedTensorType newInputType =
1628 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1629 RankedTensorType newKernelType =
1630 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1631 RankedTensorType newOutputType =
1632 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1637 rewriter, loc, input, newInputType);
1639 rewriter, loc, kernel, newKernelType);
1641 rewriter, loc, output, newOutputType);
1643 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1649 rewriter, loc, conv1DOp.getResult(0), output);
1668 PoolingNwcMaxUnsignedOp>,
1671 PoolingNwcMinUnsignedOp>,
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of rank-reduced.
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.